@@ -11,6 +11,7 @@ use digest::{
1111 } ,
1212 block_api:: BlockSizeUser ,
1313} ;
14+ use elliptic_curve:: Error ;
1415
1516/// Implements `expand_message_xof` via the [`ExpandMsg`] trait:
1617/// <https://www.rfc-editor.org/rfc/rfc9380.html#name-expand_message_xmd>
5051 return Err ( ExpandMsgXmdError :: Length ) ;
5152 }
5253
53- let ell = u8:: try_from ( usize:: from ( len_in_bytes. get ( ) ) . div_ceil ( b_in_bytes) )
54- . expect ( "should never pass the previous check" ) ;
54+ debug_assert ! (
55+ usize :: from( len_in_bytes. get( ) ) . div_ceil( b_in_bytes) <= u8 :: MAX . into( ) ,
56+ "should never pass the previous check"
57+ ) ;
5558
5659 let domain = Domain :: xmd :: < HashT > ( dst) ?;
5760 let mut b_0 = HashT :: default ( ) ;
8083 domain,
8184 index : 1 ,
8285 offset : 0 ,
83- ell ,
86+ remaining : len_in_bytes . get ( ) ,
8487 } )
8588 }
8689}
@@ -97,51 +100,64 @@ where
97100 domain : Domain < ' a , HashT :: OutputSize > ,
98101 index : u8 ,
99102 offset : usize ,
100- ell : u8 ,
103+ remaining : u16 ,
101104}
102105
103- impl < HashT > ExpanderXmd < ' _ , HashT >
106+ impl < HashT > Expander for ExpanderXmd < ' _ , HashT >
104107where
105108 HashT : BlockSizeUser + Default + FixedOutput + HashMarker ,
106109 HashT :: OutputSize : IsLessOrEqual < HashT :: BlockSize , Output = True > ,
107110{
108- fn next ( & mut self ) -> bool {
109- if self . index < self . ell {
110- self . index += 1 ;
111- self . offset = 0 ;
112- // b_0 XOR b_(idx - 1)
113- let mut tmp = Array :: < u8 , HashT :: OutputSize > :: default ( ) ;
114- self . b_0
115- . iter ( )
116- . zip ( & self . b_vals [ ..] )
117- . enumerate ( )
118- . for_each ( |( j, ( b0val, bi1val) ) | tmp[ j] = b0val ^ bi1val) ;
119- let mut b_vals = HashT :: default ( ) ;
120- b_vals. update ( & tmp) ;
121- b_vals. update ( & [ self . index ] ) ;
122- self . domain . update_hash ( & mut b_vals) ;
123- b_vals. update ( & [ self . domain . len ( ) ] ) ;
124- self . b_vals = b_vals. finalize_fixed ( ) ;
125- true
126- } else {
127- false
111+ fn fill_bytes ( & mut self , mut okm : & mut [ u8 ] ) -> Result < usize , Error > {
112+ if self . remaining == 0 {
113+ return Err ( Error ) ;
128114 }
129- }
130- }
131115
132- impl < HashT > Expander for ExpanderXmd < ' _ , HashT >
133- where
134- HashT : BlockSizeUser + Default + FixedOutput + HashMarker ,
135- HashT :: OutputSize : IsLessOrEqual < HashT :: BlockSize , Output = True > ,
136- {
137- fn fill_bytes ( & mut self , okm : & mut [ u8 ] ) {
138- for b in okm {
139- if self . offset == self . b_vals . len ( ) && !self . next ( ) {
140- return ;
116+ let mut read_bytes = 0 ;
117+
118+ while self . remaining != 0 {
119+ if self . offset == self . b_vals . len ( ) {
120+ self . index += 1 ;
121+ self . offset = 0 ;
122+ // b_0 XOR b_(idx - 1)
123+ let mut tmp = Array :: < u8 , HashT :: OutputSize > :: default ( ) ;
124+ self . b_0
125+ . iter ( )
126+ . zip ( & self . b_vals [ ..] )
127+ . enumerate ( )
128+ . for_each ( |( j, ( b0val, bi1val) ) | tmp[ j] = b0val ^ bi1val) ;
129+ let mut b_vals = HashT :: default ( ) ;
130+ b_vals. update ( & tmp) ;
131+ b_vals. update ( & [ self . index ] ) ;
132+ self . domain . update_hash ( & mut b_vals) ;
133+ b_vals. update ( & [ self . domain . len ( ) ] ) ;
134+ self . b_vals = b_vals. finalize_fixed ( ) ;
135+ }
136+
137+ let bytes_to_read = self
138+ . remaining
139+ . min ( okm. len ( ) . try_into ( ) . unwrap_or ( u16:: MAX ) )
140+ . min (
141+ ( self . b_vals . len ( ) - self . offset )
142+ . try_into ( )
143+ . unwrap_or ( u16:: MAX ) ,
144+ ) ;
145+
146+ if bytes_to_read == 0 {
147+ return Ok ( read_bytes) ;
141148 }
142- * b = self . b_vals [ self . offset ] ;
143- self . offset += 1 ;
149+
150+ okm[ ..bytes_to_read. into ( ) ] . copy_from_slice (
151+ & self . b_vals [ self . offset ..self . offset + usize:: from ( bytes_to_read) ] ,
152+ ) ;
153+ okm = & mut okm[ bytes_to_read. into ( ) ..] ;
154+
155+ self . offset += usize:: from ( bytes_to_read) ;
156+ self . remaining -= bytes_to_read;
157+ read_bytes += usize:: from ( bytes_to_read) ;
144158 }
159+
160+ Ok ( read_bytes)
145161 }
146162}
147163
@@ -181,6 +197,31 @@ mod test {
181197 use hex_literal:: hex;
182198 use sha2:: Sha256 ;
183199
200+ #[ test]
201+ fn edge_cases ( ) {
202+ fn generate ( ) -> ExpanderXmd < ' static , Sha256 > {
203+ <ExpandMsgXmd < Sha256 > as ExpandMsg < U4 > >:: expand_message (
204+ & [ b"test message" ] ,
205+ & [ b"test DST" ] ,
206+ NonZero :: new ( 64 ) . unwrap ( ) ,
207+ )
208+ . unwrap ( )
209+ }
210+
211+ assert_eq ! ( generate( ) . fill_bytes( & mut [ 0 ; 0 ] ) , Ok ( 0 ) ) ;
212+ assert_eq ! ( generate( ) . fill_bytes( & mut [ 0 ; 1 ] ) , Ok ( 1 ) ) ;
213+ assert_eq ! ( generate( ) . fill_bytes( & mut [ 0 ; 64 ] ) , Ok ( 64 ) ) ;
214+ assert_eq ! ( generate( ) . fill_bytes( & mut [ 0 ; 65 ] ) , Ok ( 64 ) ) ;
215+
216+ let mut expander = generate ( ) ;
217+ assert_eq ! ( expander. fill_bytes( & mut [ 0 ; 0 ] ) , Ok ( 0 ) ) ;
218+ assert_eq ! ( expander. fill_bytes( & mut [ 0 ; 32 ] ) , Ok ( 32 ) ) ;
219+ assert_eq ! ( expander. fill_bytes( & mut [ 0 ; 0 ] ) , Ok ( 0 ) ) ;
220+ assert_eq ! ( expander. fill_bytes( & mut [ 0 ; 31 ] ) , Ok ( 31 ) ) ;
221+ assert_eq ! ( expander. fill_bytes( & mut [ 0 ; 2 ] ) , Ok ( 1 ) ) ;
222+ assert_eq ! ( expander. fill_bytes( & mut [ 0 ; 1 ] ) , Err ( Error ) ) ;
223+ }
224+
184225 fn assert_message < HashT > (
185226 msg : & [ u8 ] ,
186227 domain : & Domain < ' _ , HashT :: OutputSize > ,
@@ -239,7 +280,7 @@ mod test {
239280 . unwrap ( ) ;
240281
241282 let mut uniform_bytes = Array :: < u8 , L > :: default ( ) ;
242- expander. fill_bytes ( & mut uniform_bytes) ;
283+ expander. fill_bytes ( & mut uniform_bytes) . unwrap ( ) ;
243284
244285 assert_eq ! ( uniform_bytes. as_slice( ) , self . uniform_bytes) ;
245286 }
0 commit comments