@@ -46,6 +46,7 @@ class User:
4646from typing import (
4747 Any ,
4848 Callable ,
49+ ChainMap ,
4950 Dict ,
5051 Generic ,
5152 List ,
@@ -74,6 +75,8 @@ class User:
7475
7576NoneType = type (None )
7677_U = TypeVar ("_U" )
78+ _Field = TypeVar ("_Field" , bound = marshmallow .fields .Field )
79+
7780
7881# Whitelist of dataclass members that will be copied to generated schema.
7982MEMBERS_WHITELIST : Set [str ] = {"Meta" }
@@ -442,6 +445,23 @@ class _SchemaContext:
442445 base_schema : Optional [Type [marshmallow .Schema ]] = None
443446 seen_classes : Dict [type , str ] = dataclasses .field (default_factory = dict )
444447
448+ def get_type_mapping (
449+ self , use_mro : bool = False
450+ ) -> Mapping [Any , Type [marshmallow .fields .Field ]]:
451+ """Get base_schema.TYPE_MAPPING.
452+
453+ If use_mro is true, then merges the TYPE_MAPPINGs from
454+ all bases in base_schema's MRO.
455+ """
456+ base_schema = self .base_schema
457+ if base_schema is None :
458+ base_schema = marshmallow .Schema
459+ if use_mro :
460+ return ChainMap (
461+ * (getattr (cls , "TYPE_MAPPING" , {}) for cls in base_schema .__mro__ )
462+ )
463+ return getattr (base_schema , "TYPE_MAPPING" , {})
464+
445465 def __enter__ (self ) -> "_SchemaContext" :
446466 _schema_ctx_stack .push (self )
447467 return self
@@ -534,10 +554,9 @@ def _internal_class_schema(
534554
535555def _field_by_type (typ : Union [type , Any ]) -> Optional [Type [marshmallow .fields .Field ]]:
536556 # FIXME: remove this function
537- base_schema = _schema_ctx_stack .top .base_schema
538- return (
539- base_schema and base_schema .TYPE_MAPPING .get (typ )
540- ) or marshmallow .Schema .TYPE_MAPPING .get (typ )
557+ schema_ctx = _schema_ctx_stack .top
558+ type_mapping = schema_ctx .get_type_mapping (use_mro = True )
559+ return type_mapping .get (typ )
541560
542561
543562def _field_by_supertype (
@@ -605,15 +624,15 @@ def _field_for_generic_type(
605624 arguments = typing_inspect .get_args (typ , True )
606625 if origin :
607626 # Override base_schema.TYPE_MAPPING to change the class used for generic types below
608- base_schema = _schema_ctx_stack .top .base_schema
609- type_mapping = base_schema .TYPE_MAPPING if base_schema else {}
627+ schema_ctx = _schema_ctx_stack .top
628+
629+ def get_field_type (type_spec : Any , default : Type [_Field ]) -> Type [_Field ]:
630+ type_mapping = schema_ctx .get_type_mapping ()
631+ return type_mapping .get (type_spec , default ) # type: ignore[return-value]
610632
611633 if origin in (list , List ):
612634 child_type = _field_for_schema (arguments [0 ])
613- list_type = cast (
614- Type [marshmallow .fields .List ],
615- type_mapping .get (List , marshmallow .fields .List ),
616- )
635+ list_type = get_field_type (List , default = marshmallow .fields .List )
617636 return list_type (child_type , ** metadata )
618637 if origin in (collections .abc .Sequence , Sequence ) or (
619638 origin in (tuple , Tuple )
@@ -640,15 +659,10 @@ def _field_for_generic_type(
640659 )
641660 if origin in (tuple , Tuple ):
642661 children = tuple (_field_for_schema (arg ) for arg in arguments )
643- tuple_type = cast (
644- Type [marshmallow .fields .Tuple ],
645- type_mapping .get ( # type:ignore[call-overload]
646- Tuple , marshmallow .fields .Tuple
647- ),
648- )
662+ tuple_type = get_field_type (Tuple , default = marshmallow .fields .Tuple )
649663 return tuple_type (children , ** metadata )
650664 elif origin in (dict , Dict , collections .abc .Mapping , Mapping ):
651- dict_type = type_mapping . get (Dict , marshmallow .fields .Dict )
665+ dict_type = get_field_type (Dict , default = marshmallow .fields .Dict )
652666 return dict_type (
653667 keys = _field_for_schema (arguments [0 ]),
654668 values = _field_for_schema (arguments [1 ]),
0 commit comments