cp_library_rs/data_structure/
segment_tree_2d.rs

1//! 二次元セグメント木
2//! - 参考:[二次元セグメント木 - Nyaan's Library](https://nyaannyaan.github.io/library/data-structure-2d/2d-segment-tree.hpp.html)
3
4#![allow(clippy::needless_range_loop)]
5
6use crate::algebraic_structure::monoid::Monoid;
7use std::fmt::{self, Debug};
8use std::ops::{
9    Bound::{Excluded, Included, Unbounded},
10    Deref, DerefMut, RangeBounds,
11};
12
13macro_rules! cfor {
14    ($def:stmt ; $fin:expr ; $incr:stmt ;; $bl:block) => {{
15        $def
16        while $fin {
17            $bl
18            $incr
19        }
20    }}
21}
22
23/// # SegmentTree2D (Monoid)
24/// - 2次元セグメント木
25pub struct SegmentTree2D<M: Monoid> {
26    pub oh: usize,
27    pub ow: usize,
28    pub data: Vec<M::Val>,
29}
30
31impl<M: Monoid> SegmentTree2D<M> {
32    #[inline]
33    fn parse_range<R: RangeBounds<usize>>(&self, range: &R, max: usize) -> Option<(usize, usize)> {
34        let start = match range.start_bound() {
35            Unbounded => 0,
36            Excluded(&v) => v + 1,
37            Included(&v) => v,
38        };
39        let end = match range.end_bound() {
40            Unbounded => max,
41            Excluded(&v) => v,
42            Included(&v) => v + 1,
43        };
44        if start <= end && end <= max {
45            Some((start, end))
46        } else {
47            None
48        }
49    }
50
51    #[inline]
52    fn idx(&self, i: usize, j: usize) -> usize {
53        2 * self.ow * i + j
54    }
55
56    /// セグメント木を初期化する
57    pub fn new(H: usize, W: usize) -> Self {
58        let oh = H.next_power_of_two();
59        let ow = W.next_power_of_two();
60
61        Self {
62            oh,
63            ow,
64            data: vec![M::e(); 4 * oh * ow],
65        }
66    }
67
68    /// 座標 `(r,c)` の値を `x` に更新する
69    pub fn update(&mut self, mut r: usize, mut c: usize, x: M::Val) {
70        r += self.oh;
71        c += self.ow;
72        let idx = self.idx(r, c);
73        self.data[idx] = x;
74        // col方向の更新
75        cfor! {let mut i = r >> 1; i > 0; i >>= 1;; {
76            let idx = self.idx(i, c);
77            self.data[idx] = M::op(
78                &self.data[self.idx(2 * i, c)],
79                &self.data[self.idx(2 * i + 1, c)],
80            );
81        }}
82        // row方向の更新
83        cfor! {let mut i = r; i > 0; i >>= 1;; {
84            cfor! {let mut j = c >> 1; j > 0; j >>= 1;; {
85                let idx = self.idx(i, j);
86                self.data[idx] = M::op(
87                    &self.data[self.idx(i, 2 * j)],
88                    &self.data[self.idx(i, 2 * j + 1)],
89                );
90            }}
91        }}
92    }
93
94    /// 可変な参照を返す
95    pub fn get_mut(&mut self, r: usize, c: usize) -> Option<ValMut<'_, M>> {
96        if r < self.oh && c < self.ow {
97            let old_val = self.data[self.idx(r + self.oh, c + self.ow)].clone();
98            Some(ValMut {
99                segtree: self,
100                r,
101                c,
102                new_val: old_val,
103            })
104        } else {
105            None
106        }
107    }
108
109    /// row方向での集約を行う
110    fn aggregate_row(&self, r: usize, mut cs: usize, mut ce: usize) -> M::Val {
111        // 集約
112        let mut res = M::e();
113        while cs < ce {
114            if cs & 1 == 1 {
115                res = M::op(&res, &self.data[self.idx(r, cs)]);
116                cs += 1;
117            }
118            if ce & 1 == 1 {
119                ce -= 1;
120                res = M::op(&res, &self.data[self.idx(r, ce)]);
121            }
122            cs >>= 1;
123            ce >>= 1;
124        }
125        res
126    }
127
128    /// 区間の集約を行う
129    pub fn get_range<R, C>(&self, row: R, col: C) -> M::Val
130    where
131        R: RangeBounds<usize> + fmt::Debug,
132        C: RangeBounds<usize> + fmt::Debug,
133    {
134        let Some((mut rs, mut re)) = self.parse_range(&row, self.oh) else {
135            panic!("The given range is wrong (row): {:?}", row);
136        };
137        let Some((mut cs, mut ce)) = self.parse_range(&col, self.ow) else {
138            panic!("The given range is wrong (col): {:?}", col);
139        };
140        rs += self.oh;
141        re += self.oh;
142        cs += self.ow;
143        ce += self.ow;
144        // 値の取得
145        let mut res = M::e();
146        while rs < re {
147            if rs & 1 == 1 {
148                res = M::op(&res, &self.aggregate_row(rs, cs, ce));
149                rs += 1;
150            }
151            if re & 1 == 1 {
152                re -= 1;
153                res = M::op(&res, &self.aggregate_row(re, cs, ce));
154            }
155            rs >>= 1;
156            re >>= 1;
157        }
158        res
159    }
160}
161
162impl<M: Monoid> From<&Vec<Vec<M::Val>>> for SegmentTree2D<M> {
163    fn from(src: &Vec<Vec<M::Val>>) -> Self {
164        let (H, W) = (src.len(), src[0].len());
165        let mut seg = SegmentTree2D::new(H, W);
166        let (oh, ow) = (seg.oh, seg.ow);
167        // セグ木の値を埋める
168        for i in 0..H {
169            for j in 0..W {
170                let idx = seg.idx(oh + i, ow + j);
171                seg.data[idx] = src[i][j].clone();
172            }
173        }
174        // col方向の集約
175        for j in ow..2 * ow {
176            for i in (1..oh).rev() {
177                let idx = seg.idx(i, j);
178                seg.data[idx] = M::op(
179                    &seg.data[seg.idx(2 * i, j)],
180                    &seg.data[seg.idx(2 * i + 1, j)],
181                );
182            }
183        }
184        // row方向の集約
185        for i in 0..2 * oh {
186            for j in (1..ow).rev() {
187                let idx = seg.idx(i, j);
188                seg.data[idx] = M::op(
189                    &seg.data[seg.idx(i, 2 * j)],
190                    &seg.data[seg.idx(i, 2 * j + 1)],
191                );
192            }
193        }
194        seg
195    }
196}
197
198pub struct ValMut<'a, M: 'a + Monoid> {
199    segtree: &'a mut SegmentTree2D<M>,
200    r: usize,
201    c: usize,
202    new_val: M::Val,
203}
204
205impl<M> fmt::Debug for ValMut<'_, M>
206where
207    M: Monoid,
208    M::Val: Debug,
209{
210    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
211        f.debug_struct("ValMut")
212            .field("r", &self.r)
213            .field("c", &self.c)
214            .field("new_val", &self.new_val)
215            .finish()
216    }
217}
218
219impl<M: Monoid> Drop for ValMut<'_, M> {
220    fn drop(&mut self) {
221        self.segtree.update(self.r, self.c, self.new_val.clone());
222    }
223}
224
225impl<M: Monoid> Deref for ValMut<'_, M> {
226    type Target = M::Val;
227    fn deref(&self) -> &Self::Target {
228        &self.new_val
229    }
230}
231
232impl<M: Monoid> DerefMut for ValMut<'_, M> {
233    fn deref_mut(&mut self) -> &mut Self::Target {
234        &mut self.new_val
235    }
236}
237
238impl<M> SegmentTree2D<M>
239where
240    M: Monoid,
241    M::Val: Debug,
242{
243    /// テーブルとして表示する
244    pub fn show(&self) {
245        if cfg!(debug_assertions) {
246            let H = self.oh;
247            let W = self.ow;
248            eprintln!("SegmentTree2D (H={}, W={}) {{", H, W);
249            for i in 0..H {
250                eprintln!("  {:?},", &self.data[2 * i * W..2 * (i + 1) * W]);
251            }
252            eprintln!("}}");
253        }
254    }
255}