Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(search): Multishard cutoffs #1924

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/core/search/ast_expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,11 @@ struct AstNode : public NodeVariants {
const NodeVariants& Variant() const& {
return *this;
}

// Aggregations: KNN, SORTBY. They reorder result sets and optionally reduce them.
bool IsAggregation() const {
return std::holds_alternative<AstKnnNode>(Variant());
}
};

using AstExpr = AstNode;
Expand Down
6 changes: 6 additions & 0 deletions src/core/search/search.cc
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,7 @@ struct BasicSearch {
profile_builder_ ? make_optional(profile_builder_->Take()) : nullopt;

size_t total = result.Size();

return SearchResult{total,
max(total, preagg_total_),
result.Take(limit_),
Expand All @@ -423,6 +424,7 @@ struct BasicSearch {
std::move(error_)};
}

private:
const FieldIndices* indices_;
size_t limit_;

Expand Down Expand Up @@ -599,4 +601,8 @@ void SearchAlgorithm::EnableProfiling() {
profiling_enabled_ = true;
}

bool SearchAlgorithm::IsProfilingEnabled() const {
return profiling_enabled_;
}

} // namespace dfly::search
1 change: 1 addition & 0 deletions src/core/search/search.h
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ class SearchAlgorithm {
std::optional<AggregationInfo> HasAggregation() const;

void EnableProfiling();
bool IsProfilingEnabled() const;

private:
bool profiling_enabled_ = false;
Expand Down
11 changes: 11 additions & 0 deletions src/facade/reply_capture.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "facade/reply_capture.h"

#include "base/logging.h"
#include "facade/conn_context.h"
#include "reply_capture.h"

#define SKIP_LESS(needed) \
Expand Down Expand Up @@ -152,6 +153,16 @@ void CapturingReplyBuilder::CollapseFilledCollections() {
}
}

CapturingReplyBuilder::ScopeCapture::ScopeCapture(CapturingReplyBuilder* crb,
ConnectionContext* cntx)
: cntx_{cntx} {
old_ = cntx->Inject(crb);
}

CapturingReplyBuilder::ScopeCapture::~ScopeCapture() {
cntx_->Inject(old_);
}

CapturingReplyBuilder::CollectionPayload::CollectionPayload(unsigned len, CollectionType type)
: len{len}, type{type}, arr{} {
arr.reserve(type == MAP ? len * 2 : len);
Expand Down
11 changes: 11 additions & 0 deletions src/facade/reply_capture.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

namespace facade {

class ConnectionContext;
struct CaptureVisitor;

// CapturingReplyBuilder allows capturing replies and retrieveing them with Take().
Expand Down Expand Up @@ -66,6 +67,16 @@ class CapturingReplyBuilder : public RedisReplyBuilder {
bool with_scores;
};

public:
struct ScopeCapture {
ScopeCapture(CapturingReplyBuilder* crb, ConnectionContext* cntx);
~ScopeCapture();

private:
SinkReplyBuilder* old_;
ConnectionContext* cntx_;
};

public:
CapturingReplyBuilder(ReplyMode mode = ReplyMode::FULL)
: RedisReplyBuilder{nullptr}, reply_mode_{mode}, stack_{}, current_{} {
Expand Down
8 changes: 5 additions & 3 deletions src/server/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -215,11 +215,13 @@ template <typename RandGen> std::string GetRandomHex(RandGen& gen, size_t len) {
// truthy value;
template <typename T> struct AggregateValue {
bool operator=(T val) {
if (!bool(val))
return false;

std::lock_guard l{mu_};
if (!bool(current_) && bool(val)) {
if (!bool(current_))
current_ = val;
}
return bool(val);
return true;
}

T operator*() {
Expand Down
12 changes: 3 additions & 9 deletions src/server/main_service.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1393,13 +1393,6 @@ void Service::Unwatch(CmdArgList args, ConnectionContext* cntx) {
return (*cntx)->SendOk();
}

template <typename F> void WithReplies(CapturingReplyBuilder* crb, ConnectionContext* cntx, F&& f) {
SinkReplyBuilder* old_rrb = nullptr;
old_rrb = cntx->Inject(crb);
f();
cntx->Inject(old_rrb);
}

optional<CapturingReplyBuilder::Payload> Service::FlushEvalAsyncCmds(ConnectionContext* cntx,
bool force) {
auto& info = cntx->conn_state.script_info;
Expand All @@ -1415,9 +1408,10 @@ optional<CapturingReplyBuilder::Payload> Service::FlushEvalAsyncCmds(ConnectionC
cntx->transaction->MultiSwitchCmd(eval_cid);

CapturingReplyBuilder crb{ReplyMode::ONLY_ERR};
WithReplies(&crb, cntx, [&] {
{
CapturingReplyBuilder::ScopeCapture capture{&crb, cntx};
MultiCommandSquasher::Execute(absl::MakeSpan(info->async_cmds), cntx, this, true, true);
});
}

info->async_cmds_heap_mem = 0;
info->async_cmds.clear();
Expand Down
123 changes: 101 additions & 22 deletions src/server/search/doc_index.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,13 +56,44 @@ const absl::flat_hash_map<string_view, search::SchemaField::FieldType> kSchemaTy
{"NUMERIC"sv, search::SchemaField::NUMERIC},
{"VECTOR"sv, search::SchemaField::VECTOR}};

size_t GetProbabilisticBound(size_t hits, size_t requested, optional<search::AggregationInfo> agg) {
auto intlog2 = [](size_t x) {
size_t l = 0;
while (x >>= 1)
++l;
return l;
};

if (hits == 0 || requested == 0)
return 0;

size_t shards = shard_set->size();

// Estimate how much every shard has with at least 99% prob
size_t avg_shard_min = hits * intlog2(hits) / (12 + shard_set->size() / 10);
avg_shard_min -= min(avg_shard_min, min(hits, size_t(5)));
Comment on lines +73 to +74
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you explain the rationale here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Experimentally from #1892 😄 Those formulas might not be final

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd link to that, at least until you change it :)


// If it turns out that we might have not enough results to cover the request, don't skip any
if (avg_shard_min * shards < requested)
return requested;

// If all shards have at least avg min, keep the bare minimum needed to cover the request
size_t limit = requested / shards + 1;

// Aggregations like SORTBY and KNN reorder the result and thus introduce some variance
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't we have to fetch all matching for SORTBY?

Copy link
Contributor Author

@dranikpg dranikpg Oct 30, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No/Yes 🤓 We have to fetch all, but not serialize all. That is what the doc reference is for. If we find out when building the reply that some entries should have been included but are not serialized, we'll refill (see BuildSortedOrder())

if (agg.has_value())
limit += max(requested / 4 + 1, 3UL);

return limit;
}

} // namespace

bool SerializedSearchDoc::operator<(const SerializedSearchDoc& other) const {
bool DocResult::operator<(const DocResult& other) const {
return this->score < other.score;
}

bool SerializedSearchDoc::operator>=(const SerializedSearchDoc& other) const {
bool DocResult::operator>=(const DocResult& other) const {
return this->score >= other.score;
}

Expand Down Expand Up @@ -171,10 +202,11 @@ bool DocIndex::Matches(string_view key, unsigned obj_code) const {
}

ShardDocIndex::ShardDocIndex(shared_ptr<DocIndex> index)
: base_{std::move(index)}, indices_{{}, nullptr}, key_index_{} {
: base_{std::move(index)}, indices_{{}, nullptr}, key_index_{}, write_epoch_{0} {
}

void ShardDocIndex::Rebuild(const OpArgs& op_args, PMR_NS::memory_resource* mr) {
write_epoch_++;
key_index_ = DocKeyIndex{};
indices_ = search::FieldIndices{base_->schema, mr};

Expand All @@ -183,11 +215,13 @@ void ShardDocIndex::Rebuild(const OpArgs& op_args, PMR_NS::memory_resource* mr)
}

void ShardDocIndex::AddDoc(string_view key, const DbContext& db_cntx, const PrimeValue& pv) {
write_epoch_++;
auto accessor = GetAccessor(db_cntx, pv);
indices_.Add(key_index_.Add(key), accessor.get());
}

void ShardDocIndex::RemoveDoc(string_view key, const DbContext& db_cntx, const PrimeValue& pv) {
write_epoch_++;
auto accessor = GetAccessor(db_cntx, pv);
DocId id = key_index_.Remove(key);
indices_.Remove(id, accessor.get());
Expand All @@ -197,38 +231,83 @@ bool ShardDocIndex::Matches(string_view key, unsigned obj_code) const {
return base_->Matches(key, obj_code);
}

SearchResult ShardDocIndex::Search(const OpArgs& op_args, const SearchParams& params,
search::SearchAlgorithm* search_algo) const {
io::Result<SearchResult, facade::ErrorReply> ShardDocIndex::Search(
const OpArgs& op_args, const SearchParams& params, search::SearchAlgorithm* search_algo) const {
size_t requested_count = params.limit_offset + params.limit_total;
auto search_results = search_algo->Search(&indices_, requested_count);
if (!search_results.error.empty())
return nonstd::make_unexpected(facade::ErrorReply(std::move(search_results.error)));

size_t return_count = min(requested_count, search_results.ids.size());

// Probabilistic optimization: If we are about 99% sure that all shards in total fetch more
// results than needed to statisfy the search request, we can avoid serializing some of the last
// result hits as they likely won't be needed. The `cutoff_bound` indicates how much entries it's
// reasonable to serialize directly, for the rest only id's are stored. In the 1% case they are
// either serialized on another hop or the query is fully repeated without this optimization.
size_t cuttoff_bound = requested_count;
if (params.enable_cutoff && !params.IdsOnly()) {
cuttoff_bound = GetProbabilisticBound(search_results.pre_aggregation_total, requested_count,
search_algo->HasAggregation());
}

vector<DocResult> out(return_count);
auto shard_id = EngineShard::tlocal()->shard_id();
auto& scores = search_results.scores;
for (size_t i = 0; i < out.size(); i++) {
out[i].value = DocResult::DocReference{shard_id, search_results.ids[i], i < cuttoff_bound};
out[i].score = scores.empty() ? search::ResultScore{} : std::move(scores[i]);
}

Serialize(op_args, params, absl::MakeSpan(out));

return SearchResult{write_epoch_, search_results.total, std::move(out),
std::move(search_results.profile)};
}

bool ShardDocIndex::Refill(const OpArgs& op_args, const SearchParams& params,
search::SearchAlgorithm* search_algo, SearchResult* result) const {
// If no writes occured, serialize remaining entries without breaking correctness
if (result->write_epoch == write_epoch_) {
Serialize(op_args, params, absl::MakeSpan(result->docs));
return true;
}

// We're already on the cold path and we don't wanna gamble any more
DCHECK(!params.enable_cutoff);

auto new_result = Search(op_args, params, search_algo);
CHECK(new_result.has_value()); // Query should be valid since it passed first step

*result = std::move(new_result.value());
return false;
}

void ShardDocIndex::Serialize(const OpArgs& op_args, const SearchParams& params,
absl::Span<DocResult> docs) const {
auto& db_slice = op_args.shard->db_slice();
auto search_results = search_algo->Search(&indices_, params.limit_offset + params.limit_total);

if (!search_results.error.empty())
return SearchResult{facade::ErrorReply{std::move(search_results.error)}};
for (auto& doc : docs) {
if (!holds_alternative<DocResult::DocReference>(doc.value))
continue;

vector<SerializedSearchDoc> out;
out.reserve(search_results.ids.size());
auto ref = get<DocResult::DocReference>(doc.value);
if (!ref.requested)
return;

size_t expired_count = 0;
for (size_t i = 0; i < search_results.ids.size(); i++) {
auto key = key_index_.Get(search_results.ids[i]);
auto it = db_slice.Find(op_args.db_cntx, key, base_->GetObjCode());
string key{key_index_.Get(ref.doc_id)};

auto it = db_slice.Find(op_args.db_cntx, key, base_->GetObjCode());
if (!it || !IsValid(*it)) { // Item must have expired
expired_count++;
doc.value = DocResult::SerializedValue{std::move(key), {}};
continue;
}

auto accessor = GetAccessor(op_args.db_cntx, (*it)->second);
auto doc_data = params.return_fields ? accessor->Serialize(base_->schema, *params.return_fields)
: accessor->Serialize(base_->schema);

auto score =
search_results.scores.empty() ? std::monostate{} : std::move(search_results.scores[i]);
out.push_back(SerializedSearchDoc{string{key}, std::move(doc_data), std::move(score)});
doc.value = DocResult::SerializedValue{std::move(key), std::move(doc_data)};
}

return SearchResult{search_results.total - expired_count, std::move(out),
std::move(search_results.profile)};
}

DocIndexInfo ShardDocIndex::GetInfo() const {
Expand Down