diff --git a/pilot/utils/custom_print.py b/pilot/utils/custom_print.py index 9bc3800cb..095f69831 100644 --- a/pilot/utils/custom_print.py +++ b/pilot/utils/custom_print.py @@ -1,13 +1,12 @@ import builtins from helpers.ipc import IPCClient from const.ipc import MESSAGE_TYPE, LOCAL_IGNORE_MESSAGE_TYPES +from typing import Callable, List, Tuple, Optional - -def get_custom_print(args): +def get_custom_print(args) -> Tuple[Callable, Optional[IPCClient]]: built_in_print = builtins.print def print_to_external_process(*args, **kwargs): - # message = " ".join(map(str, args)) message = args[0] if 'type' not in kwargs: diff --git a/pilot/utils/dot_gpt_pilot.py b/pilot/utils/dot_gpt_pilot.py index c6a10571e..e02945c86 100644 --- a/pilot/utils/dot_gpt_pilot.py +++ b/pilot/utils/dot_gpt_pilot.py @@ -41,24 +41,36 @@ def chat_log_folder(self, task): if task is not None: chat_log_path = os.path.join(chat_log_path, 'task_' + str(task)) - os.makedirs(chat_log_path, exist_ok=True) + try: + os.makedirs(chat_log_path) + except OSError as e: + print(f"Error creating folder: {e}") + raise + self.chat_log_path = chat_log_path return chat_log_path + def log_chat_completion(self, endpoint: str, model: str, req_type: str, messages: list[dict], response: str): if not USE_GPTPILOT_FOLDER: return if self.log_chat_completions: time = datetime.now().strftime('%Y-%m-%d_%H_%M_%S') - with open(os.path.join(self.chat_log_path, f'{time}-{req_type}.yaml'), 'w', encoding="utf-8") as file: - data = { - 'endpoint': endpoint, - 'model': model, - 'messages': messages, - 'response': response, - } - - yaml.safe_dump(data, file, width=120, indent=2, default_flow_style=False, sort_keys=False) + try: + with open(os.path.join(self.chat_log_path, f'{time}-{req_type}.yaml'), 'w', encoding="utf-8") as file: + data = { + 'endpoint': endpoint, + 'model': model, + 'messages': messages, + 'response': response, + } + + try: + yaml.safe_dump(data, file) + except yaml.YAMLError as e: + print(f"Error serializing YAML: {e}") + except Exception as e: + print(f"Error logging chat completion: {e}") def log_chat_completion_json(self, endpoint: str, model: str, req_type: str, functions: dict, json_response: str): if not USE_GPTPILOT_FOLDER: diff --git a/pilot/utils/questionary.py b/pilot/utils/questionary.py index 8a0a82ac0..0a8ab56ce 100644 --- a/pilot/utils/questionary.py +++ b/pilot/utils/questionary.py @@ -6,9 +6,10 @@ from utils.style import color_yellow_bold, style_config -def remove_ansi_codes(s: str) -> str: - ansi_escape = re.compile(r'\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])') - return ansi_escape.sub('', s) +ANSI_ESCAPE_REGEX = re.compile(r'\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])') + +def remove_ansi_codes(s): + return ANSI_ESCAPE_REGEX.sub('', s) def styled_select(*args, **kwargs):