diff --git a/cli_demo.py b/cli_demo.py index 3559840c..f92d3d86 100644 --- a/cli_demo.py +++ b/cli_demo.py @@ -29,6 +29,7 @@ def signal_handler(signal, frame): def main(): history = [] global stop_stream + signal.signal(signal.SIGINT, signal_handler) print("欢迎使用 ChatGLM-6B 模型,输入内容即可进行对话,clear 清空对话历史,stop 终止程序") while True: query = input("\n用户:") @@ -39,20 +40,16 @@ def main(): os.system(clear_command) print("欢迎使用 ChatGLM-6B 模型,输入内容即可进行对话,clear 清空对话历史,stop 终止程序") continue - count = 0 + prev_response = "" + print("ChatGLM-6B:", end="", flush=True) for response, history in model.stream_chat(tokenizer, query, history=history): if stop_stream: stop_stream = False break else: - count += 1 - if count % 8 == 0: - os.system(clear_command) - print(build_prompt(history), flush=True) - signal.signal(signal.SIGINT, signal_handler) - os.system(clear_command) - print(build_prompt(history), flush=True) - + print(response[len(prev_response):], end="", flush=True) + prev_response = response + print("\n", end="", flush=True) if __name__ == "__main__": main()