Skip to content

Commit 1158f1b

Browse files
authored
Support async functions and methods in CLI (#531, #517)
1 parent ac4ce0b commit 1158f1b

File tree

3 files changed

+56
-11
lines changed

3 files changed

+56
-11
lines changed

CHANGELOG.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ v4.31.0 (2024-06-??)
1717

1818
Added
1919
^^^^^
20+
- Support async functions and methods in ``CLI`` (`#531
21+
<https://github.com/omni-us/jsonargparse/pull/531>`__).
2022
- Support for ``Protocol`` types only accepting exact matching signature of
2123
public methods (`#526
2224
<https://github.com/omni-us/jsonargparse/pull/526>`__).

jsonargparse/_cli.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -92,8 +92,8 @@ def CLI(
9292
deprecation_warning_cli_return_parser()
9393
return parser
9494
cfg = parser.parse_args(args)
95-
cfg_init = parser.instantiate_classes(cfg)
96-
return _run_component(components, cfg_init)
95+
init = parser.instantiate_classes(cfg)
96+
return _run_component(components, init)
9797

9898
elif isinstance(components, list):
9999
components = {c.__name__: c for c in components}
@@ -192,12 +192,13 @@ def _add_component_to_parser(
192192

193193
def _run_component(component, cfg):
194194
cfg.pop("config", None)
195-
if not inspect.isclass(component):
196-
return component(**cfg)
197195
subcommand = cfg.pop("subcommand")
198-
if not subcommand:
199-
return component(**cfg)
200-
subcommand_cfg = cfg.pop(subcommand, {})
201-
subcommand_cfg.pop("config", None)
202-
component_obj = component(**cfg)
203-
return getattr(component_obj, subcommand)(**subcommand_cfg)
196+
if inspect.isclass(component) and subcommand:
197+
subcommand_cfg = cfg.pop(subcommand, {})
198+
subcommand_cfg.pop("config", None)
199+
component_obj = component(**cfg)
200+
component = getattr(component_obj, subcommand)
201+
cfg = subcommand_cfg
202+
if inspect.iscoroutinefunction(component):
203+
return __import__("asyncio").run(component(**cfg))
204+
return component(**cfg)

jsonargparse_tests/test_cli.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
from __future__ import annotations
22

3+
import asyncio
34
import os
45
import sys
56
from contextlib import redirect_stderr, redirect_stdout, suppress
67
from dataclasses import asdict, dataclass
78
from io import StringIO
89
from pathlib import Path
9-
from typing import Optional
10+
from typing import Callable, Optional
1011
from unittest.mock import patch
1112

1213
import pytest
@@ -535,3 +536,44 @@ def test_final_and_subclass_type_config_file(tmp_cwd):
535536

536537
out = CLI(run_bf, args=["--config=config.yaml"])
537538
assert "a yaml" == out
539+
540+
541+
# async tests
542+
543+
544+
async def run_async(time: float = 0.1):
545+
await asyncio.sleep(time)
546+
return "done"
547+
548+
549+
def test_async_function():
550+
assert "done" == CLI(run_async, args=["--time=0.0"])
551+
552+
553+
class AsyncMethod:
554+
def __init__(self, time: float = 0.1, require_async: bool = False):
555+
self.time = time
556+
if require_async:
557+
self.loop = asyncio.get_event_loop()
558+
559+
async def run(self):
560+
await asyncio.sleep(self.time)
561+
return "done"
562+
563+
564+
def test_async_method():
565+
assert "done" == CLI(AsyncMethod, args=["--time=0.0", "run"])
566+
567+
568+
async def run_async_instance(cls: Callable[[], AsyncMethod]):
569+
return await cls().run()
570+
571+
572+
def test_async_instance():
573+
config = {
574+
"cls": {
575+
"class_path": f"{__name__}.AsyncMethod",
576+
"init_args": {"time": 0.0, "require_async": True},
577+
}
578+
}
579+
assert "done" == CLI(run_async_instance, args=[f"--config={config}"])

0 commit comments

Comments
 (0)