cp_library_rs/data_structure/
weighted_union_find.rs

1//! 重み付きUnionFind
2
3use crate::algebraic_structure::abel::Abel;
4
5/// 重み付きUnionFind
6pub struct WeightedUnionFind<G: Abel> {
7    par: Vec<usize>,
8    rank: Vec<usize>,
9    weight: Vec<G::Val>,
10    group_count: usize,
11}
12
13impl<G: Abel> WeightedUnionFind<G>
14where
15    G::Val: Eq,
16{
17    /// UnionFindを構築
18    pub fn new(n: usize) -> Self {
19        WeightedUnionFind {
20            par: (0..n).collect(),
21            rank: vec![1; n],
22            weight: vec![G::e(); n],
23            group_count: n,
24        }
25    }
26
27    /// 根を求める
28    pub fn get_root(&mut self, x: usize) -> usize {
29        if self.par[x] == x {
30            return x;
31        }
32        let r = self.get_root(self.par[x]);
33        let parent = self.weight[self.par[x]].clone();
34        let child = self.weight.get_mut(x).unwrap();
35        *child = G::op(child, &parent);
36        self.par[x] = r; // 経路圧縮
37        r
38    }
39
40    /// 重みを求める
41    pub fn weight(&mut self, x: usize) -> G::Val {
42        self.get_root(x); // 経路圧縮
43        self.weight[x].clone()
44    }
45
46    /// 同一の集合に所属するか判定
47    pub fn is_same(&mut self, x: usize, y: usize) -> bool {
48        self.get_root(x) == self.get_root(y)
49    }
50
51    /// 重みの差を求める
52    ///
53    /// 同じグループにいない場合にはNoneを返す
54    pub fn diff(&mut self, x: usize, y: usize) -> Option<G::Val> {
55        if self.is_same(x, y) {
56            let res = G::op(&self.weight(y), &G::inv(&self.weight(x)));
57            return Some(res);
58        }
59        None
60    }
61
62    /// 集合`x,y`を`self.diff(x, y) = weight`となるように併合する.
63    ///
64    /// **戻り値**
65    /// - すでに`x,y`が併合済みだった場合
66    ///   - `self.diff(x, y) == weight` の場合 → `Some(false)`
67    ///   - `self.diff(x, y) != weight` の場合 → `Err(())`
68    /// - `x,y`が併合済みでない場合 → `Ok(true)`
69    pub fn unite(&mut self, mut x: usize, mut y: usize, mut weight: G::Val) -> Result<bool, &str> {
70        // すでにmerge済みの場合
71        if let Some(w) = self.diff(x, y) {
72            return if w == weight {
73                Ok(false)
74            } else {
75                Err("weight mismatch")
76            };
77        }
78
79        // x, yそれぞれについて重み差分を補正
80        weight = G::op(&weight, &self.weight(x));
81        weight = G::op(&weight, &G::inv(&self.weight(y)));
82
83        x = self.get_root(x);
84        y = self.get_root(y);
85
86        // 要素数が大きい方を子にすることで、高さを均等に保つ
87        if self.rank[x] < self.rank[y] {
88            std::mem::swap(&mut x, &mut y);
89            weight = G::inv(&weight);
90        }
91
92        self.par[y] = x;
93        self.rank[x] += self.rank[y];
94        self.group_count -= 1;
95
96        // 重みの更新
97        self.weight[y] = weight;
98
99        Ok(true)
100    }
101
102    /// `x`が属する集合の大きさを求める
103    pub fn get_size(&mut self, x: usize) -> usize {
104        let get_root = self.get_root(x);
105        self.rank[get_root]
106    }
107
108    /// 全体の要素数を求める
109    #[inline]
110    pub fn group_count(&self) -> usize {
111        self.group_count
112    }
113}