cp_library_rs/string/
rolling_hash.rs

1//! ローリングハッシュ
2
3#![allow(clippy::len_without_is_empty)]
4
5use std::{
6    fmt::Debug,
7    ops::{Bound, RangeBounds},
8};
9
10use crate::number_theory::modint_for_rollinghash::modint::Modint;
11
12use num_traits::{One, Zero};
13
14/// ローリングハッシュ
15///
16/// 文字列をハッシュし,連続部分列の一致判定を $`O(1)`$ で行う.
17#[derive(Debug)]
18pub struct RollingHash {
19    pub size: usize,
20    power: Vec<Modint>,
21    hash: Vec<Modint>,
22    base: Modint,
23}
24
25impl RollingHash {
26    /// 初期化
27    pub fn build(arr: &[Modint], base: Modint) -> Self {
28        let size = arr.len();
29        let mut power = vec![Modint(1); size + 1];
30        let mut hash = vec![Modint(0); size + 1];
31
32        // hashを初期化
33        let (mut h, mut p) = (Modint::zero(), Modint::one());
34        for i in 0..size {
35            h = arr[i] + (h * base);
36            p *= base;
37            hash[i + 1] = h;
38            power[i + 1] = p;
39        }
40
41        Self {
42            size,
43            power,
44            hash,
45            base,
46        }
47    }
48
49    /// 文字列から生成
50    pub fn from_str(s: &str, base: Modint) -> Self {
51        let arr: Vec<Modint> = s.chars().map(Self::ord).map(Modint).collect();
52        Self::build(&arr, base)
53    }
54
55    /// `l..r`のハッシュを取得
56    /// - 時間計算量: $`O(1)`$
57    pub fn hash<'a, R: RangeBounds<usize> + Debug>(&'a self, range: R) -> HashVal<'a> {
58        let (l, r) = self.parse_range(&range);
59        HashVal {
60            rolling_hash: self,
61            length: r - l,
62            hash: self.hash[r] - self.hash[l] * self.power[r - l],
63        }
64    }
65
66    /// `S[a..]`, `S[b..]` の最長共通接頭辞の長さを調べる
67    /// - 時間計算量: $`O(\log N)`$
68    pub fn get_LCP(&self, a: usize, b: usize) -> usize {
69        let len = self.size.saturating_sub(a.max(b));
70        let (mut lo, mut hi) = (0, len + 1);
71        while hi - lo > 1 {
72            let mid = (lo + hi) / 2;
73            if self.hash(a..a + mid) == self.hash(b..b + mid) {
74                lo = mid;
75            } else {
76                hi = mid;
77            }
78        }
79        lo
80    }
81
82    pub fn len(&self) -> usize {
83        self.size
84    }
85
86    // ========== internal ==========
87    #[inline]
88    fn parse_range<R: RangeBounds<usize> + Debug>(&self, range: &R) -> (usize, usize) {
89        let start = match range.start_bound() {
90            Bound::Unbounded => 0,
91            Bound::Excluded(&v) => v + 1,
92            Bound::Included(&v) => v,
93        };
94        let end = match range.end_bound() {
95            Bound::Unbounded => self.size,
96            Bound::Excluded(&v) => v,
97            Bound::Included(&v) => v + 1,
98        };
99        if start <= end && end <= self.size {
100            (start, end)
101        } else {
102            panic!(
103                "Index out of bounds: the len is {} but the range is {:?}",
104                self.size, range
105            );
106        }
107    }
108
109    /// 文字 c の ASCII コードを返す
110    #[inline]
111    fn ord(c: char) -> usize {
112        c as usize
113    }
114}
115
116#[derive(Clone)]
117pub struct HashVal<'a> {
118    rolling_hash: &'a RollingHash,
119    length: usize,
120    hash: Modint,
121}
122
123impl<'a> HashVal<'a> {
124    /// ハッシュの値を返す
125    /// - 時間計算量: $`O(1)`$
126    pub fn val(&self) -> Modint {
127        self.hash
128    }
129
130    /// ハッシュ同士を連結
131    /// - 時間計算量: $`O(1)`$
132    pub fn chain(&self, other: &HashVal<'a>) -> HashVal<'a> {
133        assert_eq!(
134            self.rolling_hash.base, other.rolling_hash.base,
135            "Cannot chain HashVal with different bases"
136        );
137        Self {
138            rolling_hash: self.rolling_hash,
139            length: self.len() + other.len(),
140            hash: self.val() * self.rolling_hash.power[other.len()] + other.val(),
141        }
142    }
143
144    pub fn len(&self) -> usize {
145        self.length
146    }
147}
148
149impl PartialEq for HashVal<'_> {
150    fn eq(&self, other: &Self) -> bool {
151        assert_eq!(
152            self.rolling_hash.base, other.rolling_hash.base,
153            "Cannot compare HashVal with different bases"
154        );
155        self.len() == other.len() && self.val() == other.val()
156    }
157}
158
159impl<'a> Eq for HashVal<'a> {}
160
161impl Debug for HashVal<'_> {
162    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
163        f.debug_struct("HashVal")
164            .field("length", &self.length)
165            .field("hash", &self.hash)
166            .finish()
167    }
168}