From 1741f15e8f77f1345a57db6216897fc8961c466b Mon Sep 17 00:00:00 2001 From: Stepan Bagritsevich Date: Thu, 2 May 2024 14:23:08 +0000 Subject: [PATCH] fix(zset): fix random in ZRANDMEMBER command fixes dragonflydb#2850 Signed-off-by: Stepan Bagritsevich --- src/server/zset_family.cc | 130 ++++++++++++++++++++++++++---- src/server/zset_family_test.cc | 141 +++++++++++++++++++++++++-------- 2 files changed, 224 insertions(+), 47 deletions(-) diff --git a/src/server/zset_family.cc b/src/server/zset_family.cc index 22dd8e0452c9..38810237a24d 100644 --- a/src/server/zset_family.cc +++ b/src/server/zset_family.cc @@ -222,6 +222,70 @@ OpResult FindZEntry(const ZParams& zparams, const OpArgs& return DbSlice::ItAndUpdater{add_res.it, add_res.exp_it, std::move(add_res.post_updater)}; } +using RandomPick = std::size_t; + +class PicksGenerator { + public: + virtual RandomPick Generate() = 0; + virtual ~PicksGenerator() = default; +}; + +class NonUniquePicksGenerator : public PicksGenerator { + public: + NonUniquePicksGenerator(std::size_t total_size) : total_size_(total_size) { + CHECK_GT(total_size, std::size_t(0)); + } + + RandomPick Generate() override { + return absl::Uniform(bitgen_, 0u, total_size_); + } + + private: + const std::size_t total_size_; + absl::BitGen bitgen_{}; +}; + +/* + * Generates unique index in O(1). + * + * picks_count specifies the number of random indexes to be generated. + * In other words, this is the number of times the Generate() function is called. + * + * The class uses Robert Floyd's sampling algorithm + * https://dl.acm.org/doi/pdf/10.1145/30401.315746 + * */ +class UniquePicksGenerator : public PicksGenerator { + public: + UniquePicksGenerator(std::size_t picks_count, std::size_t total_size) + : picked_indexes_(picks_count) { + CHECK_GE(total_size, picks_count); + current_random_limit_ = total_size - picks_count; + } + + RandomPick Generate() override { + const std::size_t max_index = current_random_limit_++; + const RandomPick random_index = absl::Uniform(bitgen_, 0u, max_index + 1u); + + if (!IndexWasPicked(random_index)) { + picked_indexes_.insert(random_index); + return random_index; + } else { + picked_indexes_.insert(max_index); + return max_index; + } + } + + private: + bool IndexWasPicked(RandomPick pick) { + return picked_indexes_.find(pick) != picked_indexes_.end(); + } + + private: + std::size_t current_random_limit_; + std::unordered_set picked_indexes_; + absl::BitGen bitgen_{}; +}; + bool ScoreToLongLat(const std::optional& val, double* xy) { if (!val.has_value()) return false; @@ -1702,6 +1766,51 @@ OpResult OpScan(const OpArgs& op_args, std::string_view key, uint64_t return res; } +OpResult OpRandMember(int count, const ZSetFamily::RangeParams& params, + const OpArgs& op_args, string_view key) { + auto it = op_args.shard->db_slice().FindReadOnly(op_args.db_cntx, key, OBJ_ZSET); + if (!it) + return it.status(); + + // Action::RANGE is a read-only operation, but requires const_cast + PrimeValue& pv = const_cast(it.value()->second); + + const std::size_t size = pv.Size(); + const std::size_t picks_count = + count >= 0 ? std::min(static_cast(count), size) : std::abs(count); + + ScoredArray result{picks_count}; + auto generator = [count, picks_count, size]() -> std::unique_ptr { + if (count >= 0) { + return std::make_unique(picks_count, size); + } else { + return std::make_unique(size); + } + }(); + + if (picks_count * static_cast(std::log2(size)) < size) { + for (std::size_t i = 0; i < picks_count; i++) { + const std::size_t picked_index = generator->Generate(); + + IntervalVisitor iv{Action::RANGE, params, &pv}; + iv(ZSetFamily::IndexInterval{picked_index, picked_index}); + + result[i] = iv.PopResult().front(); + } + } else { + IntervalVisitor iv{Action::RANGE, params, &pv}; + iv(ZSetFamily::IndexInterval{0, -1}); + + ScoredArray all_elements = iv.PopResult(); + + for (std::size_t i = 0; i < picks_count; i++) { + result[i] = all_elements[generator->Generate()]; + } + } + + return result; +} + void ZAddGeneric(string_view key, const ZParams& zparams, ScoredMemberSpan memb_sp, ConnectionContext* cntx) { auto cb = [&](Transaction* t, EngineShard* shard) { @@ -2323,16 +2432,14 @@ void ZSetFamily::ZRandMember(CmdArgList args, ConnectionContext* cntx) { if (args.size() > 3) return cntx->SendError(WrongNumArgsError("ZRANDMEMBER")); - ZRangeSpec range_spec; - range_spec.interval = IndexInterval(0, -1); - CmdArgParser parser{args}; string_view key = parser.Next(); bool is_count = parser.HasNext(); int count = is_count ? parser.Next() : 1; - range_spec.params.with_scores = static_cast(parser.Check("WITHSCORES").IgnoreCase()); + ZSetFamily::RangeParams params; + params.with_scores = static_cast(parser.Check("WITHSCORES").IgnoreCase()); if (parser.HasNext()) return cntx->SendError(absl::StrCat("Unsupported option:", string_view(parser.Next()))); @@ -2340,26 +2447,17 @@ void ZSetFamily::ZRandMember(CmdArgList args, ConnectionContext* cntx) { if (auto err = parser.Error(); err) return cntx->SendError(err->MakeReply()); - bool sign = count < 0; - range_spec.params.limit = std::abs(count); - const auto cb = [&](Transaction* t, EngineShard* shard) { - return OpRange(range_spec, t->GetOpArgs(shard), key); + return OpRandMember(count, params, t->GetOpArgs(shard), key); }; OpResult result = cntx->transaction->ScheduleSingleHopT(cb); auto* rb = static_cast(cntx->reply_builder()); if (result) { - if (sign && !result->empty()) { - for (auto i = result->size(); i < range_spec.params.limit; ++i) { - // we can return duplicate elements, so first is OK - result->push_back(result->front()); - } - } - rb->SendScoredArray(result.value(), range_spec.params.with_scores); + rb->SendScoredArray(result.value(), params.with_scores); } else if (result.status() == OpStatus::KEY_NOTFOUND) { if (is_count) { - rb->SendScoredArray(ScoredArray(), range_spec.params.with_scores); + rb->SendScoredArray(ScoredArray(), params.with_scores); } else { rb->SendNull(); } diff --git a/src/server/zset_family_test.cc b/src/server/zset_family_test.cc index 093c3e647a62..e3150b5e99c2 100644 --- a/src/server/zset_family_test.cc +++ b/src/server/zset_family_test.cc @@ -77,53 +77,132 @@ TEST_F(ZSetFamilyTest, ZRem) { } TEST_F(ZSetFamilyTest, ZRandMember) { - auto resp = Run({ - "zadd", - "x", - "1", - "a", - "2", - "b", - "3", - "c", - }); + auto resp = Run({"ZAdd", "x", "1", "a", "2", "b", "3", "c"}); + EXPECT_THAT(resp, IntArg(3)); + + // Test if count > 0 resp = Run({"ZRandMember", "x"}); ASSERT_THAT(resp, ArgType(RespExpr::STRING)); - EXPECT_THAT(resp, "a"); + EXPECT_THAT(resp, AnyOf("a", "b", "c")); + + resp = Run({"ZRandMember", "x", "1"}); + ASSERT_THAT(resp, ArgType(RespExpr::STRING)); + EXPECT_THAT(resp, AnyOf("a", "b", "c")); resp = Run({"ZRandMember", "x", "2"}); - ASSERT_THAT(resp, ArgType(RespExpr::ARRAY)); - EXPECT_THAT(resp.GetVec(), UnorderedElementsAre("a", "b")); + ASSERT_THAT(resp, ArrLen(2)); + EXPECT_THAT(resp.GetVec(), IsSubsetOf({"a", "b", "c"})); - resp = Run({"ZRandMember", "x", "0"}); - ASSERT_THAT(resp, ArgType(RespExpr::ARRAY)); - EXPECT_EQ(resp.GetVec().size(), 0); + resp = Run({"ZRandMember", "x", "3"}); + ASSERT_THAT(resp, ArrLen(3)); + EXPECT_THAT(resp.GetVec(), UnorderedElementsAre("a", "b", "c")); - resp = Run({"ZRandMember", "k"}); - ASSERT_THAT(resp, ArgType(RespExpr::NIL)); + // Test if count < 0 + std::unordered_set expected_entries({"a", "b", "c"}); - resp = Run({"ZRandMember", "k", "2"}); - ASSERT_THAT(resp, ArgType(RespExpr::ARRAY)); - EXPECT_EQ(resp.GetVec().size(), 0); + auto expect_elements = [](const auto& expected_elements, const auto& actual_elements) { + for (const auto& x : actual_elements) { + if (expected_elements.find(x) == expected_elements.end()) { + return false; + } + } + return true; + }; + + auto parse_response = [](const auto& resp) { + auto vec = resp.GetVec(); + + std::vector entries; + std::transform(vec.begin(), vec.end(), std::back_inserter(entries), + [](auto& x) { return x.GetString(); }); + return entries; + }; + + resp = Run({"ZRandMember", "x", "-1"}); + ASSERT_THAT(resp, ArgType(RespExpr::STRING)); + EXPECT_THAT(resp, AnyOf("a", "b", "c")); - resp = Run({"ZRandMember", "x", "-5"}); - ASSERT_THAT(resp, ArrLen(5)); - EXPECT_THAT(resp.GetVec(), ElementsAre("a", "b", "c", "a", "a")); + resp = Run({"ZRandMember", "x", "-2"}); + ASSERT_THAT(resp, ArrLen(2)); + EXPECT_TRUE(expect_elements(expected_entries, parse_response(resp))); - resp = Run({"ZRandMember", "x", "5"}); + resp = Run({"ZRandMember", "x", "-3"}); + ASSERT_THAT(resp, ArrLen(3)); + EXPECT_TRUE(expect_elements(expected_entries, parse_response(resp))); + + // Test if count < 0, but the absolute value is larger than the size of the sorted set + resp = Run({"ZRandMember", "x", "-15"}); + ASSERT_THAT(resp, ArrLen(15)); + EXPECT_TRUE(expect_elements(expected_entries, parse_response(resp))); + + // Test if count is 0 + ASSERT_THAT(Run({"ZRandMember", "x", "0"}), ArrLen(0)); + + // Test if count is larger than the size of the sorted set + resp = Run({"ZRandMember", "x", "15"}); ASSERT_THAT(resp, ArrLen(3)); EXPECT_THAT(resp.GetVec(), UnorderedElementsAre("a", "b", "c")); - resp = Run({"ZRandMember", "x", "-5", "WITHSCORES"}); - ASSERT_THAT(resp, ArrLen(10)); - EXPECT_THAT(resp.GetVec(), ElementsAre("a", "1", "b", "2", "c", "3", "a", "1", "a", "1")); + // Test if sorted set is empty + EXPECT_THAT(Run({"ZAdd", "empty::zset", "1", "one"}), IntArg(1)); + EXPECT_THAT(Run({"ZRem", "empty::zset", "one"}), IntArg(1)); + ASSERT_THAT(Run({"ZRandMember", "empty::zset", "0"}), ArrLen(0)); + ASSERT_THAT(Run({"ZRandMember", "empty::zset", "3"}), ArrLen(0)); + ASSERT_THAT(Run({"ZRandMember", "empty::zset", "-4"}), ArrLen(0)); + + // Test if key does not exist + ASSERT_THAT(Run({"ZRandMember", "y"}), ArgType(RespExpr::NIL)); + ASSERT_THAT(Run({"ZRandMember", "y", "0"}), ArrLen(0)); + + // Test WITHSCORES + using ZSetEntry = std::pair; + std::set expected_entries_with_scores{{"a", "1"}, {"b", "2"}, {"c", "3"}}; + + auto parse_response_with_scores = [](const auto& resp) { + auto vec = resp.GetVec(); + + std::vector entries; + for (std::size_t i = 1; i < vec.size(); i += 2) { + entries.emplace_back(vec[i - 1].GetString(), vec[i].GetString()); + } + return entries; + }; + + resp = Run({"ZRandMember", "x", "1", "WITHSCORES"}); + ASSERT_THAT(resp, ArrLen(2)); + EXPECT_THAT(parse_response_with_scores(resp), IsSubsetOf(expected_entries_with_scores)); + + resp = Run({"ZRandMember", "x", "2", "WITHSCORES"}); + ASSERT_THAT(resp, ArrLen(4)); + EXPECT_THAT(parse_response_with_scores(resp), IsSubsetOf(expected_entries_with_scores)); resp = Run({"ZRandMember", "x", "3", "WITHSCORES"}); ASSERT_THAT(resp, ArrLen(6)); - EXPECT_THAT(resp.GetVec(), UnorderedElementsAre("a", "1", "b", "2", "c", "3")); + EXPECT_THAT(parse_response_with_scores(resp), + UnorderedElementsAre(std::make_pair("a", "1"), std::make_pair("b", "2"), + std::make_pair("c", "3"))); - resp = Run({"ZRandMember", "x", "3", "WITHSCORES", "test"}); - EXPECT_THAT(resp, ErrArg("wrong number of arguments")); + resp = Run({"ZRandMember", "x", "15", "WITHSCORES"}); + ASSERT_THAT(resp, ArrLen(6)); + EXPECT_THAT(parse_response_with_scores(resp), + UnorderedElementsAre(std::make_pair("a", "1"), std::make_pair("b", "2"), + std::make_pair("c", "3"))); + + resp = Run({"ZRandMember", "x", "-1", "WITHSCORES"}); + ASSERT_THAT(resp, ArrLen(2)); + EXPECT_TRUE(expect_elements(expected_entries_with_scores, parse_response_with_scores(resp))); + + resp = Run({"ZRandMember", "x", "-2", "WITHSCORES"}); + ASSERT_THAT(resp, ArrLen(4)); + EXPECT_TRUE(expect_elements(expected_entries_with_scores, parse_response_with_scores(resp))); + + resp = Run({"ZRandMember", "x", "-3", "WITHSCORES"}); + ASSERT_THAT(resp, ArrLen(6)); + EXPECT_TRUE(expect_elements(expected_entries_with_scores, parse_response_with_scores(resp))); + + resp = Run({"ZRandMember", "x", "-15", "WITHSCORES"}); + ASSERT_THAT(resp, ArrLen(30)); + EXPECT_TRUE(expect_elements(expected_entries_with_scores, parse_response_with_scores(resp))); } TEST_F(ZSetFamilyTest, ZMScore) {