mirror of
https://github.com/aljazceru/goose.git
synced 2026-01-09 09:24:27 +01:00
fix: Cost calculation enhancement (#207)
This commit is contained in:
@@ -2,18 +2,19 @@ import json
|
||||
import traceback
|
||||
from copy import deepcopy
|
||||
from typing import Mapping
|
||||
from attrs import define, evolve, field, Factory
|
||||
from exchange.langfuse_wrapper import observe_wrapper
|
||||
|
||||
from attrs import Factory, define, evolve, field
|
||||
from tiktoken import get_encoding
|
||||
|
||||
from exchange.checkpoint import Checkpoint, CheckpointData
|
||||
from exchange.content import Text, ToolResult, ToolUse
|
||||
from exchange.langfuse_wrapper import observe_wrapper
|
||||
from exchange.message import Message
|
||||
from exchange.moderators import Moderator
|
||||
from exchange.moderators.truncate import ContextTruncate
|
||||
from exchange.providers import Provider, Usage
|
||||
from exchange.tool import Tool
|
||||
from exchange.token_usage_collector import _token_usage_collector
|
||||
from exchange.tool import Tool
|
||||
|
||||
|
||||
def validate_tool_output(output: str) -> None:
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from datetime import datetime
|
||||
import logging
|
||||
import traceback
|
||||
from pathlib import Path
|
||||
@@ -168,6 +169,7 @@ class Session:
|
||||
Args:
|
||||
new_session (bool): True when starting a new session, False when resuming.
|
||||
"""
|
||||
time_start = datetime.now()
|
||||
if is_existing_session(self.session_file_path) and new_session:
|
||||
self._prompt_overwrite_session()
|
||||
|
||||
@@ -196,7 +198,8 @@ class Session:
|
||||
message = Message.user(text=user_input.text) if user_input.to_continue() else None
|
||||
|
||||
self._remove_empty_session()
|
||||
self._log_cost()
|
||||
time_end = datetime.now()
|
||||
self._log_cost(start_time=time_start, end_time=time_end)
|
||||
|
||||
@observe_wrapper()
|
||||
def reply(self) -> None:
|
||||
@@ -281,8 +284,8 @@ class Session:
|
||||
def load_session(self) -> list[Message]:
|
||||
return read_or_create_file(self.session_file_path)
|
||||
|
||||
def _log_cost(self) -> None:
|
||||
get_logger().info(get_total_cost_message(self.exchange.get_token_usage()))
|
||||
def _log_cost(self, start_time: datetime, end_time: datetime) -> None:
|
||||
get_logger().info(get_total_cost_message(self.exchange.get_token_usage(), self.name, start_time, end_time))
|
||||
print(f"[dim]you can view the cost and token usage in the log directory {LOG_PATH}[/]")
|
||||
|
||||
def _prompt_overwrite_session(self) -> None:
|
||||
|
||||
@@ -1,6 +1,10 @@
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from exchange.providers.base import Usage
|
||||
|
||||
from goose.utils.time_utils import formatted_time
|
||||
|
||||
PRICES = {
|
||||
"gpt-4o": (2.50, 10.00),
|
||||
"gpt-4o-2024-08-06": (2.50, 10.00),
|
||||
@@ -8,14 +12,26 @@ PRICES = {
|
||||
"gpt-4o-mini": (0.150, 0.600),
|
||||
"gpt-4o-mini-2024-07-18": (0.150, 0.600),
|
||||
"o1-preview": (15.00, 60.00),
|
||||
"o1-preview-2024-09-12": (15.00, 60.00),
|
||||
"o1-mini": (3.00, 12.00),
|
||||
"claude-3-5-sonnet-20240620": (3.00, 15.00),
|
||||
"o1-mini-2024-09-12": (3.00, 12.00),
|
||||
"claude-3-5-sonnet-latest": (3.00, 15.00), # Claude 3.5 Sonnet model
|
||||
"claude-3-5-sonnet-2": (3.00, 15.00),
|
||||
"claude-3-5-sonnet-20241022": (3.00, 15.00),
|
||||
"anthropic.claude-3-5-sonnet-20241022-v2:0": (3.00, 15.00),
|
||||
"claude-3-5-sonnet-v2@20241022": (3.00, 15.00),
|
||||
"claude-3-5-sonnet-20240620": (3.00, 15.00),
|
||||
"anthropic.claude-3-5-sonnet-20240620-v1:0": (3.00, 15.00),
|
||||
"claude-3-opus-latest": (15.00, 75.00), # Claude Opus 3 model
|
||||
"claude-3-opus-20240229": (15.00, 75.00),
|
||||
"anthropic.claude-3-opus-20240229-v1:0": (15.00, 75.00),
|
||||
"claude-3-haiku-20240307": (0.25, 1.25),
|
||||
"claude-3-opus@20240229": (15.00, 75.00),
|
||||
"claude-3-sonnet-20240229": (3.00, 15.00), # Claude Sonnet 3 model
|
||||
"anthropic.claude-3-sonnet-20240229-v1:0": (3.00, 15.00),
|
||||
"claude-3-sonnet@20240229": (3.00, 15.00),
|
||||
"claude-3-haiku-20240307": (0.25, 1.25), # Claude Haiku 3 model
|
||||
"anthropic.claude-3-haiku-20240307-v1:0": (0.25, 1.25),
|
||||
"claude-3-haiku@20240307": (0.25, 1.25),
|
||||
}
|
||||
|
||||
|
||||
@@ -27,15 +43,20 @@ def _calculate_cost(model: str, token_usage: Usage) -> Optional[float]:
|
||||
return None
|
||||
|
||||
|
||||
def get_total_cost_message(token_usages: dict[str, Usage]) -> str:
|
||||
def get_total_cost_message(
|
||||
token_usages: dict[str, Usage], session_name: str, start_time: datetime, end_time: datetime
|
||||
) -> str:
|
||||
total_cost = 0
|
||||
message = ""
|
||||
session_name_prefix = f"Session name: {session_name}"
|
||||
for model, token_usage in token_usages.items():
|
||||
cost = _calculate_cost(model, token_usage)
|
||||
if cost is not None:
|
||||
message += f"Cost for model {model} {str(token_usage)}: ${cost:.2f}\n"
|
||||
message += f"{session_name_prefix} | Cost for model {model} {str(token_usage)}: ${cost:.2f}\n"
|
||||
total_cost += cost
|
||||
else:
|
||||
message += f"Cost for model {model} {str(token_usage)}: Not available\n"
|
||||
message += f"Total cost: ${total_cost:.2f}"
|
||||
return message
|
||||
message += f"{session_name_prefix} | Cost for model {model} {str(token_usage)}: Not available\n"
|
||||
|
||||
datetime_range = f"{formatted_time(start_time)} - {formatted_time(end_time)}"
|
||||
summary = f"{datetime_range} | {session_name_prefix} | Total cost: ${total_cost:.2f}"
|
||||
return message + summary
|
||||
|
||||
5
src/goose/utils/time_utils.py
Normal file
5
src/goose/utils/time_utils.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
def formatted_time(time: datetime) -> str:
|
||||
return time.astimezone().isoformat(timespec="seconds")
|
||||
@@ -1,3 +1,4 @@
|
||||
from datetime import datetime
|
||||
import os
|
||||
from typing import Union
|
||||
from unittest.mock import MagicMock, mock_open, patch
|
||||
@@ -159,13 +160,15 @@ def test_process_first_message_return_last_exchange_message(create_session_with_
|
||||
def test_log_log_cost(create_session_with_mock_configs):
|
||||
session = create_session_with_mock_configs()
|
||||
mock_logger = MagicMock()
|
||||
start_time = datetime(2024, 10, 20, 1, 2, 3)
|
||||
end_time = datetime(2024, 10, 21, 2, 3, 4)
|
||||
cost_message = "You have used 100 tokens"
|
||||
with (
|
||||
patch("exchange.Exchange.get_token_usage", return_value={}),
|
||||
patch("goose.cli.session.get_total_cost_message", return_value=cost_message),
|
||||
patch("goose.cli.session.get_logger", return_value=mock_logger),
|
||||
):
|
||||
session._log_cost()
|
||||
session._log_cost(start_time, end_time)
|
||||
mock_logger.info.assert_called_once_with(cost_message)
|
||||
|
||||
|
||||
|
||||
@@ -1,7 +1,27 @@
|
||||
from unittest.mock import patch
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from goose.utils._cost_calculator import _calculate_cost, get_total_cost_message
|
||||
from exchange.providers.base import Usage
|
||||
from goose.utils._cost_calculator import _calculate_cost, get_total_cost_message
|
||||
|
||||
SESSION_NAME = "test_session"
|
||||
START_TIME = datetime(2024, 10, 20, 1, 2, 3, tzinfo=timezone.utc)
|
||||
END_TIME = datetime(2024, 10, 21, 2, 3, 4, tzinfo=timezone.utc)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def start_time():
|
||||
mock_start_time = MagicMock(spec=datetime)
|
||||
mock_start_time.astimezone.return_value = START_TIME
|
||||
return mock_start_time
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def end_time():
|
||||
mock_end_time = MagicMock(spec=datetime)
|
||||
mock_end_time.astimezone.return_value = END_TIME
|
||||
return mock_end_time
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -16,32 +36,41 @@ def test_calculate_cost(mock_prices):
|
||||
assert cost == 0.059
|
||||
|
||||
|
||||
def test_get_total_cost_message(mock_prices):
|
||||
def test_get_total_cost_message(mock_prices, start_time, end_time):
|
||||
message = get_total_cost_message(
|
||||
{
|
||||
"gpt-4o": Usage(input_tokens=10000, output_tokens=600, total_tokens=10600),
|
||||
"gpt-4o-mini": Usage(input_tokens=3000000, output_tokens=4000000, total_tokens=7000000),
|
||||
}
|
||||
},
|
||||
SESSION_NAME,
|
||||
start_time,
|
||||
end_time,
|
||||
)
|
||||
expected_message = (
|
||||
"Cost for model gpt-4o Usage(input_tokens=10000, output_tokens=600, total_tokens=10600): $0.06\n"
|
||||
+ "Cost for model gpt-4o-mini Usage(input_tokens=3000000, output_tokens=4000000, total_tokens=7000000)"
|
||||
+ ": $2.85\nTotal cost: $2.91"
|
||||
"Session name: test_session | Cost for model gpt-4o Usage(input_tokens=10000, output_tokens=600,"
|
||||
" total_tokens=10600): $0.06\n"
|
||||
"Session name: test_session | Cost for model gpt-4o-mini Usage(input_tokens=3000000, output_tokens=4000000, "
|
||||
"total_tokens=7000000): $2.85\n"
|
||||
"2024-10-20T01:02:03+00:00 - 2024-10-21T02:03:04+00:00 | Session name: test_session | Total cost: $2.91"
|
||||
)
|
||||
assert message == expected_message
|
||||
|
||||
|
||||
def test_get_total_cost_message_with_non_available_pricing(mock_prices):
|
||||
def test_get_total_cost_message_with_non_available_pricing(mock_prices, start_time, end_time):
|
||||
message = get_total_cost_message(
|
||||
{
|
||||
"non_pricing_model": Usage(input_tokens=10000, output_tokens=600, total_tokens=10600),
|
||||
"gpt-4o-mini": Usage(input_tokens=3000000, output_tokens=4000000, total_tokens=7000000),
|
||||
}
|
||||
},
|
||||
SESSION_NAME,
|
||||
start_time,
|
||||
end_time,
|
||||
)
|
||||
expected_message = (
|
||||
"Cost for model non_pricing_model Usage(input_tokens=10000, output_tokens=600, total_tokens=10600): "
|
||||
+ "Not available\n"
|
||||
+ "Cost for model gpt-4o-mini Usage(input_tokens=3000000, output_tokens=4000000, total_tokens=7000000)"
|
||||
+ ": $2.85\nTotal cost: $2.85"
|
||||
"Session name: test_session | Cost for model non_pricing_model Usage(input_tokens=10000, output_tokens=600,"
|
||||
" total_tokens=10600): Not available\n"
|
||||
+ "Session name: test_session | Cost for model gpt-4o-mini Usage(input_tokens=3000000, output_tokens=4000000,"
|
||||
" total_tokens=7000000): $2.85\n"
|
||||
+ "2024-10-20T01:02:03+00:00 - 2024-10-21T02:03:04+00:00 | Session name: test_session | Total cost: $2.85"
|
||||
)
|
||||
assert message == expected_message
|
||||
|
||||
Reference in New Issue
Block a user