fix: Cost calculation enhancement (#207)

This commit is contained in:
Lifei Zhou
2024-10-31 13:58:51 +11:00
committed by GitHub
parent 0ab1966e93
commit aa324ce507
6 changed files with 89 additions and 27 deletions

View File

@@ -2,18 +2,19 @@ import json
import traceback
from copy import deepcopy
from typing import Mapping
from attrs import define, evolve, field, Factory
from exchange.langfuse_wrapper import observe_wrapper
from attrs import Factory, define, evolve, field
from tiktoken import get_encoding
from exchange.checkpoint import Checkpoint, CheckpointData
from exchange.content import Text, ToolResult, ToolUse
from exchange.langfuse_wrapper import observe_wrapper
from exchange.message import Message
from exchange.moderators import Moderator
from exchange.moderators.truncate import ContextTruncate
from exchange.providers import Provider, Usage
from exchange.tool import Tool
from exchange.token_usage_collector import _token_usage_collector
from exchange.tool import Tool
def validate_tool_output(output: str) -> None:

View File

@@ -1,3 +1,4 @@
from datetime import datetime
import logging
import traceback
from pathlib import Path
@@ -168,6 +169,7 @@ class Session:
Args:
new_session (bool): True when starting a new session, False when resuming.
"""
time_start = datetime.now()
if is_existing_session(self.session_file_path) and new_session:
self._prompt_overwrite_session()
@@ -196,7 +198,8 @@ class Session:
message = Message.user(text=user_input.text) if user_input.to_continue() else None
self._remove_empty_session()
self._log_cost()
time_end = datetime.now()
self._log_cost(start_time=time_start, end_time=time_end)
@observe_wrapper()
def reply(self) -> None:
@@ -281,8 +284,8 @@ class Session:
def load_session(self) -> list[Message]:
return read_or_create_file(self.session_file_path)
def _log_cost(self) -> None:
get_logger().info(get_total_cost_message(self.exchange.get_token_usage()))
def _log_cost(self, start_time: datetime, end_time: datetime) -> None:
get_logger().info(get_total_cost_message(self.exchange.get_token_usage(), self.name, start_time, end_time))
print(f"[dim]you can view the cost and token usage in the log directory {LOG_PATH}[/]")
def _prompt_overwrite_session(self) -> None:

View File

@@ -1,6 +1,10 @@
from datetime import datetime
from typing import Optional
from exchange.providers.base import Usage
from goose.utils.time_utils import formatted_time
PRICES = {
"gpt-4o": (2.50, 10.00),
"gpt-4o-2024-08-06": (2.50, 10.00),
@@ -8,14 +12,26 @@ PRICES = {
"gpt-4o-mini": (0.150, 0.600),
"gpt-4o-mini-2024-07-18": (0.150, 0.600),
"o1-preview": (15.00, 60.00),
"o1-preview-2024-09-12": (15.00, 60.00),
"o1-mini": (3.00, 12.00),
"claude-3-5-sonnet-20240620": (3.00, 15.00),
"o1-mini-2024-09-12": (3.00, 12.00),
"claude-3-5-sonnet-latest": (3.00, 15.00), # Claude 3.5 Sonnet model
"claude-3-5-sonnet-2": (3.00, 15.00),
"claude-3-5-sonnet-20241022": (3.00, 15.00),
"anthropic.claude-3-5-sonnet-20241022-v2:0": (3.00, 15.00),
"claude-3-5-sonnet-v2@20241022": (3.00, 15.00),
"claude-3-5-sonnet-20240620": (3.00, 15.00),
"anthropic.claude-3-5-sonnet-20240620-v1:0": (3.00, 15.00),
"claude-3-opus-latest": (15.00, 75.00), # Claude Opus 3 model
"claude-3-opus-20240229": (15.00, 75.00),
"anthropic.claude-3-opus-20240229-v1:0": (15.00, 75.00),
"claude-3-haiku-20240307": (0.25, 1.25),
"claude-3-opus@20240229": (15.00, 75.00),
"claude-3-sonnet-20240229": (3.00, 15.00), # Claude Sonnet 3 model
"anthropic.claude-3-sonnet-20240229-v1:0": (3.00, 15.00),
"claude-3-sonnet@20240229": (3.00, 15.00),
"claude-3-haiku-20240307": (0.25, 1.25), # Claude Haiku 3 model
"anthropic.claude-3-haiku-20240307-v1:0": (0.25, 1.25),
"claude-3-haiku@20240307": (0.25, 1.25),
}
@@ -27,15 +43,20 @@ def _calculate_cost(model: str, token_usage: Usage) -> Optional[float]:
return None
def get_total_cost_message(token_usages: dict[str, Usage]) -> str:
def get_total_cost_message(
token_usages: dict[str, Usage], session_name: str, start_time: datetime, end_time: datetime
) -> str:
total_cost = 0
message = ""
session_name_prefix = f"Session name: {session_name}"
for model, token_usage in token_usages.items():
cost = _calculate_cost(model, token_usage)
if cost is not None:
message += f"Cost for model {model} {str(token_usage)}: ${cost:.2f}\n"
message += f"{session_name_prefix} | Cost for model {model} {str(token_usage)}: ${cost:.2f}\n"
total_cost += cost
else:
message += f"Cost for model {model} {str(token_usage)}: Not available\n"
message += f"Total cost: ${total_cost:.2f}"
return message
message += f"{session_name_prefix} | Cost for model {model} {str(token_usage)}: Not available\n"
datetime_range = f"{formatted_time(start_time)} - {formatted_time(end_time)}"
summary = f"{datetime_range} | {session_name_prefix} | Total cost: ${total_cost:.2f}"
return message + summary

View File

@@ -0,0 +1,5 @@
from datetime import datetime
def formatted_time(time: datetime) -> str:
return time.astimezone().isoformat(timespec="seconds")

View File

@@ -1,3 +1,4 @@
from datetime import datetime
import os
from typing import Union
from unittest.mock import MagicMock, mock_open, patch
@@ -159,13 +160,15 @@ def test_process_first_message_return_last_exchange_message(create_session_with_
def test_log_log_cost(create_session_with_mock_configs):
session = create_session_with_mock_configs()
mock_logger = MagicMock()
start_time = datetime(2024, 10, 20, 1, 2, 3)
end_time = datetime(2024, 10, 21, 2, 3, 4)
cost_message = "You have used 100 tokens"
with (
patch("exchange.Exchange.get_token_usage", return_value={}),
patch("goose.cli.session.get_total_cost_message", return_value=cost_message),
patch("goose.cli.session.get_logger", return_value=mock_logger),
):
session._log_cost()
session._log_cost(start_time, end_time)
mock_logger.info.assert_called_once_with(cost_message)

View File

@@ -1,7 +1,27 @@
from unittest.mock import patch
from datetime import datetime, timezone
from unittest.mock import MagicMock, patch
import pytest
from goose.utils._cost_calculator import _calculate_cost, get_total_cost_message
from exchange.providers.base import Usage
from goose.utils._cost_calculator import _calculate_cost, get_total_cost_message
SESSION_NAME = "test_session"
START_TIME = datetime(2024, 10, 20, 1, 2, 3, tzinfo=timezone.utc)
END_TIME = datetime(2024, 10, 21, 2, 3, 4, tzinfo=timezone.utc)
@pytest.fixture
def start_time():
mock_start_time = MagicMock(spec=datetime)
mock_start_time.astimezone.return_value = START_TIME
return mock_start_time
@pytest.fixture
def end_time():
mock_end_time = MagicMock(spec=datetime)
mock_end_time.astimezone.return_value = END_TIME
return mock_end_time
@pytest.fixture
@@ -16,32 +36,41 @@ def test_calculate_cost(mock_prices):
assert cost == 0.059
def test_get_total_cost_message(mock_prices):
def test_get_total_cost_message(mock_prices, start_time, end_time):
message = get_total_cost_message(
{
"gpt-4o": Usage(input_tokens=10000, output_tokens=600, total_tokens=10600),
"gpt-4o-mini": Usage(input_tokens=3000000, output_tokens=4000000, total_tokens=7000000),
}
},
SESSION_NAME,
start_time,
end_time,
)
expected_message = (
"Cost for model gpt-4o Usage(input_tokens=10000, output_tokens=600, total_tokens=10600): $0.06\n"
+ "Cost for model gpt-4o-mini Usage(input_tokens=3000000, output_tokens=4000000, total_tokens=7000000)"
+ ": $2.85\nTotal cost: $2.91"
"Session name: test_session | Cost for model gpt-4o Usage(input_tokens=10000, output_tokens=600,"
" total_tokens=10600): $0.06\n"
"Session name: test_session | Cost for model gpt-4o-mini Usage(input_tokens=3000000, output_tokens=4000000, "
"total_tokens=7000000): $2.85\n"
"2024-10-20T01:02:03+00:00 - 2024-10-21T02:03:04+00:00 | Session name: test_session | Total cost: $2.91"
)
assert message == expected_message
def test_get_total_cost_message_with_non_available_pricing(mock_prices):
def test_get_total_cost_message_with_non_available_pricing(mock_prices, start_time, end_time):
message = get_total_cost_message(
{
"non_pricing_model": Usage(input_tokens=10000, output_tokens=600, total_tokens=10600),
"gpt-4o-mini": Usage(input_tokens=3000000, output_tokens=4000000, total_tokens=7000000),
}
},
SESSION_NAME,
start_time,
end_time,
)
expected_message = (
"Cost for model non_pricing_model Usage(input_tokens=10000, output_tokens=600, total_tokens=10600): "
+ "Not available\n"
+ "Cost for model gpt-4o-mini Usage(input_tokens=3000000, output_tokens=4000000, total_tokens=7000000)"
+ ": $2.85\nTotal cost: $2.85"
"Session name: test_session | Cost for model non_pricing_model Usage(input_tokens=10000, output_tokens=600,"
" total_tokens=10600): Not available\n"
+ "Session name: test_session | Cost for model gpt-4o-mini Usage(input_tokens=3000000, output_tokens=4000000,"
" total_tokens=7000000): $2.85\n"
+ "2024-10-20T01:02:03+00:00 - 2024-10-21T02:03:04+00:00 | Session name: test_session | Total cost: $2.85"
)
assert message == expected_message