You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
75 lines
2.0 KiB
Python
75 lines
2.0 KiB
Python
2 years ago
|
import textwrap
|
||
|
|
||
|
import colorama
|
||
|
from transformers import LlamaForCausalLM, LlamaTokenizer
|
||
|
from transformers import logging as t_logging
|
||
|
from transformers import pipeline
|
||
|
|
||
|
# Configure logging level for transformers library
|
||
|
t_logging.logging.set_verbosity_info()
|
||
|
|
||
|
|
||
|
# Utility Functions
|
||
|
def get_prompt(human_prompt):
|
||
|
prompt_template = f"### Human: {human_prompt} \n### Assistant:"
|
||
|
return prompt_template
|
||
|
|
||
|
|
||
|
def remove_human_text(text):
|
||
|
return text.split("### Human:", 1)[0]
|
||
|
|
||
|
|
||
|
def parse_text(data):
|
||
|
for item in data:
|
||
|
text = item["generated_text"]
|
||
|
assistant_text_index = text.find("### Assistant:")
|
||
|
if assistant_text_index != -1:
|
||
|
assistant_text = text[
|
||
|
assistant_text_index + len("### Assistant:") :
|
||
|
].strip()
|
||
|
assistant_text = remove_human_text(assistant_text)
|
||
|
wrapped_text = textwrap.fill(assistant_text, width=100)
|
||
|
print(wrapped_text)
|
||
|
|
||
|
|
||
|
# Reasoning question
|
||
|
EXAMPLE_REASONING = "Answer the following question by reasoning step by step. \
|
||
|
The cafeteria had 22 apples. If they used 20 for lunch, and bought 6 more, \
|
||
|
how many apple do they have?"
|
||
|
|
||
|
|
||
|
# User interface
|
||
|
def main(model_dir):
|
||
|
# Model loading for inference
|
||
|
tokenizer = LlamaTokenizer.from_pretrained(model_dir)
|
||
|
|
||
|
base_model = LlamaForCausalLM.from_pretrained(
|
||
|
model_dir,
|
||
|
load_in_8bit=True,
|
||
|
device_map="auto",
|
||
|
)
|
||
|
|
||
|
pipe = pipeline(
|
||
|
"text-generation",
|
||
|
model=base_model,
|
||
|
tokenizer=tokenizer,
|
||
|
max_length=512,
|
||
|
temperature=0.7,
|
||
|
top_p=0.95,
|
||
|
repetition_penalty=1.15,
|
||
|
)
|
||
|
|
||
|
print("Reading for inference!")
|
||
|
while True:
|
||
|
input_prompt = ""
|
||
|
input_prompt = input("USER:")
|
||
|
print(colorama.Style.DIM + f"You are submitting: {input_prompt}")
|
||
|
print(colorama.Style.RESET_ALL)
|
||
|
raw_output = pipe(get_prompt(input_prompt))
|
||
|
parse_text(raw_output)
|
||
|
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
print("Warming up the engines...")
|
||
|
main()
|