Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initiliaze ATOM with NER #3656

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -69,3 +69,5 @@ httpx==0.26.0
httpx_sse==0.4.0
watchdog==3.0.0
pyjwt==2.8.0

modelscope=1.13.3
1 change: 1 addition & 0 deletions requirements_api.txt
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ httpx==0.26.0
httpx_sse==0.4.0
llama-index==0.9.35
pyjwt==2.8.0
modelscope=1.13.3

# jq==1.6.0
# beautifulsoup4~=4.12.2
Expand Down
1 change: 1 addition & 0 deletions requirements_webui.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@ streamlit-aggrid==0.3.4.post3
httpx==0.26.0
httpx_sse==0.4.0
watchdog==3.0.0
modelscope=1.13.3
10 changes: 7 additions & 3 deletions server/chat/knowledge_base_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
RERANKER_MODEL,
RERANKER_MAX_LENGTH,
MODEL_PATH)
from server.utils import wrap_done, get_ChatOpenAI
from server.utils import wrap_done, get_ChatOpenAI, recognize_with_ner, get_search_query
from server.utils import BaseResponse, get_prompt_template
from langchain.chains import LLMChain
from langchain.callbacks import AsyncIteratorCallbackHandler
Expand Down Expand Up @@ -41,6 +41,7 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入",
{"role": "assistant",
"content": "虎头虎脑"}]]
),
enable_ner: bool = Body(False, description="是否开启命名实体识别"),
stream: bool = Body(False, description="流式输出"),
model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"),
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
Expand All @@ -64,6 +65,7 @@ async def knowledge_base_chat_iterator(
query: str,
top_k: int,
history: Optional[List[History]],
enable_ner: bool,
model_name: str = model_name,
prompt_name: str = prompt_name,
) -> AsyncIterable[str]:
Expand All @@ -78,8 +80,10 @@ async def knowledge_base_chat_iterator(
max_tokens=max_tokens,
callbacks=[callback],
)
if enable_ner is True:
search_query = get_search_query(query)
docs = await run_in_threadpool(search_docs,
query=query,
query=search_query,
knowledge_base_name=knowledge_base_name,
top_k=top_k,
score_threshold=score_threshold)
Expand Down Expand Up @@ -143,5 +147,5 @@ async def knowledge_base_chat_iterator(
ensure_ascii=False)
await task

return EventSourceResponse(knowledge_base_chat_iterator(query, top_k, history,model_name,prompt_name))
return EventSourceResponse(knowledge_base_chat_iterator(query, top_k, history, enable_ner, model_name, prompt_name))

26 changes: 26 additions & 0 deletions server/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -683,3 +683,29 @@ def get_temp_dir(id: str = None) -> Tuple[str, str]:

path = tempfile.mkdtemp(dir=BASE_TEMP_DIR)
return path, os.path.basename(path)


def get_search_query(query):
search_query_template = "疾病:{{disease}};症状:{{symbols}}"
result = recognize_with_ner(query)
diseases, symbols = result['dis'], result['sym']
search_query = search_query_template.replace(
"{{disease}}", ','.join(diseases)).replace("{{symbols}}", ','.join(symbols))
return search_query


def recognize_with_ner(query):
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
ner_pipeline = pipeline(
Tasks.named_entity_recognition,
'damo/nlp_raner_named-entity-recognition_chinese-base-cmeee'
)
result = ner_pipeline(query)
dis, sym = set(), set()
for l in result:
if l['type'] == 'dis':
dis.add(l['span'])
elif l['type'] == 'sym':
sym.add(l['span'])
return {'dis': list(dis), 'sym': list(sym)}
16 changes: 16 additions & 0 deletions text_splitter/chinese_recursive_text_splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,22 @@ def _split_text(self, text: str, separators: List[str]) -> List[str]:
return [re.sub(r"\n{2,}", "\n", chunk.strip()) for chunk in final_chunks if chunk.strip()!=""]


class ATOMChineseRecursiveTextSplitter(ChineseRecursiveTextSplitter):
def __init__(
self,
separators: Optional[List[str]] = None,
keep_separator: bool = True,
is_separator_regex: bool = True,
**kwargs: Any,
) -> None:
super(ATOMChineseRecursiveTextSplitter, self).__init__(
separators=["\n\n\n"],
keep_separator=keep_separator,
is_separator_regex=is_separator_regex,
** kwargs
)


if __name__ == "__main__":
text_splitter = ChineseRecursiveTextSplitter(
keep_separator=True,
Expand Down
2 changes: 2 additions & 0 deletions webui_pages/dialogue/dialogue.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,7 @@ def prompt_change():
prompt_template_name = st.session_state.prompt_template_select
temperature = st.slider("Temperature:", 0.0, 2.0, TEMPERATURE, 0.05)
history_len = st.number_input("历史对话轮数:", 0, 20, HISTORY_LEN)
enable_ner = st.checkbox("开启命名实体识别", False)

def on_kb_change():
st.toast(f"已加载知识库: {st.session_state.selected_kb}")
Expand Down Expand Up @@ -375,6 +376,7 @@ def on_feedback(
top_k=kb_top_k,
score_threshold=score_threshold,
history=history,
enable_ner=enable_ner,
model=llm_model,
prompt_name=prompt_template_name,
temperature=temperature):
Expand Down
2 changes: 2 additions & 0 deletions webui_pages/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,7 @@ def knowledge_base_chat(
top_k: int = VECTOR_SEARCH_TOP_K,
score_threshold: float = SCORE_THRESHOLD,
history: List[Dict] = [],
enable_ner: bool = False,
stream: bool = True,
model: str = LLM_MODELS[0],
temperature: float = TEMPERATURE,
Expand All @@ -354,6 +355,7 @@ def knowledge_base_chat(
"top_k": top_k,
"score_threshold": score_threshold,
"history": history,
"enable_ner": enable_ner,
"stream": stream,
"model_name": model,
"temperature": temperature,
Expand Down