55use crate :: hasher:: { Hasher , Keccak256Hasher } ;
66use crate :: { MerkleError , Node , Store } ;
77use core:: ops:: Index ;
8- use std:: collections:: HashMap ;
8+ use std:: { collections:: HashMap , sync :: RwLock } ;
99
1010#[ cfg( feature = "memory_store" ) ]
1111use crate :: stores:: MemoryStore ;
1212
1313pub struct MerkleProof < const DEPTH : usize > {
14+ inner : RwLock < MerkleProofInner < DEPTH > > ,
15+ }
16+
17+ impl < const DEPTH : usize > MerkleProof < DEPTH > {
18+ /// Get a read lock on the proof data
19+ pub fn read ( & self ) -> Result < std:: sync:: RwLockReadGuard < ' _ , MerkleProofInner < DEPTH > > , MerkleError > {
20+ self . inner . read ( ) . map_err ( |e| MerkleError :: LockPoisoned ( format ! ( "Failed to acquire read lock on MerkleProof: {}" , e) ) )
21+ }
22+
23+ /// Get a write lock on the proof data
24+ pub fn write ( & self ) -> Result < std:: sync:: RwLockWriteGuard < ' _ , MerkleProofInner < DEPTH > > , MerkleError > {
25+ self . inner . write ( ) . map_err ( |e| MerkleError :: LockPoisoned ( format ! ( "Failed to acquire write lock on MerkleProof: {}" , e) ) )
26+ }
27+ }
28+
29+ // Make MerkleProofInner public so the read/write methods can return it
30+ pub struct MerkleProofInner < const DEPTH : usize > {
1431 pub proof : [ Node ; DEPTH ] ,
1532 pub leaf : Node ,
1633 pub index : u64 ,
1734 pub root : Node ,
1835}
1936
2037pub struct MerkleTree < H , S , const DEPTH : usize >
38+ where
39+ H : Hasher ,
40+ S : Store ,
41+ {
42+ inner : RwLock < MerkleTreeInner < H , S , DEPTH > > ,
43+ }
44+
45+ struct MerkleTreeInner < H , S , const DEPTH : usize >
2146where
2247 H : Hasher ,
2348 S : Store ,
@@ -78,21 +103,26 @@ where
78103 last : hasher. hash ( & zero[ DEPTH - 1 ] , & zero[ DEPTH - 1 ] ) ,
79104 } ;
80105 Self {
81- hasher,
82- store,
83- zeros,
106+ inner : RwLock :: new ( MerkleTreeInner {
107+ hasher,
108+ store,
109+ zeros,
110+ } ) ,
84111 }
85112 }
86113
87- pub fn add_leaves ( & mut self , leaves : & [ Node ] ) -> Result < ( ) , MerkleError > {
114+ pub fn add_leaves ( & self , leaves : & [ Node ] ) -> Result < ( ) , MerkleError > {
88115 // Early return
89116 if leaves. is_empty ( ) {
90117 return Ok ( ( ) ) ;
91118 }
92119
120+ let mut inner = self . inner . write ( )
121+ . map_err ( |e| MerkleError :: LockPoisoned ( format ! ( "Failed to acquire write lock on MerkleTree: {}" , e) ) ) ?;
122+
93123 // Error if leaves do not fit in the tree
94124 // TODO: Avoid calculating this. Calculate it at init or do the shifting with the generic.
95- if self . store . get_num_leaves ( ) + leaves. len ( ) as u64 > ( 1 << DEPTH as u64 ) {
125+ if inner . store . get_num_leaves ( ) + leaves. len ( ) as u64 > ( 1 << DEPTH as u64 ) {
96126 return Err ( MerkleError :: TreeFull {
97127 depth : DEPTH as u32 ,
98128 capacity : 1 << DEPTH as u64 ,
@@ -107,7 +137,7 @@ where
107137 let mut cache: HashMap < ( u32 , u64 ) , Node > = HashMap :: new ( ) ;
108138
109139 for ( offset, leaf) in leaves. iter ( ) . enumerate ( ) {
110- let mut idx = self . store . get_num_leaves ( ) + offset as u64 ;
140+ let mut idx = inner . store . get_num_leaves ( ) + offset as u64 ;
111141 let mut h = * leaf;
112142
113143 // Store the leaf
@@ -132,7 +162,7 @@ where
132162
133163 // Batch-fetch the missing siblings and insert them in cache.
134164 if fetch_len != 0 {
135- let fetched = self . store . get (
165+ let fetched = inner . store . get (
136166 & levels_to_fetch[ ..fetch_len] ,
137167 & indices_to_fetch[ ..fetch_len] ,
138168 ) ?;
@@ -150,15 +180,15 @@ where
150180 let sib_hash = cache
151181 . get ( & ( level as u32 , sibling_idx) )
152182 . copied ( )
153- . unwrap_or ( self . zeros [ level] ) ;
183+ . unwrap_or ( inner . zeros [ level] ) ;
154184
155185 let ( left, right) = if idx & 1 == 1 {
156186 ( sib_hash, h)
157187 } else {
158188 ( h, sib_hash)
159189 } ;
160190
161- h = self . hasher . hash ( & left, & right) ;
191+ h = inner . hasher . hash ( & left, & right) ;
162192 idx >>= 1 ;
163193
164194 batch. push ( ( ( level + 1 ) as u32 , idx, h) ) ;
@@ -167,19 +197,21 @@ where
167197 }
168198
169199 // Update all values in a single batch
170- self . store . put ( & batch) ?;
200+ inner . store . put ( & batch) ?;
171201
172202 Ok ( ( ) )
173203 }
174204
175205 pub fn root ( & self ) -> Result < Node , MerkleError > {
176- Ok ( self
206+ let inner = self . inner . read ( )
207+ . map_err ( |e| MerkleError :: LockPoisoned ( format ! ( "Failed to acquire read lock on MerkleTree: {}" , e) ) ) ?;
208+ Ok ( inner
177209 . store
178210 . get ( & [ DEPTH as u32 ] , & [ 0 ] ) ?
179211 . into_iter ( )
180212 . next ( )
181213 . ok_or_else ( || MerkleError :: StoreError ( "root fetch returned empty vector" . into ( ) ) ) ?
182- . unwrap_or ( self . zeros [ DEPTH ] ) )
214+ . unwrap_or ( inner . zeros [ DEPTH ] ) )
183215 }
184216
185217 pub fn proof ( & self , leaf_idx : u64 ) -> Result < MerkleProof < DEPTH > , MerkleError > {
@@ -201,6 +233,9 @@ where
201233 } ) ;
202234 }
203235
236+ let inner = self . inner . read ( )
237+ . map_err ( |e| MerkleError :: LockPoisoned ( format ! ( "Failed to acquire read lock on MerkleTree: {}" , e) ) ) ?;
238+
204239 // Build level/index lists for siblings plus the leaf.
205240 // TODO: Can't do arithmetic here with DEPTH meaning there is no
206241 // easy way to put this in the stack. Unfortunately the array size
@@ -222,40 +257,56 @@ where
222257 indices. push ( leaf_idx) ;
223258
224259 // Batch fetch all requested nodes.
225- let fetched = self . store . get ( & levels, & indices) ?;
260+ let fetched = inner . store . get ( & levels, & indices) ?;
226261
227262 // The first DEPTH items are the siblings.
228263 let mut proof = [ Node :: ZERO ; DEPTH ] ;
229264 for ( d, opt) in fetched. iter ( ) . take ( DEPTH ) . enumerate ( ) {
230- proof[ d] = opt. unwrap_or ( self . zeros [ d] ) ;
265+ proof[ d] = opt. unwrap_or ( inner . zeros [ d] ) ;
231266 }
232267
233268 // The last item is the leaf itself.
234- let leaf_hash = fetched. last ( ) . copied ( ) . flatten ( ) . unwrap_or ( self . zeros [ 0 ] ) ;
269+ let leaf_hash = fetched. last ( ) . copied ( ) . flatten ( ) . unwrap_or ( inner. zeros [ 0 ] ) ;
270+
271+ // Release the lock before calling root() to avoid deadlock
272+ let root = {
273+ drop ( inner) ;
274+ self . root ( ) ?
275+ } ;
235276
236277 Ok ( MerkleProof {
237- proof,
238- leaf : leaf_hash,
239- index : leaf_idx,
240- root : self . root ( ) ?,
278+ inner : RwLock :: new ( MerkleProofInner {
279+ proof,
280+ leaf : leaf_hash,
281+ index : leaf_idx,
282+ root,
283+ } ) ,
241284 } )
242285 }
243286
244287 pub fn verify_proof ( & self , proof : & MerkleProof < DEPTH > ) -> Result < bool , MerkleError > {
245- let mut computed_hash = proof. leaf ;
246- for ( j, sibling_hash) in proof. proof . iter ( ) . enumerate ( ) {
247- let ( left, right) = if proof. index & ( 1 << j) == 0 {
288+ let proof_inner = proof. inner . read ( )
289+ . map_err ( |e| MerkleError :: LockPoisoned ( format ! ( "Failed to acquire read lock on MerkleProof: {}" , e) ) ) ?;
290+ let tree_inner = self . inner . read ( )
291+ . map_err ( |e| MerkleError :: LockPoisoned ( format ! ( "Failed to acquire read lock on MerkleTree: {}" , e) ) ) ?;
292+ let mut computed_hash = proof_inner. leaf ;
293+ let idx = proof_inner. index ;
294+ let root = proof_inner. root ;
295+ for ( j, sibling_hash) in proof_inner. proof . iter ( ) . enumerate ( ) {
296+ let ( left, right) = if idx & ( 1 << j) == 0 {
248297 ( computed_hash, * sibling_hash)
249298 } else {
250299 ( * sibling_hash, computed_hash)
251300 } ;
252- computed_hash = self . hasher . hash ( & left, & right) ;
301+ computed_hash = tree_inner . hasher . hash ( & left, & right) ;
253302 }
254- Ok ( computed_hash == proof . root )
303+ Ok ( computed_hash == root)
255304 }
256305
257- pub fn num_leaves ( & self ) -> u64 {
258- self . store . get_num_leaves ( )
306+ pub fn num_leaves ( & self ) -> Result < u64 , MerkleError > {
307+ Ok ( self . inner . read ( )
308+ . map_err ( |e| MerkleError :: LockPoisoned ( format ! ( "Failed to acquire read lock on MerkleTree: {}" , e) ) ) ?
309+ . store . get_num_leaves ( ) )
259310 }
260311}
261312
@@ -312,10 +363,12 @@ mod tests {
312363 to_node ! ( "0x27ae5ba08d7291c96c8cbddcc148bf48a6d68c7974b94356f53754ef6171d757" ) ,
313364 ] ;
314365
315- for ( i, zero) in tree. zeros . front . iter ( ) . enumerate ( ) {
366+ let inner = tree. inner . read ( )
367+ . expect ( "Lock should not be poisoned in test" ) ;
368+ for ( i, zero) in inner. zeros . front . iter ( ) . enumerate ( ) {
316369 assert_eq ! ( zero, & expected_zeros[ i] ) ;
317370 }
318- assert_eq ! ( tree . zeros. last, expected_zeros[ 32 ] ) ;
371+ assert_eq ! ( inner . zeros. last, expected_zeros[ 32 ] ) ;
319372 }
320373
321374 #[ cfg( feature = "memory_store" ) ]
@@ -366,18 +419,20 @@ mod tests {
366419 to_node ! ( "0x2f68a1c58e257e42a17a6c61dff5551ed560b9922ab119d5ac8e184c9734ead9" ) ,
367420 ] ;
368421
369- for ( i, zero) in tree. zeros . front . iter ( ) . enumerate ( ) {
422+ let inner = tree. inner . read ( )
423+ . expect ( "Lock should not be poisoned in test" ) ;
424+ for ( i, zero) in inner. zeros . front . iter ( ) . enumerate ( ) {
370425 assert_eq ! ( zero, & expected_zeros[ i] ) ;
371426 }
372- assert_eq ! ( tree . zeros. last, expected_zeros[ 32 ] ) ;
427+ assert_eq ! ( inner . zeros. last, expected_zeros[ 32 ] ) ;
373428 }
374429
375430 #[ cfg( feature = "memory_store" ) ]
376431 #[ test]
377432 fn test_tree_full_error ( ) {
378433 let hasher = Keccak256Hasher ;
379434 let store = MemoryStore :: default ( ) ;
380- let mut tree = MerkleTree :: < Keccak256Hasher , MemoryStore , 3 > :: new ( hasher, store) ;
435+ let tree = MerkleTree :: < Keccak256Hasher , MemoryStore , 3 > :: new ( hasher, store) ;
381436
382437 tree. add_leaves ( & ( 0 ..8 ) . map ( |_| Node :: ZERO ) . collect :: < Vec < Node > > ( ) )
383438 . unwrap ( ) ;
0 commit comments