use std::{collections::HashMap, fmt::Debug, mem};
use crate::{algebraic_structure::commutative::CommutativeMonoid, utils::consts::NEG1};
pub type UnionFind = UnionFindMonoid<()>;
pub struct UnionFindMonoid<M: CommutativeMonoid> {
n: usize,
parent: Vec<usize>,
value: Vec<Option<M::Val>>,
count: usize,
}
impl<M: CommutativeMonoid> UnionFindMonoid<M> {
pub fn root(&mut self, mut x: usize) -> usize {
let mut root = x;
while self.parent[root] < self.n {
root = self.parent[root];
}
while self.parent[x] < self.n {
x = mem::replace(&mut self.parent[x], root);
}
root
}
pub fn value(&mut self, x: usize) -> &M::Val {
let root = self.root(x);
self.value[root].as_ref().unwrap()
}
pub fn is_same(&mut self, x: usize, y: usize) -> bool {
self.root(x) == self.root(y)
}
pub fn unite(&mut self, x: usize, y: usize) -> Option<usize> {
let mut parent = self.root(x);
let mut child = self.root(y);
if parent == child {
return None;
}
if self.parent[parent] > self.parent[child] {
(parent, child) = (child, parent);
}
self.parent[parent] = self.parent[parent].wrapping_add(self.parent[child]);
self.parent[child] = parent;
self.count -= 1;
let child_val = self.value[child].take();
let parent_val = self.value[parent].take();
self.value[parent] = child_val.zip(parent_val).map(|(c, p)| M::op(&c, &p));
Some(parent)
}
pub fn get_size(&mut self, x: usize) -> usize {
let root = self.root(x);
self.parent[root].wrapping_neg()
}
pub fn group_count(&self) -> usize {
self.count
}
pub fn enum_groups(&mut self) -> HashMap<usize, Vec<usize>> {
(0..self.n).fold(HashMap::default(), |mut map, i| {
let root = self.root(i);
map.entry(root).or_default().push(i);
map
})
}
}
impl UnionFindMonoid<()> {
pub fn new(n: usize) -> Self {
UnionFindMonoid {
n,
parent: vec![NEG1; n],
value: vec![None; n],
count: n,
}
}
}
impl<M: CommutativeMonoid> From<Vec<M::Val>> for UnionFindMonoid<M> {
fn from(value: Vec<M::Val>) -> Self {
let N = value.len();
UnionFindMonoid {
n: N,
parent: vec![NEG1; N],
value: value.into_iter().map(Some).collect(),
count: N,
}
}
}
impl<M: CommutativeMonoid> FromIterator<M::Val> for UnionFindMonoid<M> {
fn from_iter<T: IntoIterator<Item = M::Val>>(iter: T) -> Self {
UnionFindMonoid::from(iter.into_iter().collect::<Vec<_>>())
}
}
impl<M> Debug for UnionFindMonoid<M>
where
M: CommutativeMonoid,
M::Val: Debug + Clone,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let mut uf = UnionFindMonoid::<M> {
n: self.n,
parent: self.parent.clone(),
value: self.value.clone(),
count: self.count,
};
let groups = uf.enum_groups();
f.debug_map().entries(groups).finish()
}
}