@@ -10,7 +10,7 @@ use digest::{
1010 typenum:: { IsLess , IsLessOrEqual , Unsigned , U256 } ,
1111 GenericArray ,
1212 } ,
13- Digest ,
13+ FixedOutput , HashMarker ,
1414} ;
1515
1616/// Placeholder type for implementing `expand_message_xmd` based on a hash function
@@ -22,14 +22,14 @@ use digest::{
2222/// - `len_in_bytes > 255 * HashT::OutputSize`
2323pub struct ExpandMsgXmd < HashT > ( PhantomData < HashT > )
2424where
25- HashT : Digest + BlockSizeUser ,
25+ HashT : BlockSizeUser + Default + FixedOutput + HashMarker ,
2626 HashT :: OutputSize : IsLess < U256 > ,
2727 HashT :: OutputSize : IsLessOrEqual < HashT :: BlockSize > ;
2828
2929/// ExpandMsgXmd implements expand_message_xmd for the ExpandMsg trait
3030impl < ' a , HashT > ExpandMsg < ' a > for ExpandMsgXmd < HashT >
3131where
32- HashT : Digest + BlockSizeUser ,
32+ HashT : BlockSizeUser + Default + FixedOutput + HashMarker ,
3333 // If `len_in_bytes` is bigger then 256, length of the `DST` will depend on
3434 // the output size of the hash, which is still not allowed to be bigger then 256:
3535 // https://www.ietf.org/archive/id/draft-irtf-cfrg-hash-to-curve-13.html#section-5.4.1-6
4242
4343 fn expand_message (
4444 msgs : & [ & [ u8 ] ] ,
45- dst : & ' a [ u8 ] ,
45+ dsts : & ' a [ & ' a [ u8 ] ] ,
4646 len_in_bytes : usize ,
4747 ) -> Result < Self :: Expander > {
4848 if len_in_bytes == 0 {
@@ -54,26 +54,26 @@ where
5454 let b_in_bytes = HashT :: OutputSize :: to_usize ( ) ;
5555 let ell = u8:: try_from ( ( len_in_bytes + b_in_bytes - 1 ) / b_in_bytes) . map_err ( |_| Error ) ?;
5656
57- let domain = Domain :: xmd :: < HashT > ( dst ) ?;
58- let mut b_0 = HashT :: new ( ) ;
59- b_0. update ( GenericArray :: < u8 , HashT :: BlockSize > :: default ( ) ) ;
57+ let domain = Domain :: xmd :: < HashT > ( dsts ) ?;
58+ let mut b_0 = HashT :: default ( ) ;
59+ b_0. update ( & GenericArray :: < u8 , HashT :: BlockSize > :: default ( ) ) ;
6060
6161 for msg in msgs {
6262 b_0. update ( msg) ;
6363 }
6464
65- b_0. update ( len_in_bytes_u16. to_be_bytes ( ) ) ;
66- b_0. update ( [ 0 ] ) ;
67- b_0 . update ( domain. data ( ) ) ;
68- b_0. update ( [ domain. len ( ) ] ) ;
69- let b_0 = b_0. finalize ( ) ;
65+ b_0. update ( & len_in_bytes_u16. to_be_bytes ( ) ) ;
66+ b_0. update ( & [ 0 ] ) ;
67+ domain. update_hash ( & mut b_0 ) ;
68+ b_0. update ( & [ domain. len ( ) ] ) ;
69+ let b_0 = b_0. finalize_fixed ( ) ;
7070
71- let mut b_vals = HashT :: new ( ) ;
71+ let mut b_vals = HashT :: default ( ) ;
7272 b_vals. update ( & b_0[ ..] ) ;
73- b_vals. update ( [ 1u8 ] ) ;
74- b_vals . update ( domain. data ( ) ) ;
75- b_vals. update ( [ domain. len ( ) ] ) ;
76- let b_vals = b_vals. finalize ( ) ;
73+ b_vals. update ( & [ 1u8 ] ) ;
74+ domain. update_hash ( & mut b_vals ) ;
75+ b_vals. update ( & [ domain. len ( ) ] ) ;
76+ let b_vals = b_vals. finalize_fixed ( ) ;
7777
7878 Ok ( ExpanderXmd {
7979 b_0,
8989/// [`Expander`] type for [`ExpandMsgXmd`].
9090pub struct ExpanderXmd < ' a , HashT >
9191where
92- HashT : Digest + BlockSizeUser ,
92+ HashT : BlockSizeUser + Default + FixedOutput + HashMarker ,
9393 HashT :: OutputSize : IsLess < U256 > ,
9494 HashT :: OutputSize : IsLessOrEqual < HashT :: BlockSize > ,
9595{
@@ -103,7 +103,7 @@ where
103103
104104impl < ' a , HashT > ExpanderXmd < ' a , HashT >
105105where
106- HashT : Digest + BlockSizeUser ,
106+ HashT : BlockSizeUser + Default + FixedOutput + HashMarker ,
107107 HashT :: OutputSize : IsLess < U256 > ,
108108 HashT :: OutputSize : IsLessOrEqual < HashT :: BlockSize > ,
109109{
@@ -118,12 +118,12 @@ where
118118 . zip ( & self . b_vals [ ..] )
119119 . enumerate ( )
120120 . for_each ( |( j, ( b0val, bi1val) ) | tmp[ j] = b0val ^ bi1val) ;
121- let mut b_vals = HashT :: new ( ) ;
122- b_vals. update ( tmp) ;
123- b_vals. update ( [ self . index ] ) ;
124- b_vals . update ( self . domain . data ( ) ) ;
125- b_vals. update ( [ self . domain . len ( ) ] ) ;
126- self . b_vals = b_vals. finalize ( ) ;
121+ let mut b_vals = HashT :: default ( ) ;
122+ b_vals. update ( & tmp) ;
123+ b_vals. update ( & [ self . index ] ) ;
124+ self . domain . update_hash ( & mut b_vals ) ;
125+ b_vals. update ( & [ self . domain . len ( ) ] ) ;
126+ self . b_vals = b_vals. finalize_fixed ( ) ;
127127 true
128128 } else {
129129 false
@@ -133,7 +133,7 @@ where
133133
134134impl < ' a , HashT > Expander for ExpanderXmd < ' a , HashT >
135135where
136- HashT : Digest + BlockSizeUser ,
136+ HashT : BlockSizeUser + Default + FixedOutput + HashMarker ,
137137 HashT :: OutputSize : IsLess < U256 > ,
138138 HashT :: OutputSize : IsLessOrEqual < HashT :: BlockSize > ,
139139{
@@ -165,7 +165,7 @@ mod test {
165165 len_in_bytes : u16 ,
166166 bytes : & [ u8 ] ,
167167 ) where
168- HashT : Digest + BlockSizeUser ,
168+ HashT : BlockSizeUser + Default + FixedOutput + HashMarker ,
169169 HashT :: OutputSize : IsLess < U256 > ,
170170 {
171171 let block = HashT :: BlockSize :: to_usize ( ) ;
@@ -183,8 +183,8 @@ mod test {
183183 let pad = l + mem:: size_of :: < u8 > ( ) ;
184184 assert_eq ! ( [ 0 ] , & bytes[ l..pad] ) ;
185185
186- let dst = pad + domain. data ( ) . len ( ) ;
187- assert_eq ! ( domain. data ( ) , & bytes[ pad..dst] ) ;
186+ let dst = pad + usize :: from ( domain. len ( ) ) ;
187+ domain. assert ( & bytes[ pad..dst] ) ;
188188
189189 let dst_len = dst + mem:: size_of :: < u8 > ( ) ;
190190 assert_eq ! ( [ domain. len( ) ] , & bytes[ dst..dst_len] ) ;
@@ -205,13 +205,14 @@ mod test {
205205 domain : & Domain < ' _ , HashT :: OutputSize > ,
206206 ) -> Result < ( ) >
207207 where
208- HashT : Digest + BlockSizeUser ,
208+ HashT : BlockSizeUser + Default + FixedOutput + HashMarker ,
209209 HashT :: OutputSize : IsLess < U256 > + IsLessOrEqual < HashT :: BlockSize > ,
210210 {
211211 assert_message :: < HashT > ( self . msg , domain, L :: to_u16 ( ) , self . msg_prime ) ;
212212
213+ let dst = [ dst] ;
213214 let mut expander =
214- ExpandMsgXmd :: < HashT > :: expand_message ( & [ self . msg ] , dst, L :: to_usize ( ) ) ?;
215+ ExpandMsgXmd :: < HashT > :: expand_message ( & [ self . msg ] , & dst, L :: to_usize ( ) ) ?;
215216
216217 let mut uniform_bytes = GenericArray :: < u8 , L > :: default ( ) ;
217218 expander. fill_bytes ( & mut uniform_bytes) ;
@@ -227,8 +228,8 @@ mod test {
227228 const DST_PRIME : & [ u8 ] =
228229 & hex ! ( "515555582d5630312d435330322d776974682d657870616e6465722d5348413235362d31323826" ) ;
229230
230- let dst_prime = Domain :: xmd :: < Sha256 > ( DST ) ?;
231- dst_prime. assert ( DST_PRIME ) ;
231+ let dst_prime = Domain :: xmd :: < Sha256 > ( & [ DST ] ) ?;
232+ dst_prime. assert_dst ( DST_PRIME ) ;
232233
233234 const TEST_VECTORS_32 : & [ TestVector ] = & [
234235 TestVector {
@@ -299,8 +300,8 @@ mod test {
299300 const DST_PRIME : & [ u8 ] =
300301 & hex ! ( "412717974da474d0f8c420f320ff81e8432adb7c927d9bd082b4fb4d16c0a23620" ) ;
301302
302- let dst_prime = Domain :: xmd :: < Sha256 > ( DST ) ?;
303- dst_prime. assert ( DST_PRIME ) ;
303+ let dst_prime = Domain :: xmd :: < Sha256 > ( & [ DST ] ) ?;
304+ dst_prime. assert_dst ( DST_PRIME ) ;
304305
305306 const TEST_VECTORS_32 : & [ TestVector ] = & [
306307 TestVector {
@@ -377,8 +378,8 @@ mod test {
377378 const DST_PRIME : & [ u8 ] =
378379 & hex ! ( "515555582d5630312d435330322d776974682d657870616e6465722d5348413531322d32353626" ) ;
379380
380- let dst_prime = Domain :: xmd :: < Sha512 > ( DST ) ?;
381- dst_prime. assert ( DST_PRIME ) ;
381+ let dst_prime = Domain :: xmd :: < Sha512 > ( & [ DST ] ) ?;
382+ dst_prime. assert_dst ( DST_PRIME ) ;
382383
383384 const TEST_VECTORS_32 : & [ TestVector ] = & [
384385 TestVector {
0 commit comments