cp_library_rs/data_structure/
union_find.rs

1//! ## UnionFind木
2//!
3//! モノイドを乗せるUnionFind木.
4
5use std::{collections::HashMap, fmt::Debug, mem};
6
7use crate::{algebraic_structure::commutative::CommutativeMonoid, utils::consts::NEG1};
8
9/// UnionFind木
10pub type UnionFind = UnionFindMonoid<()>;
11
12/// UnionFind木(モノイド)
13pub struct UnionFindMonoid<M: CommutativeMonoid> {
14    /// 要素数
15    n: usize,
16    /// 親の番号を格納する配列
17    parent: Vec<usize>,
18    /// 値
19    value: Vec<Option<M::Val>>,
20    /// 連結成分の個数
21    count: usize,
22}
23
24impl<M: CommutativeMonoid> UnionFindMonoid<M> {
25    /// 根を求める
26    pub fn root(&mut self, mut x: usize) -> usize {
27        // 根を探索
28        let mut root = x;
29        while self.parent[root] < self.n {
30            root = self.parent[root];
31        }
32        // 経路圧縮
33        while self.parent[x] < self.n {
34            x = mem::replace(&mut self.parent[x], root);
35        }
36        root
37    }
38
39    /// ノード`x`が属する集合の値を取得
40    pub fn value(&mut self, x: usize) -> &M::Val {
41        let root = self.root(x);
42        self.value[root].as_ref().unwrap()
43    }
44
45    /// 同一の集合に所属するか判定
46    pub fn is_same(&mut self, x: usize, y: usize) -> bool {
47        self.root(x) == self.root(y)
48    }
49
50    /// 集合`x,y`を併合する.
51    ///
52    /// **戻り値**
53    /// - すでに併合済みだった場合`None`,そうでない場合親となった要素の番号を返す
54    pub fn unite(&mut self, x: usize, y: usize) -> Option<usize> {
55        let mut parent = self.root(x);
56        let mut child = self.root(y);
57
58        if parent == child {
59            return None;
60        }
61
62        // 要素数が大きい方を親にすることで、高さを均等に保つ
63        if self.parent[parent] > self.parent[child] {
64            (parent, child) = (child, parent);
65        }
66
67        self.parent[parent] = self.parent[parent].wrapping_add(self.parent[child]);
68        self.parent[child] = parent;
69        self.count -= 1;
70
71        // 値のマージ
72        let child_val = self.value[child].take();
73        let parent_val = self.value[parent].take();
74        self.value[parent] = child_val.zip(parent_val).map(|(c, p)| M::op(&c, &p));
75
76        Some(parent)
77    }
78
79    /// 連結成分の大きさを求める
80    pub fn get_size(&mut self, x: usize) -> usize {
81        let root = self.root(x);
82        self.parent[root].wrapping_neg()
83    }
84
85    /// 連結成分の数を返す
86    pub fn group_count(&self) -> usize {
87        self.count
88    }
89
90    /// {代表元: 集合} のマップを返す
91    ///
92    /// - 時間計算量: $`O(N)`$
93    pub fn enum_groups(&mut self) -> HashMap<usize, Vec<usize>> {
94        (0..self.n).fold(HashMap::default(), |mut map, i| {
95            let root = self.root(i);
96            map.entry(root).or_default().push(i);
97            map
98        })
99    }
100}
101
102impl UnionFindMonoid<()> {
103    /// 新しいUnionFind木を生成する
104    pub fn new(n: usize) -> Self {
105        UnionFindMonoid {
106            n,
107            parent: vec![NEG1; n],
108            value: vec![None; n],
109            count: n,
110        }
111    }
112}
113
114impl<M: CommutativeMonoid> From<Vec<M::Val>> for UnionFindMonoid<M> {
115    fn from(value: Vec<M::Val>) -> Self {
116        let N = value.len();
117        UnionFindMonoid {
118            n: N,
119            parent: vec![NEG1; N],
120            value: value.into_iter().map(Some).collect(),
121            count: N,
122        }
123    }
124}
125
126impl<M: CommutativeMonoid> FromIterator<M::Val> for UnionFindMonoid<M> {
127    fn from_iter<T: IntoIterator<Item = M::Val>>(iter: T) -> Self {
128        UnionFindMonoid::from(iter.into_iter().collect::<Vec<_>>())
129    }
130}
131
132impl<M> Debug for UnionFindMonoid<M>
133where
134    M: CommutativeMonoid,
135    M::Val: Debug + Clone,
136{
137    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
138        let mut uf = UnionFindMonoid::<M> {
139            n: self.n,
140            parent: self.parent.clone(),
141            value: self.value.clone(),
142            count: self.count,
143        };
144        let groups = uf.enum_groups();
145
146        f.debug_map().entries(groups).finish()
147    }
148}