Skip to content

Commit

Permalink
Merge pull request emrgnt-cmplxty#25 from EmergentAGI/feature/revive-…
Browse files Browse the repository at this point in the history
…symbol-search-tool

revive symbol search tool
  • Loading branch information
emrgnt-cmplxty committed Jun 16, 2023
2 parents fcafb7a + 1643d7e commit 0db821b
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 240 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@
from typing import Callable, List, Optional, Union

from automata.core.base.tool import Tool
from automata.core.tools.search.symbol_searcher import (
from automata.core.symbol.search.symbol_search import (
ExactSearchResult,
SourceCodeResult,
SymbolRankResult,
SymbolReferencesResult,
SymbolSearcher,
SymbolSearch,
)


Expand All @@ -18,14 +18,14 @@ class SearchTool(Enum):
EXACT_SEARCH = "exact-search"


class SymbolSearcherToolManager:
class SymbolSearchToolManager:
def __init__(
self,
symbol_searcher: SymbolSearcher,
symbol_search: SymbolSearch,
search_tools: Optional[List[SearchTool]] = None,
post_processing: Optional[Callable] = None,
):
self.symbol_searcher = symbol_searcher
self.symbol_search = symbol_search
self.search_tools = search_tools or list(SearchTool)
self.post_processing = post_processing

Expand Down Expand Up @@ -65,21 +65,21 @@ def process_query(
# TODO - Cleanup these processors to ensure they behave well.
# -- Right now these are just simplest implementations I can rattle off
def _symbol_rank_search_processor(self, query: str) -> str:
query_result = self.symbol_searcher.symbol_rank_search(query)
query_result = self.symbol_search.symbol_rank_search(query)
return "\n".join([symbol.uri for symbol, _rank in query_result])

def _symbol_symbol_references_processor(self, query: str) -> str:
query_result = self.symbol_searcher.symbol_references(query)
query_result = self.symbol_search.symbol_references(query)
return "\n".join(
[f"{symbol}:{str(reference)}" for symbol, reference in query_result.items()]
)

def _retrieve_source_code_by_symbol_processor(self, query: str) -> str:
query_result = self.symbol_searcher.retrieve_source_code_by_symbol(query)
query_result = self.symbol_search.retrieve_source_code_by_symbol(query)
return query_result or "No Result Found"

def _exact_search_processor(self, query: str) -> str:
query_result = self.symbol_searcher.exact_search(query)
query_result = self.symbol_search.exact_search(query)
processed_result = "\n".join(
[f"{symbol}:{str(references)}" for symbol, references in query_result.items()]
)
Expand Down
44 changes: 34 additions & 10 deletions automata/core/agent/tool_management/tool_management_utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,21 @@
import importlib
import logging
import os
from typing import Dict, List

# from automata.core.search.symbol_factory import SymbolSearcherFactory
from automata.config.config_types import ConfigCategory
from automata.core.agent.tool_management.base_tool_manager import BaseToolManager
from automata.core.base.tool import Tool, Toolkit, ToolkitType
from automata.core.coding.py_coding.retriever import PyCodeRetriever
from automata.core.coding.py_coding.writer import PyCodeWriter
from automata.core.database.vector import JSONVectorDatabase
from automata.core.embedding.code_embedding import SymbolCodeEmbeddingHandler
from automata.core.embedding.embedding_types import OpenAIEmbedding
from automata.core.embedding.symbol_similarity import SymbolSimilarity
from automata.core.symbol.graph import SymbolGraph
from automata.core.symbol.search.rank import SymbolRankConfig
from automata.core.symbol.search.symbol_search import SymbolSearch
from automata.core.utils import config_fpath

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -40,15 +49,30 @@ def create_tool_manager(toolkit_type: ToolkitType) -> BaseToolManager:
return PyCodeWriterToolManager(
python_writer=PyCodeWriter(ToolManagerFactory._retriever_instance)
)
# elif toolkit_type == ToolkitType.SYMBOL_SEARCHER:
# SymbolSearcherToolManager = importlib.import_module(
# "automata.core.agent.tool_management.symbol_searcher_tool_manager"
# ).SymbolSearcherToolManager
# return SymbolSearcherToolManager(
# symbol_searcher=SymbolSearcherFactory().create(
# index_name="index.scip", symbol_embedding_name="symbol_embedding.json"
# )
# )
elif toolkit_type == ToolkitType.SYMBOL_SEARCHER:
SymbolSearchToolManager = importlib.import_module(
"automata.core.agent.tool_management.symbol_search_manager"
).SymbolSearchToolManager

graph = SymbolGraph()
subgraph = graph.get_rankable_symbol_subgraph()

code_embedding_fpath = os.path.join(
config_fpath(), ConfigCategory.SYMBOL.value, "symbol_code_embedding.json"
)
code_embedding_db = JSONVectorDatabase(code_embedding_fpath)
code_embedding_handler = SymbolCodeEmbeddingHandler(
code_embedding_db, OpenAIEmbedding()
)

symbol_similarity = SymbolSimilarity(code_embedding_handler)
symbol_search = SymbolSearch(
graph,
symbol_similarity,
symbol_rank_config=SymbolRankConfig(),
code_subgraph=subgraph,
)
return SymbolSearchToolManager(symbol_search=symbol_search)
else:
raise ValueError("Unknown toolkit type: %s" % toolkit_type)

Expand Down
Empty file removed automata/core/tools/__init__.py
Empty file.
142 changes: 0 additions & 142 deletions automata/core/tools/search/symbol_searcher.py

This file was deleted.

79 changes: 0 additions & 79 deletions automata/tests/unit/test_symbol_searcher_tool_manager.py

This file was deleted.

0 comments on commit 0db821b

Please sign in to comment.