Skip to content

Commit

Permalink
feat(search): Multishard cutoffs
Browse files Browse the repository at this point in the history
Signed-off-by: Vladislav Oleshko <[email protected]>
  • Loading branch information
dranikpg committed Oct 20, 2023
1 parent 1d02e12 commit 30ea6b8
Show file tree
Hide file tree
Showing 13 changed files with 571 additions and 173 deletions.
5 changes: 5 additions & 0 deletions src/core/search/ast_expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,11 @@ struct AstNode : public NodeVariants {
const NodeVariants& Variant() const& {
return *this;
}

// Aggregations reduce and re-order result sets.
bool IsAggregation() const {
return std::holds_alternative<AstKnnNode>(Variant());
}
};

using AstExpr = AstNode;
Expand Down
4 changes: 4 additions & 0 deletions src/core/search/search.cc
Original file line number Diff line number Diff line change
Expand Up @@ -599,4 +599,8 @@ void SearchAlgorithm::EnableProfiling() {
profiling_enabled_ = true;
}

bool SearchAlgorithm::IsProfilingEnabled() const {
return profiling_enabled_;
}

} // namespace dfly::search
1 change: 1 addition & 0 deletions src/core/search/search.h
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ class SearchAlgorithm {
std::optional<AggregationInfo> HasAggregation() const;

void EnableProfiling();
bool IsProfilingEnabled() const;

private:
bool profiling_enabled_ = false;
Expand Down
11 changes: 11 additions & 0 deletions src/facade/reply_capture.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "facade/reply_capture.h"

#include "base/logging.h"
#include "facade/conn_context.h"
#include "reply_capture.h"

#define SKIP_LESS(needed) \
Expand Down Expand Up @@ -150,6 +151,16 @@ void CapturingReplyBuilder::CollapseFilledCollections() {
}
}

CapturingReplyBuilder::ScopeCapture::ScopeCapture(CapturingReplyBuilder* crb,
ConnectionContext* cntx)
: cntx_{cntx} {
old_ = cntx->Inject(crb);
}

CapturingReplyBuilder::ScopeCapture::~ScopeCapture() {
cntx_->Inject(old_);
}

CapturingReplyBuilder::CollectionPayload::CollectionPayload(unsigned len, CollectionType type)
: len{len}, type{type}, arr{} {
arr.reserve(type == MAP ? len * 2 : len);
Expand Down
11 changes: 11 additions & 0 deletions src/facade/reply_capture.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

namespace facade {

class ConnectionContext;
struct CaptureVisitor;

// CapturingReplyBuilder allows capturing replies and retrieveing them with Take().
Expand Down Expand Up @@ -66,6 +67,16 @@ class CapturingReplyBuilder : public RedisReplyBuilder {
bool with_scores;
};

public:
struct ScopeCapture {
ScopeCapture(CapturingReplyBuilder* crb, ConnectionContext* cntx);
~ScopeCapture();

private:
SinkReplyBuilder* old_;
ConnectionContext* cntx_;
};

public:
CapturingReplyBuilder(ReplyMode mode = ReplyMode::FULL)
: RedisReplyBuilder{nullptr}, reply_mode_{mode}, stack_{}, current_{} {
Expand Down
8 changes: 5 additions & 3 deletions src/server/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -214,11 +214,13 @@ template <typename RandGen> std::string GetRandomHex(RandGen& gen, size_t len) {
// truthy value;
template <typename T> struct AggregateValue {
bool operator=(T val) {
if (!bool(val))
return false;

std::lock_guard l{mu_};
if (!bool(current_) && bool(val)) {
if (!bool(current_))
current_ = val;
}
return bool(val);
return true;
}

T operator*() {
Expand Down
12 changes: 3 additions & 9 deletions src/server/main_service.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1363,13 +1363,6 @@ void Service::Unwatch(CmdArgList args, ConnectionContext* cntx) {
return (*cntx)->SendOk();
}

template <typename F> void WithReplies(CapturingReplyBuilder* crb, ConnectionContext* cntx, F&& f) {
SinkReplyBuilder* old_rrb = nullptr;
old_rrb = cntx->Inject(crb);
f();
cntx->Inject(old_rrb);
}

optional<CapturingReplyBuilder::Payload> Service::FlushEvalAsyncCmds(ConnectionContext* cntx,
bool force) {
auto& info = cntx->conn_state.script_info;
Expand All @@ -1385,9 +1378,10 @@ optional<CapturingReplyBuilder::Payload> Service::FlushEvalAsyncCmds(ConnectionC
cntx->transaction->MultiSwitchCmd(eval_cid);

CapturingReplyBuilder crb{ReplyMode::ONLY_ERR};
WithReplies(&crb, cntx, [&] {
{
CapturingReplyBuilder::ScopeCapture capture{&crb, cntx};
MultiCommandSquasher::Execute(absl::MakeSpan(info->async_cmds), cntx, this, true, true);
});
}

info->async_cmds_heap_mem = 0;
info->async_cmds.clear();
Expand Down
106 changes: 84 additions & 22 deletions src/server/search/doc_index.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,13 +56,33 @@ const absl::flat_hash_map<string_view, search::SchemaField::FieldType> kSchemaTy
{"NUMERIC"sv, search::SchemaField::NUMERIC},
{"VECTOR"sv, search::SchemaField::VECTOR}};

size_t GetProbabilisticBound(size_t shards, size_t hits, size_t requested, bool is_aggregation) {
auto intlog2 = [](size_t x) {
size_t l = 0;
while (x >>= 1)
++l;
return l;
};
size_t avg_shard_min = hits * intlog2(hits) / (12 + shards / 10);
avg_shard_min -= min(avg_shard_min, min(hits, size_t(5)));

// VLOG(0) << "PROB BOUND " << hits << " " << shards << " " << requested << " => " <<
// avg_shard_min
// << " diffb " << requested / shards + 1 << " & " << requested;

if (!is_aggregation && avg_shard_min * shards >= requested)
return requested / shards + 1;

return min(hits, requested);
}

} // namespace

bool SerializedSearchDoc::operator<(const SerializedSearchDoc& other) const {
bool DocResult::operator<(const DocResult& other) const {
return this->score < other.score;
}

bool SerializedSearchDoc::operator>=(const SerializedSearchDoc& other) const {
bool DocResult::operator>=(const DocResult& other) const {
return this->score >= other.score;
}

Expand Down Expand Up @@ -162,10 +182,11 @@ bool DocIndex::Matches(string_view key, unsigned obj_code) const {
}

ShardDocIndex::ShardDocIndex(shared_ptr<DocIndex> index)
: base_{std::move(index)}, indices_{{}, nullptr}, key_index_{} {
: base_{std::move(index)}, write_epoch_{0}, indices_{{}, nullptr}, key_index_{} {
}

void ShardDocIndex::Rebuild(const OpArgs& op_args, PMR_NS::memory_resource* mr) {
write_epoch_++;
key_index_ = DocKeyIndex{};
indices_ = search::FieldIndices{base_->schema, mr};

Expand All @@ -174,11 +195,13 @@ void ShardDocIndex::Rebuild(const OpArgs& op_args, PMR_NS::memory_resource* mr)
}

void ShardDocIndex::AddDoc(string_view key, const DbContext& db_cntx, const PrimeValue& pv) {
write_epoch_++;
auto accessor = GetAccessor(db_cntx, pv);
indices_.Add(key_index_.Add(key), accessor.get());
}

void ShardDocIndex::RemoveDoc(string_view key, const DbContext& db_cntx, const PrimeValue& pv) {
write_epoch_++;
auto accessor = GetAccessor(db_cntx, pv);
DocId id = key_index_.Remove(key);
indices_.Remove(id, accessor.get());
Expand All @@ -188,38 +211,77 @@ bool ShardDocIndex::Matches(string_view key, unsigned obj_code) const {
return base_->Matches(key, obj_code);
}

SearchResult ShardDocIndex::Search(const OpArgs& op_args, const SearchParams& params,
search::SearchAlgorithm* search_algo) const {
io::Result<SearchResult, facade::ErrorReply> ShardDocIndex::Search(
const OpArgs& op_args, const SearchParams& params, search::SearchAlgorithm* search_algo) const {
auto search_results = search_algo->Search(&indices_);
if (!search_results.error.empty())
return nonstd::make_unexpected(facade::ErrorReply(std::move(search_results.error)));

size_t requested_count = params.limit_offset + params.limit_total;
size_t serialize_count = min(requested_count, search_results.ids.size());

size_t cuttoff_bound = serialize_count;
if (params.enable_cutoff && !params.IdsOnly())
cuttoff_bound =
GetProbabilisticBound(params.num_shards, search_results.ids.size(), requested_count,
search_algo->HasAggregation().has_value());

VLOG(0) << "Requested " << requested_count << " got " << search_results.ids.size() << " cutoff "
<< cuttoff_bound;

vector<DocResult> out(serialize_count);
auto shard_id = EngineShard::tlocal()->shard_id();
for (size_t i = 0; i < out.size(); i++) {
out[i].value = DocResult::DocReference{shard_id, search_results.ids[i], i < cuttoff_bound};
out[i].score =
search_results.scores.empty() ? search::ResultScore{} : std::move(search_results.scores[i]);
}

Serialize(op_args, params, absl::MakeSpan(out));

return SearchResult{write_epoch_, search_results.ids.size(), std::move(out),
std::move(search_results.profile)};
}

bool ShardDocIndex::Refill(const OpArgs& op_args, const SearchParams& params,
search::SearchAlgorithm* search_algo, SearchResult* result) const {
if (result->write_epoch == write_epoch_) {
Serialize(op_args, params, absl::MakeSpan(result->docs));
return true;
}

DCHECK(!params.enable_cutoff);
auto new_result = Search(op_args, params, search_algo);
CHECK(new_result.has_value());
*result = std::move(new_result.value());
return false;
}

void ShardDocIndex::Serialize(const OpArgs& op_args, const SearchParams& params,
absl::Span<DocResult> docs) const {
auto& db_slice = op_args.shard->db_slice();
auto search_results = search_algo->Search(&indices_, params.limit_offset + params.limit_total);

if (!search_results.error.empty())
return SearchResult{facade::ErrorReply{std::move(search_results.error)}};
for (auto& doc : docs) {
if (!holds_alternative<DocResult::DocReference>(doc.value))
continue;

vector<SerializedSearchDoc> out;
out.reserve(search_results.ids.size());
auto ref = get<DocResult::DocReference>(doc.value);
if (!ref.requested)
return;

size_t expired_count = 0;
for (size_t i = 0; i < search_results.ids.size(); i++) {
auto key = key_index_.Get(search_results.ids[i]);
auto it = db_slice.Find(op_args.db_cntx, key, base_->GetObjCode());
string key{key_index_.Get(ref.doc_id)};

auto it = db_slice.Find(op_args.db_cntx, key, base_->GetObjCode());
if (!it || !IsValid(*it)) { // Item must have expired
expired_count++;
doc.value = DocResult::SerializedValue{std::move(key), {}};
continue;
}

auto accessor = GetAccessor(op_args.db_cntx, (*it)->second);
auto doc_data = params.return_fields ? accessor->Serialize(base_->schema, *params.return_fields)
: accessor->Serialize(base_->schema);

auto score =
search_results.scores.empty() ? std::monostate{} : std::move(search_results.scores[i]);
out.push_back(SerializedSearchDoc{string{key}, std::move(doc_data), std::move(score)});
doc.value = DocResult::SerializedValue{std::move(key), std::move(doc_data)};
}

return SearchResult{search_results.total - expired_count, std::move(out),
std::move(search_results.profile)};
}

DocIndexInfo ShardDocIndex::GetInfo() const {
Expand Down
56 changes: 38 additions & 18 deletions src/server/search/doc_index.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,31 +25,39 @@ using SearchDocData = absl::flat_hash_map<std::string /*field*/, std::string /*v
std::optional<search::SchemaField::FieldType> ParseSearchFieldType(std::string_view name);
std::string_view SearchFieldTypeToString(search::SchemaField::FieldType);

struct SerializedSearchDoc {
std::string key;
SearchDocData values;
struct DocResult {
struct SerializedValue {
std::string key;
SearchDocData values;
};

struct DocReference {
ShardId shard_id;
search::DocId doc_id;
bool requested;
};

std::variant<SerializedValue, DocReference> value;
search::ResultScore score;

bool operator<(const SerializedSearchDoc& other) const;
bool operator>=(const SerializedSearchDoc& other) const;
bool operator<(const DocResult& other) const;
bool operator>=(const DocResult& other) const;
};

struct SearchResult {
SearchResult() = default;
size_t write_epoch = 0; // Write epoch of the index during on the result was created

SearchResult(size_t total_hits, std::vector<SerializedSearchDoc> docs,
std::optional<search::AlgorithmProfile> profile)
: total_hits{total_hits}, docs{std::move(docs)}, profile{std::move(profile)} {
}
size_t total_hits = 0; // total number of hits in shard
std::vector<DocResult> docs; // serialized documents of first hits

SearchResult(facade::ErrorReply error) : error{std::move(error)} {
}
// After combining results from multiple shards and accumulating more documents than initially
// requested, only a subset of all documents will be sent back to the client,
// so it doesn't make sense to serialize strictly all documents in every shard ahead.
// Instead, only documents up to a probablistic bound are serialized, the
// leftover ids and scores are stored in the cutoff tail for use in the "unlikely" scenario.
// size_t num_cutoff = 0;

size_t total_hits;
std::vector<SerializedSearchDoc> docs;
std::optional<search::AlgorithmProfile> profile;

std::optional<facade::ErrorReply> error;
};

struct SearchParams {
Expand All @@ -61,6 +69,10 @@ struct SearchParams {
size_t limit_offset = 0;
size_t limit_total = 10;

// Total number of shards, used in probabilistic queries
size_t num_shards = 0;
bool enable_cutoff = false;

// Set but empty means no fields should be returned
std::optional<FieldReturnList> return_fields;
std::optional<search::SortOption> sort_option;
Expand Down Expand Up @@ -123,8 +135,12 @@ class ShardDocIndex {
ShardDocIndex(std::shared_ptr<DocIndex> index);

// Perform search on all indexed documents and return results.
SearchResult Search(const OpArgs& op_args, const SearchParams& params,
search::SearchAlgorithm* search_algo) const;
io::Result<SearchResult, facade::ErrorReply> Search(const OpArgs& op_args,
const SearchParams& params,
search::SearchAlgorithm* search_algo) const;

bool Refill(const OpArgs& op_args, const SearchParams& params,
search::SearchAlgorithm* search_algo, SearchResult* result) const;

// Return whether base index matches
bool Matches(std::string_view key, unsigned obj_code) const;
Expand All @@ -138,8 +154,12 @@ class ShardDocIndex {
// Clears internal data. Traverses all matching documents and assigns ids.
void Rebuild(const OpArgs& op_args, PMR_NS::memory_resource* mr);

void Serialize(const OpArgs& op_args, const SearchParams& params,
absl::Span<DocResult> docs) const;

private:
std::shared_ptr<const DocIndex> base_;
size_t write_epoch_;
search::FieldIndices indices_;
DocKeyIndex key_index_;
};
Expand Down

0 comments on commit 30ea6b8

Please sign in to comment.