@@ -30,7 +30,7 @@ namespace cp_algo::math::fft {
3030 }
3131
3232 dft (size_t n): A(n), B(n) {init ();}
33- dft (auto const & a, size_t n): A(n), B(n) {
33+ dft (auto const & a, size_t n, bool partial = true ): A(n), B(n) {
3434 init ();
3535 base b2x32 = bpow (base (2 ), 32 );
3636 u64x4 cur = {
@@ -66,35 +66,47 @@ namespace cp_algo::math::fft {
6666 }
6767 checkpoint (" dft init" );
6868 if (n) {
69- A.fft ();
70- B.fft ();
69+ if (partial) {
70+ A.fft ();
71+ B.fft ();
72+ } else {
73+ A.template fft <false >();
74+ B.template fft <false >();
75+ }
7176 }
7277 }
73- template <bool overwrite = true >
78+ template <bool overwrite = true , bool partial = true >
7479 void dot (auto const & C, auto const & D, auto &Aout, auto &Bout, auto &Cout) const {
7580 cvector::exec_on_evals<1 >(A.size () / flen, [&](size_t k, point rt) {
7681 k *= flen;
77- auto [Ax, Ay] = A.at (k);
78- auto [Bx, By] = B.at (k);
7982 vpoint AC, AD, BC, BD;
8083 AC = AD = BC = BD = vz;
8184 auto Cv = C.at (k), Dv = D.at (k);
82- for (size_t i = 0 ; i < flen; i++) {
83- vpoint Av = {vz + Ax[i], vz + Ay[i]}, Bv = {vz + Bx[i], vz + By[i]};
84- AC += Av * Cv; AD += Av * Dv;
85- BC += Bv * Cv; BD += Bv * Dv;
86- real (Cv) = rotate_right (real (Cv));
87- imag (Cv) = rotate_right (imag (Cv));
88- real (Dv) = rotate_right (real (Dv));
89- imag (Dv) = rotate_right (imag (Dv));
90- auto cx = real (Cv)[0 ], cy = imag (Cv)[0 ];
91- auto dx = real (Dv)[0 ], dy = imag (Dv)[0 ];
92- real (Cv)[0 ] = cx * real (rt) - cy * imag (rt);
93- imag (Cv)[0 ] = cx * imag (rt) + cy * real (rt);
94- real (Dv)[0 ] = dx * real (rt) - dy * imag (rt);
95- imag (Dv)[0 ] = dx * imag (rt) + dy * real (rt);
85+ if constexpr (partial) {
86+ auto [Ax, Ay] = A.at (k);
87+ auto [Bx, By] = B.at (k);
88+ for (size_t i = 0 ; i < flen; i++) {
89+ vpoint Av = {vz + Ax[i], vz + Ay[i]}, Bv = {vz + Bx[i], vz + By[i]};
90+ AC += Av * Cv; AD += Av * Dv;
91+ BC += Bv * Cv; BD += Bv * Dv;
92+ real (Cv) = rotate_right (real (Cv));
93+ imag (Cv) = rotate_right (imag (Cv));
94+ real (Dv) = rotate_right (real (Dv));
95+ imag (Dv) = rotate_right (imag (Dv));
96+ auto cx = real (Cv)[0 ], cy = imag (Cv)[0 ];
97+ auto dx = real (Dv)[0 ], dy = imag (Dv)[0 ];
98+ real (Cv)[0 ] = cx * real (rt) - cy * imag (rt);
99+ imag (Cv)[0 ] = cx * imag (rt) + cy * real (rt);
100+ real (Dv)[0 ] = dx * real (rt) - dy * imag (rt);
101+ imag (Dv)[0 ] = dx * imag (rt) + dy * real (rt);
102+ }
103+ } else {
104+ AC = A.at (k) * Cv;
105+ AD = A.at (k) * Dv;
106+ BC = B.at (k) * Cv;
107+ BD = B.at (k) * Dv;
96108 }
97- if (overwrite) {
109+ if constexpr (overwrite) {
98110 Aout.at (k) = AC;
99111 Cout.at (k) = AD + BC;
100112 Bout.at (k) = BD;
0 commit comments