Basic inference testing

pull/6/head
Drew Bednar 2 years ago
parent 1c5ee1fc27
commit c4b3738a24

@ -0,0 +1,26 @@
import os
from threading import Lock
import torch
import whisper
# TODO use pydantic config
model_name = os.getenv("ASR_MODEL", "base")
if torch.cuda.is_available():
model = whisper.load_model(model_name).cuda()
else:
model = whisper.load_model(model_name)
model_lock = Lock()
# TODO move transcribe to a modeling worker
def transcribe(audio):
# options_dict = {"task" : task}
# if language:
# options_dict["language"] = language
# if initial_prompt:
# options_dict["initial_prompt"] = initial_prompt
with model_lock:
# result = model.transcribe(audio, **options_dict)
result = model.transcribe(audio)
return result

@ -1,46 +1,14 @@
import os
from threading import Lock
import torch
import whisper
from fastapi import FastAPI, File, UploadFile
from fastapi.responses import RedirectResponse
from . import __version__
from .audio import load_audio
# TODO use pydantic config
model_name = os.getenv("ASR_MODEL", "base")
if torch.cuda.is_available():
model = whisper.load_model(model_name).cuda()
else:
model = whisper.load_model(model_name)
model_lock = Lock()
from .inference import transcribe
# TODO use pydantic config
SAMPLE_RATE = 16000
# TODO move transcribe to a modeling worker
def transcribe(
audio,
# task: Union[str, None],
# language: Union[str, None],
# initial_prompt: Union[str, None],
):
# options_dict = {"task" : task}
# if language:
# options_dict["language"] = language
# if initial_prompt:
# options_dict["initial_prompt"] = initial_prompt
with model_lock:
# result = model.transcribe(audio, **options_dict)
result = model.transcribe(audio)
return result
app = FastAPI(
title="Local Whisper",
description="Making OpenAPI's Open Whisper available via ReST API locally.",

@ -10,7 +10,6 @@ def test_audio(sample_audio):
# Test the file length is appropriate size
assert DEFAULT_SAMPLE_RATE * 8 < audio.shape[0] < DEFAULT_SAMPLE_RATE * 12
# Taking the standard diviation of audio data can be used to check
# Amplitude Variability, Noise Detection, or Normalization. Hear we just want
# Amplitude Variability, Noise Detection, or Normalization. Hear we just want
# to make certain it does not contain a lot of noise.
assert 0 < audio.std() < 1

@ -0,0 +1,9 @@
from local_whisper.inference import transcribe
from local_whisper.audio import load_audio
def test_transcribe(sample_audio):
with open(sample_audio, mode="rb") as af:
audio = load_audio(af)
result = transcribe(audio)
assert result["text"].strip() == "Let's see, right now I'm playing Horizon Zero Dawn. I also had just recently finished BioShock Infinite."
Loading…
Cancel
Save