Skip to content

Commit 695b5a5

Browse files
authored
added "async" support for set_handlers (#176)
1 parent c74798e commit 695b5a5

File tree

5 files changed

+105
-36
lines changed

5 files changed

+105
-36
lines changed

.github/workflows/analysis-coverage.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -725,7 +725,7 @@ jobs:
725725
run: |
726726
php occ app:enable app_api
727727
cd nc_py_api
728-
coverage run --data-file=.coverage.ci_install tests/_install.py &
728+
coverage run --data-file=.coverage.ci_install tests/_install_async.py &
729729
echo $! > /tmp/_install.pid
730730
python3 tests/_install_wait.py http://127.0.0.1:$APP_PORT/heartbeat "\"status\":\"ok\"" 15 0.5
731731
python3 tests/_app_security_checks.py http://127.0.0.1:$APP_PORT

CHANGELOG.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,12 @@
22

33
All notable changes to this project will be documented in this file.
44

5+
## [0.6.1 - 202x-xx-xx]
6+
7+
### Added
8+
9+
- set_handlers: `enabled_handler`, `heartbeat_handler` now can be async(Coroutines). #175
10+
511
## [0.6.0 - 2023-12-06]
612

713
### Added

nc_py_api/ex_app/integration_fastapi.py

Lines changed: 50 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,8 @@ def talk_bot_app(request: Request) -> TalkBotMessage:
5454

5555
def set_handlers(
5656
fast_api_app: FastAPI,
57-
enabled_handler: typing.Callable[[bool, NextcloudApp], str],
58-
heartbeat_handler: typing.Optional[typing.Callable[[], str]] = None,
57+
enabled_handler: typing.Callable[[bool, NextcloudApp], typing.Union[str, typing.Awaitable[str]]],
58+
heartbeat_handler: typing.Optional[typing.Callable[[], typing.Union[str, typing.Awaitable[str]]]] = None,
5959
init_handler: typing.Optional[typing.Callable[[NextcloudApp], None]] = None,
6060
models_to_fetch: typing.Optional[list[str]] = None,
6161
models_download_params: typing.Optional[dict] = None,
@@ -81,50 +81,40 @@ def set_handlers(
8181
.. note:: First, presence of these directories in the current working dir is checked, then one directory higher.
8282
"""
8383

84-
def fetch_models_task(nc: NextcloudApp, models: list[str]) -> None:
85-
if models:
86-
from huggingface_hub import snapshot_download # noqa isort:skip pylint: disable=C0415 disable=E0401
87-
from tqdm import tqdm # noqa isort:skip pylint: disable=C0415 disable=E0401
88-
89-
class TqdmProgress(tqdm):
90-
def display(self, msg=None, pos=None):
91-
if init_handler is None:
92-
nc.set_init_status(min(int((self.n * 100 / self.total) / len(models)), 100))
93-
return super().display(msg, pos)
94-
95-
params = models_download_params if models_download_params else {}
96-
if "max_workers" not in params:
97-
params["max_workers"] = 2
98-
if "cache_dir" not in params:
99-
params["cache_dir"] = persistent_storage()
100-
for model in models:
101-
snapshot_download(model, tqdm_class=TqdmProgress, **params) # noqa
102-
if init_handler is None:
103-
nc.set_init_status(100)
104-
else:
105-
init_handler(nc)
106-
10784
@fast_api_app.put("/enabled")
108-
def enabled_callback(
85+
async def enabled_callback(
10986
enabled: bool,
11087
nc: typing.Annotated[NextcloudApp, Depends(nc_app)],
11188
):
112-
r = enabled_handler(enabled, nc)
89+
if asyncio.iscoroutinefunction(heartbeat_handler):
90+
r = await enabled_handler(enabled, nc) # type: ignore
91+
else:
92+
r = enabled_handler(enabled, nc)
11393
return responses.JSONResponse(content={"error": r}, status_code=200)
11494

11595
@fast_api_app.get("/heartbeat")
116-
def heartbeat_callback():
117-
return_status = "ok"
96+
async def heartbeat_callback():
11897
if heartbeat_handler is not None:
119-
return_status = heartbeat_handler()
98+
if asyncio.iscoroutinefunction(heartbeat_handler):
99+
return_status = await heartbeat_handler()
100+
else:
101+
return_status = heartbeat_handler()
102+
else:
103+
return_status = "ok"
120104
return responses.JSONResponse(content={"status": return_status}, status_code=200)
121105

122106
@fast_api_app.post("/init")
123-
def init_callback(
107+
async def init_callback(
124108
background_tasks: BackgroundTasks,
125109
nc: typing.Annotated[NextcloudApp, Depends(nc_app)],
126110
):
127-
background_tasks.add_task(fetch_models_task, nc, models_to_fetch if models_to_fetch else [])
111+
background_tasks.add_task(
112+
__fetch_models_task,
113+
nc,
114+
init_handler,
115+
models_to_fetch if models_to_fetch else [],
116+
models_download_params if models_download_params else {},
117+
)
128118
return responses.JSONResponse(content={}, status_code=200)
129119

130120
if map_app_static:
@@ -139,3 +129,31 @@ def __map_app_static_folders(fast_api_app: FastAPI):
139129
mnt_dir_path = os.path.join(os.path.dirname(os.getcwd()), mnt_dir)
140130
if os.path.exists(mnt_dir_path):
141131
fast_api_app.mount(f"/{mnt_dir}", staticfiles.StaticFiles(directory=mnt_dir_path), name=mnt_dir)
132+
133+
134+
def __fetch_models_task(
135+
nc: NextcloudApp,
136+
init_handler: typing.Optional[typing.Callable[[NextcloudApp], None]],
137+
models: list[str],
138+
params: dict[str, typing.Any],
139+
) -> None:
140+
if models:
141+
from huggingface_hub import snapshot_download # noqa isort:skip pylint: disable=C0415 disable=E0401
142+
from tqdm import tqdm # noqa isort:skip pylint: disable=C0415 disable=E0401
143+
144+
class TqdmProgress(tqdm):
145+
def display(self, msg=None, pos=None):
146+
if init_handler is None:
147+
nc.set_init_status(min(int((self.n * 100 / self.total) / len(models)), 100))
148+
return super().display(msg, pos)
149+
150+
if "max_workers" not in params:
151+
params["max_workers"] = 2
152+
if "cache_dir" not in params:
153+
params["cache_dir"] = persistent_storage()
154+
for model in models:
155+
snapshot_download(model, tqdm_class=TqdmProgress, **params) # noqa
156+
if init_handler is None:
157+
nc.set_init_status(100)
158+
else:
159+
init_handler(nc)

tests/_install.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,9 @@ async def lifespan(_app: FastAPI):
1919
@APP.put("/sec_check")
2020
def sec_check(
2121
value: int,
22-
nc: Annotated[NextcloudApp, Depends(ex_app.nc_app)],
22+
_nc: Annotated[NextcloudApp, Depends(ex_app.nc_app)],
2323
):
24-
print(value)
25-
_ = nc
24+
print(value, flush=True)
2625
return JSONResponse(content={"error": ""}, status_code=200)
2726

2827

tests/_install_async.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
from contextlib import asynccontextmanager
2+
from typing import Annotated
3+
4+
from fastapi import Depends, FastAPI
5+
from fastapi.responses import JSONResponse
6+
7+
from nc_py_api import NextcloudApp, ex_app
8+
9+
10+
@asynccontextmanager
11+
async def lifespan(_app: FastAPI):
12+
ex_app.set_handlers(APP, enabled_handler, heartbeat_callback, init_handler=init_handler)
13+
yield
14+
15+
16+
APP = FastAPI(lifespan=lifespan)
17+
18+
19+
@APP.put("/sec_check")
20+
async def sec_check(
21+
value: int,
22+
_nc: Annotated[NextcloudApp, Depends(ex_app.nc_app)],
23+
):
24+
print(value, flush=True)
25+
return JSONResponse(content={"error": ""}, status_code=200)
26+
27+
28+
async def enabled_handler(enabled: bool, nc: NextcloudApp) -> str:
29+
print(f"enabled_handler: enabled={enabled}", flush=True)
30+
if enabled:
31+
nc.log(ex_app.LogLvl.WARNING, f"Hello from {nc.app_cfg.app_name} :)")
32+
else:
33+
nc.log(ex_app.LogLvl.WARNING, f"Bye bye from {nc.app_cfg.app_name} :(")
34+
return ""
35+
36+
37+
def init_handler(nc: NextcloudApp):
38+
nc.set_init_status(100)
39+
40+
41+
async def heartbeat_callback():
42+
return "ok"
43+
44+
45+
if __name__ == "__main__":
46+
ex_app.run_app("_install_async:APP", log_level="trace")

0 commit comments

Comments
 (0)