1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
//! SparseTable

use std::fmt;
use std::ops::{
    Bound::{Excluded, Included, Unbounded},
    RangeBounds,
};

use crate::algebraic_structure::semilattice::Semilattice;

#[derive(Debug)]
pub struct SparseTable<S: Semilattice> {
    pub size: usize,
    table: Vec<Vec<S::Val>>,
    logs: Vec<usize>,
}

impl<S: Semilattice> SparseTable<S> {
    #[inline]
    fn parse_range<R: RangeBounds<usize>>(&self, range: &R) -> Option<(usize, usize)> {
        let start = match range.start_bound() {
            Unbounded => 0,
            Excluded(&v) => v + 1,
            Included(&v) => v,
        };
        let end = match range.end_bound() {
            Unbounded => self.size,
            Excluded(&v) => v,
            Included(&v) => v + 1,
        };
        if start <= end && end <= self.size {
            Some((start, end))
        } else {
            None
        }
    }

    /// SparseTableを構築する
    pub fn build(arr: &[S::Val]) -> Self {
        let size = arr.len();
        // 区間取得用の配列
        let mut logs = vec![0; size + 1];
        for i in 2..=size {
            logs[i] = logs[i >> 1] + 1;
        }
        // テーブルの高さ
        let lg = logs[size] + 1;
        // 繰り返し適用した結果
        let mut table = vec![vec![]; lg];
        for a in arr {
            table[0].push(a.clone());
        }
        for i in 1..lg {
            let mut j = 0;
            while j + (1 << i) <= size {
                let a = &table[i - 1][j];
                let b = &table[i - 1][j + (1 << (i - 1))];
                let res = S::op(a, b);
                table[i].push(res);
                j += 1;
            }
        }
        Self { size, table, logs }
    }

    /// 区間取得
    pub fn get_range<R: RangeBounds<usize> + fmt::Debug>(&self, range: R) -> S::Val {
        let Some((start, end)) = self.parse_range(&range) else {
            panic!("The given range is wrong: {:?}", range);
        };

        if start >= end {
            return S::id();
        }

        let lg = self.logs[end - start];
        let left = &self.table[lg][start];
        let right = &self.table[lg][end - (1 << lg)];

        S::op(left, right)
    }
}