Skip to content

Commit 75001b2

Browse files
Int wrapping multiplication improvements (#998)
This optimizes `Int::wrapping_mul` and `Uint::checked_mul_int` to make use of `Uint::wrapping_mul`, and adds `Int::saturating_mul`. Signed-off-by: Andrew Whitehead <cywolf@gmail.com>
1 parent 62b90b8 commit 75001b2

File tree

6 files changed

+247
-34
lines changed

6 files changed

+247
-34
lines changed

benches/int.rs

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,59 @@ fn bench_concatenating_mul(c: &mut Criterion) {
112112
});
113113
}
114114

115+
fn bench_wrapping_mul(c: &mut Criterion) {
116+
let mut rng = ChaCha8Rng::from_seed([7u8; 32]);
117+
let mut group = c.benchmark_group("wrapping ops");
118+
119+
group.bench_function("wrapping_mul, I128xI128", |b| {
120+
b.iter_batched(
121+
|| (I256::random(&mut rng), I256::random(&mut rng)),
122+
|(x, y)| black_box(x.wrapping_mul(&y)),
123+
BatchSize::SmallInput,
124+
)
125+
});
126+
127+
group.bench_function("wrapping_mul, I256xI256", |b| {
128+
b.iter_batched(
129+
|| (I256::random(&mut rng), I256::random(&mut rng)),
130+
|(x, y)| black_box(x.wrapping_mul(&y)),
131+
BatchSize::SmallInput,
132+
)
133+
});
134+
135+
group.bench_function("wrapping_mul, I512xI512", |b| {
136+
b.iter_batched(
137+
|| (I512::random(&mut rng), I512::random(&mut rng)),
138+
|(x, y)| black_box(x.wrapping_mul(&y)),
139+
BatchSize::SmallInput,
140+
)
141+
});
142+
143+
group.bench_function("wrapping_mul, I1024xI1024", |b| {
144+
b.iter_batched(
145+
|| (I1024::random(&mut rng), I1024::random(&mut rng)),
146+
|(x, y)| black_box(x.wrapping_mul(&y)),
147+
BatchSize::SmallInput,
148+
)
149+
});
150+
151+
group.bench_function("wrapping_mul, I2048xI2048", |b| {
152+
b.iter_batched(
153+
|| (I2048::random(&mut rng), I2048::random(&mut rng)),
154+
|(x, y)| black_box(x.wrapping_mul(&y)),
155+
BatchSize::SmallInput,
156+
)
157+
});
158+
159+
group.bench_function("wrapping_mul, I4096xI4096", |b| {
160+
b.iter_batched(
161+
|| (I4096::random(&mut rng), I4096::random(&mut rng)),
162+
|(x, y)| black_box(x.wrapping_mul(&y)),
163+
BatchSize::SmallInput,
164+
)
165+
});
166+
}
167+
115168
fn bench_div(c: &mut Criterion) {
116169
let mut rng = ChaCha8Rng::from_seed([7u8; 32]);
117170
let mut group = c.benchmark_group("wrapping ops");
@@ -341,6 +394,7 @@ criterion_group!(
341394
benches,
342395
bench_mul,
343396
bench_concatenating_mul,
397+
bench_wrapping_mul,
344398
bench_div,
345399
bench_add,
346400
bench_sub,

src/int.rs

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -56,20 +56,17 @@ impl<const LIMBS: usize> Int<LIMBS> {
5656
pub const ONE: Self = Self(Uint::ONE); // Bit sequence (be): 0000....0001
5757

5858
/// The value `-1`
59-
pub const MINUS_ONE: Self = Self::FULL_MASK; // Bit sequence (be): 1111....1111
59+
pub const MINUS_ONE: Self = Self(Uint::MAX); // Bit sequence (be): 1111....1111
6060

6161
/// Smallest value this [`Int`] can express.
62-
pub const MIN: Self = Self(Uint::MAX.bitxor(&Uint::MAX.shr(1u32))); // Bit sequence (be): 1000....0000
62+
pub const MIN: Self = Self::MAX.not(); // Bit sequence (be): 1000....0000
6363

6464
/// Maximum value this [`Int`] can express.
65-
pub const MAX: Self = Self(Uint::MAX.shr(1u32)); // Bit sequence (be): 0111....1111
65+
pub const MAX: Self = Self(Uint::MAX.shr_vartime(1u32)); // Bit sequence (be): 0111....1111
6666

6767
/// Bit mask for the sign bit of this [`Int`].
6868
pub const SIGN_MASK: Self = Self::MIN; // Bit sequence (be): 1000....0000
6969

70-
/// All-one bit mask.
71-
pub const FULL_MASK: Self = Self(Uint::MAX); // Bit sequence (be): 1111...1111
72-
7370
/// Total size of the represented integer in bits.
7471
pub const BITS: u32 = Uint::<LIMBS>::BITS;
7572

@@ -113,7 +110,7 @@ impl<const LIMBS: usize> Int<LIMBS> {
113110
}
114111

115112
/// Borrow the inner limbs as a mutable array of [`Word`]s.
116-
pub fn as_mut_words(&mut self) -> &mut [Word; LIMBS] {
113+
pub const fn as_mut_words(&mut self) -> &mut [Word; LIMBS] {
117114
self.0.as_mut_words()
118115
}
119116

@@ -182,7 +179,7 @@ impl<const LIMBS: usize> Int<LIMBS> {
182179
}
183180

184181
/// Whether this [`Int`] is equal to `Self::MAX`.
185-
pub fn is_max(&self) -> ConstChoice {
182+
pub const fn is_max(&self) -> ConstChoice {
186183
Self::eq(self, &Self::MAX)
187184
}
188185

src/int/mul.rs

Lines changed: 131 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ use core::ops::{Mul, MulAssign};
44
use num_traits::WrappingMul;
55
use subtle::CtOption;
66

7-
use crate::{Checked, CheckedMul, ConcatMixed, ConstChoice, ConstCtOption, Int, Uint, Zero};
7+
use crate::{Checked, CheckedMul, ConcatMixed, ConstChoice, ConstCtOption, Int, Uint};
88

99
impl<const LIMBS: usize> Int<LIMBS> {
1010
/// Compute "wide" multiplication as a 3-tuple `(lo, hi, negate)`.
@@ -64,12 +64,41 @@ impl<const LIMBS: usize> Int<LIMBS> {
6464
Int::from_bits(product_abs.wrapping_neg_if(product_sign))
6565
}
6666

67-
/// Multiply `self` by `rhs`, wrapping the result in case of overflow.
68-
pub const fn wrapping_mul<const RHS_LIMBS: usize>(&self, rhs: &Int<RHS_LIMBS>) -> Self {
67+
/// Multiply `self` by `rhs`, returning a `ConstCtOption` which is `is_some` only if
68+
/// overflow did not occur.
69+
pub const fn checked_mul<const RHS_LIMBS: usize>(
70+
&self,
71+
rhs: &Int<RHS_LIMBS>,
72+
) -> ConstCtOption<Self> {
6973
let (abs_lhs, lhs_sgn) = self.abs_sign();
7074
let (abs_rhs, rhs_sgn) = rhs.abs_sign();
71-
let (lo, _) = abs_lhs.widening_mul(&abs_rhs);
72-
*lo.wrapping_neg_if(lhs_sgn.xor(rhs_sgn)).as_int()
75+
let maybe_res = abs_lhs.checked_mul(&abs_rhs);
76+
let (lo, is_some) = maybe_res.components_ref();
77+
Self::new_from_abs_sign(*lo, lhs_sgn.xor(rhs_sgn)).and_choice(is_some)
78+
}
79+
80+
/// Multiply `self` by `rhs`, saturating at the numeric bounds instead of overflowing.
81+
pub const fn saturating_mul<const RHS_LIMBS: usize>(&self, rhs: &Int<RHS_LIMBS>) -> Self {
82+
let (abs_lhs, lhs_sgn) = self.abs_sign();
83+
let (abs_rhs, rhs_sgn) = rhs.abs_sign();
84+
let maybe_res = abs_lhs.checked_mul(&abs_rhs);
85+
let (lo, is_some) = maybe_res.components_ref();
86+
let is_neg = lhs_sgn.xor(rhs_sgn);
87+
let bound = Self::select(&Self::MAX, &Self::MIN, is_neg);
88+
Self::new_from_abs_sign(*lo, is_neg)
89+
.and_choice(is_some)
90+
.unwrap_or(bound)
91+
}
92+
93+
/// Multiply `self` by `rhs`, wrapping the result in case of overflow.
94+
/// This is equivalent to `(self * rhs) % (Uint::<LIMBS>::MAX + 1)`.
95+
pub const fn wrapping_mul<const RHS_LIMBS: usize>(&self, rhs: &Int<RHS_LIMBS>) -> Self {
96+
if RHS_LIMBS >= LIMBS {
97+
Self(self.0.wrapping_mul(&rhs.0))
98+
} else {
99+
let (abs_rhs, rhs_sgn) = rhs.abs_sign();
100+
Self(self.0.wrapping_mul(&abs_rhs).wrapping_neg_if(rhs_sgn))
101+
}
73102
}
74103
}
75104

@@ -102,9 +131,7 @@ impl<const LIMBS: usize> Int<LIMBS> {
102131
impl<const LIMBS: usize, const RHS_LIMBS: usize> CheckedMul<Int<RHS_LIMBS>> for Int<LIMBS> {
103132
#[inline]
104133
fn checked_mul(&self, rhs: &Int<RHS_LIMBS>) -> CtOption<Self> {
105-
let (lo, hi, is_negative) = self.widening_mul(rhs);
106-
let val = Self::new_from_abs_sign(lo, is_negative);
107-
CtOption::from(val).and_then(|int| CtOption::new(int, hi.is_zero()))
134+
self.checked_mul(rhs).into()
108135
}
109136
}
110137

@@ -173,7 +200,7 @@ impl<const LIMBS: usize> MulAssign<&Checked<Int<LIMBS>>> for Checked<Int<LIMBS>>
173200

174201
#[cfg(test)]
175202
mod tests {
176-
use crate::{CheckedMul, ConstChoice, I64, I128, I256, Int, U128, U256};
203+
use crate::{ConstChoice, I64, I128, I256, Int, U64, U128, U256};
177204

178205
#[test]
179206
#[allow(clippy::init_numbered_fields)]
@@ -271,20 +298,26 @@ mod tests {
271298
#[test]
272299
fn test_wrapping_mul() {
273300
// wrapping
274-
let a = I128::from_be_hex("FFFFFFFB7B63198EF870DF1F90D9BD9E");
275-
let b = I128::from_be_hex("F20C29FA87B356AA3B4C05C4F9C24B4A");
301+
let a = 0xFFFFFFFB7B63198EF870DF1F90D9BD9Eu128 as i128;
302+
let b = 0xF20C29FA87B356AA3B4C05C4F9C24B4Au128 as i128;
303+
let z = 0xAA700D354D6CF4EE881F8FF8093A19ACu128 as i128;
304+
assert_eq!(a.wrapping_mul(b), z);
276305
assert_eq!(
277-
a.wrapping_mul(&b),
278-
I128::from_be_hex("AA700D354D6CF4EE881F8FF8093A19AC")
306+
I128::from_i128(a).wrapping_mul(&I128::from_i128(b)),
307+
I128::from_i128(z)
279308
);
280309

281310
// no wrapping
282-
let c = I64::from_i64(-12345i64);
311+
let c = -12345i64;
283312
assert_eq!(
284-
a.wrapping_mul(&c),
285-
I128::from_be_hex("0000D9DEF2248095850866CFEBF727D2")
313+
I128::from_i128(a).wrapping_mul(&I128::from_i64(c)),
314+
I128::from_i128(a.wrapping_mul(c as i128))
286315
);
287316

317+
// overflow into MSB
318+
let a = 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFu128 as i128;
319+
assert!(!a.is_negative() && a.wrapping_mul(a).is_negative());
320+
288321
// core case
289322
assert_eq!(i8::MAX.wrapping_mul(2), -2);
290323
assert_eq!(i64::MAX.wrapping_mul(2), -2);
@@ -312,6 +345,88 @@ mod tests {
312345
);
313346
}
314347

348+
#[test]
349+
fn test_wrapping_mul_mixed() {
350+
let a = U64::from_u64(0x0011223344556677);
351+
let b = U128::from_u128(0x8899aabbccddeeff_8899aabbccddeeff);
352+
let expected = a.as_int().concatenating_mul(b.as_int());
353+
assert_eq!(a.as_int().wrapping_mul(b.as_int()), expected.resize());
354+
assert_eq!(b.as_int().wrapping_mul(a.as_int()), expected.resize());
355+
assert_eq!(
356+
a.as_int().wrapping_neg().wrapping_mul(b.as_int()),
357+
expected.wrapping_neg().resize()
358+
);
359+
assert_eq!(
360+
a.as_int().wrapping_mul(&b.as_int().wrapping_neg()),
361+
expected.wrapping_neg().resize()
362+
);
363+
assert_eq!(
364+
b.as_int().wrapping_neg().wrapping_mul(a.as_int()),
365+
expected.wrapping_neg().resize()
366+
);
367+
assert_eq!(
368+
b.as_int().wrapping_mul(&a.as_int().wrapping_neg()),
369+
expected.wrapping_neg().resize()
370+
);
371+
assert_eq!(
372+
a.as_int()
373+
.wrapping_neg()
374+
.wrapping_mul(&b.as_int().wrapping_neg()),
375+
expected.resize()
376+
);
377+
assert_eq!(
378+
b.as_int()
379+
.wrapping_neg()
380+
.wrapping_mul(&a.as_int().wrapping_neg()),
381+
expected.resize()
382+
);
383+
}
384+
385+
#[test]
386+
fn test_saturating_mul() {
387+
// wrapping
388+
let a = 0xFFFFFFFB7B63198EF870DF1F90D9BD9Eu128 as i128;
389+
let b = 0xF20C29FA87B356AA3B4C05C4F9C24B4Au128 as i128;
390+
assert_eq!(a.saturating_mul(b), i128::MAX);
391+
assert_eq!(
392+
I128::from_i128(a).saturating_mul(&I128::from_i128(b)),
393+
I128::MAX
394+
);
395+
396+
// no wrapping
397+
let c = -12345i64;
398+
assert_eq!(
399+
I128::from_i128(a).saturating_mul(&I128::from_i64(c)),
400+
I128::from_i128(a.saturating_mul(c as i128))
401+
);
402+
403+
// core case
404+
assert_eq!(i8::MAX.saturating_mul(2), i8::MAX);
405+
assert_eq!(i8::MAX.saturating_mul(-2), i8::MIN);
406+
assert_eq!(i64::MAX.saturating_mul(2), i64::MAX);
407+
assert_eq!(i64::MAX.saturating_mul(-2), i64::MIN);
408+
assert_eq!(I128::MAX.saturating_mul(&I128::from_i64(2i64)), I128::MAX);
409+
assert_eq!(I128::MAX.saturating_mul(&I128::from_i64(-2i64)), I128::MIN);
410+
411+
let x = -197044252290277702i64;
412+
let y = -2631691865753118366;
413+
assert_eq!(x.saturating_mul(y), i64::MAX);
414+
assert_eq!(I64::from_i64(x).saturating_mul(&I64::from_i64(y)), I64::MAX);
415+
416+
let x = -86027672844719838068326470675019902915i128;
417+
let y = 21188806580823612823777395451044967239i128;
418+
assert_eq!(x.saturating_mul(y), i128::MIN);
419+
assert_eq!(x.saturating_mul(-y), i128::MAX);
420+
assert_eq!(
421+
I128::from_i128(x).saturating_mul(&I128::from_i128(y)),
422+
I128::MIN
423+
);
424+
assert_eq!(
425+
I128::from_i128(x).saturating_mul(&I128::from_i128(-y)),
426+
I128::MAX
427+
);
428+
}
429+
315430
#[test]
316431
fn test_concatenating_mul() {
317432
assert_eq!(

src/int/mul_uint.rs

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -96,8 +96,10 @@ impl<const LIMBS: usize> Int<LIMBS> {
9696
&self,
9797
rhs: &Uint<RHS_LIMBS>,
9898
) -> ConstCtOption<Int<LIMBS>> {
99-
let (lo, hi, is_negative) = self.widening_mul_uint(rhs);
100-
Self::new_from_abs_sign(lo, is_negative).and_choice(hi.is_nonzero().not())
99+
let (abs_lhs, lhs_sgn) = self.abs_sign();
100+
let maybe_lo = abs_lhs.checked_mul(rhs);
101+
let (lo, is_some) = maybe_lo.components_ref();
102+
Self::new_from_abs_sign(*lo, lhs_sgn).and_choice(is_some)
101103
}
102104
}
103105

@@ -136,7 +138,7 @@ impl<const LIMBS: usize, const RHS_LIMBS: usize> Mul<&Uint<RHS_LIMBS>> for &Int<
136138
type Output = Int<LIMBS>;
137139

138140
fn mul(self, rhs: &Uint<RHS_LIMBS>) -> Self::Output {
139-
self.checked_mul(rhs)
141+
self.checked_mul_uint(rhs)
140142
.expect("attempted to multiply with overflow")
141143
}
142144
}

0 commit comments

Comments
 (0)