📗 feat: prompt logging

This commit is contained in:
Florian Hönicke
2023-05-22 16:53:42 +02:00
parent 83719bf380
commit 70b6a9b7c7
3 changed files with 41 additions and 6 deletions

View File

@@ -16,6 +16,7 @@ from urllib3.exceptions import InvalidChunkLength
from dev_gpt.constants import PRICING_GPT4_PROMPT, PRICING_GPT4_GENERATION, PRICING_GPT3_5_TURBO_PROMPT, \
PRICING_GPT3_5_TURBO_GENERATION, CHARS_PER_TOKEN
from dev_gpt.options.generate.conversation_logger import ConversationLogger
from dev_gpt.options.generate.templates_system import template_system_message_base
from dev_gpt.utils.string_tools import print_colored, get_template_parameters
@@ -41,9 +42,10 @@ class GPTSession:
cls._instance = super(GPTSession, cls).__new__(cls)
return cls._instance
def __init__(self, model: str = 'gpt-4', ):
def __init__(self, log_file_path: str, model: str = 'gpt-4', ):
if GPTSession._initialized:
return
self.conversation_logger = ConversationLogger(log_file_path)
if model == 'gpt-4' and self.is_gpt4_available():
self.pricing_prompt = PRICING_GPT4_PROMPT
self.pricing_generation = PRICING_GPT4_GENERATION
@@ -58,10 +60,13 @@ class GPTSession:
self.chars_generation_so_far = 0
GPTSession._initialized = True
def get_conversation(self, messages: List[BaseMessage] = [], print_stream: bool = True, print_costs: bool = True):
messages = deepcopy(messages)
return _GPTConversation(
self.model_name, self.cost_callback, messages, print_stream, print_costs
self.model_name, self.cost_callback, messages, print_stream, print_costs, self.conversation_logger
)
@staticmethod
@@ -107,7 +112,7 @@ class AssistantStreamingStdOutCallbackHandler(StreamingStdOutCallbackHandler):
class _GPTConversation:
def __init__(self, model: str, cost_callback, messages: List[BaseMessage], print_stream, print_costs):
def __init__(self, model: str, cost_callback, messages: List[BaseMessage], print_stream, print_costs, conversation_logger: ConversationLogger = None):
self._chat = ChatOpenAI(
model_name=model,
streaming=True,
@@ -119,6 +124,7 @@ class _GPTConversation:
self.messages = messages
self.print_stream = print_stream
self.print_costs = print_costs
self.conversation_logger = conversation_logger
def print_messages(self, messages):
for i, message in enumerate(messages):
@@ -141,6 +147,7 @@ class _GPTConversation:
for i in range(10):
try:
response = self._chat(self.messages)
self.conversation_logger.log(self.messages, response)
break
except (ConnectionError, InvalidChunkLength, ChunkedEncodingError) as e:
print('There was a connection error. Retrying...')
@@ -173,7 +180,7 @@ def ask_gpt(prompt_template, parser, **kwargs):
if isinstance(value, dict):
kwargs[key] = json.dumps(value, indent=4)
prompt = prompt_template.format(**kwargs)
conversation = GPTSession().get_conversation(
conversation = GPTSession._instance.get_conversation(
[],
print_stream=os.environ['VERBOSE'].lower() == 'true',
print_costs=False