From c55d0d6a195d9016db9e05791de5a2544070ca3f Mon Sep 17 00:00:00 2001 From: Gao Date: Sun, 4 Feb 2024 20:13:59 +0800 Subject: [PATCH] Shuffle ids for building hnsw index (#322) (#381) Signed-off-by: chasingegg --- include/knowhere/config.h | 5 ++ src/index/diskann/diskann.cc | 3 +- src/index/hnsw/hnsw.cc | 55 ++++++++++++++----- .../DiskANN/include/diskann/aux_utils.h | 4 +- thirdparty/DiskANN/src/aux_utils.cpp | 11 ++-- thirdparty/DiskANN/src/index.cpp | 15 ++++- 6 files changed, 71 insertions(+), 22 deletions(-) diff --git a/include/knowhere/config.h b/include/knowhere/config.h index 30850c85..318f40f5 100644 --- a/include/knowhere/config.h +++ b/include/knowhere/config.h @@ -509,6 +509,7 @@ class BaseConfig : public Config { CFG_BOOL trace_visit; CFG_BOOL enable_mmap; CFG_BOOL for_tuning; + CFG_BOOL shuffle_build; KNOHWERE_DECLARE_CONFIG(BaseConfig) { KNOWHERE_CONFIG_DECLARE_FIELD(metric_type) .set_default("L2") @@ -557,6 +558,10 @@ class BaseConfig : public Config { .for_deserialize() .for_deserialize_from_file(); KNOWHERE_CONFIG_DECLARE_FIELD(for_tuning).set_default(false).description("for tuning").for_search(); + KNOWHERE_CONFIG_DECLARE_FIELD(shuffle_build) + .set_default(true) + .description("shuffle ids before index building") + .for_train(); } virtual Status diff --git a/src/index/diskann/diskann.cc b/src/index/diskann/diskann.cc index a80ac5df..7a2844d4 100644 --- a/src/index/diskann/diskann.cc +++ b/src/index/diskann/diskann.cc @@ -308,7 +308,8 @@ DiskANNIndexNode::Build(const DataSet& dataset, const Config& cfg) { static_cast(build_conf.disk_pq_dims.value()), false, build_conf.accelerate_build.value(), - static_cast(num_nodes_to_cache)}; + static_cast(num_nodes_to_cache), + build_conf.shuffle_build.value()}; RETURN_IF_ERROR(TryDiskANNCall([&]() { int res = diskann::build_disk_index(diskann_internal_build_config); if (res != 0) diff --git a/src/index/hnsw/hnsw.cc b/src/index/hnsw/hnsw.cc index 4f24309a..fd800244 100644 --- a/src/index/hnsw/hnsw.cc +++ b/src/index/hnsw/hnsw.cc @@ -12,6 +12,7 @@ #include "knowhere/feder/HNSW.h" #include +#include #include "common/range_util.h" #include "hnswlib/hnswalg.h" @@ -76,29 +77,57 @@ class HnswIndexNode : public IndexNode { knowhere::TimeRecorder build_time("Building HNSW cost"); auto rows = dataset.GetRows(); + if (rows <= 0) { + LOG_KNOWHERE_ERROR_ << "Can not add empty data to HNSW index."; + return Status::empty_index; + } auto tensor = dataset.GetTensor(); auto hnsw_cfg = static_cast(cfg); - index_->addPoint(tensor, 0); - auto build_pool = ThreadPool::GetGlobalBuildThreadPool(); - std::vector> futures; - futures.reserve(rows); + bool shuffle_build = hnsw_cfg.shuffle_build.value(); std::atomic counter{0}; uint64_t one_tenth_row = rows / 10; - for (int i = 1; i < rows; ++i) { - futures.emplace_back(build_pool->push([&, idx = i]() { - index_->addPoint(((const char*)tensor + index_->data_size_ * idx), idx); - uint64_t added = counter.fetch_add(1); - if (added % one_tenth_row == 0) { - LOG_KNOWHERE_INFO_ << "HNSW build progress: " << (added / one_tenth_row) << "0%"; - } - })); + + std::vector shuffle_batch_ids; + constexpr int64_t batch_size = 8192; // same with diskann + int64_t round_num = std::ceil(float(rows - 1) / batch_size); + auto build_pool = ThreadPool::GetGlobalBuildThreadPool(); + std::vector> futures; + + if (shuffle_build) { + shuffle_batch_ids.reserve(round_num); + for (int i = 0; i < round_num; ++i) { + shuffle_batch_ids.emplace_back(i); + } + std::random_device rng; + std::mt19937 urng(rng()); + std::shuffle(shuffle_batch_ids.begin(), shuffle_batch_ids.end(), urng); } - knowhere::WaitAllSuccess(futures); + index_->addPoint(tensor, 0); + + futures.reserve(batch_size); + for (int64_t round_id = 0; round_id < round_num; round_id++) { + int64_t start_id = (shuffle_build ? shuffle_batch_ids[round_id] : round_id) * batch_size; + int64_t end_id = + std::min(rows - 1, ((shuffle_build ? shuffle_batch_ids[round_id] : round_id) + 1) * batch_size); + for (int64_t i = start_id; i < end_id; ++i) { + futures.emplace_back(build_pool->push([&, idx = i + 1]() { + index_->addPoint(((const char*)tensor + index_->data_size_ * idx), idx); + uint64_t added = counter.fetch_add(1); + if (added % one_tenth_row == 0) { + LOG_KNOWHERE_INFO_ << "HNSW build progress: " << (added / one_tenth_row) << "0%"; + } + })); + } + WaitAllSuccess(futures); + futures.clear(); + } + build_time.RecordSection(""); LOG_KNOWHERE_INFO_ << "HNSW built with #points num:" << index_->max_elements_ << " #M:" << index_->M_ << " #max level:" << index_->maxlevel_ << " #ef_construction:" << index_->ef_construction_ << " #dim:" << *(size_t*)(index_->space_->get_dist_func_param()); + return Status::success; } diff --git a/thirdparty/DiskANN/include/diskann/aux_utils.h b/thirdparty/DiskANN/include/diskann/aux_utils.h index c517c001..68e5f949 100644 --- a/thirdparty/DiskANN/include/diskann/aux_utils.h +++ b/thirdparty/DiskANN/include/diskann/aux_utils.h @@ -101,7 +101,7 @@ namespace diskann { template DISKANN_DLLEXPORT std::unique_ptr> build_merged_vamana_index( std::string base_file, diskann::Metric _compareMetric, unsigned L, - unsigned R, bool accelerate_build, double sampling_rate, + unsigned R, bool accelerate_build, bool shuffle_build, double sampling_rate, double ram_budget, std::string mem_index_path, std::string medoids_file, std::string centroids_file); @@ -141,6 +141,8 @@ namespace diskann { bool accelerate_build = false; // the cached nodes number uint32_t num_nodes_to_cache = 0; + // shuffle id to build index + bool shuffle_build = false; }; template diff --git a/thirdparty/DiskANN/src/aux_utils.cpp b/thirdparty/DiskANN/src/aux_utils.cpp index 525a01eb..72cbe3b8 100644 --- a/thirdparty/DiskANN/src/aux_utils.cpp +++ b/thirdparty/DiskANN/src/aux_utils.cpp @@ -510,7 +510,7 @@ namespace diskann { template std::unique_ptr> build_merged_vamana_index( std::string base_file, bool ip_prepared, diskann::Metric compareMetric, - unsigned L, unsigned R, bool accelerate_build, double sampling_rate, + unsigned L, unsigned R, bool accelerate_build, bool shuffle_build, double sampling_rate, double ram_budget, std::string mem_index_path, std::string medoids_file, std::string centroids_file) { size_t base_num, base_dim; @@ -532,6 +532,7 @@ namespace diskann { paras.Set("saturate_graph", 1); paras.Set("save_path", mem_index_path); paras.Set("accelerate_build", accelerate_build); + paras.Set("shuffle_build", shuffle_build); std::unique_ptr> _pvamanaIndex = std::unique_ptr>(new diskann::Index( @@ -1270,7 +1271,7 @@ namespace diskann { auto graph_s = std::chrono::high_resolution_clock::now(); auto vamana_index = diskann::build_merged_vamana_index( data_file_to_use.c_str(), ip_prepared, diskann::Metric::L2, L, R, - config.accelerate_build, p_val, indexing_ram_budget, mem_index_path, + config.accelerate_build, config.shuffle_build, p_val, indexing_ram_budget, mem_index_path, medoids_path, centroids_path); auto graph_e = std::chrono::high_resolution_clock::now(); std::chrono::duration graph_diff = graph_e - graph_s; @@ -1386,7 +1387,7 @@ namespace diskann { template DISKANN_DLLEXPORT std::unique_ptr> build_merged_vamana_index(std::string base_file, bool ip_prepared, diskann::Metric compareMetric, unsigned L, - unsigned R, bool accelerate_build, + unsigned R, bool accelerate_build, bool shuffle_build, double sampling_rate, double ram_budget, std::string mem_index_path, std::string medoids_path, @@ -1394,7 +1395,7 @@ namespace diskann { template DISKANN_DLLEXPORT std::unique_ptr> build_merged_vamana_index(std::string base_file, bool ip_prepared, diskann::Metric compareMetric, unsigned L, - unsigned R, bool accelerate_build, + unsigned R, bool accelerate_build, bool shuffle_build, double sampling_rate, double ram_budget, std::string mem_index_path, std::string medoids_path, @@ -1402,7 +1403,7 @@ namespace diskann { template DISKANN_DLLEXPORT std::unique_ptr> build_merged_vamana_index(std::string base_file, bool ip_prepared, diskann::Metric compareMetric, unsigned L, - unsigned R, bool accelerate_build, + unsigned R, bool accelerate_build, bool shuffle_build, double sampling_rate, double ram_budget, std::string mem_index_path, std::string medoids_path, diff --git a/thirdparty/DiskANN/src/index.cpp b/thirdparty/DiskANN/src/index.cpp index 4a8d65d2..86157932 100644 --- a/thirdparty/DiskANN/src/index.cpp +++ b/thirdparty/DiskANN/src/index.cpp @@ -1429,6 +1429,7 @@ namespace diskann { _indexingRange = parameters.Get("R"); _indexingMaxC = parameters.Get("C"); const bool accelerate_build = parameters.Get("accelerate_build"); + const bool shuffle_build = parameters.Get("shuffle_build"); const float last_round_alpha = parameters.Get("alpha"); unsigned L = _indexingQueueSize; @@ -1518,10 +1519,20 @@ namespace diskann { } futures.reserve(round_size); + std::vector shuffle_batch_ids; + if (shuffle_build) { + shuffle_batch_ids.reserve(round_num_syncs); + for (unsigned i = 0; i < (unsigned) round_num_syncs; i++) { + shuffle_batch_ids.emplace_back(i); + } + std::random_device rng; + std::mt19937 urng(rng()); + std::shuffle(shuffle_batch_ids.begin(), shuffle_batch_ids.end(), urng); + } for (uint32_t sync_num = 0; sync_num < round_num_syncs; sync_num++) { - size_t start_id = sync_num * round_size; + size_t start_id = (shuffle_build ? shuffle_batch_ids[sync_num] : sync_num) * round_size; size_t end_id = - (std::min)(_nd + _num_frozen_pts, (sync_num + 1) * round_size); + (std::min)(_nd + _num_frozen_pts, ((shuffle_build ? shuffle_batch_ids[sync_num] : sync_num) + 1) * round_size); auto s = std::chrono::high_resolution_clock::now(); std::chrono::duration diff;