Skip to content

Commit b3d3797

Browse files
committed
support nested generic dataclasses
1 parent 86f6d84 commit b3d3797

File tree

2 files changed

+73
-41
lines changed

2 files changed

+73
-41
lines changed

marshmallow_dataclass/__init__.py

Lines changed: 45 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -356,20 +356,27 @@ def class_schema(
356356
del current_frame
357357
_RECURSION_GUARD.seen_classes = {}
358358
try:
359-
return _internal_class_schema(clazz, base_schema, clazz_frame)
359+
return _internal_class_schema(clazz, base_schema, clazz_frame, None)
360360
finally:
361361
_RECURSION_GUARD.seen_classes.clear()
362362

363363

364+
def _dataclass_fields(clazz: type) -> Tuple[dataclasses.Field, ...]:
365+
if _is_generic_alias_of_dataclass(clazz):
366+
clazz = typing_inspect.get_origin(clazz)
367+
return dataclasses.fields(clazz)
368+
369+
364370
@lru_cache(maxsize=MAX_CLASS_SCHEMA_CACHE_SIZE)
365371
def _internal_class_schema(
366372
clazz: type,
367373
base_schema: Optional[Type[marshmallow.Schema]] = None,
368374
clazz_frame: types.FrameType = None,
375+
generic_params_to_args: Optional[Tuple[Tuple[type, type], ...]] = None,
369376
) -> Type[marshmallow.Schema]:
370377
_RECURSION_GUARD.seen_classes[clazz] = clazz.__name__
371378
try:
372-
class_name, fields = _dataclass_name_and_fields(clazz)
379+
fields = _dataclass_fields(clazz)
373380
except TypeError: # Not a dataclass
374381
try:
375382
warnings.warn(
@@ -384,7 +391,9 @@ def _internal_class_schema(
384391
"****** WARNING ******"
385392
)
386393
created_dataclass: type = dataclasses.dataclass(clazz)
387-
return _internal_class_schema(created_dataclass, base_schema, clazz_frame)
394+
return _internal_class_schema(
395+
created_dataclass, base_schema, clazz_frame, generic_params_to_args
396+
)
388397
except Exception as exc:
389398
raise TypeError(
390399
f"{getattr(clazz, '__name__', repr(clazz))} is not a dataclass and cannot be turned into one."
@@ -397,10 +406,11 @@ def _internal_class_schema(
397406
if hasattr(v, "__marshmallow_hook__") or k in MEMBERS_WHITELIST
398407
}
399408

409+
if _is_generic_alias_of_dataclass(clazz) and generic_params_to_args is None:
410+
generic_params_to_args = _generic_params_to_args(clazz)
411+
412+
type_hints = _dataclass_type_hints(clazz, clazz_frame, generic_params_to_args)
400413
# Update the schema members to contain marshmallow fields instead of dataclass fields
401-
type_hints = get_type_hints(
402-
clazz, localns=clazz_frame.f_locals if clazz_frame else None
403-
)
404414
attributes.update(
405415
(
406416
field.name,
@@ -410,13 +420,14 @@ def _internal_class_schema(
410420
field.metadata,
411421
base_schema,
412422
clazz_frame,
423+
generic_params_to_args,
413424
),
414425
)
415426
for field in fields
416427
if field.init
417428
)
418429

419-
schema_class = type(class_name, (_base_schema(clazz, base_schema),), attributes)
430+
schema_class = type(clazz.__name__, (_base_schema(clazz, base_schema),), attributes)
420431
return cast(Type[marshmallow.Schema], schema_class)
421432

422433

@@ -551,7 +562,7 @@ def _field_for_generic_type(
551562
),
552563
)
553564
return tuple_type(children, **metadata)
554-
elif origin in (dict, Dict, collections.abc.Mapping, Mapping):
565+
if origin in (dict, Dict, collections.abc.Mapping, Mapping):
555566
dict_type = type_mapping.get(Dict, marshmallow.fields.Dict)
556567
return dict_type(
557568
keys=field_for_schema(
@@ -603,6 +614,7 @@ def field_for_schema(
603614
metadata: Mapping[str, Any] = None,
604615
base_schema: Optional[Type[marshmallow.Schema]] = None,
605616
typ_frame: Optional[types.FrameType] = None,
617+
generic_params_to_args: Optional[Tuple[Tuple[type, type], ...]] = None,
606618
) -> marshmallow.fields.Field:
607619
"""
608620
Get a marshmallow Field corresponding to the given python type.
@@ -732,7 +744,7 @@ def field_for_schema(
732744
nested_schema
733745
or forward_reference
734746
or _RECURSION_GUARD.seen_classes.get(typ)
735-
or _internal_class_schema(typ, base_schema, typ_frame)
747+
or _internal_class_schema(typ, base_schema, typ_frame, generic_params_to_args)
736748
)
737749

738750
return marshmallow.fields.Nested(nested, **metadata)
@@ -786,35 +798,33 @@ def _is_generic_alias_of_dataclass(clazz: type) -> bool:
786798
)
787799

788800

789-
# noinspection PyDataclass
790-
def _dataclass_name_and_fields(
791-
clazz: type,
792-
) -> Tuple[str, Tuple[dataclasses.Field, ...]]:
793-
if not _is_generic_alias_of_dataclass(clazz):
794-
return clazz.__name__, dataclasses.fields(clazz)
795-
801+
def _generic_params_to_args(clazz: type) -> Tuple[Tuple[type, type], ...]:
796802
base_dataclass = typing_inspect.get_origin(clazz)
797803
base_parameters = typing_inspect.get_parameters(base_dataclass)
798804
type_arguments = typing_inspect.get_args(clazz)
799-
params_to_args = dict(zip(base_parameters, type_arguments))
800-
non_generic_fields = [ # swap generic typed fields with types in given type arguments
801-
(
802-
f.name,
803-
params_to_args.get(f.type, f.type),
804-
dataclasses.field(
805-
default=f.default,
806-
# ignoring mypy: https://github.com/python/mypy/issues/6910
807-
default_factory=f.default_factory, # type: ignore
808-
init=f.init,
809-
metadata=f.metadata,
810-
),
811-
)
812-
for f in dataclasses.fields(base_dataclass)
813-
]
814-
non_generic_dataclass = dataclasses.make_dataclass(
815-
cls_name=f"{base_dataclass.__name__}{type_arguments}", fields=non_generic_fields
816-
)
817-
return base_dataclass.__name__, dataclasses.fields(non_generic_dataclass)
805+
return tuple(zip(base_parameters, type_arguments))
806+
807+
808+
def _dataclass_type_hints(
809+
clazz: type,
810+
clazz_frame: types.FrameType = None,
811+
generic_params_to_args: Optional[Tuple[Tuple[type, type], ...]] = None,
812+
) -> Mapping[str, type]:
813+
localns = clazz_frame.f_locals if clazz_frame else None
814+
if not _is_generic_alias_of_dataclass(clazz):
815+
return get_type_hints(clazz, localns=localns)
816+
# dataclass is generic
817+
generic_type_hints = get_type_hints(typing_inspect.get_origin(clazz), localns)
818+
generic_params_map = dict(generic_params_to_args if generic_params_to_args else {})
819+
820+
def _get_hint(_t: type) -> type:
821+
if isinstance(_t, TypeVar):
822+
return generic_params_map[_t]
823+
return _t
824+
825+
return {
826+
field_name: _get_hint(typ) for field_name, typ in generic_type_hints.items()
827+
}
818828

819829

820830
def NewType(

tests/test_class_schema.py

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from marshmallow.fields import Field, UUID as UUIDField, List as ListField, Integer
1515
from marshmallow.validate import Validator
1616

17-
from marshmallow_dataclass import class_schema, NewType
17+
from marshmallow_dataclass import class_schema, NewType, _is_generic_alias_of_dataclass
1818

1919

2020
class TestClassSchema(unittest.TestCase):
@@ -409,24 +409,45 @@ class SimpleGeneric(typing.Generic[T]):
409409
data: T
410410

411411
@dataclasses.dataclass
412-
class Nested:
412+
class NestedFixed:
413413
data: SimpleGeneric[int]
414414

415+
@dataclasses.dataclass
416+
class NestedGeneric(typing.Generic[T]):
417+
data: SimpleGeneric[T]
418+
419+
self.assertTrue(_is_generic_alias_of_dataclass(SimpleGeneric[int]))
420+
self.assertFalse(_is_generic_alias_of_dataclass(SimpleGeneric))
421+
415422
schema_s = class_schema(SimpleGeneric[str])()
416423
self.assertEqual(SimpleGeneric(data="a"), schema_s.load({"data": "a"}))
417424
self.assertEqual(schema_s.dump(SimpleGeneric(data="a")), {"data": "a"})
418425
with self.assertRaises(ValidationError):
419426
schema_s.load({"data": 2})
420427

421-
schema_n = class_schema(Nested)()
428+
schema_nested = class_schema(NestedFixed)()
429+
self.assertEqual(
430+
NestedFixed(data=SimpleGeneric(1)),
431+
schema_nested.load({"data": {"data": 1}}),
432+
)
433+
self.assertEqual(
434+
schema_nested.dump(NestedFixed(data=SimpleGeneric(data=1))),
435+
{"data": {"data": 1}},
436+
)
437+
with self.assertRaises(ValidationError):
438+
schema_nested.load({"data": {"data": "str"}})
439+
440+
schema_nested_generic = class_schema(NestedGeneric[int])()
422441
self.assertEqual(
423-
Nested(data=SimpleGeneric(1)), schema_n.load({"data": {"data": 1}})
442+
NestedGeneric(data=SimpleGeneric(1)),
443+
schema_nested_generic.load({"data": {"data": 1}}),
424444
)
425445
self.assertEqual(
426-
schema_n.dump(Nested(data=SimpleGeneric(data=1))), {"data": {"data": 1}}
446+
schema_nested_generic.dump(NestedGeneric(data=SimpleGeneric(data=1))),
447+
{"data": {"data": 1}},
427448
)
428449
with self.assertRaises(ValidationError):
429-
schema_n.load({"data": {"data": "str"}})
450+
schema_nested_generic.load({"data": {"data": "str"}})
430451

431452
def test_recursive_reference(self):
432453
@dataclasses.dataclass
@@ -461,5 +482,6 @@ class Second:
461482
{"first": {"second": {"first": None}}},
462483
)
463484

485+
464486
if __name__ == "__main__":
465487
unittest.main()

0 commit comments

Comments
 (0)