Skip to content

Commit

Permalink
async generate diskann cache (#191)
Browse files Browse the repository at this point in the history
Signed-off-by: cqy123456 <[email protected]>
  • Loading branch information
cqy123456 committed Nov 22, 2023
1 parent 26abda0 commit d63c403
Show file tree
Hide file tree
Showing 4 changed files with 106 additions and 38 deletions.
33 changes: 16 additions & 17 deletions knowhere/index/vector_index/IndexDiskANN.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -290,9 +290,9 @@ IndexDiskANN<T>::Prepare(const Config& config) {
KNOWHERE_THROW_MSG("Failed to generate cache, num_nodes_to_cache is larger than 1/3 of the total data number.");
}
if (num_nodes_to_cache > 0) {
std::vector<uint32_t> node_list;
LOG_KNOWHERE_INFO_ << "Caching " << num_nodes_to_cache << " sample nodes around medoid(s).";
if (prep_conf.use_bfs_cache) {
std::vector<uint32_t> node_list;
auto gen_cache_successful = TryDiskANNCall<bool>([&]() -> bool {
pq_flash_index_->cache_bfs_levels(num_nodes_to_cache, node_list);
return true;
Expand All @@ -302,29 +302,28 @@ IndexDiskANN<T>::Prepare(const Config& config) {
LOG_KNOWHERE_ERROR_ << "Failed to generate bfs cache for DiskANN.";
return false;
}
} else {
auto gen_cache_successful = TryDiskANNCall<bool>([&]() -> bool {
pq_flash_index_->generate_cache_list_from_sample_queries(warmup_query_file, 15, 6, num_nodes_to_cache,
prep_conf.num_threads, node_list);
auto load_cache_successful = TryDiskANNCall<bool>([&]() -> bool {
pq_flash_index_->load_cache_list(node_list);
return true;
});

if (!gen_cache_successful.has_value()) {
LOG_KNOWHERE_ERROR_ << "Failed to generate cache from sample queries for DiskANN.";
if (!load_cache_successful.has_value()) {
LOG_KNOWHERE_ERROR_ << "Failed to load cache for DiskANN.";
return false;
}
}
auto load_cache_successful = TryDiskANNCall<bool>([&]() -> bool {
pq_flash_index_->load_cache_list(node_list);
return true;
});

if (!load_cache_successful.has_value()) {
LOG_KNOWHERE_ERROR_ << "Failed to load cache for DiskANN.";
return false;
} else {
pq_flash_index_->set_async_cache_flag(true);
pool_->push([&, cache_num = num_nodes_to_cache,
sample_nodes_file = warmup_query_file]() {
try {
pq_flash_index_->generate_cache_list_from_sample_queries(
sample_nodes_file, 15, 6, cache_num);
} catch (const std::exception& e) {
LOG_KNOWHERE_ERROR_ << "DiskANN Exception: " << e.what();
}
});
}
}

// warmup
if (prep_conf.warm_up) {
LOG_KNOWHERE_DEBUG_ << "Warming up.";
Expand Down
17 changes: 11 additions & 6 deletions thirdparty/DiskANN/include/pq_flash_index.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include "percentile_stats.h"
#include "pq_table.h"
#include "utils.h"
#include "semaphore.h"
#include "windows_customizations.h"

#define MAX_GRAPH_DEGREE 512
Expand Down Expand Up @@ -80,16 +81,15 @@ namespace diskann {

DISKANN_DLLEXPORT void load_cache_list(std::vector<uint32_t> &node_list);

// asynchronously collect the access frequency of each node in the graph
#ifdef EXEC_ENV_OLS
DISKANN_DLLEXPORT void generate_cache_list_from_sample_queries(
MemoryMappedFiles &files, std::string sample_bin, _u64 l_search,
_u64 beamwidth, _u64 num_nodes_to_cache, uint32_t nthreads,
std::vector<uint32_t> &node_list);
MemoryMappedFiles files, std::string sample_bin, _u64 l_search,
_u64 beamwidth, _u64 num_nodes_to_cache);
#else
DISKANN_DLLEXPORT void generate_cache_list_from_sample_queries(
std::string sample_bin, _u64 l_search, _u64 beamwidth,
_u64 num_nodes_to_cache, uint32_t num_threads,
std::vector<uint32_t> &node_list);
_u64 num_nodes_to_cache);
#endif

DISKANN_DLLEXPORT void cache_bfs_levels(_u64 num_nodes_to_cache,
Expand Down Expand Up @@ -128,6 +128,8 @@ namespace diskann {

DISKANN_DLLEXPORT diskann::Metric get_metric() const noexcept;

DISKANN_DLLEXPORT void set_async_cache_flag(const bool flag);

protected:
DISKANN_DLLEXPORT void use_medoids_data_as_centroids();
DISKANN_DLLEXPORT void setup_thread_data(_u64 nthreads);
Expand Down Expand Up @@ -195,6 +197,7 @@ namespace diskann {

std::string disk_index_file;
std::vector<std::pair<_u32, _u32>> node_visit_counter;
std::atomic<_u32> search_counter = 0;

// PQ data
// n_chunks = # of chunks ndims is split into
Expand Down Expand Up @@ -233,12 +236,14 @@ namespace diskann {
// coord_cache
T * coord_cache_buf = nullptr;
tsl::robin_map<_u32, T *> coord_cache;
Semaphore semaph;
std::atomic<bool> async_generate_cache = false;

// thread-specific scratch
ConcurrentQueue<ThreadData<T>> thread_data;
_u64 max_nthreads;
bool load_flag = false;
bool count_visited_nodes = false;
std::atomic<bool> count_visited_nodes = false;
bool reorder_data_exists = false;
_u64 reoreder_data_offset = 0;

Expand Down
35 changes: 35 additions & 0 deletions thirdparty/DiskANN/include/semaphore.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
#pragma once
#include <mutex>
#include <condition_variable>

namespace diskann {
class Semaphore {
public:
Semaphore(long count = 0) : count(count) {}
void Signal()
{
std::unique_lock<std::mutex> unique(mt);
++count;
if (count <= 0) {
cond.notify_one();
}
}
void Wait()
{
std::unique_lock<std::mutex> unique(mt);
--count;
if (count < 0) {
cond.wait(unique);
}
}
bool IsWaitting() {
std::unique_lock<std::mutex> unique(mt);
return count < 0;
}

private:
std::mutex mt;
std::condition_variable cond;
long count;
};
} // namespace diskann
59 changes: 44 additions & 15 deletions thirdparty/DiskANN/src/pq_flash_index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include <iterator>
#include <random>
#include <thread>
#include <mutex>
#include "distance.h"
#include "exceptions.h"
#include "parameters.h"
Expand Down Expand Up @@ -98,7 +99,7 @@ namespace diskann {
template<typename T>
PQFlashIndex<T>::PQFlashIndex(
std::shared_ptr<AlignedFileReader> fileReader, diskann::Metric m)
: reader(fileReader), metric(m) {
: reader(fileReader), metric(m), semaph(0) {
if (m == diskann::Metric::COSINE || m == diskann::Metric::INNER_PRODUCT) {
if (std::is_floating_point<T>::value) {
LOG(INFO) << "Cosine metric chosen for (normalized) float data."
Expand All @@ -119,6 +120,9 @@ namespace diskann {

template<typename T>
PQFlashIndex<T>::~PQFlashIndex() {
if (this->async_generate_cache) {
this->semaph.Wait();
}
#ifndef EXEC_ENV_OLS
if (data != nullptr) {
delete[] data;
Expand Down Expand Up @@ -216,6 +220,8 @@ namespace diskann {
template<typename T>
void PQFlashIndex<T>::load_cache_list(std::vector<uint32_t> &node_list) {
LOG(DEBUG) << "Loading the cache list into memory...";
assert(this->nhood_cache_buf == nullptr && "nhoodc_cache_buf is not null");
assert(this->coord_cache_buf == nullptr && "coord_cache_buf is not null");
_u64 num_cached_nodes = node_list.size();

// borrow thread data
Expand Down Expand Up @@ -293,20 +299,20 @@ namespace diskann {
#ifdef EXEC_ENV_OLS
template<typename T>
void PQFlashIndex<T>::generate_cache_list_from_sample_queries(
MemoryMappedFiles &files, std::string sample_bin, _u64 l_search,
_u64 beamwidth, _u64 num_nodes_to_cache, uint32_t nthreads,
std::vector<uint32_t> &node_list) {
MemoryMappedFiles files, std::string sample_bin, _u64 l_search,
_u64 beamwidth, _u64 num_nodes_to_cache) {
#else
template<typename T>
void PQFlashIndex<T>::generate_cache_list_from_sample_queries(
std::string sample_bin, _u64 l_search, _u64 beamwidth,
_u64 num_nodes_to_cache, uint32_t nthreads,
std::vector<uint32_t> &node_list) {
_u64 num_nodes_to_cache) {
#endif
auto s = std::chrono::high_resolution_clock::now();
this->count_visited_nodes = true;
this->search_counter.store(0);
this->node_visit_counter.clear();
this->node_visit_counter.resize(this->num_points);
this->count_visited_nodes.store(true);

for (_u32 i = 0; i < node_visit_counter.size(); i++) {
this->node_visit_counter[i].first = i;
this->node_visit_counter[i].second = 0;
Expand All @@ -332,32 +338,47 @@ namespace diskann {
return;
}

std::vector<int64_t> tmp_result_ids_64(sample_num, 0);
std::vector<float> tmp_result_dists(sample_num, 0);
int64_t tmp_result_ids_64;
float tmp_result_dists;

#pragma omp parallel for schedule(dynamic, 1) num_threads(nthreads)
for (_s64 i = 0; i < (int64_t) sample_num; i++) {
cached_beam_search(samples + (i * sample_aligned_dim), 1, l_search,
tmp_result_ids_64.data() + (i * 1),
tmp_result_dists.data() + (i * 1), beamwidth);
auto id = 0;
while (this->search_counter.load() < sample_num && id < sample_num &&
!this->semaph.IsWaitting()) {
cached_beam_search(samples + (id * sample_aligned_dim), 1, l_search,
&tmp_result_ids_64, &tmp_result_dists, beamwidth);
id++;
}

if (this->semaph.IsWaitting()) {
this->semaph.Signal();
return;
}

this->count_visited_nodes.store(false);
std::sort(this->node_visit_counter.begin(), node_visit_counter.end(),
[](std::pair<_u32, _u32> &left, std::pair<_u32, _u32> &right) {
return left.second > right.second;
});

std::vector<uint32_t> node_list;
node_list.clear();
node_list.shrink_to_fit();
node_list.reserve(num_nodes_to_cache);
for (_u64 i = 0; i < num_nodes_to_cache; i++) {
node_list.push_back(this->node_visit_counter[i].first);
}
this->count_visited_nodes = false;
this->node_visit_counter.clear();
this->node_visit_counter.shrink_to_fit();
this->search_counter.store(0);

diskann::aligned_free(samples);
this->load_cache_list(node_list);
auto e = std::chrono::high_resolution_clock::now();
std::chrono::duration<double> diff = e - s;
LOG(INFO) << "Using sample queries to generate cache, cost: " << diff.count() << "s";

this->semaph.Signal();
return;
}

template<typename T>
Expand Down Expand Up @@ -1464,6 +1485,9 @@ namespace diskann {
if (stats != nullptr) {
stats->total_us = (double) query_timer.elapsed();
}
if (this->count_visited_nodes) {
this->search_counter.fetch_add(1);
}
}

// range search returns results of all neighbors within distance of range.
Expand Down Expand Up @@ -1548,6 +1572,11 @@ namespace diskann {
diskann::Metric PQFlashIndex<T>::get_metric() const noexcept {
return metric;
}

template<typename T>
void PQFlashIndex<T>::set_async_cache_flag(const bool flag) {
this->async_generate_cache.exchange(flag);
}

#ifdef EXEC_ENV_OLS
template<typename T>
Expand Down

0 comments on commit d63c403

Please sign in to comment.