Skip to content

Commit

Permalink
FileSimilarityRetriever - Add readme, fix mypy linter issues (#6821)
Browse files Browse the repository at this point in the history
* Add readme, fix mypy linter issues

* fix black formatting

* fix more mypy issues

* fix black formatting

* fix mypy again

* try to fix mypy linter again

* try fixing mypy again
  • Loading branch information
augchan42 committed Jan 24, 2024
1 parent a146165 commit 59b3266
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 22 deletions.
7 changes: 5 additions & 2 deletions haystack/modeling/data_handler/data_silo.py
Original file line number Diff line number Diff line change
Expand Up @@ -816,13 +816,16 @@ def _get_dataset(self, filename: Optional[Union[str, Path]], dicts: Optional[Lis
batch.append(x)
corresponding_chunks.append(i)
if len(batch) == self.teacher_batch_size:
# Convert the list of tuples to a list of lists
data_as_list_of_lists = [list(item) for item in batch]
self._pass_batches(
batch, corresponding_chunks, teacher_outputs, tensor_names
data_as_list_of_lists, corresponding_chunks, teacher_outputs, tensor_names
) # doing forward pass on teacher model
batch = []
corresponding_chunks = []
if batch:
self._pass_batches(batch, corresponding_chunks, teacher_outputs, tensor_names)
data_as_list_of_lists = [list(item) for item in batch]
self._pass_batches(data_as_list_of_lists, corresponding_chunks, teacher_outputs, tensor_names)

# appending teacher outputs to original dataset
for dataset, teacher_output in zip(concat_datasets.datasets, teacher_outputs):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import json
import os
from typing import Optional, Dict, Union, List, Any, Callable
from typing import cast, Generator, Optional, Dict, Union, List, Any, Callable
import logging

import requests
Expand Down Expand Up @@ -191,7 +191,8 @@ def _process_streaming_response(
:param stream_handler: The handler to invoke on each token.
:param stop_words: The stop words to ignore.
"""
client = sseclient.SSEClient(response)
byte_stream_generator = cast(Generator[bytes, None, None], response.iter_content(chunk_size=1))
client = sseclient.SSEClient(byte_stream_generator)
tokens: List[str] = []
try:
for event in client.events():
Expand Down
78 changes: 60 additions & 18 deletions haystack/nodes/retriever/filesim.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@

import math

from typing import List, Optional, Dict, Union, Tuple
from typing import Any, List, Optional, Dict, Union, Tuple
import logging
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor, as_completed

from haystack.nodes import BaseComponent, BaseRetriever, EmbeddingRetriever
from haystack.document_stores import KeywordDocumentStore
from haystack.schema import Document
from haystack.schema import Document, MultiLabel


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -77,15 +77,18 @@ def __init__(
self.use_existing_embedding = use_existing_embedding
self.executor = ThreadPoolExecutor(max_workers=len(self.retrievers))

# pylint: disable=arguments-renamed
def run(
self,
query: str,
top_k: Optional[int] = None,
indices: Optional[Union[str, List[Union[str, None]]]] = None,
filters: Optional[Dict] = None,
file_index: Optional[str] = None,
) -> Tuple[Dict, str]:
query: Union[str, List[str], None] = None,
file_paths: Optional[List[str]] = None,
labels: Optional[MultiLabel] = None,
documents: Optional[List[Document]] = None,
meta: Optional[Dict[Any, Any]] = None,
top_k: Optional[int] = None, # Additional parameter with a default value
indices: Optional[Union[str, List[Union[str, None]]]] = None, # Additional parameter with a default value
filters: Optional[Dict[Any, Any]] = None, # Additional parameter with a default value
file_index: Optional[str] = None, # Additional parameter with a default value
) -> Tuple[Dict[Any, Any], str]:
"""
Performs file similarity retrieval using all retrievers that this node was initialized with.
The query should be the file aggregator value that will be used to get all relevant documents from the
Expand All @@ -97,23 +100,48 @@ def run(
:param file_index: The index that the query file should be retrieved from.
:param filters: Filters that should be applied for each retriever.
"""
retrieved_docs = self.retrieve(
query=query, top_k=top_k, indices=indices, file_index=file_index, filters=filters
)
if isinstance(query, list):
# Handle the case where query is a list
# For example, take the first element or concatenate
query = query[0] # or your own logic to handle a list

if isinstance(indices, list):
# Convert indices to the simpler type if necessary
indices = [index for index in indices if index is not None] # Remove None values

if query is None:
raise ValueError("Query cannot be None.")

if query is not None:
retrieved_docs = self.retrieve(
query=query, top_k=top_k, indices=indices, file_index=file_index, filters=filters
)

return {"documents": retrieved_docs}, "output_1"

# pylint: disable=arguments-differ
def run_batch(
self,
queries: List[str],
queries: Union[str, List[str], None] = None,
file_paths: Optional[List[str]] = None,
labels: Optional[Union[MultiLabel, List[MultiLabel]]] = None,
documents: Optional[Union[List[Document], List[List[Document]]]] = None,
meta: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None,
params: Optional[Dict[Any, Any]] = None,
debug: Optional[bool] = None,
# Additional parameters with default values
top_k: Optional[int] = None,
indices: Optional[Union[str, List[Union[str, None]]]] = None,
filters: Optional[Dict] = None,
filters: Optional[Dict[Any, Any]] = None,
file_index: Optional[str] = None,
) -> Tuple[Dict, str]:
) -> Any:
# Convert complex types to simpler types
if queries is not None:
simple_queries = [
q[0] if isinstance(q, list) else q for q in queries
] # Assuming you want the first query if it's a list of lists

results = []
for query in queries:
for query in simple_queries:
results.append(
self.retrieve(query=query, top_k=top_k, indices=indices, filters=filters, file_index=file_index)
)
Expand Down Expand Up @@ -210,8 +238,22 @@ def _retrieve_for_documents_by_embedding(
filters: Optional[Dict] = None,
) -> List[List[Document]]:
doc_store = retriever.document_store
if doc_store is None:
raise ValueError("Document store cannot be None")

top_k = retriever.top_k
query_embs = [document.embedding for document in documents]

# Filter out documents where the embedding is None
valid_documents = [doc for doc in documents if doc.embedding is not None]
if not valid_documents:
raise ValueError("No valid document embeddings found for query.")

# Filter out documents where the embedding is None and create the query_embs list
# redundant check to pass mypy linter
query_embs = [doc.embedding for doc in valid_documents if doc.embedding is not None]
if not query_embs:
raise ValueError("All document embeddings are None")

results: List[List[Document]] = doc_store.query_by_embedding_batch(
query_embs=query_embs, filters=filters, index=index, top_k=top_k
)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
---
features:
- |
Addition of FileSimilarityRetriever to haystack

0 comments on commit 59b3266

Please sign in to comment.