cp_library_rs/graph/
loop_detection_fold.rs

1//! ループ検出(fold関数版)
2
3use num_bigint::BigUint;
4use num_traits::{FromPrimitive, ToPrimitive, Zero};
5use std::ops::{Add, Mul, Sub};
6use std::{collections::HashMap, hash::Hash};
7
8pub struct Loop<T, V, F, G>
9where
10    F: Fn(T) -> T,
11    G: Fn(V, T) -> V,
12{
13    /// ノードの移動を行う関数
14    pub next: F,
15    /// 値を更新する関数
16    pub fold: G,
17    /// 始点となるノード
18    pub begin: T,
19    /// ループの長さ
20    pub loop_len: usize,
21    /// ループ開始時の値
22    pub loop_begin: T,
23    /// ループに到達するまでの移動回数
24    pub loop_begin_idx: usize,
25    /// ループ開始時までの累積
26    pub before_loop_sum: V,
27    /// ループ内での累積
28    pub loop_sum: V,
29    /// ループの途中の値
30    vals: HashMap<T, (usize, V)>,
31}
32
33impl<T, V, F, G> Loop<T, V, F, G>
34where
35    T: Copy + Hash + Eq,
36    V: Copy + Zero + Add<Output = V> + Sub<Output = V> + Mul<usize, Output = V>,
37    F: Fn(T) -> T,
38    G: Fn(V, T) -> V,
39{
40    /// ループを検出する
41    pub fn build(begin: T, next: F, fold: G) -> Self {
42        // 初期化
43        let mut cur: T = begin;
44        let mut idx: usize = 0;
45        let mut sum: V = V::zero();
46        let mut vals: HashMap<T, (usize, V)> = HashMap::new();
47
48        // ループ検出
49        while !vals.contains_key(&cur) {
50            vals.insert(cur, (idx, sum));
51            sum = fold(sum, cur);
52            cur = next(cur);
53            idx += 1;
54        }
55
56        // ループの値を取り出す
57        let loop_begin = cur;
58        let (loop_begin_idx, before_loop_sum) = vals[&loop_begin];
59        let loop_len = idx - loop_begin_idx;
60        let loop_sum = sum - before_loop_sum;
61
62        // 返す
63        Self {
64            next,
65            fold,
66            begin,
67            loop_len,
68            loop_begin,
69            loop_begin_idx,
70            before_loop_sum,
71            loop_sum,
72            vals,
73        }
74    }
75
76    fn accumulate(&self, begin: T, n: usize) -> (T, V) {
77        let mut res = V::zero();
78        let mut cur = begin;
79        for _ in 0..n {
80            res = (self.fold)(res, cur);
81            cur = (self.next)(cur);
82        }
83        (cur, res)
84    }
85
86    /// self.beginからn個後の頂点を取り出す
87    pub fn get_nth_node_usize(&self, n: usize) -> T {
88        if n < self.loop_begin_idx {
89            self.accumulate(self.begin, n).0
90        } else {
91            let loop_rem = (n - self.loop_begin_idx) % self.loop_len;
92            self.accumulate(self.loop_begin, loop_rem).0
93        }
94    }
95
96    /// self.beginからn個後の値を取り出す
97    pub fn get_nth_val_usize(&self, n: usize) -> V {
98        if n < self.loop_begin_idx {
99            self.accumulate(self.begin, n).1
100        } else {
101            let loop_rep = (n - self.loop_begin_idx) / self.loop_len;
102            let loop_rem = (n - self.loop_begin_idx) % self.loop_len;
103            self.before_loop_sum
104                + self.loop_sum * loop_rep
105                + self.accumulate(self.loop_begin, loop_rem).1
106        }
107    }
108
109    /// self.beginからn個後の値を取り出す
110    pub fn get_nth_node_biguint(&self, n: BigUint) -> T {
111        let loop_begin_idx = BigUint::from_usize(self.loop_begin_idx).unwrap();
112        if n < loop_begin_idx {
113            let n_usize = n.to_usize().unwrap();
114            self.accumulate(self.begin, n_usize).0
115        } else {
116            let loop_len = BigUint::from_usize(self.loop_len).unwrap();
117            let loop_rem = (n - loop_begin_idx) % loop_len;
118            let loop_rem = loop_rem.to_usize().unwrap();
119            self.accumulate(self.loop_begin, loop_rem).0
120        }
121    }
122}