cp_library_rs/data_structure/
bit.rs

1//! BinaryIndexedTree / FenwickTree
2
3use std::{
4    fmt::Debug,
5    ops::{
6        Bound::{Excluded, Included, Unbounded},
7        RangeBounds,
8    },
9};
10
11use crate::algebraic_structure::{group::Group, monoid::Monoid, ordered_monoid::OrderedMonoid};
12
13/// # BinaryIndexedTree
14/// - `0-indexed`なインターフェースを持つBIT
15pub struct BIT<T: Monoid> {
16    pub size: usize,
17    arr: Vec<T::Val>,
18}
19
20impl<T: Monoid> BIT<T> {
21    #[inline]
22    fn lsb(x: usize) -> usize {
23        x & x.wrapping_neg()
24    }
25
26    /// BITの初期化を行う
27    /// - `n`: 列の長さ
28    pub fn new(n: usize) -> Self {
29        BIT {
30            size: n,
31            arr: vec![T::e(); n + 1],
32        }
33    }
34
35    /// 一点加算を行う
36    /// - `i`: 加算を行うインデックス(`0-indexed`)
37    /// - `x`: 加算する値
38    pub fn add(&mut self, mut i: usize, x: T::Val) {
39        i += 1;
40        while i <= self.size {
41            self.arr[i] = T::op(&self.arr[i], &x);
42            i += Self::lsb(i);
43        }
44    }
45
46    /// 先頭からの和を求める
47    /// - `i`: 区間`[0,i)`に対しての総和(`0-indexed`)
48    pub fn prefix_sum(&self, mut i: usize) -> T::Val {
49        let mut res = T::e();
50        while i != 0 {
51            res = T::op(&res, &self.arr[i]);
52            i -= Self::lsb(i);
53        }
54        res
55    }
56}
57
58impl<T: Group> BIT<T> {
59    #[inline]
60    fn parse_range<R: RangeBounds<usize>>(&self, range: R) -> Option<(usize, usize)> {
61        let start = match range.start_bound() {
62            Unbounded => 0,
63            Excluded(&v) => v + 1,
64            Included(&v) => v,
65        }
66        .min(self.size);
67        let end = match range.end_bound() {
68            Unbounded => self.size,
69            Excluded(&v) => v,
70            Included(&v) => v + 1,
71        }
72        .min(self.size);
73        if start <= end {
74            Some((start, end))
75        } else {
76            None
77        }
78    }
79
80    /// 任意の区間の和を求める
81    /// - `range`: 区間を表すRangeオブジェクト
82    pub fn sum<R: RangeBounds<usize>>(&self, range: R) -> T::Val {
83        if let Some((i, j)) = self.parse_range(range) {
84            T::op(&self.prefix_sum(j), &T::inv(&self.prefix_sum(i)))
85        } else {
86            T::e()
87        }
88    }
89}
90
91impl<T: Monoid> From<&Vec<T::Val>> for BIT<T> {
92    /// ベクターの参照からBITを作成
93    fn from(src: &Vec<T::Val>) -> Self {
94        let size = src.len();
95        let mut arr = vec![T::e(); size + 1];
96        for i in 1..=size {
97            let x = src[i - 1].clone();
98            arr[i] = T::op(&arr[i], &x);
99            let j = i + Self::lsb(i);
100            if j < size + 1 {
101                arr[j] = T::op(&arr[j], &arr[i].clone());
102            }
103        }
104        Self { size, arr }
105    }
106}
107
108impl<T: OrderedMonoid> BIT<T> {
109    /// `lower_bound`/`upper_bound`を共通化した実装
110    fn binary_search<F>(&self, w: T::Val, compare: F) -> usize
111    where
112        F: Fn(&T::Val, &T::Val) -> bool,
113    {
114        let mut sum = T::e();
115        let mut idx = 0;
116        let mut d = self.size.next_power_of_two() / 2;
117        while d != 0 {
118            if idx + d <= self.size {
119                let nxt = T::op(&sum, &self.arr[idx + d]);
120                if compare(&nxt, &w) {
121                    sum = nxt;
122                    idx += d;
123                }
124            }
125            d >>= 1;
126        }
127        idx
128    }
129    /// `a_0 + a_1 + ... + a_i >= w`となる最小の`i`を求める
130    pub fn lower_bound(&self, w: T::Val) -> usize {
131        self.binary_search(w, T::lt)
132    }
133    /// `a_0 + a_1 + ... + a_i > w`となる最小の`i`を求める
134    pub fn upper_bound(&self, w: T::Val) -> usize {
135        self.binary_search(w, T::le)
136    }
137}
138
139impl<T> Debug for BIT<T>
140where
141    T: Group,
142    T::Val: Debug,
143{
144    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
145        write!(f, "BIT {{ [")?;
146        for i in 0..self.size - 1 {
147            write!(f, "{:?}, ", self.sum(i..i + 1))?;
148        }
149        write!(f, "{:?}] }}", self.sum(self.size - 1..self.size))
150    }
151}