Skip to content

Commit

Permalink
issue #188 bug fix
Browse files Browse the repository at this point in the history
The default model was set to TheBloke/WizardLM-7B-uncensored-GPTQ which causes issue when running on cpu.

Change the default to TheBloke/vicuna-7B-1.1-HF
When --device_type is cpu or mps, the model_basename will be set to None and will use LlamaForCausalLM. This is a temporary fix. Need a permanent fix for M1/M2.
  • Loading branch information
PromtEngineer committed Jun 27, 2023
1 parent 9aab283 commit 506d0ee
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 97 deletions.
10 changes: 7 additions & 3 deletions run_localGPT.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ def load_model(device_type, model_id, model_basename=None):
Raises:
ValueError: If an unsupported model or device type is provided.
"""
if device_type.lower() in ["cpu", "mps"]:
model_basename=None

logging.info(f"Loading Model: {model_id}, on: {device_type}")
logging.info("This action can take a few minutes!")
Expand Down Expand Up @@ -177,7 +179,8 @@ def main(device_type, show_sources):
# load the LLM for generating Natural Language responses

# for HF models
# model_id = "TheBloke/vicuna-7B-1.1-HF"
model_id = "TheBloke/vicuna-7B-1.1-HF"
model_basename=None
# model_id = "TheBloke/Wizard-Vicuna-7B-Uncensored-HF"
# model_id = "TheBloke/guanaco-7B-HF"
# model_id = 'NousResearch/Nous-Hermes-13b' # Requires ~ 23GB VRAM. Using STransformers
Expand All @@ -192,8 +195,9 @@ def main(device_type, show_sources):
# ~21GB VRAM. Using STransformers alongside can potentially create OOM on 24GB cards.
# model_id = "TheBloke/wizardLM-7B-GPTQ"
# model_basename = "wizardLM-7B-GPTQ-4bit.compat.no-act-order.safetensors"
model_id = "TheBloke/WizardLM-7B-uncensored-GPTQ"
model_basename = "WizardLM-7B-uncensored-GPTQ-4bit-128g.compat.no-act-order.safetensors"
# model_id = "TheBloke/WizardLM-7B-uncensored-GPTQ"
# model_basename = "WizardLM-7B-uncensored-GPTQ-4bit-128g.compat.no-act-order.safetensors"

llm = load_model(device_type, model_id=model_id, model_basename=model_basename)

qa = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=retriever, return_source_documents=True)
Expand Down
94 changes: 0 additions & 94 deletions xlxs_loader.py

This file was deleted.

0 comments on commit 506d0ee

Please sign in to comment.