diff --git a/chat_savant/cli.py b/chat_savant/cli.py index f1ed496..db6d9a8 100644 --- a/chat_savant/cli.py +++ b/chat_savant/cli.py @@ -1,6 +1,26 @@ -def main(): - print("Hey this is the cli application") +import click + +from . import run_stable_vicuna + + +@click.group() +def cli(): + pass + + +@cli.command() +@click.option( + "--model-dir", + type=click.Path(exists=True, file_okay=False, dir_okay=True), + envvar="SAVANT_MODEL_DIR", + default="~/models", + show_default=True, + help="The path to a directory containing a hugging face Stable-Vicuna model.", +) +def stable_vicuna(model_dir): + """Runs a Stable Vicuna CLI prompt.""" + run_stable_vicuna.main(model_dir=model_dir) if __name__ == "__main__": - main() + cli() diff --git a/chat_savant/run_stable_vicuna.py b/chat_savant/run_stable_vicuna.py new file mode 100644 index 0000000..8ab6dc3 --- /dev/null +++ b/chat_savant/run_stable_vicuna.py @@ -0,0 +1,74 @@ +import textwrap + +import colorama +from transformers import LlamaForCausalLM, LlamaTokenizer +from transformers import logging as t_logging +from transformers import pipeline + +# Configure logging level for transformers library +t_logging.logging.set_verbosity_info() + + +# Utility Functions +def get_prompt(human_prompt): + prompt_template = f"### Human: {human_prompt} \n### Assistant:" + return prompt_template + + +def remove_human_text(text): + return text.split("### Human:", 1)[0] + + +def parse_text(data): + for item in data: + text = item["generated_text"] + assistant_text_index = text.find("### Assistant:") + if assistant_text_index != -1: + assistant_text = text[ + assistant_text_index + len("### Assistant:") : + ].strip() + assistant_text = remove_human_text(assistant_text) + wrapped_text = textwrap.fill(assistant_text, width=100) + print(wrapped_text) + + +# Reasoning question +EXAMPLE_REASONING = "Answer the following question by reasoning step by step. \ + The cafeteria had 22 apples. If they used 20 for lunch, and bought 6 more, \ + how many apple do they have?" + + +# User interface +def main(model_dir): + # Model loading for inference + tokenizer = LlamaTokenizer.from_pretrained(model_dir) + + base_model = LlamaForCausalLM.from_pretrained( + model_dir, + load_in_8bit=True, + device_map="auto", + ) + + pipe = pipeline( + "text-generation", + model=base_model, + tokenizer=tokenizer, + max_length=512, + temperature=0.7, + top_p=0.95, + repetition_penalty=1.15, + ) + + print("Reading for inference!") + while True: + input_prompt = "" + input_prompt = input("USER:") + print(colorama.Style.DIM + f"You are submitting: {input_prompt}") + print(colorama.Style.RESET_ALL) + raw_output = pipe(get_prompt(input_prompt)) + parse_text(raw_output) + + +if __name__ == "__main__": + print("Warming up the engines...") + main() diff --git a/pyinvoke/autocomplete.sh b/pyinvoke/autocomplete.sh new file mode 100644 index 0000000..d37d964 --- /dev/null +++ b/pyinvoke/autocomplete.sh @@ -0,0 +1,32 @@ +# Invoke tab-completion script to be sourced with Bash shell. +# Known to work on Bash 3.x, untested on 4.x. + +_complete_invoke() { + local candidates + + # COMP_WORDS contains the entire command string up til now (including + # program name). + # We hand it to Invoke so it can figure out the current context: spit back + # core options, task names, the current task's options, or some combo. + candidates=`invoke --complete -- ${COMP_WORDS[*]}` + + # `compgen -W` takes list of valid options & a partial word & spits back + # possible matches. Necessary for any partial word completions (vs + # completions performed when no partial words are present). + # + # $2 is the current word or token being tabbed on, either empty string or a + # partial word, and thus wants to be compgen'd to arrive at some subset of + # our candidate list which actually matches. + # + # COMPREPLY is the list of valid completions handed back to `complete`. + COMPREPLY=( $(compgen -W "${candidates}" -- $2) ) +} + + +# Tell shell builtin to use the above for completing our invocations. +# * -F: use given function name to generate completions. +# * -o default: when function generates no results, use filenames. +# * positional args: program names to complete for. +complete -F _complete_invoke -o default invoke inv + +# vim: set ft=sh : diff --git a/pyproject.toml b/pyproject.toml index bdcd7ea..f957f8d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,7 +18,7 @@ classifiers = [ dynamic = ["version", "dependencies"] [project.scripts] -savant-cli = "chat_savant.cli:main" +savant = "chat_savant.cli:cli" [project.optional-dependencies] whisper = ["openai-whisper"] diff --git a/requirements.in b/requirements.in index e69de29..8955bf9 100644 --- a/requirements.in +++ b/requirements.in @@ -0,0 +1,9 @@ +transformers @ git+https://github.com/huggingface/transformers@849367ccf741d8c58aa88ccfe1d52d8636eaf2b7 +bitsandbytes +datasets +loralib +sentencepiece +bitsandbytes +accelerate +langchain +colorama diff --git a/requirements.txt b/requirements.txt index ea51277..295d33d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,227 @@ # -# This file is autogenerated by pip-compile with Python 3.11 +# This file is autogenerated by pip-compile with Python 3.10 # by the following command: # -# pip-compile --output-file=requirements.txt requirements.in +# pip-compile requirements.in # +accelerate==0.18.0 + # via -r requirements.in +aiohttp==3.8.4 + # via + # datasets + # fsspec + # langchain +aiosignal==1.3.1 + # via aiohttp +async-timeout==4.0.2 + # via + # aiohttp + # langchain +attrs==23.1.0 + # via aiohttp +bitsandbytes==0.38.1 + # via -r requirements.in +certifi==2022.12.7 + # via requests +charset-normalizer==3.1.0 + # via + # aiohttp + # requests +cmake==3.26.3 + # via triton +colorama==0.4.6 + # via -r requirements.in +dataclasses-json==0.5.7 + # via langchain +datasets==2.12.0 + # via -r requirements.in +dill==0.3.6 + # via + # datasets + # multiprocess +filelock==3.12.0 + # via + # huggingface-hub + # torch + # transformers + # triton +frozenlist==1.3.3 + # via + # aiohttp + # aiosignal +fsspec[http]==2023.4.0 + # via + # datasets + # huggingface-hub +greenlet==2.0.2 + # via sqlalchemy +huggingface-hub==0.14.1 + # via + # datasets + # transformers +idna==3.4 + # via + # requests + # yarl +jinja2==3.1.2 + # via torch +langchain==0.0.160 + # via -r requirements.in +lit==16.0.3 + # via triton +loralib==0.1.1 + # via -r requirements.in +markupsafe==2.1.2 + # via jinja2 +marshmallow==3.19.0 + # via + # dataclasses-json + # marshmallow-enum +marshmallow-enum==1.5.1 + # via dataclasses-json +mpmath==1.3.0 + # via sympy +multidict==6.0.4 + # via + # aiohttp + # yarl +multiprocess==0.70.14 + # via datasets +mypy-extensions==1.0.0 + # via typing-inspect +networkx==3.1 + # via torch +numexpr==2.8.4 + # via langchain +numpy==1.24.3 + # via + # accelerate + # datasets + # langchain + # numexpr + # pandas + # pyarrow + # transformers +nvidia-cublas-cu11==11.10.3.66 + # via + # nvidia-cudnn-cu11 + # nvidia-cusolver-cu11 + # torch +nvidia-cuda-cupti-cu11==11.7.101 + # via torch +nvidia-cuda-nvrtc-cu11==11.7.99 + # via torch +nvidia-cuda-runtime-cu11==11.7.99 + # via torch +nvidia-cudnn-cu11==8.5.0.96 + # via torch +nvidia-cufft-cu11==10.9.0.58 + # via torch +nvidia-curand-cu11==10.2.10.91 + # via torch +nvidia-cusolver-cu11==11.4.0.1 + # via torch +nvidia-cusparse-cu11==11.7.4.91 + # via torch +nvidia-nccl-cu11==2.14.3 + # via torch +nvidia-nvtx-cu11==11.7.91 + # via torch +openapi-schema-pydantic==1.2.4 + # via langchain +packaging==23.1 + # via + # accelerate + # datasets + # huggingface-hub + # marshmallow + # transformers +pandas==2.0.1 + # via datasets +psutil==5.9.5 + # via accelerate +pyarrow==12.0.0 + # via datasets +pydantic==1.10.7 + # via + # langchain + # openapi-schema-pydantic +python-dateutil==2.8.2 + # via pandas +pytz==2023.3 + # via pandas +pyyaml==6.0 + # via + # accelerate + # datasets + # huggingface-hub + # langchain + # transformers +regex==2023.5.5 + # via transformers +requests==2.30.0 + # via + # datasets + # fsspec + # huggingface-hub + # langchain + # responses + # transformers +responses==0.18.0 + # via datasets +sentencepiece==0.1.99 + # via -r requirements.in +six==1.16.0 + # via python-dateutil +sqlalchemy==2.0.12 + # via langchain +sympy==1.11.1 + # via torch +tenacity==8.2.2 + # via langchain +tokenizers==0.13.3 + # via transformers +torch==2.0.0 + # via + # accelerate + # triton +tqdm==4.65.0 + # via + # datasets + # huggingface-hub + # langchain + # transformers +transformers @ git+https://github.com/huggingface/transformers@849367ccf741d8c58aa88ccfe1d52d8636eaf2b7 + # via -r requirements.in +triton==2.0.0 + # via torch +typing-extensions==4.5.0 + # via + # huggingface-hub + # pydantic + # sqlalchemy + # torch + # typing-inspect +typing-inspect==0.8.0 + # via dataclasses-json +tzdata==2023.3 + # via pandas +urllib3==2.0.2 + # via + # requests + # responses +wheel==0.40.0 + # via + # nvidia-cublas-cu11 + # nvidia-cuda-cupti-cu11 + # nvidia-cuda-runtime-cu11 + # nvidia-curand-cu11 + # nvidia-cusparse-cu11 + # nvidia-nvtx-cu11 +xxhash==3.2.0 + # via datasets +yarl==1.9.2 + # via aiohttp + +# The following packages are considered to be unsafe in a requirements file: +# setuptools