Skip to content

Commit 69b0f33

Browse files
Verify types in TypedData in line with SNIP-12 (#1369)
1 parent 7480837 commit 69b0f33

File tree

2 files changed

+125
-12
lines changed

2 files changed

+125
-12
lines changed

starknet_py/utils/typed_data.py

Lines changed: 45 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,9 @@ def _encode_value(
132132
hashes = [self._encode_value(type_name, val) for val in value]
133133
return compute_hash_on_elements(hashes)
134134

135+
if type_name not in _get_basic_type_names(self.domain.resolved_revision):
136+
raise ValueError(f"Type [{type_name}] is not defined in types.")
137+
135138
basic_type = BasicType(type_name)
136139

137140
if basic_type == BasicType.MERKLE_TREE and isinstance(value, list):
@@ -164,12 +167,34 @@ def _encode_data(self, type_name: str, data: dict) -> List[int]:
164167
return values
165168

166169
def _verify_types(self):
167-
reserved_type_names = ["felt", "felt*", "string", "selector", "merkletree"]
170+
if self.domain.separator_name not in self.types:
171+
raise ValueError(f"Types must contain '{self.domain.separator_name}'.")
172+
173+
basic_type_names = _get_basic_type_names(self.domain.resolved_revision)
168174

169-
for type_name in reserved_type_names:
175+
for type_name in basic_type_names:
170176
if type_name in self.types:
171177
raise ValueError(f"Reserved type name: {type_name}")
172178

179+
referenced_types = {
180+
ref_type.contains
181+
if ref_type.contains is not None
182+
else strip_pointer(ref_type.type)
183+
for type_name in self.types
184+
for ref_type in self.types[type_name]
185+
}
186+
referenced_types.update([self.domain.separator_name, self.primary_type])
187+
188+
for type_name in self.types:
189+
if not type_name:
190+
raise ValueError("Type names cannot be empty.")
191+
if is_pointer(type_name):
192+
raise ValueError(f"Type names cannot end in *. {type_name} was found.")
193+
if type_name not in referenced_types:
194+
raise ValueError(
195+
f"Dangling types are not allowed. Unreferenced type {type_name} was found."
196+
)
197+
173198
def _get_dependencies(self, type_name: str) -> List[str]:
174199
if type_name not in self.types:
175200
# type_name is a primitive type, has no dependencies
@@ -280,7 +305,7 @@ def get_hex(value: Union[int, str]) -> str:
280305

281306

282307
def is_pointer(value: str) -> bool:
283-
return len(value) > 0 and value[-1] == "*"
308+
return value.endswith("*")
284309

285310

286311
def strip_pointer(value: str) -> str:
@@ -306,9 +331,26 @@ class BasicType(Enum):
306331
FELT = "felt"
307332
SELECTOR = "selector"
308333
MERKLE_TREE = "merkletree"
334+
STRING = "string"
309335
SHORT_STRING = "shortstring"
310336

311337

338+
def _get_basic_type_names(revision: Revision) -> List[str]:
339+
basic_types_v0 = [
340+
BasicType.FELT,
341+
BasicType.SELECTOR,
342+
BasicType.MERKLE_TREE,
343+
BasicType.STRING,
344+
]
345+
346+
basic_types_v1 = basic_types_v0 + [
347+
BasicType.SHORT_STRING,
348+
]
349+
350+
basic_types = basic_types_v0 if revision == Revision.V0 else basic_types_v1
351+
return [basic_type.value for basic_type in basic_types]
352+
353+
312354
# pylint: disable=unused-argument
313355
# pylint: disable=no-self-use
314356

starknet_py/utils/typed_data_test.py

Lines changed: 80 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,13 @@
99

1010
from starknet_py.net.models.typed_data import Revision
1111
from starknet_py.tests.e2e.fixtures.constants import TYPED_DATA_DIR
12-
from starknet_py.utils.typed_data import Domain, Parameter, TypedData, get_hex
12+
from starknet_py.utils.typed_data import (
13+
BasicType,
14+
Domain,
15+
Parameter,
16+
TypedData,
17+
get_hex,
18+
)
1319

1420

1521
class CasesRev0(Enum):
@@ -197,15 +203,80 @@ def _make_typed_data(included_type: str, revision: Revision):
197203

198204

199205
@pytest.mark.parametrize(
200-
"included_type",
206+
"included_type, revision",
207+
[
208+
("", Revision.V1),
209+
("myType*", Revision.V1)
210+
],
211+
)
212+
def test_invalid_type_names(included_type: str, revision: Revision):
213+
with pytest.raises(ValueError):
214+
_make_typed_data(included_type, revision)
215+
216+
217+
@pytest.mark.parametrize(
218+
"included_type, revision",
201219
[
202-
"felt",
203-
"felt*",
204-
"string",
205-
"selector",
206-
"merkletree"
220+
(BasicType.FELT.value, Revision.V0),
221+
(BasicType.STRING.value, Revision.V0),
222+
(BasicType.SELECTOR.value, Revision.V0),
223+
(BasicType.MERKLE_TREE.value, Revision.V0),
224+
(BasicType.FELT.value, Revision.V1),
225+
(BasicType.STRING.value, Revision.V1),
226+
(BasicType.SELECTOR.value, Revision.V1),
227+
(BasicType.MERKLE_TREE.value, Revision.V1),
228+
(BasicType.SHORT_STRING.value, Revision.V1),
207229
],
208230
)
209-
def test_invalid_types(included_type: str):
231+
def test_types_redefinition(included_type: str, revision: Revision):
210232
with pytest.raises(ValueError, match=f"Reserved type name: {included_type}"):
211-
_make_typed_data(included_type, Revision.V1)
233+
_make_typed_data(included_type, revision)
234+
235+
236+
def test_custom_type_definition():
237+
_make_typed_data("myType", Revision.V0)
238+
239+
240+
@pytest.mark.parametrize(
241+
"revision",
242+
list(Revision),
243+
)
244+
def test_missing_domain_type(revision: Revision):
245+
domain = domain_v0 if revision == Revision.V0 else domain_v1
246+
247+
with pytest.raises(ValueError, match=f"Types must contain '{domain.separator_name}'."):
248+
TypedData(
249+
types={},
250+
primary_type="felt",
251+
domain=domain,
252+
message={},
253+
)
254+
255+
256+
def test_dangling_type():
257+
with pytest.raises(ValueError, match="Dangling types are not allowed. Unreferenced type dangling was found."):
258+
TypedData(
259+
types={
260+
**domain_type_v1,
261+
"dangling": [],
262+
"mytype": []
263+
},
264+
primary_type="mytype",
265+
domain=domain_v1,
266+
message={"mytype": 1},
267+
)
268+
269+
270+
def test_missing_dependency():
271+
typed_data = TypedData(
272+
types={
273+
**domain_type_v1,
274+
"house": [Parameter(name="fridge", type="ice cream")]
275+
},
276+
primary_type="house",
277+
domain=domain_v1,
278+
message={"fridge": 1},
279+
)
280+
281+
with pytest.raises(ValueError, match=r"Type \[ice cream\] is not defined in types."):
282+
typed_data.struct_hash("house", {"fridge": 1})

0 commit comments

Comments
 (0)