Compare commits

...

2 Commits

@ -3,7 +3,9 @@ from typing import BinaryIO
import ffmpeg
import numpy as np
DEFAULT_SAMPLE_RATE = 16000
from .settings import whisper_settings
DEFAULT_SAMPLE_RATE = whisper_settings.default_sample_rate
# TODO probably can offload this on a worker queue too

@ -4,14 +4,17 @@ from threading import Lock
import torch
import whisper
from .settings import whisper_settings
# TODO use pydantic config
model_name = os.getenv("ASR_MODEL", "base")
model_name = whisper_settings.base_asr_model
if torch.cuda.is_available():
model = whisper.load_model(model_name).cuda()
else:
model = whisper.load_model(model_name)
model_lock = Lock()
# TODO move transcribe to a modeling worker
def transcribe(audio):
# options_dict = {"task" : task}

@ -0,0 +1,22 @@
from pydantic import BaseSettings, Field
class WhisperSettings(BaseSettings):
"""Whisper Application Settings.
All environment varaibles supplied should be prefixed with "WHISPER_".
"""
base_asr_model: str = Field(
default="medium.en", description="The base whisper model to host."
)
default_sample_rate: int = Field(
default=16000,
description="The default sample rate used to resample the audio if necessary",
)
class Config:
env_prefix = "WHISPER_"
whisper_settings = WhisperSettings()

@ -1,9 +1,13 @@
from local_whisper.inference import transcribe
from local_whisper.audio import load_audio
from local_whisper.inference import transcribe
def test_transcribe(sample_audio):
with open(sample_audio, mode="rb") as af:
audio = load_audio(af)
result = transcribe(audio)
assert result["text"].strip() == "Let's see, right now I'm playing Horizon Zero Dawn. I also had just recently finished BioShock Infinite."
assert (
result["text"].strip().lower()
== "Let's see, right now I'm playing Horizon Zero Dawn."
" I also had just recently finished BioShock Infinite.".lower()
)

@ -0,0 +1,29 @@
import os
from unittest.mock import patch
from pydantic.types import SecretStr
from local_whisper.settings import WhisperSettings
SETTING_DEFAULTS = {"BASE_ASR_MODEL": "medium.en", "DEFAULT_SAMPLE_RATE": 16000}
def test_setting_defaults():
"""Regression test for settings schema."""
with patch.dict(os.environ, {}, clear=True):
savant_settings = WhisperSettings()
assert len(savant_settings.dict()) == len(SETTING_DEFAULTS)
for k, v in SETTING_DEFAULTS.items():
_setting_value = getattr(savant_settings, k.lower())
unmasked_setting = (
_setting_value.get_secret_value()
if isinstance(_setting_value, SecretStr)
else _setting_value
)
unmasked_setting == v
def test_with_envvar_prefix():
with patch.dict(os.environ, {"WHISPER_DEFAULT_SAMPLE_RATE": "22500"}, clear=True):
whisper_settings = WhisperSettings()
assert whisper_settings.default_sample_rate == 22500
Loading…
Cancel
Save