Skip to content

Commit

Permalink
custom node code
Browse files Browse the repository at this point in the history
  • Loading branch information
elundaeva committed Aug 29, 2023
1 parent 1f7c7b7 commit a146165
Showing 1 changed file with 295 additions and 0 deletions.
295 changes: 295 additions & 0 deletions haystack/nodes/retriever/filesim.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,295 @@
# pylint: disable=too-many-instance-attributes

import math

from typing import List, Optional, Dict, Union, Tuple
import logging
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor, as_completed

from haystack.nodes import BaseComponent, BaseRetriever, EmbeddingRetriever
from haystack.document_stores import KeywordDocumentStore
from haystack.schema import Document


logger = logging.getLogger(__name__)


class FileSimilarityRetriever(BaseComponent):
"""
This retriever performs retrieval for file similarity. It is a self-referential retriever that will use existing
files as a query and returns a list of the most similar documents from each file in the order of similarity to the
query file. It uses reciprocal rank fusion to determine file similarity. That means, it uses each document from the
query file and performs a retrieval for this document. It then aggregates the results from each document query. A
similar approach is described here: https://arxiv.org/pdf/2201.01614.pdf (Althammer et al. 2022).
"""

outgoing_edges = 1

# pylint: disable=too-many-arguments
def __init__(
self,
document_store: KeywordDocumentStore,
primary_retriever: Optional[BaseRetriever] = None,
secondary_retriever: Optional[BaseRetriever] = None,
file_aggregation_key: str = "file_id",
keep_original_score: Optional[str] = None,
top_k: int = 10,
max_query_len: int = 6000,
max_num_queries: Optional[int] = None,
use_existing_embedding: bool = True,
) -> None:
"""
Initialize an instance of FileSimilarityRetriever.
:param document_store: The document store that the retriever should retrieve from.
:param file_aggregation_key: The meta data key that should be used to aggregate documents to the file level.
:param primary_retriever: A clutch until haystack supports passing a list of retrievers.
:param secondary_retriever: A clutch until haystack supports passing a list of retrievers.
:param keep_original_score: Set this to store the original score of the returned document in the document's meta
field. The document's score property will be replaced with the reciprocal rank fusion score.
:param top_k: How many documents to return.
:param max_query_len: How many chars can be in a query document. The document will be cut off if it is longer
than the maximum length. We need this here because there might be an issue with queries that are too long
and the BM25Retriever because an error will be thrown if the query excees the `max_clause_count` search
setting (https://www.elastic.co/guide/en/elasticsearch/reference/7.17/search-settings.html)
:param max_num_queries: The maximum number of queries that should be run for a single file. If the number of
query documents exceeds this limit, the query documents will be split into n parts so that
n < max_num_queries and every nth document will be kept.
:param use_existing_embeddings: Whether to re-use the existing embeddings from the index.
To optimize speed for the file similarity retrieval you should set this parameter to `True`.
This way the FileSimilarityRetriever can run on the CPU.
"""
super().__init__()
self.retrievers = []
if primary_retriever:
self.retrievers.append(primary_retriever)

if secondary_retriever:
self.retrievers.append(secondary_retriever)

self.file_aggregation_key = file_aggregation_key

self.document_store = document_store
self.keep_original_score = keep_original_score
self.top_k = top_k
self.max_query_len = max_query_len
self.max_num_queries = max_num_queries
self.use_existing_embedding = use_existing_embedding
self.executor = ThreadPoolExecutor(max_workers=len(self.retrievers))

# pylint: disable=arguments-renamed
def run(
self,
query: str,
top_k: Optional[int] = None,
indices: Optional[Union[str, List[Union[str, None]]]] = None,
filters: Optional[Dict] = None,
file_index: Optional[str] = None,
) -> Tuple[Dict, str]:
"""
Performs file similarity retrieval using all retrievers that this node was initialized with.
The query should be the file aggregator value that will be used to get all relevant documents from the
document_store.
:param query: Will be used to filter for all documents that belong to the file you want to use as query file.
:param top_k: The maximum number of documents to return.
:param indices: The document_store index or indices that the retrievers should retrieve from.
:param file_index: The index that the query file should be retrieved from.
:param filters: Filters that should be applied for each retriever.
"""
retrieved_docs = self.retrieve(
query=query, top_k=top_k, indices=indices, file_index=file_index, filters=filters
)

return {"documents": retrieved_docs}, "output_1"

# pylint: disable=arguments-differ
def run_batch(
self,
queries: List[str],
top_k: Optional[int] = None,
indices: Optional[Union[str, List[Union[str, None]]]] = None,
filters: Optional[Dict] = None,
file_index: Optional[str] = None,
) -> Tuple[Dict, str]:
results = []
for query in queries:
results.append(
self.retrieve(query=query, top_k=top_k, indices=indices, filters=filters, file_index=file_index)
)

return {"documents": results}, "output_1"

def retrieve(
self,
query: str,
top_k: Optional[int] = None,
indices: Optional[Union[str, List[Optional[str]]]] = None,
file_index: Optional[str] = None,
filters: Optional[Dict] = None,
) -> List[Document]:
"""
Performs file similarity retrieval using all retrievers that this node was initialized with.
The query should be the file aggregator value that will be used to get all relevant documents from the
document_store.
:param query: Will be used to filter for all documents that belong to the file you want to use as query file.
:param top_k: The maximum number of documents to return.
:param indices: The document_store index or indices that the retrievers should retrieve from.
:param file_index: The index that the query file should be retrieved from.
:param filters: Filters that should be applied for each retriever.
"""
if isinstance(indices, str) or indices is None:
retriever_indices = [indices] * len(self.retrievers)
else:
retriever_indices = indices

if not top_k:
top_k = self.top_k

query_file_documents = self._get_documents(file_filter=query, index=file_index)

if self.max_num_queries is not None and len(query_file_documents) > self.max_num_queries:
logger.warning(
"Query file %s has %s documents. "
"It exceeds the maximum number of query documents and was reduced to %s query documents.",
query,
len(query_file_documents),
self.max_num_queries,
)
num_splits = math.ceil(len(query_file_documents) / self.max_num_queries)
query_file_documents = query_file_documents[::num_splits]

if len(query_file_documents):
retrieved_docs = []
futures = []
for idx, retriever in zip(retriever_indices, self.retrievers):
if isinstance(retriever, EmbeddingRetriever) and all(
doc.embedding is not None for doc in query_file_documents
):
future = self.executor.submit(
self._retrieve_for_documents_by_embedding,
retriever=retriever,
documents=query_file_documents,
index=idx,
filters=filters,
)
else:
future = self.executor.submit(
self._retrieve_for_documents,
retriever=retriever,
documents=query_file_documents,
index=idx,
filters=filters,
)
futures.append(future)

for future in as_completed(futures):
retrieved_docs.extend(future.result())

aggregated_results = self._aggregate_results(results=retrieved_docs, top_k=top_k)
else:
logger.info("Could not find any indexed documents for query: %s.", query)
aggregated_results = []

return aggregated_results

def _get_documents(self, file_filter: str, index: Optional[str]) -> List[Document]:
docs: List[Document] = self.document_store.get_all_documents(
index=index,
filters={self.file_aggregation_key: [file_filter]},
return_embedding=self.use_existing_embedding,
)
return docs

def _retrieve_for_documents_by_embedding(
self,
retriever: EmbeddingRetriever,
documents: List[Document],
index: Optional[str] = None,
filters: Optional[Dict] = None,
) -> List[List[Document]]:
doc_store = retriever.document_store
top_k = retriever.top_k
query_embs = [document.embedding for document in documents]
results: List[List[Document]] = doc_store.query_by_embedding_batch(
query_embs=query_embs, filters=filters, index=index, top_k=top_k
)
return results

def _retrieve_for_documents(
self,
retriever: BaseRetriever,
documents: List[Document],
index: Optional[str] = None,
filters: Optional[Dict] = None,
) -> List[List[Document]]:
queries = []
for doc in documents:
content = doc.content[: self.max_query_len]
queries.append(content)
if len(content) != len(doc.content):
logger.warning(
"Document %s retrieved with aggregation key '%s' exceeds max_query_len of %s and was cut off.",
doc,
self.file_aggregation_key,
self.max_query_len,
)

docs: List[List[Document]] = retriever.retrieve_batch(
queries=queries, index=index, filters=filters, scale_score=False
)

return docs

def _aggregate_results(self, results: List[List[Document]], top_k: int) -> List[Document]:
# We iterate over each result list that contains for each query document and each retriever a list of documents
# ranked by their similarity to the query document.
# We group the result documents by the same file aggregation key and calculate the reciprocal rank fusion score.
aggregator_doc_lookup = defaultdict(list)
aggregated_scores: Dict = defaultdict(int)
for result in results:
for idx, doc in enumerate(result):
aggregator = doc.meta.get(self.file_aggregation_key)
if aggregator is None:
logger.warning(
"Document %s can not be aggregated. Missing aggregation key '%s' in meta.",
doc,
self.file_aggregation_key,
)
else:
score = self._calculate_reciprocal_rank_fusion(idx)
aggregated_scores[doc.meta[self.file_aggregation_key]] += score
aggregator_doc_lookup[doc.meta[self.file_aggregation_key]].append(doc)

# For each aggregated file we want to sort the retrieved documents by their score so that we
# can return the most relevant document for each aggregation later.
for aggregator in aggregator_doc_lookup:
aggregator_doc_lookup[aggregator] = sorted(
aggregator_doc_lookup[aggregator], key=lambda doc: doc.score, reverse=True # type: ignore
)

sorted_aggregator_scores = sorted(aggregated_scores.items(), key=lambda d: d[1], reverse=True) # type: ignore
result_docs = []
for aggregator_id, rrf_score in sorted_aggregator_scores[:top_k]:
doc = aggregator_doc_lookup[aggregator_id][0]
if self.keep_original_score:
doc.meta[self.keep_original_score] = doc.score
doc.score = rrf_score
result_docs.append(doc)

return result_docs

@staticmethod
def _calculate_reciprocal_rank_fusion(current_idx: int) -> float:
"""
Calculates the reciprocal rank score for a Document instance at the current rank.
See https://plg.uwaterloo.ca/~gvcormac/cormacksigir09-rrf.pdf
:param current_idx: The rank position of a document in the result list.
"""
# The paper above mentions a constant of 60 that should be used in the denominator.
# The denominator is the result of adding this constant and the rank of the retrieved document
# We set the constant to 61 because python passes the rank starting from 0.
reciprocal_rank_constant = 61
return 1 / (reciprocal_rank_constant + current_idx)

0 comments on commit a146165

Please sign in to comment.