mirror of
https://github.com/aljazceru/dev-gpt.git
synced 2025-12-21 15:44:19 +01:00
refactor: simplify gpt session
This commit is contained in:
@@ -7,12 +7,7 @@ import openai
|
|||||||
from langchain.callbacks import CallbackManager
|
from langchain.callbacks import CallbackManager
|
||||||
from langchain.chat_models import ChatOpenAI
|
from langchain.chat_models import ChatOpenAI
|
||||||
from openai.error import RateLimitError
|
from openai.error import RateLimitError
|
||||||
from langchain.schema import (
|
from langchain.schema import AIMessage, HumanMessage, SystemMessage, BaseMessage
|
||||||
AIMessage,
|
|
||||||
HumanMessage,
|
|
||||||
SystemMessage,
|
|
||||||
BaseMessage
|
|
||||||
)
|
|
||||||
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
||||||
|
|
||||||
from src.options.generate.prompt_system import system_message_base, executor_example, docarray_example, client_example
|
from src.options.generate.prompt_system import system_message_base, executor_example, docarray_example, client_example
|
||||||
@@ -22,13 +17,7 @@ from src.utils.string_tools import print_colored
|
|||||||
class GPTSession:
|
class GPTSession:
|
||||||
def __init__(self, model: str = 'gpt-4'):
|
def __init__(self, model: str = 'gpt-4'):
|
||||||
self.configure_openai_api_key()
|
self.configure_openai_api_key()
|
||||||
if model == 'gpt-4' and self.is_gpt4_available():
|
self.model_name = 'gpt-4' if model == 'gpt-4' and self.is_gpt4_available() else 'gpt-3.5-turbo'
|
||||||
self.supported_model = 'gpt-4'
|
|
||||||
else:
|
|
||||||
if model == 'gpt-4':
|
|
||||||
print_colored('GPT-4 is not available. Using GPT-3.5-turbo instead.', 'yellow')
|
|
||||||
model = 'gpt-3.5-turbo'
|
|
||||||
self.supported_model = model
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def configure_openai_api_key():
|
def configure_openai_api_key():
|
||||||
@@ -58,10 +47,11 @@ If you have updated it already, please restart your terminal.
|
|||||||
continue
|
continue
|
||||||
return True
|
return True
|
||||||
except openai.error.InvalidRequestError:
|
except openai.error.InvalidRequestError:
|
||||||
|
print_colored('GPT-4 is not available. Using GPT-3.5-turbo instead.', 'yellow')
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def get_conversation(self, system_definition_examples: List[str] = ['executor', 'docarray', 'client']):
|
def get_conversation(self, system_definition_examples: List[str] = ['executor', 'docarray', 'client']):
|
||||||
return _GPTConversation(self.supported_model, system_definition_examples)
|
return _GPTConversation(self.model_name, system_definition_examples)
|
||||||
|
|
||||||
|
|
||||||
class AssistantStreamingStdOutCallbackHandler(StreamingStdOutCallbackHandler):
|
class AssistantStreamingStdOutCallbackHandler(StreamingStdOutCallbackHandler):
|
||||||
|
|||||||
Reference in New Issue
Block a user