refactor: cleanup

This commit is contained in:
Florian Hönicke
2023-05-11 00:05:33 +02:00
parent bcb1804997
commit 69ebebf4a4
5 changed files with 42 additions and 39 deletions

View File

@@ -2,16 +2,15 @@ import json
import os
from copy import deepcopy
from time import sleep
from typing import List, Any
import openai
from langchain import PromptTemplate
from langchain.callbacks import CallbackManager
from langchain.chat_models import ChatOpenAI
from openai.error import RateLimitError
from langchain.schema import HumanMessage, SystemMessage, BaseMessage, AIMessage
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from langchain.chat_models import ChatOpenAI
from langchain.schema import HumanMessage, SystemMessage, BaseMessage, AIMessage
from openai.error import RateLimitError
from requests.exceptions import ConnectionError, ChunkedEncodingError
from urllib3.exceptions import InvalidChunkLength
@@ -32,6 +31,7 @@ If you have updated it already, please restart your terminal.
exit(1)
openai.api_key = os.environ['OPENAI_API_KEY']
class GPTSession:
_instance = None
_initialized = False
@@ -64,7 +64,6 @@ class GPTSession:
self.model_name, self.cost_callback, messages, print_stream, print_costs
)
@staticmethod
def is_gpt4_available():
try:
@@ -120,7 +119,6 @@ class _GPTConversation:
self.print_stream = print_stream
self.print_costs = print_costs
def print_messages(self, messages):
for i, message in enumerate(messages):
if os.environ['VERBOSE'].lower() == 'true':
@@ -167,8 +165,9 @@ class _GPTConversation:
def ask_gpt(prompt_template, parser, **kwargs):
template_parameters = get_template_parameters(prompt_template)
if set(template_parameters) != set(kwargs.keys()):
raise ValueError(f'Prompt template parameters {get_template_parameters(prompt_template)} do not match '
f'provided parameters {kwargs.keys()}')
raise ValueError(
f'Prompt template parameters {set(template_parameters)} do not match provided parameters {set(kwargs.keys())}'
)
for key, value in kwargs.items():
if isinstance(value, dict):
kwargs[key] = json.dumps(value, indent=4)