Skip to content

Commit 54c5455

Browse files
Merge pull request #2001 from albinahlback/sqrhigh_normalized
Add flint_mpn_sqrhigh_normalised
2 parents b7cf6d6 + 9e9d2c5 commit 54c5455

File tree

7 files changed

+900
-71
lines changed

7 files changed

+900
-71
lines changed

src/mpn_extras.h

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -549,20 +549,25 @@ flint_mpn_sqr(mp_ptr r, mp_srcptr x, mp_size_t n)
549549
#define FLINT_HAVE_MULHIGH_FUNC(n) ((n) <= FLINT_MPN_MULHIGH_FUNC_TAB_WIDTH)
550550
#define FLINT_HAVE_SQRHIGH_FUNC(n) ((n) <= FLINT_MPN_SQRHIGH_FUNC_TAB_WIDTH)
551551
#define FLINT_HAVE_MULHIGH_NORMALISED_FUNC(n) ((n) <= FLINT_MPN_MULHIGH_NORMALISED_FUNC_TAB_WIDTH)
552+
#define FLINT_HAVE_SQRHIGH_NORMALISED_FUNC(n) ((n) <= FLINT_MPN_SQRHIGH_NORMALISED_FUNC_TAB_WIDTH)
552553

553554
typedef struct { mp_limb_t m1; mp_limb_t m2; } mp_limb_pair_t;
555+
typedef mp_limb_pair_t (* flint_mpn_sqrhigh_normalised_func_t)(mp_ptr, mp_srcptr);
554556
typedef mp_limb_pair_t (* flint_mpn_mulhigh_normalised_func_t)(mp_ptr, mp_srcptr, mp_srcptr);
555557

556558
FLINT_DLL extern const flint_mpn_mul_func_t flint_mpn_mullow_func_tab[];
557559
FLINT_DLL extern const flint_mpn_mul_func_t flint_mpn_mulhigh_func_tab[];
558560
FLINT_DLL extern const flint_mpn_sqr_func_t flint_mpn_sqrhigh_func_tab[];
559561
FLINT_DLL extern const flint_mpn_mulhigh_normalised_func_t flint_mpn_mulhigh_normalised_func_tab[];
562+
FLINT_DLL extern const flint_mpn_sqrhigh_normalised_func_t flint_mpn_sqrhigh_normalised_func_tab[];
560563

561564
#if FLINT_HAVE_ASSEMBLY_x86_64_adx
562565
# define FLINT_MPN_MULLOW_FUNC_TAB_WIDTH 8
563566
# define FLINT_MPN_MULHIGH_FUNC_TAB_WIDTH 9
564567
# define FLINT_MPN_SQRHIGH_FUNC_TAB_WIDTH 8
565568
# define FLINT_MPN_MULHIGH_NORMALISED_FUNC_TAB_WIDTH 9
569+
# define FLINT_MPN_SQRHIGH_NORMALISED_FUNC_TAB_WIDTH 8
570+
566571
# define FLINT_HAVE_NATIVE_mpn_mullow_basecase 1
567572
/* NOTE: This function only works for n >= 6 */
568573
# define FLINT_HAVE_NATIVE_mpn_mulhigh_basecase 1
@@ -574,6 +579,8 @@ FLINT_DLL extern const flint_mpn_mulhigh_normalised_func_t flint_mpn_mulhigh_nor
574579
# define FLINT_MPN_MULHIGH_FUNC_TAB_WIDTH 8
575580
# define FLINT_MPN_SQRHIGH_FUNC_TAB_WIDTH 8
576581
# define FLINT_MPN_MULHIGH_NORMALISED_FUNC_TAB_WIDTH 0
582+
# define FLINT_MPN_SQRHIGH_NORMALISED_FUNC_TAB_WIDTH 0
583+
577584
/* NOTE: This function only works for n > 8 */
578585
# define FLINT_HAVE_NATIVE_mpn_mulhigh_basecase 1
579586

@@ -583,6 +590,7 @@ FLINT_DLL extern const flint_mpn_mulhigh_normalised_func_t flint_mpn_mulhigh_nor
583590
# define FLINT_MPN_MULHIGH_FUNC_TAB_WIDTH 16
584591
# define FLINT_MPN_SQRHIGH_FUNC_TAB_WIDTH 2
585592
# define FLINT_MPN_MULHIGH_NORMALISED_FUNC_TAB_WIDTH 0
593+
# define FLINT_MPN_SQRHIGH_NORMALISED_FUNC_TAB_WIDTH 0
586594

587595
#endif
588596

@@ -715,6 +723,19 @@ mp_limb_pair_t flint_mpn_mulhigh_normalised(mp_ptr rp, mp_srcptr xp, mp_srcptr y
715723
return _flint_mpn_mulhigh_normalised(rp, xp, yp, n);
716724
}
717725

726+
mp_limb_pair_t _flint_mpn_sqrhigh_normalised(mp_ptr rp, mp_srcptr xp, mp_size_t n);
727+
728+
MPN_EXTRAS_INLINE
729+
mp_limb_pair_t flint_mpn_sqrhigh_normalised(mp_ptr rp, mp_srcptr xp, mp_size_t n)
730+
{
731+
FLINT_ASSERT(n >= 1);
732+
733+
if (FLINT_HAVE_SQRHIGH_NORMALISED_FUNC(n))
734+
return flint_mpn_sqrhigh_normalised_func_tab[n](rp, xp);
735+
else
736+
return _flint_mpn_sqrhigh_normalised(rp, xp, n);
737+
}
738+
718739
/* division ******************************************************************/
719740

720741
#if FLINT_HAVE_NATIVE_mpn_modexact_1_odd

src/mpn_extras/mulhigh_basecase.c

Lines changed: 0 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -33,15 +33,6 @@ mp_limb_pair_t flint_mpn_mulhigh_normalised_7(mp_ptr, mp_srcptr, mp_srcptr);
3333
mp_limb_pair_t flint_mpn_mulhigh_normalised_8(mp_ptr, mp_srcptr, mp_srcptr);
3434
mp_limb_pair_t flint_mpn_mulhigh_normalised_9(mp_ptr, mp_srcptr, mp_srcptr);
3535

36-
mp_limb_t flint_mpn_sqrhigh_1(mp_ptr, mp_srcptr);
37-
mp_limb_t flint_mpn_sqrhigh_2(mp_ptr, mp_srcptr);
38-
mp_limb_t flint_mpn_sqrhigh_3(mp_ptr, mp_srcptr);
39-
mp_limb_t flint_mpn_sqrhigh_4(mp_ptr, mp_srcptr);
40-
mp_limb_t flint_mpn_sqrhigh_5(mp_ptr, mp_srcptr);
41-
mp_limb_t flint_mpn_sqrhigh_6(mp_ptr, mp_srcptr);
42-
mp_limb_t flint_mpn_sqrhigh_7(mp_ptr, mp_srcptr);
43-
mp_limb_t flint_mpn_sqrhigh_8(mp_ptr, mp_srcptr);
44-
4536
const flint_mpn_mul_func_t flint_mpn_mulhigh_func_tab[] =
4637
{
4738
NULL,
@@ -69,19 +60,6 @@ const flint_mpn_mulhigh_normalised_func_t flint_mpn_mulhigh_normalised_func_tab[
6960
flint_mpn_mulhigh_normalised_8,
7061
flint_mpn_mulhigh_normalised_9
7162
};
72-
73-
const flint_mpn_sqr_func_t flint_mpn_sqrhigh_func_tab[] =
74-
{
75-
NULL,
76-
flint_mpn_sqrhigh_1,
77-
flint_mpn_sqrhigh_2,
78-
flint_mpn_sqrhigh_3,
79-
flint_mpn_sqrhigh_4,
80-
flint_mpn_sqrhigh_5,
81-
flint_mpn_sqrhigh_6,
82-
flint_mpn_sqrhigh_7,
83-
flint_mpn_sqrhigh_8
84-
};
8563
#elif FLINT_HAVE_ASSEMBLY_armv8
8664
mp_limb_t flint_mpn_mulhigh_1(mp_ptr, mp_srcptr, mp_srcptr);
8765
mp_limb_t flint_mpn_mulhigh_2(mp_ptr, mp_srcptr, mp_srcptr);
@@ -92,15 +70,6 @@ mp_limb_t flint_mpn_mulhigh_6(mp_ptr, mp_srcptr, mp_srcptr);
9270
mp_limb_t flint_mpn_mulhigh_7(mp_ptr, mp_srcptr, mp_srcptr);
9371
mp_limb_t flint_mpn_mulhigh_8(mp_ptr, mp_srcptr, mp_srcptr);
9472

95-
mp_limb_t flint_mpn_sqrhigh_1(mp_ptr, mp_srcptr);
96-
mp_limb_t flint_mpn_sqrhigh_2(mp_ptr, mp_srcptr);
97-
mp_limb_t flint_mpn_sqrhigh_3(mp_ptr, mp_srcptr);
98-
mp_limb_t flint_mpn_sqrhigh_4(mp_ptr, mp_srcptr);
99-
mp_limb_t flint_mpn_sqrhigh_5(mp_ptr, mp_srcptr);
100-
mp_limb_t flint_mpn_sqrhigh_6(mp_ptr, mp_srcptr);
101-
mp_limb_t flint_mpn_sqrhigh_7(mp_ptr, mp_srcptr);
102-
mp_limb_t flint_mpn_sqrhigh_8(mp_ptr, mp_srcptr);
103-
10473
const flint_mpn_mul_func_t flint_mpn_mulhigh_func_tab[] =
10574
{
10675
NULL,
@@ -118,25 +87,10 @@ const flint_mpn_mulhigh_normalised_func_t flint_mpn_mulhigh_normalised_func_tab[
11887
{
11988
NULL,
12089
};
121-
122-
const flint_mpn_sqr_func_t flint_mpn_sqrhigh_func_tab[] =
123-
{
124-
NULL,
125-
flint_mpn_sqrhigh_1,
126-
flint_mpn_sqrhigh_2,
127-
flint_mpn_sqrhigh_3,
128-
flint_mpn_sqrhigh_4,
129-
flint_mpn_sqrhigh_5,
130-
flint_mpn_sqrhigh_6,
131-
flint_mpn_sqrhigh_7,
132-
flint_mpn_sqrhigh_8,
133-
};
13490
#else
13591

13692
/* todo: add MPFR-like basecase for use in mulders */
13793
/* todo: squaring code */
138-
/* todo: define the generic basecase also on x86_64_adx,
139-
and use to test the assembly versions */
14094

14195
mp_limb_t _flint_mpn_mulhigh_basecase(mp_ptr res, mp_srcptr u, mp_srcptr v, mp_size_t n)
14296
{
@@ -428,28 +382,4 @@ const flint_mpn_mulhigh_normalised_func_t flint_mpn_mulhigh_normalised_func_tab[
428382
{
429383
NULL,
430384
};
431-
432-
mp_limb_t flint_mpn_sqrhigh_1(mp_ptr res, mp_srcptr u)
433-
{
434-
mp_limb_t low;
435-
umul_ppmm(res[0], low, u[0], u[0]);
436-
return low;
437-
}
438-
439-
/* todo */
440-
mp_limb_t flint_mpn_sqrhigh_2(mp_ptr res, mp_srcptr u)
441-
{
442-
mp_limb_t b, low;
443-
FLINT_MPN_MUL_2X2(res[1], res[0], low, b, u[1], u[0], u[1], u[0]);
444-
return low;
445-
}
446-
447-
/* todo: higher cases */
448-
449-
const flint_mpn_sqr_func_t flint_mpn_sqrhigh_func_tab[] = {
450-
NULL,
451-
flint_mpn_sqrhigh_1,
452-
flint_mpn_sqrhigh_2,
453-
};
454-
455385
#endif

src/mpn_extras/sqrhigh.c

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,3 +175,27 @@ _flint_mpn_sqrhigh(mp_ptr res, mp_srcptr u, mp_size_t n)
175175
else
176176
return _flint_mpn_sqrhigh_sqr(res, u, n);
177177
}
178+
179+
mp_limb_pair_t _flint_mpn_sqrhigh_normalised(mp_ptr rp, mp_srcptr xp, mp_size_t n)
180+
{
181+
mp_limb_pair_t ret;
182+
183+
FLINT_ASSERT(n >= 1);
184+
FLINT_ASSERT(rp != xp);
185+
186+
ret.m1 = flint_mpn_sqrhigh(rp, xp, n);
187+
188+
if (rp[n - 1] >> (FLINT_BITS - 1))
189+
{
190+
ret.m2 = 0;
191+
}
192+
else
193+
{
194+
ret.m2 = 1;
195+
mpn_lshift(rp, rp, n, 1);
196+
rp[0] |= (ret.m1 >> (FLINT_BITS - 1));
197+
ret.m1 <<= 1;
198+
}
199+
200+
return ret;
201+
}

src/mpn_extras/sqrhigh_basecase.c

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
/*
2+
Copyright (C) 2024 Albin Ahlbäck
3+
4+
This file is part of FLINT.
5+
6+
FLINT is free software: you can redistribute it and/or modify it under
7+
the terms of the GNU Lesser General Public License (LGPL) as published
8+
by the Free Software Foundation; either version 3 of the License, or
9+
(at your option) any later version. See <https://www.gnu.org/licenses/>.
10+
*/
11+
12+
#include "mpn_extras.h"
13+
14+
#if FLINT_HAVE_ASSEMBLY_x86_64_adx || FLINT_HAVE_ASSEMBLY_armv8
15+
mp_limb_t flint_mpn_sqrhigh_1(mp_ptr, mp_srcptr);
16+
mp_limb_t flint_mpn_sqrhigh_2(mp_ptr, mp_srcptr);
17+
mp_limb_t flint_mpn_sqrhigh_3(mp_ptr, mp_srcptr);
18+
mp_limb_t flint_mpn_sqrhigh_4(mp_ptr, mp_srcptr);
19+
mp_limb_t flint_mpn_sqrhigh_5(mp_ptr, mp_srcptr);
20+
mp_limb_t flint_mpn_sqrhigh_6(mp_ptr, mp_srcptr);
21+
mp_limb_t flint_mpn_sqrhigh_7(mp_ptr, mp_srcptr);
22+
mp_limb_t flint_mpn_sqrhigh_8(mp_ptr, mp_srcptr);
23+
24+
const flint_mpn_sqr_func_t flint_mpn_sqrhigh_func_tab[] =
25+
{
26+
NULL,
27+
flint_mpn_sqrhigh_1,
28+
flint_mpn_sqrhigh_2,
29+
flint_mpn_sqrhigh_3,
30+
flint_mpn_sqrhigh_4,
31+
flint_mpn_sqrhigh_5,
32+
flint_mpn_sqrhigh_6,
33+
flint_mpn_sqrhigh_7,
34+
flint_mpn_sqrhigh_8
35+
};
36+
#else
37+
mp_limb_t flint_mpn_sqrhigh_1(mp_ptr res, mp_srcptr u)
38+
{
39+
mp_limb_t low;
40+
umul_ppmm(res[0], low, u[0], u[0]);
41+
return low;
42+
}
43+
44+
/* todo */
45+
mp_limb_t flint_mpn_sqrhigh_2(mp_ptr res, mp_srcptr u)
46+
{
47+
mp_limb_t b, low;
48+
FLINT_MPN_MUL_2X2(res[1], res[0], low, b, u[1], u[0], u[1], u[0]);
49+
return low;
50+
}
51+
52+
/* todo: higher cases */
53+
54+
const flint_mpn_sqr_func_t flint_mpn_sqrhigh_func_tab[] = {
55+
NULL,
56+
flint_mpn_sqrhigh_1,
57+
flint_mpn_sqrhigh_2,
58+
};
59+
#endif
60+
61+
#if FLINT_HAVE_ASSEMBLY_x86_64_adx
62+
mp_limb_pair_t flint_mpn_sqrhigh_normalised_1(mp_ptr, mp_srcptr);
63+
mp_limb_pair_t flint_mpn_sqrhigh_normalised_2(mp_ptr, mp_srcptr);
64+
mp_limb_pair_t flint_mpn_sqrhigh_normalised_3(mp_ptr, mp_srcptr);
65+
mp_limb_pair_t flint_mpn_sqrhigh_normalised_4(mp_ptr, mp_srcptr);
66+
mp_limb_pair_t flint_mpn_sqrhigh_normalised_5(mp_ptr, mp_srcptr);
67+
mp_limb_pair_t flint_mpn_sqrhigh_normalised_6(mp_ptr, mp_srcptr);
68+
mp_limb_pair_t flint_mpn_sqrhigh_normalised_7(mp_ptr, mp_srcptr);
69+
mp_limb_pair_t flint_mpn_sqrhigh_normalised_8(mp_ptr, mp_srcptr);
70+
71+
const flint_mpn_sqrhigh_normalised_func_t flint_mpn_sqrhigh_normalised_func_tab[] =
72+
{
73+
NULL,
74+
flint_mpn_sqrhigh_normalised_1,
75+
flint_mpn_sqrhigh_normalised_2,
76+
flint_mpn_sqrhigh_normalised_3,
77+
flint_mpn_sqrhigh_normalised_4,
78+
flint_mpn_sqrhigh_normalised_5,
79+
flint_mpn_sqrhigh_normalised_6,
80+
flint_mpn_sqrhigh_normalised_7,
81+
flint_mpn_sqrhigh_normalised_8
82+
};
83+
#else
84+
const flint_mpn_sqrhigh_normalised_func_t flint_mpn_sqrhigh_normalised_func_tab[] =
85+
{
86+
NULL
87+
};
88+
#endif

src/mpn_extras/test/main.c

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
#include "t-remove_power.c"
3232
#include "t-sqr.c"
3333
#include "t-sqrhigh.c"
34+
#include "t-sqrhigh_normalised.c"
3435

3536
/* Array of test functions ***************************************************/
3637

@@ -55,7 +56,8 @@ test_struct tests[] =
5556
TEST_FUNCTION(flint_mpn_remove_2exp),
5657
TEST_FUNCTION(flint_mpn_remove_power),
5758
TEST_FUNCTION(flint_mpn_sqr),
58-
TEST_FUNCTION(flint_mpn_sqrhigh)
59+
TEST_FUNCTION(flint_mpn_sqrhigh),
60+
TEST_FUNCTION(flint_mpn_sqrhigh_normalised)
5961
};
6062

6163
/* main function *************************************************************/
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
/*
2+
Copyright (C) 2024 Albin Ahlbäck
3+
4+
This file is part of FLINT.
5+
6+
FLINT is free software: you can redistribute it and/or modify it under
7+
the terms of the GNU Lesser General Public License (LGPL) as published
8+
by the Free Software Foundation; either version 3 of the License, or
9+
(at your option) any later version. See <https://www.gnu.org/licenses/>.
10+
*/
11+
12+
#include "test_helpers.h"
13+
#include "mpn_extras.h"
14+
15+
#define N_MIN 1
16+
#define N_MAX 64
17+
18+
TEST_FUNCTION_START(flint_mpn_sqrhigh_normalised, state)
19+
{
20+
slong ix;
21+
int result;
22+
23+
for (ix = 0; ix < 10000 * flint_test_multiplier(); ix++)
24+
{
25+
mp_ptr rp_n, rp_u, xp;
26+
mp_size_t n;
27+
mp_limb_pair_t res_norm;
28+
mp_limb_t retlimb, normalised;
29+
30+
n = N_MIN + n_randint(state, N_MAX - N_MIN + 1);
31+
32+
rp_n = flint_malloc(sizeof(mp_limb_t) * (n + 1));
33+
rp_u = flint_malloc(sizeof(mp_limb_t) * (n + 1));
34+
xp = flint_malloc(sizeof(mp_limb_t) * n);
35+
36+
flint_mpn_rrandom(xp, state, n);
37+
xp[n - 1] |= (UWORD(1) << (FLINT_BITS - 1));
38+
39+
rp_u[0] = flint_mpn_sqrhigh(rp_u + 1, xp, n);
40+
res_norm = flint_mpn_sqrhigh_normalised(rp_n + 1, xp, n);
41+
42+
retlimb = res_norm.m1;
43+
normalised = res_norm.m2;
44+
rp_n[0] = retlimb;
45+
46+
result = ((rp_n[n] & (UWORD(1) << (FLINT_BITS - 1))) != UWORD(0));
47+
if (!result)
48+
TEST_FUNCTION_FAIL(
49+
"Top bit not set in normalised result\n"
50+
"ix = %wd\n"
51+
"n = %wd\n"
52+
"xp = %{ulong*}\n"
53+
"rp_n = %{ulong*}\n"
54+
"rp_u = %{ulong*}\n",
55+
ix, n, xp, n, rp_n, n + 1, rp_u, n + 1);
56+
57+
if (normalised)
58+
{
59+
result = (mpn_lshift(rp_u, rp_u, n + 1, 1) == 0);
60+
result = result && (mpn_cmp(rp_n, rp_u, n + 1) == 0);
61+
if (!result)
62+
TEST_FUNCTION_FAIL(
63+
"rp_n != rp_u << 1 when normalised\n"
64+
"ix = %wd\n"
65+
"n = %wd\n"
66+
"xp = %{ulong*}\n"
67+
"rp_n = %{ulong*}\n"
68+
"rp_u = %{ulong*}\n",
69+
ix, n, xp, n, rp_n, n + 1, rp_u, n + 1);
70+
}
71+
else
72+
{
73+
result = (mpn_cmp(rp_n, rp_u, n + 1) == 0);
74+
if (!result)
75+
TEST_FUNCTION_FAIL(
76+
"rp_n != rp_u when unnormalised\n"
77+
"ix = %wd\n"
78+
"n = %wd\n"
79+
"xp = %{ulong*}\n"
80+
"rp_n = %{ulong*}\n"
81+
"rp_u = %{ulong*}\n",
82+
ix, n, xp, n, rp_n, n + 1, rp_u, n + 1);
83+
}
84+
85+
flint_free(rp_n);
86+
flint_free(rp_u);
87+
flint_free(xp);
88+
}
89+
90+
TEST_FUNCTION_END(state);
91+
}
92+
#undef N_MIN
93+
#undef N_MAX

0 commit comments

Comments
 (0)