Skip to content

Commit

Permalink
Shuffle ids for building hnsw index (#322) (#381)
Browse files Browse the repository at this point in the history
Signed-off-by: chasingegg <[email protected]>
  • Loading branch information
chasingegg committed Feb 4, 2024
1 parent 861ee91 commit c55d0d6
Show file tree
Hide file tree
Showing 6 changed files with 71 additions and 22 deletions.
5 changes: 5 additions & 0 deletions include/knowhere/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -509,6 +509,7 @@ class BaseConfig : public Config {
CFG_BOOL trace_visit;
CFG_BOOL enable_mmap;
CFG_BOOL for_tuning;
CFG_BOOL shuffle_build;
KNOHWERE_DECLARE_CONFIG(BaseConfig) {
KNOWHERE_CONFIG_DECLARE_FIELD(metric_type)
.set_default("L2")
Expand Down Expand Up @@ -557,6 +558,10 @@ class BaseConfig : public Config {
.for_deserialize()
.for_deserialize_from_file();
KNOWHERE_CONFIG_DECLARE_FIELD(for_tuning).set_default(false).description("for tuning").for_search();
KNOWHERE_CONFIG_DECLARE_FIELD(shuffle_build)
.set_default(true)
.description("shuffle ids before index building")
.for_train();
}

virtual Status
Expand Down
3 changes: 2 additions & 1 deletion src/index/diskann/diskann.cc
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,8 @@ DiskANNIndexNode<T>::Build(const DataSet& dataset, const Config& cfg) {
static_cast<uint32_t>(build_conf.disk_pq_dims.value()),
false,
build_conf.accelerate_build.value(),
static_cast<uint32_t>(num_nodes_to_cache)};
static_cast<uint32_t>(num_nodes_to_cache),
build_conf.shuffle_build.value()};
RETURN_IF_ERROR(TryDiskANNCall([&]() {
int res = diskann::build_disk_index<T>(diskann_internal_build_config);
if (res != 0)
Expand Down
55 changes: 42 additions & 13 deletions src/index/hnsw/hnsw.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "knowhere/feder/HNSW.h"

#include <new>
#include <numeric>

#include "common/range_util.h"
#include "hnswlib/hnswalg.h"
Expand Down Expand Up @@ -76,29 +77,57 @@ class HnswIndexNode : public IndexNode {

knowhere::TimeRecorder build_time("Building HNSW cost");
auto rows = dataset.GetRows();
if (rows <= 0) {
LOG_KNOWHERE_ERROR_ << "Can not add empty data to HNSW index.";
return Status::empty_index;
}
auto tensor = dataset.GetTensor();
auto hnsw_cfg = static_cast<const HnswConfig&>(cfg);
index_->addPoint(tensor, 0);
auto build_pool = ThreadPool::GetGlobalBuildThreadPool();
std::vector<folly::Future<folly::Unit>> futures;
futures.reserve(rows);
bool shuffle_build = hnsw_cfg.shuffle_build.value();

std::atomic<uint64_t> counter{0};
uint64_t one_tenth_row = rows / 10;
for (int i = 1; i < rows; ++i) {
futures.emplace_back(build_pool->push([&, idx = i]() {
index_->addPoint(((const char*)tensor + index_->data_size_ * idx), idx);
uint64_t added = counter.fetch_add(1);
if (added % one_tenth_row == 0) {
LOG_KNOWHERE_INFO_ << "HNSW build progress: " << (added / one_tenth_row) << "0%";
}
}));

std::vector<int> shuffle_batch_ids;
constexpr int64_t batch_size = 8192; // same with diskann
int64_t round_num = std::ceil(float(rows - 1) / batch_size);
auto build_pool = ThreadPool::GetGlobalBuildThreadPool();
std::vector<folly::Future<folly::Unit>> futures;

if (shuffle_build) {
shuffle_batch_ids.reserve(round_num);
for (int i = 0; i < round_num; ++i) {
shuffle_batch_ids.emplace_back(i);
}
std::random_device rng;
std::mt19937 urng(rng());
std::shuffle(shuffle_batch_ids.begin(), shuffle_batch_ids.end(), urng);
}
knowhere::WaitAllSuccess(futures);
index_->addPoint(tensor, 0);

futures.reserve(batch_size);
for (int64_t round_id = 0; round_id < round_num; round_id++) {
int64_t start_id = (shuffle_build ? shuffle_batch_ids[round_id] : round_id) * batch_size;
int64_t end_id =
std::min(rows - 1, ((shuffle_build ? shuffle_batch_ids[round_id] : round_id) + 1) * batch_size);
for (int64_t i = start_id; i < end_id; ++i) {
futures.emplace_back(build_pool->push([&, idx = i + 1]() {
index_->addPoint(((const char*)tensor + index_->data_size_ * idx), idx);
uint64_t added = counter.fetch_add(1);
if (added % one_tenth_row == 0) {
LOG_KNOWHERE_INFO_ << "HNSW build progress: " << (added / one_tenth_row) << "0%";
}
}));
}
WaitAllSuccess(futures);
futures.clear();
}

build_time.RecordSection("");
LOG_KNOWHERE_INFO_ << "HNSW built with #points num:" << index_->max_elements_ << " #M:" << index_->M_
<< " #max level:" << index_->maxlevel_ << " #ef_construction:" << index_->ef_construction_
<< " #dim:" << *(size_t*)(index_->space_->get_dist_func_param());

return Status::success;
}

Expand Down
4 changes: 3 additions & 1 deletion thirdparty/DiskANN/include/diskann/aux_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ namespace diskann {
template<typename T>
DISKANN_DLLEXPORT std::unique_ptr<diskann::Index<T>> build_merged_vamana_index(
std::string base_file, diskann::Metric _compareMetric, unsigned L,
unsigned R, bool accelerate_build, double sampling_rate,
unsigned R, bool accelerate_build, bool shuffle_build, double sampling_rate,
double ram_budget, std::string mem_index_path, std::string medoids_file,
std::string centroids_file);

Expand Down Expand Up @@ -141,6 +141,8 @@ namespace diskann {
bool accelerate_build = false;
// the cached nodes number
uint32_t num_nodes_to_cache = 0;
// shuffle id to build index
bool shuffle_build = false;
};

template<typename T>
Expand Down
11 changes: 6 additions & 5 deletions thirdparty/DiskANN/src/aux_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -510,7 +510,7 @@ namespace diskann {
template<typename T>
std::unique_ptr<diskann::Index<T>> build_merged_vamana_index(
std::string base_file, bool ip_prepared, diskann::Metric compareMetric,
unsigned L, unsigned R, bool accelerate_build, double sampling_rate,
unsigned L, unsigned R, bool accelerate_build, bool shuffle_build, double sampling_rate,
double ram_budget, std::string mem_index_path, std::string medoids_file,
std::string centroids_file) {
size_t base_num, base_dim;
Expand All @@ -532,6 +532,7 @@ namespace diskann {
paras.Set<bool>("saturate_graph", 1);
paras.Set<std::string>("save_path", mem_index_path);
paras.Set<bool>("accelerate_build", accelerate_build);
paras.Set<bool>("shuffle_build", shuffle_build);

std::unique_ptr<diskann::Index<T>> _pvamanaIndex =
std::unique_ptr<diskann::Index<T>>(new diskann::Index<T>(
Expand Down Expand Up @@ -1270,7 +1271,7 @@ namespace diskann {
auto graph_s = std::chrono::high_resolution_clock::now();
auto vamana_index = diskann::build_merged_vamana_index<T>(
data_file_to_use.c_str(), ip_prepared, diskann::Metric::L2, L, R,
config.accelerate_build, p_val, indexing_ram_budget, mem_index_path,
config.accelerate_build, config.shuffle_build, p_val, indexing_ram_budget, mem_index_path,
medoids_path, centroids_path);
auto graph_e = std::chrono::high_resolution_clock::now();
std::chrono::duration<double> graph_diff = graph_e - graph_s;
Expand Down Expand Up @@ -1386,23 +1387,23 @@ namespace diskann {
template DISKANN_DLLEXPORT std::unique_ptr<diskann::Index<int8_t>>
build_merged_vamana_index<int8_t>(std::string base_file, bool ip_prepared,
diskann::Metric compareMetric, unsigned L,
unsigned R, bool accelerate_build,
unsigned R, bool accelerate_build, bool shuffle_build,
double sampling_rate, double ram_budget,
std::string mem_index_path,
std::string medoids_path,
std::string centroids_file);
template DISKANN_DLLEXPORT std::unique_ptr<diskann::Index<float>>
build_merged_vamana_index<float>(std::string base_file, bool ip_prepared,
diskann::Metric compareMetric, unsigned L,
unsigned R, bool accelerate_build,
unsigned R, bool accelerate_build, bool shuffle_build,
double sampling_rate, double ram_budget,
std::string mem_index_path,
std::string medoids_path,
std::string centroids_file);
template DISKANN_DLLEXPORT std::unique_ptr<diskann::Index<uint8_t>>
build_merged_vamana_index<uint8_t>(std::string base_file, bool ip_prepared,
diskann::Metric compareMetric, unsigned L,
unsigned R, bool accelerate_build,
unsigned R, bool accelerate_build, bool shuffle_build,
double sampling_rate, double ram_budget,
std::string mem_index_path,
std::string medoids_path,
Expand Down
15 changes: 13 additions & 2 deletions thirdparty/DiskANN/src/index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1429,6 +1429,7 @@ namespace diskann {
_indexingRange = parameters.Get<unsigned>("R");
_indexingMaxC = parameters.Get<unsigned>("C");
const bool accelerate_build = parameters.Get<bool>("accelerate_build");
const bool shuffle_build = parameters.Get<bool>("shuffle_build");
const float last_round_alpha = parameters.Get<float>("alpha");
unsigned L = _indexingQueueSize;

Expand Down Expand Up @@ -1518,10 +1519,20 @@ namespace diskann {
}

futures.reserve(round_size);
std::vector<unsigned> shuffle_batch_ids;
if (shuffle_build) {
shuffle_batch_ids.reserve(round_num_syncs);
for (unsigned i = 0; i < (unsigned) round_num_syncs; i++) {
shuffle_batch_ids.emplace_back(i);
}
std::random_device rng;
std::mt19937 urng(rng());
std::shuffle(shuffle_batch_ids.begin(), shuffle_batch_ids.end(), urng);
}
for (uint32_t sync_num = 0; sync_num < round_num_syncs; sync_num++) {
size_t start_id = sync_num * round_size;
size_t start_id = (shuffle_build ? shuffle_batch_ids[sync_num] : sync_num) * round_size;
size_t end_id =
(std::min)(_nd + _num_frozen_pts, (sync_num + 1) * round_size);
(std::min)(_nd + _num_frozen_pts, ((shuffle_build ? shuffle_batch_ids[sync_num] : sync_num) + 1) * round_size);

auto s = std::chrono::high_resolution_clock::now();
std::chrono::duration<double> diff;
Expand Down

0 comments on commit c55d0d6

Please sign in to comment.