Skip to content

Commit 4ffcb7d

Browse files
committed
Async ready lib
1 parent 85c496b commit 4ffcb7d

File tree

5 files changed

+121
-49
lines changed

5 files changed

+121
-49
lines changed

src/errors.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,7 @@ pub enum MerkleError {
1111

1212
#[error("Levels and indices must have the same length")]
1313
LengthMismatch { levels: usize, indices: usize },
14+
15+
#[error("Lock was poisoned: {0}")]
16+
LockPoisoned(String),
1417
}

src/lib.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,8 @@ fn main() {
3131
.unwrap();
3232
3333
println!("root: {:?}", tree.root().unwrap());
34-
println!("num leaves: {:?}", tree.num_leaves());
35-
println!("proof: {:?}", tree.proof(0).unwrap().proof);
34+
println!("num leaves: {:?}", tree.num_leaves().unwrap());
35+
println!("proof: {:?}", tree.proof(0).unwrap().read().unwrap().proof);
3636
}
3737
```
3838

src/stores/memory_store.rs

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,16 @@
33
//! Simple in-memory store implementation.
44
55
use crate::{MerkleError, Node, Store};
6-
use std::collections::HashMap;
6+
use std::{collections::HashMap, sync::RwLock};
77

88
/// Simple in-memory store implementation using a `HashMap`.
99
#[derive(Default)]
1010
pub struct MemoryStore {
11+
inner: RwLock<MemoryStoreInner>,
12+
}
13+
14+
#[derive(Default)]
15+
struct MemoryStoreInner {
1116
store: HashMap<(u32, u64), Node>,
1217
num_leaves: u64,
1318
}
@@ -26,27 +31,36 @@ impl Store for MemoryStore {
2631
indices: indices.len(),
2732
});
2833
}
29-
30-
// The memory store doesnt really allow batch reads, so just get all the
31-
// indexes/levels one by one.
34+
let inner = self.inner.read().map_err(|e| {
35+
MerkleError::LockPoisoned(format!("Failed to acquire read lock on MemoryStore: {}", e))
36+
})?;
3237
let result = levels
3338
.iter()
3439
.zip(indices)
35-
.map(|(&lvl, &idx)| self.store.get(&(lvl, idx)).cloned())
40+
.map(|(&lvl, &idx)| inner.store.get(&(lvl, idx)).cloned())
3641
.collect();
37-
3842
Ok(result)
3943
}
4044

4145
fn put(&mut self, items: &[(u32, u64, Node)]) -> Result<(), MerkleError> {
46+
let mut inner = self.inner.write().map_err(|e| {
47+
MerkleError::LockPoisoned(format!(
48+
"Failed to acquire write lock on MemoryStore: {}",
49+
e
50+
))
51+
})?;
4252
for (level, index, hash) in items {
43-
self.store.insert((*level, *index), *hash);
53+
inner.store.insert((*level, *index), *hash);
4454
}
4555
let counter = items.iter().filter(|(level, _, _)| *level == 0).count();
46-
self.num_leaves += counter as u64;
56+
inner.num_leaves += counter as u64;
4757
Ok(())
4858
}
4959
fn get_num_leaves(&self) -> u64 {
50-
self.num_leaves
60+
// For get_num_leaves, we use expect since it's a simple getter and lock poisoning
61+
// would indicate a serious bug. Using expect provides a clearer panic message.
62+
self.inner.read()
63+
.expect("MemoryStore lock was poisoned - this indicates a panic occurred while holding the lock")
64+
.num_leaves
5165
}
5266
}

src/tree.rs

Lines changed: 87 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -5,19 +5,44 @@
55
use crate::hasher::{Hasher, Keccak256Hasher};
66
use crate::{MerkleError, Node, Store};
77
use core::ops::Index;
8-
use std::collections::HashMap;
8+
use std::{collections::HashMap, sync::RwLock};
99

1010
#[cfg(feature = "memory_store")]
1111
use crate::stores::MemoryStore;
1212

1313
pub struct MerkleProof<const DEPTH: usize> {
14+
inner: RwLock<MerkleProofInner<DEPTH>>,
15+
}
16+
17+
impl<const DEPTH: usize> MerkleProof<DEPTH> {
18+
/// Get a read lock on the proof data
19+
pub fn read(&self) -> Result<std::sync::RwLockReadGuard<'_, MerkleProofInner<DEPTH>>, MerkleError> {
20+
self.inner.read().map_err(|e| MerkleError::LockPoisoned(format!("Failed to acquire read lock on MerkleProof: {}", e)))
21+
}
22+
23+
/// Get a write lock on the proof data
24+
pub fn write(&self) -> Result<std::sync::RwLockWriteGuard<'_, MerkleProofInner<DEPTH>>, MerkleError> {
25+
self.inner.write().map_err(|e| MerkleError::LockPoisoned(format!("Failed to acquire write lock on MerkleProof: {}", e)))
26+
}
27+
}
28+
29+
// Make MerkleProofInner public so the read/write methods can return it
30+
pub struct MerkleProofInner<const DEPTH: usize> {
1431
pub proof: [Node; DEPTH],
1532
pub leaf: Node,
1633
pub index: u64,
1734
pub root: Node,
1835
}
1936

2037
pub struct MerkleTree<H, S, const DEPTH: usize>
38+
where
39+
H: Hasher,
40+
S: Store,
41+
{
42+
inner: RwLock<MerkleTreeInner<H, S, DEPTH>>,
43+
}
44+
45+
struct MerkleTreeInner<H, S, const DEPTH: usize>
2146
where
2247
H: Hasher,
2348
S: Store,
@@ -78,21 +103,26 @@ where
78103
last: hasher.hash(&zero[DEPTH - 1], &zero[DEPTH - 1]),
79104
};
80105
Self {
81-
hasher,
82-
store,
83-
zeros,
106+
inner: RwLock::new(MerkleTreeInner {
107+
hasher,
108+
store,
109+
zeros,
110+
}),
84111
}
85112
}
86113

87-
pub fn add_leaves(&mut self, leaves: &[Node]) -> Result<(), MerkleError> {
114+
pub fn add_leaves(&self, leaves: &[Node]) -> Result<(), MerkleError> {
88115
// Early return
89116
if leaves.is_empty() {
90117
return Ok(());
91118
}
92119

120+
let mut inner = self.inner.write()
121+
.map_err(|e| MerkleError::LockPoisoned(format!("Failed to acquire write lock on MerkleTree: {}", e)))?;
122+
93123
// Error if leaves do not fit in the tree
94124
// TODO: Avoid calculating this. Calculate it at init or do the shifting with the generic.
95-
if self.store.get_num_leaves() + leaves.len() as u64 > (1 << DEPTH as u64) {
125+
if inner.store.get_num_leaves() + leaves.len() as u64 > (1 << DEPTH as u64) {
96126
return Err(MerkleError::TreeFull {
97127
depth: DEPTH as u32,
98128
capacity: 1 << DEPTH as u64,
@@ -107,7 +137,7 @@ where
107137
let mut cache: HashMap<(u32, u64), Node> = HashMap::new();
108138

109139
for (offset, leaf) in leaves.iter().enumerate() {
110-
let mut idx = self.store.get_num_leaves() + offset as u64;
140+
let mut idx = inner.store.get_num_leaves() + offset as u64;
111141
let mut h = *leaf;
112142

113143
// Store the leaf
@@ -132,7 +162,7 @@ where
132162

133163
// Batch-fetch the missing siblings and insert them in cache.
134164
if fetch_len != 0 {
135-
let fetched = self.store.get(
165+
let fetched = inner.store.get(
136166
&levels_to_fetch[..fetch_len],
137167
&indices_to_fetch[..fetch_len],
138168
)?;
@@ -150,15 +180,15 @@ where
150180
let sib_hash = cache
151181
.get(&(level as u32, sibling_idx))
152182
.copied()
153-
.unwrap_or(self.zeros[level]);
183+
.unwrap_or(inner.zeros[level]);
154184

155185
let (left, right) = if idx & 1 == 1 {
156186
(sib_hash, h)
157187
} else {
158188
(h, sib_hash)
159189
};
160190

161-
h = self.hasher.hash(&left, &right);
191+
h = inner.hasher.hash(&left, &right);
162192
idx >>= 1;
163193

164194
batch.push(((level + 1) as u32, idx, h));
@@ -167,19 +197,21 @@ where
167197
}
168198

169199
// Update all values in a single batch
170-
self.store.put(&batch)?;
200+
inner.store.put(&batch)?;
171201

172202
Ok(())
173203
}
174204

175205
pub fn root(&self) -> Result<Node, MerkleError> {
176-
Ok(self
206+
let inner = self.inner.read()
207+
.map_err(|e| MerkleError::LockPoisoned(format!("Failed to acquire read lock on MerkleTree: {}", e)))?;
208+
Ok(inner
177209
.store
178210
.get(&[DEPTH as u32], &[0])?
179211
.into_iter()
180212
.next()
181213
.ok_or_else(|| MerkleError::StoreError("root fetch returned empty vector".into()))?
182-
.unwrap_or(self.zeros[DEPTH]))
214+
.unwrap_or(inner.zeros[DEPTH]))
183215
}
184216

185217
pub fn proof(&self, leaf_idx: u64) -> Result<MerkleProof<DEPTH>, MerkleError> {
@@ -201,6 +233,9 @@ where
201233
});
202234
}
203235

236+
let inner = self.inner.read()
237+
.map_err(|e| MerkleError::LockPoisoned(format!("Failed to acquire read lock on MerkleTree: {}", e)))?;
238+
204239
// Build level/index lists for siblings plus the leaf.
205240
// TODO: Can't do arithmetic here with DEPTH meaning there is no
206241
// easy way to put this in the stack. Unfortunately the array size
@@ -222,40 +257,56 @@ where
222257
indices.push(leaf_idx);
223258

224259
// Batch fetch all requested nodes.
225-
let fetched = self.store.get(&levels, &indices)?;
260+
let fetched = inner.store.get(&levels, &indices)?;
226261

227262
// The first DEPTH items are the siblings.
228263
let mut proof = [Node::ZERO; DEPTH];
229264
for (d, opt) in fetched.iter().take(DEPTH).enumerate() {
230-
proof[d] = opt.unwrap_or(self.zeros[d]);
265+
proof[d] = opt.unwrap_or(inner.zeros[d]);
231266
}
232267

233268
// The last item is the leaf itself.
234-
let leaf_hash = fetched.last().copied().flatten().unwrap_or(self.zeros[0]);
269+
let leaf_hash = fetched.last().copied().flatten().unwrap_or(inner.zeros[0]);
270+
271+
// Release the lock before calling root() to avoid deadlock
272+
let root = {
273+
drop(inner);
274+
self.root()?
275+
};
235276

236277
Ok(MerkleProof {
237-
proof,
238-
leaf: leaf_hash,
239-
index: leaf_idx,
240-
root: self.root()?,
278+
inner: RwLock::new(MerkleProofInner {
279+
proof,
280+
leaf: leaf_hash,
281+
index: leaf_idx,
282+
root,
283+
}),
241284
})
242285
}
243286

244287
pub fn verify_proof(&self, proof: &MerkleProof<DEPTH>) -> Result<bool, MerkleError> {
245-
let mut computed_hash = proof.leaf;
246-
for (j, sibling_hash) in proof.proof.iter().enumerate() {
247-
let (left, right) = if proof.index & (1 << j) == 0 {
288+
let proof_inner = proof.inner.read()
289+
.map_err(|e| MerkleError::LockPoisoned(format!("Failed to acquire read lock on MerkleProof: {}", e)))?;
290+
let tree_inner = self.inner.read()
291+
.map_err(|e| MerkleError::LockPoisoned(format!("Failed to acquire read lock on MerkleTree: {}", e)))?;
292+
let mut computed_hash = proof_inner.leaf;
293+
let idx = proof_inner.index;
294+
let root = proof_inner.root;
295+
for (j, sibling_hash) in proof_inner.proof.iter().enumerate() {
296+
let (left, right) = if idx & (1 << j) == 0 {
248297
(computed_hash, *sibling_hash)
249298
} else {
250299
(*sibling_hash, computed_hash)
251300
};
252-
computed_hash = self.hasher.hash(&left, &right);
301+
computed_hash = tree_inner.hasher.hash(&left, &right);
253302
}
254-
Ok(computed_hash == proof.root)
303+
Ok(computed_hash == root)
255304
}
256305

257-
pub fn num_leaves(&self) -> u64 {
258-
self.store.get_num_leaves()
306+
pub fn num_leaves(&self) -> Result<u64, MerkleError> {
307+
Ok(self.inner.read()
308+
.map_err(|e| MerkleError::LockPoisoned(format!("Failed to acquire read lock on MerkleTree: {}", e)))?
309+
.store.get_num_leaves())
259310
}
260311
}
261312

@@ -312,10 +363,12 @@ mod tests {
312363
to_node!("0x27ae5ba08d7291c96c8cbddcc148bf48a6d68c7974b94356f53754ef6171d757"),
313364
];
314365

315-
for (i, zero) in tree.zeros.front.iter().enumerate() {
366+
let inner = tree.inner.read()
367+
.expect("Lock should not be poisoned in test");
368+
for (i, zero) in inner.zeros.front.iter().enumerate() {
316369
assert_eq!(zero, &expected_zeros[i]);
317370
}
318-
assert_eq!(tree.zeros.last, expected_zeros[32]);
371+
assert_eq!(inner.zeros.last, expected_zeros[32]);
319372
}
320373

321374
#[cfg(feature = "memory_store")]
@@ -366,18 +419,20 @@ mod tests {
366419
to_node!("0x2f68a1c58e257e42a17a6c61dff5551ed560b9922ab119d5ac8e184c9734ead9"),
367420
];
368421

369-
for (i, zero) in tree.zeros.front.iter().enumerate() {
422+
let inner = tree.inner.read()
423+
.expect("Lock should not be poisoned in test");
424+
for (i, zero) in inner.zeros.front.iter().enumerate() {
370425
assert_eq!(zero, &expected_zeros[i]);
371426
}
372-
assert_eq!(tree.zeros.last, expected_zeros[32]);
427+
assert_eq!(inner.zeros.last, expected_zeros[32]);
373428
}
374429

375430
#[cfg(feature = "memory_store")]
376431
#[test]
377432
fn test_tree_full_error() {
378433
let hasher = Keccak256Hasher;
379434
let store = MemoryStore::default();
380-
let mut tree = MerkleTree::<Keccak256Hasher, MemoryStore, 3>::new(hasher, store);
435+
let tree = MerkleTree::<Keccak256Hasher, MemoryStore, 3>::new(hasher, store);
381436

382437
tree.add_leaves(&(0..8).map(|_| Node::ZERO).collect::<Vec<Node>>())
383438
.unwrap();

tests/tree.rs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ fn dir_size(path: &Path) -> u64 {
3232
#[cfg(feature = "memory_store")]
3333
#[test]
3434
fn test_merkle_tree_keccak_32_memory() {
35-
let mut tree: MerkleTree32 = MerkleTree::new(Keccak256Hasher, MemoryStore::default());
35+
let tree: MerkleTree32 = MerkleTree::new(Keccak256Hasher, MemoryStore::default());
3636

3737
// create 10k leaves.
3838
let leaves = (0..10_000)
@@ -44,21 +44,21 @@ fn test_merkle_tree_keccak_32_memory() {
4444
tree.add_leaves(&[*i]).unwrap();
4545
}
4646

47-
assert_eq!(tree.num_leaves(), 10_000);
47+
assert_eq!(tree.num_leaves().unwrap(), 10_000);
4848
assert_eq!(
4949
tree.root().unwrap(),
5050
to_node!("0x532c79f3ea0f4873946d1b14770eaa1c157255a003e73da987b858cc287b0482")
5151
);
5252

5353
// reset the tree.
54-
let mut tree: MerkleTree32 = MerkleTree::new(Keccak256Hasher, MemoryStore::default());
54+
let tree: MerkleTree32 = MerkleTree::new(Keccak256Hasher, MemoryStore::default());
5555

5656
// same but add them in batches of 1_000.
5757
for batch in leaves.chunks(1_000) {
5858
tree.add_leaves(&batch).unwrap();
5959
}
6060

61-
assert_eq!(tree.num_leaves(), 10_000);
61+
assert_eq!(tree.num_leaves().unwrap(), 10_000);
6262
assert_eq!(
6363
tree.root().unwrap(),
6464
to_node!("0x532c79f3ea0f4873946d1b14770eaa1c157255a003e73da987b858cc287b0482")
@@ -67,7 +67,7 @@ fn test_merkle_tree_keccak_32_memory() {
6767
// Get proofs for each leaf and verify them.
6868
for i in 0..10_000 {
6969
let proof = tree.proof(i).unwrap();
70-
assert_eq!(proof.proof.len(), 32);
70+
assert_eq!(proof.read().unwrap().proof.len(), 32);
7171
assert_eq!(tree.verify_proof(&proof).unwrap(), true);
7272
}
7373

@@ -104,7 +104,7 @@ fn test_disk_space() {
104104
S: Store,
105105
F: FnOnce() -> S,
106106
{
107-
let mut tree: MerkleTree<Keccak256Hasher, S, 32> =
107+
let tree: MerkleTree<Keccak256Hasher, S, 32> =
108108
MerkleTree::new(Keccak256Hasher, new_store());
109109

110110
for _ in 0..NUM_BATCHES {

0 commit comments

Comments
 (0)