Skip to content

Commit e585ce4

Browse files
committed
refactor: move base_schema into _SchemaContext
1 parent cd6caaa commit e585ce4

File tree

1 file changed

+31
-37
lines changed

1 file changed

+31
-37
lines changed

marshmallow_dataclass/__init__.py

Lines changed: 31 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -439,6 +439,7 @@ class _SchemaContext:
439439

440440
globalns: Optional[Dict[str, Any]] = None
441441
localns: Optional[Dict[str, Any]] = None
442+
base_schema: Optional[Type[marshmallow.Schema]] = None
442443
seen_classes: Dict[type, str] = dataclasses.field(default_factory=dict)
443444

444445
def __enter__(self) -> "_SchemaContext":
@@ -513,27 +514,27 @@ def _internal_class_schema(
513514
type_hints = get_type_hints(
514515
clazz, globalns=schema_ctx.globalns, localns=schema_ctx.localns
515516
)
516-
attributes.update(
517-
(
518-
field.name,
519-
_field_for_schema(
520-
type_hints[field.name],
521-
_get_field_default(field),
522-
field.metadata,
523-
base_schema,
524-
),
517+
with dataclasses.replace(schema_ctx, base_schema=base_schema):
518+
attributes.update(
519+
(
520+
field.name,
521+
_field_for_schema(
522+
type_hints[field.name],
523+
_get_field_default(field),
524+
field.metadata,
525+
),
526+
)
527+
for field in fields
528+
if field.init
525529
)
526-
for field in fields
527-
if field.init
528-
)
529530

530531
schema_class = type(clazz.__name__, (_base_schema(clazz, base_schema),), attributes)
531532
return cast(Type[marshmallow.Schema], schema_class)
532533

533534

534-
def _field_by_type(
535-
typ: Union[type, Any], base_schema: Optional[Type[marshmallow.Schema]]
536-
) -> Optional[Type[marshmallow.fields.Field]]:
535+
def _field_by_type(typ: Union[type, Any]) -> Optional[Type[marshmallow.fields.Field]]:
536+
# FIXME: remove this function
537+
base_schema = _schema_ctx_stack.top.base_schema
537538
return (
538539
base_schema and base_schema.TYPE_MAPPING.get(typ)
539540
) or marshmallow.Schema.TYPE_MAPPING.get(typ)
@@ -544,7 +545,6 @@ def _field_by_supertype(
544545
default: Any,
545546
newtype_supertype: Type,
546547
metadata: dict,
547-
base_schema: Optional[Type[marshmallow.Schema]],
548548
) -> marshmallow.fields.Field:
549549
"""
550550
Return a new field for fields based on a super field. (Usually spawned from NewType)
@@ -574,7 +574,6 @@ def _field_by_supertype(
574574
newtype_supertype,
575575
metadata=metadata,
576576
default=default,
577-
base_schema=base_schema,
578577
)
579578

580579

@@ -597,7 +596,6 @@ def _generic_type_add_any(typ: type) -> type:
597596

598597
def _field_for_generic_type(
599598
typ: type,
600-
base_schema: Optional[Type[marshmallow.Schema]],
601599
**metadata: Any,
602600
) -> Optional[marshmallow.fields.Field]:
603601
"""
@@ -607,10 +605,11 @@ def _field_for_generic_type(
607605
arguments = typing_inspect.get_args(typ, True)
608606
if origin:
609607
# Override base_schema.TYPE_MAPPING to change the class used for generic types below
608+
base_schema = _schema_ctx_stack.top.base_schema
610609
type_mapping = base_schema.TYPE_MAPPING if base_schema else {}
611610

612611
if origin in (list, List):
613-
child_type = _field_for_schema(arguments[0], base_schema=base_schema)
612+
child_type = _field_for_schema(arguments[0])
614613
list_type = cast(
615614
Type[marshmallow.fields.List],
616615
type_mapping.get(List, marshmallow.fields.List),
@@ -623,26 +622,24 @@ def _field_for_generic_type(
623622
):
624623
from . import collection_field
625624

626-
child_type = _field_for_schema(arguments[0], base_schema=base_schema)
625+
child_type = _field_for_schema(arguments[0])
627626
return collection_field.Sequence(cls_or_instance=child_type, **metadata)
628627
if origin in (set, Set):
629628
from . import collection_field
630629

631-
child_type = _field_for_schema(arguments[0], base_schema=base_schema)
630+
child_type = _field_for_schema(arguments[0])
632631
return collection_field.Set(
633632
cls_or_instance=child_type, frozen=False, **metadata
634633
)
635634
if origin in (frozenset, FrozenSet):
636635
from . import collection_field
637636

638-
child_type = _field_for_schema(arguments[0], base_schema=base_schema)
637+
child_type = _field_for_schema(arguments[0])
639638
return collection_field.Set(
640639
cls_or_instance=child_type, frozen=True, **metadata
641640
)
642641
if origin in (tuple, Tuple):
643-
children = tuple(
644-
_field_for_schema(arg, base_schema=base_schema) for arg in arguments
645-
)
642+
children = tuple(_field_for_schema(arg) for arg in arguments)
646643
tuple_type = cast(
647644
Type[marshmallow.fields.Tuple],
648645
type_mapping.get( # type:ignore[call-overload]
@@ -653,8 +650,8 @@ def _field_for_generic_type(
653650
elif origin in (dict, Dict, collections.abc.Mapping, Mapping):
654651
dict_type = type_mapping.get(Dict, marshmallow.fields.Dict)
655652
return dict_type(
656-
keys=_field_for_schema(arguments[0], base_schema=base_schema),
657-
values=_field_for_schema(arguments[1], base_schema=base_schema),
653+
keys=_field_for_schema(arguments[0]),
654+
values=_field_for_schema(arguments[1]),
658655
**metadata,
659656
)
660657

@@ -670,7 +667,6 @@ def _field_for_generic_type(
670667
return _field_for_schema(
671668
subtypes[0],
672669
metadata=metadata,
673-
base_schema=base_schema,
674670
)
675671
from . import union_field
676672

@@ -681,7 +677,6 @@ def _field_for_generic_type(
681677
_field_for_schema(
682678
subtyp,
683679
metadata={"required": True},
684-
base_schema=base_schema,
685680
),
686681
)
687682
for subtyp in subtypes
@@ -719,15 +714,15 @@ def field_for_schema(
719714
>>> field_for_schema(str, metadata={"marshmallow_field": marshmallow.fields.Url()}).__class__
720715
<class 'marshmallow.fields.Url'>
721716
"""
722-
with _SchemaContext(localns=typ_frame.f_locals if typ_frame is not None else None):
723-
return _field_for_schema(typ, default, metadata, base_schema)
717+
localns = typ_frame.f_locals if typ_frame is not None else None
718+
with _SchemaContext(localns=localns, base_schema=base_schema):
719+
return _field_for_schema(typ, default, metadata)
724720

725721

726722
def _field_for_schema(
727723
typ: type,
728724
default: Any = marshmallow.missing,
729725
metadata: Optional[Mapping[str, Any]] = None,
730-
base_schema: Optional[Type[marshmallow.Schema]] = None,
731726
) -> marshmallow.fields.Field:
732727
"""
733728
Get a marshmallow Field corresponding to the given python type.
@@ -739,7 +734,6 @@ def _field_for_schema(
739734
:param typ: The type for which a field should be generated
740735
:param default: value to use for (de)serialization when the field is missing
741736
:param metadata: Additional parameters to pass to the marshmallow field constructor
742-
:param base_schema: marshmallow schema used as a base class when deriving dataclass schema
743737
744738
"""
745739

@@ -762,7 +756,7 @@ def _field_for_schema(
762756
typ = _generic_type_add_any(typ)
763757

764758
# Base types
765-
field = _field_by_type(typ, base_schema)
759+
field = _field_by_type(typ)
766760
if field:
767761
return field(**metadata)
768762

@@ -813,10 +807,10 @@ def _field_for_schema(
813807
)
814808
else:
815809
subtyp = Any
816-
return _field_for_schema(subtyp, default, metadata, base_schema)
810+
return _field_for_schema(subtyp, default, metadata)
817811

818812
# Generic types
819-
generic_field = _field_for_generic_type(typ, base_schema, **metadata)
813+
generic_field = _field_for_generic_type(typ, **metadata)
820814
if generic_field:
821815
return generic_field
822816

@@ -829,7 +823,6 @@ def _field_for_schema(
829823
default=default,
830824
newtype_supertype=newtype_supertype,
831825
metadata=metadata,
832-
base_schema=base_schema,
833826
)
834827

835828
# enumerations
@@ -849,6 +842,7 @@ def _field_for_schema(
849842
# Nested dataclasses
850843
forward_reference = getattr(typ, "__forward_arg__", None)
851844

845+
base_schema = _schema_ctx_stack.top.base_schema
852846
nested = (
853847
nested_schema
854848
or forward_reference

0 commit comments

Comments
 (0)