Skip to content

Commit 01c71da

Browse files
committed
refactor: add _SchemaContext.get_type_mapping
1 parent d95e171 commit 01c71da

File tree

1 file changed

+31
-17
lines changed

1 file changed

+31
-17
lines changed

marshmallow_dataclass/__init__.py

Lines changed: 31 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ class User:
4646
from typing import (
4747
Any,
4848
Callable,
49+
ChainMap,
4950
Dict,
5051
Generic,
5152
List,
@@ -74,6 +75,8 @@ class User:
7475

7576
NoneType = type(None)
7677
_U = TypeVar("_U")
78+
_Field = TypeVar("_Field", bound=marshmallow.fields.Field)
79+
7780

7881
# Whitelist of dataclass members that will be copied to generated schema.
7982
MEMBERS_WHITELIST: Set[str] = {"Meta"}
@@ -442,6 +445,23 @@ class _SchemaContext:
442445
base_schema: Optional[Type[marshmallow.Schema]] = None
443446
seen_classes: Dict[type, str] = dataclasses.field(default_factory=dict)
444447

448+
def get_type_mapping(
449+
self, use_mro: bool = False
450+
) -> Mapping[Any, Type[marshmallow.fields.Field]]:
451+
"""Get base_schema.TYPE_MAPPING.
452+
453+
If use_mro is true, then merges the TYPE_MAPPINGs from
454+
all bases in base_schema's MRO.
455+
"""
456+
base_schema = self.base_schema
457+
if base_schema is None:
458+
base_schema = marshmallow.Schema
459+
if use_mro:
460+
return ChainMap(
461+
*(getattr(cls, "TYPE_MAPPING", {}) for cls in base_schema.__mro__)
462+
)
463+
return getattr(base_schema, "TYPE_MAPPING", {})
464+
445465
def __enter__(self) -> "_SchemaContext":
446466
_schema_ctx_stack.push(self)
447467
return self
@@ -534,10 +554,9 @@ def _internal_class_schema(
534554

535555
def _field_by_type(typ: Union[type, Any]) -> Optional[Type[marshmallow.fields.Field]]:
536556
# FIXME: remove this function
537-
base_schema = _schema_ctx_stack.top.base_schema
538-
return (
539-
base_schema and base_schema.TYPE_MAPPING.get(typ)
540-
) or marshmallow.Schema.TYPE_MAPPING.get(typ)
557+
schema_ctx = _schema_ctx_stack.top
558+
type_mapping = schema_ctx.get_type_mapping(use_mro=True)
559+
return type_mapping.get(typ)
541560

542561

543562
def _field_by_supertype(
@@ -605,15 +624,15 @@ def _field_for_generic_type(
605624
arguments = typing_inspect.get_args(typ, True)
606625
if origin:
607626
# Override base_schema.TYPE_MAPPING to change the class used for generic types below
608-
base_schema = _schema_ctx_stack.top.base_schema
609-
type_mapping = base_schema.TYPE_MAPPING if base_schema else {}
627+
schema_ctx = _schema_ctx_stack.top
628+
629+
def get_field_type(type_spec: Any, default: Type[_Field]) -> Type[_Field]:
630+
type_mapping = schema_ctx.get_type_mapping()
631+
return type_mapping.get(type_spec, default) # type: ignore[return-value]
610632

611633
if origin in (list, List):
612634
child_type = _field_for_schema(arguments[0])
613-
list_type = cast(
614-
Type[marshmallow.fields.List],
615-
type_mapping.get(List, marshmallow.fields.List),
616-
)
635+
list_type = get_field_type(List, default=marshmallow.fields.List)
617636
return list_type(child_type, **metadata)
618637
if origin in (collections.abc.Sequence, Sequence) or (
619638
origin in (tuple, Tuple)
@@ -640,15 +659,10 @@ def _field_for_generic_type(
640659
)
641660
if origin in (tuple, Tuple):
642661
children = tuple(_field_for_schema(arg) for arg in arguments)
643-
tuple_type = cast(
644-
Type[marshmallow.fields.Tuple],
645-
type_mapping.get( # type:ignore[call-overload]
646-
Tuple, marshmallow.fields.Tuple
647-
),
648-
)
662+
tuple_type = get_field_type(Tuple, default=marshmallow.fields.Tuple)
649663
return tuple_type(children, **metadata)
650664
elif origin in (dict, Dict, collections.abc.Mapping, Mapping):
651-
dict_type = type_mapping.get(Dict, marshmallow.fields.Dict)
665+
dict_type = get_field_type(Dict, default=marshmallow.fields.Dict)
652666
return dict_type(
653667
keys=_field_for_schema(arguments[0]),
654668
values=_field_for_schema(arguments[1]),

0 commit comments

Comments
 (0)