diff --git a/ingest.py b/ingest.py index 42587b47..d2df88ab 100644 --- a/ingest.py +++ b/ingest.py @@ -164,7 +164,7 @@ def get_embeddings(): if "instructor" in EMBEDDING_MODEL_NAME: return HuggingFaceInstructEmbeddings( model_name=EMBEDDING_MODEL_NAME, - model_kwargs={"device": compute_device}, + model_kwargs={"device": device_type}, embed_instruction='Represent the document for retrieval:', query_instruction='Represent the question for retrieving supporting documents:' ) @@ -172,14 +172,14 @@ def get_embeddings(): elif "bge" in EMBEDDING_MODEL_NAME: return HuggingFaceBgeEmbeddings( model_name=EMBEDDING_MODEL_NAME, - model_kwargs={"device": compute_device}, + model_kwargs={"device": device_type}, query_instruction='Represent this sentence for searching relevant passages:' ) else: return HuggingFaceEmbeddings( model_name=EMBEDDING_MODEL_NAME, - model_kwargs={"device": compute_device}, + model_kwargs={"device": device_type}, ) embeddings = get_embeddings() logging.info(f"Loaded embeddings from {EMBEDDING_MODEL_NAME}")