From 70b6a9b7c79e89ca8e35533d952d69a3e9a8f0df Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Florian=20Ho=CC=88nicke?= Date: Mon, 22 May 2023 16:53:42 +0200 Subject: [PATCH] =?UTF-8?q?=F0=9F=93=97=20feat:=20prompt=20logging?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- dev_gpt/apis/gpt.py | 15 +++++++--- .../options/generate/conversation_logger.py | 28 +++++++++++++++++++ dev_gpt/options/generate/generator.py | 4 +-- 3 files changed, 41 insertions(+), 6 deletions(-) create mode 100644 dev_gpt/options/generate/conversation_logger.py diff --git a/dev_gpt/apis/gpt.py b/dev_gpt/apis/gpt.py index 387649c..335eab0 100644 --- a/dev_gpt/apis/gpt.py +++ b/dev_gpt/apis/gpt.py @@ -16,6 +16,7 @@ from urllib3.exceptions import InvalidChunkLength from dev_gpt.constants import PRICING_GPT4_PROMPT, PRICING_GPT4_GENERATION, PRICING_GPT3_5_TURBO_PROMPT, \ PRICING_GPT3_5_TURBO_GENERATION, CHARS_PER_TOKEN +from dev_gpt.options.generate.conversation_logger import ConversationLogger from dev_gpt.options.generate.templates_system import template_system_message_base from dev_gpt.utils.string_tools import print_colored, get_template_parameters @@ -41,9 +42,10 @@ class GPTSession: cls._instance = super(GPTSession, cls).__new__(cls) return cls._instance - def __init__(self, model: str = 'gpt-4', ): + def __init__(self, log_file_path: str, model: str = 'gpt-4', ): if GPTSession._initialized: return + self.conversation_logger = ConversationLogger(log_file_path) if model == 'gpt-4' and self.is_gpt4_available(): self.pricing_prompt = PRICING_GPT4_PROMPT self.pricing_generation = PRICING_GPT4_GENERATION @@ -58,10 +60,13 @@ class GPTSession: self.chars_generation_so_far = 0 GPTSession._initialized = True + + + def get_conversation(self, messages: List[BaseMessage] = [], print_stream: bool = True, print_costs: bool = True): messages = deepcopy(messages) return _GPTConversation( - self.model_name, self.cost_callback, messages, print_stream, print_costs + self.model_name, self.cost_callback, messages, print_stream, print_costs, self.conversation_logger ) @staticmethod @@ -107,7 +112,7 @@ class AssistantStreamingStdOutCallbackHandler(StreamingStdOutCallbackHandler): class _GPTConversation: - def __init__(self, model: str, cost_callback, messages: List[BaseMessage], print_stream, print_costs): + def __init__(self, model: str, cost_callback, messages: List[BaseMessage], print_stream, print_costs, conversation_logger: ConversationLogger = None): self._chat = ChatOpenAI( model_name=model, streaming=True, @@ -119,6 +124,7 @@ class _GPTConversation: self.messages = messages self.print_stream = print_stream self.print_costs = print_costs + self.conversation_logger = conversation_logger def print_messages(self, messages): for i, message in enumerate(messages): @@ -141,6 +147,7 @@ class _GPTConversation: for i in range(10): try: response = self._chat(self.messages) + self.conversation_logger.log(self.messages, response) break except (ConnectionError, InvalidChunkLength, ChunkedEncodingError) as e: print('There was a connection error. Retrying...') @@ -173,7 +180,7 @@ def ask_gpt(prompt_template, parser, **kwargs): if isinstance(value, dict): kwargs[key] = json.dumps(value, indent=4) prompt = prompt_template.format(**kwargs) - conversation = GPTSession().get_conversation( + conversation = GPTSession._instance.get_conversation( [], print_stream=os.environ['VERBOSE'].lower() == 'true', print_costs=False diff --git a/dev_gpt/options/generate/conversation_logger.py b/dev_gpt/options/generate/conversation_logger.py new file mode 100644 index 0000000..cbb3577 --- /dev/null +++ b/dev_gpt/options/generate/conversation_logger.py @@ -0,0 +1,28 @@ +import json +from typing import List + +from langchain.schema import BaseMessage + + +class ConversationLogger: + def __init__(self, log_file_path): + self.log_file_path = log_file_path + self.log_file = [] + + def log(self, prompt_message_list: List[BaseMessage], response: str): + prompt_list_json = [ + { + 'role': f'{message.type}', + 'content': f'{message.content}' + } + for message in prompt_message_list + ] + self.log_file.append({ + 'prompt': prompt_list_json, + 'response': f'{response}' + }) + with open(self.log_file_path, 'w') as f: + f.write(json.dumps(self.log_file, indent=2)) + + + diff --git a/dev_gpt/options/generate/generator.py b/dev_gpt/options/generate/generator.py index 5bd0f69..e10fe68 100644 --- a/dev_gpt/options/generate/generator.py +++ b/dev_gpt/options/generate/generator.py @@ -42,7 +42,7 @@ class TaskSpecification: class Generator: def __init__(self, task_description, path, model='gpt-4', self_healing=True): - self.gpt_session = gpt.GPTSession(model=model) + self.gpt_session = gpt.GPTSession(os.path.join(path, 'log.json'), model=model) self.microservice_specification = TaskSpecification(task=task_description, test=None) self.self_healing = self_healing self.microservice_root_path = path @@ -540,8 +540,8 @@ pytest # '/private/var/folders/f5/whmffl4d7q79s29jpyb6719m0000gn/T/pytest-of-florianhonicke/pytest-128/test_generation_level_0_mock_i0' # '/private/var/folders/f5/whmffl4d7q79s29jpyb6719m0000gn/T/pytest-of-florianhonicke/pytest-129/test_generation_level_0_mock_i0' def generate(self): - self.microservice_specification.task, self.microservice_specification.test = PM().refine_specification(self.microservice_specification.task) os.makedirs(self.microservice_root_path) + self.microservice_specification.task, self.microservice_specification.test = PM().refine_specification(self.microservice_specification.task) generated_name = self.generate_microservice_name(self.microservice_specification.task) self.microservice_name = f'{generated_name}{random.randint(0, 10_000_000)}' packages_list = self.get_possible_packages()