diff --git a/haystack/modeling/data_handler/data_silo.py b/haystack/modeling/data_handler/data_silo.py index c9db096ade2..88151c8c10b 100644 --- a/haystack/modeling/data_handler/data_silo.py +++ b/haystack/modeling/data_handler/data_silo.py @@ -816,13 +816,16 @@ def _get_dataset(self, filename: Optional[Union[str, Path]], dicts: Optional[Lis batch.append(x) corresponding_chunks.append(i) if len(batch) == self.teacher_batch_size: + # Convert the list of tuples to a list of lists + data_as_list_of_lists = [list(item) for item in batch] self._pass_batches( - batch, corresponding_chunks, teacher_outputs, tensor_names + data_as_list_of_lists, corresponding_chunks, teacher_outputs, tensor_names ) # doing forward pass on teacher model batch = [] corresponding_chunks = [] if batch: - self._pass_batches(batch, corresponding_chunks, teacher_outputs, tensor_names) + data_as_list_of_lists = [list(item) for item in batch] + self._pass_batches(data_as_list_of_lists, corresponding_chunks, teacher_outputs, tensor_names) # appending teacher outputs to original dataset for dataset, teacher_output in zip(concat_datasets.datasets, teacher_outputs): diff --git a/haystack/nodes/prompt/invocation_layer/hugging_face_inference.py b/haystack/nodes/prompt/invocation_layer/hugging_face_inference.py index 239f5265268..9f793080f69 100644 --- a/haystack/nodes/prompt/invocation_layer/hugging_face_inference.py +++ b/haystack/nodes/prompt/invocation_layer/hugging_face_inference.py @@ -1,6 +1,6 @@ import json import os -from typing import Optional, Dict, Union, List, Any, Callable +from typing import cast, Generator, Optional, Dict, Union, List, Any, Callable import logging import requests @@ -191,7 +191,8 @@ def _process_streaming_response( :param stream_handler: The handler to invoke on each token. :param stop_words: The stop words to ignore. """ - client = sseclient.SSEClient(response) + byte_stream_generator = cast(Generator[bytes, None, None], response.iter_content(chunk_size=1)) + client = sseclient.SSEClient(byte_stream_generator) tokens: List[str] = [] try: for event in client.events(): diff --git a/haystack/nodes/retriever/filesim.py b/haystack/nodes/retriever/filesim.py index ba4fcb27326..5a0682f3cce 100644 --- a/haystack/nodes/retriever/filesim.py +++ b/haystack/nodes/retriever/filesim.py @@ -2,14 +2,14 @@ import math -from typing import List, Optional, Dict, Union, Tuple +from typing import Any, List, Optional, Dict, Union, Tuple import logging from collections import defaultdict from concurrent.futures import ThreadPoolExecutor, as_completed from haystack.nodes import BaseComponent, BaseRetriever, EmbeddingRetriever from haystack.document_stores import KeywordDocumentStore -from haystack.schema import Document +from haystack.schema import Document, MultiLabel logger = logging.getLogger(__name__) @@ -77,15 +77,18 @@ def __init__( self.use_existing_embedding = use_existing_embedding self.executor = ThreadPoolExecutor(max_workers=len(self.retrievers)) - # pylint: disable=arguments-renamed def run( self, - query: str, - top_k: Optional[int] = None, - indices: Optional[Union[str, List[Union[str, None]]]] = None, - filters: Optional[Dict] = None, - file_index: Optional[str] = None, - ) -> Tuple[Dict, str]: + query: Union[str, List[str], None] = None, + file_paths: Optional[List[str]] = None, + labels: Optional[MultiLabel] = None, + documents: Optional[List[Document]] = None, + meta: Optional[Dict[Any, Any]] = None, + top_k: Optional[int] = None, # Additional parameter with a default value + indices: Optional[Union[str, List[Union[str, None]]]] = None, # Additional parameter with a default value + filters: Optional[Dict[Any, Any]] = None, # Additional parameter with a default value + file_index: Optional[str] = None, # Additional parameter with a default value + ) -> Tuple[Dict[Any, Any], str]: """ Performs file similarity retrieval using all retrievers that this node was initialized with. The query should be the file aggregator value that will be used to get all relevant documents from the @@ -97,23 +100,48 @@ def run( :param file_index: The index that the query file should be retrieved from. :param filters: Filters that should be applied for each retriever. """ - retrieved_docs = self.retrieve( - query=query, top_k=top_k, indices=indices, file_index=file_index, filters=filters - ) + if isinstance(query, list): + # Handle the case where query is a list + # For example, take the first element or concatenate + query = query[0] # or your own logic to handle a list + + if isinstance(indices, list): + # Convert indices to the simpler type if necessary + indices = [index for index in indices if index is not None] # Remove None values + + if query is None: + raise ValueError("Query cannot be None.") + + if query is not None: + retrieved_docs = self.retrieve( + query=query, top_k=top_k, indices=indices, file_index=file_index, filters=filters + ) return {"documents": retrieved_docs}, "output_1" - # pylint: disable=arguments-differ def run_batch( self, - queries: List[str], + queries: Union[str, List[str], None] = None, + file_paths: Optional[List[str]] = None, + labels: Optional[Union[MultiLabel, List[MultiLabel]]] = None, + documents: Optional[Union[List[Document], List[List[Document]]]] = None, + meta: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None, + params: Optional[Dict[Any, Any]] = None, + debug: Optional[bool] = None, + # Additional parameters with default values top_k: Optional[int] = None, indices: Optional[Union[str, List[Union[str, None]]]] = None, - filters: Optional[Dict] = None, + filters: Optional[Dict[Any, Any]] = None, file_index: Optional[str] = None, - ) -> Tuple[Dict, str]: + ) -> Any: + # Convert complex types to simpler types + if queries is not None: + simple_queries = [ + q[0] if isinstance(q, list) else q for q in queries + ] # Assuming you want the first query if it's a list of lists + results = [] - for query in queries: + for query in simple_queries: results.append( self.retrieve(query=query, top_k=top_k, indices=indices, filters=filters, file_index=file_index) ) @@ -210,8 +238,22 @@ def _retrieve_for_documents_by_embedding( filters: Optional[Dict] = None, ) -> List[List[Document]]: doc_store = retriever.document_store + if doc_store is None: + raise ValueError("Document store cannot be None") + top_k = retriever.top_k - query_embs = [document.embedding for document in documents] + + # Filter out documents where the embedding is None + valid_documents = [doc for doc in documents if doc.embedding is not None] + if not valid_documents: + raise ValueError("No valid document embeddings found for query.") + + # Filter out documents where the embedding is None and create the query_embs list + # redundant check to pass mypy linter + query_embs = [doc.embedding for doc in valid_documents if doc.embedding is not None] + if not query_embs: + raise ValueError("All document embeddings are None") + results: List[List[Document]] = doc_store.query_by_embedding_batch( query_embs=query_embs, filters=filters, index=index, top_k=top_k ) diff --git a/releasenotes/notes/file-similarity-retriever-394427b2241e1396.yaml b/releasenotes/notes/file-similarity-retriever-394427b2241e1396.yaml new file mode 100644 index 00000000000..ec0999513aa --- /dev/null +++ b/releasenotes/notes/file-similarity-retriever-394427b2241e1396.yaml @@ -0,0 +1,4 @@ +--- +features: + - | + Addition of FileSimilarityRetriever to haystack \ No newline at end of file