@@ -22,6 +22,7 @@ use core::ops::{FnMut, Deref};
2222use key:: { SecretKey , PublicKey } ;
2323use ffi:: { self , CPtr } ;
2424use secp256k1_sys:: types:: { c_int, c_uchar, c_void} ;
25+ use Error ;
2526
2627/// A tag used for recovering the public key from a compact signature
2728#[ derive( Copy , Clone ) ]
@@ -63,6 +64,7 @@ impl SharedSecret {
6364
6465 /// Set the length of the object.
6566 pub ( crate ) fn set_len ( & mut self , len : usize ) {
67+ debug_assert ! ( len <= self . data. len( ) ) ;
6668 self . len = len;
6769 }
6870}
@@ -87,19 +89,29 @@ impl Deref for SharedSecret {
8789}
8890
8991
92+ #[ cfg( feature = "std" ) ]
9093unsafe extern "C" fn hash_callback < F > ( output : * mut c_uchar , x : * const c_uchar , y : * const c_uchar , data : * mut c_void ) -> c_int
9194 where F : FnMut ( [ u8 ; 32 ] , [ u8 ; 32 ] ) -> SharedSecret {
92- let callback: & mut F = & mut * ( data as * mut F ) ;
9395
94- let mut x_arr = [ 0 ; 32 ] ;
95- let mut y_arr = [ 0 ; 32 ] ;
96- ptr:: copy_nonoverlapping ( x, x_arr. as_mut_ptr ( ) , 32 ) ;
97- ptr:: copy_nonoverlapping ( y, y_arr. as_mut_ptr ( ) , 32 ) ;
96+ use std:: panic:: catch_unwind;
97+ let res = catch_unwind ( || {
98+ let callback: & mut F = & mut * ( data as * mut F ) ;
9899
99- let secret = callback ( x_arr, y_arr) ;
100- ptr:: copy_nonoverlapping ( secret. as_ptr ( ) , output as * mut u8 , secret. len ( ) ) ;
100+ let mut x_arr = [ 0 ; 32 ] ;
101+ let mut y_arr = [ 0 ; 32 ] ;
102+ ptr:: copy_nonoverlapping ( x, x_arr. as_mut_ptr ( ) , 32 ) ;
103+ ptr:: copy_nonoverlapping ( y, y_arr. as_mut_ptr ( ) , 32 ) ;
101104
102- secret. len ( ) as c_int
105+ let secret = callback ( x_arr, y_arr) ;
106+ ptr:: copy_nonoverlapping ( secret. as_ptr ( ) , output as * mut u8 , secret. len ( ) ) ;
107+
108+ secret. len ( ) as c_int
109+ } ) ;
110+ if let Ok ( len) = res {
111+ len
112+ } else {
113+ -1
114+ }
103115}
104116
105117
@@ -140,7 +152,8 @@ impl SharedSecret {
140152 /// });
141153 ///
142154 /// ```
143- pub fn new_with_hash < F > ( point : & PublicKey , scalar : & SecretKey , mut hash_function : F ) -> SharedSecret
155+ #[ cfg( feature = "std" ) ]
156+ pub fn new_with_hash < F > ( point : & PublicKey , scalar : & SecretKey , mut hash_function : F ) -> Result < SharedSecret , Error >
144157 where F : FnMut ( [ u8 ; 32 ] , [ u8 ; 32 ] ) -> SharedSecret
145158 {
146159 let mut ss = SharedSecret :: empty ( ) ;
@@ -156,9 +169,12 @@ impl SharedSecret {
156169 & mut hash_function as * mut F as * mut c_void ,
157170 )
158171 } ;
172+ if res == -1 {
173+ return Err ( Error :: CallbackPanicked ) ;
174+ }
159175 debug_assert ! ( res >= 16 ) ; // 128 bit is the minimum for a secure hash function and the minimum we let users.
160176 ss. set_len ( res as usize ) ;
161- ss
177+ Ok ( ss )
162178 }
163179}
164180
@@ -167,6 +183,7 @@ mod tests {
167183 use rand:: thread_rng;
168184 use super :: SharedSecret ;
169185 use super :: super :: Secp256k1 ;
186+ use Error ;
170187
171188 #[ test]
172189 fn ecdh ( ) {
@@ -187,9 +204,9 @@ mod tests {
187204 let ( sk1, pk1) = s. generate_keypair ( & mut thread_rng ( ) ) ;
188205 let ( sk2, pk2) = s. generate_keypair ( & mut thread_rng ( ) ) ;
189206
190- let sec1 = SharedSecret :: new_with_hash ( & pk1, & sk2, |x, _| x. into ( ) ) ;
191- let sec2 = SharedSecret :: new_with_hash ( & pk2, & sk1, |x, _| x. into ( ) ) ;
192- let sec_odd = SharedSecret :: new_with_hash ( & pk1, & sk1, |x, _| x. into ( ) ) ;
207+ let sec1 = SharedSecret :: new_with_hash ( & pk1, & sk2, |x, _| x. into ( ) ) . unwrap ( ) ;
208+ let sec2 = SharedSecret :: new_with_hash ( & pk2, & sk1, |x, _| x. into ( ) ) . unwrap ( ) ;
209+ let sec_odd = SharedSecret :: new_with_hash ( & pk1, & sk1, |x, _| x. into ( ) ) . unwrap ( ) ;
193210 assert_eq ! ( sec1, sec2) ;
194211 assert_ne ! ( sec_odd, sec2) ;
195212 }
@@ -205,11 +222,23 @@ mod tests {
205222 x_out = x;
206223 y_out = y;
207224 expect_result. into ( )
208- } ) ;
225+ } ) . unwrap ( ) ;
209226 assert_eq ! ( & expect_result[ ..] , & result[ ..] ) ;
210227 assert_ne ! ( x_out, [ 0u8 ; 32 ] ) ;
211228 assert_ne ! ( y_out, [ 0u8 ; 32 ] ) ;
212229 }
230+
231+ #[ test]
232+ fn ecdh_with_hash_callback_panic ( ) {
233+ let s = Secp256k1 :: signing_only ( ) ;
234+ let ( sk1, pk1) = s. generate_keypair ( & mut thread_rng ( ) ) ;
235+ let mut res = [ 0u8 ; 48 ] ;
236+ let result = SharedSecret :: new_with_hash ( & pk1, & sk1, | x, _ | {
237+ res. copy_from_slice ( & x) ; // res.len() != x.len(). this will panic.
238+ res. into ( )
239+ } ) ;
240+ assert_eq ! ( result, Err ( Error :: CallbackPanicked ) ) ;
241+ }
213242}
214243
215244#[ cfg( all( test, feature = "unstable" ) ) ]
0 commit comments