Skip to content

Commit 65ed180

Browse files
authored
Add support for Dict, List, set, frozenset, Tuple[T, ...], tuple[T, ...] (lovasoa#221)
* Add support for Dict, List, set, frozenset We already supported `dict` (without explicit generic type parameters), but did not support `Dict`. (Similarly for `list` and `List`.) We supported `typing.Set`, but not the builtin `set`. (Similarly for `typing.FrozenSet` and `frozenset`.) We now support all of the above. Fixes lovasoa#181. * Add support for homogeneous tuples (Tuple[T, ...]) We currently recognize `typing.Sequence[int]` as a homogeneous tuple of ints. This adds support for the (probably more correct) concrete type `tuple[int, ...]` (or `typing.Tuple[int, ...]`).
1 parent ef95ef4 commit 65ed180

File tree

2 files changed

+97
-7
lines changed

2 files changed

+97
-7
lines changed

marshmallow_dataclass/__init__.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -470,17 +470,17 @@ def _field_by_supertype(
470470

471471
def _generic_type_add_any(typ: type) -> type:
472472
"""if typ is generic type without arguments, replace them by Any."""
473-
if typ is list:
473+
if typ is list or typ is List:
474474
typ = List[Any]
475-
elif typ is dict:
475+
elif typ is dict or typ is Dict:
476476
typ = Dict[Any, Any]
477477
elif typ is Mapping:
478478
typ = Mapping[Any, Any]
479479
elif typ is Sequence:
480480
typ = Sequence[Any]
481-
elif typ is Set:
481+
elif typ is set or typ is Set:
482482
typ = Set[Any]
483-
elif typ is FrozenSet:
483+
elif typ is frozenset or typ is FrozenSet:
484484
typ = FrozenSet[Any]
485485
return typ
486486

@@ -509,7 +509,11 @@ def _field_for_generic_type(
509509
type_mapping.get(List, marshmallow.fields.List),
510510
)
511511
return list_type(child_type, **metadata)
512-
if origin in (collections.abc.Sequence, Sequence):
512+
if origin in (collections.abc.Sequence, Sequence) or (
513+
origin in (tuple, Tuple)
514+
and len(arguments) == 2
515+
and arguments[1] is Ellipsis
516+
):
513517
from . import collection_field
514518

515519
child_type = field_for_schema(

tests/test_field_for_schema.py

Lines changed: 88 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import inspect
2+
import sys
23
import typing
34
import unittest
45
from enum import Enum
@@ -53,6 +54,16 @@ def test_dict_from_typing(self):
5354
),
5455
)
5556

57+
def test_dict_from_typing_wo_args(self):
58+
self.assertFieldsEqual(
59+
field_for_schema(Dict),
60+
fields.Dict(
61+
keys=fields.Raw(required=True, allow_none=True),
62+
values=fields.Raw(required=True, allow_none=True),
63+
required=True,
64+
),
65+
)
66+
5667
def test_builtin_dict(self):
5768
self.assertFieldsEqual(
5869
field_for_schema(dict),
@@ -63,6 +74,21 @@ def test_builtin_dict(self):
6374
),
6475
)
6576

77+
def test_list_from_typing(self):
78+
self.assertFieldsEqual(
79+
field_for_schema(List[int]),
80+
fields.List(fields.Integer(required=True), required=True),
81+
)
82+
83+
def test_list_from_typing_wo_args(self):
84+
self.assertFieldsEqual(
85+
field_for_schema(List),
86+
fields.List(
87+
fields.Raw(required=True, allow_none=True),
88+
required=True,
89+
),
90+
)
91+
6692
def test_builtin_list(self):
6793
self.assertFieldsEqual(
6894
field_for_schema(list, metadata=dict(required=False)),
@@ -217,6 +243,13 @@ def test_mapping(self):
217243
)
218244

219245
def test_sequence(self):
246+
self.maxDiff = 2000
247+
self.assertFieldsEqual(
248+
field_for_schema(typing.Sequence[int]),
249+
collection_field.Sequence(fields.Integer(required=True), required=True),
250+
)
251+
252+
def test_sequence_wo_args(self):
220253
self.assertFieldsEqual(
221254
field_for_schema(typing.Sequence),
222255
collection_field.Sequence(
@@ -225,7 +258,30 @@ def test_sequence(self):
225258
),
226259
)
227260

228-
def test_set(self):
261+
def test_homogeneous_tuple_from_typing(self):
262+
self.assertFieldsEqual(
263+
field_for_schema(Tuple[str, ...]),
264+
collection_field.Sequence(fields.String(required=True), required=True),
265+
)
266+
267+
@unittest.skipIf(sys.version_info < (3, 9), "PEP 585 unsupported")
268+
def test_homogeneous_tuple(self):
269+
self.assertFieldsEqual(
270+
field_for_schema(tuple[float, ...]),
271+
collection_field.Sequence(fields.Float(required=True), required=True),
272+
)
273+
274+
def test_set_from_typing(self):
275+
self.assertFieldsEqual(
276+
field_for_schema(typing.Set[str]),
277+
collection_field.Set(
278+
fields.String(required=True),
279+
frozen=False,
280+
required=True,
281+
),
282+
)
283+
284+
def test_set_from_typing_wo_args(self):
229285
self.assertFieldsEqual(
230286
field_for_schema(typing.Set),
231287
collection_field.Set(
@@ -235,7 +291,27 @@ def test_set(self):
235291
),
236292
)
237293

238-
def test_frozenset(self):
294+
def test_builtin_set(self):
295+
self.assertFieldsEqual(
296+
field_for_schema(set),
297+
collection_field.Set(
298+
cls_or_instance=fields.Raw(required=True, allow_none=True),
299+
frozen=False,
300+
required=True,
301+
),
302+
)
303+
304+
def test_frozenset_from_typing(self):
305+
self.assertFieldsEqual(
306+
field_for_schema(typing.FrozenSet[int]),
307+
collection_field.Set(
308+
fields.Integer(required=True),
309+
frozen=True,
310+
required=True,
311+
),
312+
)
313+
314+
def test_frozenset_from_typing_wo_args(self):
239315
self.assertFieldsEqual(
240316
field_for_schema(typing.FrozenSet),
241317
collection_field.Set(
@@ -245,6 +321,16 @@ def test_frozenset(self):
245321
),
246322
)
247323

324+
def test_builtin_frozenset(self):
325+
self.assertFieldsEqual(
326+
field_for_schema(frozenset),
327+
collection_field.Set(
328+
cls_or_instance=fields.Raw(required=True, allow_none=True),
329+
frozen=True,
330+
required=True,
331+
),
332+
)
333+
248334

249335
if __name__ == "__main__":
250336
unittest.main()

0 commit comments

Comments
 (0)