cp_library_rs/graph/
rerooting.rs

1//! 全方位木DP
2
3use crate::algebraic_structure::monoid_with_context::MonoidCtx;
4
5pub trait TreeMonoid: MonoidCtx {
6    /// DP の値
7    type T: Clone;
8    /// 辺番号 i の辺を付加する
9    fn put_edge(&self, x: &Self::T, i: usize) -> Self::Val;
10    /// 頂点番号 v の頂点を付加する
11    fn put_vertex(&self, x: &Self::Val, v: usize) -> Self::T;
12}
13
14/// 辺重みを持つグラフ
15pub type Graph = Vec<Vec<Edge>>;
16
17/// 辺の構造体
18#[derive(Clone, Debug)]
19pub struct Edge {
20    pub to: usize,
21    /// 辺のインデックス
22    pub idx: usize,
23    /// 逆辺のインデックス
24    pub ridx: usize,
25}
26
27/// 全方位木DP
28pub struct RerootingDP<M: TreeMonoid> {
29    n: usize,
30    g: Graph,
31    root: usize,
32    monoid: M,
33}
34
35impl<M: TreeMonoid> RerootingDP<M> {
36    /// 空のグラフを初期化する
37    pub fn new(n: usize, monoid: M) -> Self {
38        Self {
39            n,
40            g: vec![vec![]; n],
41            root: 0,
42            monoid,
43        }
44    }
45
46    /// 辺 (u,v) を追加する
47    pub fn add_edge(&mut self, u: usize, v: usize, idx: usize, ridx: usize) {
48        self.g[u].push(Edge { to: v, idx, ridx });
49        self.g[v].push(Edge {
50            to: u,
51            idx: ridx,
52            ridx: idx,
53        });
54    }
55
56    /// 全方位木DP を行う
57    pub fn build(&mut self, root: usize) -> Vec<M::T> {
58        self.root = root;
59
60        let (par, order) = self.rooted_order(root);
61        let agg = self.aggregate(root, &par, &order);
62
63        self.propagate(root, &par, &order, &agg)
64    }
65
66    // ========== internal ==========
67
68    /// root = r で根付けし,親 `par` と DFS 順 `order` を返す
69    fn rooted_order(&self, r: usize) -> (Vec<usize>, Vec<usize>) {
70        let n = self.n;
71        let mut par = vec![usize::MAX; n];
72        par[r] = r;
73
74        let mut order = Vec::with_capacity(n);
75        let mut st = vec![r];
76        while let Some(u) = st.pop() {
77            order.push(u);
78            for e in &self.g[u] {
79                let v = e.to;
80                if par[v] != usize::MAX {
81                    continue;
82                }
83                par[v] = u;
84                st.push(v);
85            }
86        }
87        (par, order)
88    }
89
90    /// 根に集約する
91    fn aggregate(&self, r: usize, par: &[usize], order: &[usize]) -> Vec<M::T> {
92        let n = self.n;
93
94        let mut agg = vec![self.monoid.put_vertex(&self.monoid.e(), 0); n];
95
96        for &u in order.iter().rev() {
97            let mut prod = self.monoid.e();
98            for e in &self.g[u] {
99                let v = e.to;
100                if u != r && v == par[u] {
101                    // 親方向からは集約しない
102                    continue;
103                }
104                let val = self.monoid.put_edge(&agg[v], e.idx);
105                prod = self.monoid.op(&prod, &val);
106            }
107            agg[u] = self.monoid.put_vertex(&prod, u);
108        }
109
110        agg
111    }
112
113    /// 根から伝播する
114    fn propagate(&self, r: usize, par: &[usize], order: &[usize], agg: &[M::T]) -> Vec<M::T> {
115        let n = self.n;
116
117        // 親側の部分木の集約値
118        let mut par_dp = vec![self.monoid.put_vertex(&self.monoid.e(), r); n];
119        let mut ans = vec![self.monoid.put_vertex(&self.monoid.e(), 0); n];
120
121        for &u in order {
122            let deg = self.g[u].len();
123
124            let mut vals: Vec<M::Val> = Vec::with_capacity(deg);
125            for e in &self.g[u] {
126                let v = e.to;
127                let t = if u != r && v == par[u] {
128                    &par_dp[u]
129                } else {
130                    &agg[v]
131                };
132                vals.push(self.monoid.put_edge(t, e.idx));
133            }
134
135            // 先頭からの累積
136            let mut pre: Vec<M::Val> = vec![self.monoid.e(); deg + 1];
137            for i in 0..deg {
138                pre[i + 1] = self.monoid.op(&pre[i], &vals[i]);
139            }
140            // 末尾からの累積
141            let mut suf = vec![self.monoid.e(); deg + 1];
142            for i in (0..deg).rev() {
143                suf[i] = self.monoid.op(&vals[i], &suf[i + 1]);
144            }
145
146            // ans[u] := put_vertex(隣接項の積, u)
147            ans[u] = self.monoid.put_vertex(&pre[deg], u);
148
149            // 子へ伝播
150            for i in 0..deg {
151                let v = self.g[u][i].to;
152                if u != r && v == par[u] {
153                    continue;
154                }
155                let left = &pre[i];
156                let right = &suf[i + 1];
157                let total_except_i = self.monoid.op(left, right);
158
159                // 親側の集約値を反映
160                par_dp[v] = self.monoid.put_vertex(&total_except_i, u);
161            }
162        }
163
164        ans
165    }
166}