Merge branch 'main' of https://github.com/jina-ai/gptdeploy into refactor-langchain

# Conflicts:
#	src/apis/gpt.py
#	src/cli.py
#	src/options/generate/generator.py
#	src/options/generate/prompt_system.py
#	src/options/generate/prompt_tasks.py
This commit is contained in:
Joschka Braun
2023-04-18 12:17:52 +02:00
15 changed files with 112 additions and 85 deletions

View File

@@ -15,7 +15,9 @@ from src.utils.string_tools import print_colored
class GPTSession:
def __init__(self, model: str = 'gpt-4'):
def __init__(self, task_description, test_description, model: str = 'gpt-4', ):
self.task_description = task_description
self.test_description = test_description
self.configure_openai_api_key()
self.model_name = 'gpt-4' if model == 'gpt-4' and self.is_gpt4_available() else 'gpt-3.5-turbo'
@@ -51,7 +53,7 @@ If you have updated it already, please restart your terminal.
return False
def get_conversation(self, system_definition_examples: List[str] = ['executor', 'docarray', 'client']):
return _GPTConversation(self.model_name, system_definition_examples)
return _GPTConversation(self.model_name, self.task_description, self.test_description, system_definition_examples)
class AssistantStreamingStdOutCallbackHandler(StreamingStdOutCallbackHandler):
@@ -62,7 +64,7 @@ class AssistantStreamingStdOutCallbackHandler(StreamingStdOutCallbackHandler):
class _GPTConversation:
def __init__(self, model: str, system_definition_examples: List[str] = ['executor', 'docarray', 'client']):
def __init__(self, model: str, task_description, test_description, system_definition_examples: List[str] = ['executor', 'docarray', 'client']):
self.chat = ChatOpenAI(
model_name=model,
streaming=True,
@@ -85,7 +87,7 @@ class _GPTConversation:
return response
@staticmethod
def _create_system_message(system_definition_examples: List[str] = []) -> SystemMessage:
def _create_system_message(task_description, test_description, system_definition_examples: List[str] = []) -> SystemMessage:
system_message = system_message_base
if 'executor' in system_definition_examples:
system_message += f'\n{executor_example}'
@@ -93,4 +95,5 @@ class _GPTConversation:
system_message += f'\n{docarray_example}'
if 'client' in system_definition_examples:
system_message += f'\n{client_example}'
# create from template
return SystemMessage(content=system_message)