Skip to content

Commit 1565660

Browse files
Support preset types in TypedData (#1377)
1 parent 6508f33 commit 1565660

File tree

3 files changed

+180
-82
lines changed

3 files changed

+180
-82
lines changed
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
{
2+
"types": {
3+
"StarknetDomain": [
4+
{ "name": "name", "type": "shortstring" },
5+
{ "name": "version", "type": "shortstring" },
6+
{ "name": "chainId", "type": "shortstring" },
7+
{ "name": "revision", "type": "shortstring" }
8+
],
9+
"Example": [
10+
{ "name": "n0", "type": "TokenAmount" },
11+
{ "name": "n1", "type": "NftId" }
12+
]
13+
},
14+
"primaryType": "Example",
15+
"domain": {
16+
"name": "StarkNet Mail",
17+
"version": "1",
18+
"chainId": "1",
19+
"revision": "1"
20+
},
21+
"message": {
22+
"n0": {
23+
"token_address": "0x049d36570d4e46f48e99674bd3fcc84644ddd6b96f7c741b1562b82f9e004dc7",
24+
"amount": {
25+
"low": "0x3e8",
26+
"high": "0x0"
27+
}
28+
},
29+
"n1": {
30+
"collection_address": "0x049d36570d4e46f48e99674bd3fcc84644ddd6b96f7c741b1562b82f9e004dc7",
31+
"token_id": {
32+
"low": "0x3e8",
33+
"high": "0x0"
34+
}
35+
}
36+
}
37+
}

starknet_py/utils/typed_data.py

Lines changed: 56 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,12 @@ class BasicType(Enum):
106106
TIMESTAMP = "timestamp"
107107

108108

109+
class PresetType(Enum):
110+
U256 = "u256"
111+
TOKEN_AMOUNT = "TokenAmount"
112+
NFT_ID = "NftId"
113+
114+
109115
@dataclass(frozen=True)
110116
class TypedData:
111117
"""
@@ -120,6 +126,14 @@ class TypedData:
120126
def __post_init__(self):
121127
self._verify_types()
122128

129+
@property
130+
def _all_types(self):
131+
preset_types = _get_preset_types(self.domain.resolved_revision)
132+
return {
133+
**preset_types,
134+
**self.types,
135+
}
136+
123137
@property
124138
def _hash_method(self) -> HashMethod:
125139
if self.domain.resolved_revision == Revision.V0:
@@ -202,7 +216,7 @@ def _encode_value(
202216
value: Union[int, str, dict, list],
203217
context: Optional[TypeContext] = None,
204218
) -> int:
205-
if type_name in self.types and isinstance(value, dict):
219+
if type_name in self._all_types and isinstance(value, dict):
206220
return self.struct_hash(type_name, value)
207221

208222
if is_pointer(type_name) and isinstance(value, list):
@@ -245,7 +259,7 @@ def _encode_value(
245259

246260
def _encode_data(self, type_name: str, data: dict) -> List[int]:
247261
values = []
248-
for param in self.types[type_name]:
262+
for param in self._all_types[type_name]:
249263
encoded_value = self._encode_value(
250264
param.type,
251265
data[param.name],
@@ -273,13 +287,21 @@ def _verify_types(self):
273287
referenced_types.update([self.domain.separator_name, self.primary_type])
274288

275289
basic_type_names = _get_basic_type_names(self.domain.resolved_revision)
290+
preset_type_names = _get_preset_types(self.domain.resolved_revision).keys()
276291

277292
for type_name in self.types:
278293
if not type_name:
279294
raise ValueError("Type names cannot be empty.")
280295

281296
if type_name in basic_type_names:
282-
raise ValueError(f"Reserved type name: {type_name}")
297+
raise ValueError(
298+
f"Types must not contain basic types. [{type_name}] was found."
299+
)
300+
301+
if type_name in preset_type_names:
302+
raise ValueError(
303+
f"Types must not contain preset types. [{type_name}] was found."
304+
)
283305

284306
if is_pointer(type_name):
285307
raise ValueError(
@@ -318,7 +340,7 @@ def _get_dependencies(self, type_name: str) -> List[str]:
318340

319341
while to_visit:
320342
current_type = to_visit.pop(0)
321-
params = self.types.get(current_type, [])
343+
params = self._all_types.get(current_type, [])
322344

323345
for param in params:
324346
if isinstance(param, EnumParameter):
@@ -333,7 +355,7 @@ def _get_dependencies(self, type_name: str) -> List[str]:
333355
]
334356
for extracted_type in extracted_types:
335357
if (
336-
extracted_type in self.types
358+
extracted_type in self._all_types
337359
and extracted_type not in dependencies
338360
):
339361
dependencies.append(extracted_type)
@@ -351,11 +373,11 @@ def escape(s: str) -> str:
351373
return s
352374
return f'"{s}"'
353375

354-
if dependency not in self.types:
376+
if dependency not in self._all_types:
355377
raise ValueError(f"Dependency [{dependency}] is not defined in types.")
356378

357379
encoded_params = []
358-
for param in self.types[dependency]:
380+
for param in self._all_types[dependency]:
359381
target_type = (
360382
param.contains
361383
if isinstance(param, EnumParameter)
@@ -433,10 +455,10 @@ def _get_merkle_tree_leaves_type(self, context: TypeContext) -> str:
433455
def _resolve_type(self, context: TypeContext) -> Parameter:
434456
parent, key = context.parent, context.key
435457

436-
if parent not in self.types:
458+
if parent not in self._all_types:
437459
raise ValueError(f"Parent {parent} is not defined in types.")
438460

439-
parent_type = self.types[parent]
461+
parent_type = self._all_types[parent]
440462

441463
target_type = next((item for item in parent_type if item.name == key), None)
442464
if target_type is None:
@@ -480,10 +502,10 @@ def _get_enum_variants(self, context: TypeContext) -> List[Parameter]:
480502
enum_type = self._resolve_type(context)
481503
if not isinstance(enum_type, EnumParameter):
482504
raise ValueError(f"Type [{context.key}] is not an enum.")
483-
if enum_type.contains not in self.types:
505+
if enum_type.contains not in self._all_types:
484506
raise ValueError(f"Type [{enum_type.contains}] is not defined in types")
485507

486-
return self.types[enum_type.contains]
508+
return self._all_types[enum_type.contains]
487509

488510
def _encode_long_string(self, value: str) -> int:
489511
byte_array_serializer = ByteArraySerializer()
@@ -604,20 +626,34 @@ def _get_basic_type_names(revision: Revision) -> List[str]:
604626
BasicType.BOOL,
605627
]
606628

607-
basic_types_v1 = basic_types_v0 + [
608-
BasicType.SHORT_STRING,
609-
BasicType.CONTRACT_ADDRESS,
610-
BasicType.CLASS_HASH,
611-
BasicType.U128,
612-
BasicType.I128,
613-
BasicType.TIMESTAMP,
614-
BasicType.ENUM,
615-
]
629+
basic_types_v1 = list(BasicType)
616630

617631
basic_types = basic_types_v0 if revision == Revision.V0 else basic_types_v1
618632
return [basic_type.value for basic_type in basic_types]
619633

620634

635+
def _get_preset_types(
636+
revision: Revision,
637+
) -> Dict[str, List[StandardParameter]]:
638+
if revision == Revision.V0:
639+
return {}
640+
641+
return {
642+
PresetType.U256.value: [
643+
StandardParameter(name="low", type="u128"),
644+
StandardParameter(name="high", type="u128"),
645+
],
646+
PresetType.TOKEN_AMOUNT.value: [
647+
StandardParameter(name="token_address", type="ContractAddress"),
648+
StandardParameter(name="amount", type="u256"),
649+
],
650+
PresetType.NFT_ID.value: [
651+
StandardParameter(name="collection_address", type="ContractAddress"),
652+
StandardParameter(name="token_id", type="u256"),
653+
],
654+
}
655+
656+
621657
# pylint: disable=unused-argument
622658
# pylint: disable=no-self-use
623659

0 commit comments

Comments
 (0)