From d878a782d527d60d0e140c6f698cb03486619f95 Mon Sep 17 00:00:00 2001 From: Joschka Braun Date: Wed, 19 Apr 2023 13:39:59 +0200 Subject: [PATCH] fix: use old callback for cost calculation --- src/apis/gpt.py | 35 ++++++++++++++++++++++++++++++----- src/constants.py | 6 ++++++ 2 files changed, 36 insertions(+), 5 deletions(-) diff --git a/src/apis/gpt.py b/src/apis/gpt.py index 7a86524..31b5801 100644 --- a/src/apis/gpt.py +++ b/src/apis/gpt.py @@ -8,9 +8,11 @@ from langchain import PromptTemplate from langchain.callbacks import CallbackManager from langchain.chat_models import ChatOpenAI from openai.error import RateLimitError -from langchain.schema import AIMessage, HumanMessage, SystemMessage, BaseMessage +from langchain.schema import HumanMessage, SystemMessage, BaseMessage from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler +from src.constants import PRICING_GPT4_PROMPT, PRICING_GPT4_GENERATION, PRICING_GPT3_5_TURBO_PROMPT, \ + PRICING_GPT3_5_TURBO_GENERATION from src.options.generate.templates_system import template_system_message_base, executor_example, docarray_example, client_example from src.utils.string_tools import print_colored @@ -20,7 +22,18 @@ class GPTSession: self.task_description = task_description self.test_description = test_description self.configure_openai_api_key() - self.model_name = 'gpt-4' if model == 'gpt-4' and self.is_gpt4_available() else 'gpt-3.5-turbo' + if model == 'gpt-4' and self.is_gpt4_available(): + self.pricing_prompt = PRICING_GPT4_PROMPT + self.pricing_generation = PRICING_GPT4_GENERATION + else: + if model == 'gpt-4': + print_colored('GPT version info', 'GPT-4 is not available. Using GPT-3.5-turbo instead.', 'yellow') + model = 'gpt-3.5-turbo' + self.pricing_prompt = PRICING_GPT3_5_TURBO_PROMPT + self.pricing_generation = PRICING_GPT3_5_TURBO_GENERATION + self.model_name = model + self.chars_prompt_so_far = 0 + self.chars_generation_so_far = 0 @staticmethod def configure_openai_api_key(): @@ -50,11 +63,21 @@ If you have updated it already, please restart your terminal. continue return True except openai.error.InvalidRequestError: - print_colored('GPT version info', 'GPT-4 is not available. Using GPT-3.5-turbo instead.', 'yellow') return False + def cost_callback(self, chars_prompt, chars_generation): + self.chars_prompt_so_far += chars_prompt + self.chars_generation_so_far += chars_generation + print('\n') + money_prompt = self.calculate_money_spent(self.chars_prompt_so_far, self.pricing_prompt) + money_generation = self.calculate_money_spent(self.chars_generation_so_far, self.pricing_generation) + print('Total money spent so far on openai.com:', f'${money_prompt + money_generation}') + print('\n') + def get_conversation(self, system_definition_examples: List[str] = ['executor', 'docarray', 'client']): - return _GPTConversation(self.model_name, self.task_description, self.test_description, system_definition_examples) + return _GPTConversation( + self.model_name, self.cost_callback, self.task_description, self.test_description, system_definition_examples + ) class AssistantStreamingStdOutCallbackHandler(StreamingStdOutCallbackHandler): @@ -64,7 +87,7 @@ class AssistantStreamingStdOutCallbackHandler(StreamingStdOutCallbackHandler): class _GPTConversation: - def __init__(self, model: str, task_description, test_description, system_definition_examples: List[str] = ['executor', 'docarray', 'client']): + def __init__(self, model: str, cost_callback, task_description, test_description, system_definition_examples: List[str] = ['executor', 'docarray', 'client']): self._chat = ChatOpenAI( model_name=model, streaming=True, @@ -72,6 +95,7 @@ class _GPTConversation: verbose=True, temperature=0, ) + self.cost_callback = cost_callback self.messages: List[BaseMessage] = [] self.system_message = self._create_system_message(task_description, test_description, system_definition_examples) if os.environ['VERBOSE'].lower() == 'true': @@ -86,6 +110,7 @@ class _GPTConversation: response = self._chat([self.system_message] + self.messages) if os.environ['VERBOSE'].lower() == 'true': print() + self.cost_callback(sum([m.content for m in self.messages]), len(response.content)) self.messages.append(response) return response.content diff --git a/src/constants.py b/src/constants.py index c1b3687..0d0dae0 100644 --- a/src/constants.py +++ b/src/constants.py @@ -23,6 +23,12 @@ FILE_AND_TAG_PAIRS = [ FLOW_URL_PLACEHOLDER = 'jcloud.jina.ai' +PRICING_GPT4_PROMPT = 0.03 +PRICING_GPT4_GENERATION = 0.06 +PRICING_GPT3_5_TURBO_PROMPT = 0.002 +PRICING_GPT3_5_TURBO_GENERATION = 0.002 + +CHARS_PER_TOKEN = 3.4 NUM_IMPLEMENTATION_STRATEGIES = 5 MAX_DEBUGGING_ITERATIONS = 10