diff --git a/requirements.txt b/requirements.txt index bc0630d..b8c59f4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,4 +4,5 @@ openai psutil jina jcloud -jina-hubble-sdk \ No newline at end of file +jina-hubble-sdk +langchain \ No newline at end of file diff --git a/src/apis/gpt.py b/src/apis/gpt.py index d27983f..eb0a582 100644 --- a/src/apis/gpt.py +++ b/src/apis/gpt.py @@ -1,15 +1,21 @@ import os from time import sleep -from typing import List, Tuple, Optional +from typing import List, Any import openai -from openai.error import RateLimitError, Timeout, APIConnectionError +from langchain.callbacks import CallbackManager +from langchain.chat_models import ChatOpenAI +from openai.error import RateLimitError +from langchain.schema import ( + AIMessage, + HumanMessage, + SystemMessage, + BaseMessage +) +from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler -from src.constants import PRICING_GPT4_PROMPT, PRICING_GPT4_GENERATION, PRICING_GPT3_5_TURBO_PROMPT, \ - PRICING_GPT3_5_TURBO_GENERATION, CHARS_PER_TOKEN -from src.options.generate.prompt_system import system_base_definition, executor_example, docarray_example, client_example -from src.utils.io import timeout_generator_wrapper, GenerationTimeoutError +from src.options.generate.prompt_system import system_message_base, executor_example, docarray_example, client_example from src.utils.string_tools import print_colored @@ -18,19 +24,14 @@ class GPTSession: self.configure_openai_api_key() if model == 'gpt-4' and self.is_gpt4_available(): self.supported_model = 'gpt-4' - self.pricing_prompt = PRICING_GPT4_PROMPT - self.pricing_generation = PRICING_GPT4_GENERATION else: if model == 'gpt-4': print_colored('GPT-4 is not available. Using GPT-3.5-turbo instead.', 'yellow') model = 'gpt-3.5-turbo' self.supported_model = model - self.pricing_prompt = PRICING_GPT3_5_TURBO_PROMPT - self.pricing_generation = PRICING_GPT3_5_TURBO_GENERATION - self.chars_prompt_so_far = 0 - self.chars_generation_so_far = 0 - def configure_openai_api_key(self): + @staticmethod + def configure_openai_api_key(): if 'OPENAI_API_KEY' not in os.environ: raise Exception(''' You need to set OPENAI_API_KEY in your environment. @@ -39,7 +40,8 @@ If you have updated it already, please restart your terminal. ) openai.api_key = os.environ['OPENAI_API_KEY'] - def is_gpt4_available(self): + @staticmethod + def is_gpt4_available(): try: for i in range(5): try: @@ -58,87 +60,47 @@ If you have updated it already, please restart your terminal. except openai.error.InvalidRequestError: return False - def cost_callback(self, chars_prompt, chars_generation): - self.chars_prompt_so_far += chars_prompt - self.chars_generation_so_far += chars_generation - print('\n') - money_prompt = self.calculate_money_spent(self.chars_prompt_so_far, self.pricing_prompt) - money_generation = self.calculate_money_spent(self.chars_generation_so_far, self.pricing_generation) - print('Total money spent so far on openai.com:', f'${money_prompt + money_generation}') - print('\n') - def get_conversation(self, system_definition_examples: List[str] = ['executor', 'docarray', 'client']): - return _GPTConversation(self.supported_model, self.cost_callback, system_definition_examples) + return _GPTConversation(self.supported_model, system_definition_examples) - def calculate_money_spent(self, num_chars, price): - return round(num_chars / CHARS_PER_TOKEN * price / 1000, 3) + +class AssistantStreamingStdOutCallbackHandler(StreamingStdOutCallbackHandler): + def on_llm_new_token(self, token: str, **kwargs: Any) -> None: + """Run on new LLM token. Only available when streaming is enabled.""" + if os.environ['VERBOSE'].lower() == 'true': + print_colored('', token, 'green', end='') class _GPTConversation: - def __init__(self, model: str, cost_callback, system_definition_examples: List[str] = ['executor', 'docarray', 'client']): - self.model = model - self.cost_callback = cost_callback - self.prompt_list: List[Optional[Tuple]] = [None] - self.set_system_definition(system_definition_examples) + def __init__(self, model: str, system_definition_examples: List[str] = ['executor', 'docarray', 'client']): + self.chat = ChatOpenAI( + model_name=model, + streaming=True, + callback_manager=CallbackManager([AssistantStreamingStdOutCallbackHandler()]), + temperature=0 + ) + self.messages: List[BaseMessage] = [] + self.system_message = self._create_system_message(system_definition_examples) if os.environ['VERBOSE'].lower() == 'true': - print_colored('system', self.prompt_list[0][1], 'magenta') + print_colored('system', self.system_message.content, 'magenta') - def query(self, prompt: str): + def chat(self, prompt: str): + chat_message = HumanMessage(content=prompt) + self.messages.append(chat_message) if os.environ['VERBOSE'].lower() == 'true': print_colored('user', prompt, 'blue') - self.prompt_list.append(('user', prompt)) - response = self.get_response(self.prompt_list) - self.prompt_list.append(('assistant', response)) + print_colored('assistant', '', 'green', end='') + response = self.chat([self.system_message] + self.messages) + self.messages.append(AIMessage(content=response)) return response - def set_system_definition(self, system_definition_examples: List[str] = []): - system_message = system_base_definition + @staticmethod + def _create_system_message(system_definition_examples: List[str] = []) -> SystemMessage: + system_message = system_message_base if 'executor' in system_definition_examples: system_message += f'\n{executor_example}' if 'docarray' in system_definition_examples: system_message += f'\n{docarray_example}' if 'client' in system_definition_examples: system_message += f'\n{client_example}' - self.prompt_list[0] = ('system', system_message) - - def get_response_from_stream(self, response_generator): - response_generator_with_timeout = timeout_generator_wrapper(response_generator, 10) - complete_string = '' - for chunk in response_generator_with_timeout: - delta = chunk['choices'][0]['delta'] - if 'content' in delta: - content = delta['content'] - print_colored('' if complete_string else 'assistant', content, 'green', end='') - complete_string += content - return complete_string - - def get_response(self, prompt_list: List[Tuple[str, str]]): - for i in range(10): - try: - response_generator = openai.ChatCompletion.create( - temperature=0, - max_tokens=None, - model=self.model, - stream=True, - messages=[ - { - "role": prompt[0], - "content": prompt[1] - } - for prompt in prompt_list - ] - ) - - complete_string = self.get_response_from_stream(response_generator) - - except (RateLimitError, Timeout, ConnectionError, APIConnectionError, GenerationTimeoutError) as e: - print('/n', e) - print('retrying...') - sleep(3) - continue - chars_prompt = sum(len(prompt[1]) for prompt in prompt_list) - chars_generation = len(complete_string) - self.cost_callback(chars_prompt, chars_generation) - return complete_string - raise Exception('Failed to get response') - + return SystemMessage(content=system_message) diff --git a/src/cli.py b/src/cli.py index db24889..11f9ac4 100644 --- a/src/cli.py +++ b/src/cli.py @@ -2,6 +2,7 @@ import functools import os import click +from langchain.callbacks import get_openai_callback from src.apis.jina_cloud import jina_auth_login from src.options.configure.key_handling import set_api_key @@ -64,7 +65,11 @@ def generate( from src.options.generate.generator import Generator generator = Generator(model=model) - generator.generate(description, test, path) + with get_openai_callback() as cb: + generator.generate(description, test, path) + print(f"Prompt/Completion/Total Tokens: {cb.prompt_tokens}/{cb.completion_tokens}/{cb.total_tokens}") + print(f"Total Cost on OpenAI (USD): ${cb.total_cost}") + @main.command() @path_param diff --git a/src/constants.py b/src/constants.py index 6eef6ae..1bd9746 100644 --- a/src/constants.py +++ b/src/constants.py @@ -23,12 +23,6 @@ FILE_AND_TAG_PAIRS = [ FLOW_URL_PLACEHOLDER = 'jcloud.jina.ai' -PRICING_GPT4_PROMPT = 0.03 -PRICING_GPT4_GENERATION = 0.06 -PRICING_GPT3_5_TURBO_PROMPT = 0.002 -PRICING_GPT3_5_TURBO_GENERATION = 0.002 - -CHARS_PER_TOKEN = 3.4 NUM_IMPLEMENTATION_STRATEGIES = 5 MAX_DEBUGGING_ITERATIONS = 10 diff --git a/src/options/generate/generator.py b/src/options/generate/generator.py index 875b417..21cb508 100644 --- a/src/options/generate/generator.py +++ b/src/options/generate/generator.py @@ -74,15 +74,15 @@ class Generator: + executor_file_task(microservice_name, description, test, package) ) conversation = self.gpt_session.get_conversation() - microservice_content_raw = conversation.query(user_query) + microservice_content_raw = conversation.chat(user_query) if is_chain_of_thought: - microservice_content_raw = conversation.query( + microservice_content_raw = conversation.chat( f"General rules: " + not_allowed_executor() + chain_of_thought_optimization('python', 'microservice.py')) microservice_content = self.extract_content_from_result(microservice_content_raw, 'microservice.py', match_single_block=True) if microservice_content == '': - microservice_content_raw = conversation.query('You must add the executor code.') + microservice_content_raw = conversation.chat('You must add the executor code.') microservice_content = self.extract_content_from_result( microservice_content_raw, 'microservice.py', match_single_block=True ) @@ -95,9 +95,9 @@ class Generator: + test_executor_file_task(microservice_name, test) ) conversation = self.gpt_session.get_conversation() - test_microservice_content_raw = conversation.query(user_query) + test_microservice_content_raw = conversation.chat(user_query) if is_chain_of_thought: - test_microservice_content_raw = conversation.query( + test_microservice_content_raw = conversation.chat( f"General rules: " + not_allowed_executor() + chain_of_thought_optimization('python', 'test_microservice.py') + "Don't add any additional tests. " @@ -116,9 +116,9 @@ class Generator: + requirements_file_task() ) conversation = self.gpt_session.get_conversation() - requirements_content_raw = conversation.query(user_query) + requirements_content_raw = conversation.chat(user_query) if is_chain_of_thought: - requirements_content_raw = conversation.query( + requirements_content_raw = conversation.chat( chain_of_thought_optimization('', requirements_path) + "Keep the same version of jina ") requirements_content = self.extract_content_from_result(requirements_content_raw, 'requirements.txt', @@ -134,9 +134,9 @@ class Generator: + docker_file_task() ) conversation = self.gpt_session.get_conversation() - dockerfile_content_raw = conversation.query(user_query) + dockerfile_content_raw = conversation.chat(user_query) if is_chain_of_thought: - dockerfile_content_raw = conversation.query( + dockerfile_content_raw = conversation.chat( f"General rules: " + not_allowed_executor() + chain_of_thought_optimization('dockerfile', 'Dockerfile')) dockerfile_content = self.extract_content_from_result(dockerfile_content_raw, 'Dockerfile', match_single_block=True) @@ -172,8 +172,8 @@ The playground (app.py) must not let the user configure the host on the ui. ''' ) conversation = self.gpt_session.get_conversation([]) - conversation.query(user_query) - playground_content_raw = conversation.query(chain_of_thought_optimization('python', 'app.py', 'the playground')) + conversation.chat(user_query) + playground_content_raw = conversation.chat(chain_of_thought_optimization('python', 'app.py', 'the playground')) playground_content = self.extract_content_from_result(playground_content_raw, 'app.py', match_single_block=True) persist_file(playground_content, os.path.join(microservice_path, 'app.py')) @@ -213,7 +213,7 @@ The playground (app.py) must not let the user configure the host on the ui. user_query = self.get_user_query_code_issue(description, error, file_name_to_content, test) conversation = self.gpt_session.get_conversation() - returned_files_raw = conversation.query(user_query) + returned_files_raw = conversation.chat(user_query) for file_name, tag in FILE_AND_TAG_PAIRS: updated_file = self.extract_content_from_result(returned_files_raw, file_name) if updated_file and (not is_dependency_issue or file_name in ['requirements.txt', 'Dockerfile']): @@ -280,7 +280,7 @@ complete file. Use the exact same syntax to wrap the code: print_colored('', 'Is it a dependency issue?', 'blue') conversation = self.gpt_session.get_conversation([]) - answer = conversation.query( + answer = conversation.chat( f'Your task is to assist in identifying the root cause of a Docker build error for a python application. ' f'The error message is as follows::\n\n{error}\n\n' f'The docker file is as follows:\n\n{docker_file}\n\n' @@ -305,7 +305,7 @@ The output is a the raw string wrapped into ``` and starting with **name.txt** l PDFParserExecutor ``` ''' - name_raw = conversation.query(user_query) + name_raw = conversation.chat(user_query) name = self.extract_content_from_result(name_raw, 'name.txt') return name @@ -341,7 +341,7 @@ package5 ``` ''' conversation = self.gpt_session.get_conversation() - packages_raw = conversation.query(user_query) + packages_raw = conversation.chat(user_query) packages_csv_string = self.extract_content_from_result(packages_raw, 'packages.csv') packages = [package.split(',') for package in packages_csv_string.split('\n')] packages = packages[:NUM_IMPLEMENTATION_STRATEGIES] @@ -351,13 +351,17 @@ package5 generated_name = self.generate_microservice_name(description) microservice_name = f'{generated_name}{random.randint(0, 10_000_000)}' packages_list = self.get_possible_packages(description) - packages_list = [packages for packages in packages_list if len(set(packages).intersection(set(PROBLEMATIC_PACKAGES))) == 0] + packages_list = [ + packages for packages in packages_list if len(set(packages).intersection(set(PROBLEMATIC_PACKAGES))) == 0 + ] for num_approach, packages in enumerate(packages_list): try: - self.generate_microservice(description, test, microservice_path, microservice_name, packages, - num_approach) - final_version_path = self.debug_microservice(microservice_path, microservice_name, num_approach, - packages, description, test) + self.generate_microservice( + description, test, microservice_path, microservice_name, packages, num_approach + ) + final_version_path = self.debug_microservice( + microservice_path, microservice_name, num_approach, packages, description, test + ) self.generate_playground(microservice_name, final_version_path) except self.MaxDebugTimeReachedException: print('Could not debug the Microservice with the approach:', packages) diff --git a/src/options/generate/prompt_system.py b/src/options/generate/prompt_system.py index 8b0c987..64ad82e 100644 --- a/src/options/generate/prompt_system.py +++ b/src/options/generate/prompt_system.py @@ -71,7 +71,7 @@ print(response[0].text) ```''' -system_base_definition = f''' +system_message_base = f''' It is the year 2021. You are a principal engineer working at Jina - an open source company. You accurately satisfy all of the user's requirements. diff --git a/src/utils/io.py b/src/utils/io.py index c5cf727..272c878 100644 --- a/src/utils/io.py +++ b/src/utils/io.py @@ -31,28 +31,6 @@ def get_all_microservice_files_with_content(folder_path): return file_name_to_content -class GenerationTimeoutError(Exception): - pass - -def timeout_generator_wrapper(generator, timeout): - def generator_func(): - for item in generator: - yield item - - def wrapper() -> Generator: - gen = generator_func() - while True: - try: - with concurrent.futures.ThreadPoolExecutor() as executor: - future = executor.submit(next, gen) - yield future.result(timeout=timeout) - except StopIteration: - break - except concurrent.futures.TimeoutError: - raise GenerationTimeoutError(f"Generation took too long") - - return wrapper() - @contextmanager def suppress_stdout(): original_stdout = sys.stdout