@@ -54,7 +54,7 @@ def set_handlers(
5454 fast_api_app : FastAPI ,
5555 enabled_handler : typing .Callable [[bool , NextcloudApp ], str ],
5656 heartbeat_handler : typing .Optional [typing .Callable [[], str ]] = None ,
57- init_handler : typing .Optional [typing .Callable [[], None ]] = None ,
57+ init_handler : typing .Optional [typing .Callable [[NextcloudApp ], None ]] = None ,
5858 models_to_fetch : typing .Optional [list [str ]] = None ,
5959 models_download_params : typing .Optional [dict ] = None ,
6060):
@@ -75,15 +75,15 @@ def set_handlers(
7575 :param models_download_params: Parameters to pass to ``snapshot_download`` function from **huggingface_hub**.
7676 """
7777
78- def fetch_models_task (models : list [str ]) -> None :
78+ def fetch_models_task (nc : NextcloudApp , models : list [str ]) -> None :
7979 if models :
8080 from huggingface_hub import snapshot_download # noqa isort:skip pylint: disable=C0415 disable=E0401
8181 from tqdm import tqdm # noqa isort:skip pylint: disable=C0415 disable=E0401
8282
8383 class TqdmProgress (tqdm ):
8484 def display (self , msg = None , pos = None ):
8585 if init_handler is None :
86- NextcloudApp () .set_init_status (min (int ((self .n * 100 / self .total ) / len (models )), 100 ))
86+ nc .set_init_status (min (int ((self .n * 100 / self .total ) / len (models )), 100 ))
8787 return super ().display (msg , pos )
8888
8989 params = models_download_params if models_download_params else {}
@@ -94,9 +94,9 @@ def display(self, msg=None, pos=None):
9494 for model in models :
9595 snapshot_download (model , tqdm_class = TqdmProgress , ** params ) # noqa
9696 if init_handler is None :
97- NextcloudApp () .set_init_status (100 )
97+ nc .set_init_status (100 )
9898 else :
99- init_handler ()
99+ init_handler (nc )
100100
101101 @fast_api_app .put ("/enabled" )
102102 def enabled_callback (
@@ -114,6 +114,9 @@ def heartbeat_callback():
114114 return responses .JSONResponse (content = {"status" : return_status }, status_code = 200 )
115115
116116 @fast_api_app .post ("/init" )
117- def init_callback (background_tasks : BackgroundTasks ):
118- background_tasks .add_task (fetch_models_task , models_to_fetch if models_to_fetch else [])
117+ def init_callback (
118+ background_tasks : BackgroundTasks ,
119+ nc : typing .Annotated [NextcloudApp , Depends (nc_app )],
120+ ):
121+ background_tasks .add_task (fetch_models_task , nc , models_to_fetch if models_to_fetch else [])
119122 return responses .JSONResponse (content = {}, status_code = 200 )
0 commit comments