cp_library_rs/data_structure/
segment_tree_ctx.rs

1//! ## セグメント木
2//!
3//! 集合 $`S`$ と演算 $`\circ`$ の組 $`(S,\circ)`$ がモノイド([`MonoidCtx`])であるとき,
4//! $`S`$ の要素の列 $`A`$ に対し,
5//!
6//! - 区間積の取得 : $`A[l] \circ A[l+1] \circ \cdots \circ A[r]`$
7//! - 要素の更新 : $`A[i] \leftarrow x`$
8//!
9//! をそれぞれ $`O(\log N)`$ で行う.($`N = |A|`$)
10
11use crate::{
12    algebraic_structure::monoid_with_context::MonoidCtx, tree::show_binary_tree::ShowBinaryTree,
13};
14use std::{
15    fmt::{self, Debug},
16    ops::{
17        Bound::{Excluded, Included, Unbounded},
18        Deref, DerefMut, Index, RangeBounds,
19    },
20};
21
22/// セグメント木
23pub struct SegmentTreeCtx<M: MonoidCtx> {
24    /// 要素数
25    pub N: usize,
26    offset: usize,
27    data: Vec<M::Val>,
28    monoid: M,
29}
30
31impl<M: MonoidCtx> Index<usize> for SegmentTreeCtx<M> {
32    type Output = M::Val;
33    fn index(&self, idx: usize) -> &Self::Output {
34        &self.data[self.offset + idx]
35    }
36}
37
38impl<M: MonoidCtx> SegmentTreeCtx<M> {
39    #[inline]
40    fn parse_range<R: RangeBounds<usize>>(&self, range: &R) -> Option<(usize, usize)> {
41        let start = match range.start_bound() {
42            Unbounded => 0,
43            Excluded(&v) => v + 1,
44            Included(&v) => v,
45        };
46        let end = match range.end_bound() {
47            Unbounded => self.N,
48            Excluded(&v) => v,
49            Included(&v) => v + 1,
50        };
51        if start <= end && end <= self.N {
52            Some((start, end))
53        } else {
54            None
55        }
56    }
57
58    /// セグメント木を初期化する
59    /// - 時間計算量: $`O(N)`$
60    pub fn new(N: usize, monoid: M) -> Self {
61        let offset = N.next_power_of_two();
62
63        Self {
64            N,
65            offset,
66            data: vec![monoid.e(); offset << 1],
67            monoid,
68        }
69    }
70
71    /// 配列から初期化する
72    /// - 時間計算量: $`O(N)`$
73    pub fn from_vec(src: Vec<M::Val>, monoid: M) -> Self {
74        let mut seg = Self::new(src.len(), monoid);
75        for (i, v) in src.into_iter().enumerate() {
76            seg.data[seg.offset + i] = v;
77        }
78        for i in (0..seg.offset).rev() {
79            let lch = i << 1;
80            seg.data[i] = seg.monoid.op(&seg.data[lch], &seg.data[lch + 1]);
81        }
82        seg
83    }
84
85    /// `index`番目の要素を`value`に更新する
86    /// - 時間計算量: $`O(\log N)`$
87    pub fn update(&mut self, index: usize, value: M::Val) {
88        let mut i = index + self.offset;
89        self.data[i] = value;
90        while i > 1 {
91            i >>= 1;
92            let lch = i << 1;
93            self.data[i] = self.monoid.op(&self.data[lch], &self.data[lch + 1]);
94        }
95    }
96
97    /// `i`番目の要素の可変な参照を返す
98    /// - 時間計算量: $`O(\log N)`$
99    pub fn get_mut(&mut self, i: usize) -> Option<ValMut<'_, M>> {
100        if i < self.offset {
101            let default = self.index(i).clone();
102            Some(ValMut {
103                segself: self,
104                idx: i,
105                new_val: default,
106            })
107        } else {
108            None
109        }
110    }
111
112    /// 区間`range`の集約を行う
113    /// - 時間計算量: $`O(\log N)`$
114    pub fn get_range<R: RangeBounds<usize> + Debug>(&self, range: R) -> M::Val {
115        let (start, end) = match self.parse_range(&range) {
116            Some(r) => r,
117            None => panic!("The given range is wrong: {:?}", range),
118        };
119        // 値の取得
120        let mut l = self.offset + start;
121        let mut r = self.offset + end;
122        let (mut res_l, mut res_r) = (self.monoid.e(), self.monoid.e());
123
124        while l < r {
125            if l & 1 == 1 {
126                res_l = self.monoid.op(&res_l, &self.data[l]);
127                l += 1;
128            }
129            if r & 1 == 1 {
130                r -= 1;
131                res_r = self.monoid.op(&self.data[r], &res_r);
132            }
133            l >>= 1;
134            r >>= 1;
135        }
136
137        self.monoid.op(&res_l, &res_r)
138    }
139}
140
141impl<M> Debug for SegmentTreeCtx<M>
142where
143    M: MonoidCtx,
144    M::Val: Debug,
145{
146    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
147        write!(f, "SegmentTreeCtx {{ [").ok();
148        for i in 0..self.N {
149            if i + 1 < self.N {
150                write!(f, "{:?}, ", self.data[self.offset + i]).ok();
151            } else {
152                write!(f, "{:?}", self.data[self.offset + i]).ok();
153            }
154        }
155        write!(f, "] }}")
156    }
157}
158
159/// セグメント木の要素の可変参照
160pub struct ValMut<'s, M: MonoidCtx> {
161    segself: &'s mut SegmentTreeCtx<M>,
162    idx: usize,
163    new_val: M::Val,
164}
165
166impl<M> Debug for ValMut<'_, M>
167where
168    M: MonoidCtx,
169    M::Val: Debug,
170{
171    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
172        f.debug_tuple("ValMut").field(&self.new_val).finish()
173    }
174}
175
176impl<M: MonoidCtx> Drop for ValMut<'_, M> {
177    fn drop(&mut self) {
178        self.segself.update(self.idx, self.new_val.clone());
179    }
180}
181
182impl<M: MonoidCtx> Deref for ValMut<'_, M> {
183    type Target = M::Val;
184    fn deref(&self) -> &Self::Target {
185        &self.new_val
186    }
187}
188
189impl<M: MonoidCtx> DerefMut for ValMut<'_, M> {
190    fn deref_mut(&mut self) -> &mut Self::Target {
191        &mut self.new_val
192    }
193}
194
195// セグ木上の2分探索
196impl<M: MonoidCtx> SegmentTreeCtx<M> {
197    /// 左端を固定した2分探索
198    /// - 引数`l`と関数`f`に対して,
199    ///     - `f( seg.get(l..x) ) = true`
200    ///     - `f( seg.get(l..x+1) ) = false`
201    ///
202    ///   \
203    ///   を満たす`x`を返す
204    ///
205    /// **引数**
206    /// - `f` :
207    ///   - `f(e) = true`
208    ///   - 任意の`i`に対して,`f( seg.get(l..i) ) = false`ならば,`f( seg.get(l..i+1) ) = false`
209    pub fn max_right<F>(&self, mut l: usize, f: F) -> (M::Val, usize)
210    where
211        F: Fn(M::Val) -> bool,
212    {
213        assert!(f(self.monoid.e()));
214
215        if l >= self.N {
216            return (self.monoid.e(), self.N);
217        }
218
219        l += self.offset;
220        let mut sum = self.monoid.e();
221
222        // 第1段階: 条件を満たさない区間を見つける
223        'fst: loop {
224            while l & 1 == 0 {
225                l >>= 1;
226            }
227
228            let tmp = self.monoid.op(&sum, &self.data[l]);
229
230            // 満たさない区間を発見した場合
231            if !f(tmp.clone()) {
232                break 'fst;
233            }
234
235            sum = tmp;
236            l += 1;
237
238            // すべての領域を見終わったら終了
239            if (l & l.wrapping_neg()) == l {
240                return (sum, self.N);
241            }
242        }
243
244        // 第2段階: 子方向に移動しながら2分探索
245        while l < self.offset {
246            // 左に潜る
247            l <<= 1;
248
249            let tmp = self.monoid.op(&sum, &self.data[l]);
250
251            // 左に潜っても大丈夫な場合
252            if f(tmp.clone()) {
253                sum = tmp;
254                // 右に潜る
255                l += 1;
256            }
257        }
258
259        (sum, l - self.offset)
260    }
261
262    /// 右端を固定した2分探索
263    /// - 引数`r`と関数`f`に対して,
264    ///    - `f( seg.get(x..r) ) = true`
265    ///    - `f( seg.get(x-1..r) ) = false`
266    ///
267    ///   \
268    ///   となるような`x`を返す
269    ///
270    /// **引数**
271    /// - `f` :
272    ///   - `f(e) = true`
273    ///   - 任意の`i`に対して,`f( seg.get(i..r) ) = false`ならば,`f( seg.get(i-1..r) ) = false`
274    pub fn min_left<F>(&self, mut r: usize, f: F) -> (M::Val, usize)
275    where
276        F: Fn(M::Val) -> bool,
277    {
278        assert!(f(self.monoid.e()));
279
280        if r == 0 {
281            return (self.monoid.e(), 0);
282        }
283
284        r += self.offset;
285        let mut sum = self.monoid.e();
286
287        // 第1段階: 条件を満たさない区間を見つける
288        'fst: loop {
289            r -= 1;
290            while r > 1 && r & 1 == 1 {
291                r >>= 1;
292            }
293
294            let tmp = self.monoid.op(&self.data[r], &sum);
295
296            // 満たさない区間を発見した場合
297            if !f(tmp.clone()) {
298                break 'fst;
299            }
300
301            sum = tmp;
302
303            // すべての領域を見終わったら終了
304            if (r & r.wrapping_neg()) == r {
305                return (sum, 0);
306            }
307        }
308
309        // 第2段階: 子方向に移動しながら2分探索
310        while r < self.offset {
311            // 右に潜る
312            r = (r << 1) + 1;
313
314            let tmp = self.monoid.op(&self.data[r], &sum);
315
316            // 右に潜っても大丈夫な場合
317            if f(tmp.clone()) {
318                sum = tmp;
319                // 左に潜る
320                r -= 1;
321            }
322        }
323
324        (sum, r + 1 - self.offset)
325    }
326}
327
328impl<M> ShowBinaryTree<usize> for SegmentTreeCtx<M>
329where
330    M: MonoidCtx,
331    M::Val: Debug,
332{
333    fn get_root(&self) -> Option<usize> {
334        Some(1)
335    }
336    fn get_left(&self, &i: &usize) -> Option<usize> {
337        (i * 2 < self.offset * 2).then_some(i * 2)
338    }
339    fn get_right(&self, &i: &usize) -> Option<usize> {
340        (i * 2 + 1 < self.offset * 2).then_some(i * 2 + 1)
341    }
342    fn print_node(&self, &i: &usize) -> String {
343        format!("[{:?}]", self.data[i])
344    }
345}