22#define PROBLEM " https://judge.yosupo.jp/problem/many_factorials"
33#pragma GCC optimize("Ofast,unroll-loops")
44#include < bits/stdc++.h>
5- #define CP_ALGO_CHECKPOINT
6- #include " cp-algo/util/checkpoint.hpp"
5+ // #define CP_ALGO_CHECKPOINT
76#include " blazingio/blazingio.min.hpp"
7+ #include " cp-algo/util/checkpoint.hpp"
88#include " cp-algo/util/simd.hpp"
99#include " cp-algo/math/common.hpp"
1010
@@ -14,84 +14,97 @@ using namespace cp_algo;
1414constexpr int mod = 998244353 ;
1515constexpr int imod = -math::inv2(mod);
1616
17- void facts_inplace (vector<int > &args) {
18- constexpr int block = 1 << 16 ;
19- static basic_string<size_t > args_per_block[mod / block];
20- uint64_t limit = 0 ;
21- for (auto [i, x]: args | views::enumerate) {
22- if (x < mod / 2 ) {
23- limit = max (limit, uint64_t (x));
24- args_per_block[x / block].push_back (i);
25- } else {
26- limit = max (limit, uint64_t (mod - x - 1 ));
27- args_per_block[(mod - x - 1 ) / block].push_back (i);
17+ vector<int > facts (vector<int > const & args) {
18+ constexpr int accum = 4 ;
19+ constexpr int simd_size = 8 ;
20+ constexpr int block = 1 << 18 ;
21+ constexpr int subblock = block / simd_size;
22+ static basic_string<array<int , 2 >> odd_args_per_block[mod / subblock];
23+ static basic_string<array<int , 2 >> reg_args_per_block[mod / subblock];
24+ constexpr int limit_reg = mod / 64 ;
25+ int limit_odd = 0 ;
26+
27+ vector<int > res (size (args), 1 );
28+ auto prod_mod = [&](uint64_t a, uint64_t b) {
29+ return (a * b) % mod;
30+ };
31+ for (auto [i, xy]: views::zip (args, res) | views::enumerate) {
32+ auto [x, y] = xy;
33+ auto t = x;
34+ if (t >= mod / 2 ) {
35+ t = mod - t - 1 ;
36+ y = t % 2 ? 1 : mod - 1 ;
37+ }
38+ int pw = 0 ;
39+ while (t > limit_reg) {
40+ limit_odd = max (limit_odd, (t - 1 ) / 2 );
41+ odd_args_per_block[(t - 1 ) / 2 / subblock].push_back ({int (i), (t - 1 ) / 2 });
42+ t /= 2 ;
43+ pw += t;
2844 }
45+ reg_args_per_block[t / subblock].push_back ({int (i), t});
46+ y = int (y * math::bpow (2 , pw, 1ULL , prod_mod) % mod);
2947 }
3048 cp_algo::checkpoint (" init" );
3149 uint32_t b2x32 = (1ULL << 32 ) % mod;
32- uint64_t fact = 1 ;
33- const int accum = 4 ;
34- const int simd_size = 8 ;
35- for (uint64_t b = 0 ; b <= limit; b += accum * block) {
36- u32x8 cur[accum];
37- static array<u32x8, block / simd_size> prods[accum];
38- for (int z = 0 ; z < accum; z++) {
39- for (int j = 0 ; j < simd_size; j++) {
40- cur[z][j] = uint32_t (b + z * block + j * (block / simd_size));
41- prods[z][0 ][j] = cur[z][j] + !(b || z || j);
42- #pragma GCC diagnostic push
43- #pragma GCC diagnostic ignored "-Wmaybe-uninitialized"
44- cur[z][j] = uint32_t (uint64_t (cur[z][j]) * b2x32 % mod);
45- #pragma GCC diagnostic pop
46- }
47- }
48- for (int i = 1 ; i < block / simd_size; i++) {
50+ auto process = [&](int limit, auto &args_per_block, auto step, auto &&proj) {
51+ uint64_t fact = 1 ;
52+ for (int b = 0 ; b <= limit; b += accum * block) {
53+ u32x8 cur[accum];
54+ static array<u32x8, subblock> prods[accum];
4955 for (int z = 0 ; z < accum; z++) {
50- cur[z] += b2x32;
51- cur[z] = cur[z] >= mod ? cur[z] - mod : cur[z];
52- prods[z][i] = montgomery_mul (prods[z][i - 1 ], cur[z], mod, imod);
53- }
54- }
55- cp_algo::checkpoint (" inner loop" );
56- for (int z = 0 ; z < accum; z++) {
57- uint64_t bl = b + z * block;
58- for (auto i: args_per_block[bl / block]) {
59- size_t x = args[i];
60- if (x >= mod / 2 ) {
61- x = mod - x - 1 ;
62- }
63- x -= bl;
64- auto pre_blocks = x / (block / simd_size);
65- auto in_block = x % (block / simd_size);
66- auto ans = fact * prods[z][in_block][pre_blocks] % mod;
67- for (size_t j = 0 ; j < pre_blocks; j++) {
68- ans = ans * prods[z].back ()[j] % mod;
56+ for (int j = 0 ; j < simd_size; j++) {
57+ cur[z][j] = uint32_t (b + z * block + j * subblock);
58+ cur[z][j] = proj (cur[z][j]);
59+ prods[z][0 ][j] = cur[z][j] + !cur[z][j];
60+ #pragma GCC diagnostic push
61+ #pragma GCC diagnostic ignored "-Wmaybe-uninitialized"
62+ cur[z][j] = uint32_t (uint64_t (cur[z][j]) * b2x32 % mod);
63+ #pragma GCC diagnostic pop
6964 }
70- if (args[i] >= mod / 2 ) {
71- ans = math::bpow (ans, mod - 2 , 1ULL , [](auto a, auto b){return a * b % mod;});
72- args[i] = int (x % 2 ? ans : mod - ans);
73- } else {
74- args[i] = int (ans);
65+ }
66+ for (int i = 1 ; i < block / simd_size; i++) {
67+ for (int z = 0 ; z < accum; z++) {
68+ cur[z] += step;
69+ cur[z] = cur[z] >= mod ? cur[z] - mod : cur[z];
70+ prods[z][i] = montgomery_mul (prods[z][i - 1 ], cur[z], mod, imod);
7571 }
7672 }
77- args_per_block[bl / block].clear ();
78- for (int j = 0 ; j < simd_size; j++) {
79- fact = fact * prods[z].back ()[j] % mod;
73+ cp_algo::checkpoint (" inner loop" );
74+ for (int z = 0 ; z < accum; z++) {
75+ for (int j = 0 ; j < simd_size; j++) {
76+ int bl = b + z * block + j * subblock;
77+ for (auto [i, x]: args_per_block[bl / subblock]) {
78+ auto ans = fact * prods[z][x - bl][j] % mod;
79+ res[i] = int (res[i] * ans % mod);
80+ }
81+ fact = fact * prods[z].back ()[j] % mod;
82+ }
8083 }
84+ cp_algo::checkpoint (" mul ans" );
85+ }
86+ };
87+ uint32_t b2x33 = (1ULL << 33 ) % mod;
88+ process (limit_reg, reg_args_per_block, b2x32, identity{});
89+ process (limit_odd, odd_args_per_block, b2x33, [](uint32_t x) {return 2 * x + 1 ;});
90+ for (auto [i, x]: res | views::enumerate) {
91+ if (args[i] >= mod / 2 ) {
92+ x = int (math::bpow (x, mod - 2 , 1ULL , prod_mod));
8193 }
82- cp_algo::checkpoint (" write ans" );
8394 }
95+ cp_algo::checkpoint (" inv ans" );
96+ return res;
8497}
8598
8699void solve () {
87100 int n;
88101 cin >> n;
89102 vector<int > args (n);
90103 for (auto &x : args) {cin >> x;}
91- cp_algo::checkpoint (" input read" );
92- facts_inplace (args);
93- for (auto it: args ) {cout << it << " \n " ;}
94- cp_algo::checkpoint (" output written " );
104+ cp_algo::checkpoint (" read" );
105+ auto res = facts (args);
106+ for (auto it: res ) {cout << it << " \n " ;}
107+ cp_algo::checkpoint (" write " );
95108 cp_algo::checkpoint<1 >();
96109}
97110
0 commit comments