cp_library_rs/data_structure/
multiset_splay_tree.rs

1//! スプレー木の多重集合
2
3pub use multiset_splay_tree_::MultiSet;
4
5mod multiset_splay_tree_ {
6    //! 多重集合
7    use crate::{
8        tree::show_binary_tree::ShowBinaryTree,
9        tree::splay_tree::{
10            find::{lower_bound, upper_bound},
11            insert::{insert, insert_right},
12            iterator::{prev, NodeIterator, NodePosition, NodeRangeIterator},
13            pointer::{Node, NodeOps, NodePtr},
14            remove::remove,
15            splay::splay,
16        },
17    };
18    use std::{
19        fmt::Debug,
20        ops::{Bound, RangeBounds},
21        ptr::NonNull,
22    };
23    /// MultiSet
24    /// - 多重集合
25    pub struct MultiSet<K: Ord> {
26        pub root: Option<NodePtr<K, usize>>,
27        size: usize,
28    }
29    impl<K: Ord> Default for MultiSet<K> {
30        fn default() -> Self {
31            Self::new()
32        }
33    }
34    impl<K: Ord> MultiSet<K> {
35        /// 新規作成
36        pub fn new() -> Self {
37            Self {
38                root: None,
39                size: 0,
40            }
41        }
42        /// 要素数
43        pub fn len(&self) -> usize {
44            self.size
45        }
46        /// 空判定
47        pub fn is_empty(&self) -> bool {
48            self.size == 0
49        }
50        /// 値 `x` を持つノードのうち,最も右側にあるものを探索する
51        fn find_rightmost_node(&mut self, key: &K) -> Option<NodePtr<K, usize>> {
52            let upperbound = prev(
53                {
54                    let ub;
55                    (self.root, ub) = upper_bound(self.root, key);
56                    ub
57                },
58                &self.root,
59            );
60            match upperbound {
61                NodePosition::Node(node) if node.key() == key => Some(node),
62                _ => None,
63            }
64        }
65        /// 要素の追加
66        pub fn insert(&mut self, key: K) {
67            // 最も右側の頂点を探索
68            let rightmost = self.find_rightmost_node(&key);
69            let new_node = if let Some(rightmost) = rightmost {
70                let cnt = *rightmost.value();
71                insert_right(Some(rightmost), key, cnt + 1)
72            } else {
73                insert(self.root, key, 1)
74            };
75            self.size += 1;
76            self.root = Some(splay(new_node));
77        }
78        /// 要素の削除
79        pub fn remove(&mut self, key: &K) -> bool {
80            // 最も右側の頂点を探索
81            let Some(rightmost) = self.find_rightmost_node(key) else {
82                return false;
83            };
84            (self.root, _) = remove(rightmost);
85            self.size -= 1;
86            true
87        }
88        /// `key` に一致する要素の個数を返す
89        pub fn count(&mut self, key: &K) -> usize {
90            // 最も右側の頂点を探索
91            let rightmost = self.find_rightmost_node(key);
92            if let Some(rightmost) = rightmost {
93                *rightmost.value()
94            } else {
95                0
96            }
97        }
98        /// 指定した区間のイテレータを返す
99        pub fn range<R: RangeBounds<K>>(&mut self, range: R) -> NodeRangeIterator<'_, K, usize> {
100            let left = match range.start_bound() {
101                Bound::Unbounded => NodePosition::Inf,
102                Bound::Included(x) => prev(
103                    {
104                        let lb;
105                        (self.root, lb) = lower_bound(self.root, x);
106                        lb
107                    },
108                    &self.root,
109                ),
110                Bound::Excluded(x) => prev(
111                    {
112                        let ub;
113                        (self.root, ub) = upper_bound(self.root, x);
114                        ub
115                    },
116                    &self.root,
117                ),
118            };
119            let right = match range.end_bound() {
120                Bound::Unbounded => NodePosition::Sup,
121                Bound::Included(x) => {
122                    let ub;
123                    (self.root, ub) = upper_bound(self.root, x);
124                    ub
125                }
126                Bound::Excluded(x) => {
127                    let lb;
128                    (self.root, lb) = lower_bound(self.root, x);
129                    lb
130                }
131            };
132            NodeRangeIterator::new(&self.root, left, right)
133        }
134        /// ノードのイテレータを返す
135        pub fn iter(&self) -> NodeIterator<'_, K, usize> {
136            NodeIterator::first(&self.root)
137        }
138    }
139    impl<K: Ord + Clone + Debug> Debug for MultiSet<K> {
140        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
141            f.debug_set()
142                .entries(NodeIterator::first(&self.root).map(|node| node.key().clone()))
143                .finish()
144        }
145    }
146
147    // ==================== ShowBinaryTree ====================
148    /// ShowBinaryTree 用の「ポインタ」
149    #[derive(Clone, Copy)]
150    pub struct TreePtr<K: Ord>(NonNull<Node<K, usize>>);
151
152    impl<K: Ord> TreePtr<K> {
153        #[inline]
154        fn mk_ptr(node: &Node<K, usize>) -> TreePtr<K> {
155            TreePtr(NonNull::from(node))
156        }
157    }
158
159    impl<K: Ord + Debug> ShowBinaryTree<TreePtr<K>> for MultiSet<K> {
160        fn get_left(&self, ptr: &TreePtr<K>) -> Option<TreePtr<K>> {
161            // 読み取り専用だが,trait が &mut self を要求する
162            let t = unsafe { ptr.0.as_ref() };
163            let left = unsafe { t.left.as_ref()?.as_ref() };
164            Some(TreePtr::mk_ptr(left))
165        }
166
167        fn get_right(&self, ptr: &TreePtr<K>) -> Option<TreePtr<K>> {
168            let t = unsafe { ptr.0.as_ref() };
169            let right = unsafe { t.right.as_ref()?.as_ref() };
170            Some(TreePtr::mk_ptr(right))
171        }
172
173        fn get_root(&self) -> Option<TreePtr<K>> {
174            let root = unsafe { self.root.as_ref()?.as_ref() };
175            Some(TreePtr::mk_ptr(root))
176        }
177
178        fn print_node(&self, ptr: &TreePtr<K>) -> String {
179            let t = unsafe { ptr.0.as_ref() };
180            format!("{t:?}")
181        }
182    }
183}