|
1 | 1 | """FastAPI directly related stuff.""" |
2 | 2 |
|
3 | 3 | import asyncio |
| 4 | +import builtins |
| 5 | +import hashlib |
4 | 6 | import json |
5 | 7 | import os |
6 | 8 | import typing |
| 9 | +from urllib.parse import urlparse |
7 | 10 |
|
| 11 | +import httpx |
8 | 12 | from fastapi import ( |
9 | 13 | BackgroundTasks, |
10 | 14 | Depends, |
|
20 | 24 | from .._misc import get_username_secret_from_headers |
21 | 25 | from ..nextcloud import AsyncNextcloudApp, NextcloudApp |
22 | 26 | from ..talk_bot import TalkBotMessage |
| 27 | +from .defs import LogLvl |
23 | 28 | from .misc import persistent_storage |
24 | 29 |
|
25 | 30 |
|
@@ -163,26 +168,79 @@ def __map_app_static_folders(fast_api_app: FastAPI): |
163 | 168 | fast_api_app.mount(f"/{mnt_dir}", staticfiles.StaticFiles(directory=mnt_dir_path), name=mnt_dir) |
164 | 169 |
|
165 | 170 |
|
166 | | -def __fetch_models_task( |
167 | | - nc: NextcloudApp, |
168 | | - models: dict[str, dict], |
169 | | -) -> None: |
| 171 | +def __fetch_models_task(nc: NextcloudApp, models: dict[str, dict]) -> None: |
170 | 172 | if models: |
171 | | - from huggingface_hub import snapshot_download # noqa isort:skip pylint: disable=C0415 disable=E0401 |
172 | | - from tqdm import tqdm # noqa isort:skip pylint: disable=C0415 disable=E0401 |
173 | | - |
174 | | - class TqdmProgress(tqdm): |
175 | | - def display(self, msg=None, pos=None): |
176 | | - nc.set_init_status(min(int((self.n * 100 / self.total) / len(models)), 100)) |
177 | | - return super().display(msg, pos) |
178 | | - |
| 173 | + current_progress = 0 |
| 174 | + percent_for_each = min(int(100 / len(models)), 99) |
179 | 175 | for model in models: |
180 | | - workers = models[model].pop("max_workers", 2) |
181 | | - cache = models[model].pop("cache_dir", persistent_storage()) |
182 | | - snapshot_download(model, tqdm_class=TqdmProgress, **models[model], max_workers=workers, cache_dir=cache) |
| 176 | + if model.startswith(("http://", "https://")): |
| 177 | + __fetch_model_as_file(current_progress, percent_for_each, nc, model, models[model]) |
| 178 | + else: |
| 179 | + __fetch_model_as_snapshot(current_progress, percent_for_each, nc, model, models[model]) |
| 180 | + current_progress += percent_for_each |
183 | 181 | nc.set_init_status(100) |
184 | 182 |
|
185 | 183 |
|
| 184 | +def __fetch_model_as_file( |
| 185 | + current_progress: int, progress_for_task: int, nc: NextcloudApp, model_path: str, download_options: dict |
| 186 | +) -> None: |
| 187 | + result_path = download_options.pop("save_path", urlparse(model_path).path.split("/")[-1]) |
| 188 | + try: |
| 189 | + with httpx.stream("GET", model_path, follow_redirects=True) as response: |
| 190 | + if not response.is_success: |
| 191 | + nc.log(LogLvl.ERROR, f"Downloading of '{model_path}' returned {response.status_code} status.") |
| 192 | + return |
| 193 | + downloaded_size = 0 |
| 194 | + linked_etag = "" |
| 195 | + for each_history in response.history: |
| 196 | + linked_etag = each_history.headers.get("X-Linked-ETag", "") |
| 197 | + if linked_etag: |
| 198 | + break |
| 199 | + if not linked_etag: |
| 200 | + linked_etag = response.headers.get("X-Linked-ETag", response.headers.get("ETag", "")) |
| 201 | + total_size = int(response.headers.get("Content-Length")) |
| 202 | + try: |
| 203 | + existing_size = os.path.getsize(result_path) |
| 204 | + except OSError: |
| 205 | + existing_size = 0 |
| 206 | + if linked_etag and total_size == existing_size: |
| 207 | + with builtins.open(result_path, "rb") as file: |
| 208 | + sha256_hash = hashlib.sha256() |
| 209 | + for byte_block in iter(lambda: file.read(4096), b""): |
| 210 | + sha256_hash.update(byte_block) |
| 211 | + if f'"{sha256_hash.hexdigest()}"' == linked_etag: |
| 212 | + nc.set_init_status(min(current_progress + progress_for_task, 99)) |
| 213 | + return |
| 214 | + |
| 215 | + with builtins.open(result_path, "wb") as file: |
| 216 | + last_progress = current_progress |
| 217 | + for chunk in response.iter_bytes(5 * 1024 * 1024): |
| 218 | + downloaded_size += file.write(chunk) |
| 219 | + if total_size: |
| 220 | + new_progress = min(current_progress + int(progress_for_task * downloaded_size / total_size), 99) |
| 221 | + if new_progress != last_progress: |
| 222 | + nc.set_init_status(new_progress) |
| 223 | + last_progress = new_progress |
| 224 | + except Exception as e: # noqa pylint: disable=broad-exception-caught |
| 225 | + nc.log(LogLvl.ERROR, f"Downloading of '{model_path}' raised an exception: {e}") |
| 226 | + |
| 227 | + |
| 228 | +def __fetch_model_as_snapshot( |
| 229 | + current_progress: int, progress_for_task, nc: NextcloudApp, mode_name: str, download_options: dict |
| 230 | +) -> None: |
| 231 | + from huggingface_hub import snapshot_download # noqa isort:skip pylint: disable=C0415 disable=E0401 |
| 232 | + from tqdm import tqdm # noqa isort:skip pylint: disable=C0415 disable=E0401 |
| 233 | + |
| 234 | + class TqdmProgress(tqdm): |
| 235 | + def display(self, msg=None, pos=None): |
| 236 | + nc.set_init_status(min(current_progress + int(progress_for_task * self.n / self.total), 99)) |
| 237 | + return super().display(msg, pos) |
| 238 | + |
| 239 | + workers = download_options.pop("max_workers", 2) |
| 240 | + cache = download_options.pop("cache_dir", persistent_storage()) |
| 241 | + snapshot_download(mode_name, tqdm_class=TqdmProgress, **download_options, max_workers=workers, cache_dir=cache) |
| 242 | + |
| 243 | + |
186 | 244 | class AppAPIAuthMiddleware: |
187 | 245 | """Pure ASGI AppAPIAuth Middleware.""" |
188 | 246 |
|
|
0 commit comments