mirror of
https://github.com/aljazceru/goose.git
synced 2025-12-26 10:34:22 +01:00
feat: track cost and token usage in log file (#80)
This commit is contained in:
@@ -8,7 +8,7 @@ dependencies = [
|
||||
"attrs>=23.2.0",
|
||||
"rich>=13.7.1",
|
||||
"ruamel-yaml>=0.18.6",
|
||||
"ai-exchange>=0.9.0",
|
||||
"ai-exchange>=0.9.2",
|
||||
"click>=8.1.7",
|
||||
"prompt-toolkit>=3.0.47",
|
||||
]
|
||||
@@ -62,5 +62,3 @@ dev-dependencies = [
|
||||
"mkdocs-include-markdown-plugin>=6.2.2",
|
||||
"mkdocs-callouts>=1.14.0",
|
||||
]
|
||||
|
||||
|
||||
|
||||
19
src/goose/_logger.py
Normal file
19
src/goose/_logger.py
Normal file
@@ -0,0 +1,19 @@
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
_LOGGER_NAME = "goose"
|
||||
_LOGGER_FILE_NAME = "goose.log"
|
||||
|
||||
|
||||
def setup_logging(log_file_directory: Path, log_level: str = "INFO") -> None:
|
||||
logger = logging.getLogger(_LOGGER_NAME)
|
||||
logger.setLevel(getattr(logging, log_level))
|
||||
log_file_directory.mkdir(parents=True, exist_ok=True)
|
||||
file_handler = logging.FileHandler(log_file_directory / _LOGGER_FILE_NAME)
|
||||
logger.addHandler(file_handler)
|
||||
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
||||
file_handler.setFormatter(formatter)
|
||||
|
||||
|
||||
def get_logger() -> logging.Logger:
|
||||
return logging.getLogger(_LOGGER_NAME)
|
||||
@@ -17,6 +17,7 @@ GOOSE_GLOBAL_PATH = Path("~/.config/goose").expanduser()
|
||||
PROFILES_CONFIG_PATH = GOOSE_GLOBAL_PATH.joinpath("profiles.yaml")
|
||||
SESSIONS_PATH = GOOSE_GLOBAL_PATH.joinpath("sessions")
|
||||
SESSION_FILE_SUFFIX = ".jsonl"
|
||||
LOG_PATH = GOOSE_GLOBAL_PATH.joinpath("logs")
|
||||
|
||||
|
||||
@cache
|
||||
|
||||
@@ -66,7 +66,8 @@ def list_toolkits() -> None:
|
||||
@session.command(name="start")
|
||||
@click.option("--profile")
|
||||
@click.option("--plan", type=click.Path(exists=True))
|
||||
def session_start(profile: str, plan: Optional[str] = None) -> None:
|
||||
@click.option("--log-level", type=click.Choice(["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]), default="INFO")
|
||||
def session_start(profile: str, log_level: str, plan: Optional[str] = None) -> None:
|
||||
"""Start a new goose session"""
|
||||
if plan:
|
||||
yaml = YAML()
|
||||
@@ -74,7 +75,7 @@ def session_start(profile: str, plan: Optional[str] = None) -> None:
|
||||
_plan = yaml.load(f)
|
||||
else:
|
||||
_plan = None
|
||||
session = Session(profile=profile, plan=_plan)
|
||||
session = Session(profile=profile, plan=_plan, log_level=log_level)
|
||||
session.run()
|
||||
|
||||
|
||||
|
||||
@@ -12,16 +12,13 @@ from rich.panel import Panel
|
||||
from rich.status import Status
|
||||
|
||||
from goose.build import build_exchange
|
||||
from goose.cli.config import (
|
||||
default_profiles,
|
||||
ensure_config,
|
||||
read_config,
|
||||
session_path,
|
||||
)
|
||||
from goose.cli.config import default_profiles, ensure_config, read_config, session_path, LOG_PATH
|
||||
from goose._logger import get_logger, setup_logging
|
||||
from goose.cli.prompt.goose_prompt_session import GoosePromptSession
|
||||
from goose.notifier import Notifier
|
||||
from goose.profile import Profile
|
||||
from goose.utils import droid, load_plugins
|
||||
from goose.utils._cost_calculator import get_total_cost_message
|
||||
from goose.utils.session_file import read_from_file, write_to_file
|
||||
|
||||
RESUME_MESSAGE = "I see we were interrupted. How can I help you?"
|
||||
@@ -90,6 +87,7 @@ class Session:
|
||||
name: Optional[str] = None,
|
||||
profile: Optional[str] = None,
|
||||
plan: Optional[dict] = None,
|
||||
log_level: Optional[str] = "INFO",
|
||||
**kwargs: Dict[str, Any],
|
||||
) -> None:
|
||||
self.name = name
|
||||
@@ -97,6 +95,7 @@ class Session:
|
||||
self.notifier = SessionNotifier(self.status_indicator)
|
||||
|
||||
self.exchange = build_exchange(profile=load_profile(profile), notifier=self.notifier)
|
||||
setup_logging(log_file_directory=LOG_PATH, log_level=log_level)
|
||||
|
||||
if name is not None and self.session_file_path.exists():
|
||||
messages = self.load_session()
|
||||
@@ -173,6 +172,7 @@ class Session:
|
||||
message = Message.user(text=user_input.text) if user_input.to_continue() else None
|
||||
|
||||
self.save_session()
|
||||
self._log_cost()
|
||||
|
||||
def reply(self) -> None:
|
||||
"""Reply to the last user message, calling tools as needed
|
||||
@@ -256,6 +256,10 @@ class Session:
|
||||
self.name = user_entered_session_name if user_entered_session_name else droid()
|
||||
print(f"Saving to [bold cyan]{self.session_file_path}[/bold cyan]")
|
||||
|
||||
def _log_cost(self) -> None:
|
||||
get_logger().info(get_total_cost_message(self.exchange.get_token_usage()))
|
||||
print("You can view the cost and token usage in the log directory", LOG_PATH)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
session = Session()
|
||||
|
||||
39
src/goose/utils/_cost_calculator.py
Normal file
39
src/goose/utils/_cost_calculator.py
Normal file
@@ -0,0 +1,39 @@
|
||||
from typing import Optional
|
||||
from exchange.providers.base import Usage
|
||||
|
||||
PRICES = {
|
||||
"gpt-4o": (5.00, 15.00),
|
||||
"gpt-4o-2024-08-06": (2.50, 10.00),
|
||||
"gpt-4o-mini": (0.150, 0.600),
|
||||
"gpt-4o-mini-2024-07-18": (0.150, 0.600),
|
||||
"o1-preview": (15.00, 60.00),
|
||||
"o1-mini": (3.00, 12.00),
|
||||
"claude-3-5-sonnet-20240620": (3.00, 15.00),
|
||||
"anthropic.claude-3-5-sonnet-20240620-v1:0": (3.00, 15.00),
|
||||
"claude-3-opus-20240229": (15.00, 75.00),
|
||||
"anthropic.claude-3-opus-20240229-v1:0": (15.00, 75.00),
|
||||
"claude-3-haiku-20240307": (0.25, 1.25),
|
||||
"anthropic.claude-3-haiku-20240307-v1:0": (0.25, 1.25),
|
||||
}
|
||||
|
||||
|
||||
def _calculate_cost(model: str, token_usage: Usage) -> Optional[float]:
|
||||
model_name = model.lower()
|
||||
if model_name in PRICES:
|
||||
input_token_price, output_token_price = PRICES[model_name]
|
||||
return (input_token_price * token_usage.input_tokens + output_token_price * token_usage.output_tokens) / 1000000
|
||||
return None
|
||||
|
||||
|
||||
def get_total_cost_message(token_usages: dict[str, Usage]) -> str:
|
||||
total_cost = 0
|
||||
message = ""
|
||||
for model, token_usage in token_usages.items():
|
||||
cost = _calculate_cost(model, token_usage)
|
||||
if cost is not None:
|
||||
message += f"Cost for model {model} {str(token_usage)}: ${cost:.2f}\n"
|
||||
total_cost += cost
|
||||
else:
|
||||
message += f"Cost for model {model} {str(token_usage)}: Not available\n"
|
||||
message += f"Total cost: ${total_cost:.2f}"
|
||||
return message
|
||||
@@ -167,3 +167,14 @@ def test_generate_session_name(create_session_with_mock_configs):
|
||||
session.generate_session_name()
|
||||
|
||||
assert session.name == SPECIFIED_SESSION_NAME
|
||||
|
||||
|
||||
def test_log_log_cost(create_session_with_mock_configs):
|
||||
session = create_session_with_mock_configs()
|
||||
mock_logger = MagicMock()
|
||||
cost_message = "You have used 100 tokens"
|
||||
with patch("exchange.Exchange.get_token_usage", return_value={}), patch(
|
||||
"goose.cli.session.get_total_cost_message", return_value=cost_message
|
||||
), patch("goose.cli.session.get_logger", return_value=mock_logger):
|
||||
session._log_cost()
|
||||
mock_logger.info.assert_called_once_with(cost_message)
|
||||
|
||||
38
tests/utils/test_cost_calculator.py
Normal file
38
tests/utils/test_cost_calculator.py
Normal file
@@ -0,0 +1,38 @@
|
||||
from goose.utils._cost_calculator import _calculate_cost, get_total_cost_message
|
||||
from exchange.providers.base import Usage
|
||||
|
||||
|
||||
def test_calculate_cost():
|
||||
cost = _calculate_cost("gpt-4o", Usage(input_tokens=10000, output_tokens=600, total_tokens=10600))
|
||||
assert cost == 0.059
|
||||
|
||||
|
||||
def test_get_total_cost_message():
|
||||
message = get_total_cost_message(
|
||||
{
|
||||
"gpt-4o": Usage(input_tokens=10000, output_tokens=600, total_tokens=10600),
|
||||
"gpt-4o-mini": Usage(input_tokens=3000000, output_tokens=4000000, total_tokens=7000000),
|
||||
}
|
||||
)
|
||||
expected_message = (
|
||||
"Cost for model gpt-4o Usage(input_tokens=10000, output_tokens=600, total_tokens=10600): $0.06\n"
|
||||
+ "Cost for model gpt-4o-mini Usage(input_tokens=3000000, output_tokens=4000000, total_tokens=7000000)"
|
||||
+ ": $2.85\nTotal cost: $2.91"
|
||||
)
|
||||
assert message == expected_message
|
||||
|
||||
|
||||
def test_get_total_cost_message_with_non_available_pricing():
|
||||
message = get_total_cost_message(
|
||||
{
|
||||
"non_pricing_model": Usage(input_tokens=10000, output_tokens=600, total_tokens=10600),
|
||||
"gpt-4o-mini": Usage(input_tokens=3000000, output_tokens=4000000, total_tokens=7000000),
|
||||
}
|
||||
)
|
||||
expected_message = (
|
||||
"Cost for model non_pricing_model Usage(input_tokens=10000, output_tokens=600, total_tokens=10600): "
|
||||
+ "Not available\n"
|
||||
+ "Cost for model gpt-4o-mini Usage(input_tokens=3000000, output_tokens=4000000, total_tokens=7000000)"
|
||||
+ ": $2.85\nTotal cost: $2.85"
|
||||
)
|
||||
assert message == expected_message
|
||||
Reference in New Issue
Block a user