Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat:chroma store refactor #1508

Merged
merged 6 commits into from
May 11, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion dbgpt/_version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
version = "0.5.5"
version = "0.5.6"
12 changes: 7 additions & 5 deletions dbgpt/app/knowledge/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,11 +247,13 @@ def sync_knowledge_document(self, space_name, sync_request: DocumentSyncRequest)
doc_ids = sync_request.doc_ids
self.model_name = sync_request.model_name or CFG.LLM_MODEL
for doc_id in doc_ids:
query = KnowledgeDocumentEntity(
id=doc_id,
space=space_name,
)
doc = knowledge_document_dao.get_knowledge_documents(query)[0]
query = KnowledgeDocumentEntity(id=doc_id)
docs = knowledge_document_dao.get_documents(query)
if len(docs) == 0:
raise Exception(
f"there are document called, doc_id: {sync_request.doc_id}"
)
doc = docs[0]
if (
doc.status == SyncStatus.RUNNING.name
or doc.status == SyncStatus.FINISHED.name
Expand Down
30 changes: 30 additions & 0 deletions dbgpt/storage/vector_store/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,36 @@ def load_document_with_limit(
)
return ids

def filter_by_score_threshold(
self, chunks: List[Chunk], score_threshold: float
) -> List[Chunk]:
"""Filter chunks by score threshold.

Args:
chunks(List[Chunks]): The chunks to filter.
score_threshold(float): The score threshold.
Return:
List[Chunks]: The filtered chunks.
"""
candidates_chunks = chunks
if score_threshold is not None:
candidates_chunks = [
Chunk(
metadata=chunk.metadata,
content=chunk.content,
score=chunk.score,
chunk_id=str(id),
)
for chunk in chunks
if chunk.score >= score_threshold
]
if len(candidates_chunks) == 0:
logger.warning(
"No relevant docs were retrieved using the relevance score"
f" threshold {score_threshold}"
)
return candidates_chunks

@abstractmethod
def similar_search(
self, text: str, topk: int, filters: Optional[MetadataFilters] = None
Expand Down
148 changes: 117 additions & 31 deletions dbgpt/storage/vector_store/chroma_store.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Chroma vector store."""
import logging
import os
from typing import List, Optional
from typing import Any, Dict, Iterable, List, Mapping, Optional, Union

from chromadb import PersistentClient
from chromadb.config import Settings
Expand Down Expand Up @@ -55,9 +55,11 @@ class ChromaStore(VectorStoreBase):
"""Chroma vector store."""

def __init__(self, vector_store_config: ChromaVectorConfig) -> None:
"""Create a ChromaStore instance."""
from langchain.vectorstores import Chroma
"""Create a ChromaStore instance.

Args:
vector_store_config(ChromaVectorConfig): vector store config.
"""
chroma_vector_config = vector_store_config.to_dict(exclude_none=True)
chroma_path = chroma_vector_config.get(
"persist_path", os.path.join(PILOT_PATH, "data")
Expand All @@ -71,31 +73,35 @@ def __init__(self, vector_store_config: ChromaVectorConfig) -> None:
persist_directory=self.persist_dir,
anonymized_telemetry=False,
)
client = PersistentClient(path=self.persist_dir, settings=chroma_settings)
self._chroma_client = PersistentClient(
path=self.persist_dir, settings=chroma_settings
)

collection_metadata = chroma_vector_config.get("collection_metadata") or {
"hnsw:space": "cosine"
}
self.vector_store_client = Chroma(
persist_directory=self.persist_dir,
embedding_function=self.embeddings,
# client_settings=chroma_settings,
client=client,
collection_metadata=collection_metadata,
) # type: ignore
self._collection = self._chroma_client.get_or_create_collection(
name=vector_store_config.name,
embedding_function=None,
metadata=collection_metadata,
)

def similar_search(
self, text, topk, filters: Optional[MetadataFilters] = None
) -> List[Chunk]:
"""Search similar documents."""
logger.info("ChromaStore similar search")
where_filters = self.convert_metadata_filters(filters) if filters else None
lc_documents = self.vector_store_client.similarity_search(
text, topk, filter=where_filters
chroma_results = self._query(
text=text,
topk=topk,
filters=filters,
)
return [
Chunk(content=doc.page_content, metadata=doc.metadata)
for doc in lc_documents
Chunk(content=chroma_result[0], metadata=chroma_result[1] or {}, score=0.0)
for chroma_result in zip(
chroma_results["documents"][0],
chroma_results["metadatas"][0],
)
]

def similar_search_with_scores(
Expand All @@ -114,19 +120,26 @@ def similar_search_with_scores(
filters(MetadataFilters): metadata filters, defaults to None
"""
logger.info("ChromaStore similar search with scores")
where_filters = self.convert_metadata_filters(filters) if filters else None
docs_and_scores = (
self.vector_store_client.similarity_search_with_relevance_scores(
query=text,
k=topk,
score_threshold=score_threshold,
filter=where_filters,
)
chroma_results = self._query(
text=text,
topk=topk,
filters=filters,
)
return [
Chunk(content=doc.page_content, metadata=doc.metadata, score=score)
for doc, score in docs_and_scores
chunks = [
(
Chunk(
content=chroma_result[0],
metadata=chroma_result[1] or {},
score=chroma_result[2],
)
)
for chroma_result in zip(
chroma_results["documents"][0],
chroma_results["metadatas"][0],
chroma_results["distances"][0],
)
]
return self.filter_by_score_threshold(chunks, score_threshold)

def vector_name_exists(self) -> bool:
"""Whether vector name exists."""
Expand All @@ -144,13 +157,17 @@ def load_document(self, chunks: List[Chunk]) -> List[str]:
texts = [chunk.content for chunk in chunks]
metadatas = [chunk.metadata for chunk in chunks]
ids = [chunk.chunk_id for chunk in chunks]
self.vector_store_client.add_texts(texts=texts, metadatas=metadatas, ids=ids)
chroma_metadatas = [
_transform_chroma_metadata(metadata) for metadata in metadatas
]
self._add_texts(texts=texts, metadatas=chroma_metadatas, ids=ids)
return ids

def delete_vector_name(self, vector_name: str):
"""Delete vector name."""
logger.info(f"chroma vector_name:{vector_name} begin delete...")
self.vector_store_client.delete_collection()
# self.vector_store_client.delete_collection()
self._chroma_client.delete_collection(self._collection.name)
self._clean_persist_folder()
return True

Expand All @@ -159,8 +176,7 @@ def delete_by_ids(self, ids):
logger.info(f"begin delete chroma ids: {ids}")
ids = ids.split(",")
if len(ids) > 0:
collection = self.vector_store_client._collection
collection.delete(ids=ids)
self._collection.delete(ids=ids)

def convert_metadata_filters(
self,
Expand Down Expand Up @@ -198,6 +214,65 @@ def convert_metadata_filters(
where_filters[chroma_condition] = filters_list
return where_filters

def _add_texts(
self,
texts: Iterable[str],
ids: List[str],
metadatas: Optional[List[Mapping[str, Union[str, int, float, bool]]]] = None,
) -> List[str]:
"""Add texts to Chroma collection.

Args:
texts(Iterable[str]): texts.
metadatas(Optional[List[dict]]): metadatas.
ids(Optional[List[str]]): ids.
Returns:
List[str]: ids.
"""
embeddings = None
texts = list(texts)
if self.embeddings is not None:
embeddings = self.embeddings.embed_documents(texts)
if metadatas:
try:
self._collection.upsert(
metadatas=metadatas,
embeddings=embeddings, # type: ignore
documents=texts,
ids=ids,
)
except ValueError as e:
logger.error(f"Error upsert chromadb with metadata: {e}")
else:
self._collection.upsert(
embeddings=embeddings, # type: ignore
documents=texts,
ids=ids,
)
return ids

def _query(self, text: str, topk: int, filters: Optional[MetadataFilters] = None):
"""Query Chroma collection.

Args:
text(str): query text.
topk(int): topk.
filters(MetadataFilters): metadata filters.
Returns:
dict: query result.
"""
if not text:
return {}
where_filters = self.convert_metadata_filters(filters) if filters else None
if self.embeddings is None:
raise ValueError("Chroma Embeddings is None")
query_embedding = self.embeddings.embed_query(text)
return self._collection.query(
query_embeddings=query_embedding,
n_results=topk,
where=where_filters,
)

def _clean_persist_folder(self):
"""Clean persist folder."""
for root, dirs, files in os.walk(self.persist_dir, topdown=False):
Expand Down Expand Up @@ -230,3 +305,14 @@ def _convert_chroma_filter_operator(operator: str) -> str:
return "$lte"
else:
raise ValueError(f"Chroma Where operator {operator} not supported")


def _transform_chroma_metadata(
metadata: Dict[str, Any]
) -> Mapping[str, str | int | float | bool]:
"""Transform metadata to Chroma metadata."""
transformed = {}
for key, value in metadata.items():
if isinstance(value, (str, int, float, bool)):
transformed[key] = value
return transformed
2 changes: 1 addition & 1 deletion dbgpt/storage/vector_store/pgvector_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def __init__(self, vector_store_config: PGVectorConfig) -> None:
embedding_function=self.embeddings,
collection_name=self.collection_name,
connection_string=self.connection_string,
)
) # mypy: ignore

def similar_search(
self, text: str, topk: int, filters: Optional[MetadataFilters] = None
Expand Down
3 changes: 1 addition & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
IS_DEV_MODE = os.getenv("IS_DEV_MODE", "true").lower() == "true"
# If you modify the version, please modify the version in the following files:
# dbgpt/_version.py
DB_GPT_VERSION = os.getenv("DB_GPT_VERSION", "0.5.5")
DB_GPT_VERSION = os.getenv("DB_GPT_VERSION", "0.5.6")

BUILD_NO_CACHE = os.getenv("BUILD_NO_CACHE", "true").lower() == "true"
LLAMA_CPP_GPU_ACCELERATION = (
Expand Down Expand Up @@ -499,7 +499,6 @@ def knowledge_requires():
pip install "dbgpt[rag]"
"""
setup_spec.extras["rag"] = setup_spec.extras["vstore"] + [
"langchain>=0.0.286",
"spacy>=3.7",
"markdown",
"bs4",
Expand Down
Loading