use crate::algebraic_structure::abel::Abel;
pub struct WeightedUnionFind<G: Abel> {
par: Vec<usize>,
rank: Vec<usize>,
weight: Vec<G::Val>,
group_count: usize,
}
impl<G: Abel> WeightedUnionFind<G>
where
G::Val: Eq,
{
pub fn new(n: usize) -> Self {
WeightedUnionFind {
par: (0..n).collect(),
rank: vec![1; n],
weight: vec![G::id(); n],
group_count: n,
}
}
pub fn get_root(&mut self, x: usize) -> usize {
if self.par[x] == x {
return x;
}
let r = self.get_root(self.par[x]);
let parent = self.weight[self.par[x]].clone();
let child = self.weight.get_mut(x).unwrap();
*child = G::op(child, &parent);
self.par[x] = r; r
}
pub fn weight(&mut self, x: usize) -> G::Val {
self.get_root(x); self.weight[x].clone()
}
pub fn is_same(&mut self, x: usize, y: usize) -> bool {
self.get_root(x) == self.get_root(y)
}
pub fn diff(&mut self, x: usize, y: usize) -> Option<G::Val> {
if self.is_same(x, y) {
let res = G::op(&self.weight(y), &G::inv(&self.weight(x)));
return Some(res);
}
None
}
pub fn unite(&mut self, mut x: usize, mut y: usize, mut weight: G::Val) -> Result<bool, &str> {
if let Some(w) = self.diff(x, y) {
return if w == weight {
Ok(false)
} else {
Err("weight mismatch")
};
}
weight = G::op(&weight, &self.weight(x));
weight = G::op(&weight, &G::inv(&self.weight(y)));
x = self.get_root(x);
y = self.get_root(y);
if self.rank[x] < self.rank[y] {
std::mem::swap(&mut x, &mut y);
weight = G::inv(&weight);
}
self.par[y] = x;
self.rank[x] += self.rank[y];
self.group_count -= 1;
self.weight[y] = weight;
Ok(true)
}
pub fn get_size(&mut self, x: usize) -> usize {
let get_root = self.get_root(x);
self.rank[get_root]
}
#[inline]
pub fn group_count(&self) -> usize {
self.group_count
}
}