Skip to content

Commit

Permalink
added support for llama3
Browse files Browse the repository at this point in the history
  • Loading branch information
PromtEngineer committed May 3, 2024
1 parent e997a8a commit e7311a2
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 11 deletions.
4 changes: 2 additions & 2 deletions README.md
Expand Up @@ -71,14 +71,14 @@ For `NVIDIA` GPUs support, use `cuBLAS`

```shell
# Example: cuBLAS
CMAKE_ARGS="-DLLAMA_CUBLAS=on" FORCE_CMAKE=1 pip install llama-cpp-python==0.1.83 --no-cache-dir
CMAKE_ARGS="-DLLAMA_CUBLAS=on" FORCE_CMAKE=1 pip install llama-cpp-python --no-cache-dir
```

For Apple Metal (`M1/M2`) support, use

```shell
# Example: METAL
CMAKE_ARGS="-DLLAMA_METAL=on" FORCE_CMAKE=1 pip install llama-cpp-python==0.1.83 --no-cache-dir
CMAKE_ARGS="-DLLAMA_METAL=on" FORCE_CMAKE=1 pip install llama-cpp-python --no-cache-dir
```
For more details, please refer to [llama-cpp](https://github.com/abetlen/llama-cpp-python#installation-with-openblas--cublas--clblast--metal)

Expand Down
17 changes: 14 additions & 3 deletions constants.py
Expand Up @@ -29,7 +29,7 @@
)

# Context Window and Max New Tokens
CONTEXT_WINDOW_SIZE = 4096
CONTEXT_WINDOW_SIZE = 8096
MAX_NEW_TOKENS = CONTEXT_WINDOW_SIZE # int(CONTEXT_WINDOW_SIZE/4)

#### If you get a "not enough space in the buffer" error, you should reduce the values below, start with half of the original values and keep halving the value until the error stops appearing
Expand Down Expand Up @@ -100,8 +100,19 @@
# MODEL_ID = "TheBloke/Llama-2-13b-Chat-GGUF"
# MODEL_BASENAME = "llama-2-13b-chat.Q4_K_M.gguf"

MODEL_ID = "TheBloke/Llama-2-7b-Chat-GGUF"
MODEL_BASENAME = "llama-2-7b-chat.Q4_K_M.gguf"
# MODEL_ID = "TheBloke/Llama-2-7b-Chat-GGUF"
# MODEL_BASENAME = "llama-2-7b-chat.Q4_K_M.gguf"

# MODEL_ID = "QuantFactory/Meta-Llama-3-8B-Instruct-GGUF"
# MODEL_BASENAME = "Meta-Llama-3-8B-Instruct.Q4_K_M.gguf"

# LLAMA 3 # use for Apple Silicon
# MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct"
# MODEL_BASENAME = None

# LLAMA 3 # use for NVIDIA GPUs
# MODEL_ID = "unsloth/llama-3-8b-bnb-4bit"
# MODEL_BASENAME = None

# MODEL_ID = "TheBloke/Mistral-7B-Instruct-v0.1-GGUF"
# MODEL_BASENAME = "mistral-7b-instruct-v0.1.Q8_0.gguf"
Expand Down
16 changes: 13 additions & 3 deletions load_models.py
Expand Up @@ -136,9 +136,19 @@ def load_full_model(model_id, model_basename, device_type, logging):
"""

if device_type.lower() in ["mps", "cpu"]:
logging.info("Using LlamaTokenizer")
tokenizer = LlamaTokenizer.from_pretrained(model_id, cache_dir="./models/")
model = LlamaForCausalLM.from_pretrained(model_id, cache_dir="./models/")
logging.info("Using AutoModelForCausalLM")
# tokenizer = LlamaTokenizer.from_pretrained(model_id, cache_dir="./models/")
# model = LlamaForCausalLM.from_pretrained(model_id, cache_dir="./models/")

model = AutoModelForCausalLM.from_pretrained(model_id,
# quantization_config=quantization_config,
# low_cpu_mem_usage=True,
# torch_dtype="auto",
torch_dtype=torch.bfloat16,
device_map="auto",
cache_dir="./models/")

tokenizer = AutoTokenizer.from_pretrained(model_id, cache_dir="./models/")
else:
logging.info("Using AutoModelForCausalLM for full models")
tokenizer = AutoTokenizer.from_pretrained(model_id, cache_dir="./models/")
Expand Down
24 changes: 24 additions & 0 deletions prompt_template_utils.py
Expand Up @@ -33,6 +33,28 @@ def get_prompt_template(system_prompt=system_prompt, promptTemplate_type=None, h

prompt_template = B_INST + SYSTEM_PROMPT + instruction + E_INST
prompt = PromptTemplate(input_variables=["context", "question"], template=prompt_template)

elif promptTemplate_type == "llama3":

B_INST, E_INST = "<|start_header_id|>user<|end_header_id|>", "<|eot_id|>"
B_SYS, E_SYS = "<|begin_of_text|><|start_header_id|>system<|end_header_id|> ", "<|eot_id|>"
ASSISTANT_INST = "<|start_header_id|>assistant<|end_header_id|>"
SYSTEM_PROMPT = B_SYS + system_prompt + E_SYS
if history:
instruction = """
Context: {history} \n {context}
User: {question}"""

prompt_template = SYSTEM_PROMPT + B_INST + instruction + ASSISTANT_INST
prompt = PromptTemplate(input_variables=["history", "context", "question"], template=prompt_template)
else:
instruction = """
Context: {context}
User: {question}"""

prompt_template = SYSTEM_PROMPT + B_INST + instruction + ASSISTANT_INST
prompt = PromptTemplate(input_variables=["context", "question"], template=prompt_template)

elif promptTemplate_type == "mistral":
B_INST, E_INST = "<s>[INST] ", " [/INST]"
if history:
Expand Down Expand Up @@ -82,6 +104,8 @@ def get_prompt_template(system_prompt=system_prompt, promptTemplate_type=None, h

memory = ConversationBufferMemory(input_key="question", memory_key="history")

print(f"Here is the prompt used: {prompt}")

return (
prompt,
memory,
Expand Down
6 changes: 3 additions & 3 deletions run_localGPT.py
Expand Up @@ -209,11 +209,11 @@ def retrieval_qa_pipline(device_type, use_history, promptTemplate_type="llama"):
)
@click.option(
"--model_type",
default="llama",
default="llama3",
type=click.Choice(
["llama", "mistral", "non_llama"],
["llama3", "llama", "mistral", "non_llama"],
),
help="model type, llama, mistral or non_llama",
help="model type, llama3, llama, mistral or non_llama",
)
@click.option(
"--save_qa",
Expand Down

0 comments on commit e7311a2

Please sign in to comment.