@@ -89,31 +89,38 @@ impl Deref for SharedSecret {
8989}
9090
9191
92- #[ cfg( feature = "std" ) ]
93- unsafe extern "C" fn hash_callback < F > ( output : * mut c_uchar , x : * const c_uchar , y : * const c_uchar , data : * mut c_void ) -> c_int
92+ unsafe fn callback_logic < F > ( output : * mut c_uchar , x : * const c_uchar , y : * const c_uchar , data : * mut c_void ) -> c_int
9493 where F : FnMut ( [ u8 ; 32 ] , [ u8 ; 32 ] ) -> SharedSecret {
94+ let callback: & mut F = & mut * ( data as * mut F ) ;
95+
96+ let mut x_arr = [ 0 ; 32 ] ;
97+ let mut y_arr = [ 0 ; 32 ] ;
98+ ptr:: copy_nonoverlapping ( x, x_arr. as_mut_ptr ( ) , 32 ) ;
99+ ptr:: copy_nonoverlapping ( y, y_arr. as_mut_ptr ( ) , 32 ) ;
95100
96- use std:: panic:: catch_unwind;
97- let res = catch_unwind ( || {
98- let callback: & mut F = & mut * ( data as * mut F ) ;
101+ let secret = callback ( x_arr, y_arr) ;
102+ ptr:: copy_nonoverlapping ( secret. as_ptr ( ) , output as * mut u8 , secret. len ( ) ) ;
99103
100- let mut x_arr = [ 0 ; 32 ] ;
101- let mut y_arr = [ 0 ; 32 ] ;
102- ptr:: copy_nonoverlapping ( x, x_arr. as_mut_ptr ( ) , 32 ) ;
103- ptr:: copy_nonoverlapping ( y, y_arr. as_mut_ptr ( ) , 32 ) ;
104+ secret. len ( ) as c_int
105+ }
104106
105- let secret = callback ( x_arr, y_arr) ;
106- ptr:: copy_nonoverlapping ( secret. as_ptr ( ) , output as * mut u8 , secret. len ( ) ) ;
107+ #[ cfg( feature = "std" ) ]
108+ unsafe extern "C" fn hash_callback_catch_unwind < F > ( output : * mut c_uchar , x : * const c_uchar , y : * const c_uchar , data : * mut c_void ) -> c_int
109+ where F : FnMut ( [ u8 ; 32 ] , [ u8 ; 32 ] ) -> SharedSecret {
107110
108- secret. len ( ) as c_int
109- } ) ;
111+ let res = :: std:: panic:: catch_unwind ( ||callback_logic :: < F > ( output, x, y, data) ) ;
110112 if let Ok ( len) = res {
111113 len
112114 } else {
113115 -1
114116 }
115117}
116118
119+ unsafe extern "C" fn hash_callback_unsafe < F > ( output : * mut c_uchar , x : * const c_uchar , y : * const c_uchar , data : * mut c_void ) -> c_int
120+ where F : FnMut ( [ u8 ; 32 ] , [ u8 ; 32 ] ) -> SharedSecret {
121+ callback_logic :: < F > ( output, x, y, data)
122+ }
123+
117124
118125impl SharedSecret {
119126 /// Creates a new shared secret from a pubkey and secret key
@@ -135,6 +142,29 @@ impl SharedSecret {
135142 ss
136143 }
137144
145+ fn new_with_callback_internal < F > ( point : & PublicKey , scalar : & SecretKey , mut closure : F , callback : ffi:: EcdhHashFn ) -> Result < SharedSecret , Error >
146+ where F : FnMut ( [ u8 ; 32 ] , [ u8 ; 32 ] ) -> SharedSecret {
147+ let mut ss = SharedSecret :: empty ( ) ;
148+
149+ let res = unsafe {
150+ ffi:: secp256k1_ecdh (
151+ ffi:: secp256k1_context_no_precomp,
152+ ss. get_data_mut_ptr ( ) ,
153+ point. as_ptr ( ) ,
154+ scalar. as_ptr ( ) ,
155+ callback,
156+ & mut closure as * mut F as * mut c_void ,
157+ )
158+ } ;
159+ if res == -1 {
160+ return Err ( Error :: CallbackPanicked ) ;
161+ }
162+ debug_assert ! ( res >= 16 ) ; // 128 bit is the minimum for a secure hash function and the minimum we let users.
163+ ss. set_len ( res as usize ) ;
164+ Ok ( ss)
165+
166+ }
167+
138168 /// Creates a new shared secret from a pubkey and secret key with applied custom hash function
139169 /// # Examples
140170 /// ```
@@ -153,28 +183,42 @@ impl SharedSecret {
153183 ///
154184 /// ```
155185 #[ cfg( feature = "std" ) ]
156- pub fn new_with_hash < F > ( point : & PublicKey , scalar : & SecretKey , mut hash_function : F ) -> Result < SharedSecret , Error >
157- where F : FnMut ( [ u8 ; 32 ] , [ u8 ; 32 ] ) -> SharedSecret
158- {
159- let mut ss = SharedSecret :: empty ( ) ;
160- let hashfp: ffi:: EcdhHashFn = hash_callback :: < F > ;
186+ pub fn new_with_hash < F > ( point : & PublicKey , scalar : & SecretKey , hash_function : F ) -> Result < SharedSecret , Error >
187+ where F : FnMut ( [ u8 ; 32 ] , [ u8 ; 32 ] ) -> SharedSecret {
188+ Self :: new_with_callback_internal ( point, scalar, hash_function, hash_callback_catch_unwind :: < F > )
189+ }
161190
162- let res = unsafe {
163- ffi:: secp256k1_ecdh (
164- ffi:: secp256k1_context_no_precomp,
165- ss. get_data_mut_ptr ( ) ,
166- point. as_ptr ( ) ,
167- scalar. as_ptr ( ) ,
168- hashfp,
169- & mut hash_function as * mut F as * mut c_void ,
170- )
171- } ;
172- if res == -1 {
173- return Err ( Error :: CallbackPanicked ) ;
174- }
175- debug_assert ! ( res >= 16 ) ; // 128 bit is the minimum for a secure hash function and the minimum we let users.
176- ss. set_len ( res as usize ) ;
177- Ok ( ss)
191+ /// Creates a new shared secret from a pubkey and secret key with applied custom hash function
192+ /// Note that this function is the same as [`new_with_hash`]
193+ ///
194+ /// # Safety
195+ /// The function doesn't wrap the callback with [`catch_unwind`]
196+ /// so if the callback panics it will panic through an FFI boundray which is [`Undefined Behavior`]
197+ /// If possible you should use [`new_with_hash`] which does wrap the callback with [`catch_unwind`] so is safe to use.
198+ ///
199+ /// [`catch_unwind`]: https://doc.rust-lang.org/std/panic/fn.catch_unwind.html
200+ /// [`Undefined Behavior`]: https://doc.rust-lang.org/nomicon/ffi.html#ffi-and-panics
201+ /// [`new_with_hash`]: #method.new_with_hash
202+ /// # Examples
203+ /// ```
204+ /// # use secp256k1::ecdh::SharedSecret;
205+ /// # use secp256k1::{Secp256k1, PublicKey, SecretKey};
206+ /// # fn sha2(_a: &[u8], _b: &[u8]) -> [u8; 32] {[0u8; 32]}
207+ /// # let secp = Secp256k1::signing_only();
208+ /// # let secret_key = SecretKey::from_slice(&[3u8; 32]).unwrap();
209+ /// # let secret_key2 = SecretKey::from_slice(&[7u8; 32]).unwrap();
210+ /// # let public_key = PublicKey::from_secret_key(&secp, &secret_key2);
211+ //
212+ /// let secret = unsafe { SharedSecret::new_with_hash_no_panic(&public_key, &secret_key, |x,y| {
213+ /// let hash: [u8; 32] = sha2(&x,&y);
214+ /// hash.into()
215+ /// })};
216+ ///
217+ ///
218+ /// ```
219+ pub unsafe fn new_with_hash_no_panic < F > ( point : & PublicKey , scalar : & SecretKey , hash_function : F ) -> Result < SharedSecret , Error >
220+ where F : FnMut ( [ u8 ; 32 ] , [ u8 ; 32 ] ) -> SharedSecret {
221+ Self :: new_with_callback_internal ( point, scalar, hash_function, hash_callback_unsafe :: < F > )
178222 }
179223}
180224
@@ -223,7 +267,13 @@ mod tests {
223267 y_out = y;
224268 expect_result. into ( )
225269 } ) . unwrap ( ) ;
270+ let result_unsafe = unsafe { SharedSecret :: new_with_hash_no_panic ( & pk1, & sk1, | x, y | {
271+ x_out = x;
272+ y_out = y;
273+ expect_result. into ( )
274+ } ) . unwrap ( ) } ;
226275 assert_eq ! ( & expect_result[ ..] , & result[ ..] ) ;
276+ assert_eq ! ( result, result_unsafe) ;
227277 assert_ne ! ( x_out, [ 0u8 ; 32 ] ) ;
228278 assert_ne ! ( y_out, [ 0u8 ; 32 ] ) ;
229279 }
0 commit comments