use std::iter::FromIterator;
use std::mem::{replace, swap};
use std::{cmp::Ordering, fmt::Debug};
#[derive(Debug, Clone)]
pub struct Node<T: Ord> {
pub key: T,
pub left: Option<Box<Node<T>>>,
pub right: Option<Box<Node<T>>>,
pub size: usize,
}
impl<T: Ord> Node<T> {
pub fn new(key: T) -> Self {
Self {
key,
left: None,
right: None,
size: 1,
}
}
}
pub struct IndexedSet<T: Ord> {
size: usize,
pub root: Option<Box<Node<T>>>,
}
impl<T> IndexedSet<T>
where
T: Ord + Clone,
{
#[inline]
fn le(a: &T, b: &T) -> bool {
matches!(a.cmp(b), Ordering::Less | Ordering::Equal)
}
#[inline]
fn lt(a: &T, b: &T) -> bool {
matches!(a.cmp(b), Ordering::Less)
}
#[inline]
fn ge(a: &T, b: &T) -> bool {
matches!(a.cmp(b), Ordering::Equal | Ordering::Greater)
}
#[inline]
fn gt(a: &T, b: &T) -> bool {
matches!(a.cmp(b), Ordering::Greater)
}
pub fn new() -> Self {
Self {
size: 0,
root: None,
}
}
pub fn len(&self) -> usize {
self.size
}
pub fn is_empty(&self) -> bool {
self.size == 0
}
pub fn get(&mut self, key: &T) -> Option<&T> {
let lb = self.lower_bound(key);
if lb.is_some_and(|k| k == key) {
lb
} else {
None
}
}
pub fn insert(&mut self, key: T) -> Option<T> {
let root = self.root.take();
let (mut tmp_root, _) = splay(root, &key, Self::le);
if tmp_root.is_some() && tmp_root.as_ref().unwrap().key == key {
self.root = tmp_root;
let res = replace(&mut self.root.as_deref_mut().unwrap().key, key);
return Some(res);
}
self.root = Some(Box::new(Node::new(key.clone())));
if tmp_root.is_some() {
match key.cmp(&tmp_root.as_ref().unwrap().key) {
Ordering::Less | Ordering::Equal => {
let mut new_left = tmp_root.as_mut().unwrap().left.take();
update_size(&mut tmp_root);
swap(&mut self.root.as_mut().unwrap().left, &mut new_left);
swap(&mut self.root.as_mut().unwrap().right, &mut tmp_root);
}
Ordering::Greater => {
let mut new_right = tmp_root.as_mut().unwrap().right.take();
update_size(&mut tmp_root);
swap(&mut self.root.as_mut().unwrap().right, &mut new_right);
swap(&mut self.root.as_mut().unwrap().left, &mut tmp_root);
}
}
}
update_size(&mut self.root);
self.size += 1;
None
}
pub fn remove(&mut self, key: &T) -> Option<T> {
if self.is_empty() {
return None;
}
let root = self.root.take();
let (mut tmp_root, _) = splay(root, key, Self::le);
if tmp_root.is_none() || &tmp_root.as_ref().unwrap().key != key {
self.root = tmp_root;
return None;
}
if tmp_root.as_ref().unwrap().left.is_none() {
swap(&mut self.root, &mut tmp_root.as_mut().unwrap().right);
} else {
let root_left = tmp_root.as_mut().unwrap().left.take();
swap(&mut self.root, &mut splay(root_left, key, Self::lt).0);
swap(
&mut self.root.as_mut().unwrap().right,
&mut tmp_root.as_mut().unwrap().right,
);
}
update_size(&mut self.root);
self.size -= 1;
let deleted = tmp_root.take();
Some(deleted.unwrap().key)
}
pub fn contains_key(&mut self, key: &T) -> bool {
self.get(key).is_some_and(|k| k == key)
}
pub fn lower_bound(&mut self, key: &T) -> Option<&T> {
let root = self.root.take();
let (new_root, is_found) = splay(root, key, Self::le);
self.root = new_root;
if is_found {
Some(&self.root.as_ref().unwrap().key)
} else {
None
}
}
pub fn upper_bound(&mut self, key: &T) -> Option<&T> {
let root = self.root.take();
let (new_root, is_found) = splay(root, key, Self::lt);
self.root = new_root;
if is_found {
Some(&self.root.as_ref().unwrap().key)
} else {
None
}
}
pub fn lower_bound_rev(&mut self, key: &T) -> Option<&T> {
let root = self.root.take();
let (new_root, is_found) = splay_rev(root, key, Self::ge);
self.root = new_root;
if is_found {
Some(&self.root.as_ref().unwrap().key)
} else {
None
}
}
pub fn upper_bound_rev(&mut self, key: &T) -> Option<&T> {
let root = self.root.take();
let (new_root, is_found) = splay_rev(root, key, Self::gt);
self.root = new_root;
if is_found {
Some(&self.root.as_ref().unwrap().key)
} else {
None
}
}
pub fn get_by_index(&self, n: usize) -> Option<&T> {
if n > self.size {
None
} else {
get_nth(&self.root, n + 1)
}
}
pub fn index(&mut self, key: &T) -> Option<usize> {
if self.get(key).is_some() {
let left_size = self
.root
.as_ref()
.unwrap()
.left
.as_ref()
.map_or(0, |node| node.size);
Some(left_size)
} else {
None
}
}
}
fn get_nth<T: Ord>(root: &Option<Box<Node<T>>>, n: usize) -> Option<&T> {
if let Some(root) = root {
let left_size = root.left.as_ref().map_or(0, |node| node.size);
match n.cmp(&(left_size + 1)) {
Ordering::Less => get_nth(&root.left, n),
Ordering::Equal => Some(&root.key),
Ordering::Greater => get_nth(&root.right, n - left_size - 1),
}
} else {
None
}
}
fn splay<T, C>(mut root: Option<Box<Node<T>>>, key: &T, compare: C) -> (Option<Box<Node<T>>>, bool)
where
T: Ord,
C: Fn(&T, &T) -> bool,
{
if root.is_none() {
return (root, false);
}
if compare(key, &root.as_ref().unwrap().key) {
let left = &mut root.as_mut().unwrap().left;
if left.is_none() {
return (root, true);
}
if compare(key, &left.as_ref().unwrap().key) {
let leftleft = left.as_mut().unwrap().left.take();
let (mut tmp, is_found) = splay(leftleft, key, compare);
swap(&mut left.as_mut().unwrap().left, &mut tmp);
let tmp_left = rotate_right(root);
if !is_found {
return (tmp_left, true);
}
(rotate_right(tmp_left), true)
} else {
let leftright = left.as_mut().unwrap().right.take();
let (mut new_leftright, is_found) = splay(leftright, key, compare);
swap(&mut left.as_mut().unwrap().right, &mut new_leftright);
if !is_found {
return (root, true);
}
let left = root.as_mut().unwrap().left.take();
let mut tmp_child = rotate_left(left);
swap(&mut root.as_mut().unwrap().left, &mut tmp_child);
(rotate_right(root), true)
}
} else {
let right = &mut root.as_mut().unwrap().right;
if right.is_none() {
return (root, false);
}
if compare(key, &right.as_ref().unwrap().key) {
let rightleft = right.as_mut().unwrap().left.take();
let (mut tmp, is_found) = splay(rightleft, key, compare);
swap(&mut right.as_mut().unwrap().left, &mut tmp);
if is_found {
let right = root.as_mut().unwrap().right.take();
let mut tmp_child = rotate_right(right);
swap(&mut root.as_mut().unwrap().right, &mut tmp_child);
}
(rotate_left(root), true)
} else {
let rightright = right.as_mut().unwrap().right.take();
let (mut tmp, is_found) = splay(rightright, key, compare);
swap(&mut right.as_mut().unwrap().right, &mut tmp);
let tmp_child = rotate_left(root);
(rotate_left(tmp_child), is_found)
}
}
}
fn splay_rev<T, C>(
mut root: Option<Box<Node<T>>>,
key: &T,
compare: C,
) -> (Option<Box<Node<T>>>, bool)
where
T: Ord,
C: Fn(&T, &T) -> bool,
{
if root.is_none() {
return (root, false);
}
if compare(key, &root.as_ref().unwrap().key) {
let right = &mut root.as_mut().unwrap().right;
if right.is_none() {
return (root, true);
}
if compare(key, &right.as_ref().unwrap().key) {
let rightright = right.as_mut().unwrap().right.take();
let (mut tmp, is_found) = splay_rev(rightright, key, compare);
swap(&mut right.as_mut().unwrap().right, &mut tmp);
let tmp_right = rotate_left(root);
if !is_found {
return (tmp_right, true);
}
(rotate_left(tmp_right), true)
} else {
let rightleft = right.as_mut().unwrap().left.take();
let (mut new_rightleft, is_found) = splay_rev(rightleft, key, compare);
swap(&mut right.as_mut().unwrap().left, &mut new_rightleft);
if !is_found {
return (root, true);
}
let right = root.as_mut().unwrap().right.take();
let mut tmp_child = rotate_right(right);
swap(&mut root.as_mut().unwrap().right, &mut tmp_child);
(rotate_left(root), true)
}
} else {
let left = &mut root.as_mut().unwrap().left;
if left.is_none() {
return (root, false);
}
if compare(key, &left.as_ref().unwrap().key) {
let leftright = left.as_mut().unwrap().right.take();
let (mut tmp, is_found) = splay_rev(leftright, key, compare);
swap(&mut left.as_mut().unwrap().right, &mut tmp);
if is_found {
let left = root.as_mut().unwrap().left.take();
let mut tmp_child = rotate_left(left);
swap(&mut root.as_mut().unwrap().left, &mut tmp_child);
}
(rotate_right(root), true)
} else {
let leftleft = left.as_mut().unwrap().left.take();
let (mut tmp, is_found) = splay_rev(leftleft, key, compare);
swap(&mut left.as_mut().unwrap().left, &mut tmp);
let tmp_child = rotate_right(root);
(rotate_right(tmp_child), is_found)
}
}
}
fn update_size<T: Ord>(node: &mut Option<Box<Node<T>>>) {
if let Some(node) = node {
let left_size = node.left.as_ref().map_or(0, |node| node.size);
let right_size = node.right.as_ref().map_or(0, |node| node.size);
node.size = left_size + right_size + 1;
}
}
fn rotate_right<T: Ord>(root: Option<Box<Node<T>>>) -> Option<Box<Node<T>>> {
let mut root = root?;
let Some(mut new_root) = root.left else {
return Some(root);
};
root.left = new_root.right;
new_root.right = Some(root);
update_size(&mut new_root.right);
let mut res = Some(new_root);
update_size(&mut res);
res
}
fn rotate_left<T: Ord>(root: Option<Box<Node<T>>>) -> Option<Box<Node<T>>> {
let mut root = root?;
let Some(mut new_root) = root.right else {
return Some(root);
};
root.right = new_root.left;
new_root.left = Some(root);
update_size(&mut new_root.left);
let mut res = Some(new_root);
update_size(&mut res);
res
}
impl<T: Ord + Clone> FromIterator<T> for IndexedSet<T> {
fn from_iter<I: IntoIterator<Item = T>>(iter: I) -> Self {
let mut res = IndexedSet::new();
for item in iter {
res.insert(item);
}
res
}
}
pub struct SplayTreeIterator<'a, T: 'a + Ord> {
unvisited: Vec<&'a Node<T>>,
}
impl<'a, T: Ord> SplayTreeIterator<'a, T> {
fn push_left_edge(&mut self, mut tree: &'a Option<Box<Node<T>>>) {
while let Some(node) = tree.as_deref() {
self.unvisited.push(node);
tree = &node.left;
}
}
}
impl<'a, T: Ord> Iterator for SplayTreeIterator<'a, T> {
type Item = &'a T;
fn next(&mut self) -> Option<Self::Item> {
let node = self.unvisited.pop()?;
self.push_left_edge(&node.right);
Some(&node.key)
}
}
impl<T: Ord> IndexedSet<T> {
pub fn iter(&self) -> SplayTreeIterator<'_, T> {
let mut iter = SplayTreeIterator { unvisited: vec![] };
iter.push_left_edge(&self.root);
iter
}
}
impl<'a, T: Ord> IntoIterator for &'a IndexedSet<T> {
type IntoIter = SplayTreeIterator<'a, T>;
type Item = &'a T;
fn into_iter(self) -> Self::IntoIter {
self.iter()
}
}
impl<T: Ord + Debug> Debug for IndexedSet<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_set().entries(self.iter()).finish()
}
}