cp_library_rs/data_structure/
segment_tree.rs

1//! ## セグメント木
2//!
3//! 集合 $`S`$ と演算 $`\circ`$ の組 $`(S,\circ)`$ がモノイド([`Monoid`])であるとき,
4//! $`S`$ の要素の列 $`A`$ に対し,
5//!
6//! - 区間積の取得 : $`A[l] \circ A[l+1] \circ \cdots \circ A[r]`$
7//! - 要素の更新 : $`A[i] \leftarrow x`$
8//!
9//! をそれぞれ $`O(\log N)`$ で行う.($`N = |A|`$)
10
11use crate::{algebraic_structure::monoid::Monoid, tree::show_binary_tree::ShowBinaryTree};
12use std::{
13    fmt::{self, Debug},
14    ops::{
15        Bound::{Excluded, Included, Unbounded},
16        Deref, DerefMut, Index, RangeBounds,
17    },
18};
19
20/// セグメント木
21pub struct SegmentTree<M: Monoid> {
22    /// 要素数
23    pub N: usize,
24    offset: usize,
25    data: Vec<M::Val>,
26}
27
28impl<M: Monoid> Index<usize> for SegmentTree<M> {
29    type Output = M::Val;
30    fn index(&self, idx: usize) -> &Self::Output {
31        &self.data[self.offset + idx]
32    }
33}
34
35impl<M: Monoid> SegmentTree<M> {
36    #[inline]
37    fn parse_range<R: RangeBounds<usize>>(&self, range: &R) -> Option<(usize, usize)> {
38        let start = match range.start_bound() {
39            Unbounded => 0,
40            Excluded(&v) => v + 1,
41            Included(&v) => v,
42        };
43        let end = match range.end_bound() {
44            Unbounded => self.N,
45            Excluded(&v) => v,
46            Included(&v) => v + 1,
47        };
48        if start <= end && end <= self.N {
49            Some((start, end))
50        } else {
51            None
52        }
53    }
54
55    /// セグメント木を初期化する
56    /// - 時間計算量: $`O(N)`$
57    pub fn new(N: usize) -> Self {
58        let offset = N.next_power_of_two();
59
60        Self {
61            N,
62            offset,
63            data: vec![M::e(); offset << 1],
64        }
65    }
66
67    /// 配列から初期化する
68    /// - 時間計算量: $`O(N)`$
69    pub fn from_vec(src: Vec<M::Val>) -> Self {
70        let mut seg = Self::new(src.len());
71        for (i, v) in src.into_iter().enumerate() {
72            seg.data[seg.offset + i] = v;
73        }
74        for i in (0..seg.offset).rev() {
75            let lch = i << 1;
76            seg.data[i] = M::op(&seg.data[lch], &seg.data[lch + 1]);
77        }
78        seg
79    }
80
81    /// `index`番目の要素を`value`に更新する
82    /// - 時間計算量: $`O(\log N)`$
83    pub fn update(&mut self, index: usize, value: M::Val) {
84        let mut i = index + self.offset;
85        self.data[i] = value;
86        while i > 1 {
87            i >>= 1;
88            let lch = i << 1;
89            self.data[i] = M::op(&self.data[lch], &self.data[lch + 1]);
90        }
91    }
92
93    /// `i`番目の要素の可変な参照を返す
94    /// - 時間計算量: $`O(\log N)`$
95    pub fn get_mut(&mut self, i: usize) -> Option<ValMut<'_, M>> {
96        if i < self.offset {
97            let default = self.index(i).clone();
98            Some(ValMut {
99                segself: self,
100                idx: i,
101                new_val: default,
102            })
103        } else {
104            None
105        }
106    }
107
108    /// 区間`range`の集約を行う
109    /// - 時間計算量: $`O(\log N)`$
110    pub fn get_range<R: RangeBounds<usize> + Debug>(&self, range: R) -> M::Val {
111        let (start, end) = match self.parse_range(&range) {
112            Some(r) => r,
113            None => panic!("The given range is wrong: {:?}", range),
114        };
115        // 値の取得
116        let mut l = self.offset + start;
117        let mut r = self.offset + end;
118        let (mut res_l, mut res_r) = (M::e(), M::e());
119
120        while l < r {
121            if l & 1 == 1 {
122                res_l = M::op(&res_l, &self.data[l]);
123                l += 1;
124            }
125            if r & 1 == 1 {
126                r -= 1;
127                res_r = M::op(&self.data[r], &res_r);
128            }
129            l >>= 1;
130            r >>= 1;
131        }
132
133        M::op(&res_l, &res_r)
134    }
135}
136
137impl<M: Monoid> FromIterator<M::Val> for SegmentTree<M> {
138    fn from_iter<T: IntoIterator<Item = M::Val>>(iter: T) -> Self {
139        // 配列にする
140        let arr: Vec<M::Val> = iter.into_iter().collect();
141        Self::from_vec(arr)
142    }
143}
144
145impl<M> Debug for SegmentTree<M>
146where
147    M: Monoid,
148    M::Val: Debug,
149{
150    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
151        write!(f, "SegmentTree {{ [").ok();
152        for i in 0..self.N {
153            if i + 1 < self.N {
154                write!(f, "{:?}, ", self.data[self.offset + i]).ok();
155            } else {
156                write!(f, "{:?}", self.data[self.offset + i]).ok();
157            }
158        }
159        write!(f, "] }}")
160    }
161}
162
163/// セグメント木の要素の可変参照
164pub struct ValMut<'a, M: 'a + Monoid> {
165    segself: &'a mut SegmentTree<M>,
166    idx: usize,
167    new_val: M::Val,
168}
169
170impl<M> Debug for ValMut<'_, M>
171where
172    M: Monoid,
173    M::Val: Debug,
174{
175    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
176        f.debug_tuple("ValMut").field(&self.new_val).finish()
177    }
178}
179
180impl<M: Monoid> Drop for ValMut<'_, M> {
181    fn drop(&mut self) {
182        self.segself.update(self.idx, self.new_val.clone());
183    }
184}
185
186impl<M: Monoid> Deref for ValMut<'_, M> {
187    type Target = M::Val;
188    fn deref(&self) -> &Self::Target {
189        &self.new_val
190    }
191}
192
193impl<M: Monoid> DerefMut for ValMut<'_, M> {
194    fn deref_mut(&mut self) -> &mut Self::Target {
195        &mut self.new_val
196    }
197}
198
199// セグ木上の2分探索
200impl<M: Monoid> SegmentTree<M> {
201    /// 左端を固定した2分探索
202    /// - 引数`l`と関数`f`に対して,
203    ///     - `f( seg.get(l..x) ) = true`
204    ///     - `f( seg.get(l..x+1) ) = false`
205    ///
206    ///   \
207    ///   を満たす`x`を返す
208    ///
209    /// **引数**
210    /// - `f` :
211    ///   - `f(e) = true`
212    ///   - 任意の`i`に対して,`f( seg.get(l..i) ) = false`ならば,`f( seg.get(l..i+1) ) = false`
213    pub fn max_right<F>(&self, mut l: usize, f: F) -> (M::Val, usize)
214    where
215        F: Fn(M::Val) -> bool,
216    {
217        assert!(f(M::e()));
218
219        if l >= self.N {
220            return (M::e(), self.N);
221        }
222
223        l += self.offset;
224        let mut sum = M::e();
225
226        // 第1段階: 条件を満たさない区間を見つける
227        'fst: loop {
228            while l & 1 == 0 {
229                l >>= 1;
230            }
231
232            let tmp = M::op(&sum, &self.data[l]);
233
234            // 満たさない区間を発見した場合
235            if !f(tmp.clone()) {
236                break 'fst;
237            }
238
239            sum = tmp;
240            l += 1;
241
242            // すべての領域を見終わったら終了
243            if (l & l.wrapping_neg()) == l {
244                return (sum, self.N);
245            }
246        }
247
248        // 第2段階: 子方向に移動しながら2分探索
249        while l < self.offset {
250            // 左に潜る
251            l <<= 1;
252
253            let tmp = M::op(&sum, &self.data[l]);
254
255            // 左に潜っても大丈夫な場合
256            if f(tmp.clone()) {
257                sum = tmp;
258                // 右に潜る
259                l += 1;
260            }
261        }
262
263        (sum, l - self.offset)
264    }
265
266    /// 右端を固定した2分探索
267    /// - 引数`r`と関数`f`に対して,
268    ///    - `f( seg.get(x..r) ) = true`
269    ///    - `f( seg.get(x-1..r) ) = false`
270    ///
271    ///   \
272    ///   となるような`x`を返す
273    ///
274    /// **引数**
275    /// - `f` :
276    ///   - `f(e) = true`
277    ///   - 任意の`i`に対して,`f( seg.get(i..r) ) = false`ならば,`f( seg.get(i-1..r) ) = false`
278    pub fn min_left<F>(&self, mut r: usize, f: F) -> (M::Val, usize)
279    where
280        F: Fn(M::Val) -> bool,
281    {
282        assert!(f(M::e()));
283
284        if r == 0 {
285            return (M::e(), 0);
286        }
287
288        r += self.offset;
289        let mut sum = M::e();
290
291        // 第1段階: 条件を満たさない区間を見つける
292        'fst: loop {
293            r -= 1;
294            while r > 1 && r & 1 == 1 {
295                r >>= 1;
296            }
297
298            let tmp = M::op(&self.data[r], &sum);
299
300            // 満たさない区間を発見した場合
301            if !f(tmp.clone()) {
302                break 'fst;
303            }
304
305            sum = tmp;
306
307            // すべての領域を見終わったら終了
308            if (r & r.wrapping_neg()) == r {
309                return (sum, 0);
310            }
311        }
312
313        // 第2段階: 子方向に移動しながら2分探索
314        while r < self.offset {
315            // 右に潜る
316            r = (r << 1) + 1;
317
318            let tmp = M::op(&self.data[r], &sum);
319
320            // 右に潜っても大丈夫な場合
321            if f(tmp.clone()) {
322                sum = tmp;
323                // 左に潜る
324                r -= 1;
325            }
326        }
327
328        (sum, r + 1 - self.offset)
329    }
330}
331
332impl<M> ShowBinaryTree<usize> for SegmentTree<M>
333where
334    M: Monoid,
335    M::Val: Debug,
336{
337    fn get_root(&self) -> Option<usize> {
338        Some(1)
339    }
340    fn get_left(&self, &i: &usize) -> Option<usize> {
341        (i * 2 < self.offset * 2).then_some(i * 2)
342    }
343    fn get_right(&self, &i: &usize) -> Option<usize> {
344        (i * 2 + 1 < self.offset * 2).then_some(i * 2 + 1)
345    }
346    fn print_node(&self, &i: &usize) -> String {
347        format!("[{:?}]", self.data[i])
348    }
349}