Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

1、优化多账号负载均衡时的选择 2、添加一系列功能 #230

Open
wants to merge 2 commits into
base: browser-version
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
33 changes: 30 additions & 3 deletions bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import chatbot
from config import Config
from utils.text_to_img import to_image
import conversation_manager
from manager.ratelimit import RateLimitManager
import time
from revChatGPT.V1 import Error as V1Error
Expand Down Expand Up @@ -71,11 +72,32 @@ async def create_timeout_task(target: Union[Friend, Group], source: Source):


async def handle_message(target: Union[Friend, Group], session_id: str, message: str, source: Source) -> str:
number = session_id.split('-')[1]
if not message.strip():
return config.response.placeholder

timeout_task = None

# 如果消息包含help命令(config.trigger.help_command所定义的内容),则回滚会话
if message.strip() in config.trigger.help_command:
return config.response.help_command.format(max_sessions=config.max_record.max_sessions)

# 如果消息包含 会话列表 命令(config.trigger.talk_list_command所定义的内容),则输出会话列表
if message.strip() in config.trigger.talk_list_command:
return "会话列表如下:\n" + '\n'.join(f"{i+1}. {x}" for i, x in enumerate(conversation_manager.get_user_sessions(number)))

# 如果消息包含会话上限x命令,则进入会话x
max_record_search = re.search(config.trigger.max_talk_sessions, message)
if max_record_search:
conversation_manager.update_max_sessions(number, int(max_record_search.group(1)))
if int(max_record_search.group(1)) > 10:
max_record = 10
elif int(max_record_search.group(1)) < 1:
max_record = 1
else:
max_record = int(max_record_search.group(1))
return f"会话上限已设置为{max_record}!"

session, is_new_session = chatbot.get_chat_session(session_id)

# 回滚
Expand Down Expand Up @@ -170,14 +192,19 @@ async def friend_message_listener(app: Ariadne, friend: Friend, source: Source,
if rate_usage >= 1:
response = config.ratelimit.exceed
else:
response = await handle_message(friend, f"friend-{friend.id}", chain.display, source)
response = await handle_message(friend, f"fd-{friend.id}-", chain.display, source)
if rate_usage >= config.ratelimit.warning_rate:
limit = rateLimitManager.get_limit('好友', friend.id)
usage = rateLimitManager.get_usage('好友', friend.id)
current_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(time.time()))
response = response + '\n' + config.ratelimit.warning_msg.format(usage=usage['count'], limit=limit['rate'],
current_time=current_time)
await respond(friend, source, response)
if len(response) <= 3000:
await respond(friend, source, response)
else:
chunks = [response[i:i+3000] for i in range(0, len(response), 3000)]
for chunk in chunks:
await respond(friend, source, chunk)


GroupTrigger = Annotated[MessageChain, MentionMe(config.trigger.require_mention != "at"), DetectPrefix(
Expand All @@ -193,7 +220,7 @@ async def group_message_listener(group: Group, source: Source, chain: GroupTrigg
if rate_usage >= 1:
return config.ratelimit.exceed
else:
response = await handle_message(group, f"group-{group.id}", chain.display, source)
response = await handle_message(group, f"gp-{group.id}-", chain.display, source)
if rate_usage >= config.ratelimit.warning_rate:
limit = rateLimitManager.get_limit('群组', group.id)
usage = rateLimitManager.get_usage('群组', group.id)
Expand Down
168 changes: 157 additions & 11 deletions chatbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
from manager.bot import BotManager, BotInfo
import atexit
from loguru import logger
import conversation_manager
import re
from tinydb import TinyDB, Query
import revChatGPT.V1 as V1

config = Config.load_config()
Expand All @@ -24,10 +27,12 @@ def setup():
class ChatSession:
chatbot: BotInfo = None
session_id: str
number: str

def __init__(self, session_id):
self.session_id = session_id
self.prev_conversation_id = []
self.number = session_id.split('-')[1]
self.prev_conversation_id = None
self.prev_parent_id = []
self.parent_id = None
self.conversation_id = None
Expand Down Expand Up @@ -56,47 +61,188 @@ def reset_conversation(self):
self.chatbot.bot.delete_conversation(self.conversation_id)
self.conversation_id = None
self.parent_id = None
self.prev_conversation_id = []
self.prev_conversation_id = None
self.prev_parent_id = []
self.chatbot = botManager.pick()

def rollback_conversation(self) -> bool:
if len(self.prev_parent_id) <= 0:
return False
self.conversation_id = self.prev_conversation_id.pop()
# self.conversation_id = self.prev_conversation_id.pop()
self.parent_id = self.prev_parent_id.pop()
# 回滚一次对话
conversation_manager.rollback_last_parent_id(self.number,
self.conversation_id)
return True

# 解析读取历史会话得到的字符串
def extract_conversations(self, result):

output = ""
message_mapping = result['mapping']
current_node_id = result['current_node']
current_node = message_mapping[current_node_id]
conversation = []

while current_node['parent'] is not None:
current_message = current_node['message']
if current_message is not None:
author_role = current_message['author']['role']
message_content = current_message['content']['parts'][0]
if author_role == 'user':
conversation.append("you:" + message_content)
elif author_role == 'assistant':
conversation.append("bot:" + message_content)

current_node_id = current_node['parent']
current_node = message_mapping[current_node_id]

for line in reversed(conversation):
output += line + "\n"

return output

async def get_chat_response(self, message) -> str:
self.prev_conversation_id.append(self.conversation_id)

self.prev_conversation_id = self.conversation_id
self.prev_parent_id.append(self.parent_id)
logger.info(f"当前id:{self.conversation_id}, 父节点id:{self.parent_id}")

# 如果消息包含进入会话x命令,则进入会话x
into_talk_search = re.search(config.trigger.goto_talk, message)
if into_talk_search:
conversation_parents = conversation_manager.get_last_parent_id(self.number, int(into_talk_search.group(1)))
if conversation_parents:
self.conversation_id = conversation_parents[0]
self.parent_id = conversation_parents[1][-1]
self.prev_conversation_id = conversation_parents[0]
self.prev_parent_id = conversation_parents[1][0:-1]
self.chatbot = botManager.pick_id(conversation_parents[2])
return f"进入会话{into_talk_search.group(1)}成功!"
else:
return f"进入会话{into_talk_search.group(1)}失败!停留在当前会话"

# 如果消息包含删除会话x命令,则删除对应会话
delete_talk_search = re.search(config.trigger.delete_talk, message)
if delete_talk_search:
result = conversation_manager.delete_session_record(self.number, int(delete_talk_search.group(1)))
if result == 2: # 如果删除的恰巧是最后一个,也就是当前会话
self.reset_conversation()
if result:
return f"删除会话{delete_talk_search.group(1)}成功!"
else:
return f"删除会话{delete_talk_search.group(1)}失败!会话不存在"

# 如果消息包含读取会话x命令,则读取对应会话
read_talk_search = re.search(config.trigger.read_talk, message)
if read_talk_search:
conversation_parents = conversation_manager.get_last_parent_id(self.number, int(read_talk_search.group(1)))
if conversation_parents:
chatbot_save = self.chatbot
self.chatbot = botManager.pick_id(conversation_parents[2])
result = self.chatbot.bot.get_msg_history(conversation_parents[0], encoding='utf-8')
result = self.extract_conversations(result)
self.chatbot = chatbot_save
logger.info(
f"会话{read_talk_search.group(1)}的聊天记录如下:\n{result}")
return f"读取会话成功!\n会话{read_talk_search.group(1)}的聊天记录如下:\n{result}"
else:
return f"读取会话{read_talk_search.group(1)}失败!停留在当前会话"

# 如果消息包含清空会话命令(config.trigger.rollback_command所定义的内容),则清空会话
if message.strip() in config.trigger.clear_talk_command:
self.reset_conversation()
if conversation_manager.clear_user_sessions(self.number):
return "清空会话成功!"
else:
return "清空会话失败!你是不是还没对话过。"

# 如果消息包含 会话名:*** 命令,则删除对应会话
rename_talk_search = re.search(config.trigger.rename_talk, message)
if rename_talk_search:
if self.conversation_id:
conversation_manager.update_session(self.number, self.conversation_id, rename_talk_search.group(1))
self.chatbot.bot.change_title(self.conversation_id, (
self.session_id.encode('unicode-escape') + rename_talk_search.group(1).encode(
'unicode-escape')).decode('utf-8'))
return f"会话名已改为{rename_talk_search.group(1)}"
else:
return "会话还没开始,不能设置会话名!"

bot = self.chatbot.bot
botManager.update_bot_time(self.chatbot.id)
bot.conversation_id = self.conversation_id
bot.parent_id = self.parent_id
logger.info(
f"当前id:{self.conversation_id}, 父节点id:{self.parent_id}")

loop = asyncio.get_event_loop()
resp = await loop.run_in_executor(None, self.chatbot.ask, message, self.conversation_id, self.parent_id)

if self.conversation_id is None and self.chatbot.account.title_pattern:
self.chatbot.bot.change_title(resp["conversation_id"],
self.chatbot.account.title_pattern.format(session_id=self.session_id))
flag = False
if self.conversation_id is None:
flag = True

self.conversation_id = resp["conversation_id"]
self.parent_id = resp["parent_id"]

# 添加对话记录
conversation_manager.add_session_record(self.number, self.conversation_id, message[:20],
self.chatbot.account_id,
self.parent_id)

if flag:
self.chatbot.bot.change_title(self.conversation_id, f"{self.session_id}{message[:20].encode('utf-8')}")
# self.chatbot.bot.change_title(resp["conversation_id"],self.chatbot.account.title_pattern.format(session_id=f"{self.session_id}"+str(message[:20].encode("utf-8"))))##########################

logger.info(
f"当前id:{self.conversation_id}, 父节点id:{self.parent_id}")

return resp["message"]


__sessions = {}


def get_chat_session(session_id: str) -> Tuple[ChatSession, bool]:
number = session_id.split('-')[1]
new_session = False
if session_id not in __sessions:
__sessions[session_id] = ChatSession(session_id)
new_session = True
return __sessions[session_id], new_session

if number not in __sessions: #有可能是重启容器了(读取旧聊天最后一个),也有可能是新的用户(创建新聊天)

# 创建一个新的聊天会话
__sessions[number] = ChatSession(session_id)

# 创建一个新的聊天会话
__sessions[number] = ChatSession(session_id)

# 打开数据库
db = TinyDB('data/session_records.json')

# 获取数据表
table = db.table('session_records')

# 读取数据
session_records = table.all()
if len(session_records) > 0:
session_records = session_records[0]
else:
session_records = {}

# 读取session_records字典
# with open('session_records.json', 'r') as f:
# session_records = json.loads(f.read())
if number in session_records: #重启容器了
if len(session_records[number]["sessions"]) > 0:
__sessions[number].conversation_id = session_records[number]["sessions"][-1]["conversation_id"]
__sessions[number].parent_id = session_records[number]["sessions"][-1]["parent_ids"][-1]
__sessions[number].prev_conversation_id = session_records[number]["sessions"][-1]["conversation_id"]
__sessions[number].prev_parent_id = session_records[number]["sessions"][-1]["parent_ids"][0:-1]
__sessions[number].chatbot = botManager.pick_id(session_records[number]["sessions"][-1]["account_id"])
else: # 新用户开始聊天
new_session = True

return __sessions[number], new_session


def conversation_remover():
Expand Down
40 changes: 39 additions & 1 deletion config.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,15 @@ class OpenAIAuthBase(BaseModel):
auto_remove_old_conversations: bool = False
"""自动删除旧的对话"""


class Config(BaseConfig):
extra = Extra.allow


class OpenAIEmailAuth(OpenAIAuthBase):
email: str
email: str = ''
"""OpenAI 注册邮箱"""

password: str
"""OpenAI 密码"""
isMicrosoftLogin: bool = False
Expand Down Expand Up @@ -90,6 +92,9 @@ class TextToImage(BaseModel):
"""纵坐标"""
wkhtmltoimage: Union[str, None] = None

class MaxRecord(BaseModel):
max_sessions: int = 5
"""会话数量上限"""

class Trigger(BaseModel):
prefix: List[str] = [""]
Expand All @@ -100,6 +105,23 @@ class Trigger(BaseModel):
"""重置会话的命令"""
rollback_command: List[str] = ["回滚会话"]
"""回滚会话的命令"""
help_command: List[str] = ["help"]
"""请求帮助的命令"""
talk_list_command: List[str] = ["会话列表"]
"""会话列表的命令"""
clear_talk_command: List[str] = ["清空会话"]
"""清空会话的命令"""
max_talk_sessions: str = r"会话上限\s*(\d+)"

goto_talk: str = r"进入会话\s*(\d+)"

delete_talk: str = r"删除会话\s*(\d+)"

read_talk: str = r"读取会话\s*(\d+)"

rename_talk: str = r"会话名\s*[::]\s*([^::\s]+)"




class Response(BaseModel):
Expand Down Expand Up @@ -153,6 +175,21 @@ class Response(BaseModel):
queued_notice: str = "消息已收到!当前我还有{queue_size}条消息要回复,请您稍等。"
"""新消息进入队列时,发送的通知。 queue_size 是当前排队的消息数"""

help_command: str = (
"功能列表:\n"
"样例:指令 - 指令的功能\n"
"帮助 - 显示功能列表\n"
"重置会话 - 离开当前会话(保留),开启新的会话\n"
"回滚会话 - 回滚一次对话\n"
"会话名:*** - 设置当前会话名(开始会话之后才能设置)\n"
"会话列表 - 显示自己现有的会话列表\n"
"进入会话x - 进入会话列表中第x个会话\n"
"读取会话x - 读取会话列表中第x的会话的所有会话记录\n"
"删除会话x - 删除会话列表中第x个会话\n"
"清空会话 - 清空自己所有的会话\n"
"会话上限x - 设置自己会话数量上限x,最大为{max_sessions}"
)


class System(BaseModel):
accept_group_invite: bool = False
Expand Down Expand Up @@ -186,6 +223,7 @@ class Config(BaseModel):
response: Response = Response()
system: System = System()
presets: Preset = Preset()
max_record: MaxRecord = MaxRecord()
ratelimit: Ratelimit = Ratelimit()

def scan_presets(self):
Expand Down