diff --git a/scripts/print_chat.py b/scripts/print_chat.py index eef244f..ab1c09a 100644 --- a/scripts/print_chat.py +++ b/scripts/print_chat.py @@ -34,14 +34,9 @@ def pretty_print_conversation(messages): formatted_messages.append(role_to_message[message["role"]]) for formatted_message in formatted_messages: - print( - colored( - formatted_message, - role_to_color[ - messages[formatted_messages.index(formatted_message)]["role"] - ], - ) - ) + role = messages[formatted_messages.index(formatted_message)]["role"] + color = role_to_color[role] + print(colored(formatted_message, color)) @app.command()