|
| 1 | +macro_rules! modulus { |
| 2 | + ($($name:ident),*) => { |
| 3 | + $( |
| 4 | + #[derive(Copy, Clone, Eq, PartialEq)] |
| 5 | + enum $name {} |
| 6 | + |
| 7 | + impl Modulus for $name { |
| 8 | + const VALUE: u32 = $name as _; |
| 9 | + const HINT_VALUE_IS_PRIME: bool = true; |
| 10 | + |
| 11 | + fn butterfly_cache() -> &'static ::std::thread::LocalKey<::std::cell::RefCell<::std::option::Option<crate::modint::ButterflyCache<Self>>>> { |
| 12 | + thread_local! { |
| 13 | + static BUTTERFLY_CACHE: ::std::cell::RefCell<::std::option::Option<crate::modint::ButterflyCache<$name>>> = ::std::default::Default::default(); |
| 14 | + } |
| 15 | + &BUTTERFLY_CACHE |
| 16 | + } |
| 17 | + } |
| 18 | + )* |
| 19 | + }; |
| 20 | +} |
| 21 | + |
1 | 22 | use crate::{ |
2 | 23 | internal_bit, internal_math, |
3 | 24 | modint::{ButterflyCache, Modulus, RemEuclidU32, StaticModInt}, |
4 | 25 | }; |
5 | 26 | use std::{ |
6 | | - cell::RefCell, |
7 | 27 | cmp, |
8 | 28 | convert::{TryFrom, TryInto as _}, |
9 | 29 | fmt, |
10 | | - thread::LocalKey, |
11 | 30 | }; |
12 | 31 |
|
13 | 32 | #[allow(clippy::many_single_char_names)] |
@@ -77,28 +96,7 @@ pub fn convolution_i64(a: &[i64], b: &[i64]) -> Vec<i64> { |
77 | 96 | const M1M2: u64 = M1 * M2; |
78 | 97 | const M1M2M3: u64 = M1M2.wrapping_mul(M3); |
79 | 98 |
|
80 | | - macro_rules! moduli { |
81 | | - ($($name:ident),*) => { |
82 | | - $( |
83 | | - #[derive(Copy, Clone, Eq, PartialEq)] |
84 | | - enum $name {} |
85 | | - |
86 | | - impl Modulus for $name { |
87 | | - const VALUE: u32 = $name as _; |
88 | | - const HINT_VALUE_IS_PRIME: bool = true; |
89 | | - |
90 | | - fn butterfly_cache() -> &'static LocalKey<RefCell<Option<ButterflyCache<Self>>>> { |
91 | | - thread_local! { |
92 | | - static BUTTERFLY_CACHE: RefCell<Option<ButterflyCache<$name>>> = RefCell::default(); |
93 | | - } |
94 | | - &BUTTERFLY_CACHE |
95 | | - } |
96 | | - } |
97 | | - )* |
98 | | - }; |
99 | | - } |
100 | | - |
101 | | - moduli!(M1, M2, M3); |
| 99 | + modulus!(M1, M2, M3); |
102 | 100 |
|
103 | 101 | if a.is_empty() || b.is_empty() { |
104 | 102 | return vec![]; |
@@ -230,3 +228,85 @@ fn prepare<M: Modulus>() -> ButterflyCache<M> { |
230 | 228 | .collect(); |
231 | 229 | ButterflyCache { sum_e, sum_ie } |
232 | 230 | } |
| 231 | + |
| 232 | +#[cfg(test)] |
| 233 | +mod tests { |
| 234 | + use crate::modint::{Mod998244353, Modulus, StaticModInt}; |
| 235 | + use rand::{rngs::ThreadRng, Rng as _}; |
| 236 | + |
| 237 | + // https://github.com/atcoder/ac-library/blob/8250de484ae0ab597391db58040a602e0dc1a419/test/unittest/convolution_test.cpp#L73-L85 |
| 238 | + #[test] |
| 239 | + fn mid() { |
| 240 | + const N: usize = 1234; |
| 241 | + const M: usize = 2345; |
| 242 | + |
| 243 | + let mut rng = rand::thread_rng(); |
| 244 | + let mut gen_values = |n| gen_values::<Mod998244353>(&mut rng, n); |
| 245 | + let (a, b) = (gen_values(N), gen_values(M)); |
| 246 | + assert_eq!(conv_naive(&a, &b), super::convolution(&a, &b)); |
| 247 | + } |
| 248 | + |
| 249 | + // https://github.com/atcoder/ac-library/blob/8250de484ae0ab597391db58040a602e0dc1a419/test/unittest/convolution_test.cpp#L87-L118 |
| 250 | + #[test] |
| 251 | + fn simple_s_mod() { |
| 252 | + const M1: u32 = 998_244_353; |
| 253 | + const M2: u32 = 924_844_033; |
| 254 | + |
| 255 | + modulus!(M1, M2); |
| 256 | + |
| 257 | + fn test<M: Modulus>(rng: &mut ThreadRng) { |
| 258 | + let mut gen_values = |n| gen_values::<Mod998244353>(rng, n); |
| 259 | + for (n, m) in (1..20).flat_map(|i| (1..20).map(move |j| (i, j))) { |
| 260 | + let (a, b) = (gen_values(n), gen_values(m)); |
| 261 | + assert_eq!(conv_naive(&a, &b), super::convolution(&a, &b)); |
| 262 | + } |
| 263 | + } |
| 264 | + |
| 265 | + let mut rng = rand::thread_rng(); |
| 266 | + test::<M1>(&mut rng); |
| 267 | + test::<M2>(&mut rng); |
| 268 | + } |
| 269 | + |
| 270 | + // https://github.com/atcoder/ac-library/blob/8250de484ae0ab597391db58040a602e0dc1a419/test/unittest/convolution_test.cpp#L358-L371 |
| 271 | + #[test] |
| 272 | + fn conv641() { |
| 273 | + const M: u32 = 641; |
| 274 | + modulus!(M); |
| 275 | + |
| 276 | + let mut rng = rand::thread_rng(); |
| 277 | + let mut gen_values = |n| gen_values::<M>(&mut rng, n); |
| 278 | + let (a, b) = (gen_values(64), gen_values(65)); |
| 279 | + assert_eq!(conv_naive(&a, &b), super::convolution(&a, &b)); |
| 280 | + } |
| 281 | + |
| 282 | + // https://github.com/atcoder/ac-library/blob/8250de484ae0ab597391db58040a602e0dc1a419/test/unittest/convolution_test.cpp#L373-L386 |
| 283 | + #[test] |
| 284 | + fn conv18433() { |
| 285 | + const M: u32 = 18433; |
| 286 | + modulus!(M); |
| 287 | + |
| 288 | + let mut rng = rand::thread_rng(); |
| 289 | + let mut gen_values = |n| gen_values::<M>(&mut rng, n); |
| 290 | + let (a, b) = (gen_values(1024), gen_values(1025)); |
| 291 | + assert_eq!(conv_naive(&a, &b), super::convolution(&a, &b)); |
| 292 | + } |
| 293 | + |
| 294 | + #[allow(clippy::many_single_char_names)] |
| 295 | + fn conv_naive<M: Modulus>( |
| 296 | + a: &[StaticModInt<M>], |
| 297 | + b: &[StaticModInt<M>], |
| 298 | + ) -> Vec<StaticModInt<M>> { |
| 299 | + let (n, m) = (a.len(), b.len()); |
| 300 | + let mut c = vec![StaticModInt::raw(0); n + m - 1]; |
| 301 | + for (i, j) in (0..n).flat_map(|i| (0..m).map(move |j| (i, j))) { |
| 302 | + c[i + j] += a[i] * b[j]; |
| 303 | + } |
| 304 | + c |
| 305 | + } |
| 306 | + |
| 307 | + fn gen_values<M: Modulus>(rng: &mut ThreadRng, n: usize) -> Vec<StaticModInt<M>> { |
| 308 | + (0..n) |
| 309 | + .map(|_| StaticModInt::raw(rng.gen_range(0, M::VALUE))) |
| 310 | + .collect() |
| 311 | + } |
| 312 | +} |
0 commit comments