use crate::algebraic_structure::monoid;
pub use dynamic_segment_tree_::*;
mod dynamic_segment_tree_ {
use std::{
fmt::{self, Debug},
ops::{Bound::Unbounded, Deref, DerefMut, RangeBounds},
};
use super::{
monoid::Monoid,
node::{delete, get, get_range, insert, Node, NodeInner},
};
pub struct DynamicSegmentTree<K: Ord, M: Monoid> {
pub root: Node<K, M>,
size: usize,
tmp_e: M::Val,
}
impl<K: Ord, M: Monoid> DynamicSegmentTree<K, M> {
pub fn get(&self, key: &K) -> &M::Val {
if let Some(NodeInner { value, .. }) = get(&self.root, key) {
value
} else {
&self.tmp_e
}
}
pub fn get_mut(&mut self, key: K) -> NodeEntry<'_, K, M> {
let (new_root, old_key_val) = delete(self.root.take(), &key);
self.root = new_root;
if let Some((key, value)) = old_key_val {
NodeEntry {
root: &mut self.root,
key: Some(key),
value: Some(value),
}
} else {
self.size += 1;
NodeEntry {
root: &mut self.root,
key: Some(key),
value: Some(M::id()),
}
}
}
pub fn insert(&mut self, key: K, value: M::Val) {
let (new_root, old_key_value) = insert(self.root.take(), key, value);
self.root = new_root;
if old_key_value.is_none() {
self.size += 1;
}
}
pub fn remove(&mut self, key: &K) -> Option<M::Val> {
let (new_root, old_key_value) = delete(self.root.take(), key);
self.root = new_root;
if let Some((_, old_value)) = old_key_value {
self.size -= 1;
Some(old_value)
} else {
None
}
}
pub fn get_range<R: RangeBounds<K>>(&self, range: R) -> M::Val {
let l = range.start_bound();
let r = range.end_bound();
get_range(&self.root, l, r, Unbounded, Unbounded)
}
pub fn len(&self) -> usize {
self.size
}
pub fn is_empty(&self) -> bool {
self.size == 0
}
}
pub struct NodeEntry<'a, K: Ord, M: 'a + Monoid> {
root: &'a mut Node<K, M>,
key: Option<K>,
value: Option<M::Val>,
}
impl<K, M> Debug for NodeEntry<'_, K, M>
where
K: Ord + Debug,
M: Monoid,
M::Val: Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("NodeEntry")
.field("key", &self.key.as_ref().unwrap())
.field("value", &self.value.as_ref().unwrap())
.finish()
}
}
impl<K: Ord, M: Monoid> Drop for NodeEntry<'_, K, M> {
fn drop(&mut self) {
let root = self.root.take();
let key = self.key.take().unwrap();
let value = self.value.take().unwrap();
(*self.root, _) = insert(root, key, value);
}
}
impl<K: Ord, M: Monoid> Deref for NodeEntry<'_, K, M> {
type Target = M::Val;
fn deref(&self) -> &Self::Target {
self.value.as_ref().unwrap()
}
}
impl<K: Ord, M: Monoid> DerefMut for NodeEntry<'_, K, M> {
fn deref_mut(&mut self) -> &mut Self::Target {
self.value.as_mut().unwrap()
}
}
impl<K: Ord, M: Monoid> Default for DynamicSegmentTree<K, M> {
fn default() -> Self {
Self {
root: None,
size: 0,
tmp_e: M::id(),
}
}
}
}
mod node {
#![allow(non_snake_case)]
#![allow(clippy::type_complexity)]
use crate::algebraic_structure::monoid::Monoid;
use std::{
cmp::Ordering,
fmt::Debug,
mem,
ops::Bound::{self, *},
};
pub type Node<K, M> = Option<Box<NodeInner<K, M>>>;
pub struct NodeInner<K: Ord, M: Monoid> {
pub key: K,
pub value: M::Val,
pub sum: M::Val,
pub level: usize,
pub left: Node<K, M>,
pub right: Node<K, M>,
}
impl<K: Ord, M: Monoid> NodeInner<K, M> {
pub fn new(key: K, value: M::Val) -> Node<K, M> {
Some(Box::new(NodeInner {
key,
value: value.clone(),
sum: value,
level: 1,
left: None,
right: None,
}))
}
fn eval(&mut self) {
self.sum = match (&self.left, &self.right) {
(Some(l), Some(r)) => M::op(&M::op(&l.sum, &self.value), &r.sum),
(Some(l), _) => M::op(&l.sum, &self.value),
(_, Some(r)) => M::op(&self.value, &r.sum),
_ => self.value.clone(),
};
}
}
impl<K, M> Debug for NodeInner<K, M>
where
K: Ord + Debug,
M: Monoid,
M::Val: Debug,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Node")
.field("key", &self.key)
.field("value", &self.value)
.field("sum", &self.sum)
.finish()
}
}
fn skew<K: Ord, M: Monoid>(node: Node<K, M>) -> Node<K, M> {
let mut T = node?;
if T.left.is_none() {
Some(T)
} else if T.level == T.left.as_ref().unwrap().level {
let mut L = T.left.unwrap();
T.left = L.right;
T.eval();
L.right = Some(T);
L.eval();
Some(L)
} else {
Some(T)
}
}
fn split<K: Ord, M: Monoid>(node: Node<K, M>) -> Node<K, M> {
let mut T = node?;
if T.right.is_none() || T.right.as_ref().unwrap().right.is_none() {
Some(T)
} else if T.level == T.right.as_ref().unwrap().right.as_ref().unwrap().level {
let mut R = T.right.unwrap();
T.right = R.left;
T.eval();
R.left = Some(T);
R.eval();
R.level += 1; Some(R)
} else {
Some(T)
}
}
pub fn get<'a, K: Ord, M: Monoid>(
root: &'a Node<K, M>,
key: &K,
) -> Option<&'a NodeInner<K, M>> {
let Some(T) = root else {
return None;
};
match key.cmp(&T.key) {
Ordering::Less => get(&T.left, key),
Ordering::Greater => get(&T.right, key),
Ordering::Equal => Some(T),
}
}
type Segment<K> = (Bound<K>, Bound<K>);
fn has_no_intersection<K: Ord>((l, r): Segment<&K>, (begin, end): Segment<&K>) -> bool {
(match (r, begin) {
(Included(r), Included(b)) => r < b,
(Included(r), Excluded(b)) => r <= b,
(Excluded(r), Included(b)) => r <= b,
(Excluded(r), Excluded(b)) => r <= b,
_ => false,
} || match (end, l) {
(Included(e), Included(l)) => e < l,
(Included(e), Excluded(l)) => e <= l,
(Excluded(e), Included(l)) => e <= l,
(Excluded(e), Excluded(l)) => e <= l,
_ => false,
})
}
fn includes<K: Ord>((l, r): Segment<&K>, (begin, end): Segment<&K>) -> bool {
(match (l, begin) {
(Unbounded, _) => true,
(_, Unbounded) => false,
(Included(l), Included(b)) => l <= b,
(Included(l), Excluded(b)) => l <= b,
(Excluded(l), Included(b)) => l < b,
(Excluded(l), Excluded(b)) => l <= b,
} && match (end, r) {
(_, Unbounded) => true,
(Unbounded, _) => false,
(Included(e), Included(r)) => e <= r,
(Included(e), Excluded(r)) => e < r,
(Excluded(e), Included(r)) => e <= r,
(Excluded(e), Excluded(r)) => e <= r,
})
}
pub fn get_range<K: Ord, M: Monoid>(
root: &Node<K, M>,
l: Bound<&K>,
r: Bound<&K>,
begin: Bound<&K>,
end: Bound<&K>,
) -> M::Val {
let Some(T) = root else {
return M::id();
};
if has_no_intersection((l, r), (begin, end)) {
M::id()
}
else if includes((l, r), (begin, end)) {
T.sum.clone()
}
else {
let mid = &T.key;
let l_val = get_range(&T.left, l, r, begin, Excluded(mid));
let m_val = if includes((l, r), (Included(mid), Included(mid))) {
T.value.clone()
} else {
M::id()
};
let r_val = get_range(&T.right, l, r, Excluded(mid), end);
M::op(&M::op(&l_val, &m_val), &r_val)
}
}
pub fn insert<K: Ord, M: Monoid>(
root: Node<K, M>,
key: K,
value: M::Val,
) -> (Node<K, M>, Option<(K, M::Val)>) {
let Some(mut T) = root else {
return (NodeInner::new(key, value), None);
};
let old_key_value = match key.cmp(&T.key) {
Ordering::Less => {
let (new_left, old_key_value) = insert(T.left, key, value);
T.left = new_left;
old_key_value
}
Ordering::Greater => {
let (new_right, old_key_value) = insert(T.right, key, value);
T.right = new_right;
old_key_value
}
Ordering::Equal => Some((
mem::replace(&mut T.key, key),
mem::replace(&mut T.value, value),
)),
};
T.eval();
let mut root = Some(T);
root = skew(root);
root = split(root);
(root, old_key_value)
}
pub fn delete<K: Ord, M: Monoid>(
root: Node<K, M>,
key: &K,
) -> (Node<K, M>, Option<(K, M::Val)>) {
let Some(mut T) = root else {
return (None, None);
};
let (mut new_root, old_key_value) = match key.cmp(&T.key) {
Ordering::Less => {
let (new_left, old_key_value) = delete(T.left, key);
T.left = new_left;
(Some(T), old_key_value)
}
Ordering::Greater => {
let (new_right, old_key_value) = delete(T.right, key);
T.right = new_right;
(Some(T), old_key_value)
}
Ordering::Equal => {
if T.left.is_none() {
(T.right, Some((T.key, T.value)))
} else if T.right.is_none() {
(T.left, Some((T.key, T.value)))
} else {
let (new_left, right_most) = delete_and_get_max(T.left.take());
if let Some(L) = new_left {
T.left.replace(L);
}
let Some(right_most) = right_most else {
unreachable!("T.left is not None");
};
let old_key_value = (
mem::replace(&mut T.key, right_most.key),
mem::replace(&mut T.value, right_most.value),
);
(Some(T), Some(old_key_value))
}
}
};
if let Some(T) = &mut new_root {
T.eval();
}
let rebalanced = rebarance(new_root);
(rebalanced, old_key_value)
}
fn rebarance<K: Ord, M: Monoid>(root: Node<K, M>) -> Node<K, M> {
let mut T = root?;
let left_level = T.left.as_ref().map_or(0, |node| node.level);
let right_level = T.right.as_ref().map_or(0, |node| node.level);
if left_level.min(right_level) < T.level - 1 {
T.level -= 1;
if right_level > T.level {
T.right.as_mut().unwrap().level = T.level;
}
T = skew(Some(T)).unwrap();
T.right = skew(T.right);
if let Some(mut right) = T.right.take() {
right.right = skew(right.right);
T.right.replace(right);
}
T = split(Some(T)).unwrap();
T.right = split(T.right);
T.eval();
}
Some(T)
}
fn delete_and_get_max<K: Ord, M: Monoid>(
root: Node<K, M>,
) -> (Node<K, M>, Option<NodeInner<K, M>>) {
let Some(mut T) = root else {
return (None, None);
};
let (new_right, right_most) = delete_and_get_max(T.right.take());
let Some(right_most) = right_most else {
return (None, Some(*T));
};
if let Some(R) = new_right {
T.right.replace(R);
}
T.eval();
let mut new_root = Some(T);
new_root = rebarance(new_root);
(new_root, Some(right_most))
}
}
mod print_util {
use super::{dynamic_segment_tree_::DynamicSegmentTree, monoid::Monoid, node::Node};
use std::fmt::Debug;
const LEFT: &str = " ┌──";
const MID: &str = " │ ";
const RIGHT: &str = " └──";
const NULL: &str = "";
const BLANK: &str = " ";
impl<K, M> DynamicSegmentTree<K, M>
where
K: Ord + Debug,
M: Monoid,
M::Val: Debug,
{
pub fn print_as_binary_tree(&self) {
#[cfg(debug_assertions)]
{
eprintln!("┌─ BinaryTree ──────────────────────");
fmt_inner_binary_tree(&self.root, &mut vec![], NULL);
eprintln!("└───────────────────────────────────");
}
}
}
fn fmt_inner_binary_tree<K, M: Monoid>(
node: &Node<K, M>,
fill: &mut Vec<&'static str>,
last: &'static str,
) where
K: Ord + Debug,
M::Val: Debug,
{
if let Some(node) = node.as_ref() {
let mut tmp = None;
if fill.last().is_some_and(|x| x == &last) {
tmp = fill.pop();
fill.push(BLANK);
} else if fill.last().is_some_and(|x| x != &NULL && x != &BLANK) {
tmp = fill.pop();
fill.push(MID);
}
fill.push(last);
fmt_inner_binary_tree(&node.left, fill, LEFT);
eprintln!(
"│{} {:?}",
fill.iter().fold(String::new(), |s, x| s + x),
node
);
fmt_inner_binary_tree(&node.right, fill, RIGHT);
fill.pop();
if let Some(tmp) = tmp {
fill.pop();
fill.push(tmp);
}
}
}
}