Skip to content

Commit 97f6e27

Browse files
committed
✨ Use inspect.get_annotations and add utils for it as well
1 parent a6a7bb8 commit 97f6e27

File tree

3 files changed

+221
-6
lines changed

3 files changed

+221
-6
lines changed

discord/errors.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
from .interactions import Interaction
4141

4242
__all__ = (
43+
"AnnotationMismatch",
4344
"DiscordException",
4445
"ClientException",
4546
"NoMoreItems",
@@ -97,6 +98,10 @@ class ValidationError(DiscordException):
9798
"""An Exception that is raised when there is a Validation Error."""
9899

99100

101+
class AnnotationMismatch(SyntaxError, ValidationError):
102+
"""An Exception that is raised when an annotation does not match the type of the value."""
103+
104+
100105
def _flatten_error_dict(d: dict[str, Any], key: str = "") -> dict[str, str]:
101106
items: list[tuple[str, str]] = []
102107
for k, v in d.items():

discord/gears/gear.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838

3939
from ..app.event_emitter import Event
4040
from ..utils import MISSING, Undefined
41+
from ..utils.annotations import get_annotations
4142
from ..utils.private import hybridmethod
4243

4344
_T = TypeVar("_T", bound="Gear")
@@ -152,14 +153,16 @@ def detach_gear(self, gear: "Gear") -> None:
152153
def _parse_listener_signature(
153154
callback: Callable[[E], Awaitable[None]], is_instance_function: bool = False
154155
) -> type[E]:
155-
params = list(inspect.signature(callback).parameters.values())
156+
params = get_annotations(
157+
callback,
158+
expected_types={0: type(Event)},
159+
custom_error="""Type annotation mismatch for parameter "{parameter}": expected <class 'Event'>, got {got}.""",
160+
)
156161
if is_instance_function:
157-
event = params[1].annotation
162+
event = list(params.values())[1]
158163
else:
159-
event = params[0].annotation
160-
if issubclass(event, Event):
161-
return cast(type[E], event)
162-
raise TypeError("Could not infer event type from callback. Please provide the event type explicitly.")
164+
event = next(iter(params.values()))
165+
return cast(type[E], event)
163166

164167
def add_listener(
165168
self,

discord/utils/annotations.py

Lines changed: 207 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,207 @@
1+
import ast
2+
import functools
3+
import inspect
4+
import textwrap
5+
from typing import Any, overload
6+
7+
from ..errors import AnnotationMismatch
8+
9+
10+
def _param_spans(obj: Any) -> dict[str, tuple[int, int, int, int, str]]:
11+
"""
12+
Get the source code spans for each parameter's annotation in a function.
13+
Returns a mapping of parameter name to a tuple of
14+
(start_line, start_col_1b, end_line, end_col_1b, line_text).
15+
1b = 1-based column offset.
16+
17+
Parameters
18+
----------
19+
obj:
20+
The function or method to analyze.
21+
22+
Returns
23+
-------
24+
dict[str, tuple[int, int, int, int, str]]
25+
Mapping of parameter names to their annotation spans.
26+
"""
27+
src, start_line = inspect.getsourcelines(obj) # original (indented) lines
28+
filename = inspect.getsourcefile(obj) or "<unknown>"
29+
30+
# Compute common indent that dedent will remove
31+
non_empty = [l for l in src if l.strip()]
32+
common_indent = min((len(l) - len(l.lstrip(" "))) for l in non_empty) if non_empty else 0
33+
34+
# Parse a DEDENTED copy to get stable AST coords
35+
dedented = textwrap.dedent("".join(src))
36+
mod = ast.parse(dedented, filename=filename, mode="exec", type_comments=True)
37+
38+
fn = next((n for n in mod.body if isinstance(n, (ast.FunctionDef, ast.AsyncFunctionDef))), None)
39+
if fn is None:
40+
return {}
41+
42+
def _collect_args(a: ast.arguments) -> list[tuple[ast.arg, ast.expr | None]]:
43+
out: list[tuple[ast.arg, ast.expr | None]] = []
44+
for ar in getattr(a, "posonlyargs", []):
45+
out.append((ar, ar.annotation))
46+
for ar in a.args:
47+
out.append((ar, ar.annotation))
48+
if a.vararg:
49+
out.append((a.vararg, a.vararg.annotation))
50+
for ar in a.kwonlyargs:
51+
out.append((ar, ar.annotation))
52+
if a.kwarg:
53+
out.append((a.kwarg, a.kwarg.annotation))
54+
return out
55+
56+
args = _collect_args(fn.args)
57+
58+
def _line_text_file(lineno_file: int) -> str:
59+
idx = lineno_file - start_line
60+
if 0 <= idx < len(src):
61+
return src[idx].rstrip("\n")
62+
return ""
63+
64+
spans: dict[str, tuple[int, int, int, int, str]] = {}
65+
66+
for ar, ann in args:
67+
name = ar.arg
68+
69+
# AST positions are snippet-relative: lineno 1-based, col_offset 0-based
70+
ln_snip = getattr(ar, "lineno", 1)
71+
col0_snip = getattr(ar, "col_offset", 0)
72+
73+
# Prefer annotation end if present; otherwise end at end of the name
74+
if ann is not None and hasattr(ann, "end_lineno") and hasattr(ann, "end_col_offset"):
75+
end_ln_snip = ann.end_lineno
76+
end_col0_snip = ann.end_col_offset
77+
else:
78+
end_ln_snip = ln_snip
79+
end_col0_snip = col0_snip + len(name)
80+
81+
# Convert SNIPPET positions -> FILE positions
82+
ln_file = start_line + (ln_snip - 1)
83+
end_ln_file = start_line + (end_ln_snip - 1)
84+
85+
# Add back the common indent that dedent removed; convert to 1-based
86+
col_1b_file = col0_snip + 1 + common_indent
87+
end_col_1b_file = end_col0_snip + 1 + common_indent
88+
89+
line_text = _line_text_file(ln_file)
90+
# Guard: keep columns within the line
91+
line_len_1b = len(line_text) + 1
92+
col_1b_file = max(1, min(col_1b_file, line_len_1b))
93+
end_col_1b_file = max(col_1b_file, min(end_col_1b_file, line_len_1b))
94+
95+
spans[name] = (ln_file, col_1b_file, end_ln_file, end_col_1b_file, line_text)
96+
97+
return spans
98+
99+
100+
def _unwrap_partial(func: Any) -> Any:
101+
while isinstance(func, functools.partial):
102+
func = func.func
103+
return func
104+
105+
106+
@overload
107+
def get_annotations(
108+
obj: Any,
109+
*,
110+
globals: dict[str, Any] | None = None,
111+
locals: dict[str, Any] | None = None,
112+
eval_str: bool = False,
113+
expected_types: None = None,
114+
custom_error: None = None,
115+
) -> dict[str, Any]: ...
116+
117+
118+
@overload
119+
def get_annotations(
120+
obj: Any,
121+
*,
122+
globals: dict[str, Any] | None = None,
123+
locals: dict[str, Any] | None = None,
124+
eval_str: bool = False,
125+
expected_types: dict[int, type],
126+
custom_error: str | None = None,
127+
) -> dict[str, Any]: ...
128+
129+
130+
def get_annotations(
131+
obj: Any,
132+
*,
133+
globals: dict[str, Any] | None = None,
134+
locals: dict[str, Any] | None = None,
135+
eval_str: bool = False,
136+
expected_types: dict[int, type] | None = None,
137+
custom_error: str | None = None,
138+
) -> dict[str, Any]:
139+
"""
140+
Get the type annotations of a function or method, with optional type checking.
141+
142+
This function unwraps `functools.partial` objects to access the original function.
143+
144+
This function is a modified version of `inspect.get_annotations` that adds the ability to check parameter types.
145+
146+
.. note::
147+
This function is not intended to be used by end-users.
148+
149+
Parameters
150+
----------
151+
obj:
152+
The function or method to inspect.
153+
globals:
154+
The global namespace to use for evaluating string annotations.
155+
locals:
156+
The local namespace to use for evaluating string annotations.
157+
eval_str:
158+
Whether to evaluate string annotations.
159+
expected_types:
160+
A mapping of parameter index to expected type for type checking.
161+
custom_error:
162+
A custom error message format for type mismatches. Supports the following format fields:
163+
- parameter: The name of the parameter with the mismatch.
164+
- expected: The expected type.
165+
- got: The actual type found.
166+
167+
Returns
168+
-------
169+
dict[str, Any]
170+
A mapping of parameter names to their type annotations.
171+
"""
172+
unwrapped_obj = _unwrap_partial(obj)
173+
r = inspect.get_annotations(unwrapped_obj, globals=globals, locals=locals, eval_str=eval_str)
174+
175+
if expected_types is not None:
176+
for i, (k, v) in enumerate(r.items()):
177+
if i in expected_types and not isinstance(v, expected_types[i]):
178+
error = AnnotationMismatch(
179+
(
180+
custom_error
181+
or 'Type annotation mismatch for parameter "{parameter}": expected {expected}, got {got}'
182+
).format(
183+
parameter=k,
184+
expected=repr(expected_types[i]),
185+
got=repr(r[k]),
186+
)
187+
)
188+
spans = _param_spans(unwrapped_obj)
189+
190+
if k in spans:
191+
ln, col_1b, end_ln, end_col_1b, line_text = spans[k]
192+
else:
193+
ln = unwrapped_obj.__code__.co_firstlineno
194+
line_text = inspect.getsource(unwrapped_obj).splitlines()[0]
195+
col_1b, end_ln, end_col_1b = 1, ln, len(line_text) + 1
196+
error.filename = unwrapped_obj.__code__.co_filename
197+
error.lineno = ln
198+
error.offset = col_1b
199+
error.end_lineno = end_ln
200+
error.end_offset = end_col_1b
201+
error.text = line_text
202+
raise error
203+
204+
return r
205+
206+
207+
__all__ = ("get_annotations", "AnnotationMismatch")

0 commit comments

Comments
 (0)