Skip to content

Commit ee7e1cf

Browse files
authored
Support PEP 604 (X | Y) union notation (lovasoa#219)
Fixes lovasoa#194 Note that the PEP 604 notation is only supported in Python version 3.10 and above.
1 parent 75de725 commit ee7e1cf

File tree

2 files changed

+47
-31
lines changed

2 files changed

+47
-31
lines changed

marshmallow_dataclass/__init__.py

Lines changed: 31 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -495,8 +495,8 @@ def _field_for_generic_type(
495495
If the type is a generic interface, resolve the arguments and construct the appropriate Field.
496496
"""
497497
origin = typing_inspect.get_origin(typ)
498+
arguments = typing_inspect.get_args(typ, True)
498499
if origin:
499-
arguments = typing_inspect.get_args(typ, True)
500500
# Override base_schema.TYPE_MAPPING to change the class used for generic types below
501501
type_mapping = base_schema.TYPE_MAPPING if base_schema else {}
502502

@@ -557,38 +557,38 @@ def _field_for_generic_type(
557557
),
558558
**metadata,
559559
)
560-
elif typing_inspect.is_union_type(typ):
561-
if typing_inspect.is_optional_type(typ):
562-
metadata["allow_none"] = metadata.get("allow_none", True)
563-
metadata["dump_default"] = metadata.get("dump_default", None)
564-
if not metadata.get("required"):
565-
metadata["load_default"] = metadata.get("load_default", None)
566-
metadata.setdefault("required", False)
567-
subtypes = [t for t in arguments if t is not NoneType] # type: ignore
568-
if len(subtypes) == 1:
569-
return field_for_schema(
570-
subtypes[0],
571-
metadata=metadata,
572-
base_schema=base_schema,
573-
typ_frame=typ_frame,
574-
)
575-
from . import union_field
560+
if typing_inspect.is_union_type(typ):
561+
if typing_inspect.is_optional_type(typ):
562+
metadata["allow_none"] = metadata.get("allow_none", True)
563+
metadata["dump_default"] = metadata.get("dump_default", None)
564+
if not metadata.get("required"):
565+
metadata["load_default"] = metadata.get("load_default", None)
566+
metadata.setdefault("required", False)
567+
subtypes = [t for t in arguments if t is not NoneType] # type: ignore
568+
if len(subtypes) == 1:
569+
return field_for_schema(
570+
subtypes[0],
571+
metadata=metadata,
572+
base_schema=base_schema,
573+
typ_frame=typ_frame,
574+
)
575+
from . import union_field
576576

577-
return union_field.Union(
578-
[
579-
(
577+
return union_field.Union(
578+
[
579+
(
580+
subtyp,
581+
field_for_schema(
580582
subtyp,
581-
field_for_schema(
582-
subtyp,
583-
metadata={"required": True},
584-
base_schema=base_schema,
585-
typ_frame=typ_frame,
586-
),
587-
)
588-
for subtyp in subtypes
589-
],
590-
**metadata,
591-
)
583+
metadata={"required": True},
584+
base_schema=base_schema,
585+
typ_frame=typ_frame,
586+
),
587+
)
588+
for subtyp in subtypes
589+
],
590+
**metadata,
591+
)
592592
return None
593593

594594

tests/test_union.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from dataclasses import field
2+
import sys
23
import unittest
34
from typing import List, Optional, Union, Dict
45

@@ -180,3 +181,18 @@ class IntOrStrWithDefault:
180181
)
181182
with self.assertRaises(marshmallow.exceptions.ValidationError):
182183
schema.load({"value": None})
184+
185+
@unittest.skipIf(sys.version_info < (3, 10), "No PEP604 support in py<310")
186+
def test_pep604_union(self):
187+
@dataclass
188+
class PEP604IntOrStr:
189+
value: int | str
190+
191+
schema = PEP604IntOrStr.Schema()
192+
data_in = {"value": "hello"}
193+
loaded = schema.load(data_in)
194+
self.assertEqual(loaded, PEP604IntOrStr(value="hello"))
195+
self.assertEqual(schema.dump(loaded), data_in)
196+
197+
data_in = {"value": 42}
198+
self.assertEqual(schema.dump(schema.load(data_in)), data_in)

0 commit comments

Comments
 (0)