55use crate :: hasher:: { Hasher , Keccak256Hasher } ;
66use crate :: { MerkleError , Node , Store } ;
77use core:: ops:: Index ;
8- use std:: collections:: HashMap ;
8+ use std:: { collections:: HashMap , sync :: RwLock } ;
99
1010#[ cfg( feature = "memory_store" ) ]
1111use crate :: stores:: MemoryStore ;
1212
1313pub struct MerkleProof < const DEPTH : usize > {
14+ inner : RwLock < MerkleProofInner < DEPTH > > ,
15+ }
16+
17+ impl < const DEPTH : usize > MerkleProof < DEPTH > {
18+ /// Get a read lock on the proof data
19+ pub fn read (
20+ & self ,
21+ ) -> Result < std:: sync:: RwLockReadGuard < ' _ , MerkleProofInner < DEPTH > > , MerkleError > {
22+ self . inner . read ( ) . map_err ( |e| {
23+ MerkleError :: LockPoisoned ( format ! ( "Failed to acquire read lock on MerkleProof: {}" , e) )
24+ } )
25+ }
26+
27+ /// Get a write lock on the proof data
28+ pub fn write (
29+ & self ,
30+ ) -> Result < std:: sync:: RwLockWriteGuard < ' _ , MerkleProofInner < DEPTH > > , MerkleError > {
31+ self . inner . write ( ) . map_err ( |e| {
32+ MerkleError :: LockPoisoned ( format ! (
33+ "Failed to acquire write lock on MerkleProof: {}" ,
34+ e
35+ ) )
36+ } )
37+ }
38+ }
39+
40+ // Make MerkleProofInner public so the read/write methods can return it
41+ pub struct MerkleProofInner < const DEPTH : usize > {
1442 pub proof : [ Node ; DEPTH ] ,
1543 pub leaf : Node ,
1644 pub index : u64 ,
1745 pub root : Node ,
1846}
1947
2048pub struct MerkleTree < H , S , const DEPTH : usize >
49+ where
50+ H : Hasher ,
51+ S : Store ,
52+ {
53+ inner : RwLock < MerkleTreeInner < H , S , DEPTH > > ,
54+ }
55+
56+ struct MerkleTreeInner < H , S , const DEPTH : usize >
2157where
2258 H : Hasher ,
2359 S : Store ,
@@ -78,21 +114,27 @@ where
78114 last : hasher. hash ( & zero[ DEPTH - 1 ] , & zero[ DEPTH - 1 ] ) ,
79115 } ;
80116 Self {
81- hasher,
82- store,
83- zeros,
117+ inner : RwLock :: new ( MerkleTreeInner {
118+ hasher,
119+ store,
120+ zeros,
121+ } ) ,
84122 }
85123 }
86124
87- pub fn add_leaves ( & mut self , leaves : & [ Node ] ) -> Result < ( ) , MerkleError > {
125+ pub fn add_leaves ( & self , leaves : & [ Node ] ) -> Result < ( ) , MerkleError > {
88126 // Early return
89127 if leaves. is_empty ( ) {
90128 return Ok ( ( ) ) ;
91129 }
92130
131+ let mut inner = self . inner . write ( ) . map_err ( |e| {
132+ MerkleError :: LockPoisoned ( format ! ( "Failed to acquire write lock on MerkleTree: {}" , e) )
133+ } ) ?;
134+
93135 // Error if leaves do not fit in the tree
94136 // TODO: Avoid calculating this. Calculate it at init or do the shifting with the generic.
95- if self . store . get_num_leaves ( ) + leaves. len ( ) as u64 > ( 1 << DEPTH as u64 ) {
137+ if inner . store . get_num_leaves ( ) + leaves. len ( ) as u64 > ( 1 << DEPTH as u64 ) {
96138 return Err ( MerkleError :: TreeFull {
97139 depth : DEPTH as u32 ,
98140 capacity : 1 << DEPTH as u64 ,
@@ -107,7 +149,7 @@ where
107149 let mut cache: HashMap < ( u32 , u64 ) , Node > = HashMap :: new ( ) ;
108150
109151 for ( offset, leaf) in leaves. iter ( ) . enumerate ( ) {
110- let mut idx = self . store . get_num_leaves ( ) + offset as u64 ;
152+ let mut idx = inner . store . get_num_leaves ( ) + offset as u64 ;
111153 let mut h = * leaf;
112154
113155 // Store the leaf
@@ -132,7 +174,7 @@ where
132174
133175 // Batch-fetch the missing siblings and insert them in cache.
134176 if fetch_len != 0 {
135- let fetched = self . store . get (
177+ let fetched = inner . store . get (
136178 & levels_to_fetch[ ..fetch_len] ,
137179 & indices_to_fetch[ ..fetch_len] ,
138180 ) ?;
@@ -150,15 +192,15 @@ where
150192 let sib_hash = cache
151193 . get ( & ( level as u32 , sibling_idx) )
152194 . copied ( )
153- . unwrap_or ( self . zeros [ level] ) ;
195+ . unwrap_or ( inner . zeros [ level] ) ;
154196
155197 let ( left, right) = if idx & 1 == 1 {
156198 ( sib_hash, h)
157199 } else {
158200 ( h, sib_hash)
159201 } ;
160202
161- h = self . hasher . hash ( & left, & right) ;
203+ h = inner . hasher . hash ( & left, & right) ;
162204 idx >>= 1 ;
163205
164206 batch. push ( ( ( level + 1 ) as u32 , idx, h) ) ;
@@ -167,19 +209,22 @@ where
167209 }
168210
169211 // Update all values in a single batch
170- self . store . put ( & batch) ?;
212+ inner . store . put ( & batch) ?;
171213
172214 Ok ( ( ) )
173215 }
174216
175217 pub fn root ( & self ) -> Result < Node , MerkleError > {
176- Ok ( self
218+ let inner = self . inner . read ( ) . map_err ( |e| {
219+ MerkleError :: LockPoisoned ( format ! ( "Failed to acquire read lock on MerkleTree: {}" , e) )
220+ } ) ?;
221+ Ok ( inner
177222 . store
178223 . get ( & [ DEPTH as u32 ] , & [ 0 ] ) ?
179224 . into_iter ( )
180225 . next ( )
181226 . ok_or_else ( || MerkleError :: StoreError ( "root fetch returned empty vector" . into ( ) ) ) ?
182- . unwrap_or ( self . zeros [ DEPTH ] ) )
227+ . unwrap_or ( inner . zeros [ DEPTH ] ) )
183228 }
184229
185230 pub fn proof ( & self , leaf_idx : u64 ) -> Result < MerkleProof < DEPTH > , MerkleError > {
@@ -201,6 +246,10 @@ where
201246 } ) ;
202247 }
203248
249+ let inner = self . inner . read ( ) . map_err ( |e| {
250+ MerkleError :: LockPoisoned ( format ! ( "Failed to acquire read lock on MerkleTree: {}" , e) )
251+ } ) ?;
252+
204253 // Build level/index lists for siblings plus the leaf.
205254 // TODO: Can't do arithmetic here with DEPTH meaning there is no
206255 // easy way to put this in the stack. Unfortunately the array size
@@ -222,40 +271,66 @@ where
222271 indices. push ( leaf_idx) ;
223272
224273 // Batch fetch all requested nodes.
225- let fetched = self . store . get ( & levels, & indices) ?;
274+ let fetched = inner . store . get ( & levels, & indices) ?;
226275
227276 // The first DEPTH items are the siblings.
228277 let mut proof = [ Node :: ZERO ; DEPTH ] ;
229278 for ( d, opt) in fetched. iter ( ) . take ( DEPTH ) . enumerate ( ) {
230- proof[ d] = opt. unwrap_or ( self . zeros [ d] ) ;
279+ proof[ d] = opt. unwrap_or ( inner . zeros [ d] ) ;
231280 }
232281
233282 // The last item is the leaf itself.
234- let leaf_hash = fetched. last ( ) . copied ( ) . flatten ( ) . unwrap_or ( self . zeros [ 0 ] ) ;
283+ let leaf_hash = fetched. last ( ) . copied ( ) . flatten ( ) . unwrap_or ( inner. zeros [ 0 ] ) ;
284+
285+ // Release the lock before calling root() to avoid deadlock
286+ let root = {
287+ drop ( inner) ;
288+ self . root ( ) ?
289+ } ;
235290
236291 Ok ( MerkleProof {
237- proof,
238- leaf : leaf_hash,
239- index : leaf_idx,
240- root : self . root ( ) ?,
292+ inner : RwLock :: new ( MerkleProofInner {
293+ proof,
294+ leaf : leaf_hash,
295+ index : leaf_idx,
296+ root,
297+ } ) ,
241298 } )
242299 }
243300
244301 pub fn verify_proof ( & self , proof : & MerkleProof < DEPTH > ) -> Result < bool , MerkleError > {
245- let mut computed_hash = proof. leaf ;
246- for ( j, sibling_hash) in proof. proof . iter ( ) . enumerate ( ) {
247- let ( left, right) = if proof. index & ( 1 << j) == 0 {
302+ let proof_inner = proof. inner . read ( ) . map_err ( |e| {
303+ MerkleError :: LockPoisoned ( format ! ( "Failed to acquire read lock on MerkleProof: {}" , e) )
304+ } ) ?;
305+ let tree_inner = self . inner . read ( ) . map_err ( |e| {
306+ MerkleError :: LockPoisoned ( format ! ( "Failed to acquire read lock on MerkleTree: {}" , e) )
307+ } ) ?;
308+ let mut computed_hash = proof_inner. leaf ;
309+ let idx = proof_inner. index ;
310+ let root = proof_inner. root ;
311+ for ( j, sibling_hash) in proof_inner. proof . iter ( ) . enumerate ( ) {
312+ let ( left, right) = if idx & ( 1 << j) == 0 {
248313 ( computed_hash, * sibling_hash)
249314 } else {
250315 ( * sibling_hash, computed_hash)
251316 } ;
252- computed_hash = self . hasher . hash ( & left, & right) ;
317+ computed_hash = tree_inner . hasher . hash ( & left, & right) ;
253318 }
254- Ok ( computed_hash == proof . root )
319+ Ok ( computed_hash == root)
255320 }
256321
257- pub fn num_leaves ( & self ) -> u64 {
258- self . store . get_num_leaves ( )
322+ pub fn num_leaves ( & self ) -> Result < u64 , MerkleError > {
323+ Ok ( self
324+ . inner
325+ . read ( )
326+ . map_err ( |e| {
327+ MerkleError :: LockPoisoned ( format ! (
328+ "Failed to acquire read lock on MerkleTree: {}" ,
329+ e
330+ ) )
331+ } ) ?
332+ . store
333+ . get_num_leaves ( ) )
259334 }
260335}
261336
@@ -312,10 +387,14 @@ mod tests {
312387 to_node ! ( "0x27ae5ba08d7291c96c8cbddcc148bf48a6d68c7974b94356f53754ef6171d757" ) ,
313388 ] ;
314389
315- for ( i, zero) in tree. zeros . front . iter ( ) . enumerate ( ) {
390+ let inner = tree
391+ . inner
392+ . read ( )
393+ . expect ( "Lock should not be poisoned in test" ) ;
394+ for ( i, zero) in inner. zeros . front . iter ( ) . enumerate ( ) {
316395 assert_eq ! ( zero, & expected_zeros[ i] ) ;
317396 }
318- assert_eq ! ( tree . zeros. last, expected_zeros[ 32 ] ) ;
397+ assert_eq ! ( inner . zeros. last, expected_zeros[ 32 ] ) ;
319398 }
320399
321400 #[ cfg( feature = "memory_store" ) ]
@@ -366,18 +445,22 @@ mod tests {
366445 to_node ! ( "0x2f68a1c58e257e42a17a6c61dff5551ed560b9922ab119d5ac8e184c9734ead9" ) ,
367446 ] ;
368447
369- for ( i, zero) in tree. zeros . front . iter ( ) . enumerate ( ) {
448+ let inner = tree
449+ . inner
450+ . read ( )
451+ . expect ( "Lock should not be poisoned in test" ) ;
452+ for ( i, zero) in inner. zeros . front . iter ( ) . enumerate ( ) {
370453 assert_eq ! ( zero, & expected_zeros[ i] ) ;
371454 }
372- assert_eq ! ( tree . zeros. last, expected_zeros[ 32 ] ) ;
455+ assert_eq ! ( inner . zeros. last, expected_zeros[ 32 ] ) ;
373456 }
374457
375458 #[ cfg( feature = "memory_store" ) ]
376459 #[ test]
377460 fn test_tree_full_error ( ) {
378461 let hasher = Keccak256Hasher ;
379462 let store = MemoryStore :: default ( ) ;
380- let mut tree = MerkleTree :: < Keccak256Hasher , MemoryStore , 3 > :: new ( hasher, store) ;
463+ let tree = MerkleTree :: < Keccak256Hasher , MemoryStore , 3 > :: new ( hasher, store) ;
381464
382465 tree. add_leaves ( & ( 0 ..8 ) . map ( |_| Node :: ZERO ) . collect :: < Vec < Node > > ( ) )
383466 . unwrap ( ) ;
0 commit comments