From 6b75caeea540167ebe6459593b43b77f8f5417f1 Mon Sep 17 00:00:00 2001 From: Jason Koh Date: Mon, 9 Dec 2024 11:01:46 -0800 Subject: [PATCH 1/5] proper enum schema for pydantic --- src/betterproto/templates/template.py.j2 | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/src/betterproto/templates/template.py.j2 b/src/betterproto/templates/template.py.j2 index 4a252aec..c7f9197a 100644 --- a/src/betterproto/templates/template.py.j2 +++ b/src/betterproto/templates/template.py.j2 @@ -17,7 +17,16 @@ class {{ enum.py_name }}(betterproto.Enum): def __get_pydantic_core_schema__(cls, _source_type, _handler): from pydantic_core import core_schema - return core_schema.int_schema(ge=0) + def validate(value: int) -> "{{ enum.py_name }}": + return cls(value) + + # Return the schema for validation and serialization + return core_schema.chain_schema( + [ + core_schema.int_schema(), # Validate as a string first + core_schema.no_info_plain_validator_function(validate), # Custom validation + ] + ) {% endif %} {% endfor %} From 683de47a0e22fbc13b304998fb49a3f16dfd4e77 Mon Sep 17 00:00:00 2001 From: Jason Koh Date: Mon, 9 Dec 2024 11:10:23 -0800 Subject: [PATCH 2/5] ge --- src/betterproto/templates/template.py.j2 | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/betterproto/templates/template.py.j2 b/src/betterproto/templates/template.py.j2 index c7f9197a..ce9bc52a 100644 --- a/src/betterproto/templates/template.py.j2 +++ b/src/betterproto/templates/template.py.j2 @@ -23,7 +23,7 @@ class {{ enum.py_name }}(betterproto.Enum): # Return the schema for validation and serialization return core_schema.chain_schema( [ - core_schema.int_schema(), # Validate as a string first + core_schema.int_schema(ge=0), # Validate as a string first core_schema.no_info_plain_validator_function(validate), # Custom validation ] ) From 6619065718602025d7745f4024d567755aa28e63 Mon Sep 17 00:00:00 2001 From: Jason Koh Date: Mon, 9 Dec 2024 11:17:22 -0800 Subject: [PATCH 3/5] add a test --- tests/inputs/enum/test_enum.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/tests/inputs/enum/test_enum.py b/tests/inputs/enum/test_enum.py index 21a5ac3b..fccdf338 100644 --- a/tests/inputs/enum/test_enum.py +++ b/tests/inputs/enum/test_enum.py @@ -4,6 +4,11 @@ Test, ) +from tests.output_betterproto_pydantic.enum import ( + Test as TestPyd, + Choice as ChoicePyd, +) + def test_enum_set_and_get(): assert Test(choice=Choice.ZERO).choice == Choice.ZERO @@ -112,3 +117,7 @@ def test_renamed_enum_members(): "MINUS", "_0_PREFIXED", } + +def test_pydantic_enum_preserve_type(): + test = TestPyd(choice=ChoicePyd.ZERO) + assert isinstance(test.choice, ChoicePyd) From b197d5ed47baf3a1312ad5372f01390a91ecb1d5 Mon Sep 17 00:00:00 2001 From: Jason Koh Date: Tue, 10 Dec 2024 10:14:33 -0800 Subject: [PATCH 4/5] reformat --- tests/inputs/enum/test_enum.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/inputs/enum/test_enum.py b/tests/inputs/enum/test_enum.py index fccdf338..578aba87 100644 --- a/tests/inputs/enum/test_enum.py +++ b/tests/inputs/enum/test_enum.py @@ -3,10 +3,9 @@ Choice, Test, ) - from tests.output_betterproto_pydantic.enum import ( - Test as TestPyd, Choice as ChoicePyd, + Test as TestPyd, ) @@ -118,6 +117,7 @@ def test_renamed_enum_members(): "_0_PREFIXED", } + def test_pydantic_enum_preserve_type(): test = TestPyd(choice=ChoicePyd.ZERO) assert isinstance(test.choice, ChoicePyd) From 1982d789a5d2d273c8a6fb237087ec6df5ead245 Mon Sep 17 00:00:00 2001 From: Jason Koh Date: Tue, 18 Mar 2025 00:01:39 -0700 Subject: [PATCH 5/5] use lambda --- src/betterproto/templates/template.py.j2 | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/betterproto/templates/template.py.j2 b/src/betterproto/templates/template.py.j2 index ce9bc52a..a08b36cb 100644 --- a/src/betterproto/templates/template.py.j2 +++ b/src/betterproto/templates/template.py.j2 @@ -17,14 +17,11 @@ class {{ enum.py_name }}(betterproto.Enum): def __get_pydantic_core_schema__(cls, _source_type, _handler): from pydantic_core import core_schema - def validate(value: int) -> "{{ enum.py_name }}": - return cls(value) - # Return the schema for validation and serialization return core_schema.chain_schema( [ core_schema.int_schema(ge=0), # Validate as a string first - core_schema.no_info_plain_validator_function(validate), # Custom validation + core_schema.no_info_plain_validator_function(lambda value: cls(value)), # Custom validation ] ) {% endif %}