diff --git a/CHANGELOG.rst b/CHANGELOG.rst index e6244477..4933583c 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -15,6 +15,11 @@ paths are considered internals and can change in minor and patch releases. v4.43.0 (unreleased) -------------------- +Added +^^^^^ +- Non-public experimental mixin class to override init defaults and instantiate + from config (`#800 `__). + Fixed ^^^^^ - Prevent extra environment variables in helptext when default_env=True, for diff --git a/jsonargparse/__init__.py b/jsonargparse/__init__.py index 53ad4f05..e3d3031a 100644 --- a/jsonargparse/__init__.py +++ b/jsonargparse/__init__.py @@ -14,6 +14,7 @@ from ._core import * # noqa: F403 from ._deprecated import * # noqa: F403 from ._formatters import * # noqa: F403 +from ._from_config import * # noqa: F403 from ._jsonnet import * # noqa: F403 from ._jsonschema import * # noqa: F403 from ._link_arguments import * # noqa: F403 diff --git a/jsonargparse/_from_config.py b/jsonargparse/_from_config.py new file mode 100644 index 00000000..3c8d62cf --- /dev/null +++ b/jsonargparse/_from_config.py @@ -0,0 +1,142 @@ +import inspect +from functools import wraps +from os import PathLike +from pathlib import Path +from typing import Optional, Type, TypeVar, Union + +from ._core import ArgumentParser + +__all__ = ["FromConfigMixin"] + +T = TypeVar("T") + + +class FromConfigMixin: + """Mixin class that adds from config support to classes. + + This mixin does two things: + + 1. Adds support for overriding ``__init__`` defaults by defining a + ``__from_config_init_defaults__`` class attribute pointing to a + config file path. The overriding of defaults happens on subclass + creation time. Inspecting the signature will give the overridden + defaults. + + 2. Adds a ``from_config`` ``@classmethod``, that instantiates the class + based on a config file or dict. + + Attributes: + ``__from_config_init_defaults__``: Optional path to a config file for + overriding ``__init__`` defaults. + ``__from_config_parser_kwargs__``: Additional kwargs to pass to the + ArgumentParser used for parsing configs. + """ + + __from_config_init_defaults__: Optional[Union[str, PathLike]] = None + __from_config_parser_kwargs__: dict = {} + + def __init_subclass__(cls, **kwargs) -> None: + """Override ``__init__`` defaults for the subclass based on a config file.""" + super().__init_subclass__(**kwargs) + _override_init_defaults(cls, cls.__from_config_parser_kwargs__) + + @classmethod + def from_config(cls: Type[T], config: Union[str, PathLike, dict]) -> T: + """Instantiate current class based on a config file or dict. + + Args: + config: Path to a config file or a dict with config values. + """ + kwargs = _parse_class_kwargs_from_config(cls, config, **cls.__from_config_parser_kwargs__) # type: ignore[attr-defined] + return cls(**kwargs) + + +def _parse_class_kwargs_from_config(cls: Type[T], config: Union[str, PathLike, dict], **kwargs) -> dict: + """Parse the init kwargs for `cls` from a config file or dict.""" + parser = ArgumentParser(exit_on_error=False, **kwargs) + parser.add_class_arguments(cls) + if isinstance(config, dict): + cfg = parser.parse_object(config, defaults=False) + else: + cfg = parser.parse_path(config, defaults=False) + return parser.instantiate_classes(cfg).as_dict() + + +def _override_init_defaults(cls: Type[T], parser_kwargs: dict) -> None: + """Override ``__init__`` defaults for ``cls`` based on ``__from_config_init_defaults__``.""" + config = getattr(cls, "__from_config_init_defaults__", None) + if not isinstance(config, (str, PathLike, type(None))): + raise TypeError("__from_config_init_defaults__ must be str, PathLike, or None") + if not (isinstance(config, (str, PathLike)) and Path(config).is_file()): + return + + defaults = _parse_class_kwargs_from_config(cls, config, **parser_kwargs) + _override_init_defaults_this_class(cls, defaults) + _override_init_defaults_parent_classes(cls, defaults) + + +def _override_init_defaults_this_class(cls: Type[T], defaults: dict) -> None: + params = inspect.signature(cls.__init__).parameters + for name, default in defaults.copy().items(): + param = params.get(name) + if param and param.kind in {inspect.Parameter.KEYWORD_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD}: + defaults.pop(name) + if param.kind == inspect.Parameter.KEYWORD_ONLY: + cls.__init__.__kwdefaults__[name] = default # type: ignore[index] + else: + index = list(params).index(name) - 1 + aux = cls.__init__.__defaults__ or () + cls.__init__.__defaults__ = aux[:index] + (default,) + aux[index + 1 :] + + +def _override_init_defaults_parent_classes(cls: Type[T], defaults: dict) -> None: + # Gather defaults for parameters in parent classes' __init__ + override_parent_params = [] + for base in inspect.getmro(cls)[1:]: + if not defaults: + break + + params = inspect.signature(base.__init__).parameters # type: ignore[misc] + names = [name for name in defaults if name in params] + for name in names: + new_param = inspect.Parameter( + name=name, + kind=inspect.Parameter.KEYWORD_ONLY, + default=defaults.pop(name), + annotation=params[name].annotation, + ) + override_parent_params.append(new_param) + + if not override_parent_params: + return + + # Override defaults for parameters in parent classes' __init__ via a wrapper + original_init = cls.__init__ + original_sig = inspect.signature(cls.__init__) + parameters = list(original_sig.parameters.values()) + + # Find and pop the **kwargs parameter, if it exists + kwargs_param = None + if parameters and parameters[-1].kind == inspect.Parameter.VAR_KEYWORD: + kwargs_param = parameters.pop() + + # Add new parameters + for param in reversed(override_parent_params): + parameters.append(param) + + # Add **kwargs back at the end + if kwargs_param: + parameters.append(kwargs_param) + + # Create and set __init__ wrapper with new signature + parent_defaults = {p.name: p.default for p in override_parent_params} + + @wraps(original_init) + def wrapper(*args, **kwargs): + for name, default in parent_defaults.items(): + if name not in kwargs: + kwargs[name] = default + return original_init(*args, **kwargs) + + wrapper.__signature__ = original_sig.replace(parameters=parameters) # type: ignore[attr-defined] + cls.__init__ = wrapper # type: ignore[method-assign] diff --git a/jsonargparse/_optionals.py b/jsonargparse/_optionals.py index 56a2ac2a..567672b4 100644 --- a/jsonargparse/_optionals.py +++ b/jsonargparse/_optionals.py @@ -282,12 +282,9 @@ def get_omegaconf_loader(mode): if mode == "omegaconf+": from ._common import get_parsing_setting - if not get_parsing_setting("omegaconf_absolute_to_relative_paths"): - return yaml_load - def omegaconf_plus_load(value): value = yaml_load(value) - if isinstance(value, dict): + if isinstance(value, dict) and get_parsing_setting("omegaconf_absolute_to_relative_paths"): value = omegaconf_absolute_to_relative_paths(value) return value diff --git a/jsonargparse_tests/conftest.py b/jsonargparse_tests/conftest.py index 276e5ed3..80fc5277 100644 --- a/jsonargparse_tests/conftest.py +++ b/jsonargparse_tests/conftest.py @@ -83,6 +83,12 @@ reason="responses package is required", ) +skip_if_omegaconf_unavailable = pytest.mark.skipif( + not omegaconf_support, + reason="omegaconf package is required", +) + + skip_if_running_as_root = pytest.mark.skipif( is_posix and os.geteuid() == 0, reason="User is root, permission tests will not work", diff --git a/jsonargparse_tests/test_from_config.py b/jsonargparse_tests/test_from_config.py new file mode 100644 index 00000000..b324c7ec --- /dev/null +++ b/jsonargparse_tests/test_from_config.py @@ -0,0 +1,223 @@ +import inspect + +import pytest + +from jsonargparse import FromConfigMixin +from jsonargparse_tests.conftest import json_or_yaml_dump, skip_if_omegaconf_unavailable + +# __init__ defaults override tests + + +class DefaultsOverrideParent(FromConfigMixin): + def __init__(self, parent2: int = 1, parent1: str = "parent_default_value"): + self.parent1 = parent1 + self.parent2 = parent2 + + +def test_init_defaults_override_subclass(tmp_cwd, subtests): + config_path = tmp_cwd / "config.yaml" + config_path.write_text(json_or_yaml_dump({"parent1": "overridden_parent", "child1": "overridden_child"})) + + class DefaultsOverrideChild(DefaultsOverrideParent): + __from_config_init_defaults__ = config_path + + def __init__(self, child2: int = 2, child1: str = "child_default_value", **kwargs): + super().__init__(**kwargs) + self.child1 = child1 + self.child2 = child2 + + with subtests.test("overridden subclass defaults"): + params = inspect.signature(DefaultsOverrideChild.__init__).parameters + assert params["parent1"].default == "overridden_parent" + assert params["child1"].default == "overridden_child" + + instance = DefaultsOverrideChild() + assert instance.parent1 == "overridden_parent" + assert instance.parent2 == 1 + assert instance.child1 == "overridden_child" + assert instance.child2 == 2 + + with subtests.test("shadow override"): + instance = DefaultsOverrideChild(child1="shadowed_child", parent1="shadowed_parent", child2=3) + assert instance.parent1 == "shadowed_parent" + assert instance.parent2 == 1 + assert instance.child1 == "shadowed_child" + assert instance.child2 == 3 + + with subtests.test("parent class unaffected"): + params = inspect.signature(DefaultsOverrideParent.__init__).parameters + assert params["parent1"].default == "parent_default_value" + + parent = DefaultsOverrideParent() + assert parent.parent1 == "parent_default_value" + + +def test_init_defaults_override_keyword_only_parameters(tmp_cwd): + config_path = tmp_cwd / "config.yaml" + config_path.write_text(json_or_yaml_dump({"parent1": "overridden_parent", "child1": "overridden_child"})) + + class DefaultsOverrideKeywordOnlyParent(FromConfigMixin): + def __init__(self, *, parent1: str = "parent_default_value", parent2: int = 1): + self.parent1 = parent1 + self.parent2 = parent2 + + class DefaultsOverrideKeywordOnlyChild(DefaultsOverrideKeywordOnlyParent): + __from_config_init_defaults__ = config_path + + def __init__(self, *, child2: int = 2, child1: str = "child_default_value", **kwargs): + super().__init__(**kwargs) + self.child1 = child1 + self.child2 = child2 + + params = inspect.signature(DefaultsOverrideKeywordOnlyChild.__init__).parameters + assert params["parent1"].default == "overridden_parent" + assert params["child1"].default == "overridden_child" + + instance = DefaultsOverrideKeywordOnlyChild() + assert instance.parent1 == "overridden_parent" + assert instance.parent2 == 1 + assert instance.child1 == "overridden_child" + assert instance.child2 == 2 + + +def test_init_defaults_override_class_with_init_subclass(tmp_cwd): + config_path = tmp_cwd / "config.yaml" + config_path.write_text(json_or_yaml_dump({"parent": "overridden_parent", "child": "overridden_child"})) + + class DefaultsOverrideBase(FromConfigMixin): + def __init_subclass__(cls, **kwargs): + super().__init_subclass__(**kwargs) + cls.original_init_subclass_ran = True + + def __init__(self, parent: str = "default_value"): + self.parent = parent + + class DefaultsOverrideDerived(DefaultsOverrideBase): + __from_config_init_defaults__ = config_path + + def __init__(self, child: str = "default_value", **kwargs): + super().__init__(**kwargs) + self.child = child + + params = inspect.signature(DefaultsOverrideDerived.__init__).parameters + assert params["parent"].default == "overridden_parent" + assert params["child"].default == "overridden_child" + + instance = DefaultsOverrideDerived() + assert instance.parent == "overridden_parent" + assert instance.child == "overridden_child" + assert instance.original_init_subclass_ran + + +@skip_if_omegaconf_unavailable +def test_init_defaults_override_parser_kwargs(tmp_cwd): + config_path = tmp_cwd / "config.yaml" + config_path.write_text(json_or_yaml_dump({"param1": "there", "param2": "hi ${.param1}"})) + + class DefaultsOverrideParserKwargs(FromConfigMixin): + __from_config_init_defaults__ = config_path + __from_config_parser_kwargs__ = {"parser_mode": "omegaconf+"} + + def __init__(self, param1: str = "", param2: str = ""): + self.param1 = param1 + self.param2 = param2 + + instance = DefaultsOverrideParserKwargs() + assert instance.param1 == "there" + assert instance.param2 == "hi there" + + +def test_init_defaults_override_file_not_found(): + class DefaultsOverrideFileNotFound(DefaultsOverrideParent): + __from_config_init_defaults__ = "non_existent_file.yaml" + + instance = DefaultsOverrideFileNotFound() + assert instance.parent1 == "parent_default_value" + assert instance.parent2 == 1 + + +def test_init_defaults_override_invalid(): + with pytest.raises(TypeError, match="__from_config_init_defaults__ must be str, PathLike, or None"): + + class DefaultsOverrideInvalid(DefaultsOverrideParent): + __from_config_init_defaults__ = 123 # Invalid type + + +# from_config method tests + + +def test_from_config_method_path(tmp_cwd): + config_path = tmp_cwd / "config.yaml" + config_path.write_text(json_or_yaml_dump({"param": "value_from_file"})) + + class FromConfigMethodPath(FromConfigMixin): + def __init__(self, param: str = "default_value"): + self.param = param + + instance = FromConfigMethodPath.from_config(config_path) + assert instance.param == "value_from_file" + + +def test_from_config_method_dict(): + class FromConfigMethodDict(FromConfigMixin): + def __init__(self, param: str = "default_value"): + self.param = param + + instance = FromConfigMethodDict.from_config({"param": "value_from_dict"}) + assert instance.param == "value_from_dict" + + +def test_from_config_method_default(): + from os import PathLike + from typing import Literal, Type, TypeVar, Union + + T = TypeVar("T") + default_config = {"param1": "method_default_value"} + + class FromConfigMethodDefault(FromConfigMixin): + + @classmethod + def from_config(cls: Type[T], config: Union[str, PathLike, dict, Literal["default"]] = "default") -> T: + if config == "default": + config = default_config + return super().from_config(config) + + def __init__(self, param1: str = "default_value", param2: int = 1): + self.param1 = param1 + self.param2 = param2 + + instance = FromConfigMethodDefault.from_config() + assert instance.param1 == "method_default_value" + assert instance.param2 == 1 + + +def test_from_config_method_subclass(): + class FromConfigMethodParent(FromConfigMixin): + def __init__(self, parent_param: str = "parent_default"): + self.parent_param = parent_param + + class FromConfigMethodChild(FromConfigMethodParent): + def __init__(self, child_param: str = "child_default", **kwargs): + super().__init__(**kwargs) + self.child_param = child_param + + instance = FromConfigMethodChild.from_config( + {"parent_param": "overridden_parent", "child_param": "overridden_child"} + ) + assert isinstance(instance, FromConfigMethodChild) + assert instance.parent_param == "overridden_parent" + assert instance.child_param == "overridden_child" + + +@skip_if_omegaconf_unavailable +def test_from_config_method_parser_kwargs(): + class FromConfigMethodParserKwargs(FromConfigMixin): + __from_config_parser_kwargs__ = {"parser_mode": "omegaconf+"} + + def __init__(self, param1: str, param2: str): + self.param1 = param1 + self.param2 = param2 + + instance = FromConfigMethodParserKwargs.from_config({"param1": "there", "param2": "hi ${.param1}"}) + assert instance.param1 == "there" + assert instance.param2 == "hi there" diff --git a/jsonargparse_tests/test_omegaconf.py b/jsonargparse_tests/test_omegaconf.py index 9cc27cf1..f2c32d61 100644 --- a/jsonargparse_tests/test_omegaconf.py +++ b/jsonargparse_tests/test_omegaconf.py @@ -15,16 +15,11 @@ from jsonargparse._loaders_dumpers import loaders, yaml_dump from jsonargparse._optionals import omegaconf_absolute_to_relative_paths, omegaconf_support from jsonargparse.typing import Path_fr -from jsonargparse_tests.conftest import get_parser_help +from jsonargparse_tests.conftest import get_parser_help, skip_if_omegaconf_unavailable if omegaconf_support: from omegaconf import OmegaConf -skip_if_omegaconf_unavailable = pytest.mark.skipif( - not omegaconf_support, - reason="omegaconf package is required", -) - @pytest.fixture(autouse=True) def patch_loaders():