@@ -10,8 +10,7 @@ using namespace std;
1010using namespace cp_algo ;
1111
1212constexpr int mod = 998244353 ;
13- constexpr auto mod4 = u64x4() + mod;
14- constexpr auto imod4 = u64x4() - math::inv2(mod);
13+ constexpr int imod = -math::inv2(mod);
1514
1615void facts_inplace (vector<int > &args) {
1716 constexpr int block = 1 << 16 ;
@@ -26,39 +25,40 @@ void facts_inplace(vector<int> &args) {
2625 args_per_block[(mod - x - 1 ) / block].push_back (i);
2726 }
2827 }
29- uint64_t b2x32 = (1ULL << 32 ) % mod;
28+ uint32_t b2x32 = (1ULL << 32 ) % mod;
3029 uint64_t fact = 1 ;
31- const int K = 4 ;
32- for (uint64_t b = 0 ; b <= limit; b += K * block) {
33- u64x4 cur[K];
34- static array<u64x4, block / 4 > prods[K];
35- for (int z = 0 ; z < K; z++) {
36- for (int j = 0 ; j < 4 ; j++) {
37- cur[z][j] = b + z * block + j * block / 4 ;
30+ const int accum = 4 ;
31+ const int simd_size = 8 ;
32+ for (uint64_t b = 0 ; b <= limit; b += accum * block) {
33+ u32x8 cur[accum];
34+ static array<u32x8, block / simd_size> prods[accum];
35+ for (int z = 0 ; z < accum; z++) {
36+ for (int j = 0 ; j < simd_size; j++) {
37+ cur[z][j] = uint32_t (b + z * block + j * (block / simd_size));
3838 prods[z][0 ][j] = cur[z][j] + !(b || z || j);
3939#pragma GCC diagnostic push
4040#pragma GCC diagnostic ignored "-Wmaybe-uninitialized"
41- cur[z][j] = cur[z][j] * b2x32 % mod;
41+ cur[z][j] = uint32_t ( uint64_t ( cur[z][j]) * b2x32 % mod) ;
4242#pragma GCC diagnostic pop
4343 }
4444 }
45- for (int i = 1 ; i < block / 4 ; i++) {
46- for (int z = 0 ; z < K ; z++) {
45+ for (int i = 1 ; i < block / simd_size ; i++) {
46+ for (int z = 0 ; z < accum ; z++) {
4747 cur[z] += b2x32;
4848 cur[z] = cur[z] >= mod ? cur[z] - mod : cur[z];
49- prods[z][i] = montgomery_mul (prods[z][i - 1 ], cur[z], mod4, imod4 );
49+ prods[z][i] = montgomery_mul (prods[z][i - 1 ], cur[z], mod, imod );
5050 }
5151 }
52- for (int z = 0 ; z < K ; z++) {
52+ for (int z = 0 ; z < accum ; z++) {
5353 uint64_t bl = b + z * block;
5454 for (auto i: args_per_block[bl / block]) {
5555 size_t x = args[i];
5656 if (x >= mod / 2 ) {
5757 x = mod - x - 1 ;
5858 }
5959 x -= bl;
60- auto pre_blocks = x / (block / 4 );
61- auto in_block = x % (block / 4 );
60+ auto pre_blocks = x / (block / simd_size );
61+ auto in_block = x % (block / simd_size );
6262 auto ans = fact * prods[z][in_block][pre_blocks] % mod;
6363 for (size_t j = 0 ; j < pre_blocks; j++) {
6464 ans = ans * prods[z].back ()[j] % mod;
@@ -71,7 +71,7 @@ void facts_inplace(vector<int> &args) {
7171 }
7272 }
7373 args_per_block[bl / block].clear ();
74- for (int j = 0 ; j < 4 ; j++) {
74+ for (int j = 0 ; j < simd_size ; j++) {
7575 fact = fact * prods[z].back ()[j] % mod;
7676 }
7777 }
0 commit comments