@@ -38,24 +38,49 @@ namespace cp_algo {
3838 };
3939 }
4040
41- [[gnu::always_inline]] inline u64x4 montgomery_reduce (u64x4 x, u64x4 mod, u64x4 imod) {
42- auto x_ninv = u64x4 (u32x8 (x) * u32x8 (imod));
41+ [[gnu::always_inline]] inline u64x4 montgomery_reduce (u64x4 x, uint32_t mod, uint32_t imod) {
42+ auto x_ninv = u64x4 (u32x8 (x) * ( u32x8 () + imod));
4343#ifdef __AVX2__
44- x += u64x4 (_mm256_mul_epu32 (__m256i (x_ninv), __m256i (mod) ));
44+ x += u64x4 (_mm256_mul_epu32 (__m256i (x_ninv), __m256i () + mod ));
4545#else
4646 x += x_ninv * mod;
4747#endif
4848 return x >> 32 ;
4949 }
5050
51- [[gnu::always_inline]] inline u64x4 montgomery_mul (u64x4 x, u64x4 y, u64x4 mod, u64x4 imod) {
51+ [[gnu::always_inline]] inline u64x4 montgomery_mul (u64x4 x, u64x4 y, uint32_t mod, uint32_t imod) {
5252#ifdef __AVX2__
5353 return montgomery_reduce (u64x4 (_mm256_mul_epu32 (__m256i (x), __m256i (y))), mod, imod);
5454#else
5555 return montgomery_reduce (x * y, mod, imod);
5656#endif
5757 }
5858
59+ u32x8 montgomery_mul (u32x8 x, u32x8 y, uint32_t mod, uint32_t imod) {
60+ auto x0246 = u64x4 (x) & uint32_t (-1 );
61+ auto y0246 = u64x4 (y) & uint32_t (-1 );
62+ auto x1357 = u64x4 (x) >> 32 ;
63+ auto y1357 = u64x4 (y) >> 32 ;
64+ #ifdef __AVX2__
65+ auto xy0246 = u64x4 (_mm256_mul_epu32 (__m256i (x0246), __m256i (y0246)));
66+ auto xy1357 = u64x4 (_mm256_mul_epu32 (__m256i (x1357), __m256i (y1357)));
67+ #else
68+ u64x4 xy0246 = x0246 * y0246;
69+ u64x4 xy1357 = x1357 * y1357;
70+ #endif
71+ auto xy_inv = u64x4 (u32x8 (xy0246 | (xy1357 << 32 )) * (u32x8 () + imod));
72+ auto xy_inv0246 = xy_inv & uint32_t (-1 );
73+ auto xy_inv1357 = xy_inv >> 32 ;
74+ #ifdef __AVX2__
75+ xy0246 += u64x4 (_mm256_mul_epu32 (__m256i (xy_inv0246), __m256i () + mod));
76+ xy1357 += u64x4 (_mm256_mul_epu32 (__m256i (xy_inv1357), __m256i () + mod));
77+ #else
78+ xy0246 += xy_inv0246 * mod;
79+ xy1357 += xy_inv1357 * mod;
80+ #endif
81+ return u32x8 ((xy0246 >> 32 ) | (xy1357 & -1ULL << 32 ));
82+ }
83+
5984 [[gnu::always_inline]] inline dx4 rotate_right (dx4 x) {
6085 static constexpr u64x4 shuffler = {3 , 0 , 1 , 2 };
6186 return __builtin_shuffle (x, shuffler);
0 commit comments