Compare commits
1 Commits
master
...
drew/stabl
Author | SHA1 | Date |
---|---|---|
Drew Bednar | 1bed5efdf7 | 2 years ago |
@ -1,6 +1,26 @@
|
|||||||
def main():
|
import click
|
||||||
print("Hey this is the cli application")
|
|
||||||
|
from . import run_stable_vicuna
|
||||||
|
|
||||||
|
|
||||||
|
@click.group()
|
||||||
|
def cli():
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@cli.command()
|
||||||
|
@click.option(
|
||||||
|
"--model-dir",
|
||||||
|
type=click.Path(exists=True, file_okay=False, dir_okay=True),
|
||||||
|
envvar="SAVANT_MODEL_DIR",
|
||||||
|
default="~/models",
|
||||||
|
show_default=True,
|
||||||
|
help="The path to a directory containing a hugging face Stable-Vicuna model.",
|
||||||
|
)
|
||||||
|
def stable_vicuna(model_dir):
|
||||||
|
"""Runs a Stable Vicuna CLI prompt."""
|
||||||
|
run_stable_vicuna.main(model_dir=model_dir)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
cli()
|
||||||
|
@ -0,0 +1,74 @@
|
|||||||
|
import textwrap
|
||||||
|
|
||||||
|
import colorama
|
||||||
|
from transformers import LlamaForCausalLM, LlamaTokenizer
|
||||||
|
from transformers import logging as t_logging
|
||||||
|
from transformers import pipeline
|
||||||
|
|
||||||
|
# Configure logging level for transformers library
|
||||||
|
t_logging.logging.set_verbosity_info()
|
||||||
|
|
||||||
|
|
||||||
|
# Utility Functions
|
||||||
|
def get_prompt(human_prompt):
|
||||||
|
prompt_template = f"### Human: {human_prompt} \n### Assistant:"
|
||||||
|
return prompt_template
|
||||||
|
|
||||||
|
|
||||||
|
def remove_human_text(text):
|
||||||
|
return text.split("### Human:", 1)[0]
|
||||||
|
|
||||||
|
|
||||||
|
def parse_text(data):
|
||||||
|
for item in data:
|
||||||
|
text = item["generated_text"]
|
||||||
|
assistant_text_index = text.find("### Assistant:")
|
||||||
|
if assistant_text_index != -1:
|
||||||
|
assistant_text = text[
|
||||||
|
assistant_text_index + len("### Assistant:") :
|
||||||
|
].strip()
|
||||||
|
assistant_text = remove_human_text(assistant_text)
|
||||||
|
wrapped_text = textwrap.fill(assistant_text, width=100)
|
||||||
|
print(wrapped_text)
|
||||||
|
|
||||||
|
|
||||||
|
# Reasoning question
|
||||||
|
EXAMPLE_REASONING = "Answer the following question by reasoning step by step. \
|
||||||
|
The cafeteria had 22 apples. If they used 20 for lunch, and bought 6 more, \
|
||||||
|
how many apple do they have?"
|
||||||
|
|
||||||
|
|
||||||
|
# User interface
|
||||||
|
def main(model_dir):
|
||||||
|
# Model loading for inference
|
||||||
|
tokenizer = LlamaTokenizer.from_pretrained(model_dir)
|
||||||
|
|
||||||
|
base_model = LlamaForCausalLM.from_pretrained(
|
||||||
|
model_dir,
|
||||||
|
load_in_8bit=True,
|
||||||
|
device_map="auto",
|
||||||
|
)
|
||||||
|
|
||||||
|
pipe = pipeline(
|
||||||
|
"text-generation",
|
||||||
|
model=base_model,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
max_length=512,
|
||||||
|
temperature=0.7,
|
||||||
|
top_p=0.95,
|
||||||
|
repetition_penalty=1.15,
|
||||||
|
)
|
||||||
|
|
||||||
|
print("Reading for inference!")
|
||||||
|
while True:
|
||||||
|
input_prompt = ""
|
||||||
|
input_prompt = input("USER:")
|
||||||
|
print(colorama.Style.DIM + f"You are submitting: {input_prompt}")
|
||||||
|
print(colorama.Style.RESET_ALL)
|
||||||
|
raw_output = pipe(get_prompt(input_prompt))
|
||||||
|
parse_text(raw_output)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
print("Warming up the engines...")
|
||||||
|
main()
|
@ -0,0 +1,32 @@
|
|||||||
|
# Invoke tab-completion script to be sourced with Bash shell.
|
||||||
|
# Known to work on Bash 3.x, untested on 4.x.
|
||||||
|
|
||||||
|
_complete_invoke() {
|
||||||
|
local candidates
|
||||||
|
|
||||||
|
# COMP_WORDS contains the entire command string up til now (including
|
||||||
|
# program name).
|
||||||
|
# We hand it to Invoke so it can figure out the current context: spit back
|
||||||
|
# core options, task names, the current task's options, or some combo.
|
||||||
|
candidates=`invoke --complete -- ${COMP_WORDS[*]}`
|
||||||
|
|
||||||
|
# `compgen -W` takes list of valid options & a partial word & spits back
|
||||||
|
# possible matches. Necessary for any partial word completions (vs
|
||||||
|
# completions performed when no partial words are present).
|
||||||
|
#
|
||||||
|
# $2 is the current word or token being tabbed on, either empty string or a
|
||||||
|
# partial word, and thus wants to be compgen'd to arrive at some subset of
|
||||||
|
# our candidate list which actually matches.
|
||||||
|
#
|
||||||
|
# COMPREPLY is the list of valid completions handed back to `complete`.
|
||||||
|
COMPREPLY=( $(compgen -W "${candidates}" -- $2) )
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# Tell shell builtin to use the above for completing our invocations.
|
||||||
|
# * -F: use given function name to generate completions.
|
||||||
|
# * -o default: when function generates no results, use filenames.
|
||||||
|
# * positional args: program names to complete for.
|
||||||
|
complete -F _complete_invoke -o default invoke inv
|
||||||
|
|
||||||
|
# vim: set ft=sh :
|
@ -0,0 +1,9 @@
|
|||||||
|
transformers @ git+https://github.com/huggingface/transformers@849367ccf741d8c58aa88ccfe1d52d8636eaf2b7
|
||||||
|
bitsandbytes
|
||||||
|
datasets
|
||||||
|
loralib
|
||||||
|
sentencepiece
|
||||||
|
bitsandbytes
|
||||||
|
accelerate
|
||||||
|
langchain
|
||||||
|
colorama
|
@ -1,6 +1,227 @@
|
|||||||
#
|
#
|
||||||
# This file is autogenerated by pip-compile with Python 3.11
|
# This file is autogenerated by pip-compile with Python 3.10
|
||||||
# by the following command:
|
# by the following command:
|
||||||
#
|
#
|
||||||
# pip-compile --output-file=requirements.txt requirements.in
|
# pip-compile requirements.in
|
||||||
#
|
#
|
||||||
|
accelerate==0.18.0
|
||||||
|
# via -r requirements.in
|
||||||
|
aiohttp==3.8.4
|
||||||
|
# via
|
||||||
|
# datasets
|
||||||
|
# fsspec
|
||||||
|
# langchain
|
||||||
|
aiosignal==1.3.1
|
||||||
|
# via aiohttp
|
||||||
|
async-timeout==4.0.2
|
||||||
|
# via
|
||||||
|
# aiohttp
|
||||||
|
# langchain
|
||||||
|
attrs==23.1.0
|
||||||
|
# via aiohttp
|
||||||
|
bitsandbytes==0.38.1
|
||||||
|
# via -r requirements.in
|
||||||
|
certifi==2022.12.7
|
||||||
|
# via requests
|
||||||
|
charset-normalizer==3.1.0
|
||||||
|
# via
|
||||||
|
# aiohttp
|
||||||
|
# requests
|
||||||
|
cmake==3.26.3
|
||||||
|
# via triton
|
||||||
|
colorama==0.4.6
|
||||||
|
# via -r requirements.in
|
||||||
|
dataclasses-json==0.5.7
|
||||||
|
# via langchain
|
||||||
|
datasets==2.12.0
|
||||||
|
# via -r requirements.in
|
||||||
|
dill==0.3.6
|
||||||
|
# via
|
||||||
|
# datasets
|
||||||
|
# multiprocess
|
||||||
|
filelock==3.12.0
|
||||||
|
# via
|
||||||
|
# huggingface-hub
|
||||||
|
# torch
|
||||||
|
# transformers
|
||||||
|
# triton
|
||||||
|
frozenlist==1.3.3
|
||||||
|
# via
|
||||||
|
# aiohttp
|
||||||
|
# aiosignal
|
||||||
|
fsspec[http]==2023.4.0
|
||||||
|
# via
|
||||||
|
# datasets
|
||||||
|
# huggingface-hub
|
||||||
|
greenlet==2.0.2
|
||||||
|
# via sqlalchemy
|
||||||
|
huggingface-hub==0.14.1
|
||||||
|
# via
|
||||||
|
# datasets
|
||||||
|
# transformers
|
||||||
|
idna==3.4
|
||||||
|
# via
|
||||||
|
# requests
|
||||||
|
# yarl
|
||||||
|
jinja2==3.1.2
|
||||||
|
# via torch
|
||||||
|
langchain==0.0.160
|
||||||
|
# via -r requirements.in
|
||||||
|
lit==16.0.3
|
||||||
|
# via triton
|
||||||
|
loralib==0.1.1
|
||||||
|
# via -r requirements.in
|
||||||
|
markupsafe==2.1.2
|
||||||
|
# via jinja2
|
||||||
|
marshmallow==3.19.0
|
||||||
|
# via
|
||||||
|
# dataclasses-json
|
||||||
|
# marshmallow-enum
|
||||||
|
marshmallow-enum==1.5.1
|
||||||
|
# via dataclasses-json
|
||||||
|
mpmath==1.3.0
|
||||||
|
# via sympy
|
||||||
|
multidict==6.0.4
|
||||||
|
# via
|
||||||
|
# aiohttp
|
||||||
|
# yarl
|
||||||
|
multiprocess==0.70.14
|
||||||
|
# via datasets
|
||||||
|
mypy-extensions==1.0.0
|
||||||
|
# via typing-inspect
|
||||||
|
networkx==3.1
|
||||||
|
# via torch
|
||||||
|
numexpr==2.8.4
|
||||||
|
# via langchain
|
||||||
|
numpy==1.24.3
|
||||||
|
# via
|
||||||
|
# accelerate
|
||||||
|
# datasets
|
||||||
|
# langchain
|
||||||
|
# numexpr
|
||||||
|
# pandas
|
||||||
|
# pyarrow
|
||||||
|
# transformers
|
||||||
|
nvidia-cublas-cu11==11.10.3.66
|
||||||
|
# via
|
||||||
|
# nvidia-cudnn-cu11
|
||||||
|
# nvidia-cusolver-cu11
|
||||||
|
# torch
|
||||||
|
nvidia-cuda-cupti-cu11==11.7.101
|
||||||
|
# via torch
|
||||||
|
nvidia-cuda-nvrtc-cu11==11.7.99
|
||||||
|
# via torch
|
||||||
|
nvidia-cuda-runtime-cu11==11.7.99
|
||||||
|
# via torch
|
||||||
|
nvidia-cudnn-cu11==8.5.0.96
|
||||||
|
# via torch
|
||||||
|
nvidia-cufft-cu11==10.9.0.58
|
||||||
|
# via torch
|
||||||
|
nvidia-curand-cu11==10.2.10.91
|
||||||
|
# via torch
|
||||||
|
nvidia-cusolver-cu11==11.4.0.1
|
||||||
|
# via torch
|
||||||
|
nvidia-cusparse-cu11==11.7.4.91
|
||||||
|
# via torch
|
||||||
|
nvidia-nccl-cu11==2.14.3
|
||||||
|
# via torch
|
||||||
|
nvidia-nvtx-cu11==11.7.91
|
||||||
|
# via torch
|
||||||
|
openapi-schema-pydantic==1.2.4
|
||||||
|
# via langchain
|
||||||
|
packaging==23.1
|
||||||
|
# via
|
||||||
|
# accelerate
|
||||||
|
# datasets
|
||||||
|
# huggingface-hub
|
||||||
|
# marshmallow
|
||||||
|
# transformers
|
||||||
|
pandas==2.0.1
|
||||||
|
# via datasets
|
||||||
|
psutil==5.9.5
|
||||||
|
# via accelerate
|
||||||
|
pyarrow==12.0.0
|
||||||
|
# via datasets
|
||||||
|
pydantic==1.10.7
|
||||||
|
# via
|
||||||
|
# langchain
|
||||||
|
# openapi-schema-pydantic
|
||||||
|
python-dateutil==2.8.2
|
||||||
|
# via pandas
|
||||||
|
pytz==2023.3
|
||||||
|
# via pandas
|
||||||
|
pyyaml==6.0
|
||||||
|
# via
|
||||||
|
# accelerate
|
||||||
|
# datasets
|
||||||
|
# huggingface-hub
|
||||||
|
# langchain
|
||||||
|
# transformers
|
||||||
|
regex==2023.5.5
|
||||||
|
# via transformers
|
||||||
|
requests==2.30.0
|
||||||
|
# via
|
||||||
|
# datasets
|
||||||
|
# fsspec
|
||||||
|
# huggingface-hub
|
||||||
|
# langchain
|
||||||
|
# responses
|
||||||
|
# transformers
|
||||||
|
responses==0.18.0
|
||||||
|
# via datasets
|
||||||
|
sentencepiece==0.1.99
|
||||||
|
# via -r requirements.in
|
||||||
|
six==1.16.0
|
||||||
|
# via python-dateutil
|
||||||
|
sqlalchemy==2.0.12
|
||||||
|
# via langchain
|
||||||
|
sympy==1.11.1
|
||||||
|
# via torch
|
||||||
|
tenacity==8.2.2
|
||||||
|
# via langchain
|
||||||
|
tokenizers==0.13.3
|
||||||
|
# via transformers
|
||||||
|
torch==2.0.0
|
||||||
|
# via
|
||||||
|
# accelerate
|
||||||
|
# triton
|
||||||
|
tqdm==4.65.0
|
||||||
|
# via
|
||||||
|
# datasets
|
||||||
|
# huggingface-hub
|
||||||
|
# langchain
|
||||||
|
# transformers
|
||||||
|
transformers @ git+https://github.com/huggingface/transformers@849367ccf741d8c58aa88ccfe1d52d8636eaf2b7
|
||||||
|
# via -r requirements.in
|
||||||
|
triton==2.0.0
|
||||||
|
# via torch
|
||||||
|
typing-extensions==4.5.0
|
||||||
|
# via
|
||||||
|
# huggingface-hub
|
||||||
|
# pydantic
|
||||||
|
# sqlalchemy
|
||||||
|
# torch
|
||||||
|
# typing-inspect
|
||||||
|
typing-inspect==0.8.0
|
||||||
|
# via dataclasses-json
|
||||||
|
tzdata==2023.3
|
||||||
|
# via pandas
|
||||||
|
urllib3==2.0.2
|
||||||
|
# via
|
||||||
|
# requests
|
||||||
|
# responses
|
||||||
|
wheel==0.40.0
|
||||||
|
# via
|
||||||
|
# nvidia-cublas-cu11
|
||||||
|
# nvidia-cuda-cupti-cu11
|
||||||
|
# nvidia-cuda-runtime-cu11
|
||||||
|
# nvidia-curand-cu11
|
||||||
|
# nvidia-cusparse-cu11
|
||||||
|
# nvidia-nvtx-cu11
|
||||||
|
xxhash==3.2.0
|
||||||
|
# via datasets
|
||||||
|
yarl==1.9.2
|
||||||
|
# via aiohttp
|
||||||
|
|
||||||
|
# The following packages are considered to be unsafe in a requirements file:
|
||||||
|
# setuptools
|
||||||
|
Loading…
Reference in New Issue