cp_library_rs/convolution/
ntt.rs1use std::{
4 marker::PhantomData,
5 ops::{Add, Div, Mul, MulAssign, Sub},
6};
7
8use crate::number_theory::modint::{Fp, M998};
9
10use num_traits::Zero;
11
12pub trait NTTFriendly<Rhs = Self, Output = Self>:
14 Clone
15 + Copy
16 + Add<Rhs, Output = Output>
17 + Sub<Rhs, Output = Output>
18 + Mul<Rhs, Output = Output>
19 + Div<Rhs, Output = Output>
20 + MulAssign<Rhs>
21 + Zero
22 + From<usize>
23 + Fp
24{
25 fn order() -> usize;
27 fn rem() -> usize;
29 fn root() -> Self;
31 fn root_pow2m(a: usize) -> Self {
33 let p = Self::rem() << (Self::order() - a);
34 Self::root().pow(p)
35 }
36}
37
38impl NTTFriendly for M998 {
39 fn order() -> usize {
40 23
41 }
42 fn rem() -> usize {
43 119
44 }
45 fn root() -> Self {
46 Self(3)
47 }
48}
49
50pub struct FFT<T: NTTFriendly>(PhantomData<T>);
52
53impl<T: NTTFriendly> FFT<T> {
54 pub fn fft(X: &[T]) -> Result<Vec<T>, &'static str> {
56 let (i, X) = Self::extend_array(X)?;
57 let w = T::root_pow2m(i);
58 Ok(Self::fft_core(X, w))
59 }
60
61 pub fn ifft(F: &[T]) -> Result<Vec<T>, &'static str> {
63 let (i, F) = Self::extend_array(F)?;
64 let w = T::root_pow2m(i);
65 let winv = w.inv();
66 let mut res = Self::fft_core(F, winv);
67 let n = res.len();
68 let inv_n = T::from(n).inv();
70 res.iter_mut().for_each(|v| *v *= inv_n);
71 Ok(res)
72 }
73
74 fn fft_core(X: Vec<T>, w: T) -> Vec<T> {
78 let n = X.len();
79
80 if n == 1 {
81 return X.to_vec();
82 }
83
84 let (X_even, X_odd): (Vec<_>, Vec<_>) = (0..n / 2)
85 .map(|i| {
86 let l = X[i];
87 let r = X[i + n / 2];
88 (l + r, w.pow(i) * (l - r))
89 })
90 .unzip();
91
92 let new_w = w.pow(2);
94 let Y_even = Self::fft_core(X_even, new_w);
95 let Y_odd = Self::fft_core(X_odd, new_w);
96
97 Y_even
99 .into_iter()
100 .zip(Y_odd)
101 .flat_map(|(e, o)| [e, o])
102 .collect()
103 }
104
105 fn extend_array(array: &[T]) -> Result<(usize, Vec<T>), &'static str> {
113 let n = array.len();
114 let mut i = 0;
116 let mut n_ = 1;
117 while n_ < n {
118 i += 1;
119 n_ *= 2;
120 }
121
122 if i > T::order() {
123 return Err("The prime p does not have enough factors of 2 in (p - 1).");
124 }
125
126 let mut res = array.to_vec();
128
129 res.extend((0..n_ - n).map(|_| T::zero()));
131
132 Ok((i, res))
133 }
134}