use std::collections::HashMap;
use crate::utils::consts::INF;
pub struct Rerooting<T, M, FE, FV>
where
T: Clone,
M: Fn(&T, &T) -> T,
FE: Fn(&T, usize) -> T,
FV: Fn(&T, usize) -> T,
{
pub dp: Vec<Vec<T>>,
pub ans: Vec<T>,
pub G: Vec<Vec<usize>>,
edge_cnt: usize,
pub edge_id: HashMap<(usize, usize), (usize, usize)>,
id: T,
merge: M,
put_edge: FE,
put_vertex: FV,
}
impl<T, M, FE, FV> Rerooting<T, M, FE, FV>
where
T: Clone,
M: Fn(&T, &T) -> T,
FE: Fn(&T, usize) -> T,
FV: Fn(&T, usize) -> T,
{
pub fn new(N: usize, id: T, merge: M, put_edge: FE, put_vertex: FV) -> Self {
Self {
dp: vec![vec![]; N],
ans: vec![id.clone(); N],
G: vec![vec![]; N],
edge_cnt: 0,
edge_id: HashMap::default(),
id,
merge,
put_edge,
put_vertex,
}
}
pub fn add_edge(&mut self, u: usize, v: usize) {
let pos = self.G[u].len();
self.G[u].push(v);
self.edge_id.insert((u, v), (self.edge_cnt, pos));
self.edge_cnt += 1;
}
pub fn add_edge2(&mut self, u: usize, v: usize) {
let pos_u_v = self.G[u].len();
self.G[u].push(v);
let pos_v_u = self.G[v].len();
self.G[v].push(u);
self.edge_id.insert((u, v), (self.edge_cnt, pos_u_v));
self.edge_id.insert((v, u), (self.edge_cnt, pos_v_u));
self.edge_cnt += 1;
}
pub fn build(&mut self) {
self.aggregate(INF, 0);
self.reroot(INF, 0);
}
pub fn aggregate(&mut self, p: usize, u: usize) -> T {
let mut res = self.id.clone();
let deg = self.G[u].len();
self.dp[u] = vec![self.id.clone(); deg];
for i in 0..deg {
let v = self.G[u][i];
if v == p {
continue;
}
let mut val = self.aggregate(u, v);
let (edge_vu, _) = *self.edge_id.get(&(v, u)).unwrap();
val = (self.put_edge)(&val, edge_vu);
res = (self.merge)(&res, &val);
self.dp[u][i] = val;
}
res = (self.put_vertex)(&res, u);
res
}
pub fn reroot(&mut self, p: usize, u: usize) {
let deg = self.G[u].len();
let mut Sl = vec![self.id.clone(); deg + 1];
let mut Sr = vec![self.id.clone(); deg + 1];
for i in 0..deg {
Sl[i + 1] = (self.merge)(&Sl[i], &self.dp[u][i]);
}
for i in (0..deg).rev() {
Sr[i] = (self.merge)(&self.dp[u][i], &Sr[i + 1]);
}
self.ans[u] = (self.put_vertex)(&Sl[deg], u);
for i in 0..deg {
let v = self.G[u][i];
if v == p {
continue;
}
let val = (self.put_vertex)(&(self.merge)(&Sl[i], &Sr[i + 1]), u);
let (edge_uv, _) = *self.edge_id.get(&(u, v)).unwrap();
let (_, pos_u) = *self.edge_id.get(&(v, u)).unwrap();
self.dp[v][pos_u] = (self.put_edge)(&val, edge_uv);
self.reroot(u, v);
}
}
}