11from __future__ import annotations
22
33from abc import ABC
4- from typing import Any , Literal , get_args , get_type_hints
4+ from typing import Annotated , Any , Generic , Literal , get_args , get_origin
55
6- from pydantic import BaseModel , ConfigDict
6+ from pydantic import BaseModel , ConfigDict , GetPydanticSchema
7+ from pydantic ._internal ._generics import get_origin as get_model_origin # type: ignore[import]
8+ from pydantic_core import core_schema
79from typing_extensions import ( # noqa: UP035
810 LiteralString ,
9- TypeIs ,
11+ TypeAlias ,
12+ TypeGuard ,
13+ TypeVar ,
1014 override ,
1115)
1216
@@ -33,12 +37,58 @@ def model_dump_json(self, **kwargs: Any) -> str:
3337 return super ().model_dump_json (** kwargs )
3438
3539
36- def is_literal_str_type (value : object | None ) -> TypeIs [LiteralString ]:
37- """Check if a type is a Literal type with string values."""
40+ # We do this to get the typing module's _LiteralGenericAlias type, which is not formally exported.
41+ _LiteralStrGenericAlias : TypeAlias = type (Literal ["whatever" ]) # type: ignore[valid-type] # noqa: UP040
42+ """A generic alias for a Literal type used for internal mechanisms of this module.
43+
44+ This is opposed to LiteralStrGenericAlias which is used for typing.
45+ """
46+
47+
48+ # Set this variable here to call the function just once.
49+ _pydantic_str_schema = core_schema .str_schema ()
50+
51+ GetPydanticStrSchema = GetPydanticSchema (lambda _ts , handler : handler (_pydantic_str_schema ))
52+ """A function that returns a Pydantic schema for a string type."""
53+
54+ PydanticLiteralStrGenericAlias : TypeAlias = Annotated [ # type: ignore[valid-type] # noqa: UP040
55+ _LiteralStrGenericAlias ,
56+ GetPydanticStrSchema ,
57+ ]
58+ """A Pydantic-compatible generic alias for a Literal type.
59+
60+ Pydantic will treat a field of this type as a string schema, while static type checkers
61+ will still treat it as a _LiteralGenericAlias type.
62+
63+ Even if a subclass of EventBase uses a Literal with multiple string values,
64+ an event message will only ever have one of those values in the event field,
65+ and so we don't need to handle this with a more complex Pydantic schema.
66+ """
67+
68+
69+ # This type alias is used to handle static type checking accurately while still conveying that
70+ # a value is expected to be a Literal with string type args.
71+ LiteralStrGenericAlias : TypeAlias = Annotated [ # noqa: UP040
72+ LiteralString ,
73+ GetPydanticStrSchema ,
74+ ]
75+ """Type alias for a generic literal string type that is compatible with Pydantic."""
76+
77+
78+ # covariant=True is used to allow subclasses of EventBase to be used in place of the base class.
79+ LiteralEventName_co = TypeVar ("LiteralEventName_co" , bound = PydanticLiteralStrGenericAlias , default = PydanticLiteralStrGenericAlias , covariant = True )
80+ """Type variable for a Literal type with string args."""
81+
82+
83+ def is_literal_str_generic_alias_type (value : object | None ) -> TypeGuard [LiteralStrGenericAlias ]:
84+ """Check if a type is a concrete Literal type with string args."""
3885 if value is None :
3986 return False
4087
41- event_field_base_type = getattr (value , "__origin__" , None )
88+ if isinstance (value , TypeVar ):
89+ return False
90+
91+ event_field_base_type = get_origin (value )
4292
4393 if event_field_base_type is not Literal :
4494 return False
@@ -48,12 +98,10 @@ def is_literal_str_type(value: object | None) -> TypeIs[LiteralString]:
4898
4999## EventBase implementation model of the Stream Deck Plugin SDK events.
50100
51- class EventBase (ConfiguredBaseModel , ABC ):
101+ class EventBase (ConfiguredBaseModel , ABC , Generic [ LiteralEventName_co ] ):
52102 """Base class for event models that represent Stream Deck Plugin SDK events."""
53- # Configure to use the docstrings of the fields as the field descriptions.
54- model_config = ConfigDict (use_attribute_docstrings = True , serialize_by_alias = True )
55103
56- event : str
104+ event : LiteralEventName_co
57105 """Name of the event used to identify what occurred.
58106
59107 Subclass models must define this field as a Literal type with the event name string that the model represents.
@@ -63,25 +111,30 @@ def __init_subclass__(cls, **kwargs: Any) -> None:
63111 """Validate that the event field is a Literal[str] type."""
64112 super ().__init_subclass__ (** kwargs )
65113
66- model_event_type = cls .get_event_type_annotations ()
114+ # This is a GenericAlias (likely used in the subclass definition, i.e. `class ConcreteEvent(EventBase[Literal["event_name"]]):`) which is technically a subclass.
115+ # We can safely ignore this case, as we only want to validate the concrete subclass itself (`ConscreteEvent`).
116+ if get_model_origin (cls ) is None :
117+ return
118+
119+ model_event_type = cls .__event_type__ ()
67120
68- if not is_literal_str_type (model_event_type ):
121+ if not is_literal_str_generic_alias_type (model_event_type ):
69122 msg = f"The event field annotation must be a Literal[str] type. Given type: { model_event_type } "
70123 raise TypeError (msg )
71124
72125 @classmethod
73- def get_event_type_annotations (cls ) -> type [object ]:
126+ def __event_type__ (cls ) -> type [object ]:
74127 """Get the type annotations of the subclass model's event field."""
75- return get_type_hints ( cls ) ["event" ]
128+ return cls . model_fields ["event" ]. annotation # type: ignore[index ]
76129
77130 @classmethod
78- def get_model_event_name (cls ) -> tuple [str , ...]:
131+ def get_model_event_names (cls ) -> tuple [str , ...]:
79132 """Get the value of the subclass model's event field Literal annotation."""
80- model_event_type = cls .get_event_type_annotations ()
133+ model_event_type = cls .__event_type__ ()
81134
82135 # Ensure that the event field annotation is a Literal type.
83- if not is_literal_str_type (model_event_type ):
84- msg = "The ` event` field annotation of an Event model must be a Literal[str] type."
136+ if not is_literal_str_generic_alias_type (model_event_type ):
137+ msg = f "The event field annotation of an Event model must be a Literal[str] type. Given type: { model_event_type } "
85138 raise TypeError (msg )
86139
87140 return get_args (model_event_type )
0 commit comments