import os from threading import Lock from typing import BinaryIO import ffmpeg import numpy as np import torch import whisper from fastapi import FastAPI, File, UploadFile from fastapi.responses import RedirectResponse from . import __version__ # TODO use pydantic config model_name = os.getenv("ASR_MODEL", "base") if torch.cuda.is_available(): model = whisper.load_model(model_name).cuda() else: model = whisper.load_model(model_name) model_lock = Lock() # TODO use pydantic config SAMPLE_RATE = 16000 # TODO move transcribe to a modeling worker def transcribe( audio, # task: Union[str, None], # language: Union[str, None], # initial_prompt: Union[str, None], ): # options_dict = {"task" : task} # if language: # options_dict["language"] = language # if initial_prompt: # options_dict["initial_prompt"] = initial_prompt with model_lock: # result = model.transcribe(audio, **options_dict) result = model.transcribe(audio) return result # TODO probably can offload this on a worker queue too def load_audio(file: BinaryIO, encode=True, sr: int = SAMPLE_RATE): """ Open an audio file object and read as mono waveform, resampling as necessary. Modified from https://github.com/openai/whisper/blob/main/whisper/audio.py to accept a file object Parameters ---------- file: BinaryIO The audio file like object encode: Boolean If true, encode audio stream to WAV before sending to whisper sr: int The sample rate to resample the audio if necessary Returns ------- A NumPy array containing the audio waveform, in float32 dtype. """ if encode: try: # This launches a subprocess to decode audio while down-mixing and resampling as necessary. # Requires the ffmpeg CLI and `ffmpeg-python` package to be installed. out, _ = ( ffmpeg.input("pipe:", threads=0) .output("-", format="s16le", acodec="pcm_s16le", ac=1, ar=sr) .run( cmd="ffmpeg", capture_stdout=True, capture_stderr=True, input=file.read(), ) ) except ffmpeg.Error as e: raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e else: out = file.read() return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0 app = FastAPI( title="Local Whisper", description="Making OpenAPI's Open Whisper available via ReST API locally.", version=__version__, swagger_ui_parameters={"defaultModelsExpandDepth": -1}, license_info={ "name": "MIT License", }, ) @app.get("/", response_class=RedirectResponse, include_in_schema=False) async def index(): return "/docs" @app.post("/audio/transcriptions") async def asr(file: UploadFile = File(...)): if file.content_type.startswith("audio/"): transcription = transcribe(load_audio(file.file)) return {"text": transcription["text"]}