Skip to content

Commit a6c4bc3

Browse files
authored
Allow setting required=True on fields of type Optional (lovasoa#159)
* Allow setting `required=True` on fields of type `Optional` * Brown bag... I made a last-minute change to the new test, and didn't check that it still ran.
1 parent 0541880 commit a6c4bc3

File tree

2 files changed

+21
-3
lines changed

2 files changed

+21
-3
lines changed

marshmallow_dataclass/__init__.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -491,8 +491,9 @@ def _field_for_generic_type(
491491
if typing_inspect.is_optional_type(typ):
492492
metadata["allow_none"] = metadata.get("allow_none", True)
493493
metadata["dump_default"] = metadata.get("dump_default", None)
494-
metadata["load_default"] = metadata.get("load_default", None)
495-
metadata["required"] = False
494+
if not metadata.get("required"):
495+
metadata["load_default"] = metadata.get("load_default", None)
496+
metadata.setdefault("required", False)
496497
subtypes = [t for t in arguments if t is not NoneType] # type: ignore
497498
if len(subtypes) == 1:
498499
return field_for_schema(
@@ -549,7 +550,7 @@ def field_for_schema(
549550
if not metadata.get("required"):
550551
metadata.setdefault("load_default", default)
551552
else:
552-
metadata.setdefault("required", True)
553+
metadata.setdefault("required", not typing_inspect.is_optional_type(typ))
553554

554555
# If the field was already defined by the user
555556
predefined_field = metadata.get("marshmallow_field")

tests/test_optional.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,3 +37,20 @@ class OptionalValueNotNone:
3737
self.assertEqual(
3838
exc_cm.exception.messages, {"value": ["Field may not be null."]}
3939
)
40+
41+
def test_required_optional_field(self):
42+
@dataclass
43+
class RequiredOptionalValue:
44+
value: Optional[str] = field(default="default", metadata={"required": True})
45+
46+
schema = RequiredOptionalValue.Schema()
47+
48+
self.assertEqual(schema.load({"value": None}), RequiredOptionalValue(None))
49+
self.assertEqual(
50+
schema.load({"value": "hello"}), RequiredOptionalValue(value="hello")
51+
)
52+
with self.assertRaises(marshmallow.exceptions.ValidationError) as exc_cm:
53+
schema.load({})
54+
self.assertEqual(
55+
exc_cm.exception.messages, {"value": ["Missing data for required field."]}
56+
)

0 commit comments

Comments
 (0)