diff --git a/dev_gpt/constants.py b/dev_gpt/constants.py index 4326065..20181c8 100644 --- a/dev_gpt/constants.py +++ b/dev_gpt/constants.py @@ -26,6 +26,10 @@ FILE_AND_TAG_PAIRS = [ (STREAMLIT_FILE_NAME, STREAMLIT_FILE_TAG) ] +INDICATOR_TO_IMPORT_STATEMENT = { + 'BytesIO': 'from io import BytesIO', +} + FLOW_URL_PLACEHOLDER = 'jcloud.jina.ai' PRICING_GPT4_PROMPT = 0.03 diff --git a/dev_gpt/options/generate/generator.py b/dev_gpt/options/generate/generator.py index fd0ab99..52e0dc0 100644 --- a/dev_gpt/options/generate/generator.py +++ b/dev_gpt/options/generate/generator.py @@ -17,7 +17,8 @@ from dev_gpt.apis.pypi import is_package_on_pypi, clean_requirements_txt from dev_gpt.constants import FILE_AND_TAG_PAIRS, NUM_IMPLEMENTATION_STRATEGIES, MAX_DEBUGGING_ITERATIONS, \ BLACKLISTED_PACKAGES, EXECUTOR_FILE_NAME, TEST_EXECUTOR_FILE_NAME, TEST_EXECUTOR_FILE_TAG, \ REQUIREMENTS_FILE_NAME, REQUIREMENTS_FILE_TAG, DOCKER_FILE_NAME, IMPLEMENTATION_FILE_NAME, \ - IMPLEMENTATION_FILE_TAG, LANGUAGE_PACKAGES, UNNECESSARY_PACKAGES, DOCKER_BASE_IMAGE_VERSION, SEARCH_PACKAGES + IMPLEMENTATION_FILE_TAG, LANGUAGE_PACKAGES, UNNECESSARY_PACKAGES, DOCKER_BASE_IMAGE_VERSION, SEARCH_PACKAGES, \ + INDICATOR_TO_IMPORT_STATEMENT from dev_gpt.options.generate.pm.pm import PM from dev_gpt.options.generate.templates_user import template_generate_microservice_name, \ template_generate_possible_packages, \ @@ -103,6 +104,7 @@ metas: parse_result_fn: Callable = None, use_custom_system_message: bool = True, response_format_example: str = None, + post_process_fn: Callable = None, **template_kwargs ): """This function generates file(s) using the given template and persists it/them in the given destination folder. @@ -146,6 +148,8 @@ metas: ) ) content = parse_result_fn(content_raw) + if post_process_fn is not None: + content = post_process_fn(content) if content == {}: conversation = self.gpt_session.get_conversation( messages=[SystemMessage(content='You are a helpful assistant.'), AIMessage(content=content_raw)] @@ -209,6 +213,7 @@ metas: file_name_purpose=IMPLEMENTATION_FILE_NAME, tag_name=IMPLEMENTATION_FILE_TAG, file_name_s=[IMPLEMENTATION_FILE_NAME], + post_process_fn=self.add_missing_imports_post_process_fn, )[IMPLEMENTATION_FILE_NAME] test_microservice_content = self.generate_and_persist_file( @@ -221,6 +226,7 @@ metas: file_name_purpose=TEST_EXECUTOR_FILE_NAME, tag_name=TEST_EXECUTOR_FILE_TAG, file_name_s=[TEST_EXECUTOR_FILE_NAME], + post_process_fn=self.add_missing_imports_post_process_fn, )[TEST_EXECUTOR_FILE_NAME] self.generate_and_persist_file( @@ -250,6 +256,13 @@ metas: print('\nFirst version of the microservice generated. Start iterating on it to make the tests pass...') + + def add_missing_imports_post_process_fn(self, content_raw: str): + for indicator, import_statement in INDICATOR_TO_IMPORT_STATEMENT.items(): + if indicator in content_raw and import_statement not in content_raw: + content_raw = f'{import_statement}\n{content_raw}' + + @staticmethod def read_docker_template(): with open(os.path.join(os.path.dirname(__file__), 'static_files', 'microservice', 'Dockerfile'), 'r', encoding='utf-8') as f: