@@ -34,6 +34,7 @@ class User:
3434 })
3535 Schema: ClassVar[Type[Schema]] = Schema # For the type checker
3636"""
37+
3738import collections .abc
3839import dataclasses
3940import inspect
@@ -43,14 +44,11 @@ class User:
4344import warnings
4445from enum import Enum
4546from functools import lru_cache , partial
47+ from typing import Any , Callable , Dict , FrozenSet , List , Mapping
48+ from typing import NewType as typing_NewType
4649from typing import (
47- Any ,
48- Callable ,
49- Dict ,
50- List ,
51- Mapping ,
52- NewType as typing_NewType ,
5350 Optional ,
51+ Sequence ,
5452 Set ,
5553 Tuple ,
5654 Type ,
@@ -59,16 +57,13 @@ class User:
5957 cast ,
6058 get_type_hints ,
6159 overload ,
62- Sequence ,
63- FrozenSet ,
6460)
6561
6662import marshmallow
6763import typing_inspect
6864
6965from marshmallow_dataclass .lazy_class_attribute import lazy_class_attribute
7066
71-
7267if sys .version_info >= (3 , 11 ):
7368 from typing import dataclass_transform
7469elif sys .version_info >= (3 , 7 ):
@@ -105,8 +100,7 @@ def dataclass(
105100 frozen : bool = False ,
106101 base_schema : Optional [Type [marshmallow .Schema ]] = None ,
107102 cls_frame : Optional [types .FrameType ] = None ,
108- ) -> Type [_U ]:
109- ...
103+ ) -> Type [_U ]: ...
110104
111105
112106@overload
@@ -119,8 +113,7 @@ def dataclass(
119113 frozen : bool = False ,
120114 base_schema : Optional [Type [marshmallow .Schema ]] = None ,
121115 cls_frame : Optional [types .FrameType ] = None ,
122- ) -> Callable [[Type [_U ]], Type [_U ]]:
123- ...
116+ ) -> Callable [[Type [_U ]], Type [_U ]]: ...
124117
125118
126119# _cls should never be specified by keyword, so start it with an
@@ -179,24 +172,21 @@ def dataclass(
179172
180173
181174@overload
182- def add_schema (_cls : Type [_U ]) -> Type [_U ]:
183- ...
175+ def add_schema (_cls : Type [_U ]) -> Type [_U ]: ...
184176
185177
186178@overload
187179def add_schema (
188180 base_schema : Optional [Type [marshmallow .Schema ]] = None ,
189- ) -> Callable [[Type [_U ]], Type [_U ]]:
190- ...
181+ ) -> Callable [[Type [_U ]], Type [_U ]]: ...
191182
192183
193184@overload
194185def add_schema (
195186 _cls : Type [_U ],
196187 base_schema : Optional [Type [marshmallow .Schema ]] = None ,
197188 cls_frame : Optional [types .FrameType ] = None ,
198- ) -> Type [_U ]:
199- ...
189+ ) -> Type [_U ]: ...
200190
201191
202192def add_schema (_cls = None , base_schema = None , cls_frame = None ):
@@ -386,20 +376,27 @@ def class_schema(
386376 del current_frame
387377 _RECURSION_GUARD .seen_classes = {}
388378 try :
389- return _internal_class_schema (clazz , base_schema , clazz_frame )
379+ return _internal_class_schema (clazz , base_schema , clazz_frame , None )
390380 finally :
391381 _RECURSION_GUARD .seen_classes .clear ()
392382
393383
384+ def _dataclass_fields (clazz : type ) -> Tuple [dataclasses .Field , ...]:
385+ if _is_generic_alias_of_dataclass (clazz ):
386+ clazz = typing_inspect .get_origin (clazz )
387+ return dataclasses .fields (clazz )
388+
389+
394390@lru_cache (maxsize = MAX_CLASS_SCHEMA_CACHE_SIZE )
395391def _internal_class_schema (
396392 clazz : type ,
397393 base_schema : Optional [Type [marshmallow .Schema ]] = None ,
398394 clazz_frame : Optional [types .FrameType ] = None ,
395+ generic_params_to_args : Optional [Tuple [Tuple [type , type ], ...]] = None ,
399396) -> Type [marshmallow .Schema ]:
400397 _RECURSION_GUARD .seen_classes [clazz ] = clazz .__name__
401398 try :
402- class_name , fields = _dataclass_name_and_fields (clazz )
399+ fields = _dataclass_fields (clazz )
403400 except TypeError : # Not a dataclass
404401 try :
405402 warnings .warn (
@@ -414,7 +411,9 @@ def _internal_class_schema(
414411 "****** WARNING ******"
415412 )
416413 created_dataclass : type = dataclasses .dataclass (clazz )
417- return _internal_class_schema (created_dataclass , base_schema , clazz_frame )
414+ return _internal_class_schema (
415+ created_dataclass , base_schema , clazz_frame , generic_params_to_args
416+ )
418417 except Exception as exc :
419418 raise TypeError (
420419 f"{ getattr (clazz , '__name__' , repr (clazz ))} is not a dataclass and cannot be turned into one."
@@ -430,10 +429,11 @@ def _internal_class_schema(
430429 # Determine whether we should include non-init fields
431430 include_non_init = getattr (getattr (clazz , "Meta" , None ), "include_non_init" , False )
432431
432+ if _is_generic_alias_of_dataclass (clazz ) and generic_params_to_args is None :
433+ generic_params_to_args = _generic_params_to_args (clazz )
434+
435+ type_hints = _dataclass_type_hints (clazz , clazz_frame , generic_params_to_args )
433436 # Update the schema members to contain marshmallow fields instead of dataclass fields
434- type_hints = get_type_hints (
435- clazz , localns = clazz_frame .f_locals if clazz_frame else None
436- )
437437 attributes .update (
438438 (
439439 field .name ,
@@ -443,13 +443,14 @@ def _internal_class_schema(
443443 field .metadata ,
444444 base_schema ,
445445 clazz_frame ,
446+ generic_params_to_args ,
446447 ),
447448 )
448449 for field in fields
449450 if field .init or include_non_init
450451 )
451452
452- schema_class = type (class_name , (_base_schema (clazz , base_schema ),), attributes )
453+ schema_class = type (clazz . __name__ , (_base_schema (clazz , base_schema ),), attributes )
453454 return cast (Type [marshmallow .Schema ], schema_class )
454455
455456
@@ -584,7 +585,7 @@ def _field_for_generic_type(
584585 ),
585586 )
586587 return tuple_type (children , ** metadata )
587- elif origin in (dict , Dict , collections .abc .Mapping , Mapping ):
588+ if origin in (dict , Dict , collections .abc .Mapping , Mapping ):
588589 dict_type = type_mapping .get (Dict , marshmallow .fields .Dict )
589590 return dict_type (
590591 keys = field_for_schema (
@@ -636,6 +637,7 @@ def field_for_schema(
636637 metadata : Optional [Mapping [str , Any ]] = None ,
637638 base_schema : Optional [Type [marshmallow .Schema ]] = None ,
638639 typ_frame : Optional [types .FrameType ] = None ,
640+ generic_params_to_args : Optional [Tuple [Tuple [type , type ], ...]] = None ,
639641) -> marshmallow .fields .Field :
640642 """
641643 Get a marshmallow Field corresponding to the given python type.
@@ -769,7 +771,7 @@ def field_for_schema(
769771 nested_schema
770772 or forward_reference
771773 or _RECURSION_GUARD .seen_classes .get (typ )
772- or _internal_class_schema (typ , base_schema , typ_frame ) # type: ignore [arg-type]
774+ or _internal_class_schema (typ , base_schema , typ_frame , generic_params_to_args ) # type: ignore [arg-type]
773775 )
774776
775777 return marshmallow .fields .Nested (nested , ** metadata )
@@ -823,35 +825,33 @@ def _is_generic_alias_of_dataclass(clazz: type) -> bool:
823825 )
824826
825827
826- # noinspection PyDataclass
827- def _dataclass_name_and_fields (
828- clazz : type ,
829- ) -> Tuple [str , Tuple [dataclasses .Field , ...]]:
830- if not _is_generic_alias_of_dataclass (clazz ):
831- return clazz .__name__ , dataclasses .fields (clazz )
832-
828+ def _generic_params_to_args (clazz : type ) -> Tuple [Tuple [type , type ], ...]:
833829 base_dataclass = typing_inspect .get_origin (clazz )
834830 base_parameters = typing_inspect .get_parameters (base_dataclass )
835831 type_arguments = typing_inspect .get_args (clazz )
836- params_to_args = dict (zip (base_parameters , type_arguments ))
837- non_generic_fields = [ # swap generic typed fields with types in given type arguments
838- (
839- f .name ,
840- params_to_args .get (f .type , f .type ),
841- dataclasses .field (
842- default = f .default ,
843- # ignoring mypy: https://github.com/python/mypy/issues/6910
844- default_factory = f .default_factory , # type: ignore
845- init = f .init ,
846- metadata = f .metadata ,
847- ),
848- )
849- for f in dataclasses .fields (base_dataclass )
850- ]
851- non_generic_dataclass = dataclasses .make_dataclass (
852- cls_name = f"{ base_dataclass .__name__ } { type_arguments } " , fields = non_generic_fields
853- )
854- return base_dataclass .__name__ , dataclasses .fields (non_generic_dataclass )
832+ return tuple (zip (base_parameters , type_arguments ))
833+
834+
835+ def _dataclass_type_hints (
836+ clazz : type ,
837+ clazz_frame : types .FrameType = None ,
838+ generic_params_to_args : Optional [Tuple [Tuple [type , type ], ...]] = None ,
839+ ) -> Mapping [str , type ]:
840+ localns = clazz_frame .f_locals if clazz_frame else None
841+ if not _is_generic_alias_of_dataclass (clazz ):
842+ return get_type_hints (clazz , localns = localns )
843+ # dataclass is generic
844+ generic_type_hints = get_type_hints (typing_inspect .get_origin (clazz ), localns )
845+ generic_params_map = dict (generic_params_to_args if generic_params_to_args else {})
846+
847+ def _get_hint (_t : type ) -> type :
848+ if isinstance (_t , TypeVar ):
849+ return generic_params_map [_t ]
850+ return _t
851+
852+ return {
853+ field_name : _get_hint (typ ) for field_name , typ in generic_type_hints .items ()
854+ }
855855
856856
857857def NewType (
0 commit comments