From 77d7737d0dd3b1857dae2d7d568c5e1f12bcde4f Mon Sep 17 00:00:00 2001 From: Mauricio Villegas <5780272+mauvilsa@users.noreply.github.com> Date: Fri, 7 Nov 2025 20:18:03 -0500 Subject: [PATCH 1/4] Non-public experimental decorator to override class defaults and instantiate from config --- CHANGELOG.rst | 5 + jsonargparse/__init__.py | 1 + jsonargparse/_from_config.py | 193 +++++++++++++++++++ jsonargparse_tests/test_from_config.py | 251 +++++++++++++++++++++++++ 4 files changed, 450 insertions(+) create mode 100644 jsonargparse/_from_config.py create mode 100644 jsonargparse_tests/test_from_config.py diff --git a/CHANGELOG.rst b/CHANGELOG.rst index e6244477..e03229b6 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -15,6 +15,11 @@ paths are considered internals and can change in minor and patch releases. v4.43.0 (unreleased) -------------------- +Added +^^^^^ +- Non-public experimental decorator to override class defaults and instantiate + from config (`#800 `__). + Fixed ^^^^^ - Prevent extra environment variables in helptext when default_env=True, for diff --git a/jsonargparse/__init__.py b/jsonargparse/__init__.py index 53ad4f05..e3d3031a 100644 --- a/jsonargparse/__init__.py +++ b/jsonargparse/__init__.py @@ -14,6 +14,7 @@ from ._core import * # noqa: F403 from ._deprecated import * # noqa: F403 from ._formatters import * # noqa: F403 +from ._from_config import * # noqa: F403 from ._jsonnet import * # noqa: F403 from ._jsonschema import * # noqa: F403 from ._link_arguments import * # noqa: F403 diff --git a/jsonargparse/_from_config.py b/jsonargparse/_from_config.py new file mode 100644 index 00000000..7b68d94c --- /dev/null +++ b/jsonargparse/_from_config.py @@ -0,0 +1,193 @@ +import inspect +from functools import wraps +from os import PathLike +from pathlib import Path +from typing import Optional, Type, TypeVar, Union + +from ._core import ArgumentParser + +__all__ = ["from_config_support"] + +T = TypeVar("T") + + +def _parse_class_kwargs_from_config(cls: Type[T], config: Union[str, PathLike, dict]) -> dict: + """Parse the init kwargs for `cls` from a config file or dict.""" + parser = ArgumentParser(exit_on_error=False) + parser.add_class_arguments(cls) + if isinstance(config, dict): + cfg = parser.parse_object(config, defaults=False) + else: + cfg = parser.parse_path(config, defaults=False) + return parser.instantiate_classes(cfg).as_dict() + + +def _override_init_defaults_from_config(cls: Type[T]) -> None: + """Override __init__ defaults for `cls` based on a config file.""" + config = getattr(cls, "__from_config_defaults__", None) + if not isinstance(config, (str, PathLike, type(None))): + raise TypeError("__from_config_defaults__ must be str, PathLike, or None") + if not (isinstance(config, (str, PathLike)) and Path(config).is_file()): + return + + defaults = _parse_class_kwargs_from_config(cls, config) + + # Override defaults for parameters in cls.__init__ + params = inspect.signature(cls.__init__).parameters + for name, default in defaults.copy().items(): + param = params.get(name) + if param and param.kind in {inspect.Parameter.KEYWORD_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD}: + defaults.pop(name) + if param.kind == inspect.Parameter.KEYWORD_ONLY: + cls.__init__.__kwdefaults__[name] = default # type: ignore[index] + else: + index = list(params).index(name) - 1 + aux = cls.__init__.__defaults__ or () + cls.__init__.__defaults__ = aux[:index] + (default,) + aux[index + 1 :] + + # Gather defaults for parameters in parent classes' __init__ + override_parent_params = [] + for base in inspect.getmro(cls)[1:]: + if not defaults: + break + + params = inspect.signature(base.__init__).parameters # type: ignore[misc] + for name, default in defaults.copy().items(): + if name in params: + defaults.pop(name) + new_param = inspect.Parameter( + name=name, + kind=inspect.Parameter.KEYWORD_ONLY, + default=default, + annotation=params[name].annotation, + ) + override_parent_params.append(new_param) + + # Override defaults for parameters in parent classes' __init__ via a wrapper + if override_parent_params: + original_init = cls.__init__ + original_sig = inspect.signature(cls.__init__) + parameters = list(original_sig.parameters.values()) + + # Find and pop the **kwargs parameter, if it exists + kwargs_param = None + if parameters and parameters[-1].kind == inspect.Parameter.VAR_KEYWORD: + kwargs_param = parameters.pop() + + # Add new parameters + for param in reversed(override_parent_params): + parameters.append(param) + + # Add **kwargs back at the end + if kwargs_param: + parameters.append(kwargs_param) + + # Create and set __init__ wrapper with new signature + parent_defaults = {p.name: p.default for p in override_parent_params} + + @wraps(original_init) + def wrapper(*args, **kwargs): + for name, default in parent_defaults.items(): + if name not in kwargs: + kwargs[name] = default + return original_init(*args, **kwargs) + + wrapper.__signature__ = original_sig.replace(parameters=parameters) # type: ignore[attr-defined] + cls.__init__ = wrapper # type: ignore[method-assign] + + +def from_config_support( + *args, + from_config_method: bool = True, + from_config_method_name: str = "from_config", + from_config_method_default: Optional[Union[str, PathLike, dict]] = None, +): + """Class decorator that adds config support to a base class. + + This decorator does two things: + + 1. Adds support for overriding __init__ defaults by defining a + `__from_config_defaults__` class attribute pointing to a config file + path. The overriding of defaults happens on decorator application or + on class creation for subclasses. Inspecting the signature will + give the overridden defaults. + + 2. Adds a @classmethod, by default named `from_config`, that + instantiates the class based on a config file or dict. + + The decorator can be used without parentheses, e.g. + + @from_config_support + class MyClass: + ... + + Use parentheses to customize the behavior, e.g. + + @from_config_support(from_config_method=False) + class MyClass: + ... + + Args: + from_config_method: Whether to add the `from_config` classmethod. + from_config_method_name: Name of the `from_config` classmethod. + from_config_method_default: Default value for the `config` parameter. + """ + + def decorator(cls: Type[T]) -> Type[T]: + if not inspect.isclass(cls): + raise TypeError("from_config_support can only be applied to classes") + + # 1. Add the from_config classmethod to the base class + if from_config_method: + + def from_config(cls: Type[T], config: Union[str, PathLike, dict]) -> T: + """Instantiate current class based on a config file or dict. + + Args: + config: Path to a config file or a dict with config values. + """ + kwargs = _parse_class_kwargs_from_config(cls, config) + return cls(**kwargs) + + if from_config_method_default is not None: + from_config.__defaults__ = (from_config_method_default,) + from_config.__name__ = from_config_method_name + + from_config.__module__ = cls.__module__ + from_config.__qualname__ = f"{cls.__name__}.{from_config_method_name}" + setattr(cls, from_config_method_name, classmethod(from_config)) + + # 2. Override defaults for the decorated class itself + _override_init_defaults_from_config(cls) + + # 3. Get the original __init_subclass__ defined on `cls`, if any. + # Check __dict__ so that parent's method isn't grabbed. + original_init_subclass = cls.__dict__.get("__init_subclass__") + + # 4. Create the new __init_subclass__ + def new_init_subclass(cls_sub, **kwargs): + """This method will be called on every subclass of `cls`.""" + # A. Override defaults for the subclass + _override_init_defaults_from_config(cls_sub) + + # B. Call the original __init_subclass__ if this class defined one + if original_init_subclass: + # Call the original function (it's a classmethod object) + original_init_subclass.__func__(cls_sub, **kwargs) + else: + # This class (cls) didn't have one, so just call up the MRO to the *next* class. + # super(cls, cls_sub) finds the next __init_subclass__ in the MRO *after* `cls`. + super(cls, cls_sub).__init_subclass__(**kwargs) + + # 5. Attach the new method to the class + cls.__init_subclass__ = classmethod(new_init_subclass) # type: ignore[assignment] + + return cls + + # Handle decorator usage without parentheses + if len(args) > 0: + if len(args) == 1: + return decorator(args[0]) + raise TypeError("from_config_support can only receive a single positional argument") + + return decorator diff --git a/jsonargparse_tests/test_from_config.py b/jsonargparse_tests/test_from_config.py new file mode 100644 index 00000000..d39121d0 --- /dev/null +++ b/jsonargparse_tests/test_from_config.py @@ -0,0 +1,251 @@ +import inspect +import sys +from unittest.mock import patch + +import pytest + +from jsonargparse import ArgumentParser, from_config_support +from jsonargparse_tests.conftest import json_or_yaml_dump + +# decorator usage tests + + +def test_decorator_multiple_positional_arguments(): + class Class: + pass + + with pytest.raises(TypeError, match="from_config_support can only receive a single positional argument"): + from_config_support(Class, 123) + + +def test_decorator_non_class_argument(): + with pytest.raises(TypeError, match="from_config_support can only be applied to classes"): + from_config_support(123) + + +# defaults override tests + + +class DefaultsOverrideSelf: + def __init__(self, param1: str = "default_value", param2: int = 1): + self.param1 = param1 + self.param2 = param2 + + +@pytest.mark.skipif(sys.version_info < (3, 11), reason="patch.object doesn't work correctly") +def test_defaults_override_self(tmp_cwd): + config_path = tmp_cwd / "config.yaml" + config_path.write_text(json_or_yaml_dump({"param1": "overridden_from_path"})) + DefaultsOverrideSelf.__from_config_defaults__ = config_path + + with patch.object(ArgumentParser, "parse_path", wraps=ArgumentParser.parse_path, autospec=True) as mock: + from_config_support(DefaultsOverrideSelf) + assert mock.call_count == 1 + assert mock.mock_calls[0].kwargs["defaults"] is False + + params = inspect.signature(DefaultsOverrideSelf.__init__).parameters + assert params["param1"].default == "overridden_from_path" + + instance = DefaultsOverrideSelf() + assert instance.param1 == "overridden_from_path" + assert instance.param2 == 1 + + +@from_config_support +class DefaultsOverrideParent: + def __init__(self, parent2: int = 1, parent1: str = "parent_default_value"): + self.parent1 = parent1 + self.parent2 = parent2 + + +def test_defaults_override_subclass(tmp_cwd, subtests): + config_path = tmp_cwd / "config.yaml" + config_path.write_text(json_or_yaml_dump({"parent1": "overridden_parent", "child1": "overridden_child"})) + + class DefaultsOverrideChild(DefaultsOverrideParent): + __from_config_defaults__ = config_path + + def __init__(self, child2: int = 2, child1: str = "child_default_value", **kwargs): + super().__init__(**kwargs) + self.child1 = child1 + self.child2 = child2 + + with subtests.test("overridden subclass defaults"): + params = inspect.signature(DefaultsOverrideChild.__init__).parameters + assert params["parent1"].default == "overridden_parent" + assert params["child1"].default == "overridden_child" + + instance = DefaultsOverrideChild() + assert instance.parent1 == "overridden_parent" + assert instance.parent2 == 1 + assert instance.child1 == "overridden_child" + assert instance.child2 == 2 + + with subtests.test("shadow override"): + instance = DefaultsOverrideChild(child1="shadowed_child", parent1="shadowed_parent", child2=3) + assert instance.parent1 == "shadowed_parent" + assert instance.parent2 == 1 + assert instance.child1 == "shadowed_child" + assert instance.child2 == 3 + + with subtests.test("parent class unaffected"): + params = inspect.signature(DefaultsOverrideParent.__init__).parameters + assert params["parent1"].default == "parent_default_value" + + parent = DefaultsOverrideParent() + assert parent.parent1 == "parent_default_value" + + +def test_defaults_override_keyword_only_parameters(tmp_cwd): + config_path = tmp_cwd / "config.yaml" + config_path.write_text(json_or_yaml_dump({"parent1": "overridden_parent", "child1": "overridden_child"})) + + @from_config_support + class DefaultsOverrideKeywordOnlyParent: + def __init__(self, *, parent1: str = "parent_default_value", parent2: int = 1): + self.parent1 = parent1 + self.parent2 = parent2 + + class DefaultsOverrideKeywordOnlyChild(DefaultsOverrideKeywordOnlyParent): + __from_config_defaults__ = config_path + + def __init__(self, *, child2: int = 2, child1: str = "child_default_value", **kwargs): + super().__init__(**kwargs) + self.child1 = child1 + self.child2 = child2 + + params = inspect.signature(DefaultsOverrideKeywordOnlyChild.__init__).parameters + assert params["parent1"].default == "overridden_parent" + assert params["child1"].default == "overridden_child" + + instance = DefaultsOverrideKeywordOnlyChild() + assert instance.parent1 == "overridden_parent" + assert instance.parent2 == 1 + assert instance.child1 == "overridden_child" + assert instance.child2 == 2 + + +def test_defaults_override_class_with_init_subclass(tmp_cwd): + config_path = tmp_cwd / "config.yaml" + config_path.write_text(json_or_yaml_dump({"parent": "overridden_parent", "child": "overridden_child"})) + + @from_config_support + class DefaultsOverrideBase: + def __init_subclass__(cls, **kwargs): + super().__init_subclass__(**kwargs) + cls.original_init_subclass_ran = True + + def __init__(self, parent: str = "default_value"): + self.parent = parent + + class DefaultsOverrideDerived(DefaultsOverrideBase): + __from_config_defaults__ = config_path + + def __init__(self, child: str = "default_value", **kwargs): + super().__init__(**kwargs) + self.child = child + + params = inspect.signature(DefaultsOverrideDerived.__init__).parameters + assert params["parent"].default == "overridden_parent" + assert params["child"].default == "overridden_child" + + instance = DefaultsOverrideDerived() + assert instance.parent == "overridden_parent" + assert instance.child == "overridden_child" + assert instance.original_init_subclass_ran + + +def test_defaults_override_file_not_found(): + class DefaultsOverrideFileNotFound(DefaultsOverrideParent): + __from_config_defaults__ = "non_existent_file.yaml" + + assert hasattr(DefaultsOverrideFileNotFound, "from_config") # make sure decorator applied to subclass + instance = DefaultsOverrideFileNotFound() + assert instance.parent1 == "parent_default_value" + assert instance.parent2 == 1 + + +def test_defaults_override_invalid(): + with pytest.raises(TypeError, match="__from_config_defaults__ must be str, PathLike, or None"): + + @from_config_support + class DefaultsOverrideInvalid: + __from_config_defaults__ = 123 # Invalid type + + +# from_config method tests + + +def test_without_from_config_method(): + @from_config_support(from_config_method=False) + class WithoutFromConfigMethod: + pass + + assert not hasattr(WithoutFromConfigMethod, "from_config") + + +def test_from_config_method_path(tmp_cwd): + config_path = tmp_cwd / "config.yaml" + config_path.write_text(json_or_yaml_dump({"param": "value_from_file"})) + + @from_config_support + class FromConfigMethodPath: + def __init__(self, param: str = "default_value"): + self.param = param + + instance = FromConfigMethodPath.from_config(config_path) + assert instance.param == "value_from_file" + assert FromConfigMethodPath.from_config.__func__.__qualname__ == "FromConfigMethodPath.from_config" + + +def test_from_config_method_dict(): + @from_config_support + class FromConfigMethodDict: + def __init__(self, param: str = "default_value"): + self.param = param + + instance = FromConfigMethodDict.from_config({"param": "value_from_dict"}) + assert instance.param == "value_from_dict" + + +def test_from_config_method_default(): + @from_config_support(from_config_method_default={"param1": "method_default_value"}) + class FromConfigMethodDefault: + def __init__(self, param1: str = "default_value", param2: int = 1): + self.param1 = param1 + self.param2 = param2 + + instance = FromConfigMethodDefault.from_config() + assert instance.param1 == "method_default_value" + assert instance.param2 == 1 + + +def test_from_config_method_subclass(): + @from_config_support + class FromConfigMethodParent: + def __init__(self, parent_param: str = "parent_default"): + self.parent_param = parent_param + + class FromConfigMethodChild(FromConfigMethodParent): + def __init__(self, child_param: str = "child_default", **kwargs): + super().__init__(**kwargs) + self.child_param = child_param + + instance = FromConfigMethodChild.from_config( + {"parent_param": "overridden_parent", "child_param": "overridden_child"} + ) + assert isinstance(instance, FromConfigMethodChild) + assert instance.parent_param == "overridden_parent" + assert instance.child_param == "overridden_child" + + +def test_from_config_method_custom_name(): + @from_config_support(from_config_method_name="custom_name") + class FromConfigMethodCustomName: + def __init__(self, param: str = "default_value"): + self.param = param + + assert hasattr(FromConfigMethodCustomName, "custom_name") + instance = FromConfigMethodCustomName.custom_name({"param": "custom_name_value"}) + assert instance.param == "custom_name_value" + assert FromConfigMethodCustomName.custom_name.__func__.__qualname__ == "FromConfigMethodCustomName.custom_name" From 6d47d3aca4af95fa0933127f56e8805435d1a319 Mon Sep 17 00:00:00 2001 From: Mauricio Villegas <5780272+mauvilsa@users.noreply.github.com> Date: Sun, 9 Nov 2025 06:59:19 -0500 Subject: [PATCH 2/4] - Switch from decorator to mixin class. - Allow providing parser kwargs. --- CHANGELOG.rst | 2 +- jsonargparse/_from_config.py | 221 ++++++++++--------------- jsonargparse/_optionals.py | 5 +- jsonargparse_tests/conftest.py | 6 + jsonargparse_tests/test_from_config.py | 153 +++++++---------- jsonargparse_tests/test_omegaconf.py | 7 +- 6 files changed, 155 insertions(+), 239 deletions(-) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index e03229b6..0e9690b2 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -17,7 +17,7 @@ v4.43.0 (unreleased) Added ^^^^^ -- Non-public experimental decorator to override class defaults and instantiate +- Non-public experimental mixin class to override class defaults and instantiate from config (`#800 `__). Fixed diff --git a/jsonargparse/_from_config.py b/jsonargparse/_from_config.py index 7b68d94c..ef1c3a35 100644 --- a/jsonargparse/_from_config.py +++ b/jsonargparse/_from_config.py @@ -6,14 +6,53 @@ from ._core import ArgumentParser -__all__ = ["from_config_support"] +__all__ = ["FromConfigMixin"] T = TypeVar("T") -def _parse_class_kwargs_from_config(cls: Type[T], config: Union[str, PathLike, dict]) -> dict: +class FromConfigMixin: + """Mixin class that adds config support to a base class. + + This mixin does two things: + + 1. Adds support for overriding __init__ defaults by defining a + `__from_config_init_defaults__` class attribute pointing to a config + file path. The overriding of defaults happens on subclass creation + time. Inspecting the signature will give the overridden defaults. + + 2. Adds a `from_config` @classmethod, that instantiates the class based + on a config file or dict. + + Attributes: + __from_config_init_defaults__: Optional path to a config file for + overriding __init__ defaults. + __from_config_parser_kwargs__: Additional kwargs to pass to the + ArgumentParser used for parsing configs. + """ + + __from_config_init_defaults__: Optional[Union[str, PathLike]] = None + __from_config_parser_kwargs__: dict = {} + + def __init_subclass__(cls, **kwargs) -> None: + """Override __init__ defaults for the subclass based on a config file.""" + super().__init_subclass__(**kwargs) + _override_init_defaults(cls, cls.__from_config_parser_kwargs__) + + @classmethod + def from_config(cls: Type[T], config: Union[str, PathLike, dict]) -> T: + """Instantiate current class based on a config file or dict. + + Args: + config: Path to a config file or a dict with config values. + """ + kwargs = _parse_class_kwargs_from_config(cls, config, **cls.__from_config_parser_kwargs__) # type: ignore[attr-defined] + return cls(**kwargs) + + +def _parse_class_kwargs_from_config(cls: Type[T], config: Union[str, PathLike, dict], **kwargs) -> dict: """Parse the init kwargs for `cls` from a config file or dict.""" - parser = ArgumentParser(exit_on_error=False) + parser = ArgumentParser(exit_on_error=False, **kwargs) parser.add_class_arguments(cls) if isinstance(config, dict): cfg = parser.parse_object(config, defaults=False) @@ -22,17 +61,20 @@ def _parse_class_kwargs_from_config(cls: Type[T], config: Union[str, PathLike, d return parser.instantiate_classes(cfg).as_dict() -def _override_init_defaults_from_config(cls: Type[T]) -> None: - """Override __init__ defaults for `cls` based on a config file.""" - config = getattr(cls, "__from_config_defaults__", None) +def _override_init_defaults(cls: Type[T], parser_kwargs: dict) -> None: + """Override __init__ defaults for `cls` based on __from_config_init_defaults__.""" + config = getattr(cls, "__from_config_init_defaults__", None) if not isinstance(config, (str, PathLike, type(None))): - raise TypeError("__from_config_defaults__ must be str, PathLike, or None") + raise TypeError("__from_config_init_defaults__ must be str, PathLike, or None") if not (isinstance(config, (str, PathLike)) and Path(config).is_file()): return - defaults = _parse_class_kwargs_from_config(cls, config) + defaults = _parse_class_kwargs_from_config(cls, config, **parser_kwargs) + _override_init_defaults_this_class(cls, defaults) + _override_init_defaults_parent_classes(cls, defaults) + - # Override defaults for parameters in cls.__init__ +def _override_init_defaults_this_class(cls: Type[T], defaults: dict) -> None: params = inspect.signature(cls.__init__).parameters for name, default in defaults.copy().items(): param = params.get(name) @@ -45,6 +87,8 @@ def _override_init_defaults_from_config(cls: Type[T]) -> None: aux = cls.__init__.__defaults__ or () cls.__init__.__defaults__ = aux[:index] + (default,) + aux[index + 1 :] + +def _override_init_defaults_parent_classes(cls: Type[T], defaults: dict) -> None: # Gather defaults for parameters in parent classes' __init__ override_parent_params = [] for base in inspect.getmro(cls)[1:]: @@ -63,131 +107,36 @@ def _override_init_defaults_from_config(cls: Type[T]) -> None: ) override_parent_params.append(new_param) - # Override defaults for parameters in parent classes' __init__ via a wrapper - if override_parent_params: - original_init = cls.__init__ - original_sig = inspect.signature(cls.__init__) - parameters = list(original_sig.parameters.values()) - - # Find and pop the **kwargs parameter, if it exists - kwargs_param = None - if parameters and parameters[-1].kind == inspect.Parameter.VAR_KEYWORD: - kwargs_param = parameters.pop() - - # Add new parameters - for param in reversed(override_parent_params): - parameters.append(param) - - # Add **kwargs back at the end - if kwargs_param: - parameters.append(kwargs_param) - - # Create and set __init__ wrapper with new signature - parent_defaults = {p.name: p.default for p in override_parent_params} - - @wraps(original_init) - def wrapper(*args, **kwargs): - for name, default in parent_defaults.items(): - if name not in kwargs: - kwargs[name] = default - return original_init(*args, **kwargs) - - wrapper.__signature__ = original_sig.replace(parameters=parameters) # type: ignore[attr-defined] - cls.__init__ = wrapper # type: ignore[method-assign] - - -def from_config_support( - *args, - from_config_method: bool = True, - from_config_method_name: str = "from_config", - from_config_method_default: Optional[Union[str, PathLike, dict]] = None, -): - """Class decorator that adds config support to a base class. - - This decorator does two things: - - 1. Adds support for overriding __init__ defaults by defining a - `__from_config_defaults__` class attribute pointing to a config file - path. The overriding of defaults happens on decorator application or - on class creation for subclasses. Inspecting the signature will - give the overridden defaults. - - 2. Adds a @classmethod, by default named `from_config`, that - instantiates the class based on a config file or dict. - - The decorator can be used without parentheses, e.g. - - @from_config_support - class MyClass: - ... - - Use parentheses to customize the behavior, e.g. - - @from_config_support(from_config_method=False) - class MyClass: - ... - - Args: - from_config_method: Whether to add the `from_config` classmethod. - from_config_method_name: Name of the `from_config` classmethod. - from_config_method_default: Default value for the `config` parameter. - """ - - def decorator(cls: Type[T]) -> Type[T]: - if not inspect.isclass(cls): - raise TypeError("from_config_support can only be applied to classes") - - # 1. Add the from_config classmethod to the base class - if from_config_method: - - def from_config(cls: Type[T], config: Union[str, PathLike, dict]) -> T: - """Instantiate current class based on a config file or dict. - - Args: - config: Path to a config file or a dict with config values. - """ - kwargs = _parse_class_kwargs_from_config(cls, config) - return cls(**kwargs) - - if from_config_method_default is not None: - from_config.__defaults__ = (from_config_method_default,) - from_config.__name__ = from_config_method_name - - from_config.__module__ = cls.__module__ - from_config.__qualname__ = f"{cls.__name__}.{from_config_method_name}" - setattr(cls, from_config_method_name, classmethod(from_config)) - - # 2. Override defaults for the decorated class itself - _override_init_defaults_from_config(cls) - - # 3. Get the original __init_subclass__ defined on `cls`, if any. - # Check __dict__ so that parent's method isn't grabbed. - original_init_subclass = cls.__dict__.get("__init_subclass__") - - # 4. Create the new __init_subclass__ - def new_init_subclass(cls_sub, **kwargs): - """This method will be called on every subclass of `cls`.""" - # A. Override defaults for the subclass - _override_init_defaults_from_config(cls_sub) - - # B. Call the original __init_subclass__ if this class defined one - if original_init_subclass: - # Call the original function (it's a classmethod object) - original_init_subclass.__func__(cls_sub, **kwargs) - else: - # This class (cls) didn't have one, so just call up the MRO to the *next* class. - # super(cls, cls_sub) finds the next __init_subclass__ in the MRO *after* `cls`. - super(cls, cls_sub).__init_subclass__(**kwargs) - - # 5. Attach the new method to the class - cls.__init_subclass__ = classmethod(new_init_subclass) # type: ignore[assignment] - - return cls - - # Handle decorator usage without parentheses - if len(args) > 0: - if len(args) == 1: - return decorator(args[0]) - raise TypeError("from_config_support can only receive a single positional argument") + if not override_parent_params: + return - return decorator + # Override defaults for parameters in parent classes' __init__ via a wrapper + original_init = cls.__init__ + original_sig = inspect.signature(cls.__init__) + parameters = list(original_sig.parameters.values()) + + # Find and pop the **kwargs parameter, if it exists + kwargs_param = None + if parameters and parameters[-1].kind == inspect.Parameter.VAR_KEYWORD: + kwargs_param = parameters.pop() + + # Add new parameters + for param in reversed(override_parent_params): + parameters.append(param) + + # Add **kwargs back at the end + if kwargs_param: + parameters.append(kwargs_param) + + # Create and set __init__ wrapper with new signature + parent_defaults = {p.name: p.default for p in override_parent_params} + + @wraps(original_init) + def wrapper(*args, **kwargs): + for name, default in parent_defaults.items(): + if name not in kwargs: + kwargs[name] = default + return original_init(*args, **kwargs) + + wrapper.__signature__ = original_sig.replace(parameters=parameters) # type: ignore[attr-defined] + cls.__init__ = wrapper # type: ignore[method-assign] diff --git a/jsonargparse/_optionals.py b/jsonargparse/_optionals.py index 56a2ac2a..567672b4 100644 --- a/jsonargparse/_optionals.py +++ b/jsonargparse/_optionals.py @@ -282,12 +282,9 @@ def get_omegaconf_loader(mode): if mode == "omegaconf+": from ._common import get_parsing_setting - if not get_parsing_setting("omegaconf_absolute_to_relative_paths"): - return yaml_load - def omegaconf_plus_load(value): value = yaml_load(value) - if isinstance(value, dict): + if isinstance(value, dict) and get_parsing_setting("omegaconf_absolute_to_relative_paths"): value = omegaconf_absolute_to_relative_paths(value) return value diff --git a/jsonargparse_tests/conftest.py b/jsonargparse_tests/conftest.py index 276e5ed3..80fc5277 100644 --- a/jsonargparse_tests/conftest.py +++ b/jsonargparse_tests/conftest.py @@ -83,6 +83,12 @@ reason="responses package is required", ) +skip_if_omegaconf_unavailable = pytest.mark.skipif( + not omegaconf_support, + reason="omegaconf package is required", +) + + skip_if_running_as_root = pytest.mark.skipif( is_posix and os.geteuid() == 0, reason="User is root, permission tests will not work", diff --git a/jsonargparse_tests/test_from_config.py b/jsonargparse_tests/test_from_config.py index d39121d0..9d059072 100644 --- a/jsonargparse_tests/test_from_config.py +++ b/jsonargparse_tests/test_from_config.py @@ -1,69 +1,25 @@ import inspect -import sys -from unittest.mock import patch import pytest -from jsonargparse import ArgumentParser, from_config_support -from jsonargparse_tests.conftest import json_or_yaml_dump +from jsonargparse import FromConfigMixin +from jsonargparse_tests.conftest import json_or_yaml_dump, skip_if_omegaconf_unavailable -# decorator usage tests +# __init__ defaults override tests -def test_decorator_multiple_positional_arguments(): - class Class: - pass - - with pytest.raises(TypeError, match="from_config_support can only receive a single positional argument"): - from_config_support(Class, 123) - - -def test_decorator_non_class_argument(): - with pytest.raises(TypeError, match="from_config_support can only be applied to classes"): - from_config_support(123) - - -# defaults override tests - - -class DefaultsOverrideSelf: - def __init__(self, param1: str = "default_value", param2: int = 1): - self.param1 = param1 - self.param2 = param2 - - -@pytest.mark.skipif(sys.version_info < (3, 11), reason="patch.object doesn't work correctly") -def test_defaults_override_self(tmp_cwd): - config_path = tmp_cwd / "config.yaml" - config_path.write_text(json_or_yaml_dump({"param1": "overridden_from_path"})) - DefaultsOverrideSelf.__from_config_defaults__ = config_path - - with patch.object(ArgumentParser, "parse_path", wraps=ArgumentParser.parse_path, autospec=True) as mock: - from_config_support(DefaultsOverrideSelf) - assert mock.call_count == 1 - assert mock.mock_calls[0].kwargs["defaults"] is False - - params = inspect.signature(DefaultsOverrideSelf.__init__).parameters - assert params["param1"].default == "overridden_from_path" - - instance = DefaultsOverrideSelf() - assert instance.param1 == "overridden_from_path" - assert instance.param2 == 1 - - -@from_config_support -class DefaultsOverrideParent: +class DefaultsOverrideParent(FromConfigMixin): def __init__(self, parent2: int = 1, parent1: str = "parent_default_value"): self.parent1 = parent1 self.parent2 = parent2 -def test_defaults_override_subclass(tmp_cwd, subtests): +def test_init_defaults_override_subclass(tmp_cwd, subtests): config_path = tmp_cwd / "config.yaml" config_path.write_text(json_or_yaml_dump({"parent1": "overridden_parent", "child1": "overridden_child"})) class DefaultsOverrideChild(DefaultsOverrideParent): - __from_config_defaults__ = config_path + __from_config_init_defaults__ = config_path def __init__(self, child2: int = 2, child1: str = "child_default_value", **kwargs): super().__init__(**kwargs) @@ -96,18 +52,17 @@ def __init__(self, child2: int = 2, child1: str = "child_default_value", **kwarg assert parent.parent1 == "parent_default_value" -def test_defaults_override_keyword_only_parameters(tmp_cwd): +def test_init_defaults_override_keyword_only_parameters(tmp_cwd): config_path = tmp_cwd / "config.yaml" config_path.write_text(json_or_yaml_dump({"parent1": "overridden_parent", "child1": "overridden_child"})) - @from_config_support - class DefaultsOverrideKeywordOnlyParent: + class DefaultsOverrideKeywordOnlyParent(FromConfigMixin): def __init__(self, *, parent1: str = "parent_default_value", parent2: int = 1): self.parent1 = parent1 self.parent2 = parent2 class DefaultsOverrideKeywordOnlyChild(DefaultsOverrideKeywordOnlyParent): - __from_config_defaults__ = config_path + __from_config_init_defaults__ = config_path def __init__(self, *, child2: int = 2, child1: str = "child_default_value", **kwargs): super().__init__(**kwargs) @@ -125,12 +80,11 @@ def __init__(self, *, child2: int = 2, child1: str = "child_default_value", **kw assert instance.child2 == 2 -def test_defaults_override_class_with_init_subclass(tmp_cwd): +def test_init_defaults_override_class_with_init_subclass(tmp_cwd): config_path = tmp_cwd / "config.yaml" config_path.write_text(json_or_yaml_dump({"parent": "overridden_parent", "child": "overridden_child"})) - @from_config_support - class DefaultsOverrideBase: + class DefaultsOverrideBase(FromConfigMixin): def __init_subclass__(cls, **kwargs): super().__init_subclass__(**kwargs) cls.original_init_subclass_ran = True @@ -139,7 +93,7 @@ def __init__(self, parent: str = "default_value"): self.parent = parent class DefaultsOverrideDerived(DefaultsOverrideBase): - __from_config_defaults__ = config_path + __from_config_init_defaults__ = config_path def __init__(self, child: str = "default_value", **kwargs): super().__init__(**kwargs) @@ -155,52 +109,57 @@ def __init__(self, child: str = "default_value", **kwargs): assert instance.original_init_subclass_ran -def test_defaults_override_file_not_found(): +@skip_if_omegaconf_unavailable +def test_init_defaults_override_parser_kwargs(tmp_cwd): + config_path = tmp_cwd / "config.yaml" + config_path.write_text(json_or_yaml_dump({"param1": "there", "param2": "hi ${.param1}"})) + + class DefaultsOverrideParserKwargs(FromConfigMixin): + __from_config_init_defaults__ = config_path + __from_config_parser_kwargs__ = {"parser_mode": "omegaconf+"} + + def __init__(self, param1: str, param2: str): + self.param1 = param1 + self.param2 = param2 + + instance = DefaultsOverrideParserKwargs() + assert instance.param1 == "there" + assert instance.param2 == "hi there" + + +def test_init_defaults_override_file_not_found(): class DefaultsOverrideFileNotFound(DefaultsOverrideParent): - __from_config_defaults__ = "non_existent_file.yaml" + __from_config_init_defaults__ = "non_existent_file.yaml" - assert hasattr(DefaultsOverrideFileNotFound, "from_config") # make sure decorator applied to subclass instance = DefaultsOverrideFileNotFound() assert instance.parent1 == "parent_default_value" assert instance.parent2 == 1 -def test_defaults_override_invalid(): - with pytest.raises(TypeError, match="__from_config_defaults__ must be str, PathLike, or None"): +def test_init_defaults_override_invalid(): + with pytest.raises(TypeError, match="__from_config_init_defaults__ must be str, PathLike, or None"): - @from_config_support - class DefaultsOverrideInvalid: - __from_config_defaults__ = 123 # Invalid type + class DefaultsOverrideInvalid(DefaultsOverrideParent): + __from_config_init_defaults__ = 123 # Invalid type # from_config method tests -def test_without_from_config_method(): - @from_config_support(from_config_method=False) - class WithoutFromConfigMethod: - pass - - assert not hasattr(WithoutFromConfigMethod, "from_config") - - def test_from_config_method_path(tmp_cwd): config_path = tmp_cwd / "config.yaml" config_path.write_text(json_or_yaml_dump({"param": "value_from_file"})) - @from_config_support - class FromConfigMethodPath: + class FromConfigMethodPath(FromConfigMixin): def __init__(self, param: str = "default_value"): self.param = param instance = FromConfigMethodPath.from_config(config_path) assert instance.param == "value_from_file" - assert FromConfigMethodPath.from_config.__func__.__qualname__ == "FromConfigMethodPath.from_config" def test_from_config_method_dict(): - @from_config_support - class FromConfigMethodDict: + class FromConfigMethodDict(FromConfigMixin): def __init__(self, param: str = "default_value"): self.param = param @@ -209,8 +168,17 @@ def __init__(self, param: str = "default_value"): def test_from_config_method_default(): - @from_config_support(from_config_method_default={"param1": "method_default_value"}) - class FromConfigMethodDefault: + from os import PathLike + from typing import Type, TypeVar, Union + + T = TypeVar("T") + + class FromConfigMethodDefault(FromConfigMixin): + + @classmethod + def from_config(cls: Type[T], config: Union[str, PathLike, dict] = {"param1": "method_default_value"}) -> T: + return super().from_config(config) + def __init__(self, param1: str = "default_value", param2: int = 1): self.param1 = param1 self.param2 = param2 @@ -221,8 +189,7 @@ def __init__(self, param1: str = "default_value", param2: int = 1): def test_from_config_method_subclass(): - @from_config_support - class FromConfigMethodParent: + class FromConfigMethodParent(FromConfigMixin): def __init__(self, parent_param: str = "parent_default"): self.parent_param = parent_param @@ -239,13 +206,15 @@ def __init__(self, child_param: str = "child_default", **kwargs): assert instance.child_param == "overridden_child" -def test_from_config_method_custom_name(): - @from_config_support(from_config_method_name="custom_name") - class FromConfigMethodCustomName: - def __init__(self, param: str = "default_value"): - self.param = param +@skip_if_omegaconf_unavailable +def test_from_config_method_parser_kwargs(): + class FromConfigMethodParserKwargs(FromConfigMixin): + __from_config_parser_kwargs__ = {"parser_mode": "omegaconf+"} + + def __init__(self, param1: str, param2: str): + self.param1 = param1 + self.param2 = param2 - assert hasattr(FromConfigMethodCustomName, "custom_name") - instance = FromConfigMethodCustomName.custom_name({"param": "custom_name_value"}) - assert instance.param == "custom_name_value" - assert FromConfigMethodCustomName.custom_name.__func__.__qualname__ == "FromConfigMethodCustomName.custom_name" + instance = FromConfigMethodParserKwargs.from_config({"param1": "there", "param2": "hi ${.param1}"}) + assert instance.param1 == "there" + assert instance.param2 == "hi there" diff --git a/jsonargparse_tests/test_omegaconf.py b/jsonargparse_tests/test_omegaconf.py index 9cc27cf1..f2c32d61 100644 --- a/jsonargparse_tests/test_omegaconf.py +++ b/jsonargparse_tests/test_omegaconf.py @@ -15,16 +15,11 @@ from jsonargparse._loaders_dumpers import loaders, yaml_dump from jsonargparse._optionals import omegaconf_absolute_to_relative_paths, omegaconf_support from jsonargparse.typing import Path_fr -from jsonargparse_tests.conftest import get_parser_help +from jsonargparse_tests.conftest import get_parser_help, skip_if_omegaconf_unavailable if omegaconf_support: from omegaconf import OmegaConf -skip_if_omegaconf_unavailable = pytest.mark.skipif( - not omegaconf_support, - reason="omegaconf package is required", -) - @pytest.fixture(autouse=True) def patch_loaders(): From e832da470d8c2e0bd45e81a419b0133cf0e18711 Mon Sep 17 00:00:00 2001 From: Mauricio Villegas <5780272+mauvilsa@users.noreply.github.com> Date: Sun, 9 Nov 2025 07:27:29 -0500 Subject: [PATCH 3/4] Minor fixes --- CHANGELOG.rst | 2 +- jsonargparse/_from_config.py | 20 ++++++++++---------- jsonargparse_tests/test_from_config.py | 2 +- 3 files changed, 12 insertions(+), 12 deletions(-) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 0e9690b2..4933583c 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -17,7 +17,7 @@ v4.43.0 (unreleased) Added ^^^^^ -- Non-public experimental mixin class to override class defaults and instantiate +- Non-public experimental mixin class to override init defaults and instantiate from config (`#800 `__). Fixed diff --git a/jsonargparse/_from_config.py b/jsonargparse/_from_config.py index ef1c3a35..81e04fa3 100644 --- a/jsonargparse/_from_config.py +++ b/jsonargparse/_from_config.py @@ -96,16 +96,16 @@ def _override_init_defaults_parent_classes(cls: Type[T], defaults: dict) -> None break params = inspect.signature(base.__init__).parameters # type: ignore[misc] - for name, default in defaults.copy().items(): - if name in params: - defaults.pop(name) - new_param = inspect.Parameter( - name=name, - kind=inspect.Parameter.KEYWORD_ONLY, - default=default, - annotation=params[name].annotation, - ) - override_parent_params.append(new_param) + names = [name for name in defaults if name in params] + for name in names: + default = defaults.pop(name) + new_param = inspect.Parameter( + name=name, + kind=inspect.Parameter.KEYWORD_ONLY, + default=default, + annotation=params[name].annotation, + ) + override_parent_params.append(new_param) if not override_parent_params: return diff --git a/jsonargparse_tests/test_from_config.py b/jsonargparse_tests/test_from_config.py index 9d059072..15c1b7df 100644 --- a/jsonargparse_tests/test_from_config.py +++ b/jsonargparse_tests/test_from_config.py @@ -118,7 +118,7 @@ class DefaultsOverrideParserKwargs(FromConfigMixin): __from_config_init_defaults__ = config_path __from_config_parser_kwargs__ = {"parser_mode": "omegaconf+"} - def __init__(self, param1: str, param2: str): + def __init__(self, param1: str = "", param2: str = ""): self.param1 = param1 self.param2 = param2 From 831f5fa3bf60da5da87a6a89ac600231e2b0946a Mon Sep 17 00:00:00 2001 From: Mauricio Villegas <5780272+mauvilsa@users.noreply.github.com> Date: Sun, 9 Nov 2025 19:53:14 -0500 Subject: [PATCH 4/4] More fixes --- jsonargparse/_from_config.py | 28 +++++++++++++------------- jsonargparse_tests/test_from_config.py | 7 +++++-- 2 files changed, 19 insertions(+), 16 deletions(-) diff --git a/jsonargparse/_from_config.py b/jsonargparse/_from_config.py index 81e04fa3..3c8d62cf 100644 --- a/jsonargparse/_from_config.py +++ b/jsonargparse/_from_config.py @@ -12,22 +12,23 @@ class FromConfigMixin: - """Mixin class that adds config support to a base class. + """Mixin class that adds from config support to classes. This mixin does two things: - 1. Adds support for overriding __init__ defaults by defining a - `__from_config_init_defaults__` class attribute pointing to a config - file path. The overriding of defaults happens on subclass creation - time. Inspecting the signature will give the overridden defaults. + 1. Adds support for overriding ``__init__`` defaults by defining a + ``__from_config_init_defaults__`` class attribute pointing to a + config file path. The overriding of defaults happens on subclass + creation time. Inspecting the signature will give the overridden + defaults. - 2. Adds a `from_config` @classmethod, that instantiates the class based - on a config file or dict. + 2. Adds a ``from_config`` ``@classmethod``, that instantiates the class + based on a config file or dict. Attributes: - __from_config_init_defaults__: Optional path to a config file for - overriding __init__ defaults. - __from_config_parser_kwargs__: Additional kwargs to pass to the + ``__from_config_init_defaults__``: Optional path to a config file for + overriding ``__init__`` defaults. + ``__from_config_parser_kwargs__``: Additional kwargs to pass to the ArgumentParser used for parsing configs. """ @@ -35,7 +36,7 @@ class FromConfigMixin: __from_config_parser_kwargs__: dict = {} def __init_subclass__(cls, **kwargs) -> None: - """Override __init__ defaults for the subclass based on a config file.""" + """Override ``__init__`` defaults for the subclass based on a config file.""" super().__init_subclass__(**kwargs) _override_init_defaults(cls, cls.__from_config_parser_kwargs__) @@ -62,7 +63,7 @@ def _parse_class_kwargs_from_config(cls: Type[T], config: Union[str, PathLike, d def _override_init_defaults(cls: Type[T], parser_kwargs: dict) -> None: - """Override __init__ defaults for `cls` based on __from_config_init_defaults__.""" + """Override ``__init__`` defaults for ``cls`` based on ``__from_config_init_defaults__``.""" config = getattr(cls, "__from_config_init_defaults__", None) if not isinstance(config, (str, PathLike, type(None))): raise TypeError("__from_config_init_defaults__ must be str, PathLike, or None") @@ -98,11 +99,10 @@ def _override_init_defaults_parent_classes(cls: Type[T], defaults: dict) -> None params = inspect.signature(base.__init__).parameters # type: ignore[misc] names = [name for name in defaults if name in params] for name in names: - default = defaults.pop(name) new_param = inspect.Parameter( name=name, kind=inspect.Parameter.KEYWORD_ONLY, - default=default, + default=defaults.pop(name), annotation=params[name].annotation, ) override_parent_params.append(new_param) diff --git a/jsonargparse_tests/test_from_config.py b/jsonargparse_tests/test_from_config.py index 15c1b7df..b324c7ec 100644 --- a/jsonargparse_tests/test_from_config.py +++ b/jsonargparse_tests/test_from_config.py @@ -169,14 +169,17 @@ def __init__(self, param: str = "default_value"): def test_from_config_method_default(): from os import PathLike - from typing import Type, TypeVar, Union + from typing import Literal, Type, TypeVar, Union T = TypeVar("T") + default_config = {"param1": "method_default_value"} class FromConfigMethodDefault(FromConfigMixin): @classmethod - def from_config(cls: Type[T], config: Union[str, PathLike, dict] = {"param1": "method_default_value"}) -> T: + def from_config(cls: Type[T], config: Union[str, PathLike, dict, Literal["default"]] = "default") -> T: + if config == "default": + config = default_config return super().from_config(config) def __init__(self, param1: str = "default_value", param2: int = 1):