77from starknet_py .cairo .felt import encode_shortstring
88from starknet_py .hash .hash_method import HashMethod
99from starknet_py .hash .selector import get_selector_from_name
10- from starknet_py .hash .utils import compute_hash_on_elements
1110from starknet_py .net .client_utils import _to_rpc_felt
1211from starknet_py .net .models .typed_data import DomainDict , Revision , TypedDataDict
1312from starknet_py .net .schemas .common import RevisionField
@@ -76,6 +75,17 @@ class TypeContext:
7675 key : str
7776
7877
78+ class BasicType (Enum ):
79+ FELT = "felt"
80+ SELECTOR = "selector"
81+ MERKLE_TREE = "merkletree"
82+ SHORT_STRING = "shortstring"
83+ STRING = "string"
84+ CONTRACT_ADDRESS = "ContractAddress"
85+ CLASS_HASH = "ClassHash"
86+ BOOL = "bool"
87+
88+
7989@dataclass (frozen = True )
8090class TypedData :
8191 """
@@ -118,6 +128,7 @@ def to_dict(self) -> dict:
118128 def _is_struct (self , type_name : str ) -> bool :
119129 return type_name in self .types
120130
131+ # pylint: disable=too-many-return-statements
121132 def _encode_value (
122133 self ,
123134 type_name : str ,
@@ -130,26 +141,34 @@ def _encode_value(
130141 if is_pointer (type_name ) and isinstance (value , list ):
131142 type_name = strip_pointer (type_name )
132143 hashes = [self ._encode_value (type_name , val ) for val in value ]
133- return compute_hash_on_elements (hashes )
144+ return self . _hash_method . hash_many (hashes )
134145
135146 if type_name not in _get_basic_type_names (self .domain .resolved_revision ):
136147 raise ValueError (f"Type [{ type_name } ] is not defined in types." )
137148
138149 basic_type = BasicType (type_name )
139150
140- if basic_type == BasicType .MERKLE_TREE and isinstance (value , list ):
141- if context is None :
142- raise ValueError (f"Context is not provided for '{ type_name } ' type." )
143- return self ._prepare_merkle_tree_root (value , context )
151+ if (basic_type , self .domain .resolved_revision ) in [
152+ (BasicType .FELT , Revision .V0 ),
153+ (BasicType .FELT , Revision .V1 ),
154+ (BasicType .STRING , Revision .V0 ),
155+ (BasicType .SHORT_STRING , Revision .V1 ),
156+ (BasicType .CONTRACT_ADDRESS , Revision .V1 ),
157+ (BasicType .CLASS_HASH , Revision .V1 ),
158+ ] and isinstance (value , (int , str )):
159+ return parse_felt (value )
144160
145- if basic_type in (BasicType .FELT , BasicType .SHORT_STRING ) and isinstance (
146- value , (int , str , Revision )
147- ):
148- return int (get_hex (value ), 16 )
161+ if basic_type == BasicType .BOOL and isinstance (value , (bool , str , int )):
162+ return encode_bool (value )
149163
150164 if basic_type == BasicType .SELECTOR and isinstance (value , str ):
151165 return prepare_selector (value )
152166
167+ if basic_type == BasicType .MERKLE_TREE and isinstance (value , list ):
168+ if context is None :
169+ raise ValueError (f"Context is not provided for '{ type_name } ' type." )
170+ return self ._prepare_merkle_tree_root (value , context )
171+
153172 raise ValueError (
154173 f"Error occurred while encoding value with type name { type_name } ."
155174 )
@@ -294,14 +313,14 @@ def _get_merkle_tree_leaves_type(self, context: TypeContext) -> str:
294313 return target_type .contains
295314
296315
297- def get_hex (value : Union [int , str ]) -> str :
316+ def parse_felt (value : Union [int , str ]) -> int :
298317 if isinstance (value , int ):
299- return hex (value )
300- if value [:2 ] == "0x" :
301318 return value
319+ if value .startswith ("0x" ):
320+ return int (value , 16 )
302321 if value .isnumeric ():
303- return hex ( int (value ) )
304- return hex ( encode_shortstring (value ) )
322+ return int (value )
323+ return encode_shortstring (value )
305324
306325
307326def is_pointer (value : str ) -> bool :
@@ -327,12 +346,18 @@ def prepare_selector(name: str) -> int:
327346 return get_selector_from_name (name )
328347
329348
330- class BasicType (Enum ):
331- FELT = "felt"
332- SELECTOR = "selector"
333- MERKLE_TREE = "merkletree"
334- STRING = "string"
335- SHORT_STRING = "shortstring"
349+ def encode_bool (value : Union [bool , str , int ]) -> int :
350+ if isinstance (value , bool ):
351+ return 1 if value else 0
352+ if isinstance (value , int ) and value in (0 , 1 ):
353+ return value
354+ if isinstance (value , str ) and value in ("0" , "1" ):
355+ return int (value )
356+ if isinstance (value , str ) and value in ("false" , "true" ):
357+ return 0 if value == "false" else 1
358+ if isinstance (value , str ) and value in ("0x0" , "0x1" ):
359+ return int (value , 16 )
360+ raise ValueError (f"Expected boolean value, got [{ value } ]." )
336361
337362
338363def _get_basic_type_names (revision : Revision ) -> List [str ]:
@@ -341,10 +366,13 @@ def _get_basic_type_names(revision: Revision) -> List[str]:
341366 BasicType .SELECTOR ,
342367 BasicType .MERKLE_TREE ,
343368 BasicType .STRING ,
369+ BasicType .BOOL ,
344370 ]
345371
346372 basic_types_v1 = basic_types_v0 + [
347373 BasicType .SHORT_STRING ,
374+ BasicType .CONTRACT_ADDRESS ,
375+ BasicType .CLASS_HASH ,
348376 ]
349377
350378 basic_types = basic_types_v0 if revision == Revision .V0 else basic_types_v1
0 commit comments