@@ -32,45 +32,44 @@ pub fn crt(r: &[i64], m: &[i64]) -> (i64, i64) {
3232 assert_eq ! ( r. len( ) , m. len( ) ) ;
3333 // Contracts: 0 <= r0 < m0
3434 let ( mut r0, mut m0) = ( 0 , 1 ) ;
35- for ( ri, mi) in r. iter ( ) . zip ( m. iter ( ) ) {
36- assert ! ( 1 < * mi) ;
37- let mut r1 = internal_math:: safe_mod ( * ri, * mi) ;
38- let mut m1 = * mi;
39- if m0 < m1 {
40- swap ( & mut r0, & mut r1) ;
41- swap ( & mut m0, & mut m1) ;
35+ for ( & ( mut ri) , & ( mut mi) ) in r. iter ( ) . zip ( m. iter ( ) ) {
36+ assert ! ( 1 < mi) ;
37+ ri = internal_math:: safe_mod ( ri, mi) ;
38+ if m0 < mi {
39+ swap ( & mut r0, & mut ri) ;
40+ swap ( & mut m0, & mut mi) ;
4241 }
43- if m0 % m1 == 0 {
44- if r0 % m1 != r1 {
42+ if m0 % mi == 0 {
43+ if r0 % mi != ri {
4544 return ( 0 , 0 ) ;
4645 }
4746 continue ;
4847 }
49- // assume: m0 > m1 , lcm(m0, m1 ) >= 2 * max(m0, m1 )
48+ // assume: m0 > mi , lcm(m0, mi ) >= 2 * max(m0, mi )
5049
51- // (r0, m0), (r1, m1 ) -> (r2, m2 = lcm(m0, m1));
50+ // (r0, m0), (ri, mi ) -> (r2, m2 = lcm(m0, m1));
5251 // r2 % m0 = r0
53- // r2 % m1 = r1
54- // -> (r0 + x*m0) % m1 = r1
55- // -> x*u0*g % (u1*g) = (r1 - r0) (u0*g = m0, u1*g = m1 )
56- // -> x = (r1 - r0) / g * inv(u0) (mod u1)
52+ // r2 % mi = ri
53+ // -> (r0 + x*m0) % mi = ri
54+ // -> x*u0*g % (u1*g) = (ri - r0) (u0*g = m0, u1*g = mi )
55+ // -> x = (ri - r0) / g * inv(u0) (mod u1)
5756
5857 // im = inv(u0) (mod u1) (0 <= im < u1)
59- let ( g, im) = internal_math:: inv_gcd ( m0, m1 ) ;
60- let u1 = m1 / g;
61- // |r1 - r0| < (m0 + m1 ) <= lcm(m0, m1 )
62- if ( r1 - r0) % g != 0 {
58+ let ( g, im) = internal_math:: inv_gcd ( m0, mi ) ;
59+ let u1 = mi / g;
60+ // |ri - r0| < (m0 + mi ) <= lcm(m0, mi )
61+ if ( ri - r0) % g != 0 {
6362 return ( 0 , 0 ) ;
6463 }
65- // u1 * u1 <= m1 * m1 / g / g <= m0 * m1 / g = lcm(m0, m1 )
66- let x = ( r1 - r0) / g % u1 * im % u1;
64+ // u1 * u1 <= mi * mi / g / g <= m0 * mi / g = lcm(m0, mi )
65+ let x = ( ri - r0) / g % u1 * im % u1;
6766
6867 // |r0| + |m0 * x|
6968 // < m0 + m0 * (u1 - 1)
70- // = m0 + m0 * m1 / g - m0
71- // = lcm(m0, m1 )
69+ // = m0 + m0 * mi / g - m0
70+ // = lcm(m0, mi )
7271 r0 += x * m0;
73- m0 *= u1; // -> lcm(m0, m1 )
72+ m0 *= u1; // -> lcm(m0, mi )
7473 if r0 < 0 {
7574 r0 += m0
7675 } ;
0 commit comments