cp_library_rs/data_structure/
bbt_aa.rs

1//! AA木による動的セグ木
2//! - 遅延評価なし
3
4use crate::algebraic_structure::monoid::Monoid;
5
6use std::{
7    fmt::{self, Debug},
8    ops::{Bound::Unbounded, Deref, DerefMut, RangeBounds},
9};
10
11use crate::tree::aa_tree::{delete, get, get_range, insert, Node, NodeInner};
12
13/// 平衡2分木
14/// - 1点更新,区間取得
15/// - 遅延評価なし
16pub struct BalancedBinaryTree<K: Ord, M: Monoid> {
17    pub root: Node<K, M>,
18    size: usize,
19    /// getメソッドで返すための一時的な単位元
20    tmp_e: M::Val,
21}
22
23impl<K: Ord, M: Monoid> BalancedBinaryTree<K, M> {
24    /// 1点取得(不変参照)
25    /// - 値 `key` を持つノードの不変参照を取得する
26    pub fn get(&self, key: &K) -> &M::Val {
27        if let Some(NodeInner { value, .. }) = get(&self.root, key) {
28            value
29        } else {
30            &self.tmp_e
31        }
32    }
33
34    /// 1点取得(可変参照)
35    /// - 値 `key` を持つノードの可変参照を取得する
36    pub fn get_mut(&mut self, key: K) -> NodeEntry<'_, K, M> {
37        let (new_root, old_key_val) = delete(self.root.take(), &key);
38        self.root = new_root;
39
40        if let Some((key, value)) = old_key_val {
41            NodeEntry {
42                root: &mut self.root,
43                key: Some(key),
44                value: Some(value),
45            }
46        } else {
47            // ノードの新規作成
48            self.size += 1;
49            NodeEntry {
50                root: &mut self.root,
51                key: Some(key),
52                value: Some(M::e()),
53            }
54        }
55    }
56
57    /// 要素の更新
58    /// - `key`:更新するキー
59    /// - `value`:更新後の値
60    pub fn insert(&mut self, key: K, value: M::Val) {
61        let (new_root, old_key_value) = insert(self.root.take(), key, value);
62        self.root = new_root;
63        // 要素が追加された場合
64        if old_key_value.is_none() {
65            self.size += 1;
66        }
67    }
68
69    /// 要素の削除
70    /// - `key`:削除するキー
71    pub fn remove(&mut self, key: &K) -> Option<M::Val> {
72        let (new_root, old_key_value) = delete(self.root.take(), key);
73        self.root = new_root;
74        // 削除された要素を返す
75        if let Some((_, old_value)) = old_key_value {
76            self.size -= 1;
77            Some(old_value)
78        } else {
79            None
80        }
81    }
82
83    /// 区間の取得
84    /// - 区間 `range` の要素を集約する
85    pub fn get_range<R: RangeBounds<K>>(&self, range: R) -> M::Val {
86        let l = range.start_bound();
87        let r = range.end_bound();
88        get_range(&self.root, l, r, Unbounded, Unbounded)
89    }
90
91    /// 要素数を取得
92    pub fn len(&self) -> usize {
93        self.size
94    }
95
96    pub fn is_empty(&self) -> bool {
97        self.size == 0
98    }
99}
100
101/// ノードの可変参照
102pub struct NodeEntry<'a, K: Ord, M: 'a + Monoid> {
103    root: &'a mut Node<K, M>,
104    key: Option<K>,
105    value: Option<M::Val>,
106}
107
108impl<K, M> Debug for NodeEntry<'_, K, M>
109where
110    K: Ord + Debug,
111    M: Monoid,
112    M::Val: Debug,
113{
114    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
115        f.debug_struct("NodeEntry")
116            .field("key", &self.key.as_ref().unwrap())
117            .field("value", &self.value.as_ref().unwrap())
118            .finish()
119    }
120}
121
122impl<K: Ord, M: Monoid> Drop for NodeEntry<'_, K, M> {
123    fn drop(&mut self) {
124        let root = self.root.take();
125        let key = self.key.take().unwrap();
126        let value = self.value.take().unwrap();
127        (*self.root, _) = insert(root, key, value);
128    }
129}
130
131impl<K: Ord, M: Monoid> Deref for NodeEntry<'_, K, M> {
132    type Target = M::Val;
133    fn deref(&self) -> &Self::Target {
134        self.value.as_ref().unwrap()
135    }
136}
137
138impl<K: Ord, M: Monoid> DerefMut for NodeEntry<'_, K, M> {
139    fn deref_mut(&mut self) -> &mut Self::Target {
140        self.value.as_mut().unwrap()
141    }
142}
143
144impl<K: Ord, M: Monoid> Default for BalancedBinaryTree<K, M> {
145    /// 動的セグ木の初期化
146    fn default() -> Self {
147        Self {
148            root: None,
149            size: 0,
150            tmp_e: M::e(),
151        }
152    }
153}