@@ -439,6 +439,7 @@ class _SchemaContext:
439439
440440 globalns : Optional [Dict [str , Any ]] = None
441441 localns : Optional [Dict [str , Any ]] = None
442+ base_schema : Optional [Type [marshmallow .Schema ]] = None
442443 seen_classes : Dict [type , str ] = dataclasses .field (default_factory = dict )
443444
444445 def __enter__ (self ) -> "_SchemaContext" :
@@ -513,27 +514,27 @@ def _internal_class_schema(
513514 type_hints = get_type_hints (
514515 clazz , globalns = schema_ctx .globalns , localns = schema_ctx .localns
515516 )
516- attributes .update (
517- (
518- field .name ,
519- _field_for_schema (
520- type_hints [field .name ],
521- _get_field_default (field ),
522- field .metadata ,
523- base_schema ,
524- ),
517+ with dataclasses .replace (schema_ctx , base_schema = base_schema ):
518+ attributes .update (
519+ (
520+ field .name ,
521+ _field_for_schema (
522+ type_hints [field .name ],
523+ _get_field_default (field ),
524+ field .metadata ,
525+ ),
526+ )
527+ for field in fields
528+ if field .init
525529 )
526- for field in fields
527- if field .init
528- )
529530
530531 schema_class = type (clazz .__name__ , (_base_schema (clazz , base_schema ),), attributes )
531532 return cast (Type [marshmallow .Schema ], schema_class )
532533
533534
534- def _field_by_type (
535- typ : Union [ type , Any ], base_schema : Optional [ Type [ marshmallow . Schema ]]
536- ) -> Optional [ Type [ marshmallow . fields . Field ]]:
535+ def _field_by_type (typ : Union [ type , Any ]) -> Optional [ Type [ marshmallow . fields . Field ]]:
536+ # FIXME: remove this function
537+ base_schema = _schema_ctx_stack . top . base_schema
537538 return (
538539 base_schema and base_schema .TYPE_MAPPING .get (typ )
539540 ) or marshmallow .Schema .TYPE_MAPPING .get (typ )
@@ -544,7 +545,6 @@ def _field_by_supertype(
544545 default : Any ,
545546 newtype_supertype : Type ,
546547 metadata : dict ,
547- base_schema : Optional [Type [marshmallow .Schema ]],
548548) -> marshmallow .fields .Field :
549549 """
550550 Return a new field for fields based on a super field. (Usually spawned from NewType)
@@ -574,7 +574,6 @@ def _field_by_supertype(
574574 newtype_supertype ,
575575 metadata = metadata ,
576576 default = default ,
577- base_schema = base_schema ,
578577 )
579578
580579
@@ -597,7 +596,6 @@ def _generic_type_add_any(typ: type) -> type:
597596
598597def _field_for_generic_type (
599598 typ : type ,
600- base_schema : Optional [Type [marshmallow .Schema ]],
601599 ** metadata : Any ,
602600) -> Optional [marshmallow .fields .Field ]:
603601 """
@@ -607,10 +605,11 @@ def _field_for_generic_type(
607605 arguments = typing_inspect .get_args (typ , True )
608606 if origin :
609607 # Override base_schema.TYPE_MAPPING to change the class used for generic types below
608+ base_schema = _schema_ctx_stack .top .base_schema
610609 type_mapping = base_schema .TYPE_MAPPING if base_schema else {}
611610
612611 if origin in (list , List ):
613- child_type = _field_for_schema (arguments [0 ], base_schema = base_schema )
612+ child_type = _field_for_schema (arguments [0 ])
614613 list_type = cast (
615614 Type [marshmallow .fields .List ],
616615 type_mapping .get (List , marshmallow .fields .List ),
@@ -623,26 +622,24 @@ def _field_for_generic_type(
623622 ):
624623 from . import collection_field
625624
626- child_type = _field_for_schema (arguments [0 ], base_schema = base_schema )
625+ child_type = _field_for_schema (arguments [0 ])
627626 return collection_field .Sequence (cls_or_instance = child_type , ** metadata )
628627 if origin in (set , Set ):
629628 from . import collection_field
630629
631- child_type = _field_for_schema (arguments [0 ], base_schema = base_schema )
630+ child_type = _field_for_schema (arguments [0 ])
632631 return collection_field .Set (
633632 cls_or_instance = child_type , frozen = False , ** metadata
634633 )
635634 if origin in (frozenset , FrozenSet ):
636635 from . import collection_field
637636
638- child_type = _field_for_schema (arguments [0 ], base_schema = base_schema )
637+ child_type = _field_for_schema (arguments [0 ])
639638 return collection_field .Set (
640639 cls_or_instance = child_type , frozen = True , ** metadata
641640 )
642641 if origin in (tuple , Tuple ):
643- children = tuple (
644- _field_for_schema (arg , base_schema = base_schema ) for arg in arguments
645- )
642+ children = tuple (_field_for_schema (arg ) for arg in arguments )
646643 tuple_type = cast (
647644 Type [marshmallow .fields .Tuple ],
648645 type_mapping .get ( # type:ignore[call-overload]
@@ -653,8 +650,8 @@ def _field_for_generic_type(
653650 elif origin in (dict , Dict , collections .abc .Mapping , Mapping ):
654651 dict_type = type_mapping .get (Dict , marshmallow .fields .Dict )
655652 return dict_type (
656- keys = _field_for_schema (arguments [0 ], base_schema = base_schema ),
657- values = _field_for_schema (arguments [1 ], base_schema = base_schema ),
653+ keys = _field_for_schema (arguments [0 ]),
654+ values = _field_for_schema (arguments [1 ]),
658655 ** metadata ,
659656 )
660657
@@ -670,7 +667,6 @@ def _field_for_generic_type(
670667 return _field_for_schema (
671668 subtypes [0 ],
672669 metadata = metadata ,
673- base_schema = base_schema ,
674670 )
675671 from . import union_field
676672
@@ -681,7 +677,6 @@ def _field_for_generic_type(
681677 _field_for_schema (
682678 subtyp ,
683679 metadata = {"required" : True },
684- base_schema = base_schema ,
685680 ),
686681 )
687682 for subtyp in subtypes
@@ -719,15 +714,15 @@ def field_for_schema(
719714 >>> field_for_schema(str, metadata={"marshmallow_field": marshmallow.fields.Url()}).__class__
720715 <class 'marshmallow.fields.Url'>
721716 """
722- with _SchemaContext (localns = typ_frame .f_locals if typ_frame is not None else None ):
723- return _field_for_schema (typ , default , metadata , base_schema )
717+ localns = typ_frame .f_locals if typ_frame is not None else None
718+ with _SchemaContext (localns = localns , base_schema = base_schema ):
719+ return _field_for_schema (typ , default , metadata )
724720
725721
726722def _field_for_schema (
727723 typ : type ,
728724 default : Any = marshmallow .missing ,
729725 metadata : Optional [Mapping [str , Any ]] = None ,
730- base_schema : Optional [Type [marshmallow .Schema ]] = None ,
731726) -> marshmallow .fields .Field :
732727 """
733728 Get a marshmallow Field corresponding to the given python type.
@@ -739,7 +734,6 @@ def _field_for_schema(
739734 :param typ: The type for which a field should be generated
740735 :param default: value to use for (de)serialization when the field is missing
741736 :param metadata: Additional parameters to pass to the marshmallow field constructor
742- :param base_schema: marshmallow schema used as a base class when deriving dataclass schema
743737
744738 """
745739
@@ -762,7 +756,7 @@ def _field_for_schema(
762756 typ = _generic_type_add_any (typ )
763757
764758 # Base types
765- field = _field_by_type (typ , base_schema )
759+ field = _field_by_type (typ )
766760 if field :
767761 return field (** metadata )
768762
@@ -813,10 +807,10 @@ def _field_for_schema(
813807 )
814808 else :
815809 subtyp = Any
816- return _field_for_schema (subtyp , default , metadata , base_schema )
810+ return _field_for_schema (subtyp , default , metadata )
817811
818812 # Generic types
819- generic_field = _field_for_generic_type (typ , base_schema , ** metadata )
813+ generic_field = _field_for_generic_type (typ , ** metadata )
820814 if generic_field :
821815 return generic_field
822816
@@ -829,7 +823,6 @@ def _field_for_schema(
829823 default = default ,
830824 newtype_supertype = newtype_supertype ,
831825 metadata = metadata ,
832- base_schema = base_schema ,
833826 )
834827
835828 # enumerations
@@ -849,6 +842,7 @@ def _field_for_schema(
849842 # Nested dataclasses
850843 forward_reference = getattr (typ , "__forward_arg__" , None )
851844
845+ base_schema = _schema_ctx_stack .top .base_schema
852846 nested = (
853847 nested_schema
854848 or forward_reference
0 commit comments