@@ -46,6 +46,7 @@ class User:
4646 Any ,
4747 Callable ,
4848 Dict ,
49+ Generic ,
4950 List ,
5051 Mapping ,
5152 NewType as typing_NewType ,
@@ -79,9 +80,6 @@ class User:
7980# Max number of generated schemas that class_schema keeps of generated schemas. Removes duplicates.
8081MAX_CLASS_SCHEMA_CACHE_SIZE = 1024
8182
82- # Recursion guard for class_schema()
83- _RECURSION_GUARD = threading .local ()
84-
8583
8684@overload
8785def dataclass (
@@ -352,20 +350,61 @@ def class_schema(
352350 clazz_frame = current_frame .f_back
353351 # Per https://docs.python.org/3/library/inspect.html#the-interpreter-stack
354352 del current_frame
355- _RECURSION_GUARD .seen_classes = {}
356- try :
357- return _internal_class_schema (clazz , base_schema , clazz_frame )
358- finally :
359- _RECURSION_GUARD .seen_classes .clear ()
353+
354+ with _SchemaContext (clazz_frame ):
355+ return _internal_class_schema (clazz , base_schema )
356+
357+
358+ class _SchemaContext :
359+ """Global context for an invocation of class_schema."""
360+
361+ def __init__ (self , frame : Optional [types .FrameType ]):
362+ self .seen_classes : Dict [type , str ] = {}
363+ self .frame = frame
364+
365+ def get_type_hints (self , cls : Type ) -> Dict [str , Any ]:
366+ frame = self .frame
367+ localns = frame .f_locals if frame is not None else None
368+ return get_type_hints (cls , localns = localns )
369+
370+ def __enter__ (self ) -> "_SchemaContext" :
371+ _schema_ctx_stack .push (self )
372+ return self
373+
374+ def __exit__ (
375+ self ,
376+ _typ : Optional [Type [BaseException ]],
377+ _value : Optional [BaseException ],
378+ _tb : Optional [types .TracebackType ],
379+ ) -> None :
380+ _schema_ctx_stack .pop ()
381+
382+
383+ class _LocalStack (threading .local , Generic [_U ]):
384+ def __init__ (self ) -> None :
385+ self .stack : List [_U ] = []
386+
387+ def push (self , value : _U ) -> None :
388+ self .stack .append (value )
389+
390+ def pop (self ) -> None :
391+ self .stack .pop ()
392+
393+ @property
394+ def top (self ) -> _U :
395+ return self .stack [- 1 ]
396+
397+
398+ _schema_ctx_stack = _LocalStack [_SchemaContext ]()
360399
361400
362401@lru_cache (maxsize = MAX_CLASS_SCHEMA_CACHE_SIZE )
363402def _internal_class_schema (
364403 clazz : type ,
365404 base_schema : Optional [Type [marshmallow .Schema ]] = None ,
366- clazz_frame : Optional [types .FrameType ] = None ,
367405) -> Type [marshmallow .Schema ]:
368- _RECURSION_GUARD .seen_classes [clazz ] = clazz .__name__
406+ schema_ctx = _schema_ctx_stack .top
407+ schema_ctx .seen_classes [clazz ] = clazz .__name__
369408 try :
370409 # noinspection PyDataclass
371410 fields : Tuple [dataclasses .Field , ...] = dataclasses .fields (clazz )
@@ -383,7 +422,7 @@ def _internal_class_schema(
383422 "****** WARNING ******"
384423 )
385424 created_dataclass : type = dataclasses .dataclass (clazz )
386- return _internal_class_schema (created_dataclass , base_schema , clazz_frame )
425+ return _internal_class_schema (created_dataclass , base_schema )
387426 except Exception as exc :
388427 raise TypeError (
389428 f"{ getattr (clazz , '__name__' , repr (clazz ))} is not a dataclass and cannot be turned into one."
@@ -397,18 +436,15 @@ def _internal_class_schema(
397436 }
398437
399438 # Update the schema members to contain marshmallow fields instead of dataclass fields
400- type_hints = get_type_hints (
401- clazz , localns = clazz_frame .f_locals if clazz_frame else None
402- )
439+ type_hints = schema_ctx .get_type_hints (clazz )
403440 attributes .update (
404441 (
405442 field .name ,
406- field_for_schema (
443+ _field_for_schema (
407444 type_hints [field .name ],
408445 _get_field_default (field ),
409446 field .metadata ,
410447 base_schema ,
411- clazz_frame ,
412448 ),
413449 )
414450 for field in fields
@@ -433,7 +469,6 @@ def _field_by_supertype(
433469 newtype_supertype : Type ,
434470 metadata : dict ,
435471 base_schema : Optional [Type [marshmallow .Schema ]],
436- typ_frame : Optional [types .FrameType ],
437472) -> marshmallow .fields .Field :
438473 """
439474 Return a new field for fields based on a super field. (Usually spawned from NewType)
@@ -459,12 +494,11 @@ def _field_by_supertype(
459494 if field :
460495 return field (** metadata )
461496 else :
462- return field_for_schema (
497+ return _field_for_schema (
463498 newtype_supertype ,
464499 metadata = metadata ,
465500 default = default ,
466501 base_schema = base_schema ,
467- typ_frame = typ_frame ,
468502 )
469503
470504
@@ -488,7 +522,6 @@ def _generic_type_add_any(typ: type) -> type:
488522def _field_for_generic_type (
489523 typ : type ,
490524 base_schema : Optional [Type [marshmallow .Schema ]],
491- typ_frame : Optional [types .FrameType ],
492525 ** metadata : Any ,
493526) -> Optional [marshmallow .fields .Field ]:
494527 """
@@ -501,9 +534,7 @@ def _field_for_generic_type(
501534 type_mapping = base_schema .TYPE_MAPPING if base_schema else {}
502535
503536 if origin in (list , List ):
504- child_type = field_for_schema (
505- arguments [0 ], base_schema = base_schema , typ_frame = typ_frame
506- )
537+ child_type = _field_for_schema (arguments [0 ], base_schema = base_schema )
507538 list_type = cast (
508539 Type [marshmallow .fields .List ],
509540 type_mapping .get (List , marshmallow .fields .List ),
@@ -516,32 +547,25 @@ def _field_for_generic_type(
516547 ):
517548 from . import collection_field
518549
519- child_type = field_for_schema (
520- arguments [0 ], base_schema = base_schema , typ_frame = typ_frame
521- )
550+ child_type = _field_for_schema (arguments [0 ], base_schema = base_schema )
522551 return collection_field .Sequence (cls_or_instance = child_type , ** metadata )
523552 if origin in (set , Set ):
524553 from . import collection_field
525554
526- child_type = field_for_schema (
527- arguments [0 ], base_schema = base_schema , typ_frame = typ_frame
528- )
555+ child_type = _field_for_schema (arguments [0 ], base_schema = base_schema )
529556 return collection_field .Set (
530557 cls_or_instance = child_type , frozen = False , ** metadata
531558 )
532559 if origin in (frozenset , FrozenSet ):
533560 from . import collection_field
534561
535- child_type = field_for_schema (
536- arguments [0 ], base_schema = base_schema , typ_frame = typ_frame
537- )
562+ child_type = _field_for_schema (arguments [0 ], base_schema = base_schema )
538563 return collection_field .Set (
539564 cls_or_instance = child_type , frozen = True , ** metadata
540565 )
541566 if origin in (tuple , Tuple ):
542567 children = tuple (
543- field_for_schema (arg , base_schema = base_schema , typ_frame = typ_frame )
544- for arg in arguments
568+ _field_for_schema (arg , base_schema = base_schema ) for arg in arguments
545569 )
546570 tuple_type = cast (
547571 Type [marshmallow .fields .Tuple ],
@@ -553,14 +577,11 @@ def _field_for_generic_type(
553577 elif origin in (dict , Dict , collections .abc .Mapping , Mapping ):
554578 dict_type = type_mapping .get (Dict , marshmallow .fields .Dict )
555579 return dict_type (
556- keys = field_for_schema (
557- arguments [0 ], base_schema = base_schema , typ_frame = typ_frame
558- ),
559- values = field_for_schema (
560- arguments [1 ], base_schema = base_schema , typ_frame = typ_frame
561- ),
580+ keys = _field_for_schema (arguments [0 ], base_schema = base_schema ),
581+ values = _field_for_schema (arguments [1 ], base_schema = base_schema ),
562582 ** metadata ,
563583 )
584+
564585 if typing_inspect .is_union_type (typ ):
565586 if typing_inspect .is_optional_type (typ ):
566587 metadata ["allow_none" ] = metadata .get ("allow_none" , True )
@@ -570,23 +591,21 @@ def _field_for_generic_type(
570591 metadata .setdefault ("required" , False )
571592 subtypes = [t for t in arguments if t is not NoneType ] # type: ignore
572593 if len (subtypes ) == 1 :
573- return field_for_schema (
594+ return _field_for_schema (
574595 subtypes [0 ],
575596 metadata = metadata ,
576597 base_schema = base_schema ,
577- typ_frame = typ_frame ,
578598 )
579599 from . import union_field
580600
581601 return union_field .Union (
582602 [
583603 (
584604 subtyp ,
585- field_for_schema (
605+ _field_for_schema (
586606 subtyp ,
587607 metadata = {"required" : True },
588608 base_schema = base_schema ,
589- typ_frame = typ_frame ,
590609 ),
591610 )
592611 for subtyp in subtypes
@@ -598,7 +617,7 @@ def _field_for_generic_type(
598617
599618def field_for_schema (
600619 typ : type ,
601- default = marshmallow .missing ,
620+ default : Any = marshmallow .missing ,
602621 metadata : Optional [Mapping [str , Any ]] = None ,
603622 base_schema : Optional [Type [marshmallow .Schema ]] = None ,
604623 typ_frame : Optional [types .FrameType ] = None ,
@@ -622,6 +641,29 @@ def field_for_schema(
622641
623642 >>> field_for_schema(str, metadata={"marshmallow_field": marshmallow.fields.Url()}).__class__
624643 <class 'marshmallow.fields.Url'>
644+ """
645+ with _SchemaContext (typ_frame ):
646+ return _field_for_schema (typ , default , metadata , base_schema )
647+
648+
649+ def _field_for_schema (
650+ typ : type ,
651+ default : Any = marshmallow .missing ,
652+ metadata : Optional [Mapping [str , Any ]] = None ,
653+ base_schema : Optional [Type [marshmallow .Schema ]] = None ,
654+ ) -> marshmallow .fields .Field :
655+ """
656+ Get a marshmallow Field corresponding to the given python type.
657+ The metadata of the dataclass field is used as arguments to the marshmallow Field.
658+
659+ This is an internal version of field_for_schema. It assumes a _SchemaContext
660+ has been pushed onto the local stack.
661+
662+ :param typ: The type for which a field should be generated
663+ :param default: value to use for (de)serialization when the field is missing
664+ :param metadata: Additional parameters to pass to the marshmallow field constructor
665+ :param base_schema: marshmallow schema used as a base class when deriving dataclass schema
666+
625667 """
626668
627669 metadata = {} if metadata is None else dict (metadata )
@@ -694,10 +736,10 @@ def field_for_schema(
694736 )
695737 else :
696738 subtyp = Any
697- return field_for_schema (subtyp , default , metadata , base_schema , typ_frame )
739+ return _field_for_schema (subtyp , default , metadata , base_schema )
698740
699741 # Generic types
700- generic_field = _field_for_generic_type (typ , base_schema , typ_frame , ** metadata )
742+ generic_field = _field_for_generic_type (typ , base_schema , ** metadata )
701743 if generic_field :
702744 return generic_field
703745
@@ -711,7 +753,6 @@ def field_for_schema(
711753 newtype_supertype = newtype_supertype ,
712754 metadata = metadata ,
713755 base_schema = base_schema ,
714- typ_frame = typ_frame ,
715756 )
716757
717758 # enumerations
@@ -734,8 +775,8 @@ def field_for_schema(
734775 nested = (
735776 nested_schema
736777 or forward_reference
737- or _RECURSION_GUARD .seen_classes .get (typ )
738- or _internal_class_schema (typ , base_schema , typ_frame ) # type: ignore [arg-type]
778+ or _schema_ctx_stack . top .seen_classes .get (typ )
779+ or _internal_class_schema (typ , base_schema ) # type: ignore[arg-type] # FIXME
739780 )
740781
741782 return marshmallow .fields .Nested (nested , ** metadata )
0 commit comments