From 9dbc0d95ebb34062cd911bd08c210ad36043f601 Mon Sep 17 00:00:00 2001 From: Lifei Zhou Date: Mon, 23 Sep 2024 14:46:42 -0700 Subject: [PATCH] feat: track cost and token usage in log file (#80) --- pyproject.toml | 4 +-- src/goose/_logger.py | 19 ++++++++++++++ src/goose/cli/config.py | 1 + src/goose/cli/main.py | 5 ++-- src/goose/cli/session.py | 16 +++++++----- src/goose/utils/_cost_calculator.py | 39 +++++++++++++++++++++++++++++ tests/cli/test_session.py | 11 ++++++++ tests/utils/test_cost_calculator.py | 38 ++++++++++++++++++++++++++++ 8 files changed, 122 insertions(+), 11 deletions(-) create mode 100644 src/goose/_logger.py create mode 100644 src/goose/utils/_cost_calculator.py create mode 100644 tests/utils/test_cost_calculator.py diff --git a/pyproject.toml b/pyproject.toml index b51e5742..bc6b6ac6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,7 +8,7 @@ dependencies = [ "attrs>=23.2.0", "rich>=13.7.1", "ruamel-yaml>=0.18.6", - "ai-exchange>=0.9.0", + "ai-exchange>=0.9.2", "click>=8.1.7", "prompt-toolkit>=3.0.47", ] @@ -62,5 +62,3 @@ dev-dependencies = [ "mkdocs-include-markdown-plugin>=6.2.2", "mkdocs-callouts>=1.14.0", ] - - diff --git a/src/goose/_logger.py b/src/goose/_logger.py new file mode 100644 index 00000000..a364ceed --- /dev/null +++ b/src/goose/_logger.py @@ -0,0 +1,19 @@ +import logging +from pathlib import Path + +_LOGGER_NAME = "goose" +_LOGGER_FILE_NAME = "goose.log" + + +def setup_logging(log_file_directory: Path, log_level: str = "INFO") -> None: + logger = logging.getLogger(_LOGGER_NAME) + logger.setLevel(getattr(logging, log_level)) + log_file_directory.mkdir(parents=True, exist_ok=True) + file_handler = logging.FileHandler(log_file_directory / _LOGGER_FILE_NAME) + logger.addHandler(file_handler) + formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") + file_handler.setFormatter(formatter) + + +def get_logger() -> logging.Logger: + return logging.getLogger(_LOGGER_NAME) diff --git a/src/goose/cli/config.py b/src/goose/cli/config.py index f875b49e..ee901941 100644 --- a/src/goose/cli/config.py +++ b/src/goose/cli/config.py @@ -17,6 +17,7 @@ GOOSE_GLOBAL_PATH = Path("~/.config/goose").expanduser() PROFILES_CONFIG_PATH = GOOSE_GLOBAL_PATH.joinpath("profiles.yaml") SESSIONS_PATH = GOOSE_GLOBAL_PATH.joinpath("sessions") SESSION_FILE_SUFFIX = ".jsonl" +LOG_PATH = GOOSE_GLOBAL_PATH.joinpath("logs") @cache diff --git a/src/goose/cli/main.py b/src/goose/cli/main.py index 898ad2a6..5d3a6b31 100644 --- a/src/goose/cli/main.py +++ b/src/goose/cli/main.py @@ -66,7 +66,8 @@ def list_toolkits() -> None: @session.command(name="start") @click.option("--profile") @click.option("--plan", type=click.Path(exists=True)) -def session_start(profile: str, plan: Optional[str] = None) -> None: +@click.option("--log-level", type=click.Choice(["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]), default="INFO") +def session_start(profile: str, log_level: str, plan: Optional[str] = None) -> None: """Start a new goose session""" if plan: yaml = YAML() @@ -74,7 +75,7 @@ def session_start(profile: str, plan: Optional[str] = None) -> None: _plan = yaml.load(f) else: _plan = None - session = Session(profile=profile, plan=_plan) + session = Session(profile=profile, plan=_plan, log_level=log_level) session.run() diff --git a/src/goose/cli/session.py b/src/goose/cli/session.py index 1f72a302..b7cdeaa0 100644 --- a/src/goose/cli/session.py +++ b/src/goose/cli/session.py @@ -12,16 +12,13 @@ from rich.panel import Panel from rich.status import Status from goose.build import build_exchange -from goose.cli.config import ( - default_profiles, - ensure_config, - read_config, - session_path, -) +from goose.cli.config import default_profiles, ensure_config, read_config, session_path, LOG_PATH +from goose._logger import get_logger, setup_logging from goose.cli.prompt.goose_prompt_session import GoosePromptSession from goose.notifier import Notifier from goose.profile import Profile from goose.utils import droid, load_plugins +from goose.utils._cost_calculator import get_total_cost_message from goose.utils.session_file import read_from_file, write_to_file RESUME_MESSAGE = "I see we were interrupted. How can I help you?" @@ -90,6 +87,7 @@ class Session: name: Optional[str] = None, profile: Optional[str] = None, plan: Optional[dict] = None, + log_level: Optional[str] = "INFO", **kwargs: Dict[str, Any], ) -> None: self.name = name @@ -97,6 +95,7 @@ class Session: self.notifier = SessionNotifier(self.status_indicator) self.exchange = build_exchange(profile=load_profile(profile), notifier=self.notifier) + setup_logging(log_file_directory=LOG_PATH, log_level=log_level) if name is not None and self.session_file_path.exists(): messages = self.load_session() @@ -173,6 +172,7 @@ class Session: message = Message.user(text=user_input.text) if user_input.to_continue() else None self.save_session() + self._log_cost() def reply(self) -> None: """Reply to the last user message, calling tools as needed @@ -256,6 +256,10 @@ class Session: self.name = user_entered_session_name if user_entered_session_name else droid() print(f"Saving to [bold cyan]{self.session_file_path}[/bold cyan]") + def _log_cost(self) -> None: + get_logger().info(get_total_cost_message(self.exchange.get_token_usage())) + print("You can view the cost and token usage in the log directory", LOG_PATH) + if __name__ == "__main__": session = Session() diff --git a/src/goose/utils/_cost_calculator.py b/src/goose/utils/_cost_calculator.py new file mode 100644 index 00000000..ae2f1379 --- /dev/null +++ b/src/goose/utils/_cost_calculator.py @@ -0,0 +1,39 @@ +from typing import Optional +from exchange.providers.base import Usage + +PRICES = { + "gpt-4o": (5.00, 15.00), + "gpt-4o-2024-08-06": (2.50, 10.00), + "gpt-4o-mini": (0.150, 0.600), + "gpt-4o-mini-2024-07-18": (0.150, 0.600), + "o1-preview": (15.00, 60.00), + "o1-mini": (3.00, 12.00), + "claude-3-5-sonnet-20240620": (3.00, 15.00), + "anthropic.claude-3-5-sonnet-20240620-v1:0": (3.00, 15.00), + "claude-3-opus-20240229": (15.00, 75.00), + "anthropic.claude-3-opus-20240229-v1:0": (15.00, 75.00), + "claude-3-haiku-20240307": (0.25, 1.25), + "anthropic.claude-3-haiku-20240307-v1:0": (0.25, 1.25), +} + + +def _calculate_cost(model: str, token_usage: Usage) -> Optional[float]: + model_name = model.lower() + if model_name in PRICES: + input_token_price, output_token_price = PRICES[model_name] + return (input_token_price * token_usage.input_tokens + output_token_price * token_usage.output_tokens) / 1000000 + return None + + +def get_total_cost_message(token_usages: dict[str, Usage]) -> str: + total_cost = 0 + message = "" + for model, token_usage in token_usages.items(): + cost = _calculate_cost(model, token_usage) + if cost is not None: + message += f"Cost for model {model} {str(token_usage)}: ${cost:.2f}\n" + total_cost += cost + else: + message += f"Cost for model {model} {str(token_usage)}: Not available\n" + message += f"Total cost: ${total_cost:.2f}" + return message diff --git a/tests/cli/test_session.py b/tests/cli/test_session.py index 79a7c4a2..79b6d2bb 100644 --- a/tests/cli/test_session.py +++ b/tests/cli/test_session.py @@ -167,3 +167,14 @@ def test_generate_session_name(create_session_with_mock_configs): session.generate_session_name() assert session.name == SPECIFIED_SESSION_NAME + + +def test_log_log_cost(create_session_with_mock_configs): + session = create_session_with_mock_configs() + mock_logger = MagicMock() + cost_message = "You have used 100 tokens" + with patch("exchange.Exchange.get_token_usage", return_value={}), patch( + "goose.cli.session.get_total_cost_message", return_value=cost_message + ), patch("goose.cli.session.get_logger", return_value=mock_logger): + session._log_cost() + mock_logger.info.assert_called_once_with(cost_message) diff --git a/tests/utils/test_cost_calculator.py b/tests/utils/test_cost_calculator.py new file mode 100644 index 00000000..6a28509e --- /dev/null +++ b/tests/utils/test_cost_calculator.py @@ -0,0 +1,38 @@ +from goose.utils._cost_calculator import _calculate_cost, get_total_cost_message +from exchange.providers.base import Usage + + +def test_calculate_cost(): + cost = _calculate_cost("gpt-4o", Usage(input_tokens=10000, output_tokens=600, total_tokens=10600)) + assert cost == 0.059 + + +def test_get_total_cost_message(): + message = get_total_cost_message( + { + "gpt-4o": Usage(input_tokens=10000, output_tokens=600, total_tokens=10600), + "gpt-4o-mini": Usage(input_tokens=3000000, output_tokens=4000000, total_tokens=7000000), + } + ) + expected_message = ( + "Cost for model gpt-4o Usage(input_tokens=10000, output_tokens=600, total_tokens=10600): $0.06\n" + + "Cost for model gpt-4o-mini Usage(input_tokens=3000000, output_tokens=4000000, total_tokens=7000000)" + + ": $2.85\nTotal cost: $2.91" + ) + assert message == expected_message + + +def test_get_total_cost_message_with_non_available_pricing(): + message = get_total_cost_message( + { + "non_pricing_model": Usage(input_tokens=10000, output_tokens=600, total_tokens=10600), + "gpt-4o-mini": Usage(input_tokens=3000000, output_tokens=4000000, total_tokens=7000000), + } + ) + expected_message = ( + "Cost for model non_pricing_model Usage(input_tokens=10000, output_tokens=600, total_tokens=10600): " + + "Not available\n" + + "Cost for model gpt-4o-mini Usage(input_tokens=3000000, output_tokens=4000000, total_tokens=7000000)" + + ": $2.85\nTotal cost: $2.85" + ) + assert message == expected_message