@@ -106,6 +106,12 @@ class BasicType(Enum):
106106 TIMESTAMP = "timestamp"
107107
108108
109+ class PresetType (Enum ):
110+ U256 = "u256"
111+ TOKEN_AMOUNT = "TokenAmount"
112+ NFT_ID = "NftId"
113+
114+
109115@dataclass (frozen = True )
110116class TypedData :
111117 """
@@ -120,6 +126,14 @@ class TypedData:
120126 def __post_init__ (self ):
121127 self ._verify_types ()
122128
129+ @property
130+ def _all_types (self ):
131+ preset_types = _get_preset_types (self .domain .resolved_revision )
132+ return {
133+ ** preset_types ,
134+ ** self .types ,
135+ }
136+
123137 @property
124138 def _hash_method (self ) -> HashMethod :
125139 if self .domain .resolved_revision == Revision .V0 :
@@ -202,7 +216,7 @@ def _encode_value(
202216 value : Union [int , str , dict , list ],
203217 context : Optional [TypeContext ] = None ,
204218 ) -> int :
205- if type_name in self .types and isinstance (value , dict ):
219+ if type_name in self ._all_types and isinstance (value , dict ):
206220 return self .struct_hash (type_name , value )
207221
208222 if is_pointer (type_name ) and isinstance (value , list ):
@@ -245,7 +259,7 @@ def _encode_value(
245259
246260 def _encode_data (self , type_name : str , data : dict ) -> List [int ]:
247261 values = []
248- for param in self .types [type_name ]:
262+ for param in self ._all_types [type_name ]:
249263 encoded_value = self ._encode_value (
250264 param .type ,
251265 data [param .name ],
@@ -273,13 +287,21 @@ def _verify_types(self):
273287 referenced_types .update ([self .domain .separator_name , self .primary_type ])
274288
275289 basic_type_names = _get_basic_type_names (self .domain .resolved_revision )
290+ preset_type_names = _get_preset_types (self .domain .resolved_revision ).keys ()
276291
277292 for type_name in self .types :
278293 if not type_name :
279294 raise ValueError ("Type names cannot be empty." )
280295
281296 if type_name in basic_type_names :
282- raise ValueError (f"Reserved type name: { type_name } " )
297+ raise ValueError (
298+ f"Types must not contain basic types. [{ type_name } ] was found."
299+ )
300+
301+ if type_name in preset_type_names :
302+ raise ValueError (
303+ f"Types must not contain preset types. [{ type_name } ] was found."
304+ )
283305
284306 if is_pointer (type_name ):
285307 raise ValueError (
@@ -318,7 +340,7 @@ def _get_dependencies(self, type_name: str) -> List[str]:
318340
319341 while to_visit :
320342 current_type = to_visit .pop (0 )
321- params = self .types .get (current_type , [])
343+ params = self ._all_types .get (current_type , [])
322344
323345 for param in params :
324346 if isinstance (param , EnumParameter ):
@@ -333,7 +355,7 @@ def _get_dependencies(self, type_name: str) -> List[str]:
333355 ]
334356 for extracted_type in extracted_types :
335357 if (
336- extracted_type in self .types
358+ extracted_type in self ._all_types
337359 and extracted_type not in dependencies
338360 ):
339361 dependencies .append (extracted_type )
@@ -351,11 +373,11 @@ def escape(s: str) -> str:
351373 return s
352374 return f'"{ s } "'
353375
354- if dependency not in self .types :
376+ if dependency not in self ._all_types :
355377 raise ValueError (f"Dependency [{ dependency } ] is not defined in types." )
356378
357379 encoded_params = []
358- for param in self .types [dependency ]:
380+ for param in self ._all_types [dependency ]:
359381 target_type = (
360382 param .contains
361383 if isinstance (param , EnumParameter )
@@ -433,10 +455,10 @@ def _get_merkle_tree_leaves_type(self, context: TypeContext) -> str:
433455 def _resolve_type (self , context : TypeContext ) -> Parameter :
434456 parent , key = context .parent , context .key
435457
436- if parent not in self .types :
458+ if parent not in self ._all_types :
437459 raise ValueError (f"Parent { parent } is not defined in types." )
438460
439- parent_type = self .types [parent ]
461+ parent_type = self ._all_types [parent ]
440462
441463 target_type = next ((item for item in parent_type if item .name == key ), None )
442464 if target_type is None :
@@ -480,10 +502,10 @@ def _get_enum_variants(self, context: TypeContext) -> List[Parameter]:
480502 enum_type = self ._resolve_type (context )
481503 if not isinstance (enum_type , EnumParameter ):
482504 raise ValueError (f"Type [{ context .key } ] is not an enum." )
483- if enum_type .contains not in self .types :
505+ if enum_type .contains not in self ._all_types :
484506 raise ValueError (f"Type [{ enum_type .contains } ] is not defined in types" )
485507
486- return self .types [enum_type .contains ]
508+ return self ._all_types [enum_type .contains ]
487509
488510 def _encode_long_string (self , value : str ) -> int :
489511 byte_array_serializer = ByteArraySerializer ()
@@ -604,20 +626,34 @@ def _get_basic_type_names(revision: Revision) -> List[str]:
604626 BasicType .BOOL ,
605627 ]
606628
607- basic_types_v1 = basic_types_v0 + [
608- BasicType .SHORT_STRING ,
609- BasicType .CONTRACT_ADDRESS ,
610- BasicType .CLASS_HASH ,
611- BasicType .U128 ,
612- BasicType .I128 ,
613- BasicType .TIMESTAMP ,
614- BasicType .ENUM ,
615- ]
629+ basic_types_v1 = list (BasicType )
616630
617631 basic_types = basic_types_v0 if revision == Revision .V0 else basic_types_v1
618632 return [basic_type .value for basic_type in basic_types ]
619633
620634
635+ def _get_preset_types (
636+ revision : Revision ,
637+ ) -> Dict [str , List [StandardParameter ]]:
638+ if revision == Revision .V0 :
639+ return {}
640+
641+ return {
642+ PresetType .U256 .value : [
643+ StandardParameter (name = "low" , type = "u128" ),
644+ StandardParameter (name = "high" , type = "u128" ),
645+ ],
646+ PresetType .TOKEN_AMOUNT .value : [
647+ StandardParameter (name = "token_address" , type = "ContractAddress" ),
648+ StandardParameter (name = "amount" , type = "u256" ),
649+ ],
650+ PresetType .NFT_ID .value : [
651+ StandardParameter (name = "collection_address" , type = "ContractAddress" ),
652+ StandardParameter (name = "token_id" , type = "u256" ),
653+ ],
654+ }
655+
656+
621657# pylint: disable=unused-argument
622658# pylint: disable=no-self-use
623659
0 commit comments