feat: support gpt turbo

This commit is contained in:
Joschka Braun
2023-04-11 15:06:47 +02:00
parent 4efb7333ef
commit db9ae28828
3 changed files with 22 additions and 10 deletions

View File

@@ -16,13 +16,15 @@ def main():
@click.option('--num_approaches', default=3, type=int, @click.option('--num_approaches', default=3, type=int,
help='Number of num_approaches to use to fulfill the task (default: 3).') help='Number of num_approaches to use to fulfill the task (default: 3).')
@click.option('--output_path', default='executor', help='Path to the output folder (must be empty). ') @click.option('--output_path', default='executor', help='Path to the output folder (must be empty). ')
@click.option('--model', default='gpt-4', help='GPT model to use (default: gpt-4).')
def create( def create(
description, description,
test, test,
num_approaches=3, num_approaches=3,
output_path='executor', output_path='executor',
model='gpt-4'
): ):
executor_factory = ExecutorFactory() executor_factory = ExecutorFactory(model=model)
executor_factory.create(description, num_approaches, output_path, test) executor_factory.create(description, num_approaches, output_path, test)

View File

@@ -12,14 +12,20 @@ from src.utils.string_tools import print_colored
class ExecutorFactory: class ExecutorFactory:
def __init__(self): def __init__(self, model='gpt-4'):
self.gpt_session = gpt.GPTSession() self.gpt_session = gpt.GPTSession(model=model)
def extract_content_from_result(self, plain_text, file_name): def extract_content_from_result(self, plain_text, file_name):
pattern = fr"^\*\*{file_name}\*\*\n```(?:\w+\n)?([\s\S]*?)```" pattern = fr"^\*\*{file_name}\*\*\n```(?:\w+\n)?([\s\S]*?)```"
match = re.search(pattern, plain_text, re.MULTILINE) match = re.search(pattern, plain_text, re.MULTILINE)
if match: if match:
return match.group(1).strip() return match.group(1).strip()
else:
# Check for a single code block
single_code_block_pattern = r"^```(?:\w+\n)?([\s\S]*?)```"
single_code_block_match = re.findall(single_code_block_pattern, plain_text, re.MULTILINE)
if len(single_code_block_match) == 1:
return single_code_block_match[0].strip()
else: else:
return '' return ''

View File

@@ -11,15 +11,19 @@ from src.prompt_system import system_base_definition
from src.utils.io import timeout_generator_wrapper, GenerationTimeoutError from src.utils.io import timeout_generator_wrapper, GenerationTimeoutError
from src.utils.string_tools import print_colored from src.utils.string_tools import print_colored
class GPTSession: class GPTSession:
def __init__(self): def __init__(self, model: str = 'gpt-4'):
self.get_openai_api_key() self.get_openai_api_key()
if self.is_gpt4_available(): if model == 'gpt-4' and self.is_gpt4_available():
self.supported_model = 'gpt-4' self.supported_model = 'gpt-4'
self.pricing_prompt = PRICING_GPT4_PROMPT self.pricing_prompt = PRICING_GPT4_PROMPT
self.pricing_generation = PRICING_GPT4_GENERATION self.pricing_generation = PRICING_GPT4_GENERATION
else: elif (model == 'gpt-4' and not self.is_gpt4_available()) or model == 'gpt-3.5-turbo':
self.supported_model = 'gpt-3.5-turbo' if model == 'gpt-4':
print_colored('GPT-4 is not available. Using GPT-3.5-turbo instead.', 'yellow')
model = 'gpt-3.5-turbo'
self.supported_model = model
self.pricing_prompt = PRICING_GPT3_5_TURBO_PROMPT self.pricing_prompt = PRICING_GPT3_5_TURBO_PROMPT
self.pricing_generation = PRICING_GPT3_5_TURBO_GENERATION self.pricing_generation = PRICING_GPT3_5_TURBO_GENERATION
self.chars_prompt_so_far = 0 self.chars_prompt_so_far = 0
@@ -52,8 +56,8 @@ class GPTSession:
self.chars_prompt_so_far += chars_prompt self.chars_prompt_so_far += chars_prompt
self.chars_generation_so_far += chars_generation self.chars_generation_so_far += chars_generation
print('\n') print('\n')
money_prompt = round(self.chars_prompt_so_far / 3.4 * self.pricing_prompt / 1000, 2) money_prompt = round(self.chars_prompt_so_far / 3.4 * self.pricing_prompt / 1000, 3)
money_generation = round(self.chars_generation_so_far / 3.4 * self.pricing_generation / 1000, 2) money_generation = round(self.chars_generation_so_far / 3.4 * self.pricing_generation / 1000, 3)
print('Estimated costs on openai.com:') print('Estimated costs on openai.com:')
# print('money prompt:', f'${money_prompt}') # print('money prompt:', f'${money_prompt}')
# print('money generation:', f'${money_generation}') # print('money generation:', f'${money_generation}')