|
8 | 8 | #include "../number-theory/ModularArithmetic.h" |
9 | 9 | #include "FastFourierTransform.h" |
10 | 10 | #include "FastFourierTransformMod.h" |
11 | | -// #include "NumberTheoreticTransform.h" |
| 11 | +#include "NumberTheoreticTransform.h" |
12 | 12 |
|
13 | 13 | typedef Mod num; |
14 | 14 | typedef vector<num> poly; |
15 | 15 | vector<Mod> conv(vector<Mod> a, vector<Mod> b) { |
16 | | - // auto res = convMod<mod>(vl(all(a)), vl(all(b))); |
17 | | - auto res = conv(vl(all(a)), vl(all(b))); |
18 | | - return vector<Mod>(all(res)); |
| 16 | + auto res = convMod<mod>(vl(all(a)), vl(all(b))); |
| 17 | + // auto res = conv(vl(all(a)), vl(all(b))); |
| 18 | + return vector<Mod>(all(res)); |
19 | 19 | } |
20 | 20 | poly &operator+=(poly &a, const poly &b) { |
21 | | - a.resize(max(sz(a), sz(b))); |
22 | | - rep(i, 0, sz(b)) a[i] = a[i] + b[i]; |
23 | | - return a; |
| 21 | + a.resize(max(sz(a), sz(b))); |
| 22 | + rep(i, 0, sz(b)) a[i] = a[i] + b[i]; |
| 23 | + return a; |
24 | 24 | } |
25 | 25 | poly &operator-=(poly &a, const poly &b) { |
26 | | - a.resize(max(sz(a), sz(b))); |
27 | | - rep(i, 0, sz(b)) a[i] = a[i] - b[i]; |
28 | | - return a; |
| 26 | + a.resize(max(sz(a), sz(b))); |
| 27 | + rep(i, 0, sz(b)) a[i] = a[i] - b[i]; |
| 28 | + return a; |
29 | 29 | } |
30 | 30 |
|
31 | 31 | poly &operator*=(poly &a, const poly &b) { |
32 | | - if (sz(a) + sz(b) < 100){ |
33 | | - poly res(sz(a) + sz(b) - 1); |
| 32 | + if (sz(a) + sz(b) < 100){ |
| 33 | + poly res(sz(a) + sz(b) - 1); |
34 | 34 | rep(i,0,sz(a)) rep(j,0,sz(b)) |
35 | 35 | res[i + j] = (res[i + j] + a[i] * b[j]); |
36 | | - return (a = res); |
37 | | - } |
38 | | - return a = conv(a, b); |
| 36 | + return (a = res); |
| 37 | + } |
| 38 | + return a = conv(a, b); |
39 | 39 | } |
40 | 40 | poly operator*(poly a, const num b) { |
41 | | - poly c = a; |
42 | | - trav(i, c) i = i * b; |
43 | | - return c; |
| 41 | + poly c = a; |
| 42 | + trav(i, c) i = i * b; |
| 43 | + return c; |
44 | 44 | } |
45 | 45 | #define OP(o, oe) \ |
46 | | - poly operator o(poly a, poly b) { \ |
47 | | - poly c = a; \ |
48 | | - return c oe b; \ |
49 | | - } |
| 46 | + poly operator o(poly a, poly b) { \ |
| 47 | + poly c = a; \ |
| 48 | + return c oe b; \ |
| 49 | + } |
50 | 50 | OP(*, *=) OP(+, +=) OP(-, -=); |
51 | 51 | poly modK(poly a, int k) { return {a.begin(), a.begin() + min(k, sz(a))}; } |
52 | 52 | poly inverse(poly A) { |
53 | | - poly B = poly({num(1) / A[0]}); |
54 | | - while (sz(B) < sz(A)) |
55 | | - B = modK(B * (poly({num(2)}) - modK(A, 2*sz(B)) * B), 2 * sz(B)); |
56 | | - return modK(B, sz(A)); |
| 53 | + poly B = poly({num(1) / A[0]}); |
| 54 | + while (sz(B) < sz(A)) |
| 55 | + B = modK(B * (poly({num(2)}) - modK(A, 2*sz(B)) * B), 2 * sz(B)); |
| 56 | + return modK(B, sz(A)); |
57 | 57 | } |
58 | 58 | poly &operator/=(poly &a, poly b) { |
59 | | - if (sz(a) < sz(b)) |
60 | | - return a = {}; |
61 | | - int s = sz(a) - sz(b) + 1; |
62 | | - reverse(all(a)), reverse(all(b)); |
63 | | - a.resize(s), b.resize(s); |
64 | | - a = a * inverse(b); |
65 | | - a.resize(s), reverse(all(a)); |
66 | | - return a; |
| 59 | + if (sz(a) < sz(b)) |
| 60 | + return a = {}; |
| 61 | + int s = sz(a) - sz(b) + 1; |
| 62 | + reverse(all(a)), reverse(all(b)); |
| 63 | + a.resize(s), b.resize(s); |
| 64 | + a = a * inverse(b); |
| 65 | + a.resize(s), reverse(all(a)); |
| 66 | + return a; |
67 | 67 | } |
68 | 68 | OP(/, /=) |
69 | 69 | poly &operator%=(poly &a, poly &b) { |
70 | | - if (sz(a) < sz(b)) |
71 | | - return a; |
72 | | - poly c = (a / b) * b; |
73 | | - a.resize(sz(b) - 1); |
74 | | - rep(i, 0, sz(a)) a[i] = a[i] - c[i]; |
75 | | - return a; |
| 70 | + if (sz(a) < sz(b)) |
| 71 | + return a; |
| 72 | + poly c = (a / b) * b; |
| 73 | + a.resize(sz(b) - 1); |
| 74 | + rep(i, 0, sz(a)) a[i] = a[i] - c[i]; |
| 75 | + return a; |
76 | 76 | } |
77 | 77 | OP(%, %=) |
78 | 78 | poly deriv(poly a) { |
79 | | - if (a.empty()) |
80 | | - return {}; |
81 | | - poly b(sz(a) - 1); |
82 | | - rep(i, 1, sz(a)) b[i - 1] = a[i] * num(i); |
83 | | - return b; |
| 79 | + if (a.empty()) |
| 80 | + return {}; |
| 81 | + poly b(sz(a) - 1); |
| 82 | + rep(i, 1, sz(a)) b[i - 1] = a[i] * num(i); |
| 83 | + return b; |
84 | 84 | } |
85 | 85 | poly integr(poly a) { |
86 | | - if (a.empty()) return {0}; |
87 | | - poly b(sz(a) + 1); |
88 | | - b[1] = num(1); |
89 | | - rep(i, 2, sz(b)) b[i] = b[mod%i]*Mod(-mod/i+mod); |
90 | | - rep(i, 1 ,sz(b)) b[i] = a[i-1] * b[i]; |
91 | | - return b; |
| 86 | + if (a.empty()) return {0}; |
| 87 | + poly b(sz(a) + 1); |
| 88 | + b[1] = num(1); |
| 89 | + rep(i, 2, sz(b)) b[i] = b[mod%i]*Mod(-mod/i+mod); |
| 90 | + rep(i, 1 ,sz(b)) b[i] = a[i-1] * b[i]; |
| 91 | + return b; |
92 | 92 | } |
93 | 93 | poly log(poly a) { return modK(integr(deriv(a) * inverse(a)), sz(a)); } |
94 | 94 | poly exp(poly a) { |
95 | | - poly b(1, num(1)); |
96 | | - if (a.empty()) |
97 | | - return b; |
98 | | - while (sz(b) < sz(a)) { |
99 | | - b.resize(sz(b) * 2); |
100 | | - b *= (poly({num(1)}) + modK(a, sz(b)) - log(b)); |
101 | | - b.resize(sz(b) / 2 + 1); |
102 | | - } |
103 | | - return modK(b, sz(a)); |
| 95 | + poly b(1, num(1)); |
| 96 | + if (a.empty()) |
| 97 | + return b; |
| 98 | + while (sz(b) < sz(a)) { |
| 99 | + b.resize(sz(b) * 2); |
| 100 | + b *= (poly({num(1)}) + modK(a, sz(b)) - log(b)); |
| 101 | + b.resize(sz(b) / 2 + 1); |
| 102 | + } |
| 103 | + return modK(b, sz(a)); |
104 | 104 | } |
105 | 105 | poly pow(poly a, ll m) { |
106 | | - int p = 0, n = sz(a); |
107 | | - while (p < sz(a) && a[p].x == 0) |
108 | | - ++p; |
109 | | - if (ll(m)*p >= sz(a)) return poly(sz(a)); |
110 | | - num j = a[p]; |
111 | | - a = {a.begin() + p, a.end()}; |
112 | | - a = a * (num(1) / j); |
113 | | - a.resize(n); |
114 | | - auto res = exp(log(a) * num(m)) * (j ^ m); |
115 | | - res.insert(res.begin(), p*m, 0); |
116 | | - return modK(res, n); |
| 106 | + int p = 0, n = sz(a); |
| 107 | + while (p < sz(a) && a[p].x == 0) |
| 108 | + ++p; |
| 109 | + if (ll(m)*p >= sz(a)) return poly(sz(a)); |
| 110 | + num j = a[p]; |
| 111 | + a = {a.begin() + p, a.end()}; |
| 112 | + a = a * (num(1) / j); |
| 113 | + a.resize(n); |
| 114 | + auto res = exp(log(a) * num(m)) * (j ^ m); |
| 115 | + res.insert(res.begin(), p*m, 0); |
| 116 | + return modK(res, n); |
117 | 117 | } |
118 | 118 |
|
119 | 119 | vector<num> eval(const poly &a, const vector<num> &x) { |
120 | | - int n = sz(x); |
121 | | - if (!n) return {}; |
122 | | - vector<poly> up(2 * n); |
123 | | - rep(i, 0, n) up[i + n] = poly({num(0) - x[i], 1}); |
124 | | - for (int i = n - 1; i > 0; i--) |
125 | | - up[i] = up[2 * i] * up[2 * i + 1]; |
126 | | - vector<poly> down(2 * n); |
127 | | - down[1] = a % up[1]; |
128 | | - rep(i, 2, 2 * n) down[i] = down[i / 2] % up[i]; |
129 | | - vector<num> y(n); |
130 | | - rep(i, 0, n) y[i] = down[i + n][0]; |
131 | | - return y; |
| 120 | + int n = sz(x); |
| 121 | + if (!n) return {}; |
| 122 | + vector<poly> up(2 * n); |
| 123 | + rep(i, 0, n) up[i + n] = poly({num(0) - x[i], 1}); |
| 124 | + for (int i = n - 1; i > 0; i--) |
| 125 | + up[i] = up[2 * i] * up[2 * i + 1]; |
| 126 | + vector<poly> down(2 * n); |
| 127 | + down[1] = a % up[1]; |
| 128 | + rep(i, 2, 2 * n) down[i] = down[i / 2] % up[i]; |
| 129 | + vector<num> y(n); |
| 130 | + rep(i, 0, n) y[i] = down[i + n][0]; |
| 131 | + return y; |
132 | 132 | } |
133 | 133 |
|
134 | 134 | poly interp(vector<num> x, vector<num> y) { |
|
0 commit comments