Skip to content

Commit b01c581

Browse files
authored
Improve BatchInvert (#1864)
This PR does a couple of things: - Changes the implementation to require two write stores instead of 3 write stores and one read. - Add a `BatchInvert::batch_invert_mut()` methods that allows users to pass one of the two stores themselves. With this a single store can be allocated instead of 3. - Reduce the bounds from `Invert + Mul + Copy + Default + ConditionallySelectable` to `Field`. This was largely inspired by [`curve25519-dalek`s implementation](https://github.com/dalek-cryptography/curve25519-dalek/blob/dd5bd108d6985491acb3de25497fb082a29c9fb7/curve25519-dalek/src/scalar.rs#L788-L837). I adjusted it a bit to save two unnecessary multiplications by one. Companion PR: RustCrypto/elliptic-curves#1205.
1 parent 1cae37d commit b01c581

File tree

1 file changed

+70
-64
lines changed

1 file changed

+70
-64
lines changed

elliptic-curve/src/ops.rs

Lines changed: 70 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -1,78 +1,77 @@
11
//! Traits for arithmetic operations on elliptic curve field elements.
22
3+
use core::iter;
34
pub use core::ops::{Add, AddAssign, Mul, Neg, Shr, ShrAssign, Sub, SubAssign};
45
pub use crypto_bigint::Invert;
56

67
use crypto_bigint::Integer;
7-
use subtle::{Choice, ConditionallySelectable, CtOption};
8+
use ff::Field;
9+
use subtle::{Choice, CtOption};
810

911
#[cfg(feature = "alloc")]
10-
use alloc::vec::Vec;
12+
use alloc::{borrow::ToOwned, vec::Vec};
1113

1214
/// Perform a batched inversion on a sequence of field elements (i.e. base field elements or scalars)
1315
/// at an amortized cost that should be practically as efficient as a single inversion.
14-
pub trait BatchInvert<FieldElements: ?Sized>: Invert + Sized {
16+
pub trait BatchInvert<FieldElements: ?Sized>: Field + Sized {
1517
/// The output of batch inversion. A container of field elements.
1618
type Output: AsRef<[Self]>;
1719

1820
/// Invert a batch of field elements.
1921
fn batch_invert(
2022
field_elements: &FieldElements,
2123
) -> CtOption<<Self as BatchInvert<FieldElements>>::Output>;
24+
25+
/// Invert a batch of field elements in-place.
26+
///
27+
/// # ⚠️ Warning
28+
///
29+
/// Even though `field_elements` is modified regardless of success, on failure it does not
30+
/// contain correctly inverted scalars and should be discarded instead.
31+
///
32+
/// Consider using [`Self::batch_invert()`] instead.
33+
fn batch_invert_mut(field_elements: &mut FieldElements) -> Choice;
2234
}
2335

2436
impl<const N: usize, T> BatchInvert<[T; N]> for T
2537
where
26-
T: Invert<Output = CtOption<Self>>
27-
+ Mul<Self, Output = Self>
28-
+ Copy
29-
+ Default
30-
+ ConditionallySelectable,
38+
T: Field,
3139
{
3240
type Output = [Self; N];
3341

3442
fn batch_invert(field_elements: &[Self; N]) -> CtOption<[Self; N]> {
35-
let mut field_elements_multiples = [Self::default(); N];
36-
let mut field_elements_multiples_inverses = [Self::default(); N];
37-
let mut field_elements_inverses = [Self::default(); N];
38-
39-
let inversion_succeeded = invert_batch_internal(
40-
field_elements,
41-
&mut field_elements_multiples,
42-
&mut field_elements_multiples_inverses,
43-
&mut field_elements_inverses,
44-
);
43+
let mut field_elements_inverses = *field_elements;
44+
let inversion_succeeded = Self::batch_invert_mut(&mut field_elements_inverses);
4545

4646
CtOption::new(field_elements_inverses, inversion_succeeded)
4747
}
48+
49+
fn batch_invert_mut(field_elements: &mut [T; N]) -> Choice {
50+
let mut field_elements_pad = [Self::default(); N];
51+
52+
invert_batch_internal(field_elements, &mut field_elements_pad)
53+
}
4854
}
4955

5056
#[cfg(feature = "alloc")]
5157
impl<T> BatchInvert<[T]> for T
5258
where
53-
T: Invert<Output = CtOption<Self>>
54-
+ Mul<Self, Output = Self>
55-
+ Copy
56-
+ Default
57-
+ ConditionallySelectable,
59+
T: Field,
5860
{
5961
type Output = Vec<Self>;
6062

6163
fn batch_invert(field_elements: &[Self]) -> CtOption<Vec<Self>> {
62-
let mut field_elements_multiples: Vec<Self> = vec![Self::default(); field_elements.len()];
63-
let mut field_elements_multiples_inverses: Vec<Self> =
64-
vec![Self::default(); field_elements.len()];
65-
let mut field_elements_inverses: Vec<Self> = vec![Self::default(); field_elements.len()];
66-
67-
let inversion_succeeded = invert_batch_internal(
68-
field_elements,
69-
field_elements_multiples.as_mut(),
70-
field_elements_multiples_inverses.as_mut(),
71-
field_elements_inverses.as_mut(),
72-
);
64+
let mut field_elements_inverses: Vec<Self> = field_elements.to_owned();
65+
let inversion_succeeded = Self::batch_invert_mut(field_elements_inverses.as_mut_slice());
7366

7467
CtOption::new(field_elements_inverses, inversion_succeeded)
7568
}
69+
70+
fn batch_invert_mut(field_elements: &mut [T]) -> Choice {
71+
let mut field_elements_pad: Vec<Self> = vec![Self::default(); field_elements.len()];
72+
73+
invert_batch_internal(field_elements, field_elements_pad.as_mut())
74+
}
7675
}
7776

7877
/// Implements "Montgomery's trick", a trick for computing many modular inverses at once.
@@ -81,44 +80,51 @@ where
8180
/// to computing a single inversion, plus some storage and `O(n)` extra multiplications.
8281
///
8382
/// See: https://iacr.org/archive/pkc2004/29470042/29470042.pdf section 2.2.
84-
fn invert_batch_internal<
85-
T: Invert<Output = CtOption<T>> + Mul<T, Output = T> + Default + ConditionallySelectable,
86-
>(
87-
field_elements: &[T],
88-
field_elements_multiples: &mut [T],
89-
field_elements_multiples_inverses: &mut [T],
90-
field_elements_inverses: &mut [T],
83+
fn invert_batch_internal<T: Field>(
84+
field_elements: &mut [T],
85+
field_elements_pad: &mut [T],
9186
) -> Choice {
9287
let batch_size = field_elements.len();
93-
if batch_size == 0
94-
|| batch_size != field_elements_multiples.len()
95-
|| batch_size != field_elements_multiples_inverses.len()
96-
{
88+
if batch_size == 0 || batch_size != field_elements_pad.len() {
9789
return Choice::from(0);
9890
}
9991

100-
field_elements_multiples[0] = field_elements[0];
101-
for i in 1..batch_size {
92+
let mut acc = field_elements[0];
93+
field_elements_pad[0] = acc;
94+
95+
for (field_element, field_element_pad) in field_elements
96+
.iter_mut()
97+
.zip(field_elements_pad.iter_mut())
98+
.skip(1)
99+
{
102100
// $ a_n = a_{n-1}*x_n $
103-
field_elements_multiples[i] = field_elements_multiples[i - 1] * field_elements[i];
101+
acc *= *field_element;
102+
*field_element_pad = acc;
104103
}
105104

106-
field_elements_multiples[batch_size - 1]
107-
.invert()
108-
.map(|multiple_of_inverses_of_all_field_elements| {
109-
field_elements_multiples_inverses[batch_size - 1] =
110-
multiple_of_inverses_of_all_field_elements;
111-
for i in (1..batch_size).rev() {
112-
// $ a_{n-1} = {a_n}^{-1}*x_n $
113-
field_elements_multiples_inverses[i - 1] =
114-
field_elements_multiples_inverses[i] * field_elements[i];
115-
}
116-
117-
field_elements_inverses[0] = field_elements_multiples_inverses[0];
118-
for i in 1..batch_size {
119-
// $ {x_n}^{-1} = a_{n}^{-1}*a_{n-1} $
120-
field_elements_inverses[i] =
121-
field_elements_multiples_inverses[i] * field_elements_multiples[i - 1];
105+
acc.invert()
106+
.map(|mut acc| {
107+
// Shift the iterator by one element back. The one we are skipping is served in `acc`.
108+
let field_elements_pad = field_elements_pad
109+
.iter()
110+
.rev()
111+
.skip(1)
112+
.map(Some)
113+
.chain(iter::once(None));
114+
115+
for (field_element, field_element_pad) in
116+
field_elements.iter_mut().rev().zip(field_elements_pad)
117+
{
118+
if let Some(field_element_pad) = field_element_pad {
119+
// Store in a temporary so we can overwrite `field_element`.
120+
// $ a_{n-1} = {a_n}^{-1}*x_n $
121+
let tmp = acc * *field_element;
122+
// $ {x_n}^{-1} = a_{n}^{-1}*a_{n-1} $
123+
*field_element = acc * *field_element_pad;
124+
acc = tmp;
125+
} else {
126+
*field_element = acc;
127+
}
122128
}
123129
})
124130
.is_some()

0 commit comments

Comments
 (0)