@@ -356,20 +356,27 @@ def class_schema(
356356 del current_frame
357357 _RECURSION_GUARD .seen_classes = {}
358358 try :
359- return _internal_class_schema (clazz , base_schema , clazz_frame )
359+ return _internal_class_schema (clazz , base_schema , clazz_frame , None )
360360 finally :
361361 _RECURSION_GUARD .seen_classes .clear ()
362362
363363
364+ def _dataclass_fields (clazz : type ) -> Tuple [dataclasses .Field , ...]:
365+ if _is_generic_alias_of_dataclass (clazz ):
366+ clazz = typing_inspect .get_origin (clazz )
367+ return dataclasses .fields (clazz )
368+
369+
364370@lru_cache (maxsize = MAX_CLASS_SCHEMA_CACHE_SIZE )
365371def _internal_class_schema (
366372 clazz : type ,
367373 base_schema : Optional [Type [marshmallow .Schema ]] = None ,
368374 clazz_frame : types .FrameType = None ,
375+ generic_params_to_args : Optional [Tuple [Tuple [type , type ], ...]] = None ,
369376) -> Type [marshmallow .Schema ]:
370377 _RECURSION_GUARD .seen_classes [clazz ] = clazz .__name__
371378 try :
372- class_name , fields = _dataclass_name_and_fields (clazz )
379+ fields = _dataclass_fields (clazz )
373380 except TypeError : # Not a dataclass
374381 try :
375382 warnings .warn (
@@ -384,7 +391,9 @@ def _internal_class_schema(
384391 "****** WARNING ******"
385392 )
386393 created_dataclass : type = dataclasses .dataclass (clazz )
387- return _internal_class_schema (created_dataclass , base_schema , clazz_frame )
394+ return _internal_class_schema (
395+ created_dataclass , base_schema , clazz_frame , generic_params_to_args
396+ )
388397 except Exception as exc :
389398 raise TypeError (
390399 f"{ getattr (clazz , '__name__' , repr (clazz ))} is not a dataclass and cannot be turned into one."
@@ -397,10 +406,11 @@ def _internal_class_schema(
397406 if hasattr (v , "__marshmallow_hook__" ) or k in MEMBERS_WHITELIST
398407 }
399408
409+ if _is_generic_alias_of_dataclass (clazz ) and generic_params_to_args is None :
410+ generic_params_to_args = _generic_params_to_args (clazz )
411+
412+ type_hints = _dataclass_type_hints (clazz , clazz_frame , generic_params_to_args )
400413 # Update the schema members to contain marshmallow fields instead of dataclass fields
401- type_hints = get_type_hints (
402- clazz , localns = clazz_frame .f_locals if clazz_frame else None
403- )
404414 attributes .update (
405415 (
406416 field .name ,
@@ -410,13 +420,14 @@ def _internal_class_schema(
410420 field .metadata ,
411421 base_schema ,
412422 clazz_frame ,
423+ generic_params_to_args ,
413424 ),
414425 )
415426 for field in fields
416427 if field .init
417428 )
418429
419- schema_class = type (class_name , (_base_schema (clazz , base_schema ),), attributes )
430+ schema_class = type (clazz . __name__ , (_base_schema (clazz , base_schema ),), attributes )
420431 return cast (Type [marshmallow .Schema ], schema_class )
421432
422433
@@ -551,7 +562,7 @@ def _field_for_generic_type(
551562 ),
552563 )
553564 return tuple_type (children , ** metadata )
554- elif origin in (dict , Dict , collections .abc .Mapping , Mapping ):
565+ if origin in (dict , Dict , collections .abc .Mapping , Mapping ):
555566 dict_type = type_mapping .get (Dict , marshmallow .fields .Dict )
556567 return dict_type (
557568 keys = field_for_schema (
@@ -603,6 +614,7 @@ def field_for_schema(
603614 metadata : Mapping [str , Any ] = None ,
604615 base_schema : Optional [Type [marshmallow .Schema ]] = None ,
605616 typ_frame : Optional [types .FrameType ] = None ,
617+ generic_params_to_args : Optional [Tuple [Tuple [type , type ], ...]] = None ,
606618) -> marshmallow .fields .Field :
607619 """
608620 Get a marshmallow Field corresponding to the given python type.
@@ -732,7 +744,7 @@ def field_for_schema(
732744 nested_schema
733745 or forward_reference
734746 or _RECURSION_GUARD .seen_classes .get (typ )
735- or _internal_class_schema (typ , base_schema , typ_frame )
747+ or _internal_class_schema (typ , base_schema , typ_frame , generic_params_to_args )
736748 )
737749
738750 return marshmallow .fields .Nested (nested , ** metadata )
@@ -786,35 +798,33 @@ def _is_generic_alias_of_dataclass(clazz: type) -> bool:
786798 )
787799
788800
789- # noinspection PyDataclass
790- def _dataclass_name_and_fields (
791- clazz : type ,
792- ) -> Tuple [str , Tuple [dataclasses .Field , ...]]:
793- if not _is_generic_alias_of_dataclass (clazz ):
794- return clazz .__name__ , dataclasses .fields (clazz )
795-
801+ def _generic_params_to_args (clazz : type ) -> Tuple [Tuple [type , type ], ...]:
796802 base_dataclass = typing_inspect .get_origin (clazz )
797803 base_parameters = typing_inspect .get_parameters (base_dataclass )
798804 type_arguments = typing_inspect .get_args (clazz )
799- params_to_args = dict (zip (base_parameters , type_arguments ))
800- non_generic_fields = [ # swap generic typed fields with types in given type arguments
801- (
802- f .name ,
803- params_to_args .get (f .type , f .type ),
804- dataclasses .field (
805- default = f .default ,
806- # ignoring mypy: https://github.com/python/mypy/issues/6910
807- default_factory = f .default_factory , # type: ignore
808- init = f .init ,
809- metadata = f .metadata ,
810- ),
811- )
812- for f in dataclasses .fields (base_dataclass )
813- ]
814- non_generic_dataclass = dataclasses .make_dataclass (
815- cls_name = f"{ base_dataclass .__name__ } { type_arguments } " , fields = non_generic_fields
816- )
817- return base_dataclass .__name__ , dataclasses .fields (non_generic_dataclass )
805+ return tuple (zip (base_parameters , type_arguments ))
806+
807+
808+ def _dataclass_type_hints (
809+ clazz : type ,
810+ clazz_frame : types .FrameType = None ,
811+ generic_params_to_args : Optional [Tuple [Tuple [type , type ], ...]] = None ,
812+ ) -> Mapping [str , type ]:
813+ localns = clazz_frame .f_locals if clazz_frame else None
814+ if not _is_generic_alias_of_dataclass (clazz ):
815+ return get_type_hints (clazz , localns = localns )
816+ # dataclass is generic
817+ generic_type_hints = get_type_hints (typing_inspect .get_origin (clazz ), localns )
818+ generic_params_map = dict (generic_params_to_args if generic_params_to_args else {})
819+
820+ def _get_hint (_t : type ) -> type :
821+ if isinstance (_t , TypeVar ):
822+ return generic_params_map [_t ]
823+ return _t
824+
825+ return {
826+ field_name : _get_hint (typ ) for field_name , typ in generic_type_hints .items ()
827+ }
818828
819829
820830def NewType (
0 commit comments