mirror of
https://github.com/aljazceru/dev-gpt.git
synced 2025-12-20 07:04:20 +01:00
📗 feat: prompt logging
This commit is contained in:
@@ -16,6 +16,7 @@ from urllib3.exceptions import InvalidChunkLength
|
||||
|
||||
from dev_gpt.constants import PRICING_GPT4_PROMPT, PRICING_GPT4_GENERATION, PRICING_GPT3_5_TURBO_PROMPT, \
|
||||
PRICING_GPT3_5_TURBO_GENERATION, CHARS_PER_TOKEN
|
||||
from dev_gpt.options.generate.conversation_logger import ConversationLogger
|
||||
from dev_gpt.options.generate.templates_system import template_system_message_base
|
||||
from dev_gpt.utils.string_tools import print_colored, get_template_parameters
|
||||
|
||||
@@ -41,9 +42,10 @@ class GPTSession:
|
||||
cls._instance = super(GPTSession, cls).__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
def __init__(self, model: str = 'gpt-4', ):
|
||||
def __init__(self, log_file_path: str, model: str = 'gpt-4', ):
|
||||
if GPTSession._initialized:
|
||||
return
|
||||
self.conversation_logger = ConversationLogger(log_file_path)
|
||||
if model == 'gpt-4' and self.is_gpt4_available():
|
||||
self.pricing_prompt = PRICING_GPT4_PROMPT
|
||||
self.pricing_generation = PRICING_GPT4_GENERATION
|
||||
@@ -58,10 +60,13 @@ class GPTSession:
|
||||
self.chars_generation_so_far = 0
|
||||
GPTSession._initialized = True
|
||||
|
||||
|
||||
|
||||
|
||||
def get_conversation(self, messages: List[BaseMessage] = [], print_stream: bool = True, print_costs: bool = True):
|
||||
messages = deepcopy(messages)
|
||||
return _GPTConversation(
|
||||
self.model_name, self.cost_callback, messages, print_stream, print_costs
|
||||
self.model_name, self.cost_callback, messages, print_stream, print_costs, self.conversation_logger
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@@ -107,7 +112,7 @@ class AssistantStreamingStdOutCallbackHandler(StreamingStdOutCallbackHandler):
|
||||
|
||||
|
||||
class _GPTConversation:
|
||||
def __init__(self, model: str, cost_callback, messages: List[BaseMessage], print_stream, print_costs):
|
||||
def __init__(self, model: str, cost_callback, messages: List[BaseMessage], print_stream, print_costs, conversation_logger: ConversationLogger = None):
|
||||
self._chat = ChatOpenAI(
|
||||
model_name=model,
|
||||
streaming=True,
|
||||
@@ -119,6 +124,7 @@ class _GPTConversation:
|
||||
self.messages = messages
|
||||
self.print_stream = print_stream
|
||||
self.print_costs = print_costs
|
||||
self.conversation_logger = conversation_logger
|
||||
|
||||
def print_messages(self, messages):
|
||||
for i, message in enumerate(messages):
|
||||
@@ -141,6 +147,7 @@ class _GPTConversation:
|
||||
for i in range(10):
|
||||
try:
|
||||
response = self._chat(self.messages)
|
||||
self.conversation_logger.log(self.messages, response)
|
||||
break
|
||||
except (ConnectionError, InvalidChunkLength, ChunkedEncodingError) as e:
|
||||
print('There was a connection error. Retrying...')
|
||||
@@ -173,7 +180,7 @@ def ask_gpt(prompt_template, parser, **kwargs):
|
||||
if isinstance(value, dict):
|
||||
kwargs[key] = json.dumps(value, indent=4)
|
||||
prompt = prompt_template.format(**kwargs)
|
||||
conversation = GPTSession().get_conversation(
|
||||
conversation = GPTSession._instance.get_conversation(
|
||||
[],
|
||||
print_stream=os.environ['VERBOSE'].lower() == 'true',
|
||||
print_costs=False
|
||||
|
||||
28
dev_gpt/options/generate/conversation_logger.py
Normal file
28
dev_gpt/options/generate/conversation_logger.py
Normal file
@@ -0,0 +1,28 @@
|
||||
import json
|
||||
from typing import List
|
||||
|
||||
from langchain.schema import BaseMessage
|
||||
|
||||
|
||||
class ConversationLogger:
|
||||
def __init__(self, log_file_path):
|
||||
self.log_file_path = log_file_path
|
||||
self.log_file = []
|
||||
|
||||
def log(self, prompt_message_list: List[BaseMessage], response: str):
|
||||
prompt_list_json = [
|
||||
{
|
||||
'role': f'{message.type}',
|
||||
'content': f'{message.content}'
|
||||
}
|
||||
for message in prompt_message_list
|
||||
]
|
||||
self.log_file.append({
|
||||
'prompt': prompt_list_json,
|
||||
'response': f'{response}'
|
||||
})
|
||||
with open(self.log_file_path, 'w') as f:
|
||||
f.write(json.dumps(self.log_file, indent=2))
|
||||
|
||||
|
||||
|
||||
@@ -42,7 +42,7 @@ class TaskSpecification:
|
||||
|
||||
class Generator:
|
||||
def __init__(self, task_description, path, model='gpt-4', self_healing=True):
|
||||
self.gpt_session = gpt.GPTSession(model=model)
|
||||
self.gpt_session = gpt.GPTSession(os.path.join(path, 'log.json'), model=model)
|
||||
self.microservice_specification = TaskSpecification(task=task_description, test=None)
|
||||
self.self_healing = self_healing
|
||||
self.microservice_root_path = path
|
||||
@@ -540,8 +540,8 @@ pytest
|
||||
# '/private/var/folders/f5/whmffl4d7q79s29jpyb6719m0000gn/T/pytest-of-florianhonicke/pytest-128/test_generation_level_0_mock_i0'
|
||||
# '/private/var/folders/f5/whmffl4d7q79s29jpyb6719m0000gn/T/pytest-of-florianhonicke/pytest-129/test_generation_level_0_mock_i0'
|
||||
def generate(self):
|
||||
self.microservice_specification.task, self.microservice_specification.test = PM().refine_specification(self.microservice_specification.task)
|
||||
os.makedirs(self.microservice_root_path)
|
||||
self.microservice_specification.task, self.microservice_specification.test = PM().refine_specification(self.microservice_specification.task)
|
||||
generated_name = self.generate_microservice_name(self.microservice_specification.task)
|
||||
self.microservice_name = f'{generated_name}{random.randint(0, 10_000_000)}'
|
||||
packages_list = self.get_possible_packages()
|
||||
|
||||
Reference in New Issue
Block a user