@@ -34,6 +34,7 @@ class User:
3434 })
3535 Schema: ClassVar[Type[Schema]] = Schema # For the type checker
3636"""
37+
3738import collections .abc
3839import dataclasses
3940import inspect
@@ -61,13 +62,12 @@ class User:
6162 TypeVar ,
6263 Union ,
6364 cast ,
64- get_args ,
65- get_origin ,
6665 get_type_hints ,
6766 overload ,
6867)
6968
7069import marshmallow
70+ import typing_extensions
7171import typing_inspect
7272
7373from marshmallow_dataclass .lazy_class_attribute import lazy_class_attribute
@@ -151,8 +151,7 @@ def dataclass(
151151 frozen : bool = False ,
152152 base_schema : Optional [Type [marshmallow .Schema ]] = None ,
153153 cls_frame : Optional [types .FrameType ] = None ,
154- ) -> Type [_U ]:
155- ...
154+ ) -> Type [_U ]: ...
156155
157156
158157@overload
@@ -165,8 +164,7 @@ def dataclass(
165164 frozen : bool = False ,
166165 base_schema : Optional [Type [marshmallow .Schema ]] = None ,
167166 cls_frame : Optional [types .FrameType ] = None ,
168- ) -> Callable [[Type [_U ]], Type [_U ]]:
169- ...
167+ ) -> Callable [[Type [_U ]], Type [_U ]]: ...
170168
171169
172170# _cls should never be specified by keyword, so start it with an
@@ -225,15 +223,13 @@ def decorator(cls: Type[_U], stacklevel: int = 1) -> Type[_U]:
225223
226224
227225@overload
228- def add_schema (_cls : Type [_U ]) -> Type [_U ]:
229- ...
226+ def add_schema (_cls : Type [_U ]) -> Type [_U ]: ...
230227
231228
232229@overload
233230def add_schema (
234231 base_schema : Optional [Type [marshmallow .Schema ]] = None ,
235- ) -> Callable [[Type [_U ]], Type [_U ]]:
236- ...
232+ ) -> Callable [[Type [_U ]], Type [_U ]]: ...
237233
238234
239235@overload
@@ -242,8 +238,7 @@ def add_schema(
242238 base_schema : Optional [Type [marshmallow .Schema ]] = None ,
243239 cls_frame : Optional [types .FrameType ] = None ,
244240 stacklevel : int = 1 ,
245- ) -> Type [_U ]:
246- ...
241+ ) -> Type [_U ]: ...
247242
248243
249244def add_schema (_cls = None , base_schema = None , cls_frame = None , stacklevel = 1 ):
@@ -294,8 +289,7 @@ def class_schema(
294289 * ,
295290 globalns : Optional [Dict [str , Any ]] = None ,
296291 localns : Optional [Dict [str , Any ]] = None ,
297- ) -> Type [marshmallow .Schema ]:
298- ...
292+ ) -> Type [marshmallow .Schema ]: ...
299293
300294
301295@overload
@@ -305,8 +299,7 @@ def class_schema(
305299 clazz_frame : Optional [types .FrameType ] = None ,
306300 * ,
307301 globalns : Optional [Dict [str , Any ]] = None ,
308- ) -> Type [marshmallow .Schema ]:
309- ...
302+ ) -> Type [marshmallow .Schema ]: ...
310303
311304
312305def class_schema (
@@ -514,7 +507,15 @@ def _internal_class_schema(
514507 base_schema : Optional [Type [marshmallow .Schema ]] = None ,
515508) -> Type [marshmallow .Schema ]:
516509 schema_ctx = _schema_ctx_stack .top
517- schema_ctx .seen_classes [clazz ] = clazz .__name__
510+
511+ if typing_extensions .get_origin (clazz ) is Annotated and sys .version_info < (3 , 10 ):
512+ # https://github.com/python/cpython/blob/3.10/Lib/typing.py#L977
513+ class_name = clazz ._name or clazz .__origin__ .__name__ # type: ignore[attr-defined]
514+ else :
515+ class_name = clazz .__name__
516+
517+ schema_ctx .seen_classes [clazz ] = class_name
518+
518519 try :
519520 # noinspection PyDataclass
520521 fields : Tuple [dataclasses .Field , ...] = dataclasses .fields (clazz )
@@ -549,9 +550,18 @@ def _internal_class_schema(
549550 include_non_init = getattr (getattr (clazz , "Meta" , None ), "include_non_init" , False )
550551
551552 # Update the schema members to contain marshmallow fields instead of dataclass fields
552- type_hints = get_type_hints (
553- clazz , globalns = schema_ctx .globalns , localns = schema_ctx .localns , include_extras = True ,
554- )
553+
554+ if sys .version_info >= (3 , 9 ):
555+ type_hints = get_type_hints (
556+ clazz ,
557+ globalns = schema_ctx .globalns ,
558+ localns = schema_ctx .localns ,
559+ include_extras = True ,
560+ )
561+ else :
562+ type_hints = get_type_hints (
563+ clazz , globalns = schema_ctx .globalns , localns = schema_ctx .localns
564+ )
555565 attributes .update (
556566 (
557567 field .name ,
@@ -642,8 +652,8 @@ def _field_for_generic_type(
642652 """
643653 If the type is a generic interface, resolve the arguments and construct the appropriate Field.
644654 """
645- origin = get_origin (typ )
646- arguments = get_args (typ )
655+ origin = typing_extensions . get_origin (typ )
656+ arguments = typing_extensions . get_args (typ )
647657 if origin :
648658 # Override base_schema.TYPE_MAPPING to change the class used for generic types below
649659 type_mapping = base_schema .TYPE_MAPPING if base_schema else {}
@@ -889,7 +899,7 @@ def _field_for_schema(
889899 )
890900
891901 # enumerations
892- if issubclass (typ , Enum ):
902+ if inspect . isclass ( typ ) and issubclass (typ , Enum ):
893903 return marshmallow .fields .Enum (typ , ** metadata )
894904
895905 # Nested marshmallow dataclass
0 commit comments