@@ -46,6 +46,7 @@ class User:
4646 Any ,
4747 Callable ,
4848 Dict ,
49+ Generic ,
4950 List ,
5051 Mapping ,
5152 NewType as typing_NewType ,
@@ -79,9 +80,6 @@ class User:
7980# Max number of generated schemas that class_schema keeps of generated schemas. Removes duplicates.
8081MAX_CLASS_SCHEMA_CACHE_SIZE = 1024
8182
82- # Recursion guard for class_schema()
83- _RECURSION_GUARD = threading .local ()
84-
8583
8684@overload
8785def dataclass (
@@ -352,20 +350,56 @@ def class_schema(
352350 clazz_frame = current_frame .f_back
353351 # Per https://docs.python.org/3/library/inspect.html#the-interpreter-stack
354352 del current_frame
355- _RECURSION_GUARD .seen_classes = {}
356- try :
357- return _internal_class_schema (clazz , base_schema , clazz_frame )
358- finally :
359- _RECURSION_GUARD .seen_classes .clear ()
353+
354+ with _SchemaContext (clazz_frame ):
355+ return _internal_class_schema (clazz , base_schema )
356+
357+
358+ class _SchemaContext :
359+ """Global context for an invocation of class_schema."""
360+
361+ def __init__ (self , frame : Optional [types .FrameType ]):
362+ self .seen_classes : Dict [type , str ] = {}
363+ self .frame = frame
364+
365+ def get_type_hints (self , cls : Type ) -> Dict [str , Any ]:
366+ frame = self .frame
367+ localns = frame .f_locals if frame is not None else None
368+ return get_type_hints (cls , localns = localns )
369+
370+ def __enter__ (self ) -> "_SchemaContext" :
371+ _schema_ctx_stack .push (self )
372+ return self
373+
374+ def __exit__ (self , _typ , _value , _tb ) -> None :
375+ _schema_ctx_stack .pop ()
376+
377+
378+ class _LocalStack (threading .local , Generic [_U ]):
379+ def __init__ (self ):
380+ self .stack : List [_U ] = []
381+
382+ def push (self , value : _U ) -> None :
383+ self .stack .append (value )
384+
385+ def pop (self ) -> None :
386+ self .stack .pop ()
387+
388+ @property
389+ def top (self ) -> _U :
390+ return self .stack [- 1 ]
391+
392+
393+ _schema_ctx_stack = _LocalStack [_SchemaContext ]()
360394
361395
362396@lru_cache (maxsize = MAX_CLASS_SCHEMA_CACHE_SIZE )
363397def _internal_class_schema (
364398 clazz : type ,
365399 base_schema : Optional [Type [marshmallow .Schema ]] = None ,
366- clazz_frame : types .FrameType = None ,
367400) -> Type [marshmallow .Schema ]:
368- _RECURSION_GUARD .seen_classes [clazz ] = clazz .__name__
401+ schema_ctx = _schema_ctx_stack .top
402+ schema_ctx .seen_classes [clazz ] = clazz .__name__
369403 try :
370404 # noinspection PyDataclass
371405 fields : Tuple [dataclasses .Field , ...] = dataclasses .fields (clazz )
@@ -383,7 +417,7 @@ def _internal_class_schema(
383417 "****** WARNING ******"
384418 )
385419 created_dataclass : type = dataclasses .dataclass (clazz )
386- return _internal_class_schema (created_dataclass , base_schema , clazz_frame )
420+ return _internal_class_schema (created_dataclass , base_schema )
387421 except Exception :
388422 raise TypeError (
389423 f"{ getattr (clazz , '__name__' , repr (clazz ))} is not a dataclass and cannot be turned into one."
@@ -397,18 +431,15 @@ def _internal_class_schema(
397431 }
398432
399433 # Update the schema members to contain marshmallow fields instead of dataclass fields
400- type_hints = get_type_hints (
401- clazz , localns = clazz_frame .f_locals if clazz_frame else None
402- )
434+ type_hints = schema_ctx .get_type_hints (clazz )
403435 attributes .update (
404436 (
405437 field .name ,
406- field_for_schema (
438+ _field_for_schema (
407439 type_hints [field .name ],
408440 _get_field_default (field ),
409441 field .metadata ,
410442 base_schema ,
411- clazz_frame ,
412443 ),
413444 )
414445 for field in fields
@@ -433,7 +464,6 @@ def _field_by_supertype(
433464 newtype_supertype : Type ,
434465 metadata : dict ,
435466 base_schema : Optional [Type [marshmallow .Schema ]],
436- typ_frame : Optional [types .FrameType ],
437467) -> marshmallow .fields .Field :
438468 """
439469 Return a new field for fields based on a super field. (Usually spawned from NewType)
@@ -459,12 +489,11 @@ def _field_by_supertype(
459489 if field :
460490 return field (** metadata )
461491 else :
462- return field_for_schema (
492+ return _field_for_schema (
463493 newtype_supertype ,
464494 metadata = metadata ,
465495 default = default ,
466496 base_schema = base_schema ,
467- typ_frame = typ_frame ,
468497 )
469498
470499
@@ -488,7 +517,6 @@ def _generic_type_add_any(typ: type) -> type:
488517def _field_for_generic_type (
489518 typ : type ,
490519 base_schema : Optional [Type [marshmallow .Schema ]],
491- typ_frame : Optional [types .FrameType ],
492520 ** metadata : Any ,
493521) -> Optional [marshmallow .fields .Field ]:
494522 """
@@ -501,9 +529,7 @@ def _field_for_generic_type(
501529 type_mapping = base_schema .TYPE_MAPPING if base_schema else {}
502530
503531 if origin in (list , List ):
504- child_type = field_for_schema (
505- arguments [0 ], base_schema = base_schema , typ_frame = typ_frame
506- )
532+ child_type = _field_for_schema (arguments [0 ], base_schema = base_schema )
507533 list_type = cast (
508534 Type [marshmallow .fields .List ],
509535 type_mapping .get (List , marshmallow .fields .List ),
@@ -512,32 +538,25 @@ def _field_for_generic_type(
512538 if origin in (collections .abc .Sequence , Sequence ):
513539 from . import collection_field
514540
515- child_type = field_for_schema (
516- arguments [0 ], base_schema = base_schema , typ_frame = typ_frame
517- )
541+ child_type = _field_for_schema (arguments [0 ], base_schema = base_schema )
518542 return collection_field .Sequence (cls_or_instance = child_type , ** metadata )
519543 if origin in (set , Set ):
520544 from . import collection_field
521545
522- child_type = field_for_schema (
523- arguments [0 ], base_schema = base_schema , typ_frame = typ_frame
524- )
546+ child_type = _field_for_schema (arguments [0 ], base_schema = base_schema )
525547 return collection_field .Set (
526548 cls_or_instance = child_type , frozen = False , ** metadata
527549 )
528550 if origin in (frozenset , FrozenSet ):
529551 from . import collection_field
530552
531- child_type = field_for_schema (
532- arguments [0 ], base_schema = base_schema , typ_frame = typ_frame
533- )
553+ child_type = _field_for_schema (arguments [0 ], base_schema = base_schema )
534554 return collection_field .Set (
535555 cls_or_instance = child_type , frozen = True , ** metadata
536556 )
537557 if origin in (tuple , Tuple ):
538558 children = tuple (
539- field_for_schema (arg , base_schema = base_schema , typ_frame = typ_frame )
540- for arg in arguments
559+ _field_for_schema (arg , base_schema = base_schema ) for arg in arguments
541560 )
542561 tuple_type = cast (
543562 Type [marshmallow .fields .Tuple ],
@@ -549,12 +568,8 @@ def _field_for_generic_type(
549568 elif origin in (dict , Dict , collections .abc .Mapping , Mapping ):
550569 dict_type = type_mapping .get (Dict , marshmallow .fields .Dict )
551570 return dict_type (
552- keys = field_for_schema (
553- arguments [0 ], base_schema = base_schema , typ_frame = typ_frame
554- ),
555- values = field_for_schema (
556- arguments [1 ], base_schema = base_schema , typ_frame = typ_frame
557- ),
571+ keys = _field_for_schema (arguments [0 ], base_schema = base_schema ),
572+ values = _field_for_schema (arguments [1 ], base_schema = base_schema ),
558573 ** metadata ,
559574 )
560575 elif typing_inspect .is_union_type (typ ):
@@ -566,23 +581,21 @@ def _field_for_generic_type(
566581 metadata .setdefault ("required" , False )
567582 subtypes = [t for t in arguments if t is not NoneType ] # type: ignore
568583 if len (subtypes ) == 1 :
569- return field_for_schema (
584+ return _field_for_schema (
570585 subtypes [0 ],
571586 metadata = metadata ,
572587 base_schema = base_schema ,
573- typ_frame = typ_frame ,
574588 )
575589 from . import union_field
576590
577591 return union_field .Union (
578592 [
579593 (
580594 subtyp ,
581- field_for_schema (
595+ _field_for_schema (
582596 subtyp ,
583597 metadata = {"required" : True },
584598 base_schema = base_schema ,
585- typ_frame = typ_frame ,
586599 ),
587600 )
588601 for subtyp in subtypes
@@ -618,6 +631,29 @@ def field_for_schema(
618631
619632 >>> field_for_schema(str, metadata={"marshmallow_field": marshmallow.fields.Url()}).__class__
620633 <class 'marshmallow.fields.Url'>
634+ """
635+ with _SchemaContext (typ_frame ):
636+ return _field_for_schema (typ , default , metadata , base_schema )
637+
638+
639+ def _field_for_schema (
640+ typ : type ,
641+ default = marshmallow .missing ,
642+ metadata : Mapping [str , Any ] = None ,
643+ base_schema : Optional [Type [marshmallow .Schema ]] = None ,
644+ ) -> marshmallow .fields .Field :
645+ """
646+ Get a marshmallow Field corresponding to the given python type.
647+ The metadata of the dataclass field is used as arguments to the marshmallow Field.
648+
649+ This is an internal version of field_for_schema. It assumes a _SchemaContext
650+ has been pushed onto the local stack.
651+
652+ :param typ: The type for which a field should be generated
653+ :param default: value to use for (de)serialization when the field is missing
654+ :param metadata: Additional parameters to pass to the marshmallow field constructor
655+ :param base_schema: marshmallow schema used as a base class when deriving dataclass schema
656+
621657 """
622658
623659 metadata = {} if metadata is None else dict (metadata )
@@ -690,10 +726,10 @@ def field_for_schema(
690726 )
691727 else :
692728 subtyp = Any
693- return field_for_schema (subtyp , default , metadata , base_schema , typ_frame )
729+ return _field_for_schema (subtyp , default , metadata , base_schema )
694730
695731 # Generic types
696- generic_field = _field_for_generic_type (typ , base_schema , typ_frame , ** metadata )
732+ generic_field = _field_for_generic_type (typ , base_schema , ** metadata )
697733 if generic_field :
698734 return generic_field
699735
@@ -707,7 +743,6 @@ def field_for_schema(
707743 newtype_supertype = newtype_supertype ,
708744 metadata = metadata ,
709745 base_schema = base_schema ,
710- typ_frame = typ_frame ,
711746 )
712747
713748 # enumerations
@@ -726,8 +761,8 @@ def field_for_schema(
726761 nested = (
727762 nested_schema
728763 or forward_reference
729- or _RECURSION_GUARD .seen_classes .get (typ )
730- or _internal_class_schema (typ , base_schema , typ_frame )
764+ or _schema_ctx_stack . top .seen_classes .get (typ )
765+ or _internal_class_schema (typ , base_schema )
731766 )
732767
733768 return marshmallow .fields .Nested (nested , ** metadata )
0 commit comments