Skip to content

Commit

Permalink
Merge pull request #737 from ivanvaccarics/main
Browse files Browse the repository at this point in the history
Update code and requirements to use Mixtral 8x7B and Falcon-40b
  • Loading branch information
PromtEngineer committed Feb 17, 2024
2 parents e6c318f + 4c89272 commit 8dc5a72
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 6 deletions.
18 changes: 13 additions & 5 deletions load_models.py
Expand Up @@ -7,7 +7,7 @@

from huggingface_hub import hf_hub_download
from langchain.llms import LlamaCpp
from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaForCausalLM, LlamaTokenizer
from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaForCausalLM, LlamaTokenizer, BitsAndBytesConfig

from constants import CONTEXT_WINDOW_SIZE, MAX_NEW_TOKENS, MODELS_PATH, N_BATCH, N_GPU_LAYERS

Expand Down Expand Up @@ -143,18 +143,26 @@ def load_full_model(model_id, model_basename, device_type, logging):
logging.info("Using AutoModelForCausalLM for full models")
tokenizer = AutoTokenizer.from_pretrained(model_id, cache_dir="./models/")
logging.info("Tokenizer loaded")
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16
)
model = AutoModelForCausalLM.from_pretrained(
model_id,
device_map="auto",
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
cache_dir=MODELS_PATH,
trust_remote_code=True, # set these if you are using NVIDIA GPU
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
max_memory={0: "15GB"}, # Uncomment this line with you encounter CUDA out of memory errors
quantization_config=bnb_config
# load_in_4bit=True,
# bnb_4bit_quant_type="nf4",
# bnb_4bit_compute_dtype=torch.float16,
# max_memory={0: "15GB"}, # Uncomment this line with you encounter CUDA out of memory errors
)

model.tie_weights()
return model, tokenizer

Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Expand Up @@ -11,7 +11,7 @@ autoawq; sys_platform != 'darwin'
protobuf==3.20.2; sys_platform != 'darwin'
protobuf==3.20.2; sys_platform == 'darwin' and platform_machine != 'arm64'
protobuf==3.20.3; sys_platform == 'darwin' and platform_machine == 'arm64'
auto-gptq==0.2.2; sys_platform != 'darwin'
auto-gptq==0.6.0; sys_platform != 'darwin'
docx2txt
unstructured
unstructured[pdf]
Expand Down

0 comments on commit 8dc5a72

Please sign in to comment.