Skip to content

Commit

Permalink
fix: Downloading private adapters from HF (#443)
Browse files Browse the repository at this point in the history
  • Loading branch information
tgaddair committed Apr 28, 2024
1 parent 1c9b528 commit 75cd88a
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 1 deletion.
2 changes: 1 addition & 1 deletion server/lorax_server/utils/sources/hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def download_model_assets(self):

def download_file(self, filename: str, ignore_errors: bool = False) -> Optional[Path]:
try:
return Path(hf_hub_download(self.model_id, revision=None, filename=filename))
return Path(hf_hub_download(self.model_id, revision=None, filename=filename, token=self.api_token))
except Exception as e:
if ignore_errors:
return None
Expand Down
23 changes: 23 additions & 0 deletions server/tests/adapters/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import os

import pytest
from huggingface_hub.utils import RepositoryNotFoundError

from lorax_server.adapters.utils import download_adapter
from lorax_server.utils.sources import HUB


def test_download_private_adapter_hf():
# store and unset HUGGING_FACE_HUB_TOKEN from the environment
token = os.environ.pop("HUGGING_FACE_HUB_TOKEN", None)
assert token is not None, "HUGGING_FACE_HUB_TOKEN must be set in the environment to run this test"

# verify download fails without the token set
with pytest.raises(RepositoryNotFoundError):
download_adapter("predibase/test-private-lora", HUB, api_token=None)

# pass in the token and verify download succeeds
download_adapter("predibase/test-private-lora", HUB, api_token=token)

# set the token back in the environment
os.environ["HUGGING_FACE_HUB_TOKEN"] = token

0 comments on commit 75cd88a

Please sign in to comment.