-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BUG]: The batch, the sync and the missing vector #2062
base: main
Are you sure you want to change the base?
Changes from all commits
9ca85f3
de5f87a
7cc68db
5685050
805bc79
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -298,8 +298,12 @@ def collections( | |
metadata.update(test_hnsw_config) | ||
if with_persistent_hnsw_params: | ||
metadata["hnsw:batch_size"] = draw(st.integers(min_value=3, max_value=2000)) | ||
# batch_size > sync_threshold doesn't make sense | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nice |
||
metadata["hnsw:sync_threshold"] = draw( | ||
st.integers(min_value=3, max_value=2000) | ||
st.integers( | ||
min_value=metadata["hnsw:batch_size"], | ||
max_value=metadata["hnsw:batch_size"] + 2000, | ||
) | ||
) | ||
# Sometimes, select a space at random | ||
if draw(st.booleans()): | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,9 +1,13 @@ | ||
import uuid | ||
|
||
import pytest | ||
import logging | ||
import hypothesis.strategies as st | ||
from hypothesis import given | ||
from hypothesis import given, settings | ||
from typing import Dict, Set, cast, Union, DefaultDict, Any, List | ||
from dataclasses import dataclass | ||
|
||
from chromadb.api.fastapi import FastAPI | ||
from chromadb.api.types import ID, Include, IDs, validate_embeddings | ||
import chromadb.errors as errors | ||
from chromadb.api import ServerAPI | ||
|
@@ -25,7 +29,6 @@ | |
import chromadb.test.property.invariants as invariants | ||
import numpy as np | ||
|
||
|
||
traces: DefaultDict[str, int] = defaultdict(lambda: 0) | ||
|
||
|
||
|
@@ -58,7 +61,10 @@ class EmbeddingStateMachineStates: | |
upsert_embeddings = "upsert_embeddings" | ||
|
||
|
||
collection_st = st.shared(strategies.collections(with_hnsw_params=True), key="coll") | ||
collection_st = st.shared( | ||
strategies.collections(with_hnsw_params=True, with_persistent_hnsw_params=True), | ||
key="coll", | ||
) | ||
|
||
|
||
class EmbeddingStateMachine(RuleBasedStateMachine): | ||
|
@@ -73,13 +79,30 @@ def __init__(self, api: ServerAPI): | |
@initialize(collection=collection_st) # type: ignore | ||
def initialize(self, collection: strategies.Collection): | ||
self.api.reset() | ||
self.collection = self.api.create_collection( | ||
name=collection.name, | ||
metadata=collection.metadata, | ||
embedding_function=collection.embedding_function, | ||
) | ||
try: | ||
self.collection = self.api.create_collection( | ||
name=collection.name, | ||
metadata=collection.metadata, | ||
embedding_function=collection.embedding_function, | ||
) | ||
except Exception as e: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Whats this doing and why - seems comment-worthy There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
self.api.reset() | ||
if "hnsw:batch_size" in str(e): | ||
del collection.metadata["hnsw:batch_size"] | ||
del collection.metadata["hnsw:sync_threshold"] | ||
try: | ||
self.collection = self.api.create_collection( | ||
name=collection.name, | ||
metadata=collection.metadata, | ||
embedding_function=collection.embedding_function, | ||
) | ||
except Exception as e: | ||
raise e | ||
else: | ||
raise e | ||
self.embedding_function = collection.embedding_function | ||
trace("init") | ||
self._metadata = collection.metadata | ||
self.on_state_change(EmbeddingStateMachineStates.initialize) | ||
|
||
self.record_set_state = strategies.StateMachineRecordSet( | ||
|
@@ -462,3 +485,64 @@ def test_0dim_embedding_validation() -> None: | |
with pytest.raises(ValueError) as e: | ||
validate_embeddings(embds) | ||
assert "Expected each embedding in the embeddings to be a non-empty list" in str(e) | ||
|
||
|
||
@dataclass | ||
class BatchParams: | ||
batch_size: int | ||
sync_threshold: int | ||
item_size: int | ||
|
||
|
||
@st.composite | ||
def batching_params(draw: st.DrawFn) -> BatchParams: | ||
batch_size = draw(st.integers(min_value=3, max_value=100)) | ||
sync_threshold = draw(st.integers(min_value=batch_size, max_value=batch_size * 2)) | ||
item_size = draw( | ||
st.integers(min_value=batch_size + 1, max_value=(batch_size * 2) + 1) | ||
) | ||
return BatchParams( | ||
batch_size=batch_size, sync_threshold=sync_threshold, item_size=item_size | ||
) | ||
|
||
|
||
@settings(max_examples=1, deadline=None) | ||
@given(batching_params=batching_params()) | ||
def test_batching(batching_params: BatchParams, api: ServerAPI) -> None: | ||
error_distribution = {"IndexError": 0, "TypeError": 0, "NoError": 0} | ||
rounds = 100 | ||
if isinstance(api, FastAPI) or not api.get_settings().is_persistent: | ||
pytest.skip("FastAPI does not support this test") | ||
for _ in range( | ||
rounds | ||
): # we do a few rounds to ensure that key or lists arrangements (due to UUID randomness) do not affect the test | ||
api.reset() | ||
collection = api.get_or_create_collection( | ||
"test", | ||
metadata={ | ||
"hnsw:batch_size": batching_params.batch_size, | ||
"hnsw:sync_threshold": batching_params.sync_threshold, | ||
}, | ||
) | ||
items = [ | ||
(f"{uuid.uuid4()}", i, [0.1] * 2) for i in range(batching_params.item_size) | ||
] # we want to exceed the batch size by at least 1 | ||
ids = [item[0] for item in items] | ||
embeddings = [item[2] for item in items] | ||
collection.add(ids=ids, embeddings=embeddings) | ||
collection.delete(ids=[ids[0]]) | ||
collection.add(ids=[ids[0]], embeddings=[[1] * 2]) | ||
try: | ||
collection.get(include=["embeddings"]) | ||
error_distribution["NoError"] += 1 | ||
except IndexError as e: | ||
if "list assignment index out of range" in str(e): | ||
error_distribution["IndexError"] += 1 | ||
except TypeError as e: | ||
if "'NoneType' object is not subscriptable" in str(e): | ||
error_distribution["TypeError"] += 1 | ||
invariants.segments_len_match(api, collection) | ||
|
||
assert error_distribution["NoError"] == rounds | ||
assert error_distribution["IndexError"] == 0 | ||
assert error_distribution["TypeError"] == 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice