You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
65 lines
1.6 KiB
Python
65 lines
1.6 KiB
Python
import os
|
|
from threading import Lock
|
|
|
|
import torch
|
|
import whisper
|
|
from fastapi import FastAPI, File, UploadFile
|
|
from fastapi.responses import RedirectResponse
|
|
|
|
from . import __version__
|
|
from .audio import load_audio
|
|
|
|
# TODO use pydantic config
|
|
model_name = os.getenv("ASR_MODEL", "base")
|
|
if torch.cuda.is_available():
|
|
model = whisper.load_model(model_name).cuda()
|
|
else:
|
|
model = whisper.load_model(model_name)
|
|
model_lock = Lock()
|
|
|
|
|
|
# TODO use pydantic config
|
|
SAMPLE_RATE = 16000
|
|
|
|
|
|
# TODO move transcribe to a modeling worker
|
|
def transcribe(
|
|
audio,
|
|
# task: Union[str, None],
|
|
# language: Union[str, None],
|
|
# initial_prompt: Union[str, None],
|
|
):
|
|
# options_dict = {"task" : task}
|
|
# if language:
|
|
# options_dict["language"] = language
|
|
# if initial_prompt:
|
|
# options_dict["initial_prompt"] = initial_prompt
|
|
with model_lock:
|
|
# result = model.transcribe(audio, **options_dict)
|
|
result = model.transcribe(audio)
|
|
|
|
return result
|
|
|
|
|
|
app = FastAPI(
|
|
title="Local Whisper",
|
|
description="Making OpenAPI's Open Whisper available via ReST API locally.",
|
|
version=__version__,
|
|
swagger_ui_parameters={"defaultModelsExpandDepth": -1},
|
|
license_info={
|
|
"name": "MIT License",
|
|
},
|
|
)
|
|
|
|
|
|
@app.get("/", response_class=RedirectResponse, include_in_schema=False)
|
|
async def index():
|
|
return "/docs"
|
|
|
|
|
|
@app.post("/audio/transcriptions")
|
|
async def asr(file: UploadFile = File(...)):
|
|
if file.content_type.startswith("audio/"):
|
|
transcription = transcribe(load_audio(file.file))
|
|
return {"text": transcription["text"]}
|