@@ -231,8 +231,28 @@ fn prepare<M: Modulus>() -> ButterflyCache<M> {
231231
232232#[ cfg( test) ]
233233mod tests {
234- use crate :: modint:: { Mod998244353 , Modulus , StaticModInt } ;
234+ use crate :: {
235+ modint:: { Mod998244353 , Modulus , StaticModInt } ,
236+ RemEuclidU32 ,
237+ } ;
235238 use rand:: { rngs:: ThreadRng , Rng as _} ;
239+ use std:: {
240+ convert:: { TryFrom , TryInto as _} ,
241+ fmt,
242+ } ;
243+
244+ //https://github.com/atcoder/ac-library/blob/8250de484ae0ab597391db58040a602e0dc1a419/test/unittest/convolution_test.cpp#L51-L71
245+ #[ test]
246+ fn empty ( ) {
247+ assert ! ( super :: convolution_raw:: <i32 , Mod998244353 >( & [ ] , & [ ] ) . is_empty( ) ) ;
248+ assert ! ( super :: convolution_raw:: <i32 , Mod998244353 >( & [ ] , & [ 1 , 2 ] ) . is_empty( ) ) ;
249+ assert ! ( super :: convolution_raw:: <i32 , Mod998244353 >( & [ 1 , 2 ] , & [ ] ) . is_empty( ) ) ;
250+ assert ! ( super :: convolution_raw:: <i32 , Mod998244353 >( & [ 1 ] , & [ ] ) . is_empty( ) ) ;
251+ assert ! ( super :: convolution_raw:: <i64 , Mod998244353 >( & [ ] , & [ ] ) . is_empty( ) ) ;
252+ assert ! ( super :: convolution_raw:: <i64 , Mod998244353 >( & [ ] , & [ 1 , 2 ] ) . is_empty( ) ) ;
253+ assert ! ( super :: convolution:: <Mod998244353 >( & [ ] , & [ ] ) . is_empty( ) ) ;
254+ assert ! ( super :: convolution:: <Mod998244353 >( & [ ] , & [ 1 . into( ) , 2 . into( ) ] ) . is_empty( ) ) ;
255+ }
236256
237257 // https://github.com/atcoder/ac-library/blob/8250de484ae0ab597391db58040a602e0dc1a419/test/unittest/convolution_test.cpp#L73-L85
238258 #[ test]
@@ -267,9 +287,119 @@ mod tests {
267287 test :: < M2 > ( & mut rng) ;
268288 }
269289
290+ // https://github.com/atcoder/ac-library/blob/8250de484ae0ab597391db58040a602e0dc1a419/test/unittest/convolution_test.cpp#L120-L150
291+ #[ test]
292+ fn simple_int ( ) {
293+ simple_raw :: < i32 > ( ) ;
294+ }
295+
296+ // https://github.com/atcoder/ac-library/blob/8250de484ae0ab597391db58040a602e0dc1a419/test/unittest/convolution_test.cpp#L152-L182
297+ #[ test]
298+ fn simple_uint ( ) {
299+ simple_raw :: < u32 > ( ) ;
300+ }
301+
302+ // https://github.com/atcoder/ac-library/blob/8250de484ae0ab597391db58040a602e0dc1a419/test/unittest/convolution_test.cpp#L184-L214
303+ #[ test]
304+ fn simple_ll ( ) {
305+ simple_raw :: < i64 > ( ) ;
306+ }
307+
308+ // https://github.com/atcoder/ac-library/blob/8250de484ae0ab597391db58040a602e0dc1a419/test/unittest/convolution_test.cpp#L216-L246
309+ #[ test]
310+ fn simple_ull ( ) {
311+ simple_raw :: < u64 > ( ) ;
312+ }
313+
314+ // https://github.com/atcoder/ac-library/blob/8250de484ae0ab597391db58040a602e0dc1a419/test/unittest/convolution_test.cpp#L249-L279
315+ #[ test]
316+ fn simple_int128 ( ) {
317+ simple_raw :: < i128 > ( ) ;
318+ }
319+
320+ // https://github.com/atcoder/ac-library/blob/8250de484ae0ab597391db58040a602e0dc1a419/test/unittest/convolution_test.cpp#L281-L311
321+ #[ test]
322+ fn simple_uint128 ( ) {
323+ simple_raw :: < u128 > ( ) ;
324+ }
325+
326+ fn simple_raw < T > ( )
327+ where
328+ T : TryFrom < u32 > + Copy + RemEuclidU32 ,
329+ T :: Error : fmt:: Debug ,
330+ {
331+ const M1 : u32 = 998_244_353 ;
332+ const M2 : u32 = 924_844_033 ;
333+
334+ modulus ! ( M1 , M2 ) ;
335+
336+ fn test < T , M > ( rng : & mut ThreadRng )
337+ where
338+ T : TryFrom < u32 > + Copy + RemEuclidU32 ,
339+ T :: Error : fmt:: Debug ,
340+ M : Modulus ,
341+ {
342+ let mut gen_raw_values = |n| gen_raw_values :: < u32 , Mod998244353 > ( rng, n) ;
343+ for ( n, m) in ( 1 ..20 ) . flat_map ( |i| ( 1 ..20 ) . map ( move |j| ( i, j) ) ) {
344+ let ( a, b) = ( gen_raw_values ( n) , gen_raw_values ( m) ) ;
345+ assert_eq ! (
346+ conv_raw_naive:: <_, M >( & a, & b) ,
347+ super :: convolution_raw:: <_, M >( & a, & b) ,
348+ ) ;
349+ }
350+ }
351+
352+ let mut rng = rand:: thread_rng ( ) ;
353+ test :: < T , M1 > ( & mut rng) ;
354+ test :: < T , M2 > ( & mut rng) ;
355+ }
356+
357+ // https://github.com/atcoder/ac-library/blob/8250de484ae0ab597391db58040a602e0dc1a419/test/unittest/convolution_test.cpp#L315-L329
358+ #[ test]
359+ fn conv_ll ( ) {
360+ let mut rng = rand:: thread_rng ( ) ;
361+ for ( n, m) in ( 1 ..20 ) . flat_map ( |i| ( 1 ..20 ) . map ( move |j| ( i, j) ) ) {
362+ let mut gen =
363+ |n : usize | -> Vec < _ > { ( 0 ..n) . map ( |_| rng. gen_range ( -500_000 , 500_000 ) ) . collect ( ) } ;
364+ let ( a, b) = ( gen ( n) , gen ( m) ) ;
365+ assert_eq ! ( conv_i64_naive( & a, & b) , super :: convolution_i64( & a, & b) ) ;
366+ }
367+ }
368+
369+ // https://github.com/atcoder/ac-library/blob/8250de484ae0ab597391db58040a602e0dc1a419/test/unittest/convolution_test.cpp#L331-L356
370+ #[ test]
371+ fn conv_ll_bound ( ) {
372+ const M1 : u64 = 754_974_721 ; // 2^24
373+ const M2 : u64 = 167_772_161 ; // 2^25
374+ const M3 : u64 = 469_762_049 ; // 2^26
375+ const M2M3 : u64 = M2 * M3 ;
376+ const M1M3 : u64 = M1 * M3 ;
377+ const M1M2 : u64 = M1 * M2 ;
378+
379+ modulus ! ( M1 , M2 , M3 ) ;
380+
381+ for i in -1000 ..=1000 {
382+ let a = vec ! [ 0u64 . wrapping_sub( M1M2 + M1M3 + M2M3 ) as i64 + i] ;
383+ let b = vec ! [ 1 ] ;
384+ assert_eq ! ( a, super :: convolution_i64( & a, & b) ) ;
385+ }
386+
387+ for i in 0 ..1000 {
388+ let a = vec ! [ i64 :: min_value( ) + i] ;
389+ let b = vec ! [ 1 ] ;
390+ assert_eq ! ( a, super :: convolution_i64( & a, & b) ) ;
391+ }
392+
393+ for i in 0 ..1000 {
394+ let a = vec ! [ i64 :: max_value( ) - i] ;
395+ let b = vec ! [ 1 ] ;
396+ assert_eq ! ( a, super :: convolution_i64( & a, & b) ) ;
397+ }
398+ }
399+
270400 // https://github.com/atcoder/ac-library/blob/8250de484ae0ab597391db58040a602e0dc1a419/test/unittest/convolution_test.cpp#L358-L371
271401 #[ test]
272- fn conv641 ( ) {
402+ fn conv_641 ( ) {
273403 const M : u32 = 641 ;
274404 modulus ! ( M ) ;
275405
@@ -281,7 +411,7 @@ mod tests {
281411
282412 // https://github.com/atcoder/ac-library/blob/8250de484ae0ab597391db58040a602e0dc1a419/test/unittest/convolution_test.cpp#L373-L386
283413 #[ test]
284- fn conv18433 ( ) {
414+ fn conv_18433 ( ) {
285415 const M : u32 = 18433 ;
286416 modulus ! ( M ) ;
287417
@@ -304,9 +434,43 @@ mod tests {
304434 c
305435 }
306436
437+ fn conv_raw_naive < T , M > ( a : & [ T ] , b : & [ T ] ) -> Vec < T >
438+ where
439+ T : TryFrom < u32 > + Copy + RemEuclidU32 ,
440+ T :: Error : fmt:: Debug ,
441+ M : Modulus ,
442+ {
443+ conv_naive :: < M > (
444+ & a. iter ( ) . copied ( ) . map ( Into :: into) . collect :: < Vec < _ > > ( ) ,
445+ & b. iter ( ) . copied ( ) . map ( Into :: into) . collect :: < Vec < _ > > ( ) ,
446+ )
447+ . into_iter ( )
448+ . map ( |x| x. val ( ) . try_into ( ) . unwrap ( ) )
449+ . collect ( )
450+ }
451+
452+ #[ allow( clippy:: many_single_char_names) ]
453+ fn conv_i64_naive ( a : & [ i64 ] , b : & [ i64 ] ) -> Vec < i64 > {
454+ let ( n, m) = ( a. len ( ) , b. len ( ) ) ;
455+ let mut c = vec ! [ 0 ; n + m - 1 ] ;
456+ for ( i, j) in ( 0 ..n) . flat_map ( |i| ( 0 ..m) . map ( move |j| ( i, j) ) ) {
457+ c[ i + j] += a[ i] * b[ j] ;
458+ }
459+ c
460+ }
461+
307462 fn gen_values < M : Modulus > ( rng : & mut ThreadRng , n : usize ) -> Vec < StaticModInt < M > > {
463+ ( 0 ..n) . map ( |_| rng. gen_range ( 0 , M :: VALUE ) . into ( ) ) . collect ( )
464+ }
465+
466+ fn gen_raw_values < T , M > ( rng : & mut ThreadRng , n : usize ) -> Vec < T >
467+ where
468+ T : TryFrom < u32 > ,
469+ T :: Error : fmt:: Debug ,
470+ M : Modulus ,
471+ {
308472 ( 0 ..n)
309- . map ( |_| StaticModInt :: raw ( rng. gen_range ( 0 , M :: VALUE ) ) )
473+ . map ( |_| rng. gen_range ( 0 , M :: VALUE ) . try_into ( ) . unwrap ( ) )
310474 . collect ( )
311475 }
312476}
0 commit comments