Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support loading .pt weights #420

Open
shripadk opened this issue Apr 17, 2024 · 1 comment
Open

Support loading .pt weights #420

shripadk opened this issue Apr 17, 2024 · 1 comment
Assignees

Comments

@shripadk
Copy link

Feature request

Need support for loading models that only contain .pt weights

Motivation

I quantized Mixtral 8x7b model using HQQ (which produces a qmodel.pt file). But I am unable to load the weights in LoRAX as it expects either a .safetensors or .bin weights.

Your contribution

I haven't studied the source enough to submit a PR but from cursory understanding of the code, changes need to be made in hub.py file, specifically:

try:
filenames = weight_hub_files(model_id, revision, extension, api_token)
except EntryNotFoundError as e:
if extension != ".safetensors":
raise e
# Try to see if there are pytorch weights
pt_filenames = weight_hub_files(model_id, revision, extension=".bin", api_token=api_token)
# Change pytorch extension to safetensors extension
# It is possible that we have safetensors weights locally even though they are not on the
# hub if we converted weights locally without pushing them
filenames = [f"{Path(f).stem.lstrip('pytorch_')}.safetensors" for f in pt_filenames]

Though I would also like to be able to load the base model from local rather than remote/from the hub (as explained in this issue: #347)

@magdyksaleh magdyksaleh self-assigned this Apr 18, 2024
@magdyksaleh
Copy link
Collaborator

I will work on a fix for this alongside #347

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants