Skip to content

Commit 85585b0

Browse files
Support bool, ClassHash, ContractAddress basic types in TypedData (#1370)
1 parent 69b0f33 commit 85585b0

File tree

2 files changed

+92
-24
lines changed

2 files changed

+92
-24
lines changed

starknet_py/utils/typed_data.py

Lines changed: 49 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from starknet_py.cairo.felt import encode_shortstring
88
from starknet_py.hash.hash_method import HashMethod
99
from starknet_py.hash.selector import get_selector_from_name
10-
from starknet_py.hash.utils import compute_hash_on_elements
1110
from starknet_py.net.client_utils import _to_rpc_felt
1211
from starknet_py.net.models.typed_data import DomainDict, Revision, TypedDataDict
1312
from starknet_py.net.schemas.common import RevisionField
@@ -76,6 +75,17 @@ class TypeContext:
7675
key: str
7776

7877

78+
class BasicType(Enum):
79+
FELT = "felt"
80+
SELECTOR = "selector"
81+
MERKLE_TREE = "merkletree"
82+
SHORT_STRING = "shortstring"
83+
STRING = "string"
84+
CONTRACT_ADDRESS = "ContractAddress"
85+
CLASS_HASH = "ClassHash"
86+
BOOL = "bool"
87+
88+
7989
@dataclass(frozen=True)
8090
class TypedData:
8191
"""
@@ -118,6 +128,7 @@ def to_dict(self) -> dict:
118128
def _is_struct(self, type_name: str) -> bool:
119129
return type_name in self.types
120130

131+
# pylint: disable=too-many-return-statements
121132
def _encode_value(
122133
self,
123134
type_name: str,
@@ -130,26 +141,34 @@ def _encode_value(
130141
if is_pointer(type_name) and isinstance(value, list):
131142
type_name = strip_pointer(type_name)
132143
hashes = [self._encode_value(type_name, val) for val in value]
133-
return compute_hash_on_elements(hashes)
144+
return self._hash_method.hash_many(hashes)
134145

135146
if type_name not in _get_basic_type_names(self.domain.resolved_revision):
136147
raise ValueError(f"Type [{type_name}] is not defined in types.")
137148

138149
basic_type = BasicType(type_name)
139150

140-
if basic_type == BasicType.MERKLE_TREE and isinstance(value, list):
141-
if context is None:
142-
raise ValueError(f"Context is not provided for '{type_name}' type.")
143-
return self._prepare_merkle_tree_root(value, context)
151+
if (basic_type, self.domain.resolved_revision) in [
152+
(BasicType.FELT, Revision.V0),
153+
(BasicType.FELT, Revision.V1),
154+
(BasicType.STRING, Revision.V0),
155+
(BasicType.SHORT_STRING, Revision.V1),
156+
(BasicType.CONTRACT_ADDRESS, Revision.V1),
157+
(BasicType.CLASS_HASH, Revision.V1),
158+
] and isinstance(value, (int, str)):
159+
return parse_felt(value)
144160

145-
if basic_type in (BasicType.FELT, BasicType.SHORT_STRING) and isinstance(
146-
value, (int, str, Revision)
147-
):
148-
return int(get_hex(value), 16)
161+
if basic_type == BasicType.BOOL and isinstance(value, (bool, str, int)):
162+
return encode_bool(value)
149163

150164
if basic_type == BasicType.SELECTOR and isinstance(value, str):
151165
return prepare_selector(value)
152166

167+
if basic_type == BasicType.MERKLE_TREE and isinstance(value, list):
168+
if context is None:
169+
raise ValueError(f"Context is not provided for '{type_name}' type.")
170+
return self._prepare_merkle_tree_root(value, context)
171+
153172
raise ValueError(
154173
f"Error occurred while encoding value with type name {type_name}."
155174
)
@@ -294,14 +313,14 @@ def _get_merkle_tree_leaves_type(self, context: TypeContext) -> str:
294313
return target_type.contains
295314

296315

297-
def get_hex(value: Union[int, str]) -> str:
316+
def parse_felt(value: Union[int, str]) -> int:
298317
if isinstance(value, int):
299-
return hex(value)
300-
if value[:2] == "0x":
301318
return value
319+
if value.startswith("0x"):
320+
return int(value, 16)
302321
if value.isnumeric():
303-
return hex(int(value))
304-
return hex(encode_shortstring(value))
322+
return int(value)
323+
return encode_shortstring(value)
305324

306325

307326
def is_pointer(value: str) -> bool:
@@ -327,12 +346,18 @@ def prepare_selector(name: str) -> int:
327346
return get_selector_from_name(name)
328347

329348

330-
class BasicType(Enum):
331-
FELT = "felt"
332-
SELECTOR = "selector"
333-
MERKLE_TREE = "merkletree"
334-
STRING = "string"
335-
SHORT_STRING = "shortstring"
349+
def encode_bool(value: Union[bool, str, int]) -> int:
350+
if isinstance(value, bool):
351+
return 1 if value else 0
352+
if isinstance(value, int) and value in (0, 1):
353+
return value
354+
if isinstance(value, str) and value in ("0", "1"):
355+
return int(value)
356+
if isinstance(value, str) and value in ("false", "true"):
357+
return 0 if value == "false" else 1
358+
if isinstance(value, str) and value in ("0x0", "0x1"):
359+
return int(value, 16)
360+
raise ValueError(f"Expected boolean value, got [{value}].")
336361

337362

338363
def _get_basic_type_names(revision: Revision) -> List[str]:
@@ -341,10 +366,13 @@ def _get_basic_type_names(revision: Revision) -> List[str]:
341366
BasicType.SELECTOR,
342367
BasicType.MERKLE_TREE,
343368
BasicType.STRING,
369+
BasicType.BOOL,
344370
]
345371

346372
basic_types_v1 = basic_types_v0 + [
347373
BasicType.SHORT_STRING,
374+
BasicType.CONTRACT_ADDRESS,
375+
BasicType.CLASS_HASH,
348376
]
349377

350378
basic_types = basic_types_v0 if revision == Revision.V0 else basic_types_v1

starknet_py/utils/typed_data_test.py

Lines changed: 43 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import json
55
from enum import Enum
66
from pathlib import Path
7+
from typing import Union
78

89
import pytest
910

@@ -14,7 +15,8 @@
1415
Domain,
1516
Parameter,
1617
TypedData,
17-
get_hex,
18+
encode_bool,
19+
parse_felt,
1820
)
1921

2022

@@ -47,8 +49,8 @@ def load_typed_data(file_name: str) -> TypedData:
4749
"value, result",
4850
[(123, "0x7b"), ("123", "0x7b"), ("0x7b", "0x7b"), ("short_string", "0x73686f72745f737472696e67")],
4951
)
50-
def test_get_hex(value, result):
51-
assert get_hex(value) == result
52+
def test_parse_felt(value, result):
53+
assert parse_felt(value) == int(result, 16)
5254

5355

5456
@pytest.mark.parametrize(
@@ -221,10 +223,12 @@ def test_invalid_type_names(included_type: str, revision: Revision):
221223
(BasicType.STRING.value, Revision.V0),
222224
(BasicType.SELECTOR.value, Revision.V0),
223225
(BasicType.MERKLE_TREE.value, Revision.V0),
226+
(BasicType.BOOL.value, Revision.V0),
224227
(BasicType.FELT.value, Revision.V1),
225228
(BasicType.STRING.value, Revision.V1),
226229
(BasicType.SELECTOR.value, Revision.V1),
227230
(BasicType.MERKLE_TREE.value, Revision.V1),
231+
(BasicType.BOOL.value, Revision.V1),
228232
(BasicType.SHORT_STRING.value, Revision.V1),
229233
],
230234
)
@@ -280,3 +284,39 @@ def test_missing_dependency():
280284

281285
with pytest.raises(ValueError, match=r"Type \[ice cream\] is not defined in types."):
282286
typed_data.struct_hash("house", {"fridge": 1})
287+
288+
289+
@pytest.mark.parametrize(
290+
"value, expected",
291+
[
292+
(True, 1),
293+
(False, 0),
294+
("true", 1),
295+
("false", 0),
296+
("0x1", 1),
297+
("0x0", 0),
298+
("1", 1),
299+
("0", 0),
300+
(1, 1),
301+
(0, 0)
302+
303+
]
304+
)
305+
def test_encode_bool(value: Union[bool, str, int], expected: int):
306+
assert encode_bool(value) == expected
307+
308+
309+
@pytest.mark.parametrize(
310+
"value",
311+
[
312+
-2,
313+
2,
314+
"-2",
315+
"2",
316+
"0x123",
317+
"anyvalue",
318+
]
319+
)
320+
def test_encode_invalid_bool(value: Union[bool, str, int]):
321+
with pytest.raises(ValueError, match=fr"Expected boolean value, got \[{value}\]."):
322+
encode_bool(value)

0 commit comments

Comments
 (0)