diff --git a/src/apis/gpt.py b/src/apis/gpt.py index 0122af5..a049db8 100644 --- a/src/apis/gpt.py +++ b/src/apis/gpt.py @@ -14,7 +14,8 @@ from requests.exceptions import ConnectionError from src.constants import PRICING_GPT4_PROMPT, PRICING_GPT4_GENERATION, PRICING_GPT3_5_TURBO_PROMPT, \ PRICING_GPT3_5_TURBO_GENERATION, CHARS_PER_TOKEN -from src.options.generate.templates_system import template_system_message_base, executor_example, docarray_example, client_example +from src.options.generate.templates_system import template_system_message_base, executor_example, docarray_example, \ + client_example, gpt_example from src.utils.string_tools import print_colored @@ -36,7 +37,7 @@ class GPTSession: self.chars_prompt_so_far = 0 self.chars_generation_so_far = 0 - def get_conversation(self, system_definition_examples: List[str] = ['executor', 'docarray', 'client']): + def get_conversation(self, system_definition_examples: List[str] = ['gpt', 'executor', 'docarray', 'client']): return _GPTConversation( self.model_name, self.cost_callback, self.task_description, self.test_description, system_definition_examples ) @@ -134,6 +135,8 @@ class _GPTConversation: task_description=task_description, test_description=test_description, ) + if 'gpt' in system_definition_examples: + system_message += f'\n{gpt_example}' if 'executor' in system_definition_examples: system_message += f'\n{executor_example}' if 'docarray' in system_definition_examples: