import os from threading import Lock import torch import whisper from fastapi import FastAPI, File, UploadFile from fastapi.responses import RedirectResponse from . import __version__ from .audio import load_audio # TODO use pydantic config model_name = os.getenv("ASR_MODEL", "base") if torch.cuda.is_available(): model = whisper.load_model(model_name).cuda() else: model = whisper.load_model(model_name) model_lock = Lock() # TODO use pydantic config SAMPLE_RATE = 16000 # TODO move transcribe to a modeling worker def transcribe( audio, # task: Union[str, None], # language: Union[str, None], # initial_prompt: Union[str, None], ): # options_dict = {"task" : task} # if language: # options_dict["language"] = language # if initial_prompt: # options_dict["initial_prompt"] = initial_prompt with model_lock: # result = model.transcribe(audio, **options_dict) result = model.transcribe(audio) return result app = FastAPI( title="Local Whisper", description="Making OpenAPI's Open Whisper available via ReST API locally.", version=__version__, swagger_ui_parameters={"defaultModelsExpandDepth": -1}, license_info={ "name": "MIT License", }, ) @app.get("/", response_class=RedirectResponse, include_in_schema=False) async def index(): return "/docs" @app.post("/audio/transcriptions") async def asr(file: UploadFile = File(...)): if file.content_type.startswith("audio/"): transcription = transcribe(load_audio(file.file)) return {"text": transcription["text"]}