You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
local-whisper/local_whisper/webservice.py

108 lines
3.0 KiB
Python

import os
from threading import Lock
from typing import BinaryIO
import ffmpeg
import numpy as np
import torch
import whisper
from fastapi import FastAPI, File, UploadFile
from fastapi.responses import RedirectResponse
from . import __version__
# TODO use pydantic config
model_name = os.getenv("ASR_MODEL", "base")
if torch.cuda.is_available():
model = whisper.load_model(model_name).cuda()
else:
model = whisper.load_model(model_name)
model_lock = Lock()
# TODO use pydantic config
SAMPLE_RATE = 16000
# TODO move transcribe to a modeling worker
def transcribe(
audio,
# task: Union[str, None],
# language: Union[str, None],
# initial_prompt: Union[str, None],
):
# options_dict = {"task" : task}
# if language:
# options_dict["language"] = language
# if initial_prompt:
# options_dict["initial_prompt"] = initial_prompt
with model_lock:
# result = model.transcribe(audio, **options_dict)
result = model.transcribe(audio)
return result
# TODO probably can offload this on a worker queue too
def load_audio(file: BinaryIO, encode=True, sr: int = SAMPLE_RATE):
"""
Open an audio file object and read as mono waveform, resampling as necessary.
Modified from https://github.com/openai/whisper/blob/main/whisper/audio.py
to accept a file object
Parameters
----------
file: BinaryIO
The audio file like object
encode: Boolean
If true, encode audio stream to WAV before sending to whisper
sr: int
The sample rate to resample the audio if necessary
Returns
-------
A NumPy array containing the audio waveform, in float32 dtype.
"""
if encode:
try:
# This launches a subprocess to decode audio while down-mixing and resampling as necessary.
# Requires the ffmpeg CLI and `ffmpeg-python` package to be installed.
out, _ = (
ffmpeg.input("pipe:", threads=0)
.output("-", format="s16le", acodec="pcm_s16le", ac=1, ar=sr)
.run(
cmd="ffmpeg",
capture_stdout=True,
capture_stderr=True,
input=file.read(),
)
)
except ffmpeg.Error as e:
raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e
else:
out = file.read()
return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0
app = FastAPI(
title="Local Whisper",
description="Making OpenAPI's Open Whisper available via ReST API locally.",
version=__version__,
swagger_ui_parameters={"defaultModelsExpandDepth": -1},
license_info={
"name": "MIT License",
},
)
@app.get("/", response_class=RedirectResponse, include_in_schema=False)
async def index():
return "/docs"
@app.post("/audio/transcriptions")
async def asr(file: UploadFile = File(...)):
if file.content_type.startswith("audio/"):
transcription = transcribe(load_audio(file.file))
return {"text": transcription["text"]}