|
| 1 | +import ast |
| 2 | +import functools |
| 3 | +import inspect |
| 4 | +import textwrap |
| 5 | +from typing import Any, overload |
| 6 | + |
| 7 | +from ..errors import AnnotationMismatch |
| 8 | + |
| 9 | + |
| 10 | +def _param_spans(obj: Any) -> dict[str, tuple[int, int, int, int, str]]: |
| 11 | + """ |
| 12 | + Get the source code spans for each parameter's annotation in a function. |
| 13 | + Returns a mapping of parameter name to a tuple of |
| 14 | + (start_line, start_col_1b, end_line, end_col_1b, line_text). |
| 15 | + 1b = 1-based column offset. |
| 16 | +
|
| 17 | + Parameters |
| 18 | + ---------- |
| 19 | + obj: |
| 20 | + The function or method to analyze. |
| 21 | +
|
| 22 | + Returns |
| 23 | + ------- |
| 24 | + dict[str, tuple[int, int, int, int, str]] |
| 25 | + Mapping of parameter names to their annotation spans. |
| 26 | + """ |
| 27 | + src, start_line = inspect.getsourcelines(obj) # original (indented) lines |
| 28 | + filename = inspect.getsourcefile(obj) or "<unknown>" |
| 29 | + |
| 30 | + # Compute common indent that dedent will remove |
| 31 | + non_empty = [l for l in src if l.strip()] |
| 32 | + common_indent = min((len(l) - len(l.lstrip(" "))) for l in non_empty) if non_empty else 0 |
| 33 | + |
| 34 | + # Parse a DEDENTED copy to get stable AST coords |
| 35 | + dedented = textwrap.dedent("".join(src)) |
| 36 | + mod = ast.parse(dedented, filename=filename, mode="exec", type_comments=True) |
| 37 | + |
| 38 | + fn = next((n for n in mod.body if isinstance(n, (ast.FunctionDef, ast.AsyncFunctionDef))), None) |
| 39 | + if fn is None: |
| 40 | + return {} |
| 41 | + |
| 42 | + def _collect_args(a: ast.arguments) -> list[tuple[ast.arg, ast.expr | None]]: |
| 43 | + out: list[tuple[ast.arg, ast.expr | None]] = [] |
| 44 | + for ar in getattr(a, "posonlyargs", []): |
| 45 | + out.append((ar, ar.annotation)) |
| 46 | + for ar in a.args: |
| 47 | + out.append((ar, ar.annotation)) |
| 48 | + if a.vararg: |
| 49 | + out.append((a.vararg, a.vararg.annotation)) |
| 50 | + for ar in a.kwonlyargs: |
| 51 | + out.append((ar, ar.annotation)) |
| 52 | + if a.kwarg: |
| 53 | + out.append((a.kwarg, a.kwarg.annotation)) |
| 54 | + return out |
| 55 | + |
| 56 | + args = _collect_args(fn.args) |
| 57 | + |
| 58 | + def _line_text_file(lineno_file: int) -> str: |
| 59 | + idx = lineno_file - start_line |
| 60 | + if 0 <= idx < len(src): |
| 61 | + return src[idx].rstrip("\n") |
| 62 | + return "" |
| 63 | + |
| 64 | + spans: dict[str, tuple[int, int, int, int, str]] = {} |
| 65 | + |
| 66 | + for ar, ann in args: |
| 67 | + name = ar.arg |
| 68 | + |
| 69 | + # AST positions are snippet-relative: lineno 1-based, col_offset 0-based |
| 70 | + ln_snip = getattr(ar, "lineno", 1) |
| 71 | + col0_snip = getattr(ar, "col_offset", 0) |
| 72 | + |
| 73 | + # Prefer annotation end if present; otherwise end at end of the name |
| 74 | + if ann is not None and hasattr(ann, "end_lineno") and hasattr(ann, "end_col_offset"): |
| 75 | + end_ln_snip = ann.end_lineno |
| 76 | + end_col0_snip = ann.end_col_offset |
| 77 | + else: |
| 78 | + end_ln_snip = ln_snip |
| 79 | + end_col0_snip = col0_snip + len(name) |
| 80 | + |
| 81 | + # Convert SNIPPET positions -> FILE positions |
| 82 | + ln_file = start_line + (ln_snip - 1) |
| 83 | + end_ln_file = start_line + (end_ln_snip - 1) |
| 84 | + |
| 85 | + # Add back the common indent that dedent removed; convert to 1-based |
| 86 | + col_1b_file = col0_snip + 1 + common_indent |
| 87 | + end_col_1b_file = end_col0_snip + 1 + common_indent |
| 88 | + |
| 89 | + line_text = _line_text_file(ln_file) |
| 90 | + # Guard: keep columns within the line |
| 91 | + line_len_1b = len(line_text) + 1 |
| 92 | + col_1b_file = max(1, min(col_1b_file, line_len_1b)) |
| 93 | + end_col_1b_file = max(col_1b_file, min(end_col_1b_file, line_len_1b)) |
| 94 | + |
| 95 | + spans[name] = (ln_file, col_1b_file, end_ln_file, end_col_1b_file, line_text) |
| 96 | + |
| 97 | + return spans |
| 98 | + |
| 99 | + |
| 100 | +def _unwrap_partial(func: Any) -> Any: |
| 101 | + while isinstance(func, functools.partial): |
| 102 | + func = func.func |
| 103 | + return func |
| 104 | + |
| 105 | + |
| 106 | +@overload |
| 107 | +def get_annotations( |
| 108 | + obj: Any, |
| 109 | + *, |
| 110 | + globals: dict[str, Any] | None = None, |
| 111 | + locals: dict[str, Any] | None = None, |
| 112 | + eval_str: bool = False, |
| 113 | + expected_types: None = None, |
| 114 | + custom_error: None = None, |
| 115 | +) -> dict[str, Any]: ... |
| 116 | + |
| 117 | + |
| 118 | +@overload |
| 119 | +def get_annotations( |
| 120 | + obj: Any, |
| 121 | + *, |
| 122 | + globals: dict[str, Any] | None = None, |
| 123 | + locals: dict[str, Any] | None = None, |
| 124 | + eval_str: bool = False, |
| 125 | + expected_types: dict[int, type], |
| 126 | + custom_error: str | None = None, |
| 127 | +) -> dict[str, Any]: ... |
| 128 | + |
| 129 | + |
| 130 | +def get_annotations( |
| 131 | + obj: Any, |
| 132 | + *, |
| 133 | + globals: dict[str, Any] | None = None, |
| 134 | + locals: dict[str, Any] | None = None, |
| 135 | + eval_str: bool = False, |
| 136 | + expected_types: dict[int, type] | None = None, |
| 137 | + custom_error: str | None = None, |
| 138 | +) -> dict[str, Any]: |
| 139 | + """ |
| 140 | + Get the type annotations of a function or method, with optional type checking. |
| 141 | +
|
| 142 | + This function unwraps `functools.partial` objects to access the original function. |
| 143 | +
|
| 144 | + This function is a modified version of `inspect.get_annotations` that adds the ability to check parameter types. |
| 145 | +
|
| 146 | + .. note:: |
| 147 | + This function is not intended to be used by end-users. |
| 148 | +
|
| 149 | + Parameters |
| 150 | + ---------- |
| 151 | + obj: |
| 152 | + The function or method to inspect. |
| 153 | + globals: |
| 154 | + The global namespace to use for evaluating string annotations. |
| 155 | + locals: |
| 156 | + The local namespace to use for evaluating string annotations. |
| 157 | + eval_str: |
| 158 | + Whether to evaluate string annotations. |
| 159 | + expected_types: |
| 160 | + A mapping of parameter index to expected type for type checking. |
| 161 | + custom_error: |
| 162 | + A custom error message format for type mismatches. Supports the following format fields: |
| 163 | + - parameter: The name of the parameter with the mismatch. |
| 164 | + - expected: The expected type. |
| 165 | + - got: The actual type found. |
| 166 | +
|
| 167 | + Returns |
| 168 | + ------- |
| 169 | + dict[str, Any] |
| 170 | + A mapping of parameter names to their type annotations. |
| 171 | + """ |
| 172 | + unwrapped_obj = _unwrap_partial(obj) |
| 173 | + r = inspect.get_annotations(unwrapped_obj, globals=globals, locals=locals, eval_str=eval_str) |
| 174 | + |
| 175 | + if expected_types is not None: |
| 176 | + for i, (k, v) in enumerate(r.items()): |
| 177 | + if i in expected_types and not isinstance(v, expected_types[i]): |
| 178 | + error = AnnotationMismatch( |
| 179 | + ( |
| 180 | + custom_error |
| 181 | + or 'Type annotation mismatch for parameter "{parameter}": expected {expected}, got {got}' |
| 182 | + ).format( |
| 183 | + parameter=k, |
| 184 | + expected=repr(expected_types[i]), |
| 185 | + got=repr(r[k]), |
| 186 | + ) |
| 187 | + ) |
| 188 | + spans = _param_spans(unwrapped_obj) |
| 189 | + |
| 190 | + if k in spans: |
| 191 | + ln, col_1b, end_ln, end_col_1b, line_text = spans[k] |
| 192 | + else: |
| 193 | + ln = unwrapped_obj.__code__.co_firstlineno |
| 194 | + line_text = inspect.getsource(unwrapped_obj).splitlines()[0] |
| 195 | + col_1b, end_ln, end_col_1b = 1, ln, len(line_text) + 1 |
| 196 | + error.filename = unwrapped_obj.__code__.co_filename |
| 197 | + error.lineno = ln |
| 198 | + error.offset = col_1b |
| 199 | + error.end_lineno = end_ln |
| 200 | + error.end_offset = end_col_1b |
| 201 | + error.text = line_text |
| 202 | + raise error |
| 203 | + |
| 204 | + return r |
| 205 | + |
| 206 | + |
| 207 | +__all__ = ("get_annotations", "AnnotationMismatch") |
0 commit comments