From a4a09b37f088d4b6b787ddc0de2111ee11fa2b1b Mon Sep 17 00:00:00 2001 From: Drew Bednar Date: Mon, 29 May 2023 13:29:38 -0400 Subject: [PATCH 1/2] Adding test for settings --- local_whisper/inference.py | 5 ++++- local_whisper/settings.py | 18 ++++++++++++++++++ tests/unit/test_inference.py | 8 ++++++-- tests/unit/test_settings.py | 25 +++++++++++++++++++++++++ 4 files changed, 53 insertions(+), 3 deletions(-) create mode 100644 local_whisper/settings.py create mode 100644 tests/unit/test_settings.py diff --git a/local_whisper/inference.py b/local_whisper/inference.py index 7d5490b..6cc8584 100644 --- a/local_whisper/inference.py +++ b/local_whisper/inference.py @@ -4,14 +4,17 @@ from threading import Lock import torch import whisper +from .settings import whisper_settings + # TODO use pydantic config -model_name = os.getenv("ASR_MODEL", "base") +model_name = whisper_settings.base_asr_model if torch.cuda.is_available(): model = whisper.load_model(model_name).cuda() else: model = whisper.load_model(model_name) model_lock = Lock() + # TODO move transcribe to a modeling worker def transcribe(audio): # options_dict = {"task" : task} diff --git a/local_whisper/settings.py b/local_whisper/settings.py new file mode 100644 index 0000000..d24dc8c --- /dev/null +++ b/local_whisper/settings.py @@ -0,0 +1,18 @@ +from pydantic import BaseSettings, Field + + +class WhisperSettings(BaseSettings): + """Whisper Application Settings. + + All environment varaibles supplied should be prefixed with "WHISPER_". + """ + + base_asr_model: str = Field( + default="medium.en", description="The base whisper model to host." + ) + + class Config: + env_prefix = "WHISPER_" + + +whisper_settings = WhisperSettings() diff --git a/tests/unit/test_inference.py b/tests/unit/test_inference.py index 7db4422..ce1f5ae 100644 --- a/tests/unit/test_inference.py +++ b/tests/unit/test_inference.py @@ -1,9 +1,13 @@ -from local_whisper.inference import transcribe from local_whisper.audio import load_audio +from local_whisper.inference import transcribe def test_transcribe(sample_audio): with open(sample_audio, mode="rb") as af: audio = load_audio(af) result = transcribe(audio) - assert result["text"].strip() == "Let's see, right now I'm playing Horizon Zero Dawn. I also had just recently finished BioShock Infinite." + assert ( + result["text"].strip().lower() + == "Let's see, right now I'm playing Horizon Zero Dawn. \ + I also had just recently finished BioShock Infinite.".lower() + ) diff --git a/tests/unit/test_settings.py b/tests/unit/test_settings.py new file mode 100644 index 0000000..9939974 --- /dev/null +++ b/tests/unit/test_settings.py @@ -0,0 +1,25 @@ +import os +from unittest.mock import patch + +from pydantic.types import SecretStr + +from local_whisper.settings import WhisperSettings + +SETTING_DEFAULTS = { + "BASE_ASR_MODEL": "medium.en", +} + + +def test_setting_defaults(): + """Regression test for settings schema.""" + with patch.dict(os.environ, {}, clear=True): + savant_settings = WhisperSettings() + assert len(savant_settings.dict()) == len(SETTING_DEFAULTS) + for k, v in SETTING_DEFAULTS.items(): + _setting_value = getattr(savant_settings, k.lower()) + unmasked_setting = ( + _setting_value.get_secret_value() + if isinstance(_setting_value, SecretStr) + else _setting_value + ) + unmasked_setting == v -- 2.38.4 From 0ff912be1bc7140e7e3d8d9fbe666736a656abe8 Mon Sep 17 00:00:00 2001 From: Drew Bednar Date: Mon, 29 May 2023 13:44:54 -0400 Subject: [PATCH 2/2] More settings tests --- local_whisper/audio.py | 4 +++- local_whisper/settings.py | 4 ++++ tests/unit/test_inference.py | 4 ++-- tests/unit/test_settings.py | 10 +++++++--- 4 files changed, 16 insertions(+), 6 deletions(-) diff --git a/local_whisper/audio.py b/local_whisper/audio.py index cc33676..0de8197 100644 --- a/local_whisper/audio.py +++ b/local_whisper/audio.py @@ -3,7 +3,9 @@ from typing import BinaryIO import ffmpeg import numpy as np -DEFAULT_SAMPLE_RATE = 16000 +from .settings import whisper_settings + +DEFAULT_SAMPLE_RATE = whisper_settings.default_sample_rate # TODO probably can offload this on a worker queue too diff --git a/local_whisper/settings.py b/local_whisper/settings.py index d24dc8c..fd5d3b7 100644 --- a/local_whisper/settings.py +++ b/local_whisper/settings.py @@ -10,6 +10,10 @@ class WhisperSettings(BaseSettings): base_asr_model: str = Field( default="medium.en", description="The base whisper model to host." ) + default_sample_rate: int = Field( + default=16000, + description="The default sample rate used to resample the audio if necessary", + ) class Config: env_prefix = "WHISPER_" diff --git a/tests/unit/test_inference.py b/tests/unit/test_inference.py index ce1f5ae..b58d320 100644 --- a/tests/unit/test_inference.py +++ b/tests/unit/test_inference.py @@ -8,6 +8,6 @@ def test_transcribe(sample_audio): result = transcribe(audio) assert ( result["text"].strip().lower() - == "Let's see, right now I'm playing Horizon Zero Dawn. \ - I also had just recently finished BioShock Infinite.".lower() + == "Let's see, right now I'm playing Horizon Zero Dawn." + " I also had just recently finished BioShock Infinite.".lower() ) diff --git a/tests/unit/test_settings.py b/tests/unit/test_settings.py index 9939974..a53cbe7 100644 --- a/tests/unit/test_settings.py +++ b/tests/unit/test_settings.py @@ -5,9 +5,7 @@ from pydantic.types import SecretStr from local_whisper.settings import WhisperSettings -SETTING_DEFAULTS = { - "BASE_ASR_MODEL": "medium.en", -} +SETTING_DEFAULTS = {"BASE_ASR_MODEL": "medium.en", "DEFAULT_SAMPLE_RATE": 16000} def test_setting_defaults(): @@ -23,3 +21,9 @@ def test_setting_defaults(): else _setting_value ) unmasked_setting == v + + +def test_with_envvar_prefix(): + with patch.dict(os.environ, {"WHISPER_DEFAULT_SAMPLE_RATE": "22500"}, clear=True): + whisper_settings = WhisperSettings() + assert whisper_settings.default_sample_rate == 22500 -- 2.38.4