From c7e7f8b10a0beb7b6ff606a677afb999eef3d11a Mon Sep 17 00:00:00 2001 From: Stepan Bagritsevich Date: Wed, 1 May 2024 13:25:27 +0000 Subject: [PATCH] fix(zset): fix random in ZRANDMEMBER command fixes dragonflydb#2850 Signed-off-by: Stepan Bagritsevich --- src/server/zset_family.cc | 121 ++++++++++++++++++++++++---- src/server/zset_family_test.cc | 141 +++++++++++++++++++++++++-------- 2 files changed, 215 insertions(+), 47 deletions(-) diff --git a/src/server/zset_family.cc b/src/server/zset_family.cc index 22dd8e0452c9..d1433db18593 100644 --- a/src/server/zset_family.cc +++ b/src/server/zset_family.cc @@ -222,6 +222,56 @@ OpResult FindZEntry(const ZParams& zparams, const OpArgs& return DbSlice::ItAndUpdater{add_res.it, add_res.exp_it, std::move(add_res.post_updater)}; } +using RandomPick = std::size_t; +using PicksArray = std::vector; + +/* + * Generates an array of non-unique indexes in O(picks_count). + * picks_count specifies the number of random indexes. + * */ +PicksArray GenerateRandomPicks(std::size_t picks_count, std::size_t total_size) { + CHECK_GT(total_size, std::size_t(0)); + + PicksArray picks; + picks.resize(picks_count); + + absl::BitGen bitgen; + + for (std::size_t i = 0; i < picks_count; i++) { + picks[i] = absl::Uniform(bitgen, 0u, total_size); + } + return picks; +} + +/* + * Generates an array of unique indexes in O(picks_count). + * picks_count specifies the number of random indexes. + * + * The function uses Robert Floyd's sampling algorithm + * https://dl.acm.org/doi/pdf/10.1145/30401.315746 + * */ +PicksArray GenerateUniqueRandomPicks(std::size_t picks_count, std::size_t total_size) { + CHECK_GE(total_size, picks_count); + + PicksArray picks; + std::unordered_set picked_indexes{picks_count}; + + absl::BitGen bitgen; + + for (std::size_t i = total_size - picks_count; i < total_size; ++i) { + std::size_t random_index = absl::Uniform(bitgen, 0u, i + 1u); + if (!picked_indexes.contains(random_index)) { + picks.push_back(random_index); + picked_indexes.insert(random_index); + } else { + picks.push_back(i); + picked_indexes.insert(i); + } + } + DCHECK_EQ(picks.size(), picks_count); + return picks; +} + bool ScoreToLongLat(const std::optional& val, double* xy) { if (!val.has_value()) return false; @@ -1702,6 +1752,56 @@ OpResult OpScan(const OpArgs& op_args, std::string_view key, uint64_t return res; } +OpResult OpRandMember(int count, const ZSetFamily::RangeParams& params, + const OpArgs& op_args, string_view key) { + auto it = op_args.shard->db_slice().FindReadOnly(op_args.db_cntx, key, OBJ_ZSET); + if (!it) + return it.status(); + + // Action::RANGE is a read-only operation, but requires const_cast + PrimeValue& pv = const_cast(it.value()->second); + + const std::size_t size = pv.Size(); + if (!size) { + return ScoredArray(); + } + + const std::size_t picks_count = + count >= 0 ? std::min(static_cast(count), size) : std::abs(count); + + ScoredArray result{picks_count}; + PicksArray picks = count >= 0 ? GenerateUniqueRandomPicks(picks_count, size) + : GenerateRandomPicks(picks_count, size); + + /* CASE 1: + * The number of requested elements (count) is significantly less than the total size. + * In this case, we generate random indexes, and search for the elements at this index (each + * search for O(log(size)). In total O(picks_count * log(size)). */ + if (picks_count * static_cast(std::log2(size)) < + size) { // convert to std::uint64_t to prevent overflow + for (std::size_t i = 0; i < picks_count; ++i) { + IntervalVisitor iv{Action::RANGE, params, &pv}; + iv(ZSetFamily::IndexInterval{picks[i], picks[i]}); + result[i] = iv.PopResult().front(); + } + } else { + /* CASE 2: + * The number of requested elements (count) does not differ much from the total size. + * In this case, we just traverse all elements and randomly add them to the result. + * In total O(size). */ + IntervalVisitor iv{Action::RANGE, params, &pv}; + iv(ZSetFamily::IndexInterval{0, -1}); + + ScoredArray all_elements = iv.PopResult(); + + for (std::size_t i = 0; i < picks_count; ++i) { + result[i] = all_elements[picks[i]]; + } + } + + return result; +} + void ZAddGeneric(string_view key, const ZParams& zparams, ScoredMemberSpan memb_sp, ConnectionContext* cntx) { auto cb = [&](Transaction* t, EngineShard* shard) { @@ -2323,16 +2423,14 @@ void ZSetFamily::ZRandMember(CmdArgList args, ConnectionContext* cntx) { if (args.size() > 3) return cntx->SendError(WrongNumArgsError("ZRANDMEMBER")); - ZRangeSpec range_spec; - range_spec.interval = IndexInterval(0, -1); - CmdArgParser parser{args}; string_view key = parser.Next(); bool is_count = parser.HasNext(); int count = is_count ? parser.Next() : 1; - range_spec.params.with_scores = static_cast(parser.Check("WITHSCORES").IgnoreCase()); + ZSetFamily::RangeParams params; + params.with_scores = static_cast(parser.Check("WITHSCORES").IgnoreCase()); if (parser.HasNext()) return cntx->SendError(absl::StrCat("Unsupported option:", string_view(parser.Next()))); @@ -2340,26 +2438,17 @@ void ZSetFamily::ZRandMember(CmdArgList args, ConnectionContext* cntx) { if (auto err = parser.Error(); err) return cntx->SendError(err->MakeReply()); - bool sign = count < 0; - range_spec.params.limit = std::abs(count); - const auto cb = [&](Transaction* t, EngineShard* shard) { - return OpRange(range_spec, t->GetOpArgs(shard), key); + return OpRandMember(count, params, t->GetOpArgs(shard), key); }; OpResult result = cntx->transaction->ScheduleSingleHopT(cb); auto* rb = static_cast(cntx->reply_builder()); if (result) { - if (sign && !result->empty()) { - for (auto i = result->size(); i < range_spec.params.limit; ++i) { - // we can return duplicate elements, so first is OK - result->push_back(result->front()); - } - } - rb->SendScoredArray(result.value(), range_spec.params.with_scores); + rb->SendScoredArray(result.value(), params.with_scores); } else if (result.status() == OpStatus::KEY_NOTFOUND) { if (is_count) { - rb->SendScoredArray(ScoredArray(), range_spec.params.with_scores); + rb->SendScoredArray(ScoredArray(), params.with_scores); } else { rb->SendNull(); } diff --git a/src/server/zset_family_test.cc b/src/server/zset_family_test.cc index 093c3e647a62..9354b9a6552a 100644 --- a/src/server/zset_family_test.cc +++ b/src/server/zset_family_test.cc @@ -77,53 +77,132 @@ TEST_F(ZSetFamilyTest, ZRem) { } TEST_F(ZSetFamilyTest, ZRandMember) { - auto resp = Run({ - "zadd", - "x", - "1", - "a", - "2", - "b", - "3", - "c", - }); + auto resp = Run({"ZAdd", "x", "1", "a", "2", "b", "3", "c"}); + EXPECT_THAT(resp, IntArg(3)); + + // Test if count > 0 resp = Run({"ZRandMember", "x"}); ASSERT_THAT(resp, ArgType(RespExpr::STRING)); - EXPECT_THAT(resp, "a"); + EXPECT_THAT(resp, AnyOf("a", "b", "c")); + + resp = Run({"ZRandMember", "x", "1"}); + ASSERT_THAT(resp, ArgType(RespExpr::STRING)); + EXPECT_THAT(resp, AnyOf("a", "b", "c")); resp = Run({"ZRandMember", "x", "2"}); - ASSERT_THAT(resp, ArgType(RespExpr::ARRAY)); - EXPECT_THAT(resp.GetVec(), UnorderedElementsAre("a", "b")); + ASSERT_THAT(resp, ArrLen(2)); + EXPECT_THAT(resp.GetVec(), IsSubsetOf({"a", "b", "c"})); - resp = Run({"ZRandMember", "x", "0"}); - ASSERT_THAT(resp, ArgType(RespExpr::ARRAY)); - EXPECT_EQ(resp.GetVec().size(), 0); + resp = Run({"ZRandMember", "x", "3"}); + ASSERT_THAT(resp, ArrLen(3)); + EXPECT_THAT(resp.GetVec(), UnorderedElementsAre("a", "b", "c")); - resp = Run({"ZRandMember", "k"}); - ASSERT_THAT(resp, ArgType(RespExpr::NIL)); + // Test if count < 0 + std::unordered_set expected_entries({"a", "b", "c"}); - resp = Run({"ZRandMember", "k", "2"}); - ASSERT_THAT(resp, ArgType(RespExpr::ARRAY)); - EXPECT_EQ(resp.GetVec().size(), 0); + auto expect_elements = [](const auto& expected_elements, const auto& actual_elements) { + for (const auto& x : actual_elements) { + if (!expected_elements.contains(x)) { + return false; + } + } + return true; + }; + + auto parse_response = [](const auto& resp) { + auto vec = resp.GetVec(); + + std::vector entries; + std::transform(vec.begin(), vec.end(), std::back_inserter(entries), + [](auto& x) { return x.GetString(); }); + return entries; + }; + + resp = Run({"ZRandMember", "x", "-1"}); + ASSERT_THAT(resp, ArgType(RespExpr::STRING)); + EXPECT_THAT(resp, AnyOf("a", "b", "c")); - resp = Run({"ZRandMember", "x", "-5"}); - ASSERT_THAT(resp, ArrLen(5)); - EXPECT_THAT(resp.GetVec(), ElementsAre("a", "b", "c", "a", "a")); + resp = Run({"ZRandMember", "x", "-2"}); + ASSERT_THAT(resp, ArrLen(2)); + EXPECT_TRUE(expect_elements(expected_entries, parse_response(resp))); - resp = Run({"ZRandMember", "x", "5"}); + resp = Run({"ZRandMember", "x", "-3"}); + ASSERT_THAT(resp, ArrLen(3)); + EXPECT_TRUE(expect_elements(expected_entries, parse_response(resp))); + + // Test if count < 0, but the absolute value is larger than the size of the sorted set + resp = Run({"ZRandMember", "x", "-15"}); + ASSERT_THAT(resp, ArrLen(15)); + EXPECT_TRUE(expect_elements(expected_entries, parse_response(resp))); + + // Test if count is 0 + ASSERT_THAT(Run({"ZRandMember", "x", "0"}), ArrLen(0)); + + // Test if count is larger than the size of the sorted set + resp = Run({"ZRandMember", "x", "15"}); ASSERT_THAT(resp, ArrLen(3)); EXPECT_THAT(resp.GetVec(), UnorderedElementsAre("a", "b", "c")); - resp = Run({"ZRandMember", "x", "-5", "WITHSCORES"}); - ASSERT_THAT(resp, ArrLen(10)); - EXPECT_THAT(resp.GetVec(), ElementsAre("a", "1", "b", "2", "c", "3", "a", "1", "a", "1")); + // Test if sorted set is empty + EXPECT_THAT(Run({"ZAdd", "empty::zset", "1", "one"}), IntArg(1)); + EXPECT_THAT(Run({"ZRem", "empty::zset", "one"}), IntArg(1)); + ASSERT_THAT(Run({"ZRandMember", "empty::zset", "0"}), ArrLen(0)); + ASSERT_THAT(Run({"ZRandMember", "empty::zset", "3"}), ArrLen(0)); + ASSERT_THAT(Run({"ZRandMember", "empty::zset", "-4"}), ArrLen(0)); + + // Test if key does not exist + ASSERT_THAT(Run({"ZRandMember", "y"}), ArgType(RespExpr::NIL)); + ASSERT_THAT(Run({"ZRandMember", "y", "0"}), ArrLen(0)); + + // Test WITHSCORES + using ZSetEntry = std::pair; + std::set expected_entries_with_scores{{"a", "1"}, {"b", "2"}, {"c", "3"}}; + + auto parse_response_with_scores = [](const auto& resp) { + auto vec = resp.GetVec(); + + std::vector entries; + for (std::size_t i = 1; i < vec.size(); i += 2) { + entries.emplace_back(vec[i - 1].GetString(), vec[i].GetString()); + } + return entries; + }; + + resp = Run({"ZRandMember", "x", "1", "WITHSCORES"}); + ASSERT_THAT(resp, ArrLen(2)); + EXPECT_THAT(parse_response_with_scores(resp), IsSubsetOf(expected_entries_with_scores)); + + resp = Run({"ZRandMember", "x", "2", "WITHSCORES"}); + ASSERT_THAT(resp, ArrLen(4)); + EXPECT_THAT(parse_response_with_scores(resp), IsSubsetOf(expected_entries_with_scores)); resp = Run({"ZRandMember", "x", "3", "WITHSCORES"}); ASSERT_THAT(resp, ArrLen(6)); - EXPECT_THAT(resp.GetVec(), UnorderedElementsAre("a", "1", "b", "2", "c", "3")); + EXPECT_THAT(parse_response_with_scores(resp), + UnorderedElementsAre(std::make_pair("a", "1"), std::make_pair("b", "2"), + std::make_pair("c", "3"))); - resp = Run({"ZRandMember", "x", "3", "WITHSCORES", "test"}); - EXPECT_THAT(resp, ErrArg("wrong number of arguments")); + resp = Run({"ZRandMember", "x", "15", "WITHSCORES"}); + ASSERT_THAT(resp, ArrLen(6)); + EXPECT_THAT(parse_response_with_scores(resp), + UnorderedElementsAre(std::make_pair("a", "1"), std::make_pair("b", "2"), + std::make_pair("c", "3"))); + + resp = Run({"ZRandMember", "x", "-1", "WITHSCORES"}); + ASSERT_THAT(resp, ArrLen(2)); + EXPECT_TRUE(expect_elements(expected_entries_with_scores, parse_response_with_scores(resp))); + + resp = Run({"ZRandMember", "x", "-2", "WITHSCORES"}); + ASSERT_THAT(resp, ArrLen(4)); + EXPECT_TRUE(expect_elements(expected_entries_with_scores, parse_response_with_scores(resp))); + + resp = Run({"ZRandMember", "x", "-3", "WITHSCORES"}); + ASSERT_THAT(resp, ArrLen(6)); + EXPECT_TRUE(expect_elements(expected_entries_with_scores, parse_response_with_scores(resp))); + + resp = Run({"ZRandMember", "x", "-15", "WITHSCORES"}); + ASSERT_THAT(resp, ArrLen(30)); + EXPECT_TRUE(expect_elements(expected_entries_with_scores, parse_response_with_scores(resp))); } TEST_F(ZSetFamilyTest, ZMScore) {