fix: multiple things and 3d

This commit is contained in:
Florian Hönicke
2023-04-18 02:25:28 +02:00
parent badf295f71
commit 7dcd690245
12 changed files with 176 additions and 129 deletions

View File

@@ -14,7 +14,9 @@ from src.utils.string_tools import print_colored
class GPTSession:
def __init__(self, model: str = 'gpt-4'):
def __init__(self, task_description, test_description, model: str = 'gpt-4', ):
self.task_description = task_description
self.test_description = test_description
self.configure_openai_api_key()
if model == 'gpt-4' and self.is_gpt4_available():
self.supported_model = 'gpt-4'
@@ -68,18 +70,18 @@ If you have updated it already, please restart your terminal.
print('\n')
def get_conversation(self, system_definition_examples: List[str] = ['executor', 'docarray', 'client']):
return _GPTConversation(self.supported_model, self.cost_callback, system_definition_examples)
return _GPTConversation(self.supported_model, self.cost_callback, self.task_description, self.test_description, system_definition_examples)
def calculate_money_spent(self, num_chars, price):
return round(num_chars / CHARS_PER_TOKEN * price / 1000, 3)
class _GPTConversation:
def __init__(self, model: str, cost_callback, system_definition_examples: List[str] = ['executor', 'docarray', 'client']):
def __init__(self, model: str, cost_callback, task_description, test_description, system_definition_examples: List[str] = ['executor', 'docarray', 'client']):
self.model = model
self.cost_callback = cost_callback
self.prompt_list: List[Optional[Tuple]] = [None]
self.set_system_definition(system_definition_examples)
self.set_system_definition(task_description, test_description, system_definition_examples)
if os.environ['VERBOSE'].lower() == 'true':
print_colored('system', self.prompt_list[0][1], 'magenta')
@@ -91,8 +93,8 @@ class _GPTConversation:
self.prompt_list.append(('assistant', response))
return response
def set_system_definition(self, system_definition_examples: List[str] = []):
system_message = system_base_definition
def set_system_definition(self, task_description, test_description, system_definition_examples: List[str] = []):
system_message = system_base_definition(task_description, test_description)
if 'executor' in system_definition_examples:
system_message += f'\n{executor_example}'
if 'docarray' in system_definition_examples: