Skip to content

Commit 2e17b16

Browse files
Add MerkleTree and merkletree support in TypedData (#1363)
1 parent 8ce7035 commit 2e17b16

File tree

9 files changed

+417
-36
lines changed

9 files changed

+417
-36
lines changed

starknet_py/hash/hash_method.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
from enum import Enum
22
from typing import List
33

4-
from poseidon_py.poseidon_hash import poseidon_hash_many
4+
from poseidon_py.poseidon_hash import poseidon_hash, poseidon_hash_many
55

6-
from starknet_py.hash.utils import compute_hash_on_elements
6+
from starknet_py.hash.utils import compute_hash_on_elements, pedersen_hash
77

88

99
class HashMethod(Enum):
@@ -14,7 +14,14 @@ class HashMethod(Enum):
1414
PEDERSEN = "pedersen"
1515
POSEIDON = "poseidon"
1616

17-
def hash(self, values: List[int]):
17+
def hash(self, left: int, right: int):
18+
if self == HashMethod.PEDERSEN:
19+
return pedersen_hash(left, right)
20+
if self == HashMethod.POSEIDON:
21+
return poseidon_hash(left, right)
22+
raise ValueError(f"Unsupported hash method: {self}.")
23+
24+
def hash_many(self, values: List[int]):
1825
if self == HashMethod.PEDERSEN:
1926
return compute_hash_on_elements(values)
2027
if self == HashMethod.POSEIDON:

starknet_py/net/models/typed_data.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ class ParameterDict(TypedDict):
1414

1515
name: str
1616
type: str
17+
contains: Optional[str]
1718

1819

1920
class DomainDict(TypedDict):
@@ -36,3 +37,8 @@ class TypedDataDict(TypedDict):
3637
primaryType: str
3738
domain: DomainDict
3839
message: Dict[str, Any]
40+
41+
42+
class TypeContext(TypedDict):
43+
parent: str
44+
key: str

starknet_py/net/schemas/common.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -357,8 +357,8 @@ def _serialize(self, value: Any, attr: str, obj: Any, **kwargs):
357357
def _deserialize(self, value, attr, data, **kwargs) -> Revision:
358358
if isinstance(value, str):
359359
value = int(value)
360-
revisions = [revision.value for revision in Revision]
361360

361+
revisions = [revision.value for revision in Revision]
362362
if value not in revisions:
363363
allowed_revisions_str = "".join(list(map(str, revisions)))
364364
raise ValidationError(
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
{
2+
"primaryType": "Session",
3+
"types": {
4+
"Policy": [
5+
{
6+
"name": "contractAddress",
7+
"type": "felt"
8+
},
9+
{
10+
"name": "selector",
11+
"type": "selector"
12+
}
13+
],
14+
"Session": [
15+
{
16+
"name": "key",
17+
"type": "felt"
18+
},
19+
{
20+
"name": "expires",
21+
"type": "felt"
22+
},
23+
{
24+
"name": "root",
25+
"type": "merkletree",
26+
"contains": "Policy"
27+
}
28+
],
29+
"StarkNetDomain": [
30+
{
31+
"name": "name",
32+
"type": "felt"
33+
},
34+
{
35+
"name": "version",
36+
"type": "felt"
37+
},
38+
{
39+
"name": "chainId",
40+
"type": "felt"
41+
}
42+
]
43+
},
44+
"domain": {
45+
"name": "StarkNet Mail",
46+
"version": "1",
47+
"chainId": "1"
48+
},
49+
"message": {
50+
"key": "0x0000000000000000000000000000000000000000000000000000000000000000",
51+
"expires": "0x0000000000000000000000000000000000000000000000000000000000000000",
52+
"root": [
53+
{
54+
"contractAddress": "0x1",
55+
"selector": "transfer"
56+
},
57+
{
58+
"contractAddress": "0x2",
59+
"selector": "transfer"
60+
},
61+
{
62+
"contractAddress": "0x3",
63+
"selector": "transfer"
64+
}
65+
]
66+
}
67+
}
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
{
2+
"primaryType": "Example",
3+
"types": {
4+
"Example": [
5+
{ "name": "value", "type": "felt" },
6+
{ "name": "root", "type": "merkletree", "contains": "felt" }
7+
],
8+
"StarknetDomain": [
9+
{ "name": "name", "type": "shortstring" },
10+
{ "name": "version", "type": "shortstring" },
11+
{ "name": "chainId", "type": "shortstring" },
12+
{ "name": "revision", "type": "shortstring" }
13+
]
14+
},
15+
"domain": {
16+
"name": "StarkNet Mail",
17+
"version": "1",
18+
"chainId": "1",
19+
"revision": "1"
20+
},
21+
"message": {
22+
"value": "0x2137",
23+
"root": [
24+
"0x1",
25+
"0x2",
26+
"0x3"
27+
]
28+
}
29+
}

starknet_py/utils/merkle_tree.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
from dataclasses import dataclass, field
2+
from typing import List, Tuple
3+
4+
from starknet_py.hash.hash_method import HashMethod
5+
6+
7+
@dataclass
8+
class MerkleTree:
9+
"""
10+
Dataclass representing a MerkleTree object.
11+
"""
12+
13+
leaves: List[int]
14+
hash_method: HashMethod
15+
root_hash: int = field(init=False)
16+
levels: List[List[int]] = field(init=False)
17+
18+
def __post_init__(self):
19+
self.root_hash, self.levels = self._build()
20+
21+
def _build(self) -> Tuple[int, List[List[int]]]:
22+
if not self.leaves:
23+
raise ValueError("Cannot build Merkle tree from an empty list of leaves.")
24+
25+
if len(self.leaves) == 1:
26+
return self.leaves[0], [self.leaves]
27+
28+
curr_level_nodes = self.leaves[:]
29+
levels: List[List[int]] = []
30+
31+
while len(curr_level_nodes) > 1:
32+
if len(curr_level_nodes) != len(self.leaves):
33+
levels.append(curr_level_nodes[:])
34+
35+
new_nodes = []
36+
for i in range(0, len(curr_level_nodes), 2):
37+
a, b = (
38+
curr_level_nodes[i],
39+
curr_level_nodes[i + 1] if i + 1 < len(curr_level_nodes) else 0,
40+
)
41+
new_nodes.append(self.hash_method.hash(*sorted([a, b])))
42+
43+
curr_level_nodes = new_nodes
44+
levels = [self.leaves] + levels + [curr_level_nodes]
45+
return curr_level_nodes[0], levels
Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
from typing import List
2+
3+
import pytest
4+
from poseidon_py.poseidon_hash import poseidon_hash
5+
6+
from starknet_py.hash.hash_method import HashMethod
7+
from starknet_py.hash.utils import pedersen_hash
8+
from starknet_py.utils.merkle_tree import MerkleTree
9+
10+
11+
@pytest.mark.parametrize(
12+
"leaves, hash_method, expected_root_hash",
13+
[
14+
(
15+
["0x12", "0xa"],
16+
HashMethod.PEDERSEN,
17+
"0x586699e3ba6f118227e094ad423313a2d51871507dcbc23116f11cdd79d80f2",
18+
),
19+
(
20+
["0x12", "0xa"],
21+
HashMethod.POSEIDON,
22+
"0x6257f1f60f7c9fd49e2718c8ad19cd8dce6b1ba4b553b2123113f22b1e9c379",
23+
),
24+
(
25+
[
26+
"0x5bb9440e27889a364bcb678b1f679ecd1347acdedcbf36e83494f857cc58026",
27+
"0x3",
28+
],
29+
HashMethod.PEDERSEN,
30+
"0x551b4adb6c35d49c686a00b9192da9332b18c9b262507cad0ece37f3b6918d2",
31+
),
32+
(
33+
[
34+
"0x5bb9440e27889a364bcb678b1f679ecd1347acdedcbf36e83494f857cc58026",
35+
"0x3",
36+
],
37+
HashMethod.POSEIDON,
38+
"0xc118a3963c12777b0717d1dc89baa8b3ceed84dfd713a6bd1354676f03f021",
39+
),
40+
],
41+
)
42+
def test_calculate_hash(
43+
leaves: List[str], hash_method: HashMethod, expected_root_hash: str
44+
):
45+
if hash_method == HashMethod.PEDERSEN:
46+
apply_hash = pedersen_hash
47+
elif hash_method == HashMethod.POSEIDON:
48+
apply_hash = poseidon_hash
49+
else:
50+
raise ValueError(f"Unsupported hash method: {hash_method}.")
51+
52+
a, b = int(leaves[0], 16), int(leaves[1], 16)
53+
merkle_hash = hash_method.hash(*sorted([b, a]))
54+
raw_hash = apply_hash(*sorted([b, a]))
55+
56+
assert raw_hash == merkle_hash
57+
assert int(expected_root_hash, 16) == merkle_hash
58+
59+
60+
@pytest.mark.parametrize(
61+
"hash_method",
62+
[
63+
HashMethod.PEDERSEN,
64+
HashMethod.POSEIDON,
65+
],
66+
)
67+
def test_build_from_0_elements(hash_method: HashMethod):
68+
with pytest.raises(
69+
ValueError, match="Cannot build Merkle tree from an empty list of leaves."
70+
):
71+
MerkleTree([], hash_method)
72+
73+
74+
@pytest.mark.parametrize(
75+
"leaves, hash_method, expected_root_hash, expected_levels_count",
76+
[
77+
(["0x1"], HashMethod.PEDERSEN, "0x1", 1),
78+
(["0x1"], HashMethod.POSEIDON, "0x1", 1),
79+
(
80+
["0x1", "0x2"],
81+
HashMethod.PEDERSEN,
82+
"0x5bb9440e27889a364bcb678b1f679ecd1347acdedcbf36e83494f857cc58026",
83+
2,
84+
),
85+
(
86+
["0x1", "0x2"],
87+
HashMethod.POSEIDON,
88+
"0x5d44a3decb2b2e0cc71071f7b802f45dd792d064f0fc7316c46514f70f9891a",
89+
2,
90+
),
91+
(
92+
["0x1", "0x2", "0x3", "0x4"],
93+
HashMethod.PEDERSEN,
94+
"0x38118a340bbba28e678413cd3b07a9436a5e60fd6a7cbda7db958a6d501e274",
95+
3,
96+
),
97+
(
98+
["0x1", "0x2", "0x3", "0x4"],
99+
HashMethod.POSEIDON,
100+
"0xa4d02f1e82fc554b062b754d3a4995e0ed8fc7e5016a7ca2894a451a4bae64",
101+
3,
102+
),
103+
(
104+
["0x1", "0x2", "0x3", "0x4", "0x5", "0x6"],
105+
HashMethod.PEDERSEN,
106+
"0x329d5b51e352537e8424bfd85b34d0f30b77d213e9b09e2976e6f6374ecb59",
107+
4,
108+
),
109+
(
110+
["0x1", "0x2", "0x3", "0x4", "0x5", "0x6"],
111+
HashMethod.POSEIDON,
112+
"0x34d525f018d8d6b3e492b1c9cda9bbdc3bc7834b408a30a417186c698c34766",
113+
4,
114+
),
115+
(
116+
["0x1", "0x2", "0x3", "0x4", "0x5", "0x6", "0x7"],
117+
HashMethod.PEDERSEN,
118+
"0x7f748c75e5bdb7ae28013f076b8ab650c4e01d3530c6e5ab665f9f1accbe7d4",
119+
4,
120+
),
121+
(
122+
["0x1", "0x2", "0x3", "0x4", "0x5", "0x6", "0x7"],
123+
HashMethod.POSEIDON,
124+
"0x3308a3c50c25883753f82b21f14c644ec375b88ea5b0f83d1e6afe74d0ed790",
125+
4,
126+
),
127+
],
128+
)
129+
def test_build_from_elements(
130+
leaves: List[str],
131+
hash_method: HashMethod,
132+
expected_root_hash: str,
133+
expected_levels_count: int,
134+
):
135+
tree = MerkleTree([int(leaf, 16) for leaf in leaves], hash_method)
136+
137+
assert tree.root_hash is not None
138+
assert tree.levels is not None
139+
assert tree.root_hash == int(expected_root_hash, 16)
140+
assert len(tree.levels) == expected_levels_count

0 commit comments

Comments
 (0)