diff --git a/local_whisper/audio.py b/local_whisper/audio.py index cc33676..0de8197 100644 --- a/local_whisper/audio.py +++ b/local_whisper/audio.py @@ -3,7 +3,9 @@ from typing import BinaryIO import ffmpeg import numpy as np -DEFAULT_SAMPLE_RATE = 16000 +from .settings import whisper_settings + +DEFAULT_SAMPLE_RATE = whisper_settings.default_sample_rate # TODO probably can offload this on a worker queue too diff --git a/local_whisper/inference.py b/local_whisper/inference.py index 7d5490b..6cc8584 100644 --- a/local_whisper/inference.py +++ b/local_whisper/inference.py @@ -4,14 +4,17 @@ from threading import Lock import torch import whisper +from .settings import whisper_settings + # TODO use pydantic config -model_name = os.getenv("ASR_MODEL", "base") +model_name = whisper_settings.base_asr_model if torch.cuda.is_available(): model = whisper.load_model(model_name).cuda() else: model = whisper.load_model(model_name) model_lock = Lock() + # TODO move transcribe to a modeling worker def transcribe(audio): # options_dict = {"task" : task} diff --git a/local_whisper/settings.py b/local_whisper/settings.py new file mode 100644 index 0000000..fd5d3b7 --- /dev/null +++ b/local_whisper/settings.py @@ -0,0 +1,22 @@ +from pydantic import BaseSettings, Field + + +class WhisperSettings(BaseSettings): + """Whisper Application Settings. + + All environment varaibles supplied should be prefixed with "WHISPER_". + """ + + base_asr_model: str = Field( + default="medium.en", description="The base whisper model to host." + ) + default_sample_rate: int = Field( + default=16000, + description="The default sample rate used to resample the audio if necessary", + ) + + class Config: + env_prefix = "WHISPER_" + + +whisper_settings = WhisperSettings() diff --git a/tests/unit/test_inference.py b/tests/unit/test_inference.py index 7db4422..b58d320 100644 --- a/tests/unit/test_inference.py +++ b/tests/unit/test_inference.py @@ -1,9 +1,13 @@ -from local_whisper.inference import transcribe from local_whisper.audio import load_audio +from local_whisper.inference import transcribe def test_transcribe(sample_audio): with open(sample_audio, mode="rb") as af: audio = load_audio(af) result = transcribe(audio) - assert result["text"].strip() == "Let's see, right now I'm playing Horizon Zero Dawn. I also had just recently finished BioShock Infinite." + assert ( + result["text"].strip().lower() + == "Let's see, right now I'm playing Horizon Zero Dawn." + " I also had just recently finished BioShock Infinite.".lower() + ) diff --git a/tests/unit/test_settings.py b/tests/unit/test_settings.py new file mode 100644 index 0000000..a53cbe7 --- /dev/null +++ b/tests/unit/test_settings.py @@ -0,0 +1,29 @@ +import os +from unittest.mock import patch + +from pydantic.types import SecretStr + +from local_whisper.settings import WhisperSettings + +SETTING_DEFAULTS = {"BASE_ASR_MODEL": "medium.en", "DEFAULT_SAMPLE_RATE": 16000} + + +def test_setting_defaults(): + """Regression test for settings schema.""" + with patch.dict(os.environ, {}, clear=True): + savant_settings = WhisperSettings() + assert len(savant_settings.dict()) == len(SETTING_DEFAULTS) + for k, v in SETTING_DEFAULTS.items(): + _setting_value = getattr(savant_settings, k.lower()) + unmasked_setting = ( + _setting_value.get_secret_value() + if isinstance(_setting_value, SecretStr) + else _setting_value + ) + unmasked_setting == v + + +def test_with_envvar_prefix(): + with patch.dict(os.environ, {"WHISPER_DEFAULT_SAMPLE_RATE": "22500"}, clear=True): + whisper_settings = WhisperSettings() + assert whisper_settings.default_sample_rate == 22500