|
| 1 | +"""Use the simplest model to just test speech recognition. |
| 2 | +
|
| 3 | +Example is not production ready, as probably in production app we want running requests in subprocesses with timeout or |
| 4 | +run multiply workers to process requests simultaneously. |
| 5 | +""" |
| 6 | + |
| 7 | +import os |
| 8 | +import tempfile |
| 9 | +import typing |
| 10 | +from contextlib import asynccontextmanager |
| 11 | + |
| 12 | +import torch |
| 13 | +from fastapi import Depends, FastAPI, UploadFile, responses |
| 14 | +from huggingface_hub import snapshot_download |
| 15 | +from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline |
| 16 | + |
| 17 | +from nc_py_api import NextcloudApp |
| 18 | +from nc_py_api.ex_app import nc_app, persistent_storage, run_app, set_handlers |
| 19 | + |
| 20 | +MODEL_NAME = "distil-whisper/distil-small.en" |
| 21 | + |
| 22 | + |
| 23 | +@asynccontextmanager |
| 24 | +async def lifespan(_app: FastAPI): |
| 25 | + set_handlers(APP, enabled_handler, models_to_fetch={MODEL_NAME: {"ignore_patterns": ["*.bin", "*onnx*"]}}) |
| 26 | + yield |
| 27 | + |
| 28 | + |
| 29 | +APP = FastAPI(lifespan=lifespan) |
| 30 | + |
| 31 | + |
| 32 | +@APP.post("/distil_whisper_small") |
| 33 | +async def distil_whisper_small( |
| 34 | + _nc: typing.Annotated[NextcloudApp, Depends(nc_app)], |
| 35 | + data: UploadFile, |
| 36 | + max_execution_time: float = 0, |
| 37 | +): |
| 38 | + print(max_execution_time) |
| 39 | + model = AutoModelForSpeechSeq2Seq.from_pretrained( |
| 40 | + snapshot_download( |
| 41 | + MODEL_NAME, |
| 42 | + local_files_only=True, |
| 43 | + cache_dir=persistent_storage(), |
| 44 | + ), |
| 45 | + torch_dtype=torch.float32, |
| 46 | + low_cpu_mem_usage=True, |
| 47 | + use_safetensors=True, |
| 48 | + ).to("cpu") |
| 49 | + |
| 50 | + processor = AutoProcessor.from_pretrained(MODEL_NAME) |
| 51 | + pipe = pipeline( |
| 52 | + "automatic-speech-recognition", |
| 53 | + model=model, |
| 54 | + tokenizer=processor.tokenizer, |
| 55 | + feature_extractor=processor.feature_extractor, |
| 56 | + max_new_tokens=128, |
| 57 | + torch_dtype=torch.float32, |
| 58 | + device="cpu", |
| 59 | + ) |
| 60 | + _, file_extension = os.path.splitext(data.filename) |
| 61 | + with tempfile.NamedTemporaryFile(mode="w+b", suffix=f"{file_extension}") as tmp: |
| 62 | + tmp.write(await data.read()) |
| 63 | + result = pipe(tmp.name) |
| 64 | + return responses.Response(content=result["text"]) |
| 65 | + |
| 66 | + |
| 67 | +# async |
| 68 | +def enabled_handler(enabled: bool, nc: NextcloudApp) -> str: |
| 69 | + print(f"enabled={enabled}") |
| 70 | + if enabled is True: |
| 71 | + nc.providers.speech_to_text.register("distil_whisper_small", "DistilWhisperSmall", "/distil_whisper_small") |
| 72 | + else: |
| 73 | + nc.providers.speech_to_text.unregister("distil_whisper_small") |
| 74 | + return "" |
| 75 | + |
| 76 | + |
| 77 | +if __name__ == "__main__": |
| 78 | + run_app("main:APP", log_level="trace") |
0 commit comments