Compare commits

..

1 Commits

@ -1,6 +1,26 @@
def main():
print("Hey this is the cli application")
import click
from . import run_stable_vicuna
@click.group()
def cli():
pass
@cli.command()
@click.option(
"--model-dir",
type=click.Path(exists=True, file_okay=False, dir_okay=True),
envvar="SAVANT_MODEL_DIR",
default="~/models",
show_default=True,
help="The path to a directory containing a hugging face Stable-Vicuna model.",
)
def stable_vicuna(model_dir):
"""Runs a Stable Vicuna CLI prompt."""
run_stable_vicuna.main(model_dir=model_dir)
if __name__ == "__main__":
main()
cli()

@ -1,40 +0,0 @@
"""
Uses the Eleven Labs Python library and API to stream audio.
"""
from elevenlabs import Voice, generate, set_api_key, stream, voices
from .settings import savant_settings
set_api_key(savant_settings.eleven_labs_api_key.get_secret_value())
voices = voices()
drew_voice = voices[-1]
drew_voice.settings.stability = 0.3
drew_voice.settings.similarity_boost = 0.9
def get_voice_model(model_name: str, voices: list[Voice]) -> Voice:
target_voice = None
for v in voices:
if v.name == model_name:
target_voice = v
if target_voice is None:
raise ValueError(f"Voice Model: {model_name} not found.")
return target_voice
def generate_audio(input_text: str = "") -> None:
audio_stream = generate(text=input_text, voice=voices[-1], stream=True)
stream(audio_stream)
def main():
while True:
input_text = input("Say something in Drew's voice: ")
generate_audio(input_text=input_text)
if __name__ == "__main__":
main()

@ -0,0 +1,74 @@
import textwrap
import colorama
from transformers import LlamaForCausalLM, LlamaTokenizer
from transformers import logging as t_logging
from transformers import pipeline
# Configure logging level for transformers library
t_logging.logging.set_verbosity_info()
# Utility Functions
def get_prompt(human_prompt):
prompt_template = f"### Human: {human_prompt} \n### Assistant:"
return prompt_template
def remove_human_text(text):
return text.split("### Human:", 1)[0]
def parse_text(data):
for item in data:
text = item["generated_text"]
assistant_text_index = text.find("### Assistant:")
if assistant_text_index != -1:
assistant_text = text[
assistant_text_index + len("### Assistant:") :
].strip()
assistant_text = remove_human_text(assistant_text)
wrapped_text = textwrap.fill(assistant_text, width=100)
print(wrapped_text)
# Reasoning question
EXAMPLE_REASONING = "Answer the following question by reasoning step by step. \
The cafeteria had 22 apples. If they used 20 for lunch, and bought 6 more, \
how many apple do they have?"
# User interface
def main(model_dir):
# Model loading for inference
tokenizer = LlamaTokenizer.from_pretrained(model_dir)
base_model = LlamaForCausalLM.from_pretrained(
model_dir,
load_in_8bit=True,
device_map="auto",
)
pipe = pipeline(
"text-generation",
model=base_model,
tokenizer=tokenizer,
max_length=512,
temperature=0.7,
top_p=0.95,
repetition_penalty=1.15,
)
print("Reading for inference!")
while True:
input_prompt = ""
input_prompt = input("USER:")
print(colorama.Style.DIM + f"You are submitting: {input_prompt}")
print(colorama.Style.RESET_ALL)
raw_output = pipe(get_prompt(input_prompt))
parse_text(raw_output)
if __name__ == "__main__":
print("Warming up the engines...")
main()

@ -1,32 +0,0 @@
from pydantic import BaseSettings, Field, SecretStr
class SavantSettings(BaseSettings):
"""Savant Application Settings.
All environment varaibles supplied should be prefixed with "SAVANT_".
"""
eleven_labs_api_key: SecretStr = Field(
default="", description="An optional Eleven Labs API key for text to speech."
)
eleven_labs_model: str = Field(
default="Arnold", description="The text-to-speech model name used in eleven labs audio generation."
)
llm_model_name: str = Field(
default="eachadea_vicuna-7b-1.1", description="The large language model name used in API requests."
)
openai_api_key: SecretStr = Field(
default="EMPTY", description="An OPEN_API_KEY or an empty value if using FastChat replacement server"
)
openai_api_base: str = Field(
default="http://localhost:8000/v1",
description="The base url to an OpenAI API compliant endpoint. \
Defaulted to FastChat replacement server defaults.",
)
class Config:
env_prefix = "SAVANT_"
savant_settings = SavantSettings()

@ -0,0 +1,32 @@
# Invoke tab-completion script to be sourced with Bash shell.
# Known to work on Bash 3.x, untested on 4.x.
_complete_invoke() {
local candidates
# COMP_WORDS contains the entire command string up til now (including
# program name).
# We hand it to Invoke so it can figure out the current context: spit back
# core options, task names, the current task's options, or some combo.
candidates=`invoke --complete -- ${COMP_WORDS[*]}`
# `compgen -W` takes list of valid options & a partial word & spits back
# possible matches. Necessary for any partial word completions (vs
# completions performed when no partial words are present).
#
# $2 is the current word or token being tabbed on, either empty string or a
# partial word, and thus wants to be compgen'd to arrive at some subset of
# our candidate list which actually matches.
#
# COMPREPLY is the list of valid completions handed back to `complete`.
COMPREPLY=( $(compgen -W "${candidates}" -- $2) )
}
# Tell shell builtin to use the above for completing our invocations.
# * -F: use given function name to generate completions.
# * -o default: when function generates no results, use filenames.
# * positional args: program names to complete for.
complete -F _complete_invoke -o default invoke inv
# vim: set ft=sh :

@ -18,11 +18,10 @@ classifiers = [
dynamic = ["version", "dependencies"]
[project.scripts]
savant-cli = "chat_savant.cli:main"
savant = "chat_savant.cli:cli"
[project.optional-dependencies]
whisper = ["openai-whisper"]
elevenlabs = ["elevenlabs>=0.2.16,<=0.3.0", "mpv"]
[tool.setuptools]
packages = ["chat_savant"]
@ -30,10 +29,3 @@ packages = ["chat_savant"]
[tool.setuptools.dynamic]
version = {attr = "chat_savant.__version__"}
dependencies = {file = ["requirements.txt"]}
[tool.black]
line-length = 120
skip-string-normalization = true
[tool.ruff]
line-length = 120

@ -1 +1,9 @@
pydantic>=1.6.2,<2.0.0
transformers @ git+https://github.com/huggingface/transformers@849367ccf741d8c58aa88ccfe1d52d8636eaf2b7
bitsandbytes
datasets
loralib
sentencepiece
bitsandbytes
accelerate
langchain
colorama

@ -4,7 +4,224 @@
#
# pip-compile requirements.in
#
pydantic==1.10.8
accelerate==0.18.0
# via -r requirements.in
typing-extensions==4.6.2
# via pydantic
aiohttp==3.8.4
# via
# datasets
# fsspec
# langchain
aiosignal==1.3.1
# via aiohttp
async-timeout==4.0.2
# via
# aiohttp
# langchain
attrs==23.1.0
# via aiohttp
bitsandbytes==0.38.1
# via -r requirements.in
certifi==2022.12.7
# via requests
charset-normalizer==3.1.0
# via
# aiohttp
# requests
cmake==3.26.3
# via triton
colorama==0.4.6
# via -r requirements.in
dataclasses-json==0.5.7
# via langchain
datasets==2.12.0
# via -r requirements.in
dill==0.3.6
# via
# datasets
# multiprocess
filelock==3.12.0
# via
# huggingface-hub
# torch
# transformers
# triton
frozenlist==1.3.3
# via
# aiohttp
# aiosignal
fsspec[http]==2023.4.0
# via
# datasets
# huggingface-hub
greenlet==2.0.2
# via sqlalchemy
huggingface-hub==0.14.1
# via
# datasets
# transformers
idna==3.4
# via
# requests
# yarl
jinja2==3.1.2
# via torch
langchain==0.0.160
# via -r requirements.in
lit==16.0.3
# via triton
loralib==0.1.1
# via -r requirements.in
markupsafe==2.1.2
# via jinja2
marshmallow==3.19.0
# via
# dataclasses-json
# marshmallow-enum
marshmallow-enum==1.5.1
# via dataclasses-json
mpmath==1.3.0
# via sympy
multidict==6.0.4
# via
# aiohttp
# yarl
multiprocess==0.70.14
# via datasets
mypy-extensions==1.0.0
# via typing-inspect
networkx==3.1
# via torch
numexpr==2.8.4
# via langchain
numpy==1.24.3
# via
# accelerate
# datasets
# langchain
# numexpr
# pandas
# pyarrow
# transformers
nvidia-cublas-cu11==11.10.3.66
# via
# nvidia-cudnn-cu11
# nvidia-cusolver-cu11
# torch
nvidia-cuda-cupti-cu11==11.7.101
# via torch
nvidia-cuda-nvrtc-cu11==11.7.99
# via torch
nvidia-cuda-runtime-cu11==11.7.99
# via torch
nvidia-cudnn-cu11==8.5.0.96
# via torch
nvidia-cufft-cu11==10.9.0.58
# via torch
nvidia-curand-cu11==10.2.10.91
# via torch
nvidia-cusolver-cu11==11.4.0.1
# via torch
nvidia-cusparse-cu11==11.7.4.91
# via torch
nvidia-nccl-cu11==2.14.3
# via torch
nvidia-nvtx-cu11==11.7.91
# via torch
openapi-schema-pydantic==1.2.4
# via langchain
packaging==23.1
# via
# accelerate
# datasets
# huggingface-hub
# marshmallow
# transformers
pandas==2.0.1
# via datasets
psutil==5.9.5
# via accelerate
pyarrow==12.0.0
# via datasets
pydantic==1.10.7
# via
# langchain
# openapi-schema-pydantic
python-dateutil==2.8.2
# via pandas
pytz==2023.3
# via pandas
pyyaml==6.0
# via
# accelerate
# datasets
# huggingface-hub
# langchain
# transformers
regex==2023.5.5
# via transformers
requests==2.30.0
# via
# datasets
# fsspec
# huggingface-hub
# langchain
# responses
# transformers
responses==0.18.0
# via datasets
sentencepiece==0.1.99
# via -r requirements.in
six==1.16.0
# via python-dateutil
sqlalchemy==2.0.12
# via langchain
sympy==1.11.1
# via torch
tenacity==8.2.2
# via langchain
tokenizers==0.13.3
# via transformers
torch==2.0.0
# via
# accelerate
# triton
tqdm==4.65.0
# via
# datasets
# huggingface-hub
# langchain
# transformers
transformers @ git+https://github.com/huggingface/transformers@849367ccf741d8c58aa88ccfe1d52d8636eaf2b7
# via -r requirements.in
triton==2.0.0
# via torch
typing-extensions==4.5.0
# via
# huggingface-hub
# pydantic
# sqlalchemy
# torch
# typing-inspect
typing-inspect==0.8.0
# via dataclasses-json
tzdata==2023.3
# via pandas
urllib3==2.0.2
# via
# requests
# responses
wheel==0.40.0
# via
# nvidia-cublas-cu11
# nvidia-cuda-cupti-cu11
# nvidia-cuda-runtime-cu11
# nvidia-curand-cu11
# nvidia-cusparse-cu11
# nvidia-nvtx-cu11
xxhash==3.2.0
# via datasets
yarl==1.9.2
# via aiohttp
# The following packages are considered to be unsafe in a requirements file:
# setuptools

@ -1,33 +0,0 @@
import os
from unittest.mock import patch
from pydantic.types import SecretStr
from chat_savant.settings import SavantSettings
SETTING_DEFAULTS = {
"OPENAI_API_KEY": "EMPTY",
"OPENAI_API_BASE": "http://localhost:8000/v1",
"LLM_MODEL_NAME": "eachadea_vicuna-7b-1.1",
"ELEVEN_LABS_API_KEY": "",
"ELEVEN_LABS_MODEL": "",
}
def test_setting_defaults():
"""Regression test for settings schema."""
with patch.dict(os.environ, {}, clear=True):
savant_settings = SavantSettings()
assert len(savant_settings.dict()) == len(SETTING_DEFAULTS)
for k, v in SETTING_DEFAULTS.items():
_setting_value = getattr(savant_settings, k.lower())
unmasked_setting = (
_setting_value.get_secret_value() if isinstance(_setting_value, SecretStr) else _setting_value
)
unmasked_setting == v
def test_with_envvar_prefix():
with patch.dict(os.environ, {"SAVANT_ELEVEN_LABS_API_KEY": "thisisnotreal"}, clear=True):
savant_settings = SavantSettings()
assert savant_settings.eleven_labs_api_key.get_secret_value() == "thisisnotreal"
Loading…
Cancel
Save