cp_library_rs/string/
rolling_hash.rs1#![allow(clippy::len_without_is_empty)]
4
5use std::{
6 fmt::Debug,
7 ops::{Bound, RangeBounds},
8};
9
10use crate::number_theory::modint_for_rollinghash::modint::Modint;
11
12use num_traits::{One, Zero};
13
14#[derive(Debug)]
18pub struct RollingHash {
19 pub size: usize,
20 power: Vec<Modint>,
21 hash: Vec<Modint>,
22 base: Modint,
23}
24
25impl RollingHash {
26 pub fn build(arr: &[Modint], base: Modint) -> Self {
28 let size = arr.len();
29 let mut power = vec![Modint(1); size + 1];
30 let mut hash = vec![Modint(0); size + 1];
31
32 let (mut h, mut p) = (Modint::zero(), Modint::one());
34 for i in 0..size {
35 h = arr[i] + (h * base);
36 p *= base;
37 hash[i + 1] = h;
38 power[i + 1] = p;
39 }
40
41 Self {
42 size,
43 power,
44 hash,
45 base,
46 }
47 }
48
49 pub fn from_str(s: &str, base: Modint) -> Self {
51 let arr: Vec<Modint> = s.chars().map(Self::ord).map(Modint).collect();
52 Self::build(&arr, base)
53 }
54
55 pub fn hash<'a, R: RangeBounds<usize> + Debug>(&'a self, range: R) -> HashVal<'a> {
58 let (l, r) = self.parse_range(&range);
59 HashVal {
60 rolling_hash: self,
61 length: r - l,
62 hash: self.hash[r] - self.hash[l] * self.power[r - l],
63 }
64 }
65
66 pub fn get_LCP(&self, a: usize, b: usize) -> usize {
69 let len = self.size.saturating_sub(a.max(b));
70 let (mut lo, mut hi) = (0, len + 1);
71 while hi - lo > 1 {
72 let mid = (lo + hi) / 2;
73 if self.hash(a..a + mid) == self.hash(b..b + mid) {
74 lo = mid;
75 } else {
76 hi = mid;
77 }
78 }
79 lo
80 }
81
82 pub fn len(&self) -> usize {
83 self.size
84 }
85
86 #[inline]
88 fn parse_range<R: RangeBounds<usize> + Debug>(&self, range: &R) -> (usize, usize) {
89 let start = match range.start_bound() {
90 Bound::Unbounded => 0,
91 Bound::Excluded(&v) => v + 1,
92 Bound::Included(&v) => v,
93 };
94 let end = match range.end_bound() {
95 Bound::Unbounded => self.size,
96 Bound::Excluded(&v) => v,
97 Bound::Included(&v) => v + 1,
98 };
99 if start <= end && end <= self.size {
100 (start, end)
101 } else {
102 panic!(
103 "Index out of bounds: the len is {} but the range is {:?}",
104 self.size, range
105 );
106 }
107 }
108
109 #[inline]
111 fn ord(c: char) -> usize {
112 c as usize
113 }
114}
115
116#[derive(Clone)]
117pub struct HashVal<'a> {
118 rolling_hash: &'a RollingHash,
119 length: usize,
120 hash: Modint,
121}
122
123impl<'a> HashVal<'a> {
124 pub fn val(&self) -> Modint {
127 self.hash
128 }
129
130 pub fn chain(&self, other: &HashVal<'a>) -> HashVal<'a> {
133 assert_eq!(
134 self.rolling_hash.base, other.rolling_hash.base,
135 "Cannot chain HashVal with different bases"
136 );
137 Self {
138 rolling_hash: self.rolling_hash,
139 length: self.len() + other.len(),
140 hash: self.val() * self.rolling_hash.power[other.len()] + other.val(),
141 }
142 }
143
144 pub fn len(&self) -> usize {
145 self.length
146 }
147}
148
149impl PartialEq for HashVal<'_> {
150 fn eq(&self, other: &Self) -> bool {
151 assert_eq!(
152 self.rolling_hash.base, other.rolling_hash.base,
153 "Cannot compare HashVal with different bases"
154 );
155 self.len() == other.len() && self.val() == other.val()
156 }
157}
158
159impl<'a> Eq for HashVal<'a> {}
160
161impl Debug for HashVal<'_> {
162 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
163 f.debug_struct("HashVal")
164 .field("length", &self.length)
165 .field("hash", &self.hash)
166 .finish()
167 }
168}