feat: support gpt turbo

This commit is contained in:
Joschka Braun
2023-04-11 15:06:47 +02:00
parent 4efb7333ef
commit db9ae28828
3 changed files with 22 additions and 10 deletions

View File

@@ -16,13 +16,15 @@ def main():
@click.option('--num_approaches', default=3, type=int,
help='Number of num_approaches to use to fulfill the task (default: 3).')
@click.option('--output_path', default='executor', help='Path to the output folder (must be empty). ')
@click.option('--model', default='gpt-4', help='GPT model to use (default: gpt-4).')
def create(
description,
test,
num_approaches=3,
output_path='executor',
model='gpt-4'
):
executor_factory = ExecutorFactory()
executor_factory = ExecutorFactory(model=model)
executor_factory.create(description, num_approaches, output_path, test)

View File

@@ -12,8 +12,8 @@ from src.utils.string_tools import print_colored
class ExecutorFactory:
def __init__(self):
self.gpt_session = gpt.GPTSession()
def __init__(self, model='gpt-4'):
self.gpt_session = gpt.GPTSession(model=model)
def extract_content_from_result(self, plain_text, file_name):
pattern = fr"^\*\*{file_name}\*\*\n```(?:\w+\n)?([\s\S]*?)```"
@@ -21,7 +21,13 @@ class ExecutorFactory:
if match:
return match.group(1).strip()
else:
return ''
# Check for a single code block
single_code_block_pattern = r"^```(?:\w+\n)?([\s\S]*?)```"
single_code_block_match = re.findall(single_code_block_pattern, plain_text, re.MULTILINE)
if len(single_code_block_match) == 1:
return single_code_block_match[0].strip()
else:
return ''
def write_config_yml(self, executor_name, dest_folder):
config_content = f'''

View File

@@ -11,15 +11,19 @@ from src.prompt_system import system_base_definition
from src.utils.io import timeout_generator_wrapper, GenerationTimeoutError
from src.utils.string_tools import print_colored
class GPTSession:
def __init__(self):
def __init__(self, model: str = 'gpt-4'):
self.get_openai_api_key()
if self.is_gpt4_available():
if model == 'gpt-4' and self.is_gpt4_available():
self.supported_model = 'gpt-4'
self.pricing_prompt = PRICING_GPT4_PROMPT
self.pricing_generation = PRICING_GPT4_GENERATION
else:
self.supported_model = 'gpt-3.5-turbo'
elif (model == 'gpt-4' and not self.is_gpt4_available()) or model == 'gpt-3.5-turbo':
if model == 'gpt-4':
print_colored('GPT-4 is not available. Using GPT-3.5-turbo instead.', 'yellow')
model = 'gpt-3.5-turbo'
self.supported_model = model
self.pricing_prompt = PRICING_GPT3_5_TURBO_PROMPT
self.pricing_generation = PRICING_GPT3_5_TURBO_GENERATION
self.chars_prompt_so_far = 0
@@ -52,8 +56,8 @@ class GPTSession:
self.chars_prompt_so_far += chars_prompt
self.chars_generation_so_far += chars_generation
print('\n')
money_prompt = round(self.chars_prompt_so_far / 3.4 * self.pricing_prompt / 1000, 2)
money_generation = round(self.chars_generation_so_far / 3.4 * self.pricing_generation / 1000, 2)
money_prompt = round(self.chars_prompt_so_far / 3.4 * self.pricing_prompt / 1000, 3)
money_generation = round(self.chars_generation_so_far / 3.4 * self.pricing_generation / 1000, 3)
print('Estimated costs on openai.com:')
# print('money prompt:', f'${money_prompt}')
# print('money generation:', f'${money_generation}')