Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add new splitter to process QA type file(now only support JSON) and add Toggle button in knowledge_base page #3298

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 6 additions & 0 deletions configs/kb_config.py.example
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,10 @@ text_splitter_dict = {
"source": "huggingface", # 选择tiktoken则使用openai的方法
"tokenizer_name_or_path": "",
},
"QATextSplitter": {
"source": "huggingface", # 选择tiktoken则使用openai的方法
"tokenizer_name_or_path": "",
},
"SpacyTextSplitter": {
"source": "huggingface",
"tokenizer_name_or_path": "gpt2",
Expand All @@ -141,6 +145,8 @@ text_splitter_dict = {

# TEXT_SPLITTER 名称
TEXT_SPLITTER_NAME = "ChineseRecursiveTextSplitter"
# QA_SPLITTER 名称
QA_SPLITTER_NAME = "QATextSplitter"

# Embedding模型定制词语的词表文件
EMBEDDING_KEYWORD_FILE = "embedding_keywords.txt"
10 changes: 8 additions & 2 deletions server/knowledge_base/kb_doc_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from fastapi import File, Form, Body, Query, UploadFile
from configs import (DEFAULT_VS_TYPE, EMBEDDING_MODEL,
VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD,
CHUNK_SIZE, OVERLAP_SIZE, ZH_TITLE_ENHANCE,
CHUNK_SIZE, OVERLAP_SIZE, ZH_TITLE_ENHANCE, QA_SPLITTER_NAME,
logger, log_verbose, )
from server.utils import BaseResponse, ListResponse, run_in_thread_pool
from server.knowledge_base.utils import (validate_kb_name, list_files_from_folder, get_file_path,
Expand Down Expand Up @@ -142,6 +142,7 @@ def upload_docs(
docs: Json = Form({}, description="自定义的docs,需要转为json字符串",
examples=[{"test.txt": [Document(page_content="custom doc")]}]),
not_refresh_vs_cache: bool = Form(False, description="暂不保存向量库(用于FAISS)"),
is_QA: bool = Form(False, description="是否是问答类型知识库,如果是,会启用自定义的QA分词器"),
) -> BaseResponse:
"""
API接口:上传文件,并/或向量化
Expand Down Expand Up @@ -176,6 +177,7 @@ def upload_docs(
zh_title_enhance=zh_title_enhance,
docs=docs,
not_refresh_vs_cache=True,
is_QA=is_QA
)
failed_files.update(result.data["failed_files"])
if not not_refresh_vs_cache:
Expand Down Expand Up @@ -244,6 +246,7 @@ def update_docs(
docs: Json = Body({}, description="自定义的docs,需要转为json字符串",
examples=[{"test.txt": [Document(page_content="custom doc")]}]),
not_refresh_vs_cache: bool = Body(False, description="暂不保存向量库(用于FAISS)"),
is_QA: bool = Body(False, description="是否是问答类型知识库,如果是,会启用自定义的QA分词器")
) -> BaseResponse:
"""
更新知识库文档
Expand All @@ -266,7 +269,10 @@ def update_docs(
continue
if file_name not in docs:
try:
kb_files.append(KnowledgeFile(filename=file_name, knowledge_base_name=knowledge_base_name))
if is_QA:
kb_files.append(KnowledgeFile(filename=file_name, knowledge_base_name=knowledge_base_name, text_splitter_name=QA_SPLITTER_NAME))
else:
kb_files.append(KnowledgeFile(filename=file_name, knowledge_base_name=knowledge_base_name))
except Exception as e:
msg = f"加载文档 {file_name} 时出错:{e}"
logger.error(f'{e.__class__.__name__}: {msg}',
Expand Down
3 changes: 2 additions & 1 deletion server/knowledge_base/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,7 @@ def __init__(
filename: str,
knowledge_base_name: str,
loader_kwargs: Dict = {},
text_splitter_name: str = TEXT_SPLITTER_NAME,
):
'''
对应知识库目录中的文件,必须是磁盘上存在的才能进行向量化等操作。
Expand All @@ -288,7 +289,7 @@ def __init__(
self.docs = None
self.splited_docs = None
self.document_loader_name = get_LoaderClass(self.ext)
self.text_splitter_name = TEXT_SPLITTER_NAME
self.text_splitter_name = text_splitter_name

def file2docs(self, refresh: bool = False):
if self.docs is None or refresh:
Expand Down
3 changes: 2 additions & 1 deletion text_splitter/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .chinese_text_splitter import ChineseTextSplitter
from .ali_text_splitter import AliTextSplitter
from .zh_title_enhance import zh_title_enhance
from .chinese_recursive_text_splitter import ChineseRecursiveTextSplitter
from .chinese_recursive_text_splitter import ChineseRecursiveTextSplitter
from .qa_text_splitter import QATextSplitter
25 changes: 25 additions & 0 deletions text_splitter/qa_text_splitter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from typing import List, Optional, Any
from langchain.text_splitter import TextSplitter

class QATextSplitter(TextSplitter):
"""Splitting QA file. Temporary only support json file."""
def __init__(
self,
keep_separator: bool = True,
**kwargs: Any,
) -> None:
"""Create a new TextSplitter."""
super().__init__(keep_separator=keep_separator, **kwargs)


def split_text(self, text: str) -> List[str]:
json_text = eval(text)

splits = []
for qa in json_text:
question = qa["问题"]
answer = qa["答案"]

splits.append(f"问题:{question}\n答案:{answer}")

return splits
47 changes: 35 additions & 12 deletions webui_pages/knowledge_base/knowledge_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,25 @@ def format_selected_kb(kb_name: str) -> str:
[i for ls in LOADER_DICT.values() for i in ls],
accept_multiple_files=True,
)

is_QA = st.toggle("是否是问答类型文档",
value=False,
help='''
问答类型文档会启用问答分词器,文件暂时仅支持json,具体请参考如下格式:
```json
[
{
"问题": "问题1",
"答案": "答案1"
},
{
"问题": "问题2",
"答案": "答案2"
}
]
```
''')

kb_info = st.text_area("请输入知识库介绍:", value=st.session_state["selected_kb_info"], max_chars=None,
key=None,
help=None, on_change=None, args=None, kwargs=None)
Expand All @@ -156,17 +175,20 @@ def format_selected_kb(kb_name: str) -> str:
st.session_state["selected_kb_info"] = kb_info
api.update_kb_info(kb, kb_info)

# with st.sidebar:
with st.expander(
"文件处理配置",
expanded=True,
):
cols = st.columns(3)
chunk_size = cols[0].number_input("单段文本最大长度:", 1, 1000, CHUNK_SIZE)
chunk_overlap = cols[1].number_input("相邻文本重合长度:", 0, chunk_size, OVERLAP_SIZE)
cols[2].write("")
cols[2].write("")
zh_title_enhance = cols[2].checkbox("开启中文标题加强", ZH_TITLE_ENHANCE)
chunk_size = CHUNK_SIZE
chunk_overlap = OVERLAP_SIZE
zh_title_enhance = ZH_TITLE_ENHANCE
if not is_QA:
with st.expander(
"文件处理配置",
expanded=True,
):
cols = st.columns(3)
chunk_size = cols[0].number_input("单段文本最大长度:", 1, 1000, CHUNK_SIZE)
chunk_overlap = cols[1].number_input("相邻文本重合长度:", 0, chunk_size, OVERLAP_SIZE)
cols[2].write("")
cols[2].write("")
zh_title_enhance = cols[2].checkbox("开启中文标题加强", ZH_TITLE_ENHANCE)

if st.button(
"添加文件到知识库",
Expand All @@ -178,7 +200,8 @@ def format_selected_kb(kb_name: str) -> str:
override=True,
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
zh_title_enhance=zh_title_enhance)
zh_title_enhance=zh_title_enhance,
is_QA=is_QA)
if msg := check_success_msg(ret):
st.toast(msg, icon="✔")
elif msg := check_error_msg(ret):
Expand Down
2 changes: 2 additions & 0 deletions webui_pages/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -606,6 +606,7 @@ def upload_kb_docs(
zh_title_enhance=ZH_TITLE_ENHANCE,
docs: Dict = {},
not_refresh_vs_cache: bool = False,
is_QA:bool = False
):
'''
对应api.py/knowledge_base/upload_docs接口
Expand All @@ -631,6 +632,7 @@ def convert_file(file, filename=None):
"zh_title_enhance": zh_title_enhance,
"docs": docs,
"not_refresh_vs_cache": not_refresh_vs_cache,
"is_QA":is_QA
}

if isinstance(data["docs"], dict):
Expand Down