Skip to content

Commit

Permalink
Merge branch 'master' into duckduckgo
Browse files Browse the repository at this point in the history
  • Loading branch information
Appointat committed May 8, 2024
2 parents d966760 + e2af8d9 commit c779c42
Show file tree
Hide file tree
Showing 16 changed files with 453 additions and 132 deletions.
1 change: 1 addition & 0 deletions .github/workflows/build_package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,5 @@ jobs:
SEARCH_ENGINE_ID: "${{ secrets.SEARCH_ENGINE_ID }}"
OPENWEATHERMAP_API_KEY: "${{ secrets.OPENWEATHERMAP_API_KEY }}"
ANTHROPIC_API_KEY: "${{ secrets.ANTHROPIC_API_KEY }}"
COHERE_API_KEY: "${{ secrets.COHERE_API_KEY }}"
run: pytest --fast-test-mode ./test
2 changes: 2 additions & 0 deletions .github/workflows/pytest_apps.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ jobs:
OPENAI_API_KEY: "${{ secrets.OPENAI_API_KEY }}"
GOOGLE_API_KEY: "${{ secrets.GOOGLE_API_KEY }}"
SEARCH_ENGINE_ID: "${{ secrets.SEARCH_ENGINE_ID }}"
COHERE_API_KEY: "${{ secrets.COHERE_API_KEY }}"
run: poetry run pytest -v apps/

pytest_examples:
Expand All @@ -45,4 +46,5 @@ jobs:
OPENAI_API_KEY: "${{ secrets.OPENAI_API_KEY }}"
GOOGLE_API_KEY: "${{ secrets.GOOGLE_API_KEY }}"
SEARCH_ENGINE_ID: "${{ secrets.SEARCH_ENGINE_ID }}"
COHERE_API_KEY: "${{ secrets.COHERE_API_KEY }}"
run: poetry run pytest -v examples/
3 changes: 3 additions & 0 deletions .github/workflows/pytest_package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ jobs:
SEARCH_ENGINE_ID: "${{ secrets.SEARCH_ENGINE_ID }}"
OPENWEATHERMAP_API_KEY: "${{ secrets.OPENWEATHERMAP_API_KEY }}"
ANTHROPIC_API_KEY: "${{ secrets.ANTHROPIC_API_KEY }}"
COHERE_API_KEY: "${{ secrets.COHERE_API_KEY }}"
run: poetry run pytest --fast-test-mode test/

pytest_package_llm_test:
Expand All @@ -45,6 +46,7 @@ jobs:
SEARCH_ENGINE_ID: "${{ secrets.SEARCH_ENGINE_ID }}"
OPENWEATHERMAP_API_KEY: "${{ secrets.OPENWEATHERMAP_API_KEY }}"
ANTHROPIC_API_KEY: "${{ secrets.ANTHROPIC_API_KEY }}"
COHERE_API_KEY: "${{ secrets.COHERE_API_KEY }}"
run: poetry run pytest --llm-test-only test/

pytest_package_very_slow_test:
Expand All @@ -62,4 +64,5 @@ jobs:
SEARCH_ENGINE_ID: "${{ secrets.SEARCH_ENGINE_ID }}"
OPENWEATHERMAP_API_KEY: "${{ secrets.OPENWEATHERMAP_API_KEY }}"
ANTHROPIC_API_KEY: "${{ secrets.ANTHROPIC_API_KEY }}"
COHERE_API_KEY: "${{ secrets.COHERE_API_KEY }}"
run: poetry run pytest --very-slow-test-only test/
2 changes: 2 additions & 0 deletions camel/retrievers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,13 @@
from .auto_retriever import AutoRetriever
from .base import BaseRetriever
from .bm25_retriever import BM25Retriever
from .cohere_rerank_retriever import CohereRerankRetriever
from .vector_retriever import VectorRetriever

__all__ = [
'BaseRetriever',
'VectorRetriever',
'AutoRetriever',
'BM25Retriever',
'CohereRerankRetriever',
]
15 changes: 11 additions & 4 deletions camel/retrievers/auto_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,11 +278,18 @@ def run_vector_retriever(
# Clear the vector storage
vector_storage_instance.clear()
# Process and store the content to the vector storage
vr.process(content_input_path, vector_storage_instance)
vr = VectorRetriever(
storage=vector_storage_instance,
similarity_threshold=similarity_threshold,
)
vr.process(content_input_path)
else:
vr = VectorRetriever(
storage=vector_storage_instance,
similarity_threshold=similarity_threshold,
)
# Retrieve info by given query from the vector storage
retrieved_info = vr.query(
query, vector_storage_instance, top_k, similarity_threshold
)
retrieved_info = vr.query(query, top_k)
# Reorganize the retrieved info with original query
for info in retrieved_info:
retrieved_infos += "\n" + str(info)
Expand Down
79 changes: 42 additions & 37 deletions camel/retrievers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,53 +12,58 @@
# limitations under the License.
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
from abc import ABC, abstractmethod
from typing import Any, Dict, List
from typing import Any, Callable

DEFAULT_TOP_K_RESULTS = 1


class BaseRetriever(ABC):
r"""Abstract base class for implementing various types of information
retrievers.
def _query_unimplemented(self, *input: Any) -> None:
r"""Defines the query behavior performed at every call.
Query the results. Subclasses should implement this
method according to their specific needs.
It should be overridden by all subclasses.
.. note::
Although the recipe for forward pass needs to be defined within
this function, one should call the :class:`BaseRetriever` instance
afterwards instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
"""
raise NotImplementedError(
f"Retriever [{type(self).__name__}] is missing the required \"query\" function"
)

@abstractmethod
def __init__(self) -> None:
pass

@abstractmethod
def process(
self,
content_input_path: str,
chunk_type: str = "chunk_by_title",
**kwargs: Any,
) -> None:
r"""Processes content from a file or URL, divides it into chunks by
def _process_unimplemented(self, *input: Any) -> None:
r"""Defines the process behavior performed at every call.
Processes content from a file or URL, divides it into chunks by
using `Unstructured IO`,then stored internally. This method must be
called before executing queries with the retriever.
Args:
content_input_path (str): File path or URL of the content to be
processed.
chunk_type (str): Type of chunking going to apply. Defaults to
"chunk_by_title".
**kwargs (Any): Additional keyword arguments for content parsing.
"""
pass
Should be overridden by all subclasses.
@abstractmethod
def query(
self, query: str, top_k: int = DEFAULT_TOP_K_RESULTS, **kwargs: Any
) -> List[Dict[str, Any]]:
r"""Query the results. Subclasses should implement this
method according to their specific needs.
.. note::
Although the recipe for forward pass needs to be defined within
this function, one should call the :class:`BaseRetriever` instance
afterwards instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
"""
raise NotImplementedError(
f"Retriever [{type(self).__name__}] is missing the required \"process\" function"
)


class BaseRetriever(ABC):
r"""Abstract base class for implementing various types of information
retrievers.
"""

Args:
query (str): Query string for information retriever.
top_k (int, optional): The number of top results to return during
retriever. Must be a positive integer. Defaults to
`DEFAULT_TOP_K_RESULTS`.
**kwargs (Any): Flexible keyword arguments for additional
parameters, like `similarity_threshold`.
"""
@abstractmethod
def __init__(self) -> None:
pass

process: Callable[..., Any] = _process_unimplemented
query: Callable[..., Any] = _query_unimplemented
29 changes: 10 additions & 19 deletions camel/retrievers/bm25_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ class BM25Retriever(BaseRetriever):
calculating document scores.
content_input_path (str): The path to the content that has been
processed and stored.
chunks (List[Any]): A list of document chunks processed from the
input content.
unstructured_modules (UnstructuredIO): A module for parsing files and
URLs and chunking content based on specified parameters.
References:
https://github.com/dorianbrown/rank_bm25
Expand All @@ -47,13 +47,12 @@ def __init__(self) -> None:
from rank_bm25 import BM25Okapi
except ImportError as e:
raise ImportError(
"Package `rank_bm25` not installed, install by running"
" 'pip install rank_bm25'"
"Package `rank_bm25` not installed, install by running 'pip install rank_bm25'"
) from e

self.bm25: BM25Okapi = None
self.content_input_path: str = ""
self.chunks: List[Any] = []
self.unstructured_modules: UnstructuredIO = UnstructuredIO()

def process(
self,
Expand All @@ -76,19 +75,18 @@ def process(

# Load and preprocess documents
self.content_input_path = content_input_path
unstructured_modules = UnstructuredIO()
elements = unstructured_modules.parse_file_or_url(
elements = self.unstructured_modules.parse_file_or_url(
content_input_path, **kwargs
)
self.chunks = unstructured_modules.chunk_elements(
self.chunks = self.unstructured_modules.chunk_elements(
chunk_type=chunk_type, elements=elements
)

# Convert chunks to a list of strings for tokenization
tokenized_corpus = [str(chunk).split(" ") for chunk in self.chunks]
self.bm25 = BM25Okapi(tokenized_corpus)

def query( # type: ignore[override]
def query(
self,
query: str,
top_k: int = DEFAULT_TOP_K_RESULTS,
Expand All @@ -106,22 +104,15 @@ def query( # type: ignore[override]
Raises:
ValueError: If `top_k` is less than or equal to 0, if the BM25
model has not been initialized by calling `process_and_store`
model has not been initialized by calling `process`
first.
Note:
`storage` and `kwargs` parameters are included to maintain
compatibility with the `BaseRetriever` interface but are not used
in this implementation.
"""

if top_k <= 0:
raise ValueError("top_k must be a positive integer.")

if self.bm25 is None:
if self.bm25 is None or not self.chunks:
raise ValueError(
"BM25 model is not initialized. Call `process_and_store`"
" first."
"BM25 model is not initialized. Call `process` first."
)

# Preprocess query similarly to how documents were processed
Expand Down
108 changes: 108 additions & 0 deletions camel/retrievers/cohere_rerank_retriever.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
# Licensed under the Apache License, Version 2.0 (the “License”);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an “AS IS” BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
import os
from typing import Any, Dict, List, Optional

from camel.retrievers import BaseRetriever

DEFAULT_TOP_K_RESULTS = 1


class CohereRerankRetriever(BaseRetriever):
r"""An implementation of the `BaseRetriever` using the `Cohere Re-ranking`
model.
Attributes:
model_name (str): The model name to use for re-ranking.
api_key (Optional[str]): The API key for authenticating with the
Cohere service.
References:
https://txt.cohere.com/rerank/
"""

def __init__(
self,
model_name: str = "rerank-multilingual-v2.0",
api_key: Optional[str] = None,
) -> None:
r"""Initializes an instance of the CohereRerankRetriever. This
constructor sets up a client for interacting with the Cohere API using
the specified model name and API key. If the API key is not provided,
it attempts to retrieve it from the COHERE_API_KEY environment
variable.
Args:
model_name (str): The name of the model to be used for re-ranking.
Defaults to 'rerank-multilingual-v2.0'.
api_key (Optional[str]): The API key for authenticating requests
to the Cohere API. If not provided, the method will attempt to
retrieve the key from the environment variable
'COHERE_API_KEY'.
Raises:
ImportError: If the 'cohere' package is not installed.
ValueError: If the API key is neither passed as an argument nor
set in the environment variable.
"""

try:
import cohere
except ImportError as e:
raise ImportError("Package 'cohere' is not installed") from e

try:
self.api_key = api_key or os.environ["COHERE_API_KEY"]
except ValueError as e:
raise ValueError(
"Must pass in cohere api key or specify via COHERE_API_KEY environment variable."
) from e

self.co = cohere.Client(self.api_key)
self.model_name = model_name

def query(
self,
query: str,
retrieved_result: List[Dict[str, Any]],
top_k: int = DEFAULT_TOP_K_RESULTS,
) -> List[Dict[str, Any]]:
r"""Queries and compiles results using the Cohere re-ranking model.
Args:
query (str): Query string for information retriever.
retrieved_result (List[Dict[str, Any]]): The content to be
re-ranked, should be the output from `BaseRetriever` like
`VectorRetriever`.
top_k (int, optional): The number of top results to return during
retriever. Must be a positive integer. Defaults to
`DEFAULT_TOP_K_RESULTS`.
Returns:
List[Dict[str, Any]]: Concatenated list of the query results.
"""
rerank_results = self.co.rerank(
query=query,
documents=retrieved_result,
top_n=top_k,
model=self.model_name,
)
formatted_results = []
for i in range(0, len(rerank_results.results)):
selected_chunk = retrieved_result[rerank_results[i].index]
selected_chunk['similarity score'] = rerank_results[
i
].relevance_score
formatted_results.append(selected_chunk)
return formatted_results

0 comments on commit c779c42

Please sign in to comment.