mirror of
https://github.com/aljazceru/dev-gpt.git
synced 2025-12-21 23:54:19 +01:00
Merge branch 'main' of https://github.com/jina-ai/gptdeploy into refactor-langchain
# Conflicts: # src/apis/gpt.py # src/cli.py # src/options/generate/generator.py # src/options/generate/prompt_system.py # src/options/generate/prompt_tasks.py
This commit is contained in:
@@ -15,7 +15,9 @@ from src.utils.string_tools import print_colored
|
||||
|
||||
|
||||
class GPTSession:
|
||||
def __init__(self, model: str = 'gpt-4'):
|
||||
def __init__(self, task_description, test_description, model: str = 'gpt-4', ):
|
||||
self.task_description = task_description
|
||||
self.test_description = test_description
|
||||
self.configure_openai_api_key()
|
||||
self.model_name = 'gpt-4' if model == 'gpt-4' and self.is_gpt4_available() else 'gpt-3.5-turbo'
|
||||
|
||||
@@ -51,7 +53,7 @@ If you have updated it already, please restart your terminal.
|
||||
return False
|
||||
|
||||
def get_conversation(self, system_definition_examples: List[str] = ['executor', 'docarray', 'client']):
|
||||
return _GPTConversation(self.model_name, system_definition_examples)
|
||||
return _GPTConversation(self.model_name, self.task_description, self.test_description, system_definition_examples)
|
||||
|
||||
|
||||
class AssistantStreamingStdOutCallbackHandler(StreamingStdOutCallbackHandler):
|
||||
@@ -62,7 +64,7 @@ class AssistantStreamingStdOutCallbackHandler(StreamingStdOutCallbackHandler):
|
||||
|
||||
|
||||
class _GPTConversation:
|
||||
def __init__(self, model: str, system_definition_examples: List[str] = ['executor', 'docarray', 'client']):
|
||||
def __init__(self, model: str, task_description, test_description, system_definition_examples: List[str] = ['executor', 'docarray', 'client']):
|
||||
self.chat = ChatOpenAI(
|
||||
model_name=model,
|
||||
streaming=True,
|
||||
@@ -85,7 +87,7 @@ class _GPTConversation:
|
||||
return response
|
||||
|
||||
@staticmethod
|
||||
def _create_system_message(system_definition_examples: List[str] = []) -> SystemMessage:
|
||||
def _create_system_message(task_description, test_description, system_definition_examples: List[str] = []) -> SystemMessage:
|
||||
system_message = system_message_base
|
||||
if 'executor' in system_definition_examples:
|
||||
system_message += f'\n{executor_example}'
|
||||
@@ -93,4 +95,5 @@ class _GPTConversation:
|
||||
system_message += f'\n{docarray_example}'
|
||||
if 'client' in system_definition_examples:
|
||||
system_message += f'\n{client_example}'
|
||||
# create from template
|
||||
return SystemMessage(content=system_message)
|
||||
|
||||
Reference in New Issue
Block a user