@@ -374,7 +374,9 @@ def class_schema(
374374 >>> class_schema(Custom)().load({})
375375 Custom(name=None)
376376 """
377- if not dataclasses .is_dataclass (clazz ):
377+ if not dataclasses .is_dataclass (clazz ) and not _is_generic_alias_of_dataclass (
378+ clazz
379+ ):
378380 clazz = dataclasses .dataclass (clazz )
379381 if not clazz_frame :
380382 current_frame = inspect .currentframe ()
@@ -397,8 +399,7 @@ def _internal_class_schema(
397399) -> Type [marshmallow .Schema ]:
398400 _RECURSION_GUARD .seen_classes [clazz ] = clazz .__name__
399401 try :
400- # noinspection PyDataclass
401- fields : Tuple [dataclasses .Field , ...] = dataclasses .fields (clazz )
402+ class_name , fields = _dataclass_name_and_fields (clazz )
402403 except TypeError : # Not a dataclass
403404 try :
404405 warnings .warn (
@@ -448,7 +449,7 @@ def _internal_class_schema(
448449 if field .init or include_non_init
449450 )
450451
451- schema_class = type (clazz . __name__ , (_base_schema (clazz , base_schema ),), attributes )
452+ schema_class = type (class_name , (_base_schema (clazz , base_schema ),), attributes )
452453 return cast (Type [marshmallow .Schema ], schema_class )
453454
454455
@@ -812,6 +813,47 @@ def _get_field_default(field: dataclasses.Field):
812813 return field .default
813814
814815
816+ def _is_generic_alias_of_dataclass (clazz : type ) -> bool :
817+ """
818+ Check if given class is a generic alias of a dataclass, if the dataclass is
819+ defined as `class A(Generic[T])`, this method will return true if `A[int]` is passed
820+ """
821+ return typing_inspect .is_generic_type (clazz ) and dataclasses .is_dataclass (
822+ typing_inspect .get_origin (clazz )
823+ )
824+
825+
826+ # noinspection PyDataclass
827+ def _dataclass_name_and_fields (
828+ clazz : type ,
829+ ) -> Tuple [str , Tuple [dataclasses .Field , ...]]:
830+ if not _is_generic_alias_of_dataclass (clazz ):
831+ return clazz .__name__ , dataclasses .fields (clazz )
832+
833+ base_dataclass = typing_inspect .get_origin (clazz )
834+ base_parameters = typing_inspect .get_parameters (base_dataclass )
835+ type_arguments = typing_inspect .get_args (clazz )
836+ params_to_args = dict (zip (base_parameters , type_arguments ))
837+ non_generic_fields = [ # swap generic typed fields with types in given type arguments
838+ (
839+ f .name ,
840+ params_to_args .get (f .type , f .type ),
841+ dataclasses .field (
842+ default = f .default ,
843+ # ignoring mypy: https://github.com/python/mypy/issues/6910
844+ default_factory = f .default_factory , # type: ignore
845+ init = f .init ,
846+ metadata = f .metadata ,
847+ ),
848+ )
849+ for f in dataclasses .fields (base_dataclass )
850+ ]
851+ non_generic_dataclass = dataclasses .make_dataclass (
852+ cls_name = f"{ base_dataclass .__name__ } { type_arguments } " , fields = non_generic_fields
853+ )
854+ return base_dataclass .__name__ , dataclasses .fields (non_generic_dataclass )
855+
856+
815857def NewType (
816858 name : str ,
817859 typ : Type [_U ],
0 commit comments