Skip to content

Commit

Permalink
integrate into rag
Browse files Browse the repository at this point in the history
  • Loading branch information
fanzhidongyzby committed May 14, 2024
1 parent f7f39a1 commit 12da5b3
Show file tree
Hide file tree
Showing 23 changed files with 560 additions and 157 deletions.
15 changes: 13 additions & 2 deletions .env.template
Original file line number Diff line number Diff line change
Expand Up @@ -145,10 +145,12 @@ DENYLISTED_PLUGINS=


#*******************************************************************#
#** VECTOR STORE SETTINGS **#
#** VECTOR STORE / KNOWLEDGE GRAPH SETTINGS **#
#*******************************************************************#
### Chroma vector db config
VECTOR_STORE_TYPE=Chroma
GRAPH_STORE_TYPE=TuGraph

### Chroma vector db config
#CHROMA_PERSIST_PATH=/root/DB-GPT/pilot/data

### Milvus vector db config
Expand All @@ -163,6 +165,15 @@ VECTOR_STORE_TYPE=Chroma
#VECTOR_STORE_TYPE=Weaviate
#WEAVIATE_URL=https://kt-region-m8hcy0wc.weaviate.network

### TuGraph config
#TUGRAPH_HOST=127.0.0.1
#TUGRAPH_PORT=7070
#TUGRAPH_USERNAME=admin
#TUGRAPH_PASSWORD=73@TuGraph
#TUGRAPH_VERTEX_TYPE=entity
#TUGRAPH_EDGE_TYPE=relation
#TUGRAPH_EDGE_NAME_KEY=label

#*******************************************************************#
#** WebServer Language Support **#
#*******************************************************************#
Expand Down
10 changes: 5 additions & 5 deletions dbgpt/app/knowledge/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def space_delete(request: KnowledgeSpaceRequest):
try:
return Result.succ(knowledge_space_service.delete_space(request.name))
except Exception as e:
return Result.failed(code="E000X", msg=f"space list error {e}")
return Result.failed(code="E000X", msg=f"space delete error {e}")


@router.post("/knowledge/{space_name}/arguments")
Expand All @@ -84,7 +84,7 @@ def arguments(space_name: str):
try:
return Result.succ(knowledge_space_service.arguments(space_name))
except Exception as e:
return Result.failed(code="E000X", msg=f"space list error {e}")
return Result.failed(code="E000X", msg=f"space arguments error {e}")


@router.post("/knowledge/{space_name}/argument/save")
Expand All @@ -95,7 +95,7 @@ def arguments_save(space_name: str, argument_request: SpaceArgumentRequest):
knowledge_space_service.argument_save(space_name, argument_request)
)
except Exception as e:
return Result.failed(code="E000X", msg=f"space list error {e}")
return Result.failed(code="E000X", msg=f"space save error {e}")


@router.post("/knowledge/{space_name}/document/add")
Expand Down Expand Up @@ -164,7 +164,7 @@ def document_delete(space_name: str, query_request: DocumentQueryRequest):
knowledge_space_service.delete_document(space_name, query_request.doc_name)
)
except Exception as e:
return Result.failed(code="E000X", msg=f"document list error {e}")
return Result.failed(code="E000X", msg=f"document delete error {e}")


@router.post("/knowledge/{space_name}/document/upload")
Expand Down Expand Up @@ -248,7 +248,7 @@ def batch_document_sync(
# )
return Result.succ({"tasks": doc_ids})
except Exception as e:
return Result.failed(code="E000X", msg=f"document sync error {e}")
return Result.failed(code="E000X", msg=f"document sync batch error {e}")


@router.post("/knowledge/{space_name}/chunk/list")
Expand Down
61 changes: 45 additions & 16 deletions dbgpt/app/knowledge/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,9 @@
)
from dbgpt.component import ComponentType
from dbgpt.configs.model_config import EMBEDDING_MODEL_CONFIG
from dbgpt.core import Chunk
from dbgpt.core import Chunk, LLMClient
from dbgpt.model import DefaultLLMClient
from dbgpt.model.cluster import WorkerManagerFactory
from dbgpt.rag.assembler.embedding import EmbeddingAssembler
from dbgpt.rag.assembler.summary import SummaryAssembler
from dbgpt.rag.chunk_manager import ChunkParameters
Expand All @@ -38,8 +39,9 @@
SpacyTextSplitter,
)
from dbgpt.serve.rag.api.schemas import KnowledgeSyncRequest
from dbgpt.serve.rag.models.models import KnowledgeSpaceDao, KnowledgeSpaceEntity
from dbgpt.serve.rag.service.service import Service, SyncStatus
from dbgpt.serve.rag.models.models import KnowledgeSpaceDao, \
KnowledgeSpaceEntity
from dbgpt.serve.rag.service.service import SyncStatus
from dbgpt.storage.vector_store.base import VectorStoreConfig
from dbgpt.storage.vector_store.connector import VectorStoreConnector
from dbgpt.util.executor_utils import ExecutorFactory, blocking_func_to_async
Expand Down Expand Up @@ -70,6 +72,13 @@ class KnowledgeService:
def __init__(self):
pass

@property
def llm_client(self) -> LLMClient:
worker_manager = CFG.SYSTEM_APP.get_component(
ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
).create()
return DefaultLLMClient(worker_manager, True)

def create_knowledge_space(self, request: KnowledgeSpaceRequest):
"""create knowledge space
Args:
Expand Down Expand Up @@ -332,16 +341,23 @@ def _sync_knowledge_document(
embedding_fn = embedding_factory.create(
model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL]
)

spaces = self.get_knowledge_space(KnowledgeSpaceRequest(name=space_name))
if len(spaces) != 1:
raise Exception(f"invalid space name:{space_name}")
space = spaces[0]

from dbgpt.storage.vector_store.base import VectorStoreConfig

config = VectorStoreConfig(
name=space_name,
name=space.name,
embedding_fn=embedding_fn,
max_chunks_once_load=CFG.KNOWLEDGE_MAX_CHUNKS_ONCE_LOAD,
llm_client=self.llm_client
)
vector_store_connector = VectorStoreConnector(
vector_store_type=CFG.VECTOR_STORE_TYPE,
vector_store_config=config,
vector_store_type=space.vector_type,
vector_store_config=config
)
knowledge = KnowledgeFactory.create(
datasource=doc.content,
Expand Down Expand Up @@ -442,21 +458,25 @@ def delete_space(self, space_name: str):
Args:
- space_name: knowledge space name
"""
query = KnowledgeSpaceEntity(name=space_name)
spaces = knowledge_space_dao.get_knowledge_space(query)
if len(spaces) == 0:
raise Exception(f"delete error, no space name:{space_name} in database")
spaces = knowledge_space_dao.get_knowledge_space(KnowledgeSpaceEntity(name=space_name))
if len(spaces) != 1:
raise Exception(f"invalid space name:{space_name}")
space = spaces[0]

embedding_factory = CFG.SYSTEM_APP.get_component(
"embedding_factory", EmbeddingFactory
)
embedding_fn = embedding_factory.create(
model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL]
)
config = VectorStoreConfig(name=space.name, embedding_fn=embedding_fn)
config = VectorStoreConfig(
name=space.name,
embedding_fn=embedding_fn,
llm_client=self.llm_client
)
vector_store_connector = VectorStoreConnector(
vector_store_type=CFG.VECTOR_STORE_TYPE,
vector_store_config=config,
vector_store_type=space.vector_type,
vector_store_config=config
)
# delete vectors
vector_store_connector.delete_vector_name(space.name)
Expand All @@ -480,12 +500,21 @@ def delete_document(self, space_name: str, doc_name: str):
documents = knowledge_document_dao.get_documents(document_query)
if len(documents) != 1:
raise Exception(f"there are no or more than one document called {doc_name}")

spaces = self.get_knowledge_space(KnowledgeSpaceRequest(name=space_name))
if len(spaces) != 1:
raise Exception(f"invalid space name:{space_name}")
space = spaces[0]

vector_ids = documents[0].vector_ids
if vector_ids is not None:
config = VectorStoreConfig(name=space_name)
config = VectorStoreConfig(
name=space.name,
llm_client=self.llm_client
)
vector_store_connector = VectorStoreConnector(
vector_store_type=CFG.VECTOR_STORE_TYPE,
vector_store_config=config,
vector_store_type=space.vector_type,
vector_store_config=config
)
# delete vector by ids
vector_store_connector.delete_by_ids(vector_ids)
Expand Down
20 changes: 16 additions & 4 deletions dbgpt/app/scene/chat_knowledge/v1/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,23 @@ def __init__(self, chat_param: Dict):
model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL]
)
from dbgpt.storage.vector_store.base import VectorStoreConfig

config = VectorStoreConfig(name=self.knowledge_space, embedding_fn=embedding_fn)
from dbgpt.serve.rag.models.models import KnowledgeSpaceDao
from dbgpt.serve.rag.models.models import KnowledgeSpaceEntity

spaces = KnowledgeSpaceDao().get_knowledge_space(KnowledgeSpaceEntity(name=self.knowledge_space))
if len(spaces) != 1:
raise Exception(f"invalid space name:{self.knowledge_space}")
space = spaces[0]

config = VectorStoreConfig(
name=space.name,
embedding_fn=embedding_fn,
llm_client=self.llm_client,
llm_model=self.llm_model
)
vector_store_connector = VectorStoreConnector(
vector_store_type=CFG.VECTOR_STORE_TYPE,
vector_store_config=config,
vector_store_type=space.vector_type,
vector_store_config=config
)
query_rewrite = None
if CFG.KNOWLEDGE_SEARCH_REWRITE:
Expand Down
52 changes: 50 additions & 2 deletions dbgpt/rag/index/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,46 @@
import time
from abc import ABC, abstractmethod
from concurrent.futures import ThreadPoolExecutor
from typing import List, Optional
from typing import List, Optional, Dict, Any

from dbgpt.core import Chunk
from dbgpt._private.pydantic import BaseModel, ConfigDict, Field, model_to_dict
from dbgpt.core import Chunk, Embeddings
from dbgpt.storage.vector_store.filters import MetadataFilters

logger = logging.getLogger(__name__)


class IndexStoreConfig(BaseModel):
"""Index store config."""

model_config = ConfigDict(arbitrary_types_allowed=True, extra="allow")

name: str = Field(
default="dbgpt_collection",
description="The name of index store, if not set, will use the default name.",
)
embedding_fn: Optional[Embeddings] = Field(
default=None,
description="The embedding function of vector store, if not set, will use the "
"default embedding function.",
)
max_chunks_once_load: int = Field(
default=10,
description="The max number of chunks to load at once. If your document is "
"large, you can set this value to a larger number to speed up the loading "
"process. Default is 10.",
)
max_threads: int = Field(
default=1,
description="The max number of threads to use. Default is 1. If you set this "
"bigger than 1, please make sure your vector store is thread-safe.",
)

def to_dict(self, **kwargs) -> Dict[str, Any]:
"""Convert to dict."""
return model_to_dict(self, **kwargs)


class IndexStoreBase(ABC):
"""Index store base class."""

Expand Down Expand Up @@ -45,6 +77,22 @@ def similar_search_with_scores(
List[Chunk]: The similar documents.
"""

@abstractmethod
def delete_by_ids(self, ids: str):
"""Delete docs.
Args:
ids(str): The vector ids to delete, separated by comma.
"""

@abstractmethod
def delete_vector_name(self, index_name: str):
"""Delete index by name.
Args:
index_name(str): The name of index to delete.
"""

def load_document_with_limit(
self, chunks: List[Chunk], max_chunks_once_load: int = 10, max_threads: int = 1
) -> List[str]:
Expand Down
5 changes: 0 additions & 5 deletions dbgpt/rag/transformer/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,6 @@
class TransformerBase(ABC):
"""Transformer base class."""

@abstractmethod
def transform(self, data):
"""Transform the input data and return the result."""
pass


class EmbedderBase(TransformerBase, ABC):
"""Embedder base class."""
Expand Down
1 change: 0 additions & 1 deletion dbgpt/rag/transformer/keyword_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ class KeywordExtractor(LLMExtractor):
"""KeywordExtractor class."""

def __init__(self, llm_client: LLMClient, model_name: str):
"""Initialize the KeywordExtractor with a LLM client and a specific model."""
super().__init__(llm_client, model_name, KEYWORD_EXTRACT_PT)

def _parse_response(self, text: str, limit: int) -> List[str]:
Expand Down
26 changes: 20 additions & 6 deletions dbgpt/rag/transformer/llm_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,27 +3,41 @@
from abc import ABC, abstractmethod
from typing import List

from dbgpt.core import HumanPromptTemplate, LLMClient, ModelMessage, ModelRequest
from dbgpt.core import HumanPromptTemplate, LLMClient, ModelMessage, \
ModelRequest
from dbgpt.rag.transformer.base import ExtractorBase

logger = logging.getLogger(__name__)
limit_num = 10


class LLMExtractor(ExtractorBase, ABC):
"""LLMExtractor class."""

def __init__(self, llm_client: LLMClient, model_name: str, prompt_template: str):
"""Initialize the LLMExtractor with a LLM client and a specific model."""
def __init__(
self,
llm_client: LLMClient,
model_name: str,
prompt_template: str
):
self._llm_client = llm_client
self._model_name = model_name
self._prompt_template = prompt_template

async def extract(self, text: str, limit: int = limit_num) -> List:
"""Extract keywords from text using the configured model."""
async def extract(
self,
text: str,
limit: int = None
) -> List:
template = HumanPromptTemplate.from_template(self._prompt_template)
messages = template.format_messages(text=text)

# use default model if needed
if not self._model_name:
models = await self._llm_client.models()
if not models:
raise Exception("No models available")
self._model_name = models[0].model

model_messages = ModelMessage.from_base_messages(messages)
request = ModelRequest(model=self._model_name, messages=model_messages)
response = await self._llm_client.generate(request=request)
Expand Down
1 change: 0 additions & 1 deletion dbgpt/rag/transformer/triplet_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ class TripletExtractor(LLMExtractor):
"""TripletExtractor class."""

def __init__(self, llm_client: LLMClient, model_name: str):
"""Initialize the TripletExtractor with a LLM client and a specific model."""
super().__init__(llm_client, model_name, TRIPLET_EXTRACT_PT)

def _parse_response(self, text: str, limit: int) -> List[Tuple[Any, ...]]:
Expand Down

0 comments on commit 12da5b3

Please sign in to comment.