Skip to content

Commit

Permalink
Temporary solution to fix ivf_flat search result wrong when cosine
Browse files Browse the repository at this point in the history
Signed-off-by: Yudong Cai <[email protected]>
  • Loading branch information
cydrain authored and liliu-z committed Aug 22, 2023
1 parent da457e9 commit 3bae0ad
Showing 1 changed file with 27 additions and 0 deletions.
27 changes: 27 additions & 0 deletions src/index/ivf/ivf.cc
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,10 @@ class IvfIndexNode : public IndexNode {
private:
std::unique_ptr<T> index_;
std::shared_ptr<ThreadPool> search_pool_;

// temporary solution to fix IVF_FLAT cosine
mutable bool normalized_ = false;
mutable std::mutex normalize_mtx_;
};

} // namespace knowhere
Expand Down Expand Up @@ -249,6 +253,7 @@ IvfIndexNode<T>::Train(const DataSet& dataset, const Config& cfg) {
if (IsMetricType(base_cfg.metric_type.value(), knowhere::metric::COSINE)) {
if constexpr (!(std::is_same_v<faiss::IndexIVFFlatCC, T>)&&!(std::is_same_v<faiss::IndexScaNN, T>)) {
Normalize(dataset);
normalized_ = true;
}
}

Expand Down Expand Up @@ -426,6 +431,17 @@ IvfIndexNode<T>::Search(const DataSet& dataset, const Config& cfg, const BitsetV
if (is_cosine) {
copied_query = CopyAndNormalizeFloatVec(cur_query, dim);
cur_query = copied_query.get();

// temporary solution to fix IVF_FLAT cosine
if (!normalized_) {
std::lock_guard<std::mutex> lock(normalize_mtx_);
if (!normalized_) {
faiss::IndexIVFFlat* ivf_index = static_cast<faiss::IndexIVFFlat*>(index_.get());
size_t nb = ivf_index->arranged_codes.size() / ivf_index->code_size;
NormalizeVecs((float*)(ivf_index->arranged_codes.data()), nb, dim);
normalized_ = true;
}
}
}
index_->search_without_codes_thread_safe(1, cur_query, k, distances + offset, ids + offset, nprobe,
0, bitset);
Expand Down Expand Up @@ -510,6 +526,17 @@ IvfIndexNode<T>::RangeSearch(const DataSet& dataset, const Config& cfg, const Bi
if (is_cosine) {
copied_query = CopyAndNormalizeFloatVec(cur_query, dim);
cur_query = copied_query.get();

// temporary solution to fix IVF_FLAT cosine
if (!normalized_) {
std::lock_guard<std::mutex> lock(normalize_mtx_);
if (!normalized_) {
faiss::IndexIVFFlat* ivf_index = static_cast<faiss::IndexIVFFlat*>(index_.get());
size_t nb = ivf_index->arranged_codes.size() / ivf_index->code_size;
NormalizeVecs((float*)(ivf_index->arranged_codes.data()), nb, dim);
normalized_ = true;
}
}
}
index_->range_search_without_codes_thread_safe(1, cur_query, radius, &res, index_->nlist, 0,
bitset);
Expand Down

0 comments on commit 3bae0ad

Please sign in to comment.