Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,11 @@ paths are considered internals and can change in minor and patch releases.
v4.43.0 (unreleased)
--------------------

Added
^^^^^
- Non-public experimental mixin class to override init defaults and instantiate
from config (`#800 <https://github.com/omni-us/jsonargparse/pull/800>`__).

Fixed
^^^^^
- Prevent extra environment variables in helptext when default_env=True, for
Expand Down
1 change: 1 addition & 0 deletions jsonargparse/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from ._core import * # noqa: F403
from ._deprecated import * # noqa: F403
from ._formatters import * # noqa: F403
from ._from_config import * # noqa: F403
from ._jsonnet import * # noqa: F403
from ._jsonschema import * # noqa: F403
from ._link_arguments import * # noqa: F403
Expand Down
142 changes: 142 additions & 0 deletions jsonargparse/_from_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
import inspect
from functools import wraps
from os import PathLike
from pathlib import Path
from typing import Optional, Type, TypeVar, Union

from ._core import ArgumentParser

__all__ = ["FromConfigMixin"]

T = TypeVar("T")


class FromConfigMixin:
"""Mixin class that adds from config support to classes.

This mixin does two things:

1. Adds support for overriding ``__init__`` defaults by defining a
``__from_config_init_defaults__`` class attribute pointing to a
config file path. The overriding of defaults happens on subclass
creation time. Inspecting the signature will give the overridden
defaults.

2. Adds a ``from_config`` ``@classmethod``, that instantiates the class
based on a config file or dict.

Attributes:
``__from_config_init_defaults__``: Optional path to a config file for
overriding ``__init__`` defaults.
``__from_config_parser_kwargs__``: Additional kwargs to pass to the
ArgumentParser used for parsing configs.
"""

__from_config_init_defaults__: Optional[Union[str, PathLike]] = None
__from_config_parser_kwargs__: dict = {}

def __init_subclass__(cls, **kwargs) -> None:
"""Override ``__init__`` defaults for the subclass based on a config file."""
super().__init_subclass__(**kwargs)
_override_init_defaults(cls, cls.__from_config_parser_kwargs__)

@classmethod
def from_config(cls: Type[T], config: Union[str, PathLike, dict]) -> T:
"""Instantiate current class based on a config file or dict.

Args:
config: Path to a config file or a dict with config values.
"""
kwargs = _parse_class_kwargs_from_config(cls, config, **cls.__from_config_parser_kwargs__) # type: ignore[attr-defined]
return cls(**kwargs)


def _parse_class_kwargs_from_config(cls: Type[T], config: Union[str, PathLike, dict], **kwargs) -> dict:
"""Parse the init kwargs for `cls` from a config file or dict."""
parser = ArgumentParser(exit_on_error=False, **kwargs)
parser.add_class_arguments(cls)
if isinstance(config, dict):
cfg = parser.parse_object(config, defaults=False)
else:
cfg = parser.parse_path(config, defaults=False)
return parser.instantiate_classes(cfg).as_dict()


def _override_init_defaults(cls: Type[T], parser_kwargs: dict) -> None:
"""Override ``__init__`` defaults for ``cls`` based on ``__from_config_init_defaults__``."""
config = getattr(cls, "__from_config_init_defaults__", None)
if not isinstance(config, (str, PathLike, type(None))):
raise TypeError("__from_config_init_defaults__ must be str, PathLike, or None")
if not (isinstance(config, (str, PathLike)) and Path(config).is_file()):
return

defaults = _parse_class_kwargs_from_config(cls, config, **parser_kwargs)
_override_init_defaults_this_class(cls, defaults)
_override_init_defaults_parent_classes(cls, defaults)


def _override_init_defaults_this_class(cls: Type[T], defaults: dict) -> None:
params = inspect.signature(cls.__init__).parameters
for name, default in defaults.copy().items():
param = params.get(name)
if param and param.kind in {inspect.Parameter.KEYWORD_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD}:
defaults.pop(name)
if param.kind == inspect.Parameter.KEYWORD_ONLY:
cls.__init__.__kwdefaults__[name] = default # type: ignore[index]
else:
index = list(params).index(name) - 1
aux = cls.__init__.__defaults__ or ()
cls.__init__.__defaults__ = aux[:index] + (default,) + aux[index + 1 :]


def _override_init_defaults_parent_classes(cls: Type[T], defaults: dict) -> None:
# Gather defaults for parameters in parent classes' __init__
override_parent_params = []
for base in inspect.getmro(cls)[1:]:
if not defaults:
break

params = inspect.signature(base.__init__).parameters # type: ignore[misc]
names = [name for name in defaults if name in params]
for name in names:
new_param = inspect.Parameter(
name=name,
kind=inspect.Parameter.KEYWORD_ONLY,
default=defaults.pop(name),
annotation=params[name].annotation,
)
override_parent_params.append(new_param)

if not override_parent_params:
return

# Override defaults for parameters in parent classes' __init__ via a wrapper
original_init = cls.__init__
original_sig = inspect.signature(cls.__init__)
parameters = list(original_sig.parameters.values())

# Find and pop the **kwargs parameter, if it exists
kwargs_param = None
if parameters and parameters[-1].kind == inspect.Parameter.VAR_KEYWORD:
kwargs_param = parameters.pop()

# Add new parameters
for param in reversed(override_parent_params):
parameters.append(param)

# Add **kwargs back at the end
if kwargs_param:
parameters.append(kwargs_param)

# Create and set __init__ wrapper with new signature
parent_defaults = {p.name: p.default for p in override_parent_params}

@wraps(original_init)
def wrapper(*args, **kwargs):
for name, default in parent_defaults.items():
if name not in kwargs:
kwargs[name] = default
return original_init(*args, **kwargs)

wrapper.__signature__ = original_sig.replace(parameters=parameters) # type: ignore[attr-defined]
cls.__init__ = wrapper # type: ignore[method-assign]
5 changes: 1 addition & 4 deletions jsonargparse/_optionals.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,12 +282,9 @@ def get_omegaconf_loader(mode):
if mode == "omegaconf+":
from ._common import get_parsing_setting

if not get_parsing_setting("omegaconf_absolute_to_relative_paths"):
return yaml_load

def omegaconf_plus_load(value):
value = yaml_load(value)
if isinstance(value, dict):
if isinstance(value, dict) and get_parsing_setting("omegaconf_absolute_to_relative_paths"):
value = omegaconf_absolute_to_relative_paths(value)
return value

Expand Down
6 changes: 6 additions & 0 deletions jsonargparse_tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,12 @@
reason="responses package is required",
)

skip_if_omegaconf_unavailable = pytest.mark.skipif(
not omegaconf_support,
reason="omegaconf package is required",
)


skip_if_running_as_root = pytest.mark.skipif(
is_posix and os.geteuid() == 0,
reason="User is root, permission tests will not work",
Expand Down
Loading