Skip to content

Commit 89d4c6b

Browse files
committed
support generic dataclasses
1 parent fa6c289 commit 89d4c6b

File tree

2 files changed

+73
-4
lines changed

2 files changed

+73
-4
lines changed

marshmallow_dataclass/__init__.py

Lines changed: 46 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -313,7 +313,9 @@ def class_schema(
313313
>>> class_schema(Custom)().load({})
314314
Custom(name=None)
315315
"""
316-
if not dataclasses.is_dataclass(clazz):
316+
if not dataclasses.is_dataclass(clazz) and not _is_generic_alias_of_dataclass(
317+
clazz
318+
):
317319
clazz = dataclasses.dataclass(clazz)
318320
return _internal_class_schema(clazz, base_schema)
319321

@@ -323,8 +325,7 @@ def _internal_class_schema(
323325
clazz: type, base_schema: Optional[Type[marshmallow.Schema]] = None
324326
) -> Type[marshmallow.Schema]:
325327
try:
326-
# noinspection PyDataclass
327-
fields: Tuple[dataclasses.Field, ...] = dataclasses.fields(clazz)
328+
class_name, fields = _dataclass_name_and_fields(clazz)
328329
except TypeError: # Not a dataclass
329330
try:
330331
warnings.warn(
@@ -363,7 +364,7 @@ def _internal_class_schema(
363364
if field.init
364365
)
365366

366-
schema_class = type(clazz.__name__, (_base_schema(clazz, base_schema),), attributes)
367+
schema_class = type(class_name, (_base_schema(clazz, base_schema),), attributes)
367368
return cast(Type[marshmallow.Schema], schema_class)
368369

369370

@@ -662,6 +663,47 @@ def _get_field_default(field: dataclasses.Field):
662663
return field.default
663664

664665

666+
def _is_generic_alias_of_dataclass(clazz: type) -> bool:
667+
"""
668+
Check if given class is a generic alias of a dataclass, if the dataclass is
669+
defined as `class A(Generic[T])`, this method will return true if `A[int]` is passed
670+
"""
671+
return typing_inspect.is_generic_type(clazz) and dataclasses.is_dataclass(
672+
typing_inspect.get_origin(clazz)
673+
)
674+
675+
676+
# noinspection PyDataclass
677+
def _dataclass_name_and_fields(
678+
clazz: type,
679+
) -> Tuple[str, Tuple[dataclasses.Field, ...]]:
680+
if not _is_generic_alias_of_dataclass(clazz):
681+
return clazz.__name__, dataclasses.fields(clazz)
682+
683+
base_dataclass = typing_inspect.get_origin(clazz)
684+
base_parameters = typing_inspect.get_parameters(base_dataclass)
685+
type_arguments = typing_inspect.get_args(clazz)
686+
params_to_args = dict(zip(base_parameters, type_arguments))
687+
non_generic_fields = [ # swap generic typed fields with types in given type arguments
688+
(
689+
f.name,
690+
params_to_args.get(f.type, f.type),
691+
dataclasses.field(
692+
default=f.default,
693+
# ignoring mypy: https://github.com/python/mypy/issues/6910
694+
default_factory=f.default_factory, # type: ignore
695+
init=f.init,
696+
metadata=f.metadata,
697+
),
698+
)
699+
for f in dataclasses.fields(base_dataclass)
700+
]
701+
non_generic_dataclass = dataclasses.make_dataclass(
702+
cls_name=f"{base_dataclass.__name__}{type_arguments}", fields=non_generic_fields
703+
)
704+
return base_dataclass.__name__, dataclasses.fields(non_generic_dataclass)
705+
706+
665707
def NewType(
666708
name: str,
667709
typ: Type[_U],

tests/test_class_schema.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -324,6 +324,33 @@ class J:
324324
[validator_a, validator_b, validator_c, validator_d],
325325
)
326326

327+
def test_generic_dataclass(self):
328+
T = typing.TypeVar("T")
329+
330+
@dataclasses.dataclass
331+
class SimpleGeneric(typing.Generic[T]):
332+
data: T
333+
334+
@dataclasses.dataclass
335+
class Nested:
336+
data: SimpleGeneric[int]
337+
338+
schema_s = class_schema(SimpleGeneric[str])()
339+
self.assertEqual(SimpleGeneric(data="a"), schema_s.load({"data": "a"}))
340+
self.assertEqual(schema_s.dump(SimpleGeneric(data="a")), {"data": "a"})
341+
with self.assertRaises(ValidationError):
342+
schema_s.load({"data": 2})
343+
344+
schema_n = class_schema(Nested)()
345+
self.assertEqual(
346+
Nested(data=SimpleGeneric(1)), schema_n.load({"data": {"data": 1}})
347+
)
348+
self.assertEqual(
349+
schema_n.dump(Nested(data=SimpleGeneric(data=1))), {"data": {"data": 1}}
350+
)
351+
with self.assertRaises(ValidationError):
352+
schema_n.load({"data": {"data": "str"}})
353+
327354

328355
if __name__ == "__main__":
329356
unittest.main()

0 commit comments

Comments
 (0)