diff --git a/src/cli.py b/src/cli.py index 2c0bc37..2d7cc93 100644 --- a/src/cli.py +++ b/src/cli.py @@ -16,13 +16,15 @@ def main(): @click.option('--num_approaches', default=3, type=int, help='Number of num_approaches to use to fulfill the task (default: 3).') @click.option('--output_path', default='executor', help='Path to the output folder (must be empty). ') +@click.option('--model', default='gpt-4', help='GPT model to use (default: gpt-4).') def create( description, test, num_approaches=3, output_path='executor', + model='gpt-4' ): - executor_factory = ExecutorFactory() + executor_factory = ExecutorFactory(model=model) executor_factory.create(description, num_approaches, output_path, test) diff --git a/src/executor_factory.py b/src/executor_factory.py index c7b0a41..ca52766 100644 --- a/src/executor_factory.py +++ b/src/executor_factory.py @@ -12,8 +12,8 @@ from src.utils.string_tools import print_colored class ExecutorFactory: - def __init__(self): - self.gpt_session = gpt.GPTSession() + def __init__(self, model='gpt-4'): + self.gpt_session = gpt.GPTSession(model=model) def extract_content_from_result(self, plain_text, file_name): pattern = fr"^\*\*{file_name}\*\*\n```(?:\w+\n)?([\s\S]*?)```" @@ -21,7 +21,13 @@ class ExecutorFactory: if match: return match.group(1).strip() else: - return '' + # Check for a single code block + single_code_block_pattern = r"^```(?:\w+\n)?([\s\S]*?)```" + single_code_block_match = re.findall(single_code_block_pattern, plain_text, re.MULTILINE) + if len(single_code_block_match) == 1: + return single_code_block_match[0].strip() + else: + return '' def write_config_yml(self, executor_name, dest_folder): config_content = f''' diff --git a/src/gpt.py b/src/gpt.py index 7af6da8..c3d0fff 100644 --- a/src/gpt.py +++ b/src/gpt.py @@ -11,15 +11,19 @@ from src.prompt_system import system_base_definition from src.utils.io import timeout_generator_wrapper, GenerationTimeoutError from src.utils.string_tools import print_colored + class GPTSession: - def __init__(self): + def __init__(self, model: str = 'gpt-4'): self.get_openai_api_key() - if self.is_gpt4_available(): + if model == 'gpt-4' and self.is_gpt4_available(): self.supported_model = 'gpt-4' self.pricing_prompt = PRICING_GPT4_PROMPT self.pricing_generation = PRICING_GPT4_GENERATION - else: - self.supported_model = 'gpt-3.5-turbo' + elif (model == 'gpt-4' and not self.is_gpt4_available()) or model == 'gpt-3.5-turbo': + if model == 'gpt-4': + print_colored('GPT-4 is not available. Using GPT-3.5-turbo instead.', 'yellow') + model = 'gpt-3.5-turbo' + self.supported_model = model self.pricing_prompt = PRICING_GPT3_5_TURBO_PROMPT self.pricing_generation = PRICING_GPT3_5_TURBO_GENERATION self.chars_prompt_so_far = 0 @@ -52,8 +56,8 @@ class GPTSession: self.chars_prompt_so_far += chars_prompt self.chars_generation_so_far += chars_generation print('\n') - money_prompt = round(self.chars_prompt_so_far / 3.4 * self.pricing_prompt / 1000, 2) - money_generation = round(self.chars_generation_so_far / 3.4 * self.pricing_generation / 1000, 2) + money_prompt = round(self.chars_prompt_so_far / 3.4 * self.pricing_prompt / 1000, 3) + money_generation = round(self.chars_generation_so_far / 3.4 * self.pricing_generation / 1000, 3) print('Estimated costs on openai.com:') # print('money prompt:', f'${money_prompt}') # print('money generation:', f'${money_generation}')