Skip to content

Commit

Permalink
[ENH] Add allowed_ids and disallowed_ids to HNSW bindings (#2174)
Browse files Browse the repository at this point in the history
## Description of changes

*Summarize the changes made by this PR.*
 - Improvements & Bug fixes
	 - This PR pass in allowed_ids and disallowed_ids to HNSW bindings.
 - New functionality
	 - ...

## Test plan
*How are these changes tested?*

- [ ] Tests pass locally with `pytest` for python, `yarn test` for js,
`cargo test` for rust

## Documentation Changes
*Are all docstrings for user-facing APIs updated if required? Do we need
to make documentation changes in the [docs
repository](https://github.com/chroma-core/docs)?*
  • Loading branch information
Ishiihara committed May 13, 2024
1 parent 6203deb commit 9b3ec89
Show file tree
Hide file tree
Showing 5 changed files with 171 additions and 21 deletions.
49 changes: 45 additions & 4 deletions rust/worker/bindings.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,28 @@
// Assumes that chroma-hnswlib is checked out at the same level as chroma
#include "../../../hnswlib/hnswlib/hnswlib.h"

class AllowAndDisallowListFilterFunctor : public hnswlib::BaseFilterFunctor
{
public:
std::unordered_set<hnswlib::labeltype> allow_list;
std::unordered_set<hnswlib::labeltype> disallow_list;

AllowAndDisallowListFilterFunctor(std::unordered_set<hnswlib::labeltype> allow_list, std::unordered_set<hnswlib::labeltype> disallow_list) : allow_list(allow_list), disallow_list(disallow_list) {}

bool operator()(hnswlib::labeltype id)
{
if (allow_list.size() > 0 && allow_list.find(id) == allow_list.end())
{
return false;
}
if (disallow_list.size() > 0 && disallow_list.find(id) != disallow_list.end())
{
return false;
}
return true;
}
};

template <typename dist_t, typename data_t = float>
class Index
{
Expand Down Expand Up @@ -108,13 +130,31 @@ class Index
return 0;
}

void knn_query(const data_t *query_vector, const size_t k, hnswlib::labeltype *ids, data_t *distance)
size_t knn_query(const data_t *query_vector, const size_t k, hnswlib::labeltype *ids, data_t *distance, const hnswlib::labeltype *allowed_ids, const size_t allowed_id_length, const hnswlib::labeltype *disallowed_ids, const size_t disallowed_id_length)
{
if (!index_inited)
{
std::runtime_error("Index not inited");
}
std::priority_queue<std::pair<dist_t, hnswlib::labeltype>> res = appr_alg->searchKnn(query_vector, k);

std::unordered_set<hnswlib::labeltype> allow_list;
std::unordered_set<hnswlib::labeltype> disallow_list;
if (allowed_ids != NULL)
{
for (int i = 0; i < allowed_id_length; i++)
{
allow_list.insert(allowed_ids[i]);
}
}
if (disallowed_ids != NULL)
{
for (int i = 0; i < disallowed_id_length; i++)
{
disallow_list.insert(disallowed_ids[i]);
}
}
AllowAndDisallowListFilterFunctor filter = AllowAndDisallowListFilterFunctor(allow_list, disallow_list);
std::priority_queue<std::pair<dist_t, hnswlib::labeltype>> res = appr_alg->searchKnn(query_vector, k, &filter);
if (res.size() < k)
{
// TODO: This is ok and we should return < K results, but for maintining compatibility with the old API we throw an error for now
Expand All @@ -128,6 +168,7 @@ class Index
distance[i] = res_i.first;
res.pop();
}
return total_results;
}

int get_ef()
Expand Down Expand Up @@ -186,9 +227,9 @@ extern "C"
return index->mark_deleted(id);
}

void knn_query(Index<float> *index, const float *query_vector, const size_t k, hnswlib::labeltype *ids, float *distance)
size_t knn_query(Index<float> *index, const float *query_vector, const size_t k, hnswlib::labeltype *ids, float *distance, const hnswlib::labeltype *allowed_ids, const size_t allowed_id_length, const hnswlib::labeltype *disallowed_ids, const size_t disallowed_id_length)
{
index->knn_query(query_vector, k, ids, distance);
return index->knn_query(query_vector, k, ids, distance, allowed_ids, allowed_id_length, disallowed_ids, disallowed_id_length);
}

int get_ef(Index<float> *index)
Expand Down
36 changes: 34 additions & 2 deletions rust/worker/src/execution/operators/hnsw_knn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,20 @@ impl HnswKnnOperator {
}
Ok(disallowed_ids)
}

// Validate that the allowed ids are not in the disallowed ids
fn validate_allowed_and_disallowed_ids(
&self,
allowed_ids: &[u32],
disallowed_ids: &[u32],
) -> Result<(), Box<dyn ChromaError>> {
for allowed_id in allowed_ids {
if disallowed_ids.contains(allowed_id) {
return Err(Box::new(HnswKnnOperatorError::RecordSegmentError));
}
}
Ok(())
}
}

#[async_trait]
Expand Down Expand Up @@ -118,8 +132,26 @@ impl Operator<HnswKnnOperatorInput, HnswKnnOperatorOutput> for HnswKnnOperator {
}
};

// TODO: pass in the updated + deleted ids from log and the result from the metadata filtering
let (offset_ids, distances) = input.segment.query(&input.query, input.k);
match self.validate_allowed_and_disallowed_ids(&allowed_offset_ids, &disallowed_offset_ids)
{
Ok(_) => {}
Err(e) => {
return Err(e);
}
};

// Convert to usize
let allowed_offset_ids: Vec<usize> =
allowed_offset_ids.iter().map(|&x| x as usize).collect();
let disallowed_offset_ids: Vec<usize> =
disallowed_offset_ids.iter().map(|&x| x as usize).collect();

let (offset_ids, distances) = input.segment.query(
&input.query,
input.k,
&allowed_offset_ids,
&disallowed_offset_ids,
);
Ok(HnswKnnOperatorOutput {
offset_ids,
distances,
Expand Down
84 changes: 77 additions & 7 deletions rust/worker/src/index/hnsw.rs
Original file line number Diff line number Diff line change
Expand Up @@ -208,18 +208,33 @@ impl Index<HnswIndexConfig> for HnswIndex {
unsafe { mark_deleted(self.ffi_ptr, id) }
}

fn query(&self, vector: &[f32], k: usize) -> (Vec<usize>, Vec<f32>) {
fn query(
&self,
vector: &[f32],
k: usize,
allowed_ids: &[usize],
disallowed_ids: &[usize],
) -> (Vec<usize>, Vec<f32>) {
let actual_k = std::cmp::min(k, self.len());
let mut ids = vec![0usize; actual_k];
let mut distance = vec![0.0f32; actual_k];
let mut total_result = actual_k;
unsafe {
knn_query(
total_result = knn_query(
self.ffi_ptr,
vector.as_ptr(),
k,
ids.as_mut_ptr(),
distance.as_mut_ptr(),
);
allowed_ids.as_ptr(),
allowed_ids.len(),
disallowed_ids.as_ptr(),
disallowed_ids.len(),
) as usize;
}
if total_result < actual_k {
ids.truncate(total_result);
distance.truncate(total_result);
}
return (ids, distance);
}
Expand Down Expand Up @@ -317,7 +332,11 @@ extern "C" {
k: usize,
ids: *mut usize,
distance: *mut f32,
);
allowed_ids: *const usize,
allowed_ids_length: usize,
disallowed_ids: *const usize,
disallowed_ids_length: usize,
) -> c_int;

fn get_ef(index: *const IndexPtrFFI) -> c_int;
fn set_ef(index: *const IndexPtrFFI, ef: c_int);
Expand Down Expand Up @@ -499,7 +518,9 @@ pub mod test {

// Query the data
let query = &data[0..d];
let (ids, distances) = index.query(query, 1);
let allow_ids = &[];
let disallow_ids = &[];
let (ids, distances) = index.query(query, 1, allow_ids, disallow_ids);
assert_eq!(ids.len(), 1);
assert_eq!(distances.len(), 1);
assert_eq!(ids[0], 0);
Expand Down Expand Up @@ -551,10 +572,12 @@ pub mod test {
index.delete(*id);
}

let allow_ids = &[];
let disallow_ids = &[];
// Query for the deleted ids and ensure they are not found
for deleted_id in &delete_ids {
let target_vector = &data[*deleted_id * d..(*deleted_id + 1) * d];
let (ids, _) = index.query(target_vector, 10);
let (ids, _) = index.query(target_vector, 10, allow_ids, disallow_ids);
for check_deleted_id in &delete_ids {
assert!(!ids.contains(check_deleted_id));
}
Expand Down Expand Up @@ -625,7 +648,9 @@ pub mod test {

// Query the data
let query = &data[0..d];
let (ids, distances) = index.query(query, 1);
let allow_ids = &[];
let disallow_ids = &[];
let (ids, distances) = index.query(query, 1, allow_ids, disallow_ids);
assert_eq!(ids.len(), 1);
assert_eq!(distances.len(), 1);
assert_eq!(ids[0], 0);
Expand All @@ -647,4 +672,49 @@ pub mod test {
i += 1;
}
}

#[test]
fn it_can_add_and_query_with_allowed_and_disallowed_ids() {
let n = 1000;
let d: usize = 960;
let distance_function = DistanceFunction::Euclidean;
let tmp_dir = tempdir().unwrap();
let persist_path = tmp_dir.path().to_str().unwrap().to_string();
let index = HnswIndex::init(
&IndexConfig {
dimensionality: d as i32,
distance_function: distance_function,
},
Some(&HnswIndexConfig {
max_elements: n,
m: 16,
ef_construction: 100,
ef_search: 100,
random_seed: 0,
persist_path: persist_path,
}),
Uuid::new_v4(),
);

let index = match index {
Err(e) => panic!("Error initializing index: {}", e),
Ok(index) => index,
};

let data: Vec<f32> = utils::generate_random_data(n, d);
let ids: Vec<usize> = (0..n).collect();

(0..n).into_iter().for_each(|i| {
let data = &data[i * d..(i + 1) * d];
index.add(ids[i], data);
});

// Query the data
let query = &data[0..d];
let allow_ids = &[0, 2];
let disallow_ids = &[3];
let (ids, distances) = index.query(query, 10, allow_ids, disallow_ids);
assert_eq!(ids.len(), 2);
assert_eq!(distances.len(), 2);
}
}
8 changes: 7 additions & 1 deletion rust/worker/src/index/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,13 @@ pub(crate) trait Index<C> {
Self: Sized;
fn add(&self, id: usize, vector: &[f32]);
fn delete(&self, id: usize);
fn query(&self, vector: &[f32], k: usize) -> (Vec<usize>, Vec<f32>);
fn query(
&self,
vector: &[f32],
k: usize,
allowed_ids: &[usize],
disallow_ids: &[usize],
) -> (Vec<usize>, Vec<f32>);
fn get(&self, id: usize) -> Option<Vec<f32>>;
}

Expand Down
15 changes: 8 additions & 7 deletions rust/worker/src/segment/distributed_hnsw_segment.rs
Original file line number Diff line number Diff line change
Expand Up @@ -170,11 +170,6 @@ impl DistributedHNSWSegmentWriter {
)))
}
}

pub(crate) fn query(&self, vector: &[f32], k: usize) -> (Vec<usize>, Vec<f32>) {
let index = self.index.read();
index.query(vector, k)
}
}

impl SegmentWriter for DistributedHNSWSegmentWriter {
Expand Down Expand Up @@ -353,8 +348,14 @@ impl DistributedHNSWSegmentReader {
}
}

pub(crate) fn query(&self, vector: &[f32], k: usize) -> (Vec<usize>, Vec<f32>) {
pub(crate) fn query(
&self,
vector: &[f32],
k: usize,
allowed_ids: &[usize],
disallowd_ids: &[usize],
) -> (Vec<usize>, Vec<f32>) {
let index = self.index.read();
index.query(vector, k)
index.query(vector, k, allowed_ids, disallowd_ids)
}
}

0 comments on commit 9b3ec89

Please sign in to comment.