Skip to content

Commit 890cd45

Browse files
authored
Fix: List of union of classes not accepted by add_subclass_arguments in python>=3.11 (#522)
1 parent 1dd2664 commit 890cd45

File tree

6 files changed

+75
-15
lines changed

6 files changed

+75
-15
lines changed

CHANGELOG.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,9 @@ Fixed
3434
(`#523 <https://github.com/omni-us/jsonargparse/pull/523>`__).
3535
- Failing to parse list of dataclasses with nested optional dataclass (`#527
3636
<https://github.com/omni-us/jsonargparse/pull/527>`__).
37+
- List of union of classes not accepted by ``add_subclass_arguments`` in
38+
``python>=3.11`` (`#522
39+
<https://github.com/omni-us/jsonargparse/pull/522>`__).
3740

3841

3942
v4.29.0 (2024-05-24)

jsonargparse/_actions.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -349,14 +349,9 @@ def __init__(self, baseclass=None, **kwargs):
349349
super().__init__(**kwargs)
350350

351351
def update_init_kwargs(self, kwargs):
352-
if get_typehint_origin(self._baseclass) == Union:
353-
from ._typehints import ActionTypeHint
352+
from ._typehints import get_subclasses_from_type
354353

355-
self._basename = iter_to_set_str(
356-
c.__name__ for c in self._baseclass.__args__ if ActionTypeHint.is_subclass_typehint(c)
357-
)
358-
else:
359-
self._basename = self._baseclass.__name__
354+
self._basename = iter_to_set_str(get_subclasses_from_type(self._baseclass))
360355
kwargs.update(
361356
{
362357
"metavar": "CLASS_PATH_OR_NAME",

jsonargparse/_parameter_resolvers.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from ._util import (
2929
ClassFromFunctionBase,
3030
get_import_path,
31+
get_typehint_origin,
3132
iter_to_set_str,
3233
unique,
3334
)
@@ -74,6 +75,10 @@ def __init__(self, resolver: str, data: Any) -> None:
7475

7576

7677
def get_parameter_origins(component, parent) -> Optional[str]:
78+
from ._typehints import get_subclasses_from_type, sequence_origin_types
79+
80+
if get_typehint_origin(component) in sequence_origin_types:
81+
component = get_subclasses_from_type(component, names=False)
7782
if isinstance(component, tuple):
7883
assert parent is None or len(component) == len(parent)
7984
return iter_to_set_str(get_parameter_origins(c, parent[n] if parent else None) for n, c in enumerate(component))

jsonargparse/_signatures.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,10 @@
2626
ActionTypeHint,
2727
LazyInitBaseClass,
2828
callable_instances,
29+
get_subclasses_from_type,
2930
is_optional,
3031
)
31-
from ._util import get_import_path, get_private_kwargs, iter_to_set_str
32+
from ._util import get_private_kwargs, iter_to_set_str
3233
from .typing import register_pydantic_type
3334

3435
__all__ = [
@@ -510,8 +511,8 @@ def add_subclass_arguments(
510511
raise ValueError("Not allowed for dataclass-like classes.")
511512
if type(baseclass) is not tuple:
512513
baseclass = (baseclass,) # type: ignore[assignment]
513-
if not baseclass or not all(inspect.isclass(c) for c in baseclass):
514-
raise ValueError(f"Expected 'baseclass' argument to be a class or a tuple of classes: {baseclass}")
514+
if not baseclass or not all(ActionTypeHint.is_subclass_typehint(c, also_lists=True) for c in baseclass):
515+
raise ValueError(f"Expected 'baseclass' to be a subclass type or a tuple of subclass types: {baseclass}")
515516

516517
doc_group = None
517518
if len(baseclass) == 1: # type: ignore[arg-type]
@@ -530,7 +531,7 @@ def add_subclass_arguments(
530531
if skip is not None:
531532
skip = {f"{nested_key}.init_args." + s for s in skip}
532533
param = ParamData(name=nested_key, annotation=Union[baseclass], component=baseclass)
533-
str_baseclass = iter_to_set_str(get_import_path(x) for x in baseclass)
534+
str_baseclass = iter_to_set_str(get_subclasses_from_type(param.annotation))
534535
kwargs.update(
535536
{
536537
"metavar": metavar,

jsonargparse/_typehints.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -668,6 +668,22 @@ def is_pathlike(typehint) -> bool:
668668
return is_subclass(typehint, os.PathLike)
669669

670670

671+
def get_subclasses_from_type(typehint, names=True, subclasses=None) -> tuple:
672+
if subclasses is None:
673+
subclasses = []
674+
origin = get_typehint_origin(typehint)
675+
if origin == Union or origin in sequence_origin_types:
676+
for subtype in typehint.__args__:
677+
get_subclasses_from_type(subtype, names, subclasses)
678+
elif ActionTypeHint.is_subclass_typehint(typehint, all_subtypes=False):
679+
if names:
680+
if typehint.__name__ not in subclasses:
681+
subclasses.append(typehint.__name__)
682+
elif typehint not in subclasses:
683+
subclasses.append(typehint)
684+
return tuple(subclasses)
685+
686+
671687
def raise_unexpected_value(message: str, val: Any = inspect._empty, exception: Optional[Exception] = None) -> NoReturn:
672688
if val is not inspect._empty:
673689
message += f". Got value: {val}"
@@ -1105,14 +1121,20 @@ def is_private(class_path):
11051121
return "._" in class_path
11061122

11071123
def add_subclasses(cl):
1124+
if hasattr(cl, "__args__") and get_typehint_origin(cl) in {List, list, Union}:
1125+
for arg in cl.__args__:
1126+
add_subclasses(arg)
1127+
return
11081128
try:
11091129
class_path = get_import_path(cl)
11101130
except (ImportError, AttributeError) as err: # Attribute is added in case of dot notation imports
11111131
warning(f"Hit failing import with following error: {err}")
11121132
return
1113-
if is_local(cl) or issubclass(cl, LazyInitBaseClass):
1133+
if is_local(cl) or is_subclass(cl, LazyInitBaseClass):
11141134
return
11151135
if not (inspect.isabstract(cl) or is_private(class_path)):
1136+
if class_path in subclass_list:
1137+
return
11161138
subclass_list.append(class_path)
11171139
for subclass in cl.__subclasses__() if hasattr(cl, "__subclasses__") else []:
11181140
add_subclasses(subclass)
@@ -1124,7 +1146,7 @@ def add_subclasses(cl):
11241146

11251147
if get_typehint_origin(cls) in {Union, Type, type}:
11261148
for arg in cls.__args__:
1127-
if ActionTypeHint.is_subclass_typehint(arg) and arg not in {object, type}:
1149+
if ActionTypeHint.is_subclass_typehint(arg, also_lists=True) and arg not in {object, type}:
11281150
add_subclasses(arg)
11291151
else:
11301152
add_subclasses(cls)

jsonargparse_tests/test_subclasses.py

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1267,13 +1267,13 @@ def test_subclass_unresolved_parameters_name_clash(parser):
12671267
def test_add_subclass_failure_not_a_class(parser):
12681268
with pytest.raises(ValueError) as ctx:
12691269
parser.add_subclass_arguments(NAMESPACE_OID, "oid")
1270-
ctx.match("Expected 'baseclass' argument to be a class or a tuple of classes")
1270+
ctx.match("Expected 'baseclass' to be a subclass type or a tuple of subclass types")
12711271

12721272

12731273
def test_add_subclass_failure_empty_tuple(parser):
12741274
with pytest.raises(ValueError) as ctx:
12751275
parser.add_subclass_arguments((), "cls")
1276-
ctx.match("Expected 'baseclass' argument to be a class or a tuple of classes")
1276+
ctx.match("Expected 'baseclass' to be a subclass type or a tuple of subclass types")
12771277

12781278

12791279
def test_add_subclass_lazy_default(parser):
@@ -1337,6 +1337,40 @@ def test_add_subclass_not_required_group(parser):
13371337
assert init == Namespace()
13381338

13391339

1340+
class ListUnionA:
1341+
def __init__(self, pa1: int):
1342+
self.pa1 = pa1
1343+
1344+
1345+
class ListUnionB:
1346+
def __init__(self, pb1: str, pb2: float):
1347+
self.pb1 = pb1
1348+
self.pb2 = pb2
1349+
1350+
1351+
def test_add_subclass_list_of_union(parser):
1352+
parser.add_argument("--config", action="config")
1353+
parser.add_subclass_arguments(
1354+
baseclass=(ListUnionA, ListUnionB, List[Union[ListUnionA, ListUnionB]]),
1355+
nested_key="subclass",
1356+
)
1357+
config = {
1358+
"subclass": [
1359+
{
1360+
"class_path": f"{__name__}.ListUnionB",
1361+
"init_args": {
1362+
"pb1": "x",
1363+
"pb2": 0.5,
1364+
},
1365+
}
1366+
]
1367+
}
1368+
cfg = parser.parse_args([f"--config={config}"])
1369+
assert cfg.as_dict()["subclass"] == config["subclass"]
1370+
help_str = get_parser_help(parser)
1371+
assert "Show the help for the given subclass of {ListUnionA,ListUnionB}" in help_str
1372+
1373+
13401374
# instance defaults tests
13411375

13421376

0 commit comments

Comments
 (0)