Skip to content

Commit 0a4de31

Browse files
authored
Support for Protocol types only accepting exact matching signature of public methods (#526)
1 parent 0525fd2 commit 0a4de31

File tree

5 files changed

+143
-10
lines changed

5 files changed

+143
-10
lines changed

CHANGELOG.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,12 @@ paths are considered internals and can change in minor and patch releases.
1515
v4.31.0 (2024-06-??)
1616
--------------------
1717

18+
Added
19+
^^^^^
20+
- Support for ``Protocol`` types only accepting exact matching signature of
21+
public methods (`#526
22+
<https://github.com/omni-us/jsonargparse/pull/526>`__).
23+
1824
Fixed
1925
^^^^^
2026
- Resolving of import paths for some ``torch`` functions not working (`#535

DOCUMENTATION.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -421,6 +421,10 @@ Some notes about this support are:
421421
:py:meth:`.ArgumentParser.instantiate_classes` can be used to instantiate all
422422
classes in a config object. For more details see :ref:`sub-classes`.
423423

424+
- ``Protocol`` types are also supported the same as sub-classes. The protocols
425+
are not required to be ``runtime_checkable``. But the accepted classes must
426+
match exactly the signature of the protocol's public methods.
427+
424428
- ``dataclasses`` are supported even when nested. Final classes, attrs'
425429
``define`` decorator, and pydantic's ``dataclass`` decorator and ``BaseModel``
426430
classes are supported and behave like standard dataclasses. For more details

jsonargparse/_postponed_annotations.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -334,7 +334,7 @@ def evaluate_postponed_annotations(params, component, parent, logger):
334334
param.annotation = param_type
335335

336336

337-
def get_return_type(component, logger):
337+
def get_return_type(component, logger=None):
338338
return_type = inspect.signature(component).return_annotation
339339
if type_requires_eval(return_type):
340340
global_vars = vars(import_module(component.__module__))
@@ -343,6 +343,7 @@ def get_return_type(component, logger):
343343
if isinstance(return_type, ForwardRef):
344344
return_type = resolve_forward_refs(return_type.__forward_arg__, global_vars, logger)
345345
except Exception as ex:
346-
logger.debug(f"Unable to evaluate types for {component}", exc_info=ex)
346+
if logger:
347+
logger.debug(f"Unable to evaluate types for {component}", exc_info=ex)
347348
return None
348349
return return_type

jsonargparse/_typehints.py

Lines changed: 48 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -977,17 +977,17 @@ def adapt_typehints(
977977

978978
# Subclass
979979
elif not hasattr(typehint, "__origin__") and inspect.isclass(typehint):
980-
if isinstance(val, typehint):
980+
if is_instance_or_supports_protocol(val, typehint):
981981
if serialize:
982982
val = serialize_class_instance(val)
983983
return val
984984
if serialize and isinstance(val, str):
985985
return val
986986

987987
val_input = val
988-
if prev_val is None and not inspect.isabstract(typehint):
988+
if prev_val is None and not inspect.isabstract(typehint) and not is_protocol(typehint):
989989
with suppress(ValueError):
990-
prev_val = Namespace(class_path=get_import_path(typehint))
990+
prev_val = Namespace(class_path=get_import_path(typehint)) # implicit class_path
991991
val = subclass_spec_as_namespace(val, prev_val)
992992
if not is_subclass_spec(val):
993993
raise_unexpected_value(
@@ -1000,20 +1000,20 @@ def adapt_typehints(
10001000

10011001
try:
10021002
val_class = import_object(resolve_class_path_by_name(typehint, val["class_path"]))
1003-
if isinstance(val_class, typehint):
1004-
return val_class
1003+
if is_instance_or_supports_protocol(val_class, typehint):
1004+
return val_class # importable instance
10051005
not_subclass = False
1006-
if not is_subclass(val_class, typehint):
1006+
if not is_subclass_or_implements_protocol(val_class, typehint):
10071007
not_subclass = True
10081008
if not inspect.isclass(val_class) and callable(val_class):
10091009
from ._postponed_annotations import get_return_type
10101010

10111011
return_type = get_return_type(val_class, logger)
1012-
if is_subclass(return_type, typehint):
1012+
if is_subclass_or_implements_protocol(return_type, typehint):
10131013
not_subclass = False
10141014
if not_subclass:
10151015
raise_unexpected_value(
1016-
f'Import path {val["class_path"]} does not correspond to a subclass of {typehint}'
1016+
f"Import path {val['class_path']} does not correspond to a subclass of {typehint.__name__}"
10171017
)
10181018
val["class_path"] = get_import_path(val_class)
10191019
val = adapt_class_type(val, serialize, instantiate_classes, sub_add_kwargs, prev_val=prev_val)
@@ -1029,6 +1029,46 @@ def adapt_typehints(
10291029
return val
10301030

10311031

1032+
def implements_protocol(value, protocol) -> bool:
1033+
from jsonargparse._parameter_resolvers import get_signature_parameters
1034+
from jsonargparse._postponed_annotations import get_return_type
1035+
1036+
if not inspect.isclass(value):
1037+
return False
1038+
members = 0
1039+
for name, _ in inspect.getmembers(protocol, predicate=inspect.isfunction):
1040+
if name.startswith("_"):
1041+
continue
1042+
if not hasattr(value, name):
1043+
return False
1044+
members += 1
1045+
proto_params = get_signature_parameters(protocol, name)
1046+
value_params = get_signature_parameters(value, name)
1047+
if [(p.name, p.annotation) for p in proto_params] != [(p.name, p.annotation) for p in value_params]:
1048+
return False
1049+
proto_return = get_return_type(inspect.getattr_static(protocol, name))
1050+
value_return = get_return_type(inspect.getattr_static(value, name))
1051+
if proto_return != value_return:
1052+
return False
1053+
return True if members else False
1054+
1055+
1056+
def is_protocol(class_type) -> bool:
1057+
return getattr(class_type, "_is_protocol", False)
1058+
1059+
1060+
def is_subclass_or_implements_protocol(value, class_type) -> bool:
1061+
if is_protocol(class_type):
1062+
return implements_protocol(value, class_type)
1063+
return is_subclass(value, class_type)
1064+
1065+
1066+
def is_instance_or_supports_protocol(value, class_type):
1067+
if is_protocol(class_type):
1068+
return is_subclass_or_implements_protocol(value.__class__, class_type)
1069+
return isinstance(value, class_type)
1070+
1071+
10321072
def is_subclass_spec(val):
10331073
is_class = isinstance(val, (dict, Namespace)) and "class_path" in val
10341074
if is_class:

jsonargparse_tests/test_subclasses.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
Namespace,
2424
lazy_instance,
2525
)
26+
from jsonargparse._optionals import typing_extensions_import
27+
from jsonargparse._typehints import implements_protocol, is_instance_or_supports_protocol
2628
from jsonargparse.typing import final
2729
from jsonargparse_tests.conftest import (
2830
capture_logs,
@@ -32,6 +34,8 @@
3234
source_unavailable,
3335
)
3436

37+
Protocol = typing_extensions_import("Protocol")
38+
3539

3640
@pytest.mark.parametrize("type", [Calendar, Optional[Calendar]])
3741
def test_subclass_basics(parser, type):
@@ -1407,6 +1411,84 @@ def test_subclass_signature_instance_default(parser):
14071411
assert "cal: Unable to serialize instance <calendar.Calendar " in dump
14081412

14091413

1414+
# protocol tests
1415+
1416+
1417+
class Interface(Protocol): # type: ignore[valid-type,misc]
1418+
def predict(self, items: List[float]) -> List[float]: ... # type: ignore[empty-body]
1419+
1420+
1421+
class ImplementsInterface:
1422+
def __init__(self, batch_size: int):
1423+
self.batch_size = batch_size
1424+
1425+
def predict(self, items: List[float]) -> List[float]:
1426+
return items
1427+
1428+
1429+
class NotImplementsInterface1:
1430+
def predict(self, items: str) -> List[float]:
1431+
return []
1432+
1433+
1434+
class NotImplementsInterface2:
1435+
def predict(self, items: List[float], extra: int) -> List[float]:
1436+
return items
1437+
1438+
1439+
class NotImplementsInterface3:
1440+
def predict(self, items: List[float]) -> None:
1441+
return
1442+
1443+
1444+
@pytest.mark.parametrize(
1445+
"expected, value",
1446+
[
1447+
(True, ImplementsInterface),
1448+
(False, ImplementsInterface(1)),
1449+
(False, NotImplementsInterface1),
1450+
(False, NotImplementsInterface2),
1451+
(False, NotImplementsInterface3),
1452+
(False, object),
1453+
],
1454+
)
1455+
@pytest.mark.skipif(not Protocol, reason="Requires Python 3.8+ or typing_extensions")
1456+
def test_implements_protocol(expected, value):
1457+
assert implements_protocol(value, Interface) is expected
1458+
1459+
1460+
@pytest.mark.parametrize(
1461+
"expected, value",
1462+
[
1463+
(False, ImplementsInterface),
1464+
(True, ImplementsInterface(1)),
1465+
(False, NotImplementsInterface1()),
1466+
(False, object),
1467+
],
1468+
)
1469+
@pytest.mark.skipif(not Protocol, reason="Requires Python 3.8+ or typing_extensions")
1470+
def test_is_instance_or_supports_protocol(expected, value):
1471+
assert is_instance_or_supports_protocol(value, Interface) is expected
1472+
1473+
1474+
@pytest.mark.skipif(not Protocol, reason="Requires Python 3.8+ or typing_extensions")
1475+
def test_parse_implements_protocol(parser):
1476+
parser.add_argument("--cls", type=Interface)
1477+
cfg = parser.parse_args([f"--cls={__name__}.ImplementsInterface", "--cls.batch_size=5"])
1478+
assert cfg.cls.class_path == f"{__name__}.ImplementsInterface"
1479+
assert cfg.cls.init_args == Namespace(batch_size=5)
1480+
init = parser.instantiate_classes(cfg)
1481+
assert isinstance(init.cls, ImplementsInterface)
1482+
assert init.cls.batch_size == 5
1483+
assert init.cls.predict([1.0, 2.0]) == [1.0, 2.0]
1484+
with pytest.raises(ArgumentError) as ctx:
1485+
parser.parse_args([f"--cls={__name__}.NotImplementsInterface1"])
1486+
ctx.match("does not correspond to a subclass of")
1487+
with pytest.raises(ArgumentError) as ctx:
1488+
parser.parse_args(['--cls={"batch_size": 5}'])
1489+
ctx.match("Not a valid subclass of Interface")
1490+
1491+
14101492
# parameter skip tests
14111493

14121494

0 commit comments

Comments
 (0)