cp_library_rs/utils/
enum_comb.rs

1//! 組合せの列挙
2
3use itertools::Itertools;
4use superslice::Ext;
5
6// ========== pairs ==========
7/// ペアのベクタ型
8pub type Pairs<T> = Vec<(T, T)>;
9
10/// ペアを列挙する
11#[derive(Debug)]
12pub struct ListPairs<T: Clone> {
13    stack: Vec<(Vec<T>, Pairs<T>)>,
14}
15
16impl<T: Clone> FromIterator<T> for ListPairs<T> {
17    fn from_iter<I: IntoIterator<Item = T>>(iter: I) -> Self {
18        Self {
19            stack: vec![(iter.into_iter().collect::<Vec<T>>(), vec![])],
20        }
21    }
22}
23
24impl<T: Clone> Iterator for ListPairs<T> {
25    type Item = Pairs<T>;
26    fn next(&mut self) -> Option<Self::Item> {
27        loop {
28            let (rem, pairs) = self.stack.pop()?;
29
30            if rem.len() < 2 {
31                return Some(pairs);
32            }
33            for i in (1..rem.len()).rev() {
34                let mut new_rem = rem.clone();
35                let snd = new_rem.remove(i);
36                let fst = new_rem.remove(0);
37                let mut new_pairs = pairs.clone();
38                new_pairs.push((fst, snd));
39                // 新しい要素を追加
40                self.stack.push((new_rem, new_pairs));
41            }
42        }
43    }
44}
45
46impl ListPairs<usize> {
47    /// (0〜n-1)のn個の要素からなる系列
48    /// をペアにする組合せを列挙する
49    pub fn pairs_usize(n: usize) -> Self {
50        (0..n).collect()
51    }
52}
53
54// ========== comb with rep ==========
55/// n 個の集合から重複を許して r 個取り出すときの組合せを列挙する.
56pub fn comb_with_rep(n: usize, r: usize) -> impl Iterator<Item = Vec<usize>> {
57    let perm: Vec<_> = std::iter::repeat_n(false, r)
58        .chain(std::iter::repeat_n(true, n - 1))
59        .collect();
60
61    std::iter::once(perm.clone())
62        .chain(std::iter::repeat(()).scan(perm, |p, _| p.next_permutation().then_some(p.clone())))
63        .map(aggregate_comb)
64}
65
66fn aggregate_comb(choose: Vec<bool>) -> Vec<usize> {
67    let mut res = vec![];
68    if choose[0] {
69        res.push(0);
70    }
71    for (k, &f) in choose.iter().dedup_with_count() {
72        if f {
73            res.extend(std::iter::repeat_n(0, k - 1));
74        } else {
75            res.push(k);
76        }
77    }
78    if *choose.last().unwrap() {
79        res.push(0);
80    }
81    res
82}