Merge pull request #17 from jina-ai/fix-add-cost-estimation

fix: use old callback for cost calculation
This commit is contained in:
Florian Hönicke
2023-04-19 13:45:20 +02:00
committed by GitHub
2 changed files with 41 additions and 6 deletions

View File

@@ -8,9 +8,11 @@ from langchain import PromptTemplate
from langchain.callbacks import CallbackManager
from langchain.chat_models import ChatOpenAI
from openai.error import RateLimitError
from langchain.schema import AIMessage, HumanMessage, SystemMessage, BaseMessage
from langchain.schema import HumanMessage, SystemMessage, BaseMessage
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from src.constants import PRICING_GPT4_PROMPT, PRICING_GPT4_GENERATION, PRICING_GPT3_5_TURBO_PROMPT, \
PRICING_GPT3_5_TURBO_GENERATION, CHARS_PER_TOKEN
from src.options.generate.templates_system import template_system_message_base, executor_example, docarray_example, client_example
from src.utils.string_tools import print_colored
@@ -20,7 +22,23 @@ class GPTSession:
self.task_description = task_description
self.test_description = test_description
self.configure_openai_api_key()
self.model_name = 'gpt-4' if model == 'gpt-4' and self.is_gpt4_available() else 'gpt-3.5-turbo'
if model == 'gpt-4' and self.is_gpt4_available():
self.pricing_prompt = PRICING_GPT4_PROMPT
self.pricing_generation = PRICING_GPT4_GENERATION
else:
if model == 'gpt-4':
print_colored('GPT version info', 'GPT-4 is not available. Using GPT-3.5-turbo instead.', 'yellow')
model = 'gpt-3.5-turbo'
self.pricing_prompt = PRICING_GPT3_5_TURBO_PROMPT
self.pricing_generation = PRICING_GPT3_5_TURBO_GENERATION
self.model_name = model
self.chars_prompt_so_far = 0
self.chars_generation_so_far = 0
def get_conversation(self, system_definition_examples: List[str] = ['executor', 'docarray', 'client']):
return _GPTConversation(
self.model_name, self.cost_callback, self.task_description, self.test_description, system_definition_examples
)
@staticmethod
def configure_openai_api_key():
@@ -50,11 +68,20 @@ If you have updated it already, please restart your terminal.
continue
return True
except openai.error.InvalidRequestError:
print_colored('GPT version info', 'GPT-4 is not available. Using GPT-3.5-turbo instead.', 'yellow')
return False
def get_conversation(self, system_definition_examples: List[str] = ['executor', 'docarray', 'client']):
return _GPTConversation(self.model_name, self.task_description, self.test_description, system_definition_examples)
def cost_callback(self, chars_prompt, chars_generation):
self.chars_prompt_so_far += chars_prompt
self.chars_generation_so_far += chars_generation
print('\n')
money_prompt = self._calculate_money_spent(self.chars_prompt_so_far, self.pricing_prompt)
money_generation = self._calculate_money_spent(self.chars_generation_so_far, self.pricing_generation)
print('Total money spent so far on openai.com:', f'${money_prompt + money_generation}')
print('\n')
@staticmethod
def _calculate_money_spent(num_chars, price):
return round(num_chars / CHARS_PER_TOKEN * price / 1000, 3)
class AssistantStreamingStdOutCallbackHandler(StreamingStdOutCallbackHandler):
@@ -64,7 +91,7 @@ class AssistantStreamingStdOutCallbackHandler(StreamingStdOutCallbackHandler):
class _GPTConversation:
def __init__(self, model: str, task_description, test_description, system_definition_examples: List[str] = ['executor', 'docarray', 'client']):
def __init__(self, model: str, cost_callback, task_description, test_description, system_definition_examples: List[str] = ['executor', 'docarray', 'client']):
self._chat = ChatOpenAI(
model_name=model,
streaming=True,
@@ -72,6 +99,7 @@ class _GPTConversation:
verbose=True,
temperature=0,
)
self.cost_callback = cost_callback
self.messages: List[BaseMessage] = []
self.system_message = self._create_system_message(task_description, test_description, system_definition_examples)
if os.environ['VERBOSE'].lower() == 'true':
@@ -86,6 +114,7 @@ class _GPTConversation:
response = self._chat([self.system_message] + self.messages)
if os.environ['VERBOSE'].lower() == 'true':
print()
self.cost_callback(sum([len(m.content) for m in self.messages]), len(response.content))
self.messages.append(response)
return response.content

View File

@@ -23,6 +23,12 @@ FILE_AND_TAG_PAIRS = [
FLOW_URL_PLACEHOLDER = 'jcloud.jina.ai'
PRICING_GPT4_PROMPT = 0.03
PRICING_GPT4_GENERATION = 0.06
PRICING_GPT3_5_TURBO_PROMPT = 0.002
PRICING_GPT3_5_TURBO_GENERATION = 0.002
CHARS_PER_TOKEN = 3.4
NUM_IMPLEMENTATION_STRATEGIES = 5
MAX_DEBUGGING_ITERATIONS = 10