Skip to content

Commit 7d90ff6

Browse files
committed
✨ Use inspect.get_annotations and add utils for it as well
1 parent a6a7bb8 commit 7d90ff6

File tree

3 files changed

+108
-3
lines changed

3 files changed

+108
-3
lines changed

discord/errors.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
from .interactions import Interaction
4141

4242
__all__ = (
43+
"AnnotationMismatch",
4344
"DiscordException",
4445
"ClientException",
4546
"NoMoreItems",
@@ -96,6 +97,8 @@ def __init__(self):
9697
class ValidationError(DiscordException):
9798
"""An Exception that is raised when there is a Validation Error."""
9899

100+
class AnnotationMismatch(SyntaxError, ValidationError):
101+
"""An Exception that is raised when an annotation does not match the type of the value."""
99102

100103
def _flatten_error_dict(d: dict[str, Any], key: str = "") -> dict[str, str]:
101104
items: list[tuple[str, str]] = []

discord/gears/gear.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838

3939
from ..app.event_emitter import Event
4040
from ..utils import MISSING, Undefined
41+
from ..utils.annotations import get_annotations
4142
from ..utils.private import hybridmethod
4243

4344
_T = TypeVar("_T", bound="Gear")
@@ -152,11 +153,15 @@ def detach_gear(self, gear: "Gear") -> None:
152153
def _parse_listener_signature(
153154
callback: Callable[[E], Awaitable[None]], is_instance_function: bool = False
154155
) -> type[E]:
155-
params = list(inspect.signature(callback).parameters.values())
156+
params = get_annotations(
157+
callback,
158+
expected_types={0: type(Event)},
159+
custom_error="""Type annotation mismatch for parameter "{parameter}": expected <class 'Event'>, got {got}.\nEither change the signature of the callback or provide the event type explicitly with the "listen(event=Event)" parameter.""",
160+
)
156161
if is_instance_function:
157-
event = params[1].annotation
162+
event = list(params.values())[1]
158163
else:
159-
event = params[0].annotation
164+
event = next(iter(params.values()))
160165
if issubclass(event, Event):
161166
return cast(type[E], event)
162167
raise TypeError("Could not infer event type from callback. Please provide the event type explicitly.")

discord/utils/annotations.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
import functools
2+
import inspect
3+
from typing import Any, overload
4+
5+
from ..errors import AnnotationMismatch
6+
7+
8+
def _unwrap_partial(func: Any) -> Any:
9+
while isinstance(func, functools.partial):
10+
func = func.func
11+
return func
12+
13+
14+
@overload
15+
def get_annotations(
16+
obj: Any,
17+
*,
18+
globals: dict[str, Any] | None = None,
19+
locals: dict[str, Any] | None = None,
20+
eval_str: bool = False,
21+
expected_types: None = None,
22+
custom_error: None = None,
23+
) -> dict[str, Any]: ...
24+
25+
26+
@overload
27+
def get_annotations(
28+
obj: Any,
29+
*,
30+
globals: dict[str, Any] | None = None,
31+
locals: dict[str, Any] | None = None,
32+
eval_str: bool = False,
33+
expected_types: dict[int, type],
34+
custom_error: str | None = None,
35+
) -> dict[str, Any]: ...
36+
37+
38+
def get_annotations(
39+
obj: Any,
40+
*,
41+
globals: dict[str, Any] | None = None,
42+
locals: dict[str, Any] | None = None,
43+
eval_str: bool = False,
44+
expected_types: dict[int, type] | None = None,
45+
custom_error: str | None = None,
46+
) -> dict[str, Any]:
47+
unwrapped_obj = _unwrap_partial(obj)
48+
r = inspect.get_annotations(unwrapped_obj, globals=globals, locals=locals, eval_str=eval_str)
49+
50+
if expected_types is not None:
51+
for i, (k, v) in enumerate(r.items()):
52+
if i in expected_types and not isinstance(v, expected_types[i]):
53+
filename = unwrapped_obj.__code__.co_filename
54+
source_lines, start_line = inspect.getsourcelines(unwrapped_obj)
55+
56+
param_line = start_line
57+
param_text = ""
58+
col_offset = 0
59+
end_offset = -1
60+
61+
for j, line in enumerate(source_lines):
62+
if f"{k}:" in line:
63+
param_line = start_line + j
64+
param_text = line
65+
# Find column offset (position of parameter name)
66+
col_offset = line.find(k) + 1
67+
# calculate end offset by finding the next comma or closing parenthesis
68+
end_offset = line[col_offset:].find(",")
69+
if end_offset == -1:
70+
end_offset = line[col_offset:].find(")")
71+
if end_offset != -1:
72+
end_offset += col_offset + 1
73+
break
74+
75+
error = AnnotationMismatch(
76+
(
77+
custom_error
78+
or 'Type annotation mismatch for parameter "{parameter}": expected {expected}, got {got}'
79+
).format(
80+
parameter=k,
81+
expected=repr(expected_types[i]),
82+
got=repr(r[k]),
83+
)
84+
)
85+
error.filename = filename
86+
error.lineno = param_line
87+
error.offset = col_offset
88+
error.end_offset = end_offset
89+
error.end_lineno = param_line
90+
error.text = param_text
91+
92+
raise error
93+
94+
return r
95+
96+
97+
__all__ = ("get_annotations", "AnnotationMismatch")

0 commit comments

Comments
 (0)