@@ -433,6 +433,35 @@ def class_schema(
433433 return _internal_class_schema (clazz , base_schema )
434434
435435
436+ class _TypeMapping :
437+ """Helper for looking up field types in a chained list of TYPE_MAPPINGs"""
438+
439+ def __init__ (self , * mappings : Mapping [Any , Type [marshmallow .fields .Field ]]) -> None :
440+ self .mappings = mappings
441+
442+ _Field = TypeVar ("_Field" , bound = marshmallow .fields .Field )
443+
444+ @overload
445+ def get (self , typ : object , default : Type [_Field ]) -> Type [_Field ]:
446+ ...
447+
448+ @overload
449+ def get (
450+ self , typ : object , default : None = None
451+ ) -> Optional [Type [marshmallow .fields .Field ]]:
452+ ...
453+
454+ def get (
455+ self , typ : object , default : Optional [Type [_Field ]] = None
456+ ) -> Optional [Type [marshmallow .fields .Field ]]:
457+ for mapping in self .mappings :
458+ try :
459+ return mapping [typ ]
460+ except KeyError :
461+ pass
462+ return default
463+
464+
436465@dataclasses .dataclass
437466class _SchemaContext :
438467 """Global context for an invocation of class_schema."""
@@ -442,6 +471,18 @@ class _SchemaContext:
442471 base_schema : Optional [Type [marshmallow .Schema ]] = None
443472 seen_classes : Dict [type , str ] = dataclasses .field (default_factory = dict )
444473
474+ def get_type_mapping (
475+ self , include_marshmallow_default : bool = False
476+ ) -> _TypeMapping :
477+ default_mapping = marshmallow .Schema .TYPE_MAPPING
478+ if self .base_schema is not None :
479+ mappings = [self .base_schema .TYPE_MAPPING ]
480+ if include_marshmallow_default :
481+ mappings .append (default_mapping )
482+ else :
483+ mappings = [default_mapping ]
484+ return _TypeMapping (* mappings )
485+
445486 def __enter__ (self ) -> "_SchemaContext" :
446487 _schema_ctx_stack .push (self )
447488 return self
@@ -534,10 +575,9 @@ def _internal_class_schema(
534575
535576def _field_by_type (typ : Union [type , Any ]) -> Optional [Type [marshmallow .fields .Field ]]:
536577 # FIXME: remove this function
537- base_schema = _schema_ctx_stack .top .base_schema
538- return (
539- base_schema and base_schema .TYPE_MAPPING .get (typ )
540- ) or marshmallow .Schema .TYPE_MAPPING .get (typ )
578+ schema_ctx = _schema_ctx_stack .top
579+ type_mapping = schema_ctx .get_type_mapping (include_marshmallow_default = True )
580+ return type_mapping .get (typ )
541581
542582
543583def _field_by_supertype (
@@ -605,15 +645,12 @@ def _field_for_generic_type(
605645 arguments = typing_inspect .get_args (typ , True )
606646 if origin :
607647 # Override base_schema.TYPE_MAPPING to change the class used for generic types below
608- base_schema = _schema_ctx_stack .top . base_schema
609- type_mapping = base_schema . TYPE_MAPPING if base_schema else {}
648+ schema_ctx = _schema_ctx_stack .top
649+ type_mapping = schema_ctx . get_type_mapping ()
610650
611651 if origin in (list , List ):
612652 child_type = _field_for_schema (arguments [0 ])
613- list_type = cast (
614- Type [marshmallow .fields .List ],
615- type_mapping .get (List , marshmallow .fields .List ),
616- )
653+ list_type = type_mapping .get (List , default = marshmallow .fields .List )
617654 return list_type (child_type , ** metadata )
618655 if origin in (collections .abc .Sequence , Sequence ) or (
619656 origin in (tuple , Tuple )
@@ -640,15 +677,10 @@ def _field_for_generic_type(
640677 )
641678 if origin in (tuple , Tuple ):
642679 children = tuple (_field_for_schema (arg ) for arg in arguments )
643- tuple_type = cast (
644- Type [marshmallow .fields .Tuple ],
645- type_mapping .get ( # type:ignore[call-overload]
646- Tuple , marshmallow .fields .Tuple
647- ),
648- )
680+ tuple_type = type_mapping .get (Tuple , default = marshmallow .fields .Tuple )
649681 return tuple_type (children , ** metadata )
650682 elif origin in (dict , Dict , collections .abc .Mapping , Mapping ):
651- dict_type = type_mapping .get (Dict , marshmallow .fields .Dict )
683+ dict_type = type_mapping .get (Dict , default = marshmallow .fields .Dict )
652684 return dict_type (
653685 keys = _field_for_schema (arguments [0 ]),
654686 values = _field_for_schema (arguments [1 ]),
0 commit comments