@@ -75,8 +75,7 @@ def set_handlers(
7575 enabled_handler : typing .Callable [[bool , AsyncNextcloudApp | NextcloudApp ], typing .Awaitable [str ] | str ],
7676 heartbeat_handler : typing .Callable [[], typing .Awaitable [str ] | str ] | None = None ,
7777 init_handler : typing .Callable [[AsyncNextcloudApp | NextcloudApp ], typing .Awaitable [None ] | None ] | None = None ,
78- models_to_fetch : list [str ] | None = None ,
79- models_download_params : dict | None = None ,
78+ models_to_fetch : dict [str , dict ] | None = None ,
8079 map_app_static : bool = True ,
8180):
8281 """Defines handlers for the application.
@@ -92,7 +91,6 @@ def set_handlers(
9291
9392 .. note:: ```huggingface_hub`` package should be present for automatic models fetching.
9493
95- :param models_download_params: Parameters to pass to ``snapshot_download`` function from **huggingface_hub**.
9694 :param map_app_static: Should be folders ``js``, ``css``, ``l10n``, ``img`` automatically mounted in FastAPI or not.
9795
9896 .. note:: First, presence of these directories in the current working dir is checked, then one directory higher.
@@ -140,8 +138,7 @@ async def init_callback(
140138 background_tasks .add_task (
141139 __fetch_models_task ,
142140 nc ,
143- models_to_fetch if models_to_fetch else [],
144- models_download_params if models_download_params else {},
141+ models_to_fetch if models_to_fetch else {},
145142 )
146143 return responses .JSONResponse (content = {}, status_code = 200 )
147144
@@ -181,8 +178,7 @@ def __map_app_static_folders(fast_api_app: FastAPI):
181178
182179def __fetch_models_task (
183180 nc : NextcloudApp ,
184- models : list [str ],
185- params : dict [str , typing .Any ],
181+ models : dict [str , dict ],
186182) -> None :
187183 if models :
188184 from huggingface_hub import snapshot_download # noqa isort:skip pylint: disable=C0415 disable=E0401
@@ -193,10 +189,8 @@ def display(self, msg=None, pos=None):
193189 nc .set_init_status (min (int ((self .n * 100 / self .total ) / len (models )), 100 ))
194190 return super ().display (msg , pos )
195191
196- if "max_workers" not in params :
197- params ["max_workers" ] = 2
198- if "cache_dir" not in params :
199- params ["cache_dir" ] = persistent_storage ()
200192 for model in models :
201- snapshot_download (model , tqdm_class = TqdmProgress , ** params ) # noqa
193+ workers = models [model ].pop ("max_workers" , 2 )
194+ cache = models [model ].pop ("cache_dir" , persistent_storage ())
195+ snapshot_download (model , tqdm_class = TqdmProgress , ** models [model ], max_workers = workers , cache_dir = cache )
202196 nc .set_init_status (100 )
0 commit comments