refactor: simplify gpt session

This commit is contained in:
Joschka Braun
2023-04-17 15:05:58 +02:00
parent deaea68f4f
commit 03de77b58e

View File

@@ -7,12 +7,7 @@ import openai
from langchain.callbacks import CallbackManager from langchain.callbacks import CallbackManager
from langchain.chat_models import ChatOpenAI from langchain.chat_models import ChatOpenAI
from openai.error import RateLimitError from openai.error import RateLimitError
from langchain.schema import ( from langchain.schema import AIMessage, HumanMessage, SystemMessage, BaseMessage
AIMessage,
HumanMessage,
SystemMessage,
BaseMessage
)
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from src.options.generate.prompt_system import system_message_base, executor_example, docarray_example, client_example from src.options.generate.prompt_system import system_message_base, executor_example, docarray_example, client_example
@@ -22,13 +17,7 @@ from src.utils.string_tools import print_colored
class GPTSession: class GPTSession:
def __init__(self, model: str = 'gpt-4'): def __init__(self, model: str = 'gpt-4'):
self.configure_openai_api_key() self.configure_openai_api_key()
if model == 'gpt-4' and self.is_gpt4_available(): self.model_name = 'gpt-4' if model == 'gpt-4' and self.is_gpt4_available() else 'gpt-3.5-turbo'
self.supported_model = 'gpt-4'
else:
if model == 'gpt-4':
print_colored('GPT-4 is not available. Using GPT-3.5-turbo instead.', 'yellow')
model = 'gpt-3.5-turbo'
self.supported_model = model
@staticmethod @staticmethod
def configure_openai_api_key(): def configure_openai_api_key():
@@ -58,10 +47,11 @@ If you have updated it already, please restart your terminal.
continue continue
return True return True
except openai.error.InvalidRequestError: except openai.error.InvalidRequestError:
print_colored('GPT-4 is not available. Using GPT-3.5-turbo instead.', 'yellow')
return False return False
def get_conversation(self, system_definition_examples: List[str] = ['executor', 'docarray', 'client']): def get_conversation(self, system_definition_examples: List[str] = ['executor', 'docarray', 'client']):
return _GPTConversation(self.supported_model, system_definition_examples) return _GPTConversation(self.model_name, system_definition_examples)
class AssistantStreamingStdOutCallbackHandler(StreamingStdOutCallbackHandler): class AssistantStreamingStdOutCallbackHandler(StreamingStdOutCallbackHandler):