1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
//! ループ検出(fold関数版)

use num_bigint::BigUint;
use num_traits::{FromPrimitive, ToPrimitive, Zero};
use std::ops::{Add, Mul, Sub};
use std::{collections::HashMap, hash::Hash};

pub struct Loop<T, V, F, G>
where
    F: Fn(T) -> T,
    G: Fn(V, T) -> V,
{
    /// ノードの移動を行う関数
    pub next: F,
    /// 値を更新する関数
    pub fold: G,
    /// 始点となるノード
    pub begin: T,
    /// ループの長さ
    pub loop_len: usize,
    /// ループ開始時の値
    pub loop_begin: T,
    /// ループに到達するまでの移動回数
    pub loop_begin_idx: usize,
    /// ループ開始時までの累積
    pub before_loop_sum: V,
    /// ループ内での累積
    pub loop_sum: V,
    /// ループの途中の値
    vals: HashMap<T, (usize, V)>,
}

impl<T, V, F, G> Loop<T, V, F, G>
where
    T: Copy + Hash + Eq,
    V: Copy + Zero + Add<Output = V> + Sub<Output = V> + Mul<usize, Output = V>,
    F: Fn(T) -> T,
    G: Fn(V, T) -> V,
{
    /// ループを検出する
    pub fn build(begin: T, next: F, fold: G) -> Self {
        // 初期化
        let mut cur: T = begin;
        let mut idx: usize = 0;
        let mut sum: V = V::zero();
        let mut vals: HashMap<T, (usize, V)> = HashMap::new();

        // ループ検出
        while !vals.contains_key(&cur) {
            vals.insert(cur, (idx, sum));
            sum = fold(sum, cur);
            cur = next(cur);
            idx += 1;
        }

        // ループの値を取り出す
        let loop_begin = cur;
        let (loop_begin_idx, before_loop_sum) = vals[&loop_begin];
        let loop_len = idx - loop_begin_idx;
        let loop_sum = sum - before_loop_sum;

        // 返す
        Self {
            next,
            fold,
            begin,
            loop_len,
            loop_begin,
            loop_begin_idx,
            before_loop_sum,
            loop_sum,
            vals,
        }
    }

    fn accumulate(&self, begin: T, n: usize) -> (T, V) {
        let mut res = V::zero();
        let mut cur = begin;
        for _ in 0..n {
            res = (self.fold)(res, cur);
            cur = (self.next)(cur);
        }
        (cur, res)
    }

    /// self.beginからn個後の頂点を取り出す
    pub fn get_nth_node_usize(&self, n: usize) -> T {
        if n < self.loop_begin_idx {
            self.accumulate(self.begin, n).0
        } else {
            let loop_rem = (n - self.loop_begin_idx) % self.loop_len;
            self.accumulate(self.loop_begin, loop_rem).0
        }
    }

    /// self.beginからn個後の値を取り出す
    pub fn get_nth_val_usize(&self, n: usize) -> V {
        if n < self.loop_begin_idx {
            self.accumulate(self.begin, n).1
        } else {
            let loop_rep = (n - self.loop_begin_idx) / self.loop_len;
            let loop_rem = (n - self.loop_begin_idx) % self.loop_len;
            self.before_loop_sum
                + self.loop_sum * loop_rep
                + self.accumulate(self.loop_begin, loop_rem).1
        }
    }

    /// self.beginからn個後の値を取り出す
    pub fn get_nth_node_biguint(&self, n: BigUint) -> T {
        let loop_begin_idx = BigUint::from_usize(self.loop_begin_idx).unwrap();
        if n < loop_begin_idx {
            let n_usize = n.to_usize().unwrap();
            self.accumulate(self.begin, n_usize).0
        } else {
            let loop_len = BigUint::from_usize(self.loop_len).unwrap();
            let loop_rem = (n - loop_begin_idx) % loop_len;
            let loop_rem = loop_rem.to_usize().unwrap();
            self.accumulate(self.loop_begin, loop_rem).0
        }
    }
}