Basic inference testing

pull/6/head
Drew Bednar 2 years ago
parent 1c5ee1fc27
commit c4b3738a24

@ -0,0 +1,26 @@
import os
from threading import Lock
import torch
import whisper
# TODO use pydantic config
model_name = os.getenv("ASR_MODEL", "base")
if torch.cuda.is_available():
model = whisper.load_model(model_name).cuda()
else:
model = whisper.load_model(model_name)
model_lock = Lock()
# TODO move transcribe to a modeling worker
def transcribe(audio):
# options_dict = {"task" : task}
# if language:
# options_dict["language"] = language
# if initial_prompt:
# options_dict["initial_prompt"] = initial_prompt
with model_lock:
# result = model.transcribe(audio, **options_dict)
result = model.transcribe(audio)
return result

@ -1,46 +1,14 @@
import os
from threading import Lock
import torch
import whisper
from fastapi import FastAPI, File, UploadFile from fastapi import FastAPI, File, UploadFile
from fastapi.responses import RedirectResponse from fastapi.responses import RedirectResponse
from . import __version__ from . import __version__
from .audio import load_audio from .audio import load_audio
from .inference import transcribe
# TODO use pydantic config
model_name = os.getenv("ASR_MODEL", "base")
if torch.cuda.is_available():
model = whisper.load_model(model_name).cuda()
else:
model = whisper.load_model(model_name)
model_lock = Lock()
# TODO use pydantic config # TODO use pydantic config
SAMPLE_RATE = 16000 SAMPLE_RATE = 16000
# TODO move transcribe to a modeling worker
def transcribe(
audio,
# task: Union[str, None],
# language: Union[str, None],
# initial_prompt: Union[str, None],
):
# options_dict = {"task" : task}
# if language:
# options_dict["language"] = language
# if initial_prompt:
# options_dict["initial_prompt"] = initial_prompt
with model_lock:
# result = model.transcribe(audio, **options_dict)
result = model.transcribe(audio)
return result
app = FastAPI( app = FastAPI(
title="Local Whisper", title="Local Whisper",
description="Making OpenAPI's Open Whisper available via ReST API locally.", description="Making OpenAPI's Open Whisper available via ReST API locally.",

@ -10,7 +10,6 @@ def test_audio(sample_audio):
# Test the file length is appropriate size # Test the file length is appropriate size
assert DEFAULT_SAMPLE_RATE * 8 < audio.shape[0] < DEFAULT_SAMPLE_RATE * 12 assert DEFAULT_SAMPLE_RATE * 8 < audio.shape[0] < DEFAULT_SAMPLE_RATE * 12
# Taking the standard diviation of audio data can be used to check # Taking the standard diviation of audio data can be used to check
# Amplitude Variability, Noise Detection, or Normalization. Hear we just want # Amplitude Variability, Noise Detection, or Normalization. Hear we just want
# to make certain it does not contain a lot of noise. # to make certain it does not contain a lot of noise.
assert 0 < audio.std() < 1 assert 0 < audio.std() < 1

@ -0,0 +1,9 @@
from local_whisper.inference import transcribe
from local_whisper.audio import load_audio
def test_transcribe(sample_audio):
with open(sample_audio, mode="rb") as af:
audio = load_audio(af)
result = transcribe(audio)
assert result["text"].strip() == "Let's see, right now I'm playing Horizon Zero Dawn. I also had just recently finished BioShock Infinite."
Loading…
Cancel
Save