Skip to content

Commit 2ad5d45

Browse files
Support u128, i128, timestamp in TypedData (#1380)
1 parent 944728f commit 2ad5d45

File tree

2 files changed

+180
-10
lines changed

2 files changed

+180
-10
lines changed

starknet_py/utils/typed_data.py

Lines changed: 102 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
1+
import re
12
from dataclasses import dataclass
23
from enum import Enum
34
from typing import Dict, List, Optional, Union, cast
45

56
from marshmallow import Schema, fields, post_load
67

78
from starknet_py.cairo.felt import encode_shortstring
9+
from starknet_py.constants import FIELD_PRIME
810
from starknet_py.hash.hash_method import HashMethod
911
from starknet_py.hash.selector import get_selector_from_name
1012
from starknet_py.net.client_utils import _to_rpc_felt
@@ -84,6 +86,43 @@ class BasicType(Enum):
8486
CONTRACT_ADDRESS = "ContractAddress"
8587
CLASS_HASH = "ClassHash"
8688
BOOL = "bool"
89+
U128 = "u128"
90+
I128 = "i128"
91+
TIMESTAMP = "timestamp"
92+
93+
94+
def _encode_value_v1(basic_type: BasicType, value: Union[int, str]) -> Optional[int]:
95+
if basic_type in (
96+
BasicType.FELT,
97+
BasicType.SHORT_STRING,
98+
BasicType.CONTRACT_ADDRESS,
99+
BasicType.CLASS_HASH,
100+
) and isinstance(value, (int, str)):
101+
return parse_felt(value)
102+
103+
if basic_type in (
104+
BasicType.U128,
105+
BasicType.TIMESTAMP,
106+
) and isinstance(value, (int, str)):
107+
return encode_u128(value)
108+
109+
if basic_type == BasicType.I128 and isinstance(value, (int, str)):
110+
return encode_i128(value)
111+
112+
return None
113+
114+
115+
def _encode_value_v0(
116+
basic_type: BasicType,
117+
value: Union[int, str],
118+
) -> Optional[int]:
119+
if basic_type in (
120+
BasicType.FELT,
121+
BasicType.STRING,
122+
) and isinstance(value, (int, str)):
123+
return parse_felt(value)
124+
125+
return None
87126

88127

89128
@dataclass(frozen=True)
@@ -128,7 +167,6 @@ def to_dict(self) -> dict:
128167
def _is_struct(self, type_name: str) -> bool:
129168
return type_name in self.types
130169

131-
# pylint: disable=too-many-return-statements
132170
def _encode_value(
133171
self,
134172
type_name: str,
@@ -148,15 +186,18 @@ def _encode_value(
148186

149187
basic_type = BasicType(type_name)
150188

151-
if (basic_type, self.domain.resolved_revision) in [
152-
(BasicType.FELT, Revision.V0),
153-
(BasicType.FELT, Revision.V1),
154-
(BasicType.STRING, Revision.V0),
155-
(BasicType.SHORT_STRING, Revision.V1),
156-
(BasicType.CONTRACT_ADDRESS, Revision.V1),
157-
(BasicType.CLASS_HASH, Revision.V1),
158-
] and isinstance(value, (int, str)):
159-
return parse_felt(value)
189+
encoded_value = None
190+
if self.domain.resolved_revision == Revision.V0 and isinstance(
191+
value, (str, int)
192+
):
193+
encoded_value = _encode_value_v0(basic_type, value)
194+
elif self.domain.resolved_revision == Revision.V1 and isinstance(
195+
value, (str, int)
196+
):
197+
encoded_value = _encode_value_v1(basic_type, value)
198+
199+
if encoded_value is not None:
200+
return encoded_value
160201

161202
if basic_type == BasicType.BOOL and isinstance(value, (bool, str, int)):
162203
return encode_bool(value)
@@ -360,6 +401,54 @@ def encode_bool(value: Union[bool, str, int]) -> int:
360401
raise ValueError(f"Expected boolean value, got [{value}].")
361402

362403

404+
def is_digit_string(s: str, signed=False) -> bool:
405+
if signed:
406+
return bool(re.fullmatch(r"-?\d+", s))
407+
return bool(re.fullmatch(r"\d+", s))
408+
409+
410+
def encode_u128(value: Union[str, int]) -> int:
411+
def is_in_range(n: int):
412+
return 0 <= n < 2**128
413+
414+
if isinstance(value, str) and value.startswith("0x"):
415+
int_value = int(value, 16)
416+
elif isinstance(value, str) and is_digit_string(value):
417+
int_value = int(value)
418+
elif isinstance(value, int):
419+
int_value = value
420+
else:
421+
raise ValueError(f"Value [{value}] is not a valid number.")
422+
423+
if is_in_range(int_value):
424+
return int_value
425+
raise ValueError(f"Value [{value}] is out of range for '{BasicType.U128}'.")
426+
427+
428+
def encode_i128(value: Union[str, int]) -> int:
429+
def is_in_range(n: int):
430+
return (n < 2**127) or (n >= (FIELD_PRIME - (2**127)))
431+
432+
if isinstance(value, str) and value.startswith("0x"):
433+
int_value = int(value, 16)
434+
elif isinstance(value, str) and is_digit_string(value, True):
435+
int_value = int(value)
436+
elif isinstance(value, int):
437+
int_value = value
438+
else:
439+
raise ValueError(f"Value [{value}] is not a valid number.")
440+
441+
if abs(int_value) >= FIELD_PRIME:
442+
raise ValueError(
443+
f"Values outside the range (-FIELD_PRIME, FIELD_PRIME) are not allowed, [{value}] given."
444+
)
445+
int_value %= FIELD_PRIME
446+
447+
if is_in_range(int_value):
448+
return int_value
449+
raise ValueError(f"Value [{value}] is out of range for '{BasicType.I128}'.")
450+
451+
363452
def _get_basic_type_names(revision: Revision) -> List[str]:
364453
basic_types_v0 = [
365454
BasicType.FELT,
@@ -373,6 +462,9 @@ def _get_basic_type_names(revision: Revision) -> List[str]:
373462
BasicType.SHORT_STRING,
374463
BasicType.CONTRACT_ADDRESS,
375464
BasicType.CLASS_HASH,
465+
BasicType.U128,
466+
BasicType.I128,
467+
BasicType.TIMESTAMP,
376468
]
377469

378470
basic_types = basic_types_v0 if revision == Revision.V0 else basic_types_v1

starknet_py/utils/typed_data_test.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
Parameter,
1717
TypedData,
1818
encode_bool,
19+
encode_i128,
20+
encode_u128,
1921
parse_felt,
2022
)
2123

@@ -320,3 +322,79 @@ def test_encode_bool(value: Union[bool, str, int], expected: int):
320322
def test_encode_invalid_bool(value: Union[bool, str, int]):
321323
with pytest.raises(ValueError, match=fr"Expected boolean value, got \[{value}\]."):
322324
encode_bool(value)
325+
326+
327+
@pytest.mark.parametrize(
328+
"value, expected",
329+
[
330+
(0, 0),
331+
(1, 1),
332+
(1000000, 1000000),
333+
("0x0", 0),
334+
("0x1", 1),
335+
("0x64", 100),
336+
(2 ** 128 - 1, 2 ** 128 - 1),
337+
]
338+
)
339+
def test_encode_u128(value: Union[str, int], expected: str):
340+
assert encode_u128(value) == expected
341+
342+
343+
@pytest.mark.parametrize(
344+
"value",
345+
[
346+
-1,
347+
"-1",
348+
1.23,
349+
-1.23,
350+
"1.23",
351+
"-1.23",
352+
"example",
353+
"0xwrong",
354+
2 ** 128,
355+
hex(2 ** 128),
356+
]
357+
)
358+
def test_encode_invalid_u128(value: Union[str, int]):
359+
with pytest.raises(ValueError):
360+
encode_u128(value)
361+
362+
363+
@pytest.mark.parametrize(
364+
"value, expected",
365+
[
366+
(0, 0),
367+
(1, 1),
368+
(1000000, 1000000),
369+
("0x0", 0),
370+
("0x1", 1),
371+
("0x64", 100),
372+
(2 ** 127 - 1, 2 ** 127 - 1),
373+
(-1, 3618502788666131213697322783095070105623107215331596699973092056135872020480),
374+
(-1000000, 3618502788666131213697322783095070105623107215331596699973092056135871020481),
375+
(-(2 ** 127), 3618502788666131213697322783095070105452966031871127468241404752419987914753),
376+
]
377+
)
378+
def test_encode_i128(value: Union[str, int], expected: str):
379+
assert encode_i128(value) == expected
380+
381+
382+
@pytest.mark.parametrize(
383+
"value",
384+
[
385+
"example",
386+
"0xwrong",
387+
1.23,
388+
-1.23,
389+
"1.23",
390+
"-1.23",
391+
-2 ** 127 - 1,
392+
2 ** 127,
393+
str(-2 ** 127 - 1),
394+
str(2 ** 127),
395+
396+
]
397+
)
398+
def test_encode_invalid_i128(value: Union[str, int]):
399+
with pytest.raises(ValueError):
400+
encode_i128(value)

0 commit comments

Comments
 (0)