Skip to content

Commit

Permalink
community: Chroma Adding create_collection_if_not_exists flag to Chro…
Browse files Browse the repository at this point in the history
…ma constructor (langchain-ai#21420)

- **Description:** Adds the ability to either `get_or_create` or simply
`get_collection`. This is useful when dealing with read-only Chroma
instances where users are constraint to using `get_collection`. Targeted
at Http/CloudClients mostly.
- **Issue:** chroma-core/chroma#2163
- **Dependencies:** N/A
- **Twitter handle:** `@t_azarov`




| Collection Exists | create_collection_if_not_exists | Outcome | test |

|-------------------|---------------------------------|----------------------------------------------------------------|----------------------------------------------------------|
| True | False | No errors, collection state unchanged |
`test_create_collection_if_not_exist_false_existing` |
| True | True | No errors, collection state unchanged |
`test_create_collection_if_not_exist_true_existing` |
| False | False | Error, `get_collection()` fails |
`test_create_collection_if_not_exist_false_non_existing` |
| False | True | No errors, `get_or_create_collection()` creates the
collection | `test_create_collection_if_not_exist_true_non_existing` |
  • Loading branch information
tazarov authored and kyle-cassidy committed May 10, 2024
1 parent 6ae784b commit c0147d6
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 11 deletions.
20 changes: 13 additions & 7 deletions libs/partners/chroma/langchain_chroma/vectorstores.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@ def cosine_similarity(X: Matrix, Y: Matrix) -> np.ndarray:
Y = np.array(Y)
if X.shape[1] != Y.shape[1]:
raise ValueError(
f"Number of columns in X and Y must be the same. X has shape {X.shape} "
"Number of columns in X and Y must be the same. X has shape"
f"{X.shape} "
f"and Y has shape {Y.shape}."
)

Expand Down Expand Up @@ -133,6 +134,7 @@ def __init__(
collection_metadata: Optional[Dict] = None,
client: Optional[chromadb.ClientAPI] = None,
relevance_score_fn: Optional[Callable[[float], float]] = None,
create_collection_if_not_exists: Optional[bool] = True,
) -> None:
"""Initialize with a Chroma client."""

Expand Down Expand Up @@ -161,11 +163,14 @@ def __init__(
)

self._embedding_function = embedding_function
self._collection = self._client.get_or_create_collection(
name=collection_name,
embedding_function=None,
metadata=collection_metadata,
)
if create_collection_if_not_exists:
self._collection = self._client.get_or_create_collection(
name=collection_name,
embedding_function=None,
metadata=collection_metadata,
)
else:
self._collection = self._client.get_collection(name=collection_name)
self.override_relevance_score_fn = relevance_score_fn

@property
Expand Down Expand Up @@ -650,7 +655,8 @@ def update_document(self, document_id: str, document: Document) -> None:
"""
return self.update_documents([document_id], [document])

def update_documents(self, ids: List[str], documents: List[Document]) -> None: # type: ignore
# type: ignore
def update_documents(self, ids: List[str], documents: List[Document]) -> None:
"""Update a document in the collection.
Args:
Expand Down
83 changes: 79 additions & 4 deletions libs/partners/chroma/tests/integration_tests/test_vectorstores.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
"""Test Chroma functionality."""

import uuid
from typing import Generator

import chromadb
import pytest
import requests
from chromadb.api.client import SharedSystemClient
from langchain_core.documents import Document
from langchain_core.embeddings.fake import FakeEmbeddings as Fak

Expand All @@ -15,6 +17,13 @@
)


@pytest.fixture()
def client() -> Generator[chromadb.ClientAPI, None, None]:
SharedSystemClient.clear_system_cache()
client = chromadb.Client(chromadb.config.Settings())
yield client


def test_chroma() -> None:
"""Test end to end construction and search."""
texts = ["foo", "bar", "baz"]
Expand Down Expand Up @@ -271,10 +280,7 @@ def test_chroma_with_relevance_score_custom_normalization_fn() -> None:
]


def test_init_from_client() -> None:
import chromadb

client = chromadb.Client(chromadb.config.Settings())
def test_init_from_client(client: chromadb.ClientAPI) -> None:
Chroma(client=client)


Expand Down Expand Up @@ -414,3 +420,72 @@ def test_chroma_legacy_batching() -> None:
)

db.delete_collection()


def test_create_collection_if_not_exist_default() -> None:
"""Tests existing behaviour without the new create_collection_if_not_exists flag."""
texts = ["foo", "bar", "baz"]
docsearch = Chroma.from_texts(
collection_name="test_collection", texts=texts, embedding=FakeEmbeddings()
)
assert docsearch._client.get_collection("test_collection") is not None
docsearch.delete_collection()


def test_create_collection_if_not_exist_true_existing(
client: chromadb.ClientAPI,
) -> None:
"""Tests create_collection_if_not_exists=True and collection already existing."""
client.create_collection("test_collection")
vectorstore = Chroma(
client=client,
collection_name="test_collection",
embedding_function=FakeEmbeddings(),
create_collection_if_not_exists=True,
)
assert vectorstore._client.get_collection("test_collection") is not None
vectorstore.delete_collection()


def test_create_collection_if_not_exist_false_existing(
client: chromadb.ClientAPI,
) -> None:
"""Tests create_collection_if_not_exists=False and collection already existing."""
client.create_collection("test_collection")
vectorstore = Chroma(
client=client,
collection_name="test_collection",
embedding_function=FakeEmbeddings(),
create_collection_if_not_exists=False,
)
assert vectorstore._client.get_collection("test_collection") is not None
vectorstore.delete_collection()


def test_create_collection_if_not_exist_false_non_existing(
client: chromadb.ClientAPI,
) -> None:
"""Tests create_collection_if_not_exists=False and collection not-existing,
should raise."""
with pytest.raises(Exception, match="does not exist"):
Chroma(
client=client,
collection_name="test_collection",
embedding_function=FakeEmbeddings(),
create_collection_if_not_exists=False,
)


def test_create_collection_if_not_exist_true_non_existing(
client: chromadb.ClientAPI,
) -> None:
"""Tests create_collection_if_not_exists=True and collection non-existing. ."""
vectorstore = Chroma(
client=client,
collection_name="test_collection",
embedding_function=FakeEmbeddings(),
create_collection_if_not_exists=True,
)

assert vectorstore._client.get_collection("test_collection") is not None
vectorstore.delete_collection()

0 comments on commit c0147d6

Please sign in to comment.