|
1 | 1 | use crate::{ |
2 | 2 | internal_bit, internal_math, |
3 | | - modint::{ButterflyCache, Modulus, StaticModInt}, |
| 3 | + modint::{ButterflyCache, Modulus, RemEuclidU32, StaticModInt}, |
| 4 | +}; |
| 5 | +use std::{ |
| 6 | + cell::RefCell, |
| 7 | + cmp, |
| 8 | + convert::{TryFrom, TryInto as _}, |
| 9 | + fmt, |
| 10 | + thread::LocalKey, |
4 | 11 | }; |
5 | | -use std::{cell::RefCell, cmp, thread::LocalKey}; |
6 | 12 |
|
7 | 13 | #[allow(clippy::many_single_char_names)] |
8 | 14 | pub fn convolution<M: Modulus>( |
@@ -43,6 +49,26 @@ pub fn convolution<M: Modulus>( |
43 | 49 | a |
44 | 50 | } |
45 | 51 |
|
| 52 | +pub fn convolution_raw< |
| 53 | + T: RemEuclidU32 + TryFrom<u32, Error = E> + Clone, |
| 54 | + E: fmt::Debug, |
| 55 | + M: Modulus, |
| 56 | +>( |
| 57 | + a: &[T], |
| 58 | + b: &[T], |
| 59 | +) -> Vec<T> { |
| 60 | + let a = a.iter().cloned().map(Into::into).collect::<Vec<_>>(); |
| 61 | + let b = b.iter().cloned().map(Into::into).collect::<Vec<_>>(); |
| 62 | + convolution::<M>(&a, &b) |
| 63 | + .into_iter() |
| 64 | + .map(|z| { |
| 65 | + z.val() |
| 66 | + .try_into() |
| 67 | + .expect("the numeric type is smaller than the modulus") |
| 68 | + }) |
| 69 | + .collect() |
| 70 | +} |
| 71 | + |
46 | 72 | #[allow(clippy::many_single_char_names)] |
47 | 73 | pub fn convolution_i64(a: &[i64], b: &[i64]) -> Vec<i64> { |
48 | 74 | const M1: u64 = 754_974_721; // 2^24 |
@@ -84,17 +110,9 @@ pub fn convolution_i64(a: &[i64], b: &[i64]) -> Vec<i64> { |
84 | 110 | let i2 = internal_math::inv_gcd(M1M3 as _, M2 as _).1; |
85 | 111 | let i3 = internal_math::inv_gcd(M1M2 as _, M3 as _).1; |
86 | 112 |
|
87 | | - let (c1, c2, c3) = { |
88 | | - fn c<M: Modulus>(a: &[i64], b: &[i64]) -> Vec<i64> { |
89 | | - let a = a.iter().copied().map(Into::into).collect::<Vec<_>>(); |
90 | | - let b = b.iter().copied().map(Into::into).collect::<Vec<_>>(); |
91 | | - convolution::<M>(&a, &b) |
92 | | - .into_iter() |
93 | | - .map(|z| z.val().into()) |
94 | | - .collect() |
95 | | - } |
96 | | - (c::<M1>(a, b), c::<M2>(a, b), c::<M3>(a, b)) |
97 | | - }; |
| 113 | + let c1 = convolution_raw::<i64, _, M1>(a, b); |
| 114 | + let c2 = convolution_raw::<i64, _, M2>(a, b); |
| 115 | + let c3 = convolution_raw::<i64, _, M3>(a, b); |
98 | 116 |
|
99 | 117 | c1.into_iter() |
100 | 118 | .zip(c2) |
|
0 commit comments