cp_library_rs/graph/
loop_detection_fold.rs1use num_bigint::BigUint;
4use num_traits::{FromPrimitive, ToPrimitive, Zero};
5use std::ops::{Add, Mul, Sub};
6use std::{collections::HashMap, hash::Hash};
7
8pub struct Loop<T, V, F, G>
9where
10 F: Fn(T) -> T,
11 G: Fn(V, T) -> V,
12{
13 pub next: F,
15 pub fold: G,
17 pub begin: T,
19 pub loop_len: usize,
21 pub loop_begin: T,
23 pub loop_begin_idx: usize,
25 pub before_loop_sum: V,
27 pub loop_sum: V,
29 vals: HashMap<T, (usize, V)>,
31}
32
33impl<T, V, F, G> Loop<T, V, F, G>
34where
35 T: Copy + Hash + Eq,
36 V: Copy + Zero + Add<Output = V> + Sub<Output = V> + Mul<usize, Output = V>,
37 F: Fn(T) -> T,
38 G: Fn(V, T) -> V,
39{
40 pub fn build(begin: T, next: F, fold: G) -> Self {
42 let mut cur: T = begin;
44 let mut idx: usize = 0;
45 let mut sum: V = V::zero();
46 let mut vals: HashMap<T, (usize, V)> = HashMap::new();
47
48 while !vals.contains_key(&cur) {
50 vals.insert(cur, (idx, sum));
51 sum = fold(sum, cur);
52 cur = next(cur);
53 idx += 1;
54 }
55
56 let loop_begin = cur;
58 let (loop_begin_idx, before_loop_sum) = vals[&loop_begin];
59 let loop_len = idx - loop_begin_idx;
60 let loop_sum = sum - before_loop_sum;
61
62 Self {
64 next,
65 fold,
66 begin,
67 loop_len,
68 loop_begin,
69 loop_begin_idx,
70 before_loop_sum,
71 loop_sum,
72 vals,
73 }
74 }
75
76 fn accumulate(&self, begin: T, n: usize) -> (T, V) {
77 let mut res = V::zero();
78 let mut cur = begin;
79 for _ in 0..n {
80 res = (self.fold)(res, cur);
81 cur = (self.next)(cur);
82 }
83 (cur, res)
84 }
85
86 pub fn get_nth_node_usize(&self, n: usize) -> T {
88 if n < self.loop_begin_idx {
89 self.accumulate(self.begin, n).0
90 } else {
91 let loop_rem = (n - self.loop_begin_idx) % self.loop_len;
92 self.accumulate(self.loop_begin, loop_rem).0
93 }
94 }
95
96 pub fn get_nth_val_usize(&self, n: usize) -> V {
98 if n < self.loop_begin_idx {
99 self.accumulate(self.begin, n).1
100 } else {
101 let loop_rep = (n - self.loop_begin_idx) / self.loop_len;
102 let loop_rem = (n - self.loop_begin_idx) % self.loop_len;
103 self.before_loop_sum
104 + self.loop_sum * loop_rep
105 + self.accumulate(self.loop_begin, loop_rem).1
106 }
107 }
108
109 pub fn get_nth_node_biguint(&self, n: BigUint) -> T {
111 let loop_begin_idx = BigUint::from_usize(self.loop_begin_idx).unwrap();
112 if n < loop_begin_idx {
113 let n_usize = n.to_usize().unwrap();
114 self.accumulate(self.begin, n_usize).0
115 } else {
116 let loop_len = BigUint::from_usize(self.loop_len).unwrap();
117 let loop_rem = (n - loop_begin_idx) % loop_len;
118 let loop_rem = loop_rem.to_usize().unwrap();
119 self.accumulate(self.loop_begin, loop_rem).0
120 }
121 }
122}