Skip to content

Commit f9d894b

Browse files
committed
add test for repeated fields, fix __name__ attr for py<3.10
1 parent b3d3797 commit f9d894b

File tree

2 files changed

+69
-10
lines changed

2 files changed

+69
-10
lines changed

marshmallow_dataclass/__init__.py

Lines changed: 44 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -374,7 +374,10 @@ def _internal_class_schema(
374374
clazz_frame: types.FrameType = None,
375375
generic_params_to_args: Optional[Tuple[Tuple[type, type], ...]] = None,
376376
) -> Type[marshmallow.Schema]:
377-
_RECURSION_GUARD.seen_classes[clazz] = clazz.__name__
377+
# generic aliases do not have a __name__ prior python 3.10
378+
_name = getattr(clazz, "__name__", repr(clazz))
379+
380+
_RECURSION_GUARD.seen_classes[clazz] = _name
378381
try:
379382
fields = _dataclass_fields(clazz)
380383
except TypeError: # Not a dataclass
@@ -427,7 +430,7 @@ def _internal_class_schema(
427430
if field.init
428431
)
429432

430-
schema_class = type(clazz.__name__, (_base_schema(clazz, base_schema),), attributes)
433+
schema_class = type(_name, (_base_schema(clazz, base_schema),), attributes)
431434
return cast(Type[marshmallow.Schema], schema_class)
432435

433436

@@ -446,6 +449,7 @@ def _field_by_supertype(
446449
metadata: dict,
447450
base_schema: Optional[Type[marshmallow.Schema]],
448451
typ_frame: Optional[types.FrameType],
452+
generic_params_to_args: Optional[Tuple[Tuple[type, type], ...]] = None,
449453
) -> marshmallow.fields.Field:
450454
"""
451455
Return a new field for fields based on a super field. (Usually spawned from NewType)
@@ -477,6 +481,7 @@ def _field_by_supertype(
477481
default=default,
478482
base_schema=base_schema,
479483
typ_frame=typ_frame,
484+
generic_params_to_args=generic_params_to_args,
480485
)
481486

482487

@@ -501,6 +506,7 @@ def _field_for_generic_type(
501506
typ: type,
502507
base_schema: Optional[Type[marshmallow.Schema]],
503508
typ_frame: Optional[types.FrameType],
509+
generic_params_to_args: Optional[Tuple[Tuple[type, type], ...]] = None,
504510
**metadata: Any,
505511
) -> Optional[marshmallow.fields.Field]:
506512
"""
@@ -514,7 +520,10 @@ def _field_for_generic_type(
514520

515521
if origin in (list, List):
516522
child_type = field_for_schema(
517-
arguments[0], base_schema=base_schema, typ_frame=typ_frame
523+
arguments[0],
524+
base_schema=base_schema,
525+
typ_frame=typ_frame,
526+
generic_params_to_args=generic_params_to_args,
518527
)
519528
list_type = cast(
520529
Type[marshmallow.fields.List],
@@ -529,14 +538,20 @@ def _field_for_generic_type(
529538
from . import collection_field
530539

531540
child_type = field_for_schema(
532-
arguments[0], base_schema=base_schema, typ_frame=typ_frame
541+
arguments[0],
542+
base_schema=base_schema,
543+
typ_frame=typ_frame,
544+
generic_params_to_args=generic_params_to_args,
533545
)
534546
return collection_field.Sequence(cls_or_instance=child_type, **metadata)
535547
if origin in (set, Set):
536548
from . import collection_field
537549

538550
child_type = field_for_schema(
539-
arguments[0], base_schema=base_schema, typ_frame=typ_frame
551+
arguments[0],
552+
base_schema=base_schema,
553+
typ_frame=typ_frame,
554+
generic_params_to_args=generic_params_to_args,
540555
)
541556
return collection_field.Set(
542557
cls_or_instance=child_type, frozen=False, **metadata
@@ -545,14 +560,22 @@ def _field_for_generic_type(
545560
from . import collection_field
546561

547562
child_type = field_for_schema(
548-
arguments[0], base_schema=base_schema, typ_frame=typ_frame
563+
arguments[0],
564+
base_schema=base_schema,
565+
typ_frame=typ_frame,
566+
generic_params_to_args=generic_params_to_args,
549567
)
550568
return collection_field.Set(
551569
cls_or_instance=child_type, frozen=True, **metadata
552570
)
553571
if origin in (tuple, Tuple):
554572
children = tuple(
555-
field_for_schema(arg, base_schema=base_schema, typ_frame=typ_frame)
573+
field_for_schema(
574+
arg,
575+
base_schema=base_schema,
576+
typ_frame=typ_frame,
577+
generic_params_to_args=generic_params_to_args,
578+
)
556579
for arg in arguments
557580
)
558581
tuple_type = cast(
@@ -566,10 +589,16 @@ def _field_for_generic_type(
566589
dict_type = type_mapping.get(Dict, marshmallow.fields.Dict)
567590
return dict_type(
568591
keys=field_for_schema(
569-
arguments[0], base_schema=base_schema, typ_frame=typ_frame
592+
arguments[0],
593+
base_schema=base_schema,
594+
typ_frame=typ_frame,
595+
generic_params_to_args=generic_params_to_args,
570596
),
571597
values=field_for_schema(
572-
arguments[1], base_schema=base_schema, typ_frame=typ_frame
598+
arguments[1],
599+
base_schema=base_schema,
600+
typ_frame=typ_frame,
601+
generic_params_to_args=generic_params_to_args,
573602
),
574603
**metadata,
575604
)
@@ -587,6 +616,7 @@ def _field_for_generic_type(
587616
metadata=metadata,
588617
base_schema=base_schema,
589618
typ_frame=typ_frame,
619+
generic_params_to_args=generic_params_to_args,
590620
)
591621
from . import union_field
592622

@@ -599,6 +629,7 @@ def _field_for_generic_type(
599629
metadata={"required": True},
600630
base_schema=base_schema,
601631
typ_frame=typ_frame,
632+
generic_params_to_args=generic_params_to_args,
602633
),
603634
)
604635
for subtyp in subtypes
@@ -707,7 +738,9 @@ def field_for_schema(
707738
)
708739
else:
709740
subtyp = Any
710-
return field_for_schema(subtyp, default, metadata, base_schema, typ_frame)
741+
return field_for_schema(
742+
subtyp, default, metadata, base_schema, typ_frame, generic_params_to_args
743+
)
711744

712745
# Generic types
713746
generic_field = _field_for_generic_type(typ, base_schema, typ_frame, **metadata)
@@ -725,6 +758,7 @@ def field_for_schema(
725758
metadata=metadata,
726759
base_schema=base_schema,
727760
typ_frame=typ_frame,
761+
generic_params_to_args=generic_params_to_args,
728762
)
729763

730764
# enumerations

tests/test_class_schema.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -449,6 +449,31 @@ class NestedGeneric(typing.Generic[T]):
449449
with self.assertRaises(ValidationError):
450450
schema_nested_generic.load({"data": {"data": "str"}})
451451

452+
def test_generic_dataclass_repeated_fields(self):
453+
T = typing.TypeVar("T")
454+
455+
@dataclasses.dataclass
456+
class AA:
457+
a: int
458+
459+
@dataclasses.dataclass
460+
class BB(typing.Generic[T]):
461+
b: T
462+
463+
@dataclasses.dataclass
464+
class Nested:
465+
x: BB[float]
466+
z: BB[float]
467+
# if y is the first field in this class, deserialisation will fail.
468+
# see https://github.com/lovasoa/marshmallow_dataclass/pull/172#issuecomment-1334024027
469+
y: BB[AA]
470+
471+
schema_nested = class_schema(Nested)()
472+
self.assertEqual(
473+
Nested(x=BB(b=1), z=BB(b=1), y=BB(b=AA(1))),
474+
schema_nested.load({"x": {"b": 1}, "z": {"b": 1}, "y": {"b": {"a": 1}}}),
475+
)
476+
452477
def test_recursive_reference(self):
453478
@dataclasses.dataclass
454479
class Tree:

0 commit comments

Comments
 (0)