-
Notifications
You must be signed in to change notification settings - Fork 4.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Retrieval Metrics: Updating HitRate and MRR for Evaluation@K documents retrieved. Also adding RR as a separate metric #12997
Merged
Merged
Changes from 5 commits
Commits
Show all changes
6 commits
Select commit
Hold shift + click to select a range
62af5fd
Updating metrics: MRR renamed to RR, HitRate updated for multi-doc ev…
AgenP f302d83
Merge branch 'main' of https://github.com/AgenP/llama_index_fork
AgenP 8ae408a
Updated MRR and HitRate with requested changes
AgenP f50c07a
Merge branch 'run-llama:main' into main
AgenP a54e2dd
Merge branch 'main' of https://github.com/AgenP/llama_index_fork into…
AgenP eff2806
Iteration w/ class attribute implementation for the calculation optio…
AgenP File filter
Filter by extension
Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,7 +12,13 @@ | |
|
||
|
||
class HitRate(BaseRetrievalMetric): | ||
"""Hit rate metric.""" | ||
"""Hit rate metric: Compute hit rate with two calculation options. | ||
|
||
- The default method checks for a single match between any of the retrieved docs and expected docs. | ||
- The more granular method checks for all potential matches between retrieved docs and expected docs. | ||
|
||
The granular compute method can be selected by inputting the 'use_granular_hit_rate' kwarg as True. | ||
""" | ||
|
||
metric_name: str = "hit_rate" | ||
|
||
|
@@ -25,17 +31,55 @@ def compute( | |
retrieved_texts: Optional[List[str]] = None, | ||
**kwargs: Any, | ||
) -> RetrievalMetricResult: | ||
"""Compute metric.""" | ||
if retrieved_ids is None or expected_ids is None: | ||
"""Compute metric based on the provided inputs. | ||
|
||
Parameters: | ||
query (Optional[str]): The query string (not used in the current implementation). | ||
expected_ids (Optional[List[str]]): Expected document IDs. | ||
retrieved_ids (Optional[List[str]]): Retrieved document IDs. | ||
expected_texts (Optional[List[str]]): Expected texts (not used in the current implementation). | ||
retrieved_texts (Optional[List[str]]): Retrieved texts (not used in the current implementation). | ||
use_granular_hit_rate (bool): If True, use the granular hit rate calculation. | ||
|
||
Raises: | ||
ValueError: If the necessary IDs are not provided. | ||
|
||
Returns: | ||
RetrievalMetricResult: The result with the computed hit rate score. | ||
""" | ||
# Checking for the required arguments | ||
if ( | ||
retrieved_ids is None | ||
or expected_ids is None | ||
or not retrieved_ids | ||
or not expected_ids | ||
): | ||
raise ValueError("Retrieved ids and expected ids must be provided") | ||
is_hit = any(id in expected_ids for id in retrieved_ids) | ||
return RetrievalMetricResult( | ||
score=1.0 if is_hit else 0.0, | ||
) | ||
|
||
# Determining which implementation to use based on `use_granular_hit_rate` kwarg | ||
use_granular = kwargs.get("use_granular_hit_rate", False) | ||
|
||
if use_granular: | ||
# Granular HitRate calculation: Calculate all hits and divide by the number of expected docs | ||
expected_set = set(expected_ids) | ||
hits = sum(1 for doc_id in retrieved_ids if doc_id in expected_set) | ||
score = hits / len(expected_ids) if expected_ids else 0.0 | ||
else: | ||
# Default HitRate calculation: Check if there is a single hit | ||
is_hit = any(id in expected_ids for id in retrieved_ids) | ||
score = 1.0 if is_hit else 0.0 | ||
|
||
return RetrievalMetricResult(score=score) | ||
|
||
|
||
class MRR(BaseRetrievalMetric): | ||
"""MRR metric.""" | ||
"""MRR (Mean Reciprocal Rank) metric with two calculation options. | ||
|
||
- The default method calculates the reciprocal rank of the first relevant (a.k.a expected) retrieved document. | ||
- The more granular method sums the reciprocal ranks of all relevant retrieved documents and divides by the count of relevant retrieved documents. | ||
|
||
The granular compute method can be selected by inputting the 'use_granular_mrr' kwarg as True. | ||
""" | ||
|
||
metric_name: str = "mrr" | ||
|
||
|
@@ -48,17 +92,58 @@ def compute( | |
retrieved_texts: Optional[List[str]] = None, | ||
**kwargs: Any, | ||
) -> RetrievalMetricResult: | ||
"""Compute metric.""" | ||
if retrieved_ids is None or expected_ids is None: | ||
"""Compute MRR based on the provided inputs and selected method. | ||
|
||
Parameters: | ||
query (Optional[str]): The query string (not used in the current implementation). | ||
expected_ids (Optional[List[str]]): Expected document IDs. | ||
retrieved_ids (Optional[List[str]]): Retrieved document IDs. | ||
expected_texts (Optional[List[str]]): Expected texts (not used in the current implementation). | ||
retrieved_texts (Optional[List[str]]): Retrieved texts (not used in the current implementation). | ||
use_granular_mrr (bool): If True, use the granular MRR calculation. | ||
|
||
Raises: | ||
ValueError: If the necessary IDs are not provided. | ||
|
||
Returns: | ||
RetrievalMetricResult: The result with the computed MRR score. | ||
""" | ||
# Checking for the required arguments | ||
if ( | ||
retrieved_ids is None | ||
or expected_ids is None | ||
or not retrieved_ids | ||
or not expected_ids | ||
): | ||
raise ValueError("Retrieved ids and expected ids must be provided") | ||
for i, id in enumerate(retrieved_ids): | ||
if id in expected_ids: | ||
return RetrievalMetricResult( | ||
score=1.0 / (i + 1), | ||
) | ||
return RetrievalMetricResult( | ||
score=0.0, | ||
) | ||
|
||
# Determining which implementation to use based on `use_granular_mrr` kwarg | ||
use_granular_mrr = kwargs.get("use_granular_mrr", False) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same comment as above, maybe we define this as a class/instance attribute? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. that way we still satisfy the superclass method signature |
||
|
||
if use_granular_mrr: | ||
# Granular MRR calculation: All relevant retrieved docs have their reciprocal ranks summed and averaged | ||
expected_set = set(expected_ids) | ||
reciprocal_rank_sum = 0.0 | ||
relevant_docs_count = 0 | ||
|
||
for index, doc_id in enumerate(retrieved_ids): | ||
if doc_id in expected_set: | ||
relevant_docs_count += 1 | ||
reciprocal_rank_sum += 1.0 / (index + 1) | ||
|
||
mrr_score = ( | ||
reciprocal_rank_sum / relevant_docs_count | ||
if relevant_docs_count > 0 | ||
else 0.0 | ||
) | ||
else: | ||
# Default MRR calculation: Reciprocal rank of the first relevant document retrieved | ||
for i, id in enumerate(retrieved_ids): | ||
if id in expected_ids: | ||
return RetrievalMetricResult(score=1.0 / (i + 1)) | ||
mrr_score = 0.0 | ||
|
||
return RetrievalMetricResult(score=mrr_score) | ||
|
||
|
||
class CohereRerankRelevancyMetric(BaseRetrievalMetric): | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
from llama_index.core.evaluation.retrieval.metrics import HitRate, MRR | ||
import pytest | ||
|
||
|
||
# Test cases for the updated HitRate class | ||
@pytest.mark.parametrize( | ||
("expected_ids", "retrieved_ids", "use_granular", "expected_result"), | ||
[ | ||
(["id1", "id2", "id3"], ["id3", "id1", "id2", "id4"], False, 1.0), | ||
(["id1", "id2", "id3", "id4"], ["id1", "id5", "id2"], True, 2 / 4), | ||
(["id1", "id2"], ["id3", "id4"], False, 0.0), | ||
(["id1", "id2"], ["id2", "id1", "id7"], True, 2 / 2), | ||
], | ||
) | ||
def test_hit_rate(expected_ids, retrieved_ids, use_granular, expected_result): | ||
hr = HitRate() | ||
result = hr.compute( | ||
expected_ids=expected_ids, | ||
retrieved_ids=retrieved_ids, | ||
use_granular_hit_rate=use_granular, | ||
) | ||
assert result.score == pytest.approx(expected_result) | ||
|
||
|
||
# Test cases for the updated MRR class | ||
@pytest.mark.parametrize( | ||
("expected_ids", "retrieved_ids", "use_granular", "expected_result"), | ||
[ | ||
(["id1", "id2", "id3"], ["id3", "id1", "id2", "id4"], False, 1 / 1), | ||
(["id1", "id2", "id3", "id4"], ["id5", "id1"], False, 1 / 2), | ||
(["id1", "id2"], ["id3", "id4"], False, 0.0), | ||
(["id1", "id2"], ["id2", "id1", "id7"], False, 1 / 1), | ||
( | ||
["id1", "id2", "id3"], | ||
["id3", "id1", "id2", "id4"], | ||
True, | ||
(1 / 1 + 1 / 2 + 1 / 3) / 3, | ||
), | ||
( | ||
["id1", "id2", "id3", "id4"], | ||
["id1", "id2", "id5"], | ||
True, | ||
(1 / 1 + 1 / 2) / 2, | ||
), | ||
(["id1", "id2"], ["id1", "id7", "id15", "id2"], True, (1 / 1 + 1 / 4) / 2), | ||
], | ||
) | ||
def test_mrr(expected_ids, retrieved_ids, use_granular, expected_result): | ||
mrr = MRR() | ||
result = mrr.compute( | ||
expected_ids=expected_ids, | ||
retrieved_ids=retrieved_ids, | ||
use_granular_mrr=use_granular, | ||
) | ||
assert result.score == pytest.approx(expected_result) | ||
|
||
|
||
# Test cases for exceptions handling for both HitRate and MRR | ||
@pytest.mark.parametrize( | ||
("expected_ids", "retrieved_ids", "use_granular"), | ||
[ | ||
( | ||
None, | ||
["id3", "id1", "id2", "id4"], | ||
False, | ||
), # None expected_ids should trigger ValueError | ||
( | ||
["id1", "id2", "id3"], | ||
None, | ||
True, | ||
), # None retrieved_ids should trigger ValueError | ||
([], [], False), # Empty IDs should trigger ValueError | ||
], | ||
) | ||
def test_exceptions(expected_ids, retrieved_ids, use_granular): | ||
with pytest.raises(ValueError): | ||
hr = HitRate() | ||
hr.compute( | ||
expected_ids=expected_ids, | ||
retrieved_ids=retrieved_ids, | ||
use_granular_hit_rate=use_granular, | ||
) | ||
mrr = MRR() | ||
mrr.compute( | ||
expected_ids=expected_ids, | ||
retrieved_ids=retrieved_ids, | ||
use_granular_mrr=use_granular, | ||
) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
instead of using
kwargs
, maybe should just create a class or instance attributeuse_granular
?