cp_library_rs/data_structure/
segment_tree_2d.rs1#![allow(clippy::needless_range_loop)]
5
6use crate::algebraic_structure::monoid::Monoid;
7use std::fmt::{self, Debug};
8use std::ops::{
9 Bound::{Excluded, Included, Unbounded},
10 Deref, DerefMut, RangeBounds,
11};
12
13macro_rules! cfor {
14 ($def:stmt ; $fin:expr ; $incr:stmt ;; $bl:block) => {{
15 $def
16 while $fin {
17 $bl
18 $incr
19 }
20 }}
21}
22
23pub struct SegmentTree2D<M: Monoid> {
26 pub oh: usize,
27 pub ow: usize,
28 pub data: Vec<M::Val>,
29}
30
31impl<M: Monoid> SegmentTree2D<M> {
32 #[inline]
33 fn parse_range<R: RangeBounds<usize>>(&self, range: &R, max: usize) -> Option<(usize, usize)> {
34 let start = match range.start_bound() {
35 Unbounded => 0,
36 Excluded(&v) => v + 1,
37 Included(&v) => v,
38 };
39 let end = match range.end_bound() {
40 Unbounded => max,
41 Excluded(&v) => v,
42 Included(&v) => v + 1,
43 };
44 if start <= end && end <= max {
45 Some((start, end))
46 } else {
47 None
48 }
49 }
50
51 #[inline]
52 fn idx(&self, i: usize, j: usize) -> usize {
53 2 * self.ow * i + j
54 }
55
56 pub fn new(H: usize, W: usize) -> Self {
58 let oh = H.next_power_of_two();
59 let ow = W.next_power_of_two();
60
61 Self {
62 oh,
63 ow,
64 data: vec![M::e(); 4 * oh * ow],
65 }
66 }
67
68 pub fn update(&mut self, mut r: usize, mut c: usize, x: M::Val) {
70 r += self.oh;
71 c += self.ow;
72 let idx = self.idx(r, c);
73 self.data[idx] = x;
74 cfor! {let mut i = r >> 1; i > 0; i >>= 1;; {
76 let idx = self.idx(i, c);
77 self.data[idx] = M::op(
78 &self.data[self.idx(2 * i, c)],
79 &self.data[self.idx(2 * i + 1, c)],
80 );
81 }}
82 cfor! {let mut i = r; i > 0; i >>= 1;; {
84 cfor! {let mut j = c >> 1; j > 0; j >>= 1;; {
85 let idx = self.idx(i, j);
86 self.data[idx] = M::op(
87 &self.data[self.idx(i, 2 * j)],
88 &self.data[self.idx(i, 2 * j + 1)],
89 );
90 }}
91 }}
92 }
93
94 pub fn get_mut(&mut self, r: usize, c: usize) -> Option<ValMut<'_, M>> {
96 if r < self.oh && c < self.ow {
97 let old_val = self.data[self.idx(r + self.oh, c + self.ow)].clone();
98 Some(ValMut {
99 segtree: self,
100 r,
101 c,
102 new_val: old_val,
103 })
104 } else {
105 None
106 }
107 }
108
109 fn aggregate_row(&self, r: usize, mut cs: usize, mut ce: usize) -> M::Val {
111 let mut res = M::e();
113 while cs < ce {
114 if cs & 1 == 1 {
115 res = M::op(&res, &self.data[self.idx(r, cs)]);
116 cs += 1;
117 }
118 if ce & 1 == 1 {
119 ce -= 1;
120 res = M::op(&res, &self.data[self.idx(r, ce)]);
121 }
122 cs >>= 1;
123 ce >>= 1;
124 }
125 res
126 }
127
128 pub fn get_range<R, C>(&self, row: R, col: C) -> M::Val
130 where
131 R: RangeBounds<usize> + fmt::Debug,
132 C: RangeBounds<usize> + fmt::Debug,
133 {
134 let Some((mut rs, mut re)) = self.parse_range(&row, self.oh) else {
135 panic!("The given range is wrong (row): {:?}", row);
136 };
137 let Some((mut cs, mut ce)) = self.parse_range(&col, self.ow) else {
138 panic!("The given range is wrong (col): {:?}", col);
139 };
140 rs += self.oh;
141 re += self.oh;
142 cs += self.ow;
143 ce += self.ow;
144 let mut res = M::e();
146 while rs < re {
147 if rs & 1 == 1 {
148 res = M::op(&res, &self.aggregate_row(rs, cs, ce));
149 rs += 1;
150 }
151 if re & 1 == 1 {
152 re -= 1;
153 res = M::op(&res, &self.aggregate_row(re, cs, ce));
154 }
155 rs >>= 1;
156 re >>= 1;
157 }
158 res
159 }
160}
161
162impl<M: Monoid> From<&Vec<Vec<M::Val>>> for SegmentTree2D<M> {
163 fn from(src: &Vec<Vec<M::Val>>) -> Self {
164 let (H, W) = (src.len(), src[0].len());
165 let mut seg = SegmentTree2D::new(H, W);
166 let (oh, ow) = (seg.oh, seg.ow);
167 for i in 0..H {
169 for j in 0..W {
170 let idx = seg.idx(oh + i, ow + j);
171 seg.data[idx] = src[i][j].clone();
172 }
173 }
174 for j in ow..2 * ow {
176 for i in (1..oh).rev() {
177 let idx = seg.idx(i, j);
178 seg.data[idx] = M::op(
179 &seg.data[seg.idx(2 * i, j)],
180 &seg.data[seg.idx(2 * i + 1, j)],
181 );
182 }
183 }
184 for i in 0..2 * oh {
186 for j in (1..ow).rev() {
187 let idx = seg.idx(i, j);
188 seg.data[idx] = M::op(
189 &seg.data[seg.idx(i, 2 * j)],
190 &seg.data[seg.idx(i, 2 * j + 1)],
191 );
192 }
193 }
194 seg
195 }
196}
197
198pub struct ValMut<'a, M: 'a + Monoid> {
199 segtree: &'a mut SegmentTree2D<M>,
200 r: usize,
201 c: usize,
202 new_val: M::Val,
203}
204
205impl<M> fmt::Debug for ValMut<'_, M>
206where
207 M: Monoid,
208 M::Val: Debug,
209{
210 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
211 f.debug_struct("ValMut")
212 .field("r", &self.r)
213 .field("c", &self.c)
214 .field("new_val", &self.new_val)
215 .finish()
216 }
217}
218
219impl<M: Monoid> Drop for ValMut<'_, M> {
220 fn drop(&mut self) {
221 self.segtree.update(self.r, self.c, self.new_val.clone());
222 }
223}
224
225impl<M: Monoid> Deref for ValMut<'_, M> {
226 type Target = M::Val;
227 fn deref(&self) -> &Self::Target {
228 &self.new_val
229 }
230}
231
232impl<M: Monoid> DerefMut for ValMut<'_, M> {
233 fn deref_mut(&mut self) -> &mut Self::Target {
234 &mut self.new_val
235 }
236}
237
238impl<M> SegmentTree2D<M>
239where
240 M: Monoid,
241 M::Val: Debug,
242{
243 pub fn show(&self) {
245 if cfg!(debug_assertions) {
246 let H = self.oh;
247 let W = self.ow;
248 eprintln!("SegmentTree2D (H={}, W={}) {{", H, W);
249 for i in 0..H {
250 eprintln!(" {:?},", &self.data[2 * i * W..2 * (i + 1) * W]);
251 }
252 eprintln!("}}");
253 }
254 }
255}