From 0ff912be1bc7140e7e3d8d9fbe666736a656abe8 Mon Sep 17 00:00:00 2001 From: Drew Bednar Date: Mon, 29 May 2023 13:44:54 -0400 Subject: [PATCH] More settings tests --- local_whisper/audio.py | 4 +++- local_whisper/settings.py | 4 ++++ tests/unit/test_inference.py | 4 ++-- tests/unit/test_settings.py | 10 +++++++--- 4 files changed, 16 insertions(+), 6 deletions(-) diff --git a/local_whisper/audio.py b/local_whisper/audio.py index cc33676..0de8197 100644 --- a/local_whisper/audio.py +++ b/local_whisper/audio.py @@ -3,7 +3,9 @@ from typing import BinaryIO import ffmpeg import numpy as np -DEFAULT_SAMPLE_RATE = 16000 +from .settings import whisper_settings + +DEFAULT_SAMPLE_RATE = whisper_settings.default_sample_rate # TODO probably can offload this on a worker queue too diff --git a/local_whisper/settings.py b/local_whisper/settings.py index d24dc8c..fd5d3b7 100644 --- a/local_whisper/settings.py +++ b/local_whisper/settings.py @@ -10,6 +10,10 @@ class WhisperSettings(BaseSettings): base_asr_model: str = Field( default="medium.en", description="The base whisper model to host." ) + default_sample_rate: int = Field( + default=16000, + description="The default sample rate used to resample the audio if necessary", + ) class Config: env_prefix = "WHISPER_" diff --git a/tests/unit/test_inference.py b/tests/unit/test_inference.py index ce1f5ae..b58d320 100644 --- a/tests/unit/test_inference.py +++ b/tests/unit/test_inference.py @@ -8,6 +8,6 @@ def test_transcribe(sample_audio): result = transcribe(audio) assert ( result["text"].strip().lower() - == "Let's see, right now I'm playing Horizon Zero Dawn. \ - I also had just recently finished BioShock Infinite.".lower() + == "Let's see, right now I'm playing Horizon Zero Dawn." + " I also had just recently finished BioShock Infinite.".lower() ) diff --git a/tests/unit/test_settings.py b/tests/unit/test_settings.py index 9939974..a53cbe7 100644 --- a/tests/unit/test_settings.py +++ b/tests/unit/test_settings.py @@ -5,9 +5,7 @@ from pydantic.types import SecretStr from local_whisper.settings import WhisperSettings -SETTING_DEFAULTS = { - "BASE_ASR_MODEL": "medium.en", -} +SETTING_DEFAULTS = {"BASE_ASR_MODEL": "medium.en", "DEFAULT_SAMPLE_RATE": 16000} def test_setting_defaults(): @@ -23,3 +21,9 @@ def test_setting_defaults(): else _setting_value ) unmasked_setting == v + + +def test_with_envvar_prefix(): + with patch.dict(os.environ, {"WHISPER_DEFAULT_SAMPLE_RATE": "22500"}, clear=True): + whisper_settings = WhisperSettings() + assert whisper_settings.default_sample_rate == 22500