55from typing import TYPE_CHECKING , Any , Optional , TypeVar , overload
66
77if TYPE_CHECKING :
8- from typing_extensions import Protocol , override
8+ from pydantic import GetCoreSchemaHandler
9+ from pydantic_core import core_schema
10+ from typing_extensions import Protocol , get_args , override
911
1012 class Comparable (Protocol ):
1113 def __lt__ (self , other : Comparable ) -> bool :
@@ -17,12 +19,25 @@ def __gt__(self, other: Comparable) -> bool:
1719 def __eq__ (self , other : object ) -> bool :
1820 pass
1921
22+
2023else :
2124 Comparable = object
2225
2326 def override (v ):
2427 return v
2528
29+ try :
30+ from pydantic_core import core_schema
31+ except ImportError :
32+ pass
33+
34+ try :
35+ # prefer typing_extensions.get_args for python3.9
36+ # needed for pydantic
37+ from typing_extensions import get_args
38+ except ImportError :
39+ from typing import get_args
40+
2641
2742T = TypeVar ("T" )
2843U = TypeVar ("U" , bound = Comparable )
@@ -203,5 +218,35 @@ def __repr__(self) -> str:
203218 def __str__ (self ) -> str :
204219 return f"{{{ ', ' .join (map (repr , self ))} }}"
205220
221+ @classmethod
222+ def __get_pydantic_core_schema__ (
223+ cls , source_type : Any , handler : GetCoreSchemaHandler
224+ ) -> core_schema .CoreSchema :
225+ # See here: https://docs.pydantic.dev/latest/concepts/types/#generic-containers
226+ instance_schema = core_schema .is_instance_schema (cls )
227+
228+ args = get_args (source_type )
229+ if args :
230+ # replace the type and rely on Pydantic to generate the right schema for `Sequence`
231+ target_type : type = Sequence [args [0 ]] # type: ignore
232+ sequence_t_schema = handler .generate_schema (target_type )
233+ else :
234+ sequence_t_schema = handler .generate_schema (Sequence )
235+
236+ non_instance_schema = core_schema .no_info_after_validator_function (
237+ OrderedSet , sequence_t_schema
238+ )
239+ return core_schema .union_schema (
240+ [
241+ instance_schema ,
242+ non_instance_schema ,
243+ ],
244+ serialization = core_schema .plain_serializer_function_ser_schema (
245+ list , # OrderedSet -> list
246+ info_arg = False ,
247+ return_schema = core_schema .list_schema (),
248+ ),
249+ )
250+
206251
207252__all__ = ("OrderedSet" ,)
0 commit comments