From e4383d047862444aa0f42da36746bf058450a797 Mon Sep 17 00:00:00 2001 From: tedcochran Date: Sat, 3 Feb 2024 19:51:24 -0700 Subject: [PATCH] Created a function for embeddings options --- ingest.py | 46 +++++++++++++++++++++++----------------------- run_localGPT.py | 43 +++++++++++++++++++++++-------------------- 2 files changed, 46 insertions(+), 43 deletions(-) diff --git a/ingest.py b/ingest.py index bad8cd1c..42587b47 100644 --- a/ingest.py +++ b/ingest.py @@ -160,29 +160,29 @@ def main(device_type): their respective huggingface repository, project page or github repository. """ - if "instructor" in EMBEDDING_MODEL_NAME: - return HuggingFaceInstructEmbeddings( - model_name=EMBEDDING_MODEL_NAME, - model_kwargs={"device": compute_device}, - embed_instruction='Represent the document for retrieval:', - query_instruction='Represent the question for retrieving supporting documents:' - ) - - elif "bge" in EMBEDDING_MODEL_NAME: - query_instruction = 'Represent this sentence for searching relevant passages:' - - return HuggingFaceBgeEmbeddings( - model_name=EMBEDDING_MODEL_NAME, - model_kwargs={"device": compute_device}, - query_instruction='Represent this sentence for searching relevant passages:' - ) - - else: - - return HuggingFaceEmbeddings( - model_name=EMBEDDING_MODEL_NAME, - model_kwargs={"device": compute_device}, - ) + def get_embeddings(): + if "instructor" in EMBEDDING_MODEL_NAME: + return HuggingFaceInstructEmbeddings( + model_name=EMBEDDING_MODEL_NAME, + model_kwargs={"device": compute_device}, + embed_instruction='Represent the document for retrieval:', + query_instruction='Represent the question for retrieving supporting documents:' + ) + + elif "bge" in EMBEDDING_MODEL_NAME: + return HuggingFaceBgeEmbeddings( + model_name=EMBEDDING_MODEL_NAME, + model_kwargs={"device": compute_device}, + query_instruction='Represent this sentence for searching relevant passages:' + ) + + else: + return HuggingFaceEmbeddings( + model_name=EMBEDDING_MODEL_NAME, + model_kwargs={"device": compute_device}, + ) + embeddings = get_embeddings() + logging.info(f"Loaded embeddings from {EMBEDDING_MODEL_NAME}") db = Chroma.from_documents( texts, diff --git a/run_localGPT.py b/run_localGPT.py index eedb3c35..5f01a4e7 100644 --- a/run_localGPT.py +++ b/run_localGPT.py @@ -126,27 +126,30 @@ def retrieval_qa_pipline(device_type, use_history, promptTemplate_type="llama"): their respective huggingface repository, project page or github repository. """ - if "instructor" in EMBEDDING_MODEL_NAME: - return HuggingFaceInstructEmbeddings( - model_name=EMBEDDING_MODEL_NAME, - model_kwargs={"device": compute_device}, - embed_instruction='Represent the document for retrieval:', - query_instruction='Represent the question for retrieving supporting documents:' - ) - - elif "bge" in EMBEDDING_MODEL_NAME: - return HuggingFaceBgeEmbeddings( - model_name=EMBEDDING_MODEL_NAME, - model_kwargs={"device": compute_device}, - query_instruction='Represent this sentence for searching relevant passages:' - ) - - else: - return HuggingFaceEmbeddings( - model_name=EMBEDDING_MODEL_NAME, - model_kwargs={"device": compute_device}, - ) + def get_embeddings(): + if "instructor" in EMBEDDING_MODEL_NAME: + return HuggingFaceInstructEmbeddings( + model_name=EMBEDDING_MODEL_NAME, + model_kwargs={"device": compute_device}, + embed_instruction='Represent the document for retrieval:', + query_instruction='Represent the question for retrieving supporting documents:' + ) + + elif "bge" in EMBEDDING_MODEL_NAME: + return HuggingFaceBgeEmbeddings( + model_name=EMBEDDING_MODEL_NAME, + model_kwargs={"device": compute_device}, + query_instruction='Represent this sentence for searching relevant passages:' + ) + else: + return HuggingFaceEmbeddings( + model_name=EMBEDDING_MODEL_NAME, + model_kwargs={"device": compute_device}, + ) + embeddings = get_embeddings() + logging.info(f"Loaded embeddings from {EMBEDDING_MODEL_NAME}") + # load the vectorstore db = Chroma( persist_directory=PERSIST_DIRECTORY,