Skip to content

Commit f685886

Browse files
committed
Async ready lib
1 parent 85c496b commit f685886

File tree

5 files changed

+149
-49
lines changed

5 files changed

+149
-49
lines changed

src/errors.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,7 @@ pub enum MerkleError {
1111

1212
#[error("Levels and indices must have the same length")]
1313
LengthMismatch { levels: usize, indices: usize },
14+
15+
#[error("Lock was poisoned: {0}")]
16+
LockPoisoned(String),
1417
}

src/lib.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,8 @@ fn main() {
3131
.unwrap();
3232
3333
println!("root: {:?}", tree.root().unwrap());
34-
println!("num leaves: {:?}", tree.num_leaves());
35-
println!("proof: {:?}", tree.proof(0).unwrap().proof);
34+
println!("num leaves: {:?}", tree.num_leaves().unwrap());
35+
println!("proof: {:?}", tree.proof(0).unwrap().read().unwrap().proof);
3636
}
3737
```
3838

src/stores/memory_store.rs

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,16 @@
33
//! Simple in-memory store implementation.
44
55
use crate::{MerkleError, Node, Store};
6-
use std::collections::HashMap;
6+
use std::{collections::HashMap, sync::RwLock};
77

88
/// Simple in-memory store implementation using a `HashMap`.
99
#[derive(Default)]
1010
pub struct MemoryStore {
11+
inner: RwLock<MemoryStoreInner>,
12+
}
13+
14+
#[derive(Default)]
15+
struct MemoryStoreInner {
1116
store: HashMap<(u32, u64), Node>,
1217
num_leaves: u64,
1318
}
@@ -26,27 +31,36 @@ impl Store for MemoryStore {
2631
indices: indices.len(),
2732
});
2833
}
29-
30-
// The memory store doesnt really allow batch reads, so just get all the
31-
// indexes/levels one by one.
34+
let inner = self.inner.read().map_err(|e| {
35+
MerkleError::LockPoisoned(format!("Failed to acquire read lock on MemoryStore: {}", e))
36+
})?;
3237
let result = levels
3338
.iter()
3439
.zip(indices)
35-
.map(|(&lvl, &idx)| self.store.get(&(lvl, idx)).cloned())
40+
.map(|(&lvl, &idx)| inner.store.get(&(lvl, idx)).cloned())
3641
.collect();
37-
3842
Ok(result)
3943
}
4044

4145
fn put(&mut self, items: &[(u32, u64, Node)]) -> Result<(), MerkleError> {
46+
let mut inner = self.inner.write().map_err(|e| {
47+
MerkleError::LockPoisoned(format!(
48+
"Failed to acquire write lock on MemoryStore: {}",
49+
e
50+
))
51+
})?;
4252
for (level, index, hash) in items {
43-
self.store.insert((*level, *index), *hash);
53+
inner.store.insert((*level, *index), *hash);
4454
}
4555
let counter = items.iter().filter(|(level, _, _)| *level == 0).count();
46-
self.num_leaves += counter as u64;
56+
inner.num_leaves += counter as u64;
4757
Ok(())
4858
}
4959
fn get_num_leaves(&self) -> u64 {
50-
self.num_leaves
60+
// For get_num_leaves, we use expect since it's a simple getter and lock poisoning
61+
// would indicate a serious bug. Using expect provides a clearer panic message.
62+
self.inner.read()
63+
.expect("MemoryStore lock was poisoned - this indicates a panic occurred while holding the lock")
64+
.num_leaves
5165
}
5266
}

src/tree.rs

Lines changed: 115 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -5,19 +5,55 @@
55
use crate::hasher::{Hasher, Keccak256Hasher};
66
use crate::{MerkleError, Node, Store};
77
use core::ops::Index;
8-
use std::collections::HashMap;
8+
use std::{collections::HashMap, sync::RwLock};
99

1010
#[cfg(feature = "memory_store")]
1111
use crate::stores::MemoryStore;
1212

1313
pub struct MerkleProof<const DEPTH: usize> {
14+
inner: RwLock<MerkleProofInner<DEPTH>>,
15+
}
16+
17+
impl<const DEPTH: usize> MerkleProof<DEPTH> {
18+
/// Get a read lock on the proof data
19+
pub fn read(
20+
&self,
21+
) -> Result<std::sync::RwLockReadGuard<'_, MerkleProofInner<DEPTH>>, MerkleError> {
22+
self.inner.read().map_err(|e| {
23+
MerkleError::LockPoisoned(format!("Failed to acquire read lock on MerkleProof: {}", e))
24+
})
25+
}
26+
27+
/// Get a write lock on the proof data
28+
pub fn write(
29+
&self,
30+
) -> Result<std::sync::RwLockWriteGuard<'_, MerkleProofInner<DEPTH>>, MerkleError> {
31+
self.inner.write().map_err(|e| {
32+
MerkleError::LockPoisoned(format!(
33+
"Failed to acquire write lock on MerkleProof: {}",
34+
e
35+
))
36+
})
37+
}
38+
}
39+
40+
// Make MerkleProofInner public so the read/write methods can return it
41+
pub struct MerkleProofInner<const DEPTH: usize> {
1442
pub proof: [Node; DEPTH],
1543
pub leaf: Node,
1644
pub index: u64,
1745
pub root: Node,
1846
}
1947

2048
pub struct MerkleTree<H, S, const DEPTH: usize>
49+
where
50+
H: Hasher,
51+
S: Store,
52+
{
53+
inner: RwLock<MerkleTreeInner<H, S, DEPTH>>,
54+
}
55+
56+
struct MerkleTreeInner<H, S, const DEPTH: usize>
2157
where
2258
H: Hasher,
2359
S: Store,
@@ -78,21 +114,27 @@ where
78114
last: hasher.hash(&zero[DEPTH - 1], &zero[DEPTH - 1]),
79115
};
80116
Self {
81-
hasher,
82-
store,
83-
zeros,
117+
inner: RwLock::new(MerkleTreeInner {
118+
hasher,
119+
store,
120+
zeros,
121+
}),
84122
}
85123
}
86124

87-
pub fn add_leaves(&mut self, leaves: &[Node]) -> Result<(), MerkleError> {
125+
pub fn add_leaves(&self, leaves: &[Node]) -> Result<(), MerkleError> {
88126
// Early return
89127
if leaves.is_empty() {
90128
return Ok(());
91129
}
92130

131+
let mut inner = self.inner.write().map_err(|e| {
132+
MerkleError::LockPoisoned(format!("Failed to acquire write lock on MerkleTree: {}", e))
133+
})?;
134+
93135
// Error if leaves do not fit in the tree
94136
// TODO: Avoid calculating this. Calculate it at init or do the shifting with the generic.
95-
if self.store.get_num_leaves() + leaves.len() as u64 > (1 << DEPTH as u64) {
137+
if inner.store.get_num_leaves() + leaves.len() as u64 > (1 << DEPTH as u64) {
96138
return Err(MerkleError::TreeFull {
97139
depth: DEPTH as u32,
98140
capacity: 1 << DEPTH as u64,
@@ -107,7 +149,7 @@ where
107149
let mut cache: HashMap<(u32, u64), Node> = HashMap::new();
108150

109151
for (offset, leaf) in leaves.iter().enumerate() {
110-
let mut idx = self.store.get_num_leaves() + offset as u64;
152+
let mut idx = inner.store.get_num_leaves() + offset as u64;
111153
let mut h = *leaf;
112154

113155
// Store the leaf
@@ -132,7 +174,7 @@ where
132174

133175
// Batch-fetch the missing siblings and insert them in cache.
134176
if fetch_len != 0 {
135-
let fetched = self.store.get(
177+
let fetched = inner.store.get(
136178
&levels_to_fetch[..fetch_len],
137179
&indices_to_fetch[..fetch_len],
138180
)?;
@@ -150,15 +192,15 @@ where
150192
let sib_hash = cache
151193
.get(&(level as u32, sibling_idx))
152194
.copied()
153-
.unwrap_or(self.zeros[level]);
195+
.unwrap_or(inner.zeros[level]);
154196

155197
let (left, right) = if idx & 1 == 1 {
156198
(sib_hash, h)
157199
} else {
158200
(h, sib_hash)
159201
};
160202

161-
h = self.hasher.hash(&left, &right);
203+
h = inner.hasher.hash(&left, &right);
162204
idx >>= 1;
163205

164206
batch.push(((level + 1) as u32, idx, h));
@@ -167,19 +209,22 @@ where
167209
}
168210

169211
// Update all values in a single batch
170-
self.store.put(&batch)?;
212+
inner.store.put(&batch)?;
171213

172214
Ok(())
173215
}
174216

175217
pub fn root(&self) -> Result<Node, MerkleError> {
176-
Ok(self
218+
let inner = self.inner.read().map_err(|e| {
219+
MerkleError::LockPoisoned(format!("Failed to acquire read lock on MerkleTree: {}", e))
220+
})?;
221+
Ok(inner
177222
.store
178223
.get(&[DEPTH as u32], &[0])?
179224
.into_iter()
180225
.next()
181226
.ok_or_else(|| MerkleError::StoreError("root fetch returned empty vector".into()))?
182-
.unwrap_or(self.zeros[DEPTH]))
227+
.unwrap_or(inner.zeros[DEPTH]))
183228
}
184229

185230
pub fn proof(&self, leaf_idx: u64) -> Result<MerkleProof<DEPTH>, MerkleError> {
@@ -201,6 +246,10 @@ where
201246
});
202247
}
203248

249+
let inner = self.inner.read().map_err(|e| {
250+
MerkleError::LockPoisoned(format!("Failed to acquire read lock on MerkleTree: {}", e))
251+
})?;
252+
204253
// Build level/index lists for siblings plus the leaf.
205254
// TODO: Can't do arithmetic here with DEPTH meaning there is no
206255
// easy way to put this in the stack. Unfortunately the array size
@@ -222,40 +271,66 @@ where
222271
indices.push(leaf_idx);
223272

224273
// Batch fetch all requested nodes.
225-
let fetched = self.store.get(&levels, &indices)?;
274+
let fetched = inner.store.get(&levels, &indices)?;
226275

227276
// The first DEPTH items are the siblings.
228277
let mut proof = [Node::ZERO; DEPTH];
229278
for (d, opt) in fetched.iter().take(DEPTH).enumerate() {
230-
proof[d] = opt.unwrap_or(self.zeros[d]);
279+
proof[d] = opt.unwrap_or(inner.zeros[d]);
231280
}
232281

233282
// The last item is the leaf itself.
234-
let leaf_hash = fetched.last().copied().flatten().unwrap_or(self.zeros[0]);
283+
let leaf_hash = fetched.last().copied().flatten().unwrap_or(inner.zeros[0]);
284+
285+
// Release the lock before calling root() to avoid deadlock
286+
let root = {
287+
drop(inner);
288+
self.root()?
289+
};
235290

236291
Ok(MerkleProof {
237-
proof,
238-
leaf: leaf_hash,
239-
index: leaf_idx,
240-
root: self.root()?,
292+
inner: RwLock::new(MerkleProofInner {
293+
proof,
294+
leaf: leaf_hash,
295+
index: leaf_idx,
296+
root,
297+
}),
241298
})
242299
}
243300

244301
pub fn verify_proof(&self, proof: &MerkleProof<DEPTH>) -> Result<bool, MerkleError> {
245-
let mut computed_hash = proof.leaf;
246-
for (j, sibling_hash) in proof.proof.iter().enumerate() {
247-
let (left, right) = if proof.index & (1 << j) == 0 {
302+
let proof_inner = proof.inner.read().map_err(|e| {
303+
MerkleError::LockPoisoned(format!("Failed to acquire read lock on MerkleProof: {}", e))
304+
})?;
305+
let tree_inner = self.inner.read().map_err(|e| {
306+
MerkleError::LockPoisoned(format!("Failed to acquire read lock on MerkleTree: {}", e))
307+
})?;
308+
let mut computed_hash = proof_inner.leaf;
309+
let idx = proof_inner.index;
310+
let root = proof_inner.root;
311+
for (j, sibling_hash) in proof_inner.proof.iter().enumerate() {
312+
let (left, right) = if idx & (1 << j) == 0 {
248313
(computed_hash, *sibling_hash)
249314
} else {
250315
(*sibling_hash, computed_hash)
251316
};
252-
computed_hash = self.hasher.hash(&left, &right);
317+
computed_hash = tree_inner.hasher.hash(&left, &right);
253318
}
254-
Ok(computed_hash == proof.root)
319+
Ok(computed_hash == root)
255320
}
256321

257-
pub fn num_leaves(&self) -> u64 {
258-
self.store.get_num_leaves()
322+
pub fn num_leaves(&self) -> Result<u64, MerkleError> {
323+
Ok(self
324+
.inner
325+
.read()
326+
.map_err(|e| {
327+
MerkleError::LockPoisoned(format!(
328+
"Failed to acquire read lock on MerkleTree: {}",
329+
e
330+
))
331+
})?
332+
.store
333+
.get_num_leaves())
259334
}
260335
}
261336

@@ -312,10 +387,14 @@ mod tests {
312387
to_node!("0x27ae5ba08d7291c96c8cbddcc148bf48a6d68c7974b94356f53754ef6171d757"),
313388
];
314389

315-
for (i, zero) in tree.zeros.front.iter().enumerate() {
390+
let inner = tree
391+
.inner
392+
.read()
393+
.expect("Lock should not be poisoned in test");
394+
for (i, zero) in inner.zeros.front.iter().enumerate() {
316395
assert_eq!(zero, &expected_zeros[i]);
317396
}
318-
assert_eq!(tree.zeros.last, expected_zeros[32]);
397+
assert_eq!(inner.zeros.last, expected_zeros[32]);
319398
}
320399

321400
#[cfg(feature = "memory_store")]
@@ -366,18 +445,22 @@ mod tests {
366445
to_node!("0x2f68a1c58e257e42a17a6c61dff5551ed560b9922ab119d5ac8e184c9734ead9"),
367446
];
368447

369-
for (i, zero) in tree.zeros.front.iter().enumerate() {
448+
let inner = tree
449+
.inner
450+
.read()
451+
.expect("Lock should not be poisoned in test");
452+
for (i, zero) in inner.zeros.front.iter().enumerate() {
370453
assert_eq!(zero, &expected_zeros[i]);
371454
}
372-
assert_eq!(tree.zeros.last, expected_zeros[32]);
455+
assert_eq!(inner.zeros.last, expected_zeros[32]);
373456
}
374457

375458
#[cfg(feature = "memory_store")]
376459
#[test]
377460
fn test_tree_full_error() {
378461
let hasher = Keccak256Hasher;
379462
let store = MemoryStore::default();
380-
let mut tree = MerkleTree::<Keccak256Hasher, MemoryStore, 3>::new(hasher, store);
463+
let tree = MerkleTree::<Keccak256Hasher, MemoryStore, 3>::new(hasher, store);
381464

382465
tree.add_leaves(&(0..8).map(|_| Node::ZERO).collect::<Vec<Node>>())
383466
.unwrap();

0 commit comments

Comments
 (0)