Skip to content

Commit 82273f9

Browse files
authored
Fix: Resolving of import paths for some torch functions not working (#535)
1 parent 2bcbd48 commit 82273f9

File tree

3 files changed

+29
-4
lines changed

3 files changed

+29
-4
lines changed

CHANGELOG.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,11 @@ paths are considered internals and can change in minor and patch releases.
1515
v4.31.0 (2024-06-??)
1616
--------------------
1717

18+
Fixed
19+
^^^^^
20+
- Resolving of import paths for some ``torch`` functions not working (`#535
21+
<https://github.com/omni-us/jsonargparse/pull/535>`__).
22+
1823
Changed
1924
^^^^^^^
2025
- Now ``--*.help`` output shows options without ``init_args`` (`#533

jsonargparse/_util.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -233,15 +233,18 @@ def get_import_path(value: Any) -> Optional[str]:
233233
if not path:
234234
raise ValueError(f"Not possible to determine the import path for object {value}.")
235235

236-
if qualname and module_path and "." in module_path:
236+
if qualname and module_path and ("." in qualname or "." in module_path):
237237
module_parts = module_path.split(".")
238238
for num in range(len(module_parts)):
239239
module_path = ".".join(module_parts[: num + 1])
240240
module = import_module(module_path)
241241
if "." in qualname:
242242
obj_name, attr = qualname.rsplit(".", 1)
243243
obj = getattr(module, obj_name, None)
244-
if getattr(obj, attr, None) is value:
244+
if getattr(module, attr, None) is value:
245+
path = module_path + "." + attr
246+
break
247+
elif getattr(obj, attr, None) is value:
245248
path = module_path + "." + qualname
246249
break
247250
elif getattr(module, qualname, None) is value:

jsonargparse_tests/test_util.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -533,8 +533,12 @@ def test_logger_jsonargparse_debug():
533533

534534

535535
def test_import_object_invalid():
536-
pytest.raises(ValueError, lambda: import_object(True))
537-
pytest.raises(ValueError, lambda: import_object("jsonargparse-tests.os"))
536+
with pytest.raises(ValueError) as ctx:
537+
import_object(True)
538+
ctx.match("Expected a dot import path string")
539+
with pytest.raises(ValueError) as ctx:
540+
import_object("jsonargparse-tests.os")
541+
ctx.match("Unexpected import path format")
538542

539543

540544
def test_get_import_path():
@@ -548,6 +552,19 @@ def test_get_import_path():
548552
assert get_import_path(MISSING) == "dataclasses.MISSING"
549553

550554

555+
class _StaticMethods:
556+
@staticmethod
557+
def static_method():
558+
pass
559+
560+
561+
static_method = _StaticMethods.static_method
562+
563+
564+
def test_get_import_path_static_method_shorthand():
565+
assert get_import_path(static_method) == f"{__name__}.static_method"
566+
567+
551568
def unresolvable_import():
552569
pass
553570

0 commit comments

Comments
 (0)