Skip to content

Commit e2b654d

Browse files
committed
Support pydantic validation & serialization
Fixes #1 Unfortunately, the unique constraint I talked about in the issue will probably need to be put in a seperate package.
1 parent 436ba28 commit e2b654d

File tree

3 files changed

+75
-3
lines changed

3 files changed

+75
-3
lines changed

pyproject.toml

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,17 @@ requires-python = ">=3.9"
1515
dependencies = []
1616
dynamic = ["version"]
1717

18+
[project.optional-dependencies]
19+
pydantic = [
20+
"pydantic~=2.11",
21+
# needed for improved get_args
22+
'typing_extensions>=4.12; python_version<="3.9"',
23+
]
24+
1825
[dependency-groups]
1926
mypy = ["mypy~=1.0", {include-group = "typing"}]
20-
test = ["pytest~=8.3", "pytest-sugar~=1.0"]
21-
typing = ["typing-extensions~=4.12"]
27+
test = ["pytest~=8.3", "pytest-sugar~=1.0", "techcable.orderedset[pydantic]"]
28+
typing = ["typing-extensions~=4.12", "techcable.orderedset[pydantic]"]
2229
dev = [{include-group = "mypy"}, {include-group = "test"}]
2330

2431
[project.urls]

src/techcable/orderedset/_orderedset.py

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@
55
from typing import TYPE_CHECKING, Any, Optional, TypeVar, overload
66

77
if TYPE_CHECKING:
8-
from typing_extensions import Protocol, override
8+
from pydantic import GetCoreSchemaHandler
9+
from pydantic_core import core_schema
10+
from typing_extensions import Protocol, get_args, override
911

1012
class Comparable(Protocol):
1113
def __lt__(self, other: Comparable) -> bool:
@@ -17,12 +19,25 @@ def __gt__(self, other: Comparable) -> bool:
1719
def __eq__(self, other: object) -> bool:
1820
pass
1921

22+
2023
else:
2124
Comparable = object
2225

2326
def override(v):
2427
return v
2528

29+
try:
30+
from pydantic_core import core_schema
31+
except ImportError:
32+
pass
33+
34+
try:
35+
# prefer typing_extensions.get_args for python3.9
36+
# needed for pydantic
37+
from typing_extensions import get_args
38+
except ImportError:
39+
from typing import get_args
40+
2641

2742
T = TypeVar("T")
2843
U = TypeVar("U", bound=Comparable)
@@ -203,5 +218,35 @@ def __repr__(self) -> str:
203218
def __str__(self) -> str:
204219
return f"{{{', '.join(map(repr, self))}}}"
205220

221+
@classmethod
222+
def __get_pydantic_core_schema__(
223+
cls, source_type: Any, handler: GetCoreSchemaHandler
224+
) -> core_schema.CoreSchema:
225+
# See here: https://docs.pydantic.dev/latest/concepts/types/#generic-containers
226+
instance_schema = core_schema.is_instance_schema(cls)
227+
228+
args = get_args(source_type)
229+
if args:
230+
# replace the type and rely on Pydantic to generate the right schema for `Sequence`
231+
target_type: type = Sequence[args[0]] # type: ignore
232+
sequence_t_schema = handler.generate_schema(target_type)
233+
else:
234+
sequence_t_schema = handler.generate_schema(Sequence)
235+
236+
non_instance_schema = core_schema.no_info_after_validator_function(
237+
OrderedSet, sequence_t_schema
238+
)
239+
return core_schema.union_schema(
240+
[
241+
instance_schema,
242+
non_instance_schema,
243+
],
244+
serialization=core_schema.plain_serializer_function_ser_schema(
245+
list, # OrderedSet -> list
246+
info_arg=False,
247+
return_schema=core_schema.list_schema(),
248+
),
249+
)
250+
206251

207252
__all__ = ("OrderedSet",)

tests/test_pydantic.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
import pydantic
2+
3+
from techcable.orderedset import OrderedSet
4+
5+
6+
def test_deser_orderedset() -> None:
7+
raw_data = [1, 2, 7, 8, 9, 1]
8+
res = pydantic.TypeAdapter(OrderedSet[int]).validate_python(raw_data)
9+
assert isinstance(res, OrderedSet)
10+
assert res == OrderedSet(raw_data)
11+
res = pydantic.TypeAdapter(OrderedSet[int]).validate_json("[1,2,7,8,9,1]")
12+
assert res == OrderedSet(raw_data)
13+
14+
15+
def test_ser_orderedset() -> None:
16+
raw_data = [1, 2, 7, 8, 9, 1]
17+
oset = OrderedSet(raw_data)
18+
ser = pydantic.TypeAdapter(OrderedSet[int]).dump_python(oset)
19+
assert isinstance(ser, list)
20+
assert ser == list(OrderedSet(raw_data))

0 commit comments

Comments
 (0)