diff --git a/scripts/print_chat.py b/scripts/print_chat.py index a972be9..2ff96f6 100644 --- a/scripts/print_chat.py +++ b/scripts/print_chat.py @@ -16,18 +16,20 @@ def pretty_print_conversation(messages): } formatted_messages = [] for message in messages: - if message["role"] == "system": - formatted_messages.append(f"system: {message['content']}\n") - elif message["role"] == "user": - formatted_messages.append(f"user: {message['content']}\n") - elif message["role"] == "assistant" and message.get("function_call"): - formatted_messages.append(f"assistant: {message['function_call']}\n") - elif message["role"] == "assistant" and not message.get("function_call"): - formatted_messages.append(f"assistant: {message['content']}\n") - elif message["role"] == "function": - formatted_messages.append( - f"function ({message['name']}): {message['content']}\n" - ) + assistant_content = ( + message["function_call"] + if message.get("function_call") + else message["content"] + ) + role_to_message = { + "system": f"system: {message['content']}\n", + "user": f"user: {message['content']}\n", + "assistant": f"assistant: {assistant_content}\n", + "function": f"function ({message['name']}): {message['content']}\n", + } + + formatted_messages.append(role_to_message[message["role"]]) + for formatted_message in formatted_messages: print( colored(