cp_library_rs/data_structure/
sparse_table.rs

1//! SparseTable
2
3use std::fmt;
4use std::ops::{
5    Bound::{Excluded, Included, Unbounded},
6    RangeBounds,
7};
8
9use crate::algebraic_structure::semilattice::Semilattice;
10
11#[derive(Debug)]
12pub struct SparseTable<S: Semilattice> {
13    pub size: usize,
14    table: Vec<Vec<S::Val>>,
15    logs: Vec<usize>,
16}
17
18impl<S: Semilattice> SparseTable<S> {
19    #[inline]
20    fn parse_range<R: RangeBounds<usize>>(&self, range: &R) -> Option<(usize, usize)> {
21        let start = match range.start_bound() {
22            Unbounded => 0,
23            Excluded(&v) => v + 1,
24            Included(&v) => v,
25        };
26        let end = match range.end_bound() {
27            Unbounded => self.size,
28            Excluded(&v) => v,
29            Included(&v) => v + 1,
30        };
31        if start <= end && end <= self.size {
32            Some((start, end))
33        } else {
34            None
35        }
36    }
37
38    /// SparseTableを構築する
39    pub fn build(arr: &[S::Val]) -> Self {
40        let size = arr.len();
41        // 区間取得用の配列
42        let mut logs = vec![0; size + 1];
43        for i in 2..=size {
44            logs[i] = logs[i >> 1] + 1;
45        }
46        // テーブルの高さ
47        let lg = logs[size] + 1;
48        // 繰り返し適用した結果
49        let mut table = vec![vec![]; lg];
50        for a in arr {
51            table[0].push(a.clone());
52        }
53        for i in 1..lg {
54            let mut j = 0;
55            while j + (1 << i) <= size {
56                let a = &table[i - 1][j];
57                let b = &table[i - 1][j + (1 << (i - 1))];
58                let res = S::op(a, b);
59                table[i].push(res);
60                j += 1;
61            }
62        }
63        Self { size, table, logs }
64    }
65
66    /// 区間取得
67    pub fn get_range<R: RangeBounds<usize> + fmt::Debug>(&self, range: R) -> S::Val {
68        let Some((start, end)) = self.parse_range(&range) else {
69            panic!("The given range is wrong: {:?}", range);
70        };
71
72        if start >= end {
73            return S::id();
74        }
75
76        let lg = self.logs[end - start];
77        let left = &self.table[lg][start];
78        let right = &self.table[lg][end - (1 << lg)];
79
80        S::op(left, right)
81    }
82}