Skip to content

Commit 7c9d0e8

Browse files
committed
Implement generic version of flint_mpn_mulhigh
Currently only available if x86_64 with ADX
1 parent 62f1e9b commit 7c9d0e8

File tree

11 files changed

+482
-273
lines changed

11 files changed

+482
-273
lines changed

src/mpn_extras.h

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -275,19 +275,18 @@ mp_limb_t _flint_mpn_mulhigh_basecase(mp_ptr, mp_srcptr, mp_srcptr, mp_size_t);
275275
mp_limb_t _flint_mpn_sqrhigh_basecase_even(mp_ptr, mp_srcptr, mp_size_t);
276276
mp_limb_t _flint_mpn_sqrhigh_basecase_odd(mp_ptr, mp_srcptr, mp_size_t);
277277

278+
mp_limb_t _flint_mpn_mulhigh(mp_ptr, mp_srcptr, mp_srcptr, mp_size_t);
279+
278280
/* TODO: Proceed with higher cases */
279281
MPN_EXTRAS_INLINE
280-
mp_limb_t flint_mpn_mulhigh_basecase(mp_ptr rp, mp_srcptr xp, mp_srcptr yp, mp_size_t n)
282+
mp_limb_t flint_mpn_mulhigh(mp_ptr rp, mp_srcptr xp, mp_srcptr yp, mp_size_t n)
281283
{
282284
FLINT_ASSERT(n >= 1);
283285

284286
if (FLINT_HAVE_MULHIGH_FUNC(n)) /* NOTE: Aliasing allowed here */
285287
return flint_mpn_mulhigh_func_tab[n](rp, xp, yp);
286288
else
287-
{
288-
FLINT_ASSERT(rp != xp && rp != yp);
289-
return _flint_mpn_mulhigh_basecase(rp, xp, yp, n);
290-
}
289+
return _flint_mpn_mulhigh(rp, xp, yp, n);
291290
}
292291

293292
/* TODO: Proceed with higher cases */
@@ -321,9 +320,7 @@ struct mp_limb_pair_t flint_mpn_mulhigh_normalised(mp_ptr rp, mp_srcptr xp, mp_s
321320

322321
FLINT_ASSERT(rp != xp && rp != yp);
323322

324-
/* TODO */
325-
/* ret.m1 = flint_mpn_mulhigh(rp, xp, yp, n); */
326-
ret.m1 = flint_mpn_mulhigh_basecase(rp, xp, yp, n);
323+
ret.m1 = _flint_mpn_mulhigh(rp, xp, yp, n);
327324

328325
if (rp[n - 1] >> (FLINT_BITS - 1))
329326
{

src/mpn_extras/mulhigh.c

Lines changed: 181 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -9,90 +9,191 @@
99
(at your option) any later version. See <https://www.gnu.org/licenses/>.
1010
*/
1111

12+
#include <string.h> /* For memcpy */
1213
#include "mpn_extras.h"
1314

14-
#if FLINT_HAVE_ADX
15-
mp_limb_t flint_mpn_mulhigh_1(mp_ptr, mp_srcptr, mp_srcptr);
16-
mp_limb_t flint_mpn_mulhigh_2(mp_ptr, mp_srcptr, mp_srcptr);
17-
mp_limb_t flint_mpn_mulhigh_3(mp_ptr, mp_srcptr, mp_srcptr);
18-
mp_limb_t flint_mpn_mulhigh_4(mp_ptr, mp_srcptr, mp_srcptr);
19-
mp_limb_t flint_mpn_mulhigh_5(mp_ptr, mp_srcptr, mp_srcptr);
20-
mp_limb_t flint_mpn_mulhigh_6(mp_ptr, mp_srcptr, mp_srcptr);
21-
mp_limb_t flint_mpn_mulhigh_7(mp_ptr, mp_srcptr, mp_srcptr);
22-
mp_limb_t flint_mpn_mulhigh_8(mp_ptr, mp_srcptr, mp_srcptr);
23-
mp_limb_t flint_mpn_mulhigh_9(mp_ptr, mp_srcptr, mp_srcptr);
24-
mp_limb_t flint_mpn_mulhigh_10(mp_ptr, mp_srcptr, mp_srcptr);
25-
mp_limb_t flint_mpn_mulhigh_11(mp_ptr, mp_srcptr, mp_srcptr);
26-
mp_limb_t flint_mpn_mulhigh_12(mp_ptr, mp_srcptr, mp_srcptr);
27-
28-
struct mp_limb_pair_t flint_mpn_mulhigh_normalised_1(mp_ptr, mp_srcptr, mp_srcptr);
29-
struct mp_limb_pair_t flint_mpn_mulhigh_normalised_2(mp_ptr, mp_srcptr, mp_srcptr);
30-
struct mp_limb_pair_t flint_mpn_mulhigh_normalised_3(mp_ptr, mp_srcptr, mp_srcptr);
31-
struct mp_limb_pair_t flint_mpn_mulhigh_normalised_4(mp_ptr, mp_srcptr, mp_srcptr);
32-
struct mp_limb_pair_t flint_mpn_mulhigh_normalised_5(mp_ptr, mp_srcptr, mp_srcptr);
33-
struct mp_limb_pair_t flint_mpn_mulhigh_normalised_6(mp_ptr, mp_srcptr, mp_srcptr);
34-
struct mp_limb_pair_t flint_mpn_mulhigh_normalised_7(mp_ptr, mp_srcptr, mp_srcptr);
35-
struct mp_limb_pair_t flint_mpn_mulhigh_normalised_8(mp_ptr, mp_srcptr, mp_srcptr);
36-
struct mp_limb_pair_t flint_mpn_mulhigh_normalised_9(mp_ptr, mp_srcptr, mp_srcptr);
37-
struct mp_limb_pair_t flint_mpn_mulhigh_normalised_10(mp_ptr, mp_srcptr, mp_srcptr);
38-
struct mp_limb_pair_t flint_mpn_mulhigh_normalised_11(mp_ptr, mp_srcptr, mp_srcptr);
39-
struct mp_limb_pair_t flint_mpn_mulhigh_normalised_12(mp_ptr, mp_srcptr, mp_srcptr);
40-
41-
mp_limb_t flint_mpn_sqrhigh_1(mp_ptr, mp_srcptr);
42-
mp_limb_t flint_mpn_sqrhigh_2(mp_ptr, mp_srcptr);
43-
mp_limb_t flint_mpn_sqrhigh_3(mp_ptr, mp_srcptr);
44-
mp_limb_t flint_mpn_sqrhigh_4(mp_ptr, mp_srcptr);
45-
mp_limb_t flint_mpn_sqrhigh_5(mp_ptr, mp_srcptr);
46-
mp_limb_t flint_mpn_sqrhigh_6(mp_ptr, mp_srcptr);
47-
mp_limb_t flint_mpn_sqrhigh_7(mp_ptr, mp_srcptr);
48-
mp_limb_t flint_mpn_sqrhigh_8(mp_ptr, mp_srcptr);
49-
50-
const flint_mpn_mul_func_t flint_mpn_mulhigh_func_tab[] =
51-
{
52-
NULL,
53-
flint_mpn_mulhigh_1,
54-
flint_mpn_mulhigh_2,
55-
flint_mpn_mulhigh_3,
56-
flint_mpn_mulhigh_4,
57-
flint_mpn_mulhigh_5,
58-
flint_mpn_mulhigh_6,
59-
flint_mpn_mulhigh_7,
60-
flint_mpn_mulhigh_8,
61-
flint_mpn_mulhigh_9,
62-
flint_mpn_mulhigh_10,
63-
flint_mpn_mulhigh_11,
64-
flint_mpn_mulhigh_12
65-
};
66-
67-
const flint_mpn_mulhigh_normalised_func_t flint_mpn_mulhigh_normalised_func_tab[] =
15+
/*
16+
17+
We will now define what we consider as a high multiplication, and how we go from
18+
basecase to bigger cases where Toom-Cook or Schönhage-Strassen multiplication is
19+
required to outperform `flint_mpn_mul'.
20+
21+
Let {a, n} and {b, n} be two n-limbed (positive) integers, of which the product
22+
of a and b is {c, 2 n}. For some applications, only the higher part is required,
23+
and so we only calculate an "approximation" of the most significant n + 1 limbs.
24+
With a sloppy wording, what this means is that we only calculate the part of the
25+
multiplication that contributes to the most significant n + 1 limbs, so carries
26+
are disregarded. This result in that the approximation of c[n - 1, ..., 2 n - 1]
27+
is smaller than the real value, and that the error is at most ~n ULPs in the
28+
least significant limb in the approximation.
29+
30+
With {c, 2 n} denoting the product of {a, n} times {b, n}, let {d, n + 1} denote
31+
the high product. We visualise the high multiplication of two 10-limbed integers
32+
with the following:
33+
34+
0 1 2 3 4 5 6 7 8 9
35+
0 h x
36+
1 h x x
37+
2 h x x x
38+
3 h x x x x
39+
4 h x x x x x
40+
5 h x x x x x x
41+
6 h x x x x x x x
42+
7 h x x x x x x x x
43+
8 h x x x x x x x x x
44+
9 x x x x x x x x x x
45+
46+
Here `h' means that only the higher part of this entry was calculated, and `x'
47+
means that the full product of the limbs where calculated.
48+
49+
To utilise multiplication algorithms that exploits symmetries, we divide this
50+
figure into four different parts:
51+
52+
0 1 2 3 4 5 6 7 8
53+
0 | h x
54+
1 | h x x
55+
2 |h x x x
56+
3 _ _ _ _h|x_x_x_x
57+
n = 9: 4 h x|x x x x
58+
5 h x x|x x x x
59+
6 h x x x|x x x x
60+
7 h x x x x|x x x x
61+
8 x x x x x|x x x x
62+
63+
0 1 2 3 4 5 6 7 8 9
64+
0 | h x
65+
1 | h x x
66+
2 | h x x x
67+
3 |h x x x x
68+
n = 10: 4 _ _ _ _h|x_x_x_x_x
69+
5 h x|x x x x x
70+
6 h x x|x x x x x
71+
7 h x x x|x x x x x
72+
8 h x x x x|x x x x x
73+
9 x x x x x|x x x x x
74+
75+
Observe that we have only one multi-limbed full multiplication, two multi-limbed
76+
high multiplications and one single-limbed high multiplication.
77+
78+
*/
79+
80+
#if FLINT_HAVE_NATIVE_MPN_MULHIGH_BASECASE && FLINT_HAVE_NATIVE_2ADD_N_INPLACE
81+
82+
#if !defined(__amd64__)
83+
# error
84+
#endif
85+
86+
/* NOTE: As we will not reuse factors in mulhigh, we utilize mul instead of mulx
87+
* to save a few bytes. */
88+
#define mulhigh(p, u, v) \
89+
do { \
90+
ulong _scr; \
91+
__asm__("mulq\t%3" \
92+
: "=a" (_scr), "=d" (p) \
93+
: "%0" ((ulong)(u)), "rm" ((ulong)(v))); \
94+
} while (0)
95+
96+
/* NOTE: Assumes no carry */
97+
#define flint_mpn_add_1(rp, x) \
98+
do { \
99+
ulong __rp_save = (rp)[0]; \
100+
(rp)[0] += (x); \
101+
if (__rp_save > (rp)[0]) \
102+
{ \
103+
slong __ix = 0; \
104+
do \
105+
{ \
106+
__ix++; \
107+
(rp)[__ix] += 1; \
108+
} while ((rp)[__ix] == UWORD(0)); \
109+
} \
110+
} while (0)
111+
112+
#define RECURSIVE_THRESHOLD 59
113+
#define _RECURSIVE_THRESHOLD 47
114+
#define FALLBACK_THRESHOLD 330
115+
116+
FLINT_STATIC_NOINLINE
117+
mp_limb_t _flint_mpn_mulhigh_rec(mp_ptr rp, mp_srcptr xp, mp_srcptr yp, mp_size_t n, mp_ptr scr)
68118
{
69-
NULL,
70-
flint_mpn_mulhigh_normalised_1,
71-
flint_mpn_mulhigh_normalised_2,
72-
flint_mpn_mulhigh_normalised_3,
73-
flint_mpn_mulhigh_normalised_4,
74-
flint_mpn_mulhigh_normalised_5,
75-
flint_mpn_mulhigh_normalised_6,
76-
flint_mpn_mulhigh_normalised_7,
77-
flint_mpn_mulhigh_normalised_8,
78-
flint_mpn_mulhigh_normalised_9,
79-
flint_mpn_mulhigh_normalised_10,
80-
flint_mpn_mulhigh_normalised_11,
81-
flint_mpn_mulhigh_normalised_12
82-
};
83-
84-
const flint_mpn_sqr_func_t flint_mpn_sqrhigh_func_tab[] =
119+
if (n < _RECURSIVE_THRESHOLD)
120+
return _flint_mpn_mulhigh_basecase(rp, xp, yp, n);
121+
else
122+
{
123+
mp_size_t np1o2 = (n + 1) / 2;
124+
mp_size_t no2 = n / 2;
125+
mp_limb_t c0, c1 = 0, ret;
126+
mp_ptr hl, hr;
127+
128+
/* Top left */
129+
mulhigh(c0, xp[no2 - 1], yp[np1o2 - 1]);
130+
131+
/* Bottom right */
132+
_flint_mpn_mul(rp, xp + no2, np1o2, yp + np1o2, no2);
133+
134+
/* Bottom left */
135+
hr = scr;
136+
ret = _flint_mpn_mulhigh_rec(hr, xp + no2, yp, np1o2, hr + np1o2);
137+
add_ssaaaa(c1, c0, c1, c0, 0, ret);
138+
139+
/* Top right */
140+
hl = scr + np1o2;
141+
hl[np1o2 - 1] = 0;
142+
ret = _flint_mpn_mulhigh_rec(hl, xp, yp + np1o2, no2, hl + np1o2);
143+
add_ssaaaa(c1, c0, c1, c0, 0, ret);
144+
145+
/* Add c1 to rp */
146+
flint_mpn_add_1(rp, c1);
147+
148+
/* Add both high multiplications to rp */
149+
ret = flint_mpn_2add_n_inplace(rp, hr, hl, np1o2);
150+
151+
/* Add carry from addition to rp */
152+
flint_mpn_add_1(rp + np1o2, ret);
153+
154+
return c0;
155+
}
156+
}
157+
158+
mp_limb_t _flint_mpn_mulhigh(mp_ptr rp, mp_srcptr xp, mp_srcptr yp, mp_size_t n)
85159
{
86-
NULL,
87-
flint_mpn_sqrhigh_1,
88-
flint_mpn_sqrhigh_2,
89-
flint_mpn_sqrhigh_3,
90-
flint_mpn_sqrhigh_4,
91-
flint_mpn_sqrhigh_5,
92-
flint_mpn_sqrhigh_6,
93-
flint_mpn_sqrhigh_7,
94-
flint_mpn_sqrhigh_8
95-
};
160+
FLINT_ASSERT(n > FLINT_MPN_MULHIGH_FUNC_TAB_WIDTH);
161+
162+
if (n < RECURSIVE_THRESHOLD)
163+
{
164+
FLINT_ASSERT(rp != xp && rp != yp);
165+
return _flint_mpn_mulhigh_basecase(rp, xp, yp, n);
166+
}
167+
else if (n < FALLBACK_THRESHOLD)
168+
{
169+
mp_limb_t ret;
170+
mp_ptr scr;
171+
172+
FLINT_ASSERT(rp != xp && rp != yp);
173+
174+
scr = flint_malloc(2 * sizeof(mp_limb_t) * n);
175+
ret = _flint_mpn_mulhigh_rec(rp, xp, yp, n, scr);
176+
flint_free(scr);
177+
178+
return ret;
179+
}
180+
else
181+
{
182+
/* Aliasing is okay */
183+
mp_ptr tmp;
184+
mp_limb_t ret;
185+
186+
tmp = flint_malloc(2 * sizeof(mp_limb_t) * n);
187+
188+
_flint_mpn_mul_n(tmp, xp, yp, n);
189+
memcpy(rp, tmp + n, sizeof(mp_limb_t) * n);
190+
ret = tmp[n - 1];
191+
192+
flint_free(tmp);
193+
194+
return ret;
195+
}
196+
}
96197
#else
97198
typedef int this_file_is_empty;
98199
#endif

0 commit comments

Comments
 (0)