Skip to content

Commit 18b9fbc

Browse files
authored
hash2curve Expander improvements (#1318)
- Limit `ExpandMsg` output by `len_in_bytes` - Return `Result` from `Expander::fill_bytes()` - Optimize `ExpanderXmd` to copy whole slices - Add tests Addresses issues found in #1317
1 parent 437db4c commit 18b9fbc

File tree

5 files changed

+133
-50
lines changed

5 files changed

+133
-50
lines changed

ed448-goldilocks/src/field/element.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -536,15 +536,15 @@ mod tests {
536536
)
537537
.unwrap();
538538
let mut data = Array::<u8, U84>::default();
539-
expander.fill_bytes(&mut data);
539+
expander.fill_bytes(&mut data).unwrap();
540540
// TODO: This should be `Curve448FieldElement`.
541541
let u0 = Ed448FieldElement::from_okm(&data).0;
542542
let mut e_u0 = *expected_u0;
543543
e_u0.reverse();
544544
let mut e_u1 = *expected_u1;
545545
e_u1.reverse();
546546
assert_eq!(u0.to_bytes(), e_u0);
547-
expander.fill_bytes(&mut data);
547+
expander.fill_bytes(&mut data).unwrap();
548548
// TODO: This should be `Curve448FieldElement`.
549549
let u1 = Ed448FieldElement::from_okm(&data).0;
550550
assert_eq!(u1.to_bytes(), e_u1);
@@ -570,14 +570,14 @@ mod tests {
570570
)
571571
.unwrap();
572572
let mut data = Array::<u8, U84>::default();
573-
expander.fill_bytes(&mut data);
573+
expander.fill_bytes(&mut data).unwrap();
574574
let u0 = Ed448FieldElement::from_okm(&data).0;
575575
let mut e_u0 = *expected_u0;
576576
e_u0.reverse();
577577
let mut e_u1 = *expected_u1;
578578
e_u1.reverse();
579579
assert_eq!(u0.to_bytes(), e_u0);
580-
expander.fill_bytes(&mut data);
580+
expander.fill_bytes(&mut data).unwrap();
581581
let u1 = Ed448FieldElement::from_okm(&data).0;
582582
assert_eq!(u1.to_bytes(), e_u1);
583583
}

hash2curve/src/hash2field.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,9 @@ where
5151
let mut tmp = Array::<u8, <T as FromOkm>::Length>::default();
5252
let mut expander = E::expand_message(data, domain, len_in_bytes)?;
5353
Ok(core::array::from_fn(|_| {
54-
expander.fill_bytes(&mut tmp);
54+
expander
55+
.fill_bytes(&mut tmp)
56+
.expect("never exceeds `len_in_bytes`");
5557
T::from_okm(&tmp)
5658
}))
5759
}

hash2curve/src/hash2field/expand_msg.rs

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ pub(super) mod xof;
66
use core::num::NonZero;
77

88
use digest::{Digest, ExtendableOutput, Update, XofReader};
9+
use elliptic_curve::Error;
910
use elliptic_curve::array::{Array, ArraySize};
1011
use xmd::ExpandMsgXmdError;
1112
use xof::ExpandMsgXofError;
@@ -42,8 +43,12 @@ pub trait ExpandMsg<K> {
4243

4344
/// Expander that, call `read` until enough bytes have been consumed.
4445
pub trait Expander {
45-
/// Fill the array with the expanded bytes
46-
fn fill_bytes(&mut self, okm: &mut [u8]);
46+
/// Fill the array with the expanded bytes, returning how many bytes were read.
47+
///
48+
/// # Errors
49+
///
50+
/// If no bytes are left.
51+
fn fill_bytes(&mut self, okm: &mut [u8]) -> Result<usize, Error>;
4752
}
4853

4954
/// The domain separation tag

hash2curve/src/hash2field/expand_msg/xmd.rs

Lines changed: 80 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ use digest::{
1111
},
1212
block_api::BlockSizeUser,
1313
};
14+
use elliptic_curve::Error;
1415

1516
/// Implements `expand_message_xof` via the [`ExpandMsg`] trait:
1617
/// <https://www.rfc-editor.org/rfc/rfc9380.html#name-expand_message_xmd>
@@ -50,8 +51,10 @@ where
5051
return Err(ExpandMsgXmdError::Length);
5152
}
5253

53-
let ell = u8::try_from(usize::from(len_in_bytes.get()).div_ceil(b_in_bytes))
54-
.expect("should never pass the previous check");
54+
debug_assert!(
55+
usize::from(len_in_bytes.get()).div_ceil(b_in_bytes) <= u8::MAX.into(),
56+
"should never pass the previous check"
57+
);
5558

5659
let domain = Domain::xmd::<HashT>(dst)?;
5760
let mut b_0 = HashT::default();
@@ -80,7 +83,7 @@ where
8083
domain,
8184
index: 1,
8285
offset: 0,
83-
ell,
86+
remaining: len_in_bytes.get(),
8487
})
8588
}
8689
}
@@ -97,51 +100,64 @@ where
97100
domain: Domain<'a, HashT::OutputSize>,
98101
index: u8,
99102
offset: usize,
100-
ell: u8,
103+
remaining: u16,
101104
}
102105

103-
impl<HashT> ExpanderXmd<'_, HashT>
106+
impl<HashT> Expander for ExpanderXmd<'_, HashT>
104107
where
105108
HashT: BlockSizeUser + Default + FixedOutput + HashMarker,
106109
HashT::OutputSize: IsLessOrEqual<HashT::BlockSize, Output = True>,
107110
{
108-
fn next(&mut self) -> bool {
109-
if self.index < self.ell {
110-
self.index += 1;
111-
self.offset = 0;
112-
// b_0 XOR b_(idx - 1)
113-
let mut tmp = Array::<u8, HashT::OutputSize>::default();
114-
self.b_0
115-
.iter()
116-
.zip(&self.b_vals[..])
117-
.enumerate()
118-
.for_each(|(j, (b0val, bi1val))| tmp[j] = b0val ^ bi1val);
119-
let mut b_vals = HashT::default();
120-
b_vals.update(&tmp);
121-
b_vals.update(&[self.index]);
122-
self.domain.update_hash(&mut b_vals);
123-
b_vals.update(&[self.domain.len()]);
124-
self.b_vals = b_vals.finalize_fixed();
125-
true
126-
} else {
127-
false
111+
fn fill_bytes(&mut self, mut okm: &mut [u8]) -> Result<usize, Error> {
112+
if self.remaining == 0 {
113+
return Err(Error);
128114
}
129-
}
130-
}
131115

132-
impl<HashT> Expander for ExpanderXmd<'_, HashT>
133-
where
134-
HashT: BlockSizeUser + Default + FixedOutput + HashMarker,
135-
HashT::OutputSize: IsLessOrEqual<HashT::BlockSize, Output = True>,
136-
{
137-
fn fill_bytes(&mut self, okm: &mut [u8]) {
138-
for b in okm {
139-
if self.offset == self.b_vals.len() && !self.next() {
140-
return;
116+
let mut read_bytes = 0;
117+
118+
while self.remaining != 0 {
119+
if self.offset == self.b_vals.len() {
120+
self.index += 1;
121+
self.offset = 0;
122+
// b_0 XOR b_(idx - 1)
123+
let mut tmp = Array::<u8, HashT::OutputSize>::default();
124+
self.b_0
125+
.iter()
126+
.zip(&self.b_vals[..])
127+
.enumerate()
128+
.for_each(|(j, (b0val, bi1val))| tmp[j] = b0val ^ bi1val);
129+
let mut b_vals = HashT::default();
130+
b_vals.update(&tmp);
131+
b_vals.update(&[self.index]);
132+
self.domain.update_hash(&mut b_vals);
133+
b_vals.update(&[self.domain.len()]);
134+
self.b_vals = b_vals.finalize_fixed();
135+
}
136+
137+
let bytes_to_read = self
138+
.remaining
139+
.min(okm.len().try_into().unwrap_or(u16::MAX))
140+
.min(
141+
(self.b_vals.len() - self.offset)
142+
.try_into()
143+
.unwrap_or(u16::MAX),
144+
);
145+
146+
if bytes_to_read == 0 {
147+
return Ok(read_bytes);
141148
}
142-
*b = self.b_vals[self.offset];
143-
self.offset += 1;
149+
150+
okm[..bytes_to_read.into()].copy_from_slice(
151+
&self.b_vals[self.offset..self.offset + usize::from(bytes_to_read)],
152+
);
153+
okm = &mut okm[bytes_to_read.into()..];
154+
155+
self.offset += usize::from(bytes_to_read);
156+
self.remaining -= bytes_to_read;
157+
read_bytes += usize::from(bytes_to_read);
144158
}
159+
160+
Ok(read_bytes)
145161
}
146162
}
147163

@@ -181,6 +197,31 @@ mod test {
181197
use hex_literal::hex;
182198
use sha2::Sha256;
183199

200+
#[test]
201+
fn edge_cases() {
202+
fn generate() -> ExpanderXmd<'static, Sha256> {
203+
<ExpandMsgXmd<Sha256> as ExpandMsg<U4>>::expand_message(
204+
&[b"test message"],
205+
&[b"test DST"],
206+
NonZero::new(64).unwrap(),
207+
)
208+
.unwrap()
209+
}
210+
211+
assert_eq!(generate().fill_bytes(&mut [0; 0]), Ok(0));
212+
assert_eq!(generate().fill_bytes(&mut [0; 1]), Ok(1));
213+
assert_eq!(generate().fill_bytes(&mut [0; 64]), Ok(64));
214+
assert_eq!(generate().fill_bytes(&mut [0; 65]), Ok(64));
215+
216+
let mut expander = generate();
217+
assert_eq!(expander.fill_bytes(&mut [0; 0]), Ok(0));
218+
assert_eq!(expander.fill_bytes(&mut [0; 32]), Ok(32));
219+
assert_eq!(expander.fill_bytes(&mut [0; 0]), Ok(0));
220+
assert_eq!(expander.fill_bytes(&mut [0; 31]), Ok(31));
221+
assert_eq!(expander.fill_bytes(&mut [0; 2]), Ok(1));
222+
assert_eq!(expander.fill_bytes(&mut [0; 1]), Err(Error));
223+
}
224+
184225
fn assert_message<HashT>(
185226
msg: &[u8],
186227
domain: &Domain<'_, HashT::OutputSize>,
@@ -239,7 +280,7 @@ mod test {
239280
.unwrap();
240281

241282
let mut uniform_bytes = Array::<u8, L>::default();
242-
expander.fill_bytes(&mut uniform_bytes);
283+
expander.fill_bytes(&mut uniform_bytes).unwrap();
243284

244285
assert_eq!(uniform_bytes.as_slice(), self.uniform_bytes);
245286
}

hash2curve/src/hash2field/expand_msg/xof.rs

Lines changed: 39 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
use super::{Domain, ExpandMsg, Expander};
44
use core::{fmt, num::NonZero, ops::Mul};
55
use digest::{CollisionResistance, ExtendableOutput, HashMarker, Update, XofReader};
6+
use elliptic_curve::Error;
67
use elliptic_curve::array::{
78
ArraySize,
89
typenum::{IsGreaterOrEqual, Prod, True, U2},
@@ -19,6 +20,7 @@ where
1920
HashT: Default + ExtendableOutput + Update + HashMarker,
2021
{
2122
reader: <HashT as ExtendableOutput>::Reader,
23+
remaining: u16,
2224
}
2325

2426
impl<HashT> fmt::Debug for ExpandMsgXof<HashT>
@@ -64,16 +66,26 @@ where
6466
domain.update_hash(&mut reader);
6567
reader.update(&[domain.len()]);
6668
let reader = reader.finalize_xof();
67-
Ok(Self { reader })
69+
Ok(Self {
70+
reader,
71+
remaining: len_in_bytes,
72+
})
6873
}
6974
}
7075

7176
impl<HashT> Expander for ExpandMsgXof<HashT>
7277
where
7378
HashT: Default + ExtendableOutput + Update + HashMarker,
7479
{
75-
fn fill_bytes(&mut self, okm: &mut [u8]) {
76-
self.reader.read(okm);
80+
fn fill_bytes(&mut self, okm: &mut [u8]) -> Result<usize, Error> {
81+
if self.remaining == 0 {
82+
return Err(Error);
83+
}
84+
85+
let bytes_to_read = self.remaining.min(okm.len().try_into().unwrap_or(u16::MAX));
86+
self.reader.read(&mut okm[..bytes_to_read.into()]);
87+
self.remaining -= bytes_to_read;
88+
Ok(bytes_to_read.into())
7789
}
7890
}
7991

@@ -109,6 +121,29 @@ mod test {
109121
use hex_literal::hex;
110122
use sha3::Shake128;
111123

124+
#[test]
125+
fn edge_cases() {
126+
fn generate() -> ExpandMsgXof<Shake128> {
127+
<ExpandMsgXof<Shake128> as ExpandMsg<U16>>::expand_message(
128+
&[b"test message"],
129+
&[b"test DST"],
130+
NonZero::new(64).unwrap(),
131+
)
132+
.unwrap()
133+
}
134+
135+
assert_eq!(generate().fill_bytes(&mut [0; 0]), Ok(0));
136+
assert_eq!(generate().fill_bytes(&mut [0; 1]), Ok(1));
137+
assert_eq!(generate().fill_bytes(&mut [0; 64]), Ok(64));
138+
assert_eq!(generate().fill_bytes(&mut [0; 65]), Ok(64));
139+
140+
let mut expander = generate();
141+
assert_eq!(expander.fill_bytes(&mut [0; 0]), Ok(0));
142+
assert_eq!(expander.fill_bytes(&mut [0; 1]), Ok(1));
143+
assert_eq!(expander.fill_bytes(&mut [0; 64]), Ok(63));
144+
assert_eq!(expander.fill_bytes(&mut [0; 1]), Err(Error));
145+
}
146+
112147
fn assert_message(msg: &[u8], domain: &Domain<'_, U32>, len_in_bytes: u16, bytes: &[u8]) {
113148
let msg_len = msg.len();
114149
assert_eq!(msg, &bytes[..msg_len]);
@@ -155,7 +190,7 @@ mod test {
155190
.unwrap();
156191

157192
let mut uniform_bytes = Array::<u8, L>::default();
158-
expander.fill_bytes(&mut uniform_bytes);
193+
expander.fill_bytes(&mut uniform_bytes).unwrap();
159194

160195
assert_eq!(uniform_bytes.as_slice(), self.uniform_bytes);
161196
}

0 commit comments

Comments
 (0)