Skip to content

Commit

Permalink
automatic correct langchain library
Browse files Browse the repository at this point in the history
Automatically select correct langchain class based on the embedding model name.
  • Loading branch information
BBC-Esq committed Jan 3, 2024
1 parent 0b4fa33 commit 747a9b4
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 16 deletions.
43 changes: 30 additions & 13 deletions ingest.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,26 +153,43 @@ def main(device_type):
logging.info(f"Loaded {len(documents)} documents from {SOURCE_DIRECTORY}")
logging.info(f"Split into {len(texts)} chunks of text")

# Create embeddings
embeddings = HuggingFaceInstructEmbeddings(
model_name=EMBEDDING_MODEL_NAME,
model_kwargs={"device": device_type},
)
# change the embedding type here if you are running into issues.
# These are much smaller embeddings and will work for most appications
# If you use HuggingFaceEmbeddings, make sure to also use the same in the
# run_localGPT.py file.

# embeddings = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL_NAME)
"""
(1) Chooses an appropriate langchain library based on the enbedding model name. Matching code is contained within fun_localGPT.py.
(2) Provides additional arguments for instructor and BGE models to improve results, pursuant to the instructions contained on
their respective huggingface repository, project page or github repository.
"""

if "instructor" in EMBEDDING_MODEL_NAME:
return HuggingFaceInstructEmbeddings(
model_name=EMBEDDING_MODEL_NAME,
model_kwargs={"device": compute_device},
embed_instruction='Represent the document for retrieval:',
query_instruction='Represent the question for retrieving supporting documents:'
)

elif "bge" in EMBEDDING_MODEL_NAME:
query_instruction = 'Represent this sentence for searching relevant passages:'

return HuggingFaceBgeEmbeddings(
model_name=EMBEDDING_MODEL_NAME,
model_kwargs={"device": compute_device},
query_instruction='Represent this sentence for searching relevant passages:'
)

else:

return HuggingFaceEmbeddings(
model_name=EMBEDDING_MODEL_NAME,
model_kwargs={"device": compute_device},
)

db = Chroma.from_documents(
texts,
embeddings,
persist_directory=PERSIST_DIRECTORY,
client_settings=CHROMA_SETTINGS,
)



if __name__ == "__main__":
logging.basicConfig(
Expand Down
30 changes: 27 additions & 3 deletions run_localGPT.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,9 +119,33 @@ def retrieval_qa_pipline(device_type, use_history, promptTemplate_type="llama"):
- The QA system retrieves relevant documents using the retriever and then answers questions based on those documents.
"""

embeddings = HuggingFaceInstructEmbeddings(model_name=EMBEDDING_MODEL_NAME, model_kwargs={"device": device_type})
# uncomment the following line if you used HuggingFaceEmbeddings in the ingest.py
# embeddings = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL_NAME)
"""
(1) Chooses an appropriate langchain library based on the enbedding model name. Matching code is contained within ingest.py.
(2) Provides additional arguments for instructor and BGE models to improve results, pursuant to the instructions contained on
their respective huggingface repository, project page or github repository.
"""

if "instructor" in EMBEDDING_MODEL_NAME:
return HuggingFaceInstructEmbeddings(
model_name=EMBEDDING_MODEL_NAME,
model_kwargs={"device": compute_device},
embed_instruction='Represent the document for retrieval:',
query_instruction='Represent the question for retrieving supporting documents:'
)

elif "bge" in EMBEDDING_MODEL_NAME:
return HuggingFaceBgeEmbeddings(
model_name=EMBEDDING_MODEL_NAME,
model_kwargs={"device": compute_device},
query_instruction='Represent this sentence for searching relevant passages:'
)

else:
return HuggingFaceEmbeddings(
model_name=EMBEDDING_MODEL_NAME,
model_kwargs={"device": compute_device},
)

# load the vectorstore
db = Chroma(
Expand Down

0 comments on commit 747a9b4

Please sign in to comment.