Skip to content

Commit

Permalink
feat(search): Multishard cutoffs
Browse files Browse the repository at this point in the history
Signed-off-by: Vladislav Oleshko <[email protected]>
  • Loading branch information
dranikpg committed Sep 25, 2023
1 parent d8b99dc commit fe7559a
Show file tree
Hide file tree
Showing 13 changed files with 496 additions and 165 deletions.
5 changes: 5 additions & 0 deletions src/core/search/ast_expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,11 @@ struct AstNode : public NodeVariants {
const NodeVariants& Variant() const& {
return *this;
}

// Aggregations reduce and re-order result sets.
bool IsAggregation() const {
return std::holds_alternative<AstKnnNode>(Variant());
}
};

using AstExpr = AstNode;
Expand Down
4 changes: 4 additions & 0 deletions src/core/search/search.cc
Original file line number Diff line number Diff line change
Expand Up @@ -499,4 +499,8 @@ void SearchAlgorithm::EnableProfiling() {
profiling_enabled_ = true;
}

bool SearchAlgorithm::IsProfilingEnabled() const {
return profiling_enabled_;
}

} // namespace dfly::search
1 change: 1 addition & 0 deletions src/core/search/search.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ class SearchAlgorithm {
std::optional<size_t> HasKnn() const;

void EnableProfiling();
bool IsProfilingEnabled() const;

private:
bool profiling_enabled_ = false;
Expand Down
11 changes: 11 additions & 0 deletions src/facade/reply_capture.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "facade/reply_capture.h"

#include "base/logging.h"
#include "facade/conn_context.h"
#include "reply_capture.h"

#define SKIP_LESS(needed) \
Expand Down Expand Up @@ -150,6 +151,16 @@ void CapturingReplyBuilder::CollapseFilledCollections() {
}
}

CapturingReplyBuilder::ScopeCapture::ScopeCapture(CapturingReplyBuilder* crb,
ConnectionContext* cntx)
: cntx_{cntx} {
old_ = cntx->Inject(crb);
}

CapturingReplyBuilder::ScopeCapture::~ScopeCapture() {
cntx_->Inject(old_);
}

CapturingReplyBuilder::CollectionPayload::CollectionPayload(unsigned len, CollectionType type)
: len{len}, type{type}, arr{} {
arr.reserve(type == MAP ? len * 2 : len);
Expand Down
11 changes: 11 additions & 0 deletions src/facade/reply_capture.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

namespace facade {

class ConnectionContext;
struct CaptureVisitor;

// CapturingReplyBuilder allows capturing replies and retrieveing them with Take().
Expand Down Expand Up @@ -66,6 +67,16 @@ class CapturingReplyBuilder : public RedisReplyBuilder {
bool with_scores;
};

public:
struct ScopeCapture {
ScopeCapture(CapturingReplyBuilder* crb, ConnectionContext* cntx);
~ScopeCapture();

private:
SinkReplyBuilder* old_;
ConnectionContext* cntx_;
};

public:
CapturingReplyBuilder(ReplyMode mode = ReplyMode::FULL)
: RedisReplyBuilder{nullptr}, reply_mode_{mode}, stack_{}, current_{} {
Expand Down
8 changes: 5 additions & 3 deletions src/server/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -206,11 +206,13 @@ template <typename RandGen> std::string GetRandomHex(RandGen& gen, size_t len) {
// truthy value;
template <typename T> struct AggregateValue {
bool operator=(T val) {
if (!bool(val))
return false;

std::lock_guard l{mu_};
if (!bool(current_) && bool(val)) {
if (!bool(current_))
current_ = val;
}
return bool(val);
return true;
}

T operator*() {
Expand Down
12 changes: 3 additions & 9 deletions src/server/main_service.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1309,13 +1309,6 @@ void Service::Unwatch(CmdArgList args, ConnectionContext* cntx) {
return (*cntx)->SendOk();
}

template <typename F> void WithReplies(CapturingReplyBuilder* crb, ConnectionContext* cntx, F&& f) {
SinkReplyBuilder* old_rrb = nullptr;
old_rrb = cntx->Inject(crb);
f();
cntx->Inject(old_rrb);
}

optional<CapturingReplyBuilder::Payload> Service::FlushEvalAsyncCmds(ConnectionContext* cntx,
bool force) {
auto& info = cntx->conn_state.script_info;
Expand All @@ -1329,9 +1322,10 @@ optional<CapturingReplyBuilder::Payload> Service::FlushEvalAsyncCmds(ConnectionC
cntx->transaction->MultiSwitchCmd(eval_cid);

CapturingReplyBuilder crb{ReplyMode::ONLY_ERR};
WithReplies(&crb, cntx, [&] {
{
CapturingReplyBuilder::ScopeCapture capture{&crb, cntx};
MultiCommandSquasher::Execute(absl::MakeSpan(info->async_cmds), cntx, this, true, true);
});
}

info->async_cmds_heap_mem = 0;
info->async_cmds.clear();
Expand Down
94 changes: 75 additions & 19 deletions src/server/search/doc_index.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,25 @@ const absl::flat_hash_map<string_view, search::SchemaField::FieldType> kSchemaTy
{"NUMERIC"sv, search::SchemaField::NUMERIC},
{"VECTOR"sv, search::SchemaField::VECTOR}};

size_t GetProbabilisticBound(size_t shards, size_t hits, size_t requested, bool is_aggregation) {
auto intlog2 = [](size_t x) {
size_t l = 0;
while (x >>= 1)
++l;
return l;
};
size_t avg_shard_min = hits * intlog2(hits) / (12 + shards / 10);
avg_shard_min -= min(avg_shard_min, min(hits, size_t(5)));

VLOG(0) << "PROB BOUND " << hits << " " << shards << " " << requested << " => " << avg_shard_min
<< " diffb " << requested / shards + 1 << " & " << requested;

if (!is_aggregation && avg_shard_min * shards >= requested)
return requested / shards + 1;

return min(hits, requested);
}

} // namespace

optional<search::SchemaField::FieldType> ParseSearchFieldType(string_view name) {
Expand Down Expand Up @@ -149,10 +168,11 @@ bool DocIndex::Matches(string_view key, unsigned obj_code) const {
}

ShardDocIndex::ShardDocIndex(shared_ptr<DocIndex> index)
: base_{std::move(index)}, indices_{{}}, key_index_{} {
: base_{std::move(index)}, write_epoch_{0}, indices_{{}}, key_index_{} {
}

void ShardDocIndex::Rebuild(const OpArgs& op_args) {
write_epoch_++;
key_index_ = DocKeyIndex{};
indices_ = search::FieldIndices{base_->schema};

Expand All @@ -161,11 +181,13 @@ void ShardDocIndex::Rebuild(const OpArgs& op_args) {
}

void ShardDocIndex::AddDoc(string_view key, const DbContext& db_cntx, const PrimeValue& pv) {
write_epoch_++;
auto accessor = GetAccessor(db_cntx, pv);
indices_.Add(key_index_.Add(key), accessor.get());
}

void ShardDocIndex::RemoveDoc(string_view key, const DbContext& db_cntx, const PrimeValue& pv) {
write_epoch_++;
auto accessor = GetAccessor(db_cntx, pv);
DocId id = key_index_.Remove(key);
indices_.Remove(id, accessor.get());
Expand All @@ -175,38 +197,72 @@ bool ShardDocIndex::Matches(string_view key, unsigned obj_code) const {
return base_->Matches(key, obj_code);
}

SearchResult ShardDocIndex::Search(const OpArgs& op_args, const SearchParams& params,
search::SearchAlgorithm* search_algo) const {
auto& db_slice = op_args.shard->db_slice();
io::Result<SearchResult, facade::ErrorReply> ShardDocIndex::Search(
const OpArgs& op_args, const SearchParams& params, search::SearchAlgorithm* search_algo) const {
auto search_results = search_algo->Search(&indices_);

if (!search_results.error.empty())
return SearchResult{facade::ErrorReply{std::move(search_results.error)}};
return nonstd::make_unexpected(facade::ErrorReply(std::move(search_results.error)));

size_t serialize_count = min(search_results.ids.size(), params.limit_offset + params.limit_total);
vector<SerializedSearchDoc> out;
out.reserve(serialize_count);
size_t requested_count = params.limit_offset + params.limit_total;
size_t serialize_count = min(requested_count, search_results.ids.size());

size_t expired_count = 0;
for (size_t i = 0; i < search_results.ids.size() && out.size() < serialize_count; i++) {
auto key = key_index_.Get(search_results.ids[i]);
auto it = db_slice.Find(op_args.db_cntx, key, base_->GetObjCode());
size_t cuttoff_bound = serialize_count;
if (params.enable_cutoff && !params.IdsOnly())
cuttoff_bound = GetProbabilisticBound(params.num_shards, search_results.ids.size(),
requested_count, search_algo->HasKnn().has_value());

vector<DocResult> out(serialize_count);
auto shard_id = EngineShard::tlocal()->shard_id();
for (size_t i = 0; i < out.size(); i++) {
out[i].value = DocResult::DocReference{shard_id, search_results.ids[i], i < cuttoff_bound};
out[i].score = search_results.knn_distances.empty() ? 0 : search_results.knn_distances[i];
}

Serialize(op_args, params, absl::MakeSpan(out));

return SearchResult{write_epoch_, search_results.ids.size(), std::move(out),
std::move(search_results.profile)};
}

bool ShardDocIndex::Refill(const OpArgs& op_args, const SearchParams& params,
search::SearchAlgorithm* search_algo, SearchResult* result) const {
if (result->write_epoch == write_epoch_) {
Serialize(op_args, params, absl::MakeSpan(result->docs));
return true;
}

DCHECK(!params.enable_cutoff);
auto new_result = Search(op_args, params, search_algo);
CHECK(new_result.has_value());
*result = std::move(new_result.value());
return false;
}

void ShardDocIndex::Serialize(const OpArgs& op_args, const SearchParams& params,
absl::Span<DocResult> docs) const {
auto& db_slice = op_args.shard->db_slice();

for (auto& doc : docs) {
if (!holds_alternative<DocResult::DocReference>(doc.value))
continue;

auto ref = get<DocResult::DocReference>(doc.value);
if (!ref.requested)
return;

auto key = key_index_.Get(ref.doc_id);
auto it = db_slice.Find(op_args.db_cntx, key, base_->GetObjCode());
if (!it || !IsValid(*it)) { // Item must have expired
expired_count++;
doc.value = DocResult::SerializedValue{string{key}, {}};
continue;
}

auto accessor = GetAccessor(op_args.db_cntx, (*it)->second);
auto doc_data = params.return_fields ? accessor->Serialize(base_->schema, *params.return_fields)
: accessor->Serialize(base_->schema);

float score = search_results.knn_distances.empty() ? 0 : search_results.knn_distances[i];
out.push_back(SerializedSearchDoc{string{key}, std::move(doc_data), score});
doc.value = DocResult::SerializedValue{string{key}, std::move(doc_data)};
}

return SearchResult{std::move(out), search_results.ids.size() - expired_count,
std::move(search_results.profile)};
}

DocIndexInfo ShardDocIndex::GetInfo() const {
Expand Down
55 changes: 38 additions & 17 deletions src/server/search/doc_index.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,28 +23,36 @@ using SearchDocData = absl::flat_hash_map<std::string /*field*/, std::string /*v
std::optional<search::SchemaField::FieldType> ParseSearchFieldType(std::string_view name);
std::string_view SearchFieldTypeToString(search::SchemaField::FieldType);

struct SerializedSearchDoc {
std::string key;
SearchDocData values;
float knn_distance;
struct DocResult {
struct SerializedValue {
std::string key;
SearchDocData values;
};

struct DocReference {
ShardId shard_id;
search::DocId doc_id;
bool requested;
};

std::variant<SerializedValue, DocReference> value;
float score;
};

struct SearchResult {
SearchResult() = default;
size_t write_epoch = 0; // Write epoch of the index during on the result was created

SearchResult(std::vector<SerializedSearchDoc> docs, size_t total_hits,
std::optional<search::AlgorithmProfile> profile)
: docs{std::move(docs)}, total_hits{total_hits}, profile{std::move(profile)} {
}
size_t total_hits = 0; // total number of hits in shard
std::vector<DocResult> docs; // serialized documents of first hits

SearchResult(facade::ErrorReply error) : error{std::move(error)} {
}
// After combining results from multiple shards and accumulating more documents than initially
// requested, only a subset of all documents will be sent back to the client,
// so it doesn't make sense to serialize strictly all documents in every shard ahead.
// Instead, only documents up to a probablistic bound are serialized, the
// leftover ids and scores are stored in the cutoff tail for use in the "unlikely" scenario.
// size_t num_cutoff = 0;

std::vector<SerializedSearchDoc> docs;
size_t total_hits;
std::optional<search::AlgorithmProfile> profile;

std::optional<facade::ErrorReply> error;
};

struct SearchParams {
Expand All @@ -56,6 +64,10 @@ struct SearchParams {
size_t limit_offset;
size_t limit_total;

// Total number of shards, used in probabilistic queries
size_t num_shards;
bool enable_cutoff;

// Set but empty means no fields should be returned
std::optional<FieldReturnList> return_fields;
search::QueryParams query_params;
Expand Down Expand Up @@ -112,8 +124,12 @@ class ShardDocIndex {
ShardDocIndex(std::shared_ptr<DocIndex> index);

// Perform search on all indexed documents and return results.
SearchResult Search(const OpArgs& op_args, const SearchParams& params,
search::SearchAlgorithm* search_algo) const;
io::Result<SearchResult, facade::ErrorReply> Search(const OpArgs& op_args,
const SearchParams& params,
search::SearchAlgorithm* search_algo) const;

bool Refill(const OpArgs& op_args, const SearchParams& params,
search::SearchAlgorithm* search_algo, SearchResult* result) const;

// Clears internal data. Traverses all matching documents and assigns ids.
void Rebuild(const OpArgs& op_args);
Expand All @@ -126,8 +142,13 @@ class ShardDocIndex {

DocIndexInfo GetInfo() const;

private:
void Serialize(const OpArgs& op_args, const SearchParams& params,
absl::Span<DocResult> docs) const;

private:
std::shared_ptr<const DocIndex> base_;
size_t write_epoch_;
search::FieldIndices indices_;
DocKeyIndex key_index_;
};
Expand Down

0 comments on commit fe7559a

Please sign in to comment.