1717//!
1818
1919use core:: ptr;
20- use core:: ops:: Deref ;
20+ use core:: ops:: { FnMut , Deref } ;
2121
2222use key:: { SecretKey , PublicKey } ;
2323use ffi:: { self , CPtr } ;
24+ use secp256k1_sys:: types:: { c_int, c_uchar, c_void} ;
2425
2526/// A tag used for recovering the public key from a compact signature
2627#[ derive( Copy , Clone ) ]
@@ -68,7 +69,7 @@ impl SharedSecret {
6869
6970impl PartialEq for SharedSecret {
7071 fn eq ( & self , other : & SharedSecret ) -> bool {
71- & self . data [ .. self . len ] == & other. data [ ..other . len ]
72+ self . as_ref ( ) == other. as_ref ( )
7273 }
7374}
7475
@@ -86,6 +87,22 @@ impl Deref for SharedSecret {
8687}
8788
8889
90+ unsafe extern "C" fn hash_callback < F > ( output : * mut c_uchar , x : * const c_uchar , y : * const c_uchar , data : * mut c_void ) -> c_int
91+ where F : FnMut ( [ u8 ; 32 ] , [ u8 ; 32 ] ) -> SharedSecret {
92+ let callback: & mut F = & mut * ( data as * mut F ) ;
93+
94+ let mut x_arr = [ 0 ; 32 ] ;
95+ let mut y_arr = [ 0 ; 32 ] ;
96+ ptr:: copy_nonoverlapping ( x, x_arr. as_mut_ptr ( ) , 32 ) ;
97+ ptr:: copy_nonoverlapping ( y, y_arr. as_mut_ptr ( ) , 32 ) ;
98+
99+ let secret = callback ( x_arr, y_arr) ;
100+ ptr:: copy_nonoverlapping ( secret. as_ptr ( ) , output as * mut u8 , secret. len ( ) ) ;
101+
102+ secret. len ( ) as c_int
103+ }
104+
105+
89106impl SharedSecret {
90107 /// Creates a new shared secret from a pubkey and secret key
91108 #[ inline]
@@ -105,6 +122,44 @@ impl SharedSecret {
105122 ss. set_len ( 32 ) ; // The default hash function is SHA256, which is 32 bytes long.
106123 ss
107124 }
125+
126+ /// Creates a new shared secret from a pubkey and secret key with applied custom hash function
127+ /// # Examples
128+ /// ```
129+ /// # use secp256k1::ecdh::SharedSecret;
130+ /// # use secp256k1::{Secp256k1, PublicKey, SecretKey};
131+ /// # fn sha2(_a: &[u8], _b: &[u8]) -> [u8; 32] {[0u8; 32]}
132+ /// # let secp = Secp256k1::signing_only();
133+ /// # let secret_key = SecretKey::from_slice(&[3u8; 32]).unwrap();
134+ /// # let secret_key2 = SecretKey::from_slice(&[7u8; 32]).unwrap();
135+ /// # let public_key = PublicKey::from_secret_key(&secp, &secret_key2);
136+ ///
137+ /// let secret = SharedSecret::new_with_hash(&public_key, &secret_key, |x,y| {
138+ /// let hash: [u8; 32] = sha2(&x,&y);
139+ /// hash.into()
140+ /// });
141+ ///
142+ /// ```
143+ pub fn new_with_hash < F > ( point : & PublicKey , scalar : & SecretKey , mut hash_function : F ) -> SharedSecret
144+ where F : FnMut ( [ u8 ; 32 ] , [ u8 ; 32 ] ) -> SharedSecret
145+ {
146+ let mut ss = SharedSecret :: empty ( ) ;
147+ let hashfp: ffi:: EcdhHashFn = hash_callback :: < F > ;
148+
149+ let res = unsafe {
150+ ffi:: secp256k1_ecdh (
151+ ffi:: secp256k1_context_no_precomp,
152+ ss. get_data_mut_ptr ( ) ,
153+ point. as_ptr ( ) ,
154+ scalar. as_ptr ( ) ,
155+ hashfp,
156+ & mut hash_function as * mut F as * mut c_void ,
157+ )
158+ } ;
159+ debug_assert ! ( res >= 16 ) ; // 128 bit is the minimum for a secure hash function and the minimum we let users.
160+ ss. set_len ( res as usize ) ;
161+ ss
162+ }
108163}
109164
110165#[ cfg( test) ]
0 commit comments