diff --git a/local_whisper/inference.py b/local_whisper/inference.py index 7d5490b..6cc8584 100644 --- a/local_whisper/inference.py +++ b/local_whisper/inference.py @@ -4,14 +4,17 @@ from threading import Lock import torch import whisper +from .settings import whisper_settings + # TODO use pydantic config -model_name = os.getenv("ASR_MODEL", "base") +model_name = whisper_settings.base_asr_model if torch.cuda.is_available(): model = whisper.load_model(model_name).cuda() else: model = whisper.load_model(model_name) model_lock = Lock() + # TODO move transcribe to a modeling worker def transcribe(audio): # options_dict = {"task" : task} diff --git a/local_whisper/settings.py b/local_whisper/settings.py new file mode 100644 index 0000000..d24dc8c --- /dev/null +++ b/local_whisper/settings.py @@ -0,0 +1,18 @@ +from pydantic import BaseSettings, Field + + +class WhisperSettings(BaseSettings): + """Whisper Application Settings. + + All environment varaibles supplied should be prefixed with "WHISPER_". + """ + + base_asr_model: str = Field( + default="medium.en", description="The base whisper model to host." + ) + + class Config: + env_prefix = "WHISPER_" + + +whisper_settings = WhisperSettings() diff --git a/tests/unit/test_inference.py b/tests/unit/test_inference.py index 7db4422..ce1f5ae 100644 --- a/tests/unit/test_inference.py +++ b/tests/unit/test_inference.py @@ -1,9 +1,13 @@ -from local_whisper.inference import transcribe from local_whisper.audio import load_audio +from local_whisper.inference import transcribe def test_transcribe(sample_audio): with open(sample_audio, mode="rb") as af: audio = load_audio(af) result = transcribe(audio) - assert result["text"].strip() == "Let's see, right now I'm playing Horizon Zero Dawn. I also had just recently finished BioShock Infinite." + assert ( + result["text"].strip().lower() + == "Let's see, right now I'm playing Horizon Zero Dawn. \ + I also had just recently finished BioShock Infinite.".lower() + ) diff --git a/tests/unit/test_settings.py b/tests/unit/test_settings.py new file mode 100644 index 0000000..9939974 --- /dev/null +++ b/tests/unit/test_settings.py @@ -0,0 +1,25 @@ +import os +from unittest.mock import patch + +from pydantic.types import SecretStr + +from local_whisper.settings import WhisperSettings + +SETTING_DEFAULTS = { + "BASE_ASR_MODEL": "medium.en", +} + + +def test_setting_defaults(): + """Regression test for settings schema.""" + with patch.dict(os.environ, {}, clear=True): + savant_settings = WhisperSettings() + assert len(savant_settings.dict()) == len(SETTING_DEFAULTS) + for k, v in SETTING_DEFAULTS.items(): + _setting_value = getattr(savant_settings, k.lower()) + unmasked_setting = ( + _setting_value.get_secret_value() + if isinstance(_setting_value, SecretStr) + else _setting_value + ) + unmasked_setting == v