cp_library_rs/convolution/
ntt.rs

1//! 数論変換
2
3use std::{
4    marker::PhantomData,
5    ops::{Add, Div, Mul, MulAssign, Sub},
6};
7
8use crate::number_theory::modint::{Fp, M998};
9
10use num_traits::Zero;
11
12/// FFTに必要な関数
13pub trait NTTFriendly<Rhs = Self, Output = Self>:
14    Clone
15    + Copy
16    + Add<Rhs, Output = Output>
17    + Sub<Rhs, Output = Output>
18    + Mul<Rhs, Output = Output>
19    + Div<Rhs, Output = Output>
20    + MulAssign<Rhs>
21    + Zero
22    + From<usize>
23    + Fp
24{
25    /// M = 2^k * m + 1 を満たすような k
26    fn order() -> usize;
27    /// M = 2^k * m + 1 を満たすような m
28    fn rem() -> usize;
29    /// 原始根
30    fn root() -> Self;
31    /// 2^m 乗根
32    fn root_pow2m(a: usize) -> Self {
33        let p = Self::rem() << (Self::order() - a);
34        Self::root().pow(p)
35    }
36}
37
38impl NTTFriendly for M998 {
39    fn order() -> usize {
40        23
41    }
42    fn rem() -> usize {
43        119
44    }
45    fn root() -> Self {
46        Self(3)
47    }
48}
49
50/// 高速フーリエ変換の実装
51pub struct FFT<T: NTTFriendly>(PhantomData<T>);
52
53impl<T: NTTFriendly> FFT<T> {
54    /// 入力された配列をフーリエ変換する
55    pub fn fft(X: &[T]) -> Result<Vec<T>, &'static str> {
56        let (i, X) = Self::extend_array(X)?;
57        let w = T::root_pow2m(i);
58        Ok(Self::fft_core(X, w))
59    }
60
61    /// 入力された配列をフーリエ逆変換する
62    pub fn ifft(F: &[T]) -> Result<Vec<T>, &'static str> {
63        let (i, F) = Self::extend_array(F)?;
64        let w = T::root_pow2m(i);
65        let winv = w.inv();
66        let mut res = Self::fft_core(F, winv);
67        let n = res.len();
68        // 逆変換後の配列を正規化
69        let inv_n = T::from(n).inv();
70        res.iter_mut().for_each(|v| *v *= inv_n);
71        Ok(res)
72    }
73
74    /// フーリエ変換,フーリエ逆変換の共通部分
75    ///
76    /// - `w`: 回転演算子
77    fn fft_core(X: Vec<T>, w: T) -> Vec<T> {
78        let n = X.len();
79
80        if n == 1 {
81            return X.to_vec();
82        }
83
84        let (X_even, X_odd): (Vec<_>, Vec<_>) = (0..n / 2)
85            .map(|i| {
86                let l = X[i];
87                let r = X[i + n / 2];
88                (l + r, w.pow(i) * (l - r))
89            })
90            .unzip();
91
92        // 再帰的にFFT
93        let new_w = w.pow(2);
94        let Y_even = Self::fft_core(X_even, new_w);
95        let Y_odd = Self::fft_core(X_odd, new_w);
96
97        // マージ
98        Y_even
99            .into_iter()
100            .zip(Y_odd)
101            .flat_map(|(e, o)| [e, o])
102            .collect()
103    }
104
105    /// 長さが 2 べきになるように配列を生成する
106    ///
107    /// **Arguments**
108    /// - `array`: 配列
109    ///
110    /// **Returns**
111    /// - `(i, res)`: 配列の長さを 2^i に拡張した結果
112    fn extend_array(array: &[T]) -> Result<(usize, Vec<T>), &'static str> {
113        let n = array.len();
114        // 2^i >= n となるような最小の i
115        let mut i = 0;
116        let mut n_ = 1;
117        while n_ < n {
118            i += 1;
119            n_ *= 2;
120        }
121
122        if i > T::order() {
123            return Err("The prime p does not have enough factors of 2 in (p - 1).");
124        }
125
126        // 配列を生成
127        let mut res = array.to_vec();
128
129        // 残りをゼロ埋め
130        res.extend((0..n_ - n).map(|_| T::zero()));
131
132        Ok((i, res))
133    }
134}