1919 * This algorithm uses Newton's approximation
2020 * x[i+1] = x[i] - f(x[i])/f'(x[i])
2121 * which will find the root in log(N) time where
22- * each step involves a fair bit. This is not meant to
23- * find huge roots [square and cube, etc].
22+ * each step involves a fair bit.
2423 */
2524int mp_n_root_ex (const mp_int * a , mp_digit b , mp_int * c , int fast )
2625{
2726 mp_int t1 , t2 , t3 , a_ ;
28- int res ;
27+ int res , cmp ;
28+ int ilog2 ;
2929
3030 /* input must be positive if b is even */
3131 if (((b & 1u ) == 0u ) && (a -> sign == MP_NEG )) {
@@ -48,9 +48,49 @@ int mp_n_root_ex(const mp_int *a, mp_digit b, mp_int *c, int fast)
4848 a_ = * a ;
4949 a_ .sign = MP_ZPOS ;
5050
51- /* t2 = 2 */
52- mp_set (& t2 , 2uL );
53-
51+ /* Compute seed: 2^(log_2(n)/b + 2)*/
52+ ilog2 = mp_count_bits (a );
53+
54+ /*
55+ GCC and clang do not understand the sizeof(bla) tests and complain,
56+ icc (the Intel compiler) seems to understand, at least it doesn't complain.
57+ 2 of 3 say these macros are necessary, so there they are.
58+ */
59+ #if ( !(defined MP_8BIT ) && !(defined MP_16BIT ) )
60+ /*
61+ The type of mp_digit might be larger than an int.
62+ If "b" is larger than INT_MAX it is also larger than
63+ log_2(n) because the bit-length of the "n" is measured
64+ with an int and hence the root is always < 2 (two).
65+ */
66+ if (sizeof (mp_digit ) >= sizeof (int )) {
67+ if (b > (mp_digit )(INT_MAX /2 )) {
68+ mp_set (c , 1uL );
69+ c -> sign = a -> sign ;
70+ res = MP_OKAY ;
71+ goto LBL_T3 ;
72+ }
73+ }
74+ #endif
75+ /* "b" is smaller than INT_MAX, we can cast safely */
76+ if (ilog2 < (int )b ) {
77+ mp_set (c , 1uL );
78+ c -> sign = a -> sign ;
79+ res = MP_OKAY ;
80+ goto LBL_T3 ;
81+ }
82+ ilog2 = ilog2 / ((int )b );
83+ if (ilog2 == 0 ) {
84+ mp_set (c , 1uL );
85+ c -> sign = a -> sign ;
86+ res = MP_OKAY ;
87+ goto LBL_T3 ;
88+ }
89+ /* Start value must be larger than root */
90+ ilog2 += 2 ;
91+ if ((res = mp_2expt (& t2 ,ilog2 )) != MP_OKAY ) {
92+ goto LBL_T3 ;
93+ }
5494 do {
5595 /* t1 = t2 */
5696 if ((res = mp_copy (& t2 , & t1 )) != MP_OKAY ) {
@@ -63,7 +103,6 @@ int mp_n_root_ex(const mp_int *a, mp_digit b, mp_int *c, int fast)
63103 if ((res = mp_expt_d_ex (& t1 , b - 1u , & t3 , fast )) != MP_OKAY ) {
64104 goto LBL_T3 ;
65105 }
66-
67106 /* numerator */
68107 /* t2 = t1**b */
69108 if ((res = mp_mul (& t3 , & t1 , & t2 )) != MP_OKAY ) {
@@ -89,14 +128,39 @@ int mp_n_root_ex(const mp_int *a, mp_digit b, mp_int *c, int fast)
89128 if ((res = mp_sub (& t1 , & t3 , & t2 )) != MP_OKAY ) {
90129 goto LBL_T3 ;
91130 }
131+ /*
132+ Number of rounds is at most log_2(root). If it is more it
133+ got stuck, so break out of the loop and do the rest manually.
134+ */
135+ if (ilog2 -- == 0 ) {
136+ break ;
137+ }
92138 } while (mp_cmp (& t1 , & t2 ) != MP_EQ );
93139
94140 /* result can be off by a few so check */
141+ /* Loop beneath can overshoot by one if found root is smaller than actual root */
142+ for (;;) {
143+ if ((res = mp_expt_d_ex (& t1 , b , & t2 , fast )) != MP_OKAY ) {
144+ goto LBL_T3 ;
145+ }
146+ cmp = mp_cmp (& t2 , & a_ );
147+ if (cmp == MP_EQ ) {
148+ res = MP_OKAY ;
149+ goto LBL_T3 ;
150+ }
151+ if (cmp == MP_LT ) {
152+ if ((res = mp_add_d (& t1 , 1uL , & t1 )) != MP_OKAY ) {
153+ goto LBL_T3 ;
154+ }
155+ } else {
156+ break ;
157+ }
158+ }
159+ /* correct overshoot from above or from recurrence */
95160 for (;;) {
96161 if ((res = mp_expt_d_ex (& t1 , b , & t2 , fast )) != MP_OKAY ) {
97162 goto LBL_T3 ;
98163 }
99-
100164 if (mp_cmp (& t2 , & a_ ) == MP_GT ) {
101165 if ((res = mp_sub_d (& t1 , 1uL , & t1 )) != MP_OKAY ) {
102166 goto LBL_T3 ;
@@ -123,7 +187,6 @@ int mp_n_root_ex(const mp_int *a, mp_digit b, mp_int *c, int fast)
123187 return res ;
124188}
125189#endif
126-
127190/* ref: $Format:%D$ */
128191/* git commit: $Format:%H$ */
129192/* commit time: $Format:%ai$ */
0 commit comments