From c9216524f0471820eb3a45bfb2605a96d68b4fde Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lars=20Gr=C3=BCter?= Date: Sat, 19 Jul 2025 17:30:52 +0200 Subject: [PATCH 1/2] WIP Collect types into a tree data structure This would likely preserve more information and would make name matching a lot more efficient. Long-term it might even help with collecting from the typeshed. --- src/docstub/_analysis.py | 193 ++++++++++++++++++++++++++++++++++++++- 1 file changed, 192 insertions(+), 1 deletion(-) diff --git a/src/docstub/_analysis.py b/src/docstub/_analysis.py index a5ff9f3..b1467d8 100644 --- a/src/docstub/_analysis.py +++ b/src/docstub/_analysis.py @@ -5,9 +5,10 @@ import json import logging import re -from dataclasses import asdict, dataclass +from dataclasses import asdict, dataclass, field from functools import cache from pathlib import Path +from typing import Self import libcst as cst import libcst.matchers as cstm @@ -291,6 +292,196 @@ def common_known_types(): return known_imports +class Node: + def __init__(self, name): + self.name = name + self._parent = None + self._children = [] + + @property + def parent(self): + return self._parent + + @property + def children(self): + return self._children.copy() + + @property + def is_leaf(self): + return not self._children + + @property + def fullname(self): + names = [node.name for node in self.walk_up()][::-1] + return ".".join(names) + + def add_child(self, child): + assert child.parent is None + child._parent = self + self._children.append(child) + + def walk_down(self): + yield self + for child in self._children: + yield from child.walk_down() + + def walk_up(self): + current = self + while current.parent is not None: + yield current + current = current.parent + + def __repr__(self): + return f"{type(self).__name__}({self.name!r})" + + +class Tree(Node): + def __init__(self): + super().__init__(name=None) + + def add_child(self, child): + if not isinstance(child, ModuleNode): + raise TypeError("expected new child to by a module") + return super().add_child(child) + + def get(self): + + +class ModuleNode(Node): + + def __init__(self, name, file_path): + super().__init__(name=name) + self.file_path = file_path + + +class _InModuleNode(Node): + pass + + +class ClassNode(_InModuleNode): + pass + + +class TypeAliasNode(_InModuleNode): + pass + + +class ImportFromNode(_InModuleNode): + pass + + +class NodeCollector(cst.CSTVisitor): + """Collect types from a given Python file. + + Examples + -------- + >>> types = NodeCollector.collect(__file__) + >>> types[f"{__name__}.TypeCollector"] + + """ + + class ImportSerializer: + """Implements the `FuncSerializer` protocol to cache `TypeCollector.collect`.""" + + suffix = ".json" + encoding = "utf-8" + + def hash_args(self, path: Path) -> str: + """Compute a unique hash from the path passed to `TypeCollector.collect`.""" + key = pyfile_checksum(path) + return key + + def serialize(self, data: dict[str, KnownImport]) -> bytes: + """Serialize results from `TypeCollector.collect`.""" + primitives = {qualname: asdict(imp) for qualname, imp in data.items()} + raw = json.dumps(primitives, separators=(",", ":")).encode(self.encoding) + return raw + + def deserialize(self, raw: bytes) -> dict[str, KnownImport]: + """Deserialize results from `TypeCollector.collect`.""" + primitives = json.loads(raw.decode(self.encoding)) + data = {qualname: KnownImport(**kw) for qualname, kw in primitives.items()} + return data + + @classmethod + def collect(cls, file_path): + """Collect importable type annotations in given file. + + Parameters + ---------- + file : Path + + Returns + ------- + collected : dict[str, KnownImport] + """ + file_path = Path(file_path) + with file_path.open("r") as fo: + source = fo.read() + + tree = cst.parse_module(source) + collector = cls(file_path=file_path) + tree.visit(collector) + return collector._root_node + + def __init__(self, *, file_path): + """Initialize type collector. + + Parameters + ---------- + module_name : str + """ + assert "." not in file_path.stem + self._root_node = ModuleNode(name=file_path.stem, file_path=file_path) + self._current_node = self._root_node + + def visit_ClassDef(self, node: cst.ClassDef) -> bool: + node = ClassNode(name=node.name.value) + self._current_node.add_child(node) + self._current_node = node + return True + + def leave_ClassDef(self, original_node: cst.ClassDef) -> None: + self._current_node = self._current_node.parent + + def visit_FunctionDef(self, node: cst.FunctionDef) -> bool: + return False + + def visit_TypeAlias(self, node: cst.TypeAlias) -> bool: + """Collect type alias with 3.12 syntax.""" + node = TypeAliasNode(name=node.name.value) + self._current_node.add_child(node) + return False + + def visit_AnnAssign(self, node: cst.AnnAssign) -> bool: + """Collect type alias annotated with `TypeAlias`.""" + is_type_alias = cstm.matches( + node, + cstm.AnnAssign( + annotation=cstm.Annotation(annotation=cstm.Name(value="TypeAlias")) + ), + ) + if is_type_alias and node.value is not None: + names = cstm.findall(node.target, cstm.Name()) + assert len(names) == 1 + node = Node(name=names[0].value) + self._current_node.add_child(node) + return False + + def visit_ImportFrom(self, node: cst.ImportFrom) -> None: + """Collect "from import" targets as usable types within each module.""" + for import_alias in node.names: + if cstm.matches(import_alias, cstm.ImportStar()): + continue + name = import_alias.evaluated_alias + if name is None: + name = import_alias.evaluated_name + assert isinstance(name, str) + + node = ImportFromNode(name=name) + self._current_node.add_child(node) + + class TypeCollector(cst.CSTVisitor): """Collect types from a given Python file. From 58d7037a8462670167c5b6a20535f698546b8a71 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lars=20Gr=C3=BCter?= Date: Sun, 20 Jul 2025 15:40:02 +0200 Subject: [PATCH 2/2] WIP --- src/docstub/_analysis.py | 346 ++++++++++++++++++++----------------- src/docstub/_cli.py | 26 ++- src/docstub/_docstrings.py | 9 +- src/docstub/_utils.py | 4 +- 4 files changed, 213 insertions(+), 172 deletions(-) diff --git a/src/docstub/_analysis.py b/src/docstub/_analysis.py index b1467d8..f8cbd5b 100644 --- a/src/docstub/_analysis.py +++ b/src/docstub/_analysis.py @@ -5,15 +5,17 @@ import json import logging import re -from dataclasses import asdict, dataclass, field +import dataclasses as dc +from itertools import pairwise from functools import cache from pathlib import Path -from typing import Self +from typing import Self, ClassVar import libcst as cst import libcst.matchers as cstm from ._utils import accumulate_qualname, module_name_from_path, pyfile_checksum +from . import __version__ logger = logging.getLogger(__name__) @@ -50,7 +52,7 @@ def _shared_leading_qualname(*qualnames): return ".".join(shared) -@dataclass(slots=True, frozen=True) +@dc.dataclass(slots=True, frozen=True) class KnownImport: """Import information associated with a single known type annotation. @@ -209,6 +211,96 @@ def __str__(self) -> str: return out +@dc.dataclass(slots=True, kw_only=True) +class PyNode: + _TYPE_KINDS: ClassVar[set[str]] = { + "builtin", + "class", + "type_alias", + "ann_assign", + "import_from", + "generic_type", + } + _KINDS: ClassVar[set[str]] = {"module"} | _TYPE_KINDS + + name: str + kind: str + loc: str | None = None + parent: Self | None = None + children: list[Self] = dc.field(default_factory=list) + + @property + def is_leaf(self): + return not self.children + + @property + def fullname(self): + names = [node.name for node in self.walk_parents()][::-1] + return ".".join(names + [self.name]) + + @property + def is_type(self): + return self.kind in self._TYPE_KINDS + + @property + def import_statement(self): + module = [] + qualname = [self.name] + for parent in self.walk_parents(): + if parent.kind == "module": + module.insert(0, parent.name) + else: + qualname.insert(0, parent.name) + + if module: + return f"from {'.'.join(module)} import {'.'.join(qualname)}" + else: + return None + + def add_child(self, child): + assert child.parent is None + child.parent = self + self.children.append(child) + + def _walk_tree(self, names=()): + names = names + (self.name,) + yield names, self + for child in self.children: + yield from child._walk_tree(names) + + def walk_tree(self): + yield from self._walk_tree() + + def walk_parents(self): + current = self.parent + while current is not None: + yield current + current = current.parent + + def serialize_tree(self): + raw = {field.name: getattr(self, field.name) for field in dc.fields(self)} + del raw["parent"] + raw["children"] = [child.serialize_tree() for child in self.children] + return raw + + @classmethod + def from_serialized_tree(cls, primitives): + self = cls(**primitives) + if self.parent: + self.parent = cls.from_serialized_tree(self.parent) + self.children = [cls.from_serialized_tree(child) for child in self.children] + return self + + def __repr__(self): + return f"{type(self).__name__}({self.name!r}, kind={self.kind!r})" + + def __post_init__(self): + unsupported_kind = {self.kind} - self._KINDS + if unsupported_kind: + msg = f"unsupported kind {unsupported_kind}, supported are {self._KINDS}" + raise ValueError(msg) + + def _is_type(value): """Check if value is a type. @@ -228,28 +320,33 @@ def _is_type(value): def _builtin_types(): - """Return known imports for all builtins (in the current runtime). + """Builtin types in the current runtime. Returns ------- - known_imports : dict[str, KnownImport] + types : dict[str, PyNode] """ - known_builtins = set(dir(builtins)) + builtins_names = set(dir(builtins)) - known_imports = {} - for name in known_builtins: + types = {} + for name in builtins_names: if name.startswith("_"): continue value = getattr(builtins, name) if not _is_type(value): continue - known_imports[name] = KnownImport(builtin_name=name) + types[name] = PyNode(name=name, kind="builtin") - return known_imports + return types def _runtime_types_in_module(module_name): module = importlib.import_module(module_name) + + modules = [PyNode(name=name, kind="module") for name in module_name.split(".")] + for parent, child in pairwise(modules): + parent.add_child(child) + types = {} for name in module.__all__: if name.startswith("_"): @@ -258,14 +355,14 @@ def _runtime_types_in_module(module_name): if not _is_type(value): continue - import_ = KnownImport(import_path=module_name, import_name=name) - types[name] = import_ - types[f"{module_name}.{name}"] = import_ + pynode = PyNode(name=name, kind="generic_type") + modules[-1].add_child(pynode) + types[pynode.fullname] = pynode return types -def common_known_types(): +def common_types_nicknames(): """Return known imports for commonly supported types. This includes builtin types, and types from the `typing` or @@ -273,113 +370,43 @@ def common_known_types(): Returns ------- - known_imports : dict[str, KnownImport] + types : list[PyNode] + type_nicknames : dict[str, str] Examples -------- >>> types = common_known_types() >>> types["str"] - + PyNode('str', kind='builtin') >>> types["Iterable"] - + PyNode('Iterable', kind='generic_type') + >>> types["Iterable"].fullname + 'collections.abc.Iterable' >>> types["collections.abc.Iterable"] - + PyNode('Iterable', kind='generic_type') """ - known_imports = _builtin_types() - known_imports |= _runtime_types_in_module("typing") - # Overrides containers from typing - known_imports |= _runtime_types_in_module("collections.abc") - return known_imports - + pynodes = _builtin_types() + pynodes |= _runtime_types_in_module("typing") + collections_abc = _runtime_types_in_module("collections.abc") + pynodes |= collections_abc -class Node: - def __init__(self, name): - self.name = name - self._parent = None - self._children = [] + type_nicknames = {node.name: fullname for fullname, node in collections_abc.items()} - @property - def parent(self): - return self._parent + return pynodes, type_nicknames - @property - def children(self): - return self._children.copy() - @property - def is_leaf(self): - return not self._children - - @property - def fullname(self): - names = [node.name for node in self.walk_up()][::-1] - return ".".join(names) - - def add_child(self, child): - assert child.parent is None - child._parent = self - self._children.append(child) - - def walk_down(self): - yield self - for child in self._children: - yield from child.walk_down() - - def walk_up(self): - current = self - while current.parent is not None: - yield current - current = current.parent - - def __repr__(self): - return f"{type(self).__name__}({self.name!r})" - - -class Tree(Node): - def __init__(self): - super().__init__(name=None) - - def add_child(self, child): - if not isinstance(child, ModuleNode): - raise TypeError("expected new child to by a module") - return super().add_child(child) - - def get(self): - - -class ModuleNode(Node): - - def __init__(self, name, file_path): - super().__init__(name=name) - self.file_path = file_path - - -class _InModuleNode(Node): - pass - - -class ClassNode(_InModuleNode): - pass - - -class TypeAliasNode(_InModuleNode): - pass - - -class ImportFromNode(_InModuleNode): - pass - - -class NodeCollector(cst.CSTVisitor): +class PythonCollector(cst.CSTVisitor): """Collect types from a given Python file. Examples -------- - >>> types = NodeCollector.collect(__file__) + >>> types = PythonCollector.collect(__file__) >>> types[f"{__name__}.TypeCollector"] """ + METADATA_DEPENDENCIES = (cst.metadata.PositionProvider,) + class ImportSerializer: """Implements the `FuncSerializer` protocol to cache `TypeCollector.collect`.""" @@ -388,20 +415,20 @@ class ImportSerializer: def hash_args(self, path: Path) -> str: """Compute a unique hash from the path passed to `TypeCollector.collect`.""" - key = pyfile_checksum(path) + key = pyfile_checksum(path, salt=__version__) return key - def serialize(self, data: dict[str, KnownImport]) -> bytes: + def serialize(self, pynode: PyNode) -> bytes: """Serialize results from `TypeCollector.collect`.""" - primitives = {qualname: asdict(imp) for qualname, imp in data.items()} + primitives = pynode.serialize_tree() raw = json.dumps(primitives, separators=(",", ":")).encode(self.encoding) return raw - def deserialize(self, raw: bytes) -> dict[str, KnownImport]: + def deserialize(self, raw: bytes) -> PyNode: """Deserialize results from `TypeCollector.collect`.""" primitives = json.loads(raw.decode(self.encoding)) - data = {qualname: KnownImport(**kw) for qualname, kw in primitives.items()} - return data + pynode = PyNode.from_serialized_tree(primitives) + return pynode @classmethod def collect(cls, file_path): @@ -409,20 +436,22 @@ def collect(cls, file_path): Parameters ---------- - file : Path + file_path : Path Returns ------- - collected : dict[str, KnownImport] + module_tree : PyNode """ file_path = Path(file_path) with file_path.open("r") as fo: source = fo.read() tree = cst.parse_module(source) + meta_tree = cst.metadata.MetadataWrapper(tree) collector = cls(file_path=file_path) - tree.visit(collector) - return collector._root_node + meta_tree.visit(collector) + + return collector._root_pynode def __init__(self, *, file_path): """Initialize type collector. @@ -431,26 +460,44 @@ def __init__(self, *, file_path): ---------- module_name : str """ - assert "." not in file_path.stem - self._root_node = ModuleNode(name=file_path.stem, file_path=file_path) - self._current_node = self._root_node + full_module_name = module_name_from_path(file_path) + current_module, *parent_modules = full_module_name.split(".")[::-1] + + self._file_path = file_path + self._root_pynode = PyNode( + name=current_module, kind="module", loc=str(file_path) + ) + self._current_pynode = self._root_pynode + + for name in parent_modules: + # TODO set location for parent modules too + parent = PyNode(name=name, kind="module") + parent.add_child(self._root_pynode) + self._root_pynode = parent + + def _get_loc(self, node): + pos = self.get_metadata(cst.metadata.PositionProvider, node).start + loc = f"{self._file_path}:{pos.line}:{pos.column}" + return loc def visit_ClassDef(self, node: cst.ClassDef) -> bool: - node = ClassNode(name=node.name.value) - self._current_node.add_child(node) - self._current_node = node + pynode = PyNode(name=node.name.value, kind="class", loc=self._get_loc(node)) + self._current_pynode.add_child(pynode) + self._current_pynode = pynode return True def leave_ClassDef(self, original_node: cst.ClassDef) -> None: - self._current_node = self._current_node.parent + self._current_pynode = self._current_pynode.parent def visit_FunctionDef(self, node: cst.FunctionDef) -> bool: return False def visit_TypeAlias(self, node: cst.TypeAlias) -> bool: """Collect type alias with 3.12 syntax.""" - node = TypeAliasNode(name=node.name.value) - self._current_node.add_child(node) + pynode = PyNode( + name=node.name.value, kind="type_alias", loc=self._get_loc(node) + ) + self._current_pynode.add_child(pynode) return False def visit_AnnAssign(self, node: cst.AnnAssign) -> bool: @@ -464,8 +511,10 @@ def visit_AnnAssign(self, node: cst.AnnAssign) -> bool: if is_type_alias and node.value is not None: names = cstm.findall(node.target, cstm.Name()) assert len(names) == 1 - node = Node(name=names[0].value) - self._current_node.add_child(node) + pynode = PyNode( + name=names[0].value, kind="ann_assign", loc=self._get_loc(node) + ) + self._current_pynode.add_child(pynode) return False def visit_ImportFrom(self, node: cst.ImportFrom) -> None: @@ -478,8 +527,8 @@ def visit_ImportFrom(self, node: cst.ImportFrom) -> None: name = import_alias.evaluated_name assert isinstance(name, str) - node = ImportFromNode(name=name) - self._current_node.add_child(node) + pynode = PyNode(name=name, kind="import_from", loc=self._get_loc(node)) + self._current_pynode.add_child(pynode) class TypeCollector(cst.CSTVisitor): @@ -597,7 +646,7 @@ class TypeMatcher: Attributes ---------- - types : dict[str, KnownImport] + types : dict[str, PyNode] type_prefixes : dict[str, KnownImport] type_nicknames : dict[str, str] successful_queries : int @@ -606,7 +655,7 @@ class TypeMatcher: Examples -------- - >>> from docstub._analysis import TypeMatcher, common_known_types + >>> from docstub._analysis import TypeMatcher >>> db = TypeMatcher() >>> db.match("Any") ('Any', ) @@ -626,12 +675,11 @@ def __init__( type_prefixes : dict[str, KnownImport] type_nicknames : dict[str, str] """ - self.types = types or common_known_types() + self.types = types or {} self.type_prefixes = type_prefixes or {} self.type_nicknames = type_nicknames or {} self.successful_queries = 0 self.unknown_qualnames = [] - self.current_module = None def match(self, search_name): @@ -644,11 +692,9 @@ def match(self, search_name): Returns ------- - type_name : str | None - type_origin : KnownImport | None + type : pynode | None """ - type_name = None - type_origin = None + pynode = None if search_name.startswith("~."): # Sphinx like matching with abbreviated name @@ -661,8 +707,7 @@ def match(self, search_name): } if len(matches) > 1: shortest_key = sorted(matches.keys(), key=lambda x: len(x))[0] - type_origin = matches[shortest_key] - type_name = shortest_key + pynode = matches[shortest_key] logger.warning( "%r in %s matches multiple types %r, using %r", search_name, @@ -671,7 +716,7 @@ def match(self, search_name): shortest_key, ) elif len(matches) == 1: - type_name, type_origin = matches.popitem() + _, pynode = matches.popitem() else: search_name = search_name[2:] logger.debug( @@ -683,38 +728,25 @@ def match(self, search_name): # Replace alias search_name = self.type_nicknames.get(search_name, search_name) - if type_origin is None and self.current_module: + if pynode is None and self.current_module: # Try scope of current module module_name = module_name_from_path(self.current_module) try_qualname = f"{module_name}.{search_name}" - type_origin = self.types.get(try_qualname) - if type_origin: - type_name = search_name + pynode = self.types.get(try_qualname) - if type_origin is None and search_name in self.types: - type_name = search_name - type_origin = self.types[search_name] + if pynode is None and search_name in self.types: + pynode = self.types[search_name] - if type_origin is None: + if pynode is None: # Try a subset of the qualname (first 'a.b.c', then 'a.b' and 'a') for partial_qualname in reversed(accumulate_qualname(search_name)): - type_origin = self.type_prefixes.get(partial_qualname) - if type_origin: - type_name = search_name + pynode = self.type_prefixes.get(partial_qualname) + if pynode: break - if ( - type_origin is not None - and type_name is not None - and type_name != type_origin.target - and not type_name.startswith(type_origin.target) - ): - # Ensure that the annotation matches the import target - type_name = type_name[type_name.find(type_origin.target) :] - - if type_name is not None: + if pynode is not None: self.successful_queries += 1 else: self.unknown_qualnames.append(search_name) - return type_name, type_origin + return pynode diff --git a/src/docstub/_cli.py b/src/docstub/_cli.py index 3bcbf76..d1aef03 100644 --- a/src/docstub/_cli.py +++ b/src/docstub/_cli.py @@ -9,9 +9,9 @@ from ._analysis import ( KnownImport, - TypeCollector, + PythonCollector, TypeMatcher, - common_known_types, + common_types_nicknames, ) from ._cache import FileCache from ._config import Config @@ -93,20 +93,26 @@ def _collect_types(root_path, *, ignore=()): Returns ------- - types : dict[str, ~.KnownImport] + types : dict[str, ~.PyNode] """ - types = common_known_types() + types = {} collect_cached_types = FileCache( - func=TypeCollector.collect, - serializer=TypeCollector.ImportSerializer(), + func=PythonCollector.collect, + serializer=PythonCollector.ImportSerializer(), cache_dir=Path.cwd() / ".docstub_cache", name=f"{__version__}/collected_types", ) if root_path.is_dir(): for source_path in walk_python_package(root_path, ignore=ignore): logger.info("collecting types in %s", source_path) - types_in_source = collect_cached_types(source_path) + + module_tree = collect_cached_types(source_path) + types_in_source = { + ".".join(fullname): pynode + for fullname, pynode in module_tree.walk_tree() + if pynode.is_type + } types.update(types_in_source) return types @@ -228,7 +234,7 @@ def run(root_path, out_dir, config_paths, ignore, group_errors, allow_errors, ve config = _load_configuration(config_paths) config = config.merge(Config(ignore_files=list(ignore))) - types = common_known_types() + types, type_nicknames = common_types_nicknames() types |= _collect_types(root_path, ignore=config.ignore_files) types |= { type_name: KnownImport(import_path=module, import_name=type_name) @@ -244,9 +250,11 @@ def run(root_path, out_dir, config_paths, ignore, group_errors, allow_errors, ve for prefix, module in config.type_prefixes.items() } + type_nicknames |= config.type_nicknames + reporter = GroupedErrorReporter() if group_errors else ErrorReporter() matcher = TypeMatcher( - types=types, type_prefixes=type_prefixes, type_nicknames=config.type_nicknames + types=types, type_prefixes=type_prefixes, type_nicknames=type_nicknames ) stub_transformer = Py2StubTransformer(matcher=matcher, reporter=reporter) diff --git a/src/docstub/_docstrings.py b/src/docstub/_docstrings.py index 1591cd4..3ee20e4 100644 --- a/src/docstub/_docstrings.py +++ b/src/docstub/_docstrings.py @@ -524,13 +524,14 @@ def _match_import(self, qualname, *, meta): Possibly modified or normalized qualname. """ if self.matcher is not None: - annotation_name, known_import = self.matcher.match(qualname) + pynode = self.matcher.match(qualname) + annotation_name = pynode.fullname else: annotation_name = None - known_import = None + pynode = None - if known_import and known_import.has_import: - self._collected_imports.add(known_import) + if pynode and pynode.import_statement: + self._collected_imports.add(pynode.import_statement) if annotation_name: matched_qualname = annotation_name diff --git a/src/docstub/_utils.py b/src/docstub/_utils.py index bbd55bd..0ed2a42 100644 --- a/src/docstub/_utils.py +++ b/src/docstub/_utils.py @@ -106,7 +106,7 @@ def module_name_from_path(path): return name -def pyfile_checksum(path): +def pyfile_checksum(path, salt=""): """Compute a unique key for a Python file. The key takes into account the given `path`, the relative position if the @@ -124,7 +124,7 @@ def pyfile_checksum(path): absolute_path = str(path.resolve()).encode() with open(path, "rb") as fp: content = fp.read() - key = crc32(content + module_name + absolute_path) + key = crc32(content + module_name + absolute_path + salt.encode()) return key