diff --git a/local_whisper/inference.py b/local_whisper/inference.py new file mode 100644 index 0000000..7d5490b --- /dev/null +++ b/local_whisper/inference.py @@ -0,0 +1,26 @@ +import os +from threading import Lock + +import torch +import whisper + +# TODO use pydantic config +model_name = os.getenv("ASR_MODEL", "base") +if torch.cuda.is_available(): + model = whisper.load_model(model_name).cuda() +else: + model = whisper.load_model(model_name) +model_lock = Lock() + +# TODO move transcribe to a modeling worker +def transcribe(audio): + # options_dict = {"task" : task} + # if language: + # options_dict["language"] = language + # if initial_prompt: + # options_dict["initial_prompt"] = initial_prompt + with model_lock: + # result = model.transcribe(audio, **options_dict) + result = model.transcribe(audio) + + return result diff --git a/local_whisper/webservice.py b/local_whisper/webservice.py index c914ce1..e369538 100644 --- a/local_whisper/webservice.py +++ b/local_whisper/webservice.py @@ -1,46 +1,14 @@ -import os -from threading import Lock - -import torch -import whisper from fastapi import FastAPI, File, UploadFile from fastapi.responses import RedirectResponse from . import __version__ from .audio import load_audio - -# TODO use pydantic config -model_name = os.getenv("ASR_MODEL", "base") -if torch.cuda.is_available(): - model = whisper.load_model(model_name).cuda() -else: - model = whisper.load_model(model_name) -model_lock = Lock() - +from .inference import transcribe # TODO use pydantic config SAMPLE_RATE = 16000 -# TODO move transcribe to a modeling worker -def transcribe( - audio, - # task: Union[str, None], - # language: Union[str, None], - # initial_prompt: Union[str, None], -): - # options_dict = {"task" : task} - # if language: - # options_dict["language"] = language - # if initial_prompt: - # options_dict["initial_prompt"] = initial_prompt - with model_lock: - # result = model.transcribe(audio, **options_dict) - result = model.transcribe(audio) - - return result - - app = FastAPI( title="Local Whisper", description="Making OpenAPI's Open Whisper available via ReST API locally.", diff --git a/tests/unit/test_audio.py b/tests/unit/test_audio.py index b1050bf..7c89f5d 100644 --- a/tests/unit/test_audio.py +++ b/tests/unit/test_audio.py @@ -10,7 +10,6 @@ def test_audio(sample_audio): # Test the file length is appropriate size assert DEFAULT_SAMPLE_RATE * 8 < audio.shape[0] < DEFAULT_SAMPLE_RATE * 12 # Taking the standard diviation of audio data can be used to check - # Amplitude Variability, Noise Detection, or Normalization. Hear we just want + # Amplitude Variability, Noise Detection, or Normalization. Hear we just want # to make certain it does not contain a lot of noise. assert 0 < audio.std() < 1 - \ No newline at end of file diff --git a/tests/unit/test_inference.py b/tests/unit/test_inference.py new file mode 100644 index 0000000..7db4422 --- /dev/null +++ b/tests/unit/test_inference.py @@ -0,0 +1,9 @@ +from local_whisper.inference import transcribe +from local_whisper.audio import load_audio + + +def test_transcribe(sample_audio): + with open(sample_audio, mode="rb") as af: + audio = load_audio(af) + result = transcribe(audio) + assert result["text"].strip() == "Let's see, right now I'm playing Horizon Zero Dawn. I also had just recently finished BioShock Infinite."