cp_library_rs/data_structure/
dynamic_segment_tree.rs

1//! 必要なところだけ作る(区間分割型)セグ木(arena 版)
2//!
3//! - dynamic segment tree(implicit segment tree)
4//! - 添字区間は [min, max)
5
6use std::{
7    fmt::{self, Debug},
8    ops::{Bound::*, Deref, DerefMut, RangeBounds},
9};
10
11use num::ToPrimitive;
12use num_traits::PrimInt;
13
14use crate::{
15    algebraic_structure::actedmonoid_with_size::ActedMonoidWithSize,
16    tree::arena::{Arena, ArenaNode, Ptr},
17    tree::show_binary_tree::ShowBinaryTree,
18};
19
20type A<M> = Arena<NodeInner<M>>;
21
22// ========== node ==========
23
24/// 区間分割型ノード(`I` を持たない)
25struct NodeInner<M: ActedMonoidWithSize> {
26    /// 区間の集約値
27    sum: M::Val,
28    /// 遅延作用
29    act: M::Act,
30    left: Option<Ptr>,
31    right: Option<Ptr>,
32}
33
34impl<M: ActedMonoidWithSize> ArenaNode for NodeInner<M> {}
35
36impl<M: ActedMonoidWithSize> NodeInner<M> {
37    fn with_length(len: usize) -> Self {
38        Self {
39            sum: M::e_with_size(len),
40            act: M::id(),
41            left: None,
42            right: None,
43        }
44    }
45}
46
47// ========== dynamic segment tree ==========
48
49pub struct DynamicSegmentTree<I: PrimInt, M: ActedMonoidWithSize> {
50    min_index: I,
51    max_index: I,
52    pub n: I,
53    arena: A<M>,
54    root: Option<Ptr>,
55}
56
57impl<I: PrimInt + ToPrimitive, M: ActedMonoidWithSize> DynamicSegmentTree<I, M> {
58    pub fn new(min: I, max: I) -> Self {
59        assert!(min < max);
60        Self {
61            min_index: min,
62            max_index: max,
63            n: max - min,
64            arena: A::new(),
65            root: None,
66        }
67    }
68
69    /// 点更新
70    /// - 計算量:\(O(\log (max-min))\)
71    pub fn update(&mut self, index: I, val: M::Val) {
72        assert!(self.min_index <= index && index < self.max_index);
73        let root = self.root.take();
74        self.root = self.update_inner(root, self.min_index, self.max_index, index, val);
75    }
76
77    /// 点取得(未生成は `M::e_with_size()`)
78    /// - 遅延があるので `push` するため `&mut self`
79    pub fn get(&mut self, index: I) -> M::Val {
80        assert!(self.min_index <= index && index < self.max_index);
81        self.get_inner(self.root, self.min_index, self.max_index, index)
82    }
83
84    /// 区間の集約値(RangeBounds を受け取る)
85    /// - 遅延があるので `push` するため `&mut self`
86    pub fn get_range<R: RangeBounds<I> + Debug>(&mut self, range: R) -> M::Val {
87        let (l, r) = self
88            .parse_range(&range)
89            .unwrap_or_else(|| panic!("The given range is wrong: {:?}", range));
90        self.get_range_inner(self.root, self.min_index, self.max_index, l, r)
91    }
92
93    /// 区間に作用を適用(遅延)
94    pub fn apply<R: RangeBounds<I> + Debug>(&mut self, range: R, act: M::Act) {
95        let (l, r) = self
96            .parse_range(&range)
97            .unwrap_or_else(|| panic!("The given range is wrong: {:?}", range));
98        let root = self.root.take();
99        self.root = self.apply_inner(root, self.min_index, self.max_index, l, r, &act);
100    }
101
102    /// `get_mut`(Drop で `update`)
103    pub fn get_mut(&mut self, i: I) -> Option<ValMut<'_, I, M>> {
104        if self.min_index <= i && i < self.max_index {
105            let default = self.get(i);
106            Some(ValMut {
107                segself: self,
108                idx: i,
109                new_val: default,
110            })
111        } else {
112            None
113        }
114    }
115
116    // ========== internal helpers ==========
117
118    #[inline]
119    fn is_leaf(seg_l: I, seg_r: I) -> bool {
120        seg_r - seg_l == I::one()
121    }
122
123    #[inline]
124    fn two() -> I {
125        I::one() + I::one()
126    }
127
128    #[inline]
129    fn mid(l: I, r: I) -> I {
130        l + (r - l) / Self::two()
131    }
132
133    #[inline]
134    fn len(l: I, r: I) -> usize {
135        (r - l).to_usize().unwrap()
136    }
137
138    #[inline]
139    fn parse_range<R: RangeBounds<I>>(&self, range: &R) -> Option<(I, I)> {
140        let start = match range.start_bound() {
141            Unbounded => self.min_index,
142            Excluded(&v) => v + I::one(),
143            Included(&v) => v,
144        };
145        let end = match range.end_bound() {
146            Unbounded => self.max_index,
147            Excluded(&v) => v,
148            Included(&v) => v + I::one(),
149        };
150        if self.min_index <= start && start <= end && end <= self.max_index {
151            Some((start, end))
152        } else {
153            None
154        }
155    }
156
157    #[inline]
158    fn sum_of(arena: &A<M>, node: Option<Ptr>, len: usize) -> M::Val {
159        node.map(|p| arena.get(p).sum.clone())
160            .unwrap_or_else(|| M::e_with_size(len))
161    }
162
163    #[inline]
164    fn apply_node(arena: &mut A<M>, ptr: Ptr, act: &M::Act) {
165        let nsum = {
166            let v = arena.get(ptr);
167            M::mapping(&v.sum, act)
168        };
169        let nact = {
170            let v = arena.get(ptr);
171            M::compose(&v.act, act)
172        };
173        let v = arena.get_mut(ptr);
174        v.sum = nsum;
175        v.act = nact;
176    }
177
178    #[inline]
179    fn ensure_left(arena: &mut A<M>, ptr: Ptr, len: usize) -> Ptr {
180        if let Some(lp) = arena.get(ptr).left {
181            lp
182        } else {
183            let lp = arena.alloc(NodeInner::with_length(len));
184            arena.get_mut(ptr).left = Some(lp);
185            lp
186        }
187    }
188
189    #[inline]
190    fn ensure_right(arena: &mut A<M>, ptr: Ptr, len: usize) -> Ptr {
191        if let Some(rp) = arena.get(ptr).right {
192            rp
193        } else {
194            let rp = arena.alloc(NodeInner::with_length(len));
195            arena.get_mut(ptr).right = Some(rp);
196            rp
197        }
198    }
199
200    /// 子へ遅延伝播
201    #[inline]
202    fn push(&mut self, ptr: Ptr, seg_l: I, seg_r: I) {
203        if Self::is_leaf(seg_l, seg_r) {
204            return;
205        }
206        let act = { self.arena.get(ptr).act.clone() };
207        if act == M::id() {
208            return;
209        }
210
211        let mid = Self::mid(seg_l, seg_r);
212        let llen = Self::len(seg_l, mid);
213        let rlen = Self::len(mid, seg_r);
214        let lp = Self::ensure_left(&mut self.arena, ptr, llen);
215        let rp = Self::ensure_right(&mut self.arena, ptr, rlen);
216
217        Self::apply_node(&mut self.arena, lp, &act);
218        Self::apply_node(&mut self.arena, rp, &act);
219
220        self.arena.get_mut(ptr).act = M::id();
221    }
222
223    /// 子の情報を吸い上げ
224    #[inline]
225    fn pull(&mut self, ptr: Ptr, l: I, r: I) {
226        let lp = self.arena.get(ptr).left;
227        let rp = self.arena.get(ptr).right;
228
229        let mid = Self::mid(l, r);
230        let llen = Self::len(l, mid);
231        let rlen = Self::len(mid, r);
232
233        let lsum = Self::sum_of(&self.arena, lp, llen);
234        let rsum = Self::sum_of(&self.arena, rp, rlen);
235
236        self.arena.get_mut(ptr).sum = M::op(&lsum, &rsum);
237    }
238
239    // ========== recursions ==========
240
241    fn update_inner(
242        &mut self,
243        node: Option<Ptr>,
244        seg_l: I,
245        seg_r: I,
246        index: I,
247        val: M::Val,
248    ) -> Option<Ptr> {
249        let len = Self::len(seg_l, seg_r);
250        let ptr = node.unwrap_or_else(|| self.arena.alloc(NodeInner::with_length(len)));
251
252        if Self::is_leaf(seg_l, seg_r) {
253            let v = self.arena.get_mut(ptr);
254            v.sum = val;
255            v.act = M::id();
256            v.left = None;
257            v.right = None;
258            return Some(ptr);
259        }
260
261        self.push(ptr, seg_l, seg_r);
262
263        let mid = Self::mid(seg_l, seg_r);
264        if index < mid {
265            let left = self.arena.get(ptr).left;
266            let nl = self.update_inner(left, seg_l, mid, index, val);
267            self.arena.get_mut(ptr).left = nl;
268        } else {
269            let right = self.arena.get(ptr).right;
270            let nr = self.update_inner(right, mid, seg_r, index, val);
271            self.arena.get_mut(ptr).right = nr;
272        }
273
274        self.pull(ptr, seg_l, seg_r);
275        Some(ptr)
276    }
277
278    fn get_inner(&mut self, node: Option<Ptr>, seg_l: I, seg_r: I, index: I) -> M::Val {
279        let len = Self::len(seg_l, seg_r);
280
281        let Some(ptr) = node else {
282            return M::e_with_size(len);
283        };
284        if Self::is_leaf(seg_l, seg_r) {
285            return self.arena.get(ptr).sum.clone();
286        }
287
288        self.push(ptr, seg_l, seg_r);
289
290        let mid = Self::mid(seg_l, seg_r);
291        if index < mid {
292            self.get_inner(self.arena.get(ptr).left, seg_l, mid, index)
293        } else {
294            self.get_inner(self.arena.get(ptr).right, mid, seg_r, index)
295        }
296    }
297
298    fn get_range_inner(&mut self, node: Option<Ptr>, seg_l: I, seg_r: I, ql: I, qr: I) -> M::Val {
299        let len = Self::len(seg_l, seg_r);
300
301        if qr <= seg_l || seg_r <= ql {
302            return M::e_with_size(len);
303        }
304
305        let Some(ptr) = node else {
306            return M::e_with_size(len);
307        };
308
309        if ql <= seg_l && seg_r <= qr {
310            return self.arena.get(ptr).sum.clone();
311        }
312
313        if Self::is_leaf(seg_l, seg_r) {
314            return self.arena.get(ptr).sum.clone();
315        }
316
317        self.push(ptr, seg_l, seg_r);
318
319        let mid = Self::mid(seg_l, seg_r);
320        let a = self.get_range_inner(self.arena.get(ptr).left, seg_l, mid, ql, qr);
321        let b = self.get_range_inner(self.arena.get(ptr).right, mid, seg_r, ql, qr);
322        M::op(&a, &b)
323    }
324
325    fn apply_inner(
326        &mut self,
327        node: Option<Ptr>,
328        seg_l: I,
329        seg_r: I,
330        ql: I,
331        qr: I,
332        act: &M::Act,
333    ) -> Option<Ptr> {
334        if qr <= seg_l || seg_r <= ql {
335            return node;
336        }
337
338        let len = Self::len(seg_l, seg_r);
339        let ptr = node.unwrap_or_else(|| self.arena.alloc(NodeInner::with_length(len)));
340
341        if ql <= seg_l && seg_r <= qr {
342            Self::apply_node(&mut self.arena, ptr, act);
343            return Some(ptr);
344        }
345
346        if Self::is_leaf(seg_l, seg_r) {
347            Self::apply_node(&mut self.arena, ptr, act);
348            return Some(ptr);
349        }
350
351        self.push(ptr, seg_l, seg_r);
352
353        let mid = Self::mid(seg_l, seg_r);
354
355        let left = self.arena.get(ptr).left;
356        let nl = self.apply_inner(left, seg_l, mid, ql, qr, act);
357        self.arena.get_mut(ptr).left = nl;
358
359        let right = self.arena.get(ptr).right;
360        let nr = self.apply_inner(right, mid, seg_r, ql, qr, act);
361        self.arena.get_mut(ptr).right = nr;
362
363        self.pull(ptr, seg_l, seg_r);
364        Some(ptr)
365    }
366}
367
368// ========== ValMut ==========
369
370pub struct ValMut<'a, I, M>
371where
372    I: PrimInt,
373    M: ActedMonoidWithSize,
374{
375    segself: &'a mut DynamicSegmentTree<I, M>,
376    idx: I,
377    new_val: M::Val,
378}
379
380impl<I, M> Debug for ValMut<'_, I, M>
381where
382    I: PrimInt,
383    M: ActedMonoidWithSize,
384    M::Val: Debug,
385{
386    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
387        f.debug_tuple("ValMut").field(&self.new_val).finish()
388    }
389}
390
391impl<I, M> Drop for ValMut<'_, I, M>
392where
393    I: PrimInt,
394    M: ActedMonoidWithSize,
395{
396    fn drop(&mut self) {
397        self.segself.update(self.idx, self.new_val.clone());
398    }
399}
400
401impl<I, M> Deref for ValMut<'_, I, M>
402where
403    I: PrimInt,
404    M: ActedMonoidWithSize,
405{
406    type Target = M::Val;
407    fn deref(&self) -> &Self::Target {
408        &self.new_val
409    }
410}
411
412impl<I, M> DerefMut for ValMut<'_, I, M>
413where
414    I: PrimInt,
415    M: ActedMonoidWithSize,
416{
417    fn deref_mut(&mut self) -> &mut Self::Target {
418        &mut self.new_val
419    }
420}
421
422// ========== max_right / min_left ==========
423//
424// 遅延があるので `push` が必要.そのため `&mut self` にする.
425
426impl<I, M> DynamicSegmentTree<I, M>
427where
428    I: PrimInt,
429    M: ActedMonoidWithSize,
430    M::Val: Debug,
431{
432    /// 左端固定二分探索
433    /// 返り値:(`get_range(l..x)`, `x`)
434    pub fn max_right<F>(&mut self, l: I, f: F) -> (M::Val, I)
435    where
436        F: Fn(M::Val) -> bool,
437    {
438        assert!(f(M::e_with_size(0)));
439        assert!(self.min_index <= l && l <= self.max_index);
440
441        let mut acc = M::e_with_size(0);
442        let x = self.max_right_inner(self.root, self.min_index, self.max_index, l, &f, &mut acc);
443        (acc, x)
444    }
445
446    fn max_right_inner<F>(
447        &mut self,
448        node: Option<Ptr>,
449        seg_l: I,
450        seg_r: I,
451        ql: I,
452        f: &F,
453        acc: &mut M::Val,
454    ) -> I
455    where
456        F: Fn(M::Val) -> bool,
457    {
458        if seg_r <= ql {
459            return seg_r;
460        }
461
462        let Some(ptr) = node else {
463            // 未生成区間は全て `M::e_with_size()` なので,どこまで進んでも `acc` は変わらない
464            return seg_r;
465        };
466
467        if ql <= seg_l {
468            let tmp = M::op(acc, &self.arena.get(ptr).sum);
469            if f(tmp.clone()) {
470                *acc = tmp;
471                return seg_r;
472            }
473        }
474
475        if Self::is_leaf(seg_l, seg_r) {
476            return seg_l;
477        }
478
479        self.push(ptr, seg_l, seg_r);
480
481        let mid = Self::mid(seg_l, seg_r);
482        let left_res = self.max_right_inner(self.arena.get(ptr).left, seg_l, mid, ql, f, acc);
483        if left_res != mid {
484            return left_res;
485        }
486        self.max_right_inner(self.arena.get(ptr).right, mid, seg_r, ql, f, acc)
487    }
488
489    /// 右端固定二分探索
490    /// 返り値:(`get_range(x..r)`, `x`)
491    pub fn min_left<F>(&mut self, r: I, f: F) -> (M::Val, I)
492    where
493        F: Fn(M::Val) -> bool,
494    {
495        assert!(f(M::e_with_size(0)));
496        assert!(self.min_index <= r && r <= self.max_index);
497
498        let mut acc = M::e_with_size(0);
499        let x = self.min_left_inner(self.root, self.min_index, self.max_index, r, &f, &mut acc);
500        (acc, x)
501    }
502
503    fn min_left_inner<F>(
504        &mut self,
505        node: Option<Ptr>,
506        seg_l: I,
507        seg_r: I,
508        qr: I,
509        f: &F,
510        acc: &mut M::Val,
511    ) -> I
512    where
513        F: Fn(M::Val) -> bool,
514    {
515        if qr <= seg_l {
516            return seg_l;
517        }
518
519        let Some(ptr) = node else {
520            return seg_l;
521        };
522
523        if seg_r <= qr {
524            let tmp = M::op(&self.arena.get(ptr).sum, acc);
525            if f(tmp.clone()) {
526                *acc = tmp;
527                return seg_l;
528            }
529        }
530
531        if Self::is_leaf(seg_l, seg_r) {
532            return seg_r;
533        }
534
535        self.push(ptr, seg_l, seg_r);
536
537        let mid = Self::mid(seg_l, seg_r);
538        let right_res = self.min_left_inner(self.arena.get(ptr).right, mid, seg_r, qr, f, acc);
539        if right_res != mid {
540            return right_res;
541        }
542        self.min_left_inner(self.arena.get(ptr).left, seg_l, mid, qr, f, acc)
543    }
544}
545
546// ========== ShowBinaryTree ==========
547
548impl<I, M> ShowBinaryTree<Ptr> for DynamicSegmentTree<I, M>
549where
550    I: PrimInt,
551    M: ActedMonoidWithSize,
552    M::Val: Debug,
553    M::Act: Debug,
554{
555    fn get_root(&self) -> Option<Ptr> {
556        self.root
557    }
558
559    fn get_left(&self, ptr: &Ptr) -> Option<Ptr> {
560        self.arena.get(*ptr).left
561    }
562
563    fn get_right(&self, ptr: &Ptr) -> Option<Ptr> {
564        self.arena.get(*ptr).right
565    }
566
567    fn print_node(&self, ptr: &Ptr) -> String {
568        let node = self.arena.get(*ptr);
569        format!("[val:{:?}, act:{:?}]", node.sum, node.act)
570    }
571}