cp_library_rs/data_structure/
bit.rs1use std::{
4 fmt::Debug,
5 ops::{
6 Bound::{Excluded, Included, Unbounded},
7 RangeBounds,
8 },
9};
10
11use crate::algebraic_structure::{group::Group, monoid::Monoid, ordered_monoid::OrderedMonoid};
12
13pub struct BIT<T: Monoid> {
16 pub size: usize,
17 arr: Vec<T::Val>,
18}
19
20impl<T: Monoid> BIT<T> {
21 #[inline]
22 fn lsb(x: usize) -> usize {
23 x & x.wrapping_neg()
24 }
25
26 pub fn new(n: usize) -> Self {
29 BIT {
30 size: n,
31 arr: vec![T::e(); n + 1],
32 }
33 }
34
35 pub fn add(&mut self, mut i: usize, x: T::Val) {
39 i += 1;
40 while i <= self.size {
41 self.arr[i] = T::op(&self.arr[i], &x);
42 i += Self::lsb(i);
43 }
44 }
45
46 pub fn prefix_sum(&self, mut i: usize) -> T::Val {
49 let mut res = T::e();
50 while i != 0 {
51 res = T::op(&res, &self.arr[i]);
52 i -= Self::lsb(i);
53 }
54 res
55 }
56}
57
58impl<T: Group> BIT<T> {
59 #[inline]
60 fn parse_range<R: RangeBounds<usize>>(&self, range: R) -> Option<(usize, usize)> {
61 let start = match range.start_bound() {
62 Unbounded => 0,
63 Excluded(&v) => v + 1,
64 Included(&v) => v,
65 }
66 .min(self.size);
67 let end = match range.end_bound() {
68 Unbounded => self.size,
69 Excluded(&v) => v,
70 Included(&v) => v + 1,
71 }
72 .min(self.size);
73 if start <= end {
74 Some((start, end))
75 } else {
76 None
77 }
78 }
79
80 pub fn sum<R: RangeBounds<usize>>(&self, range: R) -> T::Val {
83 if let Some((i, j)) = self.parse_range(range) {
84 T::op(&self.prefix_sum(j), &T::inv(&self.prefix_sum(i)))
85 } else {
86 T::e()
87 }
88 }
89}
90
91impl<T: Monoid> From<&Vec<T::Val>> for BIT<T> {
92 fn from(src: &Vec<T::Val>) -> Self {
94 let size = src.len();
95 let mut arr = vec![T::e(); size + 1];
96 for i in 1..=size {
97 let x = src[i - 1].clone();
98 arr[i] = T::op(&arr[i], &x);
99 let j = i + Self::lsb(i);
100 if j < size + 1 {
101 arr[j] = T::op(&arr[j], &arr[i].clone());
102 }
103 }
104 Self { size, arr }
105 }
106}
107
108impl<T: OrderedMonoid> BIT<T> {
109 fn binary_search<F>(&self, w: T::Val, compare: F) -> usize
111 where
112 F: Fn(&T::Val, &T::Val) -> bool,
113 {
114 let mut sum = T::e();
115 let mut idx = 0;
116 let mut d = self.size.next_power_of_two() / 2;
117 while d != 0 {
118 if idx + d <= self.size {
119 let nxt = T::op(&sum, &self.arr[idx + d]);
120 if compare(&nxt, &w) {
121 sum = nxt;
122 idx += d;
123 }
124 }
125 d >>= 1;
126 }
127 idx
128 }
129 pub fn lower_bound(&self, w: T::Val) -> usize {
131 self.binary_search(w, T::lt)
132 }
133 pub fn upper_bound(&self, w: T::Val) -> usize {
135 self.binary_search(w, T::le)
136 }
137}
138
139impl<T> Debug for BIT<T>
140where
141 T: Group,
142 T::Val: Debug,
143{
144 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
145 write!(f, "BIT {{ [")?;
146 for i in 0..self.size - 1 {
147 write!(f, "{:?}, ", self.sum(i..i + 1))?;
148 }
149 write!(f, "{:?}] }}", self.sum(self.size - 1..self.size))
150 }
151}