@@ -54,8 +54,8 @@ def talk_bot_app(request: Request) -> TalkBotMessage:
5454
5555def set_handlers (
5656 fast_api_app : FastAPI ,
57- enabled_handler : typing .Callable [[bool , NextcloudApp ], str ],
58- heartbeat_handler : typing .Optional [typing .Callable [[], str ]] = None ,
57+ enabled_handler : typing .Callable [[bool , NextcloudApp ], typing . Union [ str , typing . Awaitable [ str ]] ],
58+ heartbeat_handler : typing .Optional [typing .Callable [[], typing . Union [ str , typing . Awaitable [ str ]] ]] = None ,
5959 init_handler : typing .Optional [typing .Callable [[NextcloudApp ], None ]] = None ,
6060 models_to_fetch : typing .Optional [list [str ]] = None ,
6161 models_download_params : typing .Optional [dict ] = None ,
@@ -81,50 +81,40 @@ def set_handlers(
8181 .. note:: First, presence of these directories in the current working dir is checked, then one directory higher.
8282 """
8383
84- def fetch_models_task (nc : NextcloudApp , models : list [str ]) -> None :
85- if models :
86- from huggingface_hub import snapshot_download # noqa isort:skip pylint: disable=C0415 disable=E0401
87- from tqdm import tqdm # noqa isort:skip pylint: disable=C0415 disable=E0401
88-
89- class TqdmProgress (tqdm ):
90- def display (self , msg = None , pos = None ):
91- if init_handler is None :
92- nc .set_init_status (min (int ((self .n * 100 / self .total ) / len (models )), 100 ))
93- return super ().display (msg , pos )
94-
95- params = models_download_params if models_download_params else {}
96- if "max_workers" not in params :
97- params ["max_workers" ] = 2
98- if "cache_dir" not in params :
99- params ["cache_dir" ] = persistent_storage ()
100- for model in models :
101- snapshot_download (model , tqdm_class = TqdmProgress , ** params ) # noqa
102- if init_handler is None :
103- nc .set_init_status (100 )
104- else :
105- init_handler (nc )
106-
10784 @fast_api_app .put ("/enabled" )
108- def enabled_callback (
85+ async def enabled_callback (
10986 enabled : bool ,
11087 nc : typing .Annotated [NextcloudApp , Depends (nc_app )],
11188 ):
112- r = enabled_handler (enabled , nc )
89+ if asyncio .iscoroutinefunction (heartbeat_handler ):
90+ r = await enabled_handler (enabled , nc ) # type: ignore
91+ else :
92+ r = enabled_handler (enabled , nc )
11393 return responses .JSONResponse (content = {"error" : r }, status_code = 200 )
11494
11595 @fast_api_app .get ("/heartbeat" )
116- def heartbeat_callback ():
117- return_status = "ok"
96+ async def heartbeat_callback ():
11897 if heartbeat_handler is not None :
119- return_status = heartbeat_handler ()
98+ if asyncio .iscoroutinefunction (heartbeat_handler ):
99+ return_status = await heartbeat_handler ()
100+ else :
101+ return_status = heartbeat_handler ()
102+ else :
103+ return_status = "ok"
120104 return responses .JSONResponse (content = {"status" : return_status }, status_code = 200 )
121105
122106 @fast_api_app .post ("/init" )
123- def init_callback (
107+ async def init_callback (
124108 background_tasks : BackgroundTasks ,
125109 nc : typing .Annotated [NextcloudApp , Depends (nc_app )],
126110 ):
127- background_tasks .add_task (fetch_models_task , nc , models_to_fetch if models_to_fetch else [])
111+ background_tasks .add_task (
112+ __fetch_models_task ,
113+ nc ,
114+ init_handler ,
115+ models_to_fetch if models_to_fetch else [],
116+ models_download_params if models_download_params else {},
117+ )
128118 return responses .JSONResponse (content = {}, status_code = 200 )
129119
130120 if map_app_static :
@@ -139,3 +129,31 @@ def __map_app_static_folders(fast_api_app: FastAPI):
139129 mnt_dir_path = os .path .join (os .path .dirname (os .getcwd ()), mnt_dir )
140130 if os .path .exists (mnt_dir_path ):
141131 fast_api_app .mount (f"/{ mnt_dir } " , staticfiles .StaticFiles (directory = mnt_dir_path ), name = mnt_dir )
132+
133+
134+ def __fetch_models_task (
135+ nc : NextcloudApp ,
136+ init_handler : typing .Optional [typing .Callable [[NextcloudApp ], None ]],
137+ models : list [str ],
138+ params : dict [str , typing .Any ],
139+ ) -> None :
140+ if models :
141+ from huggingface_hub import snapshot_download # noqa isort:skip pylint: disable=C0415 disable=E0401
142+ from tqdm import tqdm # noqa isort:skip pylint: disable=C0415 disable=E0401
143+
144+ class TqdmProgress (tqdm ):
145+ def display (self , msg = None , pos = None ):
146+ if init_handler is None :
147+ nc .set_init_status (min (int ((self .n * 100 / self .total ) / len (models )), 100 ))
148+ return super ().display (msg , pos )
149+
150+ if "max_workers" not in params :
151+ params ["max_workers" ] = 2
152+ if "cache_dir" not in params :
153+ params ["cache_dir" ] = persistent_storage ()
154+ for model in models :
155+ snapshot_download (model , tqdm_class = TqdmProgress , ** params ) # noqa
156+ if init_handler is None :
157+ nc .set_init_status (100 )
158+ else :
159+ init_handler (nc )
0 commit comments