use num_bigint::BigUint;
use num_traits::{FromPrimitive, ToPrimitive, Zero};
use std::ops::{Add, Mul, Sub};
use std::{collections::HashMap, hash::Hash};
pub struct Loop<T, V, F, G>
where
F: Fn(T) -> T,
G: Fn(T) -> V,
{
pub next: F,
pub get_val: G,
pub begin: T,
pub loop_len: usize,
pub loop_begin: T,
pub loop_begin_idx: usize,
pub before_loop_sum: V,
pub loop_sum: V,
vals: HashMap<T, (usize, V)>,
}
impl<T, V, F, G> Loop<T, V, F, G>
where
T: Copy + Hash + Eq,
V: Copy + Zero + Add<Output = V> + Sub<Output = V> + Mul<usize, Output = V>,
F: Fn(T) -> T,
G: Fn(T) -> V,
{
pub fn build(begin: T, next: F, get_val: G) -> Self {
let mut cur: T = begin;
let mut idx: usize = 0;
let mut sum: V = V::zero();
let mut vals: HashMap<T, (usize, V)> = HashMap::new();
while !vals.contains_key(&cur) {
vals.insert(cur, (idx, sum));
sum = sum + get_val(cur);
cur = next(cur);
idx += 1;
}
let loop_begin = cur;
let (loop_begin_idx, before_loop_sum) = vals[&loop_begin];
let loop_len = idx - loop_begin_idx;
let loop_sum = sum - before_loop_sum;
Self {
next,
get_val,
begin,
loop_len,
loop_begin,
loop_begin_idx,
before_loop_sum,
loop_sum,
vals,
}
}
fn accumulate(&self, begin: T, n: usize) -> (T, V) {
let mut res = V::zero();
let mut cur = begin;
for _ in 0..n {
res = res + (self.get_val)(cur);
cur = (self.next)(cur);
}
(cur, res)
}
pub fn get_nth_node_usize(&self, n: usize) -> T {
if n < self.loop_begin_idx {
self.accumulate(self.begin, n).0
} else {
let loop_rem = (n - self.loop_begin_idx) % self.loop_len;
self.accumulate(self.loop_begin, loop_rem).0
}
}
pub fn get_nth_val_usize(&self, n: usize) -> V {
if n < self.loop_begin_idx {
self.accumulate(self.begin, n).1
} else {
let loop_rep = (n - self.loop_begin_idx) / self.loop_len;
let loop_rem = (n - self.loop_begin_idx) % self.loop_len;
self.before_loop_sum
+ self.loop_sum * loop_rep
+ self.accumulate(self.loop_begin, loop_rem).1
}
}
pub fn get_nth_node_biguint(&self, n: BigUint) -> T {
let loop_begin_idx = BigUint::from_usize(self.loop_begin_idx).unwrap();
if n < loop_begin_idx {
let n_usize = n.to_usize().unwrap();
self.accumulate(self.begin, n_usize).0
} else {
let loop_len = BigUint::from_usize(self.loop_len).unwrap();
let loop_rem = (n - loop_begin_idx) % loop_len;
let loop_rem = loop_rem.to_usize().unwrap();
self.accumulate(self.loop_begin, loop_rem).0
}
}
}