@@ -313,7 +313,9 @@ def class_schema(
313313 >>> class_schema(Custom)().load({})
314314 Custom(name=None)
315315 """
316- if not dataclasses .is_dataclass (clazz ):
316+ if not dataclasses .is_dataclass (clazz ) and not _is_generic_alias_of_dataclass (
317+ clazz
318+ ):
317319 clazz = dataclasses .dataclass (clazz )
318320 return _internal_class_schema (clazz , base_schema )
319321
@@ -323,8 +325,7 @@ def _internal_class_schema(
323325 clazz : type , base_schema : Optional [Type [marshmallow .Schema ]] = None
324326) -> Type [marshmallow .Schema ]:
325327 try :
326- # noinspection PyDataclass
327- fields : Tuple [dataclasses .Field , ...] = dataclasses .fields (clazz )
328+ class_name , fields = _dataclass_name_and_fields (clazz )
328329 except TypeError : # Not a dataclass
329330 try :
330331 warnings .warn (
@@ -363,7 +364,7 @@ def _internal_class_schema(
363364 if field .init
364365 )
365366
366- schema_class = type (clazz . __name__ , (_base_schema (clazz , base_schema ),), attributes )
367+ schema_class = type (class_name , (_base_schema (clazz , base_schema ),), attributes )
367368 return cast (Type [marshmallow .Schema ], schema_class )
368369
369370
@@ -662,6 +663,47 @@ def _get_field_default(field: dataclasses.Field):
662663 return field .default
663664
664665
666+ def _is_generic_alias_of_dataclass (clazz : type ) -> bool :
667+ """
668+ Check if given class is a generic alias of a dataclass, if the dataclass is
669+ defined as `class A(Generic[T])`, this method will return true if `A[int]` is passed
670+ """
671+ return typing_inspect .is_generic_type (clazz ) and dataclasses .is_dataclass (
672+ typing_inspect .get_origin (clazz )
673+ )
674+
675+
676+ # noinspection PyDataclass
677+ def _dataclass_name_and_fields (
678+ clazz : type ,
679+ ) -> Tuple [str , Tuple [dataclasses .Field , ...]]:
680+ if not _is_generic_alias_of_dataclass (clazz ):
681+ return clazz .__name__ , dataclasses .fields (clazz )
682+
683+ base_dataclass = typing_inspect .get_origin (clazz )
684+ base_parameters = typing_inspect .get_parameters (base_dataclass )
685+ type_arguments = typing_inspect .get_args (clazz )
686+ params_to_args = dict (zip (base_parameters , type_arguments ))
687+ non_generic_fields = [ # swap generic typed fields with types in given type arguments
688+ (
689+ f .name ,
690+ params_to_args .get (f .type , f .type ),
691+ dataclasses .field (
692+ default = f .default ,
693+ # ignoring mypy: https://github.com/python/mypy/issues/6910
694+ default_factory = f .default_factory , # type: ignore
695+ init = f .init ,
696+ metadata = f .metadata ,
697+ ),
698+ )
699+ for f in dataclasses .fields (base_dataclass )
700+ ]
701+ non_generic_dataclass = dataclasses .make_dataclass (
702+ cls_name = f"{ base_dataclass .__name__ } { type_arguments } " , fields = non_generic_fields
703+ )
704+ return base_dataclass .__name__ , dataclasses .fields (non_generic_dataclass )
705+
706+
665707def NewType (
666708 name : str ,
667709 typ : Type [_U ],
0 commit comments