From a14616593e212cf3a711a8e61d696c5b8e6dff8a Mon Sep 17 00:00:00 2001 From: Elena Lundaeva Date: Tue, 29 Aug 2023 14:46:48 +0200 Subject: [PATCH] custom node code --- haystack/nodes/retriever/filesim.py | 295 ++++++++++++++++++++++++++++ 1 file changed, 295 insertions(+) create mode 100644 haystack/nodes/retriever/filesim.py diff --git a/haystack/nodes/retriever/filesim.py b/haystack/nodes/retriever/filesim.py new file mode 100644 index 00000000000..ba4fcb27326 --- /dev/null +++ b/haystack/nodes/retriever/filesim.py @@ -0,0 +1,295 @@ +# pylint: disable=too-many-instance-attributes + +import math + +from typing import List, Optional, Dict, Union, Tuple +import logging +from collections import defaultdict +from concurrent.futures import ThreadPoolExecutor, as_completed + +from haystack.nodes import BaseComponent, BaseRetriever, EmbeddingRetriever +from haystack.document_stores import KeywordDocumentStore +from haystack.schema import Document + + +logger = logging.getLogger(__name__) + + +class FileSimilarityRetriever(BaseComponent): + """ + This retriever performs retrieval for file similarity. It is a self-referential retriever that will use existing + files as a query and returns a list of the most similar documents from each file in the order of similarity to the + query file. It uses reciprocal rank fusion to determine file similarity. That means, it uses each document from the + query file and performs a retrieval for this document. It then aggregates the results from each document query. A + similar approach is described here: https://arxiv.org/pdf/2201.01614.pdf (Althammer et al. 2022). + """ + + outgoing_edges = 1 + + # pylint: disable=too-many-arguments + def __init__( + self, + document_store: KeywordDocumentStore, + primary_retriever: Optional[BaseRetriever] = None, + secondary_retriever: Optional[BaseRetriever] = None, + file_aggregation_key: str = "file_id", + keep_original_score: Optional[str] = None, + top_k: int = 10, + max_query_len: int = 6000, + max_num_queries: Optional[int] = None, + use_existing_embedding: bool = True, + ) -> None: + """ + Initialize an instance of FileSimilarityRetriever. + :param document_store: The document store that the retriever should retrieve from. + :param file_aggregation_key: The meta data key that should be used to aggregate documents to the file level. + :param primary_retriever: A clutch until haystack supports passing a list of retrievers. + :param secondary_retriever: A clutch until haystack supports passing a list of retrievers. + :param keep_original_score: Set this to store the original score of the returned document in the document's meta + field. The document's score property will be replaced with the reciprocal rank fusion score. + :param top_k: How many documents to return. + :param max_query_len: How many chars can be in a query document. The document will be cut off if it is longer + than the maximum length. We need this here because there might be an issue with queries that are too long + and the BM25Retriever because an error will be thrown if the query excees the `max_clause_count` search + setting (https://www.elastic.co/guide/en/elasticsearch/reference/7.17/search-settings.html) + :param max_num_queries: The maximum number of queries that should be run for a single file. If the number of + query documents exceeds this limit, the query documents will be split into n parts so that + n < max_num_queries and every nth document will be kept. + :param use_existing_embeddings: Whether to re-use the existing embeddings from the index. + To optimize speed for the file similarity retrieval you should set this parameter to `True`. + This way the FileSimilarityRetriever can run on the CPU. + """ + super().__init__() + self.retrievers = [] + if primary_retriever: + self.retrievers.append(primary_retriever) + + if secondary_retriever: + self.retrievers.append(secondary_retriever) + + self.file_aggregation_key = file_aggregation_key + + self.document_store = document_store + self.keep_original_score = keep_original_score + self.top_k = top_k + self.max_query_len = max_query_len + self.max_num_queries = max_num_queries + self.use_existing_embedding = use_existing_embedding + self.executor = ThreadPoolExecutor(max_workers=len(self.retrievers)) + + # pylint: disable=arguments-renamed + def run( + self, + query: str, + top_k: Optional[int] = None, + indices: Optional[Union[str, List[Union[str, None]]]] = None, + filters: Optional[Dict] = None, + file_index: Optional[str] = None, + ) -> Tuple[Dict, str]: + """ + Performs file similarity retrieval using all retrievers that this node was initialized with. + The query should be the file aggregator value that will be used to get all relevant documents from the + document_store. + + :param query: Will be used to filter for all documents that belong to the file you want to use as query file. + :param top_k: The maximum number of documents to return. + :param indices: The document_store index or indices that the retrievers should retrieve from. + :param file_index: The index that the query file should be retrieved from. + :param filters: Filters that should be applied for each retriever. + """ + retrieved_docs = self.retrieve( + query=query, top_k=top_k, indices=indices, file_index=file_index, filters=filters + ) + + return {"documents": retrieved_docs}, "output_1" + + # pylint: disable=arguments-differ + def run_batch( + self, + queries: List[str], + top_k: Optional[int] = None, + indices: Optional[Union[str, List[Union[str, None]]]] = None, + filters: Optional[Dict] = None, + file_index: Optional[str] = None, + ) -> Tuple[Dict, str]: + results = [] + for query in queries: + results.append( + self.retrieve(query=query, top_k=top_k, indices=indices, filters=filters, file_index=file_index) + ) + + return {"documents": results}, "output_1" + + def retrieve( + self, + query: str, + top_k: Optional[int] = None, + indices: Optional[Union[str, List[Optional[str]]]] = None, + file_index: Optional[str] = None, + filters: Optional[Dict] = None, + ) -> List[Document]: + """ + Performs file similarity retrieval using all retrievers that this node was initialized with. + The query should be the file aggregator value that will be used to get all relevant documents from the + document_store. + + :param query: Will be used to filter for all documents that belong to the file you want to use as query file. + :param top_k: The maximum number of documents to return. + :param indices: The document_store index or indices that the retrievers should retrieve from. + :param file_index: The index that the query file should be retrieved from. + :param filters: Filters that should be applied for each retriever. + """ + if isinstance(indices, str) or indices is None: + retriever_indices = [indices] * len(self.retrievers) + else: + retriever_indices = indices + + if not top_k: + top_k = self.top_k + + query_file_documents = self._get_documents(file_filter=query, index=file_index) + + if self.max_num_queries is not None and len(query_file_documents) > self.max_num_queries: + logger.warning( + "Query file %s has %s documents. " + "It exceeds the maximum number of query documents and was reduced to %s query documents.", + query, + len(query_file_documents), + self.max_num_queries, + ) + num_splits = math.ceil(len(query_file_documents) / self.max_num_queries) + query_file_documents = query_file_documents[::num_splits] + + if len(query_file_documents): + retrieved_docs = [] + futures = [] + for idx, retriever in zip(retriever_indices, self.retrievers): + if isinstance(retriever, EmbeddingRetriever) and all( + doc.embedding is not None for doc in query_file_documents + ): + future = self.executor.submit( + self._retrieve_for_documents_by_embedding, + retriever=retriever, + documents=query_file_documents, + index=idx, + filters=filters, + ) + else: + future = self.executor.submit( + self._retrieve_for_documents, + retriever=retriever, + documents=query_file_documents, + index=idx, + filters=filters, + ) + futures.append(future) + + for future in as_completed(futures): + retrieved_docs.extend(future.result()) + + aggregated_results = self._aggregate_results(results=retrieved_docs, top_k=top_k) + else: + logger.info("Could not find any indexed documents for query: %s.", query) + aggregated_results = [] + + return aggregated_results + + def _get_documents(self, file_filter: str, index: Optional[str]) -> List[Document]: + docs: List[Document] = self.document_store.get_all_documents( + index=index, + filters={self.file_aggregation_key: [file_filter]}, + return_embedding=self.use_existing_embedding, + ) + return docs + + def _retrieve_for_documents_by_embedding( + self, + retriever: EmbeddingRetriever, + documents: List[Document], + index: Optional[str] = None, + filters: Optional[Dict] = None, + ) -> List[List[Document]]: + doc_store = retriever.document_store + top_k = retriever.top_k + query_embs = [document.embedding for document in documents] + results: List[List[Document]] = doc_store.query_by_embedding_batch( + query_embs=query_embs, filters=filters, index=index, top_k=top_k + ) + return results + + def _retrieve_for_documents( + self, + retriever: BaseRetriever, + documents: List[Document], + index: Optional[str] = None, + filters: Optional[Dict] = None, + ) -> List[List[Document]]: + queries = [] + for doc in documents: + content = doc.content[: self.max_query_len] + queries.append(content) + if len(content) != len(doc.content): + logger.warning( + "Document %s retrieved with aggregation key '%s' exceeds max_query_len of %s and was cut off.", + doc, + self.file_aggregation_key, + self.max_query_len, + ) + + docs: List[List[Document]] = retriever.retrieve_batch( + queries=queries, index=index, filters=filters, scale_score=False + ) + + return docs + + def _aggregate_results(self, results: List[List[Document]], top_k: int) -> List[Document]: + # We iterate over each result list that contains for each query document and each retriever a list of documents + # ranked by their similarity to the query document. + # We group the result documents by the same file aggregation key and calculate the reciprocal rank fusion score. + aggregator_doc_lookup = defaultdict(list) + aggregated_scores: Dict = defaultdict(int) + for result in results: + for idx, doc in enumerate(result): + aggregator = doc.meta.get(self.file_aggregation_key) + if aggregator is None: + logger.warning( + "Document %s can not be aggregated. Missing aggregation key '%s' in meta.", + doc, + self.file_aggregation_key, + ) + else: + score = self._calculate_reciprocal_rank_fusion(idx) + aggregated_scores[doc.meta[self.file_aggregation_key]] += score + aggregator_doc_lookup[doc.meta[self.file_aggregation_key]].append(doc) + + # For each aggregated file we want to sort the retrieved documents by their score so that we + # can return the most relevant document for each aggregation later. + for aggregator in aggregator_doc_lookup: + aggregator_doc_lookup[aggregator] = sorted( + aggregator_doc_lookup[aggregator], key=lambda doc: doc.score, reverse=True # type: ignore + ) + + sorted_aggregator_scores = sorted(aggregated_scores.items(), key=lambda d: d[1], reverse=True) # type: ignore + result_docs = [] + for aggregator_id, rrf_score in sorted_aggregator_scores[:top_k]: + doc = aggregator_doc_lookup[aggregator_id][0] + if self.keep_original_score: + doc.meta[self.keep_original_score] = doc.score + doc.score = rrf_score + result_docs.append(doc) + + return result_docs + + @staticmethod + def _calculate_reciprocal_rank_fusion(current_idx: int) -> float: + """ + Calculates the reciprocal rank score for a Document instance at the current rank. + See https://plg.uwaterloo.ca/~gvcormac/cormacksigir09-rrf.pdf + + :param current_idx: The rank position of a document in the result list. + """ + # The paper above mentions a constant of 60 that should be used in the denominator. + # The denominator is the result of adding this constant and the rank of the retrieved document + # We set the constant to 61 because python passes the rank starting from 0. + reciprocal_rank_constant = 61 + return 1 / (reciprocal_rank_constant + current_idx)