diff --git a/pyproject.toml b/pyproject.toml index ccaa9d2..ccb3fb8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,7 +7,7 @@ name = "scyjava" version = "1.12.2.dev0" description = "Supercharged Java access from Python" license = "Unlicense" -authors = [{name = "SciJava developers", email = "ctrueden@wisc.edu"}] +authors = [{ name = "SciJava developers", email = "ctrueden@wisc.edu" }] readme = "README.md" keywords = ["java", "maven", "cross-language"] classifiers = [ @@ -35,6 +35,7 @@ dependencies = [ "jpype1 >= 1.3.0", "jgo", "cjdk", + "stubgenj", ] [dependency-groups] @@ -50,6 +51,9 @@ dev = [ "validate-pyproject[all]", ] +[project.scripts] +scyjava-stubgen = "scyjava._stubs._cli:main" + [project.urls] homepage = "https://github.com/scijava/scyjava" documentation = "https://github.com/scijava/scyjava/blob/main/README.md" @@ -58,7 +62,7 @@ download = "https://pypi.org/project/scyjava/" tracker = "https://github.com/scijava/scyjava/issues" [tool.setuptools] -package-dir = {"" = "src"} +package-dir = { "" = "src" } include-package-data = false [tool.setuptools.packages.find] diff --git a/src/scyjava/_jvm.py b/src/scyjava/_jvm.py index 224ac61..2f0e4a8 100644 --- a/src/scyjava/_jvm.py +++ b/src/scyjava/_jvm.py @@ -363,7 +363,7 @@ def is_awt_initialized() -> bool: return False Thread = scyjava.jimport("java.lang.Thread") threads = Thread.getAllStackTraces().keySet() - return any(t.getName().startsWith("AWT-") for t in threads) + return any(str(t.getName()).startswith("AWT-") for t in threads) def when_jvm_starts(f) -> None: diff --git a/src/scyjava/_stubs/__init__.py b/src/scyjava/_stubs/__init__.py new file mode 100644 index 0000000..d6a5e7c --- /dev/null +++ b/src/scyjava/_stubs/__init__.py @@ -0,0 +1,4 @@ +from ._dynamic_import import setup_java_imports +from ._genstubs import generate_stubs + +__all__ = ["setup_java_imports", "generate_stubs"] diff --git a/src/scyjava/_stubs/_cli.py b/src/scyjava/_stubs/_cli.py new file mode 100644 index 0000000..7936529 --- /dev/null +++ b/src/scyjava/_stubs/_cli.py @@ -0,0 +1,162 @@ +"""The scyjava-stubs executable. + +Provides cli access to the `scyjava._stubs.generate_stubs` function. + +The only interesting additional things going on here is the choice of *where* the stubs +go by default. When using the CLI, they land in `scyjava.types` by default; see the +`_get_ouput_dir` helper function for details on how the output directory is resolved +from the CLI arguments. +""" + +from __future__ import annotations + +import argparse +import importlib +import importlib.util +import logging +import sys +from pathlib import Path + +from ._genstubs import generate_stubs + + +def main() -> None: + """The main entry point for the scyjava-stubs executable.""" + logging.basicConfig(level="INFO") + parser = argparse.ArgumentParser( + description="Generate Python Type Stubs for Java classes." + ) + parser.add_argument( + "endpoints", + type=str, + nargs="+", + help="Maven endpoints to install and use (e.g. org.myproject:myproject:1.0.0)", + ) + parser.add_argument( + "--prefix", + type=str, + help="package prefixes to generate stubs for (e.g. org.myproject), " + "may be used multiple times. If not specified, prefixes are gleaned from the " + "downloaded artifacts.", + action="append", + default=[], + metavar="PREFIX", + dest="prefix", + ) + path_group = parser.add_mutually_exclusive_group() + path_group.add_argument( + "--output-dir", + type=str, + default=None, + help="Filesystem path to write stubs to.", + ) + path_group.add_argument( + "--output-python-path", + type=str, + default=None, + help="Python path to write stubs to (e.g. 'scyjava.types').", + ) + parser.add_argument( + "--convert-strings", + dest="convert_strings", + action="store_true", + default=False, + help="convert java.lang.String to python str in return types. " + "consult the JPype documentation on the convertStrings flag for details", + ) + parser.add_argument( + "--no-javadoc", + dest="with_javadoc", + action="store_false", + default=True, + help="do not generate docstrings from JavaDoc where available", + ) + + rt_group = parser.add_mutually_exclusive_group() + rt_group.add_argument( + "--runtime-imports", + dest="runtime_imports", + action="store_true", + default=True, + help="Add runtime imports to the generated stubs. ", + ) + rt_group.add_argument( + "--no-runtime-imports", dest="runtime_imports", action="store_false" + ) + + parser.add_argument( + "--remove-namespace-only-stubs", + dest="remove_namespace_only_stubs", + action="store_true", + default=False, + help="Remove stubs that export no names beyond a single __module_protocol__. " + "This leaves some folders as PEP420 implicit namespace folders.", + ) + + if len(sys.argv) == 1: + parser.print_help() + sys.exit(1) + + args = parser.parse_args() + output_dir = _get_ouput_dir(args.output_dir, args.output_python_path) + if not output_dir.exists(): + output_dir.mkdir(parents=True, exist_ok=True) + + generate_stubs( + endpoints=args.endpoints, + prefixes=args.prefix, + output_dir=output_dir, + convert_strings=args.convert_strings, + include_javadoc=args.with_javadoc, + add_runtime_imports=args.runtime_imports, + remove_namespace_only_stubs=args.remove_namespace_only_stubs, + ) + + +def _get_ouput_dir(output_dir: str | None, python_path: str | None) -> Path: + if out_dir := output_dir: + return Path(out_dir) + if pp := python_path: + return _glean_path(pp) + try: + import scyjava + + return Path(scyjava.__file__).parent / "types" + except ImportError: + return Path("stubs") + + +def _glean_path(pp: str) -> Path: + try: + importlib.import_module(pp.split(".")[0]) + except ModuleNotFoundError: + # the top level module doesn't exist: + raise ValueError(f"Module {pp} does not exist. Cannot install stubs there.") + + try: + spec = importlib.util.find_spec(pp) + except ModuleNotFoundError as e: + # at least one of the middle levels doesn't exist: + raise NotImplementedError(f"Cannot install stubs to {pp}: {e}") + + new_ns = None + if not spec: + # if we get here, it means everything but the last level exists: + parent, new_ns = pp.rsplit(".", 1) + spec = importlib.util.find_spec(parent) + + if not spec: + # if we get here, it means the last level doesn't exist: + raise ValueError(f"Module {pp} does not exist. Cannot install stubs there.") + + search_locations = spec.submodule_search_locations + if not spec.loader and search_locations: + # namespace package with submodules + return Path(search_locations[0]) + if spec.origin: + return Path(spec.origin).parent + if new_ns and search_locations: + # namespace package with submodules + return Path(search_locations[0]) / new_ns + + raise ValueError(f"Error finding module {pp}. Cannot install stubs there.") diff --git a/src/scyjava/_stubs/_dynamic_import.py b/src/scyjava/_stubs/_dynamic_import.py new file mode 100644 index 0000000..16e27d4 --- /dev/null +++ b/src/scyjava/_stubs/_dynamic_import.py @@ -0,0 +1,131 @@ +"""Logic for using generated type stubs as runtime importable, with lazy JVM startup. + +Most often, the functionality here will be used as follows: + +``` +from scyjava._stubs import setup_java_imports + +__all__, __getattr__ = setup_java_imports( + __name__, + __file__, + endpoints=["org.scijava:parsington:3.1.0"], + base_prefix="org" +) +``` + +...and that little snippet is written into the generated stubs modules by the +`scyjava._stubs.generate_stubs` function. + +See docstring of `setup_java_imports` for details on how it works. +""" + +import ast +from logging import warning +from pathlib import Path +from typing import Any, Callable, Sequence + + +def setup_java_imports( + module_name: str, + module_file: str, + endpoints: Sequence[str] = (), + base_prefix: str = "", +) -> tuple[list[str], Callable[[str], Any]]: + """Setup a module to dynamically import Java class names. + + This function creates a `__getattr__` function that, when called, will dynamically + import the requested class from the Java namespace corresponding to the calling + module. + + :param module_name: The dotted name/identifier of the module that is calling this + function (usually `__name__` in the calling module). + :param module_file: The path to the module file (usually `__file__` in the calling + module). + :param endpoints: A list of Java endpoints to add to the scyjava configuration. + (Note that `scyjava._stubs.generate_stubs` will automatically add the necessary + endpoints for the generated stubs.) + :param base_prefix: The base prefix for the Java package name. This is used when + determining the Java class path for the requested class. The java class path + will be truncated to only the part including the base_prefix and after. This + makes it possible to embed a module in a subpackage (like `scyjava.types`) and + still have the correct Java class path. + + :return: A 2-tuple containing: + - A list of all classes in the module (as defined in the stub file), to be + assigned to `__all__`. + - A callable that takes a class name and returns a proxy for the Java class. + This callable should be assigned to `__getattr__` in the calling module. + The proxy object, when called, will start the JVM, import the Java class, + and return an instance of the class. The JVM will *only* be started when + the object is called. + + Example: + If the module calling this function is named `scyjava.types.org.scijava.parsington`, + then it should invoke this function as: + + .. code-block:: python + + from scyjava._stubs import setup_java_imports + + __all__, __getattr__ = setup_java_imports( + __name__, + __file__, + endpoints=["org.scijava:parsington:3.1.0"], + base_prefix="org" + ) + """ + import scyjava + import scyjava.config + + for ep in endpoints: + if ep not in scyjava.config.endpoints: + scyjava.config.endpoints.append(ep) + + # list intended to be assigned to `__all__` in the generated module. + module_all = [] + try: + my_stub = Path(module_file).with_suffix(".pyi") + stub_ast = ast.parse(my_stub.read_text()) + module_all = sorted( + { + node.name + for node in stub_ast.body + if isinstance(node, ast.ClassDef) and not node.name.startswith("__") + } + ) + except (OSError, SyntaxError): + warning( + f"Failed to read stub file {my_stub!r}. Falling back to empty __all__.", + stacklevel=3, + ) + + def module_getattr(name: str, mod_name: str = module_name) -> Any: + """Function intended to be assigned to __getattr__ in the generate module.""" + if module_all and name not in module_all: + raise AttributeError(f"module {module_name!r} has no attribute {name!r}") + + # cut the mod_name to only the part including the base_prefix and after + if base_prefix in mod_name: + mod_name = mod_name[mod_name.index(base_prefix) :] + + class_path = f"{mod_name}.{name}" + + # Generate a proxy type (with a nice repr) that + # delays the call to `jimport` until the last moment when type.__new__ is called + + class ProxyMeta(type): + def __repr__(self) -> str: + return f"" + + class Proxy(metaclass=ProxyMeta): + def __new__(_cls_, *args: Any, **kwargs: Any) -> Any: + cls = scyjava.jimport(class_path) + return cls(*args, **kwargs) + + Proxy.__name__ = name + Proxy.__qualname__ = name + Proxy.__module__ = module_name + Proxy.__doc__ = f"Proxy for {class_path}" + return Proxy + + return module_all, module_getattr diff --git a/src/scyjava/_stubs/_genstubs.py b/src/scyjava/_stubs/_genstubs.py new file mode 100644 index 0000000..3df24cb --- /dev/null +++ b/src/scyjava/_stubs/_genstubs.py @@ -0,0 +1,225 @@ +"""Type stub generation utilities using stubgen. + +This module provides utilities for generating type stubs for Java classes +using the stubgenj library. `stubgenj` must be installed for this to work +(it, in turn, only depends on JPype). + +See `generate_stubs` for most functionality. For the command-line tool, +see `scyjava._stubs.cli`, which provides a CLI interface for the `generate_stubs` +function. +""" + +from __future__ import annotations + +import ast +import logging +import os +import shutil +import subprocess +import sys +from importlib import import_module +from itertools import chain +from pathlib import Path, PurePath +from typing import TYPE_CHECKING, Any +from unittest.mock import patch +from zipfile import ZipFile + +import scyjava +import scyjava.config + +if TYPE_CHECKING: + from collections.abc import Sequence + +logger = logging.getLogger(__name__) + + +def generate_stubs( + endpoints: Sequence[str], + prefixes: Sequence[str] = (), + output_dir: str | Path = "stubs", + convert_strings: bool = True, + include_javadoc: bool = True, + add_runtime_imports: bool = True, + remove_namespace_only_stubs: bool = False, +) -> None: + """Generate stubs for the given maven endpoints. + + Parameters + ---------- + endpoints : Sequence[str] + The maven endpoints to generate stubs for. This should be a list of GAV + coordinates, e.g. ["org.apache.commons:commons-lang3:3.12.0"]. + prefixes : Sequence[str], optional + The prefixes to generate stubs for. This should be a list of Java class + prefixes that you expect to find in the endpoints. For example, + ["org.apache.commons"]. If not provided, the prefixes will be + automatically determined from the jar files provided by endpoints (see the + `_list_top_level_packages` helper function). + output_dir : str | Path, optional + The directory to write the generated stubs to. Defaults to "stubs" in the + current working directory. + convert_strings : bool, optional + Whether to cast Java strings to Python strings in the stubs. Defaults to True. + NOTE: This leads to type stubs that may not be strictly accurate at runtime. + The actual runtime type of strings is determined by whether jpype.startJVM is + called with the `convertStrings` argument set to True or False. By setting + this `convert_strings` argument to true, the type stubs will be generated as if + `convertStrings` is set to True: that is, all string types will be listed as + `str` rather than `java.lang.String | str`. This is a safer default (as `str`) + is a subtype of `java.lang.String`), but may lead to type errors in some cases. + include_javadoc : bool, optional + Whether to include Javadoc in the generated stubs. Defaults to True. + add_runtime_imports : bool, optional + Whether to add runtime imports to the generated stubs. Defaults to True. + This is useful if you want to actually import the stubs as a runtime package + with type safety. The runtime import "magic" depends on the + `scyjava._stubs.setup_java_imports` function. See its documentation for + more details. + remove_namespace_only_stubs : bool, optional + Whether to remove stubs that export no names beyond a single + `__module_protocol__`. This leaves some folders as PEP420 implicit namespace + folders. Defaults to False. Setting this to `True` is useful if you want to + merge the generated stubs with other stubs in the same namespace. Without this, + the `__init__.pyi` for any given module will be whatever whatever the *last* + stub generator wrote to it (and therefore inaccurate). + """ + try: + import stubgenj + except ImportError as e: + raise ImportError( + "stubgenj is not installed, but is required to generate java stubs. " + "Please install it with `pip/conda install stubgenj`." + ) from e + + import jpype + + startJVM = jpype.startJVM + + scyjava.config.endpoints.extend(endpoints) + + def _patched_start(*args: Any, **kwargs: Any) -> None: + kwargs.setdefault("convertStrings", convert_strings) + startJVM(*args, **kwargs) + + with patch.object(jpype, "startJVM", new=_patched_start): + scyjava.start_jvm() + + _prefixes = set(prefixes) + if not _prefixes: + cp = jpype.getClassPath(env=False) + ep_artifacts = tuple(ep.split(":")[1] for ep in endpoints) + for j in cp.split(os.pathsep): + if Path(j).name.startswith(ep_artifacts): + _prefixes.update(_list_top_level_packages(j)) + + prefixes = sorted(_prefixes) + logger.info(f"Using endpoints: {scyjava.config.endpoints!r}") + logger.info(f"Generating stubs for: {prefixes}") + logger.info(f"Writing stubs to: {output_dir}") + + metapath = sys.meta_path + try: + import jpype.imports + + jmodules = [import_module(prefix) for prefix in prefixes] + finally: + # remove the jpype.imports magic from the import system + # if it wasn't there to begin with + sys.meta_path = metapath + + stubgenj.generateJavaStubs( + jmodules, + useStubsSuffix=False, + outputDir=str(output_dir), + jpypeJPackageStubs=False, + includeJavadoc=include_javadoc, + ) + + output_dir = Path(output_dir) + if add_runtime_imports: + logger.info("Adding runtime imports to generated stubs") + + for stub in output_dir.rglob("*.pyi"): + stub_ast = ast.parse(stub.read_text()) + members = {node.name for node in stub_ast.body if hasattr(node, "name")} + if members == {"__module_protocol__"}: + # this is simply a module stub... no exports + if remove_namespace_only_stubs: + logger.info("Removing namespace only stub %s", stub) + stub.unlink() + continue + if add_runtime_imports: + real_import = stub.with_suffix(".py") + base_prefix = stub.relative_to(output_dir).parts[0] + real_import.write_text( + INIT_TEMPLATE.format( + endpoints=repr(endpoints), + base_prefix=repr(base_prefix), + ) + ) + + ruff_check(output_dir.absolute()) + + +# the "real" init file that goes into the stub package +INIT_TEMPLATE = """\ +# this file was autogenerated by scyjava-stubgen +# it creates a __getattr__ function that will dynamically import +# the requested class from the Java namespace corresponding to this module. +# see scyjava._stubs for implementation details. +from scyjava._stubs import setup_java_imports + +__all__, __getattr__ = setup_java_imports( + __name__, + __file__, + endpoints={endpoints}, + base_prefix={base_prefix}, +) +""" + + +def ruff_check(output: Path, select: str = "E,W,F,I,UP,C4,B,RUF,TC,TID") -> None: + """Run ruff check and format on the generated stubs.""" + if not shutil.which("ruff"): + return + + py_files = [str(x) for x in chain(output.rglob("*.py"), output.rglob("*.pyi"))] + logger.info( + "Running ruff check on %d generated stubs in % s", + len(py_files), + str(output), + ) + subprocess.run( + [ + "ruff", + "check", + *py_files, + "--quiet", + "--fix-only", + "--unsafe-fixes", + f"--select={select}", + ] + ) + logger.info("Running ruff format") + subprocess.run(["ruff", "format", *py_files, "--quiet"]) + + +def _list_top_level_packages(jar_path: str) -> set[str]: + """Inspect a JAR file and return the set of top-level Java package names.""" + packages: set[str] = set() + with ZipFile(jar_path, "r") as jar: + # find all classes + class_dirs = { + entry.parent + for x in jar.namelist() + if (entry := PurePath(x)).suffix == ".class" + } + + roots: set[PurePath] = set() + for p in sorted(class_dirs, key=lambda p: len(p.parts)): + # If none of the already accepted roots is a parent of p, keep p + if not any(root in p.parents for root in roots): + roots.add(p) + packages.update({str(p).replace(os.sep, ".") for p in roots}) + + return packages diff --git a/src/scyjava/types/.gitignore b/src/scyjava/types/.gitignore new file mode 100644 index 0000000..5e7d273 --- /dev/null +++ b/src/scyjava/types/.gitignore @@ -0,0 +1,4 @@ +# Ignore everything in this directory +* +# Except this file +!.gitignore diff --git a/tests/test_stubgen.py b/tests/test_stubgen.py new file mode 100644 index 0000000..ad4a74a --- /dev/null +++ b/tests/test_stubgen.py @@ -0,0 +1,58 @@ +from __future__ import annotations + +import sys +from typing import TYPE_CHECKING +from unittest.mock import patch + +import jpype +import pytest + +import scyjava +from scyjava._stubs import _cli + +if TYPE_CHECKING: + from pathlib import Path + + +@pytest.mark.skipif( + scyjava.config.mode != scyjava.config.Mode.JPYPE, + reason="Stubgen not supported in JEP", +) +def test_stubgen(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None: + # run the stubgen command as if it was run from the command line + monkeypatch.setattr( + sys, + "argv", + [ + "scyjava-stubgen", + "org.scijava:parsington:3.1.0", + "--output-dir", + str(tmp_path), + ], + ) + _cli.main() + + # remove the `jpype.imports` magic from the import system if present + mp = [x for x in sys.meta_path if not isinstance(x, jpype.imports._JImportLoader)] + monkeypatch.setattr(sys, "meta_path", mp) + + # add tmp_path to the import path + monkeypatch.setattr(sys, "path", [str(tmp_path)]) + + # first cleanup to make sure we are not importing from the cache + sys.modules.pop("org", None) + sys.modules.pop("org.scijava", None) + sys.modules.pop("org.scijava.parsington", None) + # make sure the stubgen command works and that we can now impmort stuff + + with patch.object(scyjava._jvm, "start_jvm") as mock_start_jvm: + from org.scijava.parsington import Function + + assert Function is not None + # ensure that no calls to start_jvm were made + mock_start_jvm.assert_not_called() + + # only after instantiating the class should we have a call to start_jvm + func = Function(1) + mock_start_jvm.assert_called_once() + assert isinstance(func, jpype.JObject)