You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
local-whisper/local_whisper/webservice.py

65 lines
1.6 KiB
Python

import os
from threading import Lock
import torch
import whisper
from fastapi import FastAPI, File, UploadFile
from fastapi.responses import RedirectResponse
from . import __version__
2 years ago
from .audio import load_audio
# TODO use pydantic config
model_name = os.getenv("ASR_MODEL", "base")
if torch.cuda.is_available():
model = whisper.load_model(model_name).cuda()
else:
model = whisper.load_model(model_name)
model_lock = Lock()
# TODO use pydantic config
SAMPLE_RATE = 16000
# TODO move transcribe to a modeling worker
def transcribe(
audio,
# task: Union[str, None],
# language: Union[str, None],
# initial_prompt: Union[str, None],
):
# options_dict = {"task" : task}
# if language:
# options_dict["language"] = language
# if initial_prompt:
# options_dict["initial_prompt"] = initial_prompt
with model_lock:
# result = model.transcribe(audio, **options_dict)
result = model.transcribe(audio)
return result
app = FastAPI(
title="Local Whisper",
description="Making OpenAPI's Open Whisper available via ReST API locally.",
version=__version__,
swagger_ui_parameters={"defaultModelsExpandDepth": -1},
license_info={
"name": "MIT License",
},
)
@app.get("/", response_class=RedirectResponse, include_in_schema=False)
async def index():
return "/docs"
@app.post("/audio/transcriptions")
async def asr(file: UploadFile = File(...)):
if file.content_type.startswith("audio/"):
transcription = transcribe(load_audio(file.file))
return {"text": transcription["text"]}