diff --git a/autogpt/llm/api_manager.py b/autogpt/llm/api_manager.py index 7442579d..7a384562 100644 --- a/autogpt/llm/api_manager.py +++ b/autogpt/llm/api_manager.py @@ -6,8 +6,8 @@ import openai from openai import Model from autogpt.config import Config -from autogpt.llm.base import MessageDict -from autogpt.llm.modelsinfo import COSTS +from autogpt.llm.base import CompletionModelInfo, MessageDict +from autogpt.llm.providers.openai import OPEN_AI_MODELS from autogpt.logs import logger from autogpt.singleton import Singleton @@ -83,13 +83,14 @@ class ApiManager(metaclass=Singleton): """ # the .model property in API responses can contain version suffixes like -v2 model = model[:-3] if model.endswith("-v2") else model + model_info = OPEN_AI_MODELS[model] self.total_prompt_tokens += prompt_tokens self.total_completion_tokens += completion_tokens - self.total_cost += ( - prompt_tokens * COSTS[model]["prompt"] - + completion_tokens * COSTS[model]["completion"] - ) / 1000 + self.total_cost += prompt_tokens * model_info.prompt_token_cost / 1000 + if issubclass(type(model_info), CompletionModelInfo): + self.total_cost += completion_tokens * model_info.completion_token_cost / 1000 + logger.debug(f"Total running cost: ${self.total_cost:.3f}") def set_total_budget(self, total_budget): diff --git a/autogpt/llm/base.py b/autogpt/llm/base.py index 76bd3db1..43cc0ad9 100644 --- a/autogpt/llm/base.py +++ b/autogpt/llm/base.py @@ -31,22 +31,27 @@ class ModelInfo: Would be lovely to eventually get this directly from APIs, but needs to be scraped from websites for now. - """ name: str - prompt_token_cost: float - completion_token_cost: float max_tokens: int + prompt_token_cost: float @dataclass -class ChatModelInfo(ModelInfo): +class CompletionModelInfo(ModelInfo): + """Struct for generic completion model information.""" + + completion_token_cost: float + + +@dataclass +class ChatModelInfo(CompletionModelInfo): """Struct for chat model information.""" @dataclass -class TextModelInfo(ModelInfo): +class TextModelInfo(CompletionModelInfo): """Struct for text completion model information.""" diff --git a/autogpt/llm/modelsinfo.py b/autogpt/llm/modelsinfo.py deleted file mode 100644 index 425472de..00000000 --- a/autogpt/llm/modelsinfo.py +++ /dev/null @@ -1,11 +0,0 @@ -COSTS = { - "gpt-3.5-turbo": {"prompt": 0.002, "completion": 0.002}, - "gpt-3.5-turbo-0301": {"prompt": 0.002, "completion": 0.002}, - "gpt-4-0314": {"prompt": 0.03, "completion": 0.06}, - "gpt-4": {"prompt": 0.03, "completion": 0.06}, - "gpt-4-0314": {"prompt": 0.03, "completion": 0.06}, - "gpt-4-32k": {"prompt": 0.06, "completion": 0.12}, - "gpt-4-32k-0314": {"prompt": 0.06, "completion": 0.12}, - "text-embedding-ada-002": {"prompt": 0.0004, "completion": 0.0}, - "text-davinci-003": {"prompt": 0.02, "completion": 0.02}, -} diff --git a/autogpt/llm/providers/openai.py b/autogpt/llm/providers/openai.py index acaf0671..b4254cd1 100644 --- a/autogpt/llm/providers/openai.py +++ b/autogpt/llm/providers/openai.py @@ -3,23 +3,23 @@ from autogpt.llm.base import ChatModelInfo, EmbeddingModelInfo, TextModelInfo OPEN_AI_CHAT_MODELS = { info.name: info for info in [ - ChatModelInfo( - name="gpt-3.5-turbo", - prompt_token_cost=0.002, - completion_token_cost=0.002, - max_tokens=4096, - ), ChatModelInfo( name="gpt-3.5-turbo-0301", - prompt_token_cost=0.002, + prompt_token_cost=0.0015, completion_token_cost=0.002, max_tokens=4096, ), ChatModelInfo( - name="gpt-4", - prompt_token_cost=0.03, - completion_token_cost=0.06, - max_tokens=8192, + name="gpt-3.5-turbo-0613", + prompt_token_cost=0.0015, + completion_token_cost=0.002, + max_tokens=4096, + ), + ChatModelInfo( + name="gpt-3.5-turbo-16k-0613", + prompt_token_cost=0.003, + completion_token_cost=0.004, + max_tokens=16384, ), ChatModelInfo( name="gpt-4-0314", @@ -28,10 +28,10 @@ OPEN_AI_CHAT_MODELS = { max_tokens=8192, ), ChatModelInfo( - name="gpt-4-32k", - prompt_token_cost=0.06, - completion_token_cost=0.12, - max_tokens=32768, + name="gpt-4-0613", + prompt_token_cost=0.03, + completion_token_cost=0.06, + max_tokens=8192, ), ChatModelInfo( name="gpt-4-32k-0314", @@ -39,8 +39,19 @@ OPEN_AI_CHAT_MODELS = { completion_token_cost=0.12, max_tokens=32768, ), + ChatModelInfo( + name="gpt-4-32k-0613", + prompt_token_cost=0.06, + completion_token_cost=0.12, + max_tokens=32768, + ), ] } +# Set aliases for rolling model IDs +OPEN_AI_CHAT_MODELS["gpt-3.5-turbo"] = OPEN_AI_CHAT_MODELS["gpt-3.5-turbo-0301"] +OPEN_AI_CHAT_MODELS["gpt-3.5-turbo-16k"] = OPEN_AI_CHAT_MODELS["gpt-3.5-turbo-16k-0613"] +OPEN_AI_CHAT_MODELS["gpt-4"] = OPEN_AI_CHAT_MODELS["gpt-4-0314"] +OPEN_AI_CHAT_MODELS["gpt-4-32k"] = OPEN_AI_CHAT_MODELS["gpt-4-32k-0314"] OPEN_AI_TEXT_MODELS = { info.name: info @@ -59,8 +70,7 @@ OPEN_AI_EMBEDDING_MODELS = { for info in [ EmbeddingModelInfo( name="text-embedding-ada-002", - prompt_token_cost=0.0004, - completion_token_cost=0.0, + prompt_token_cost=0.0001, max_tokens=8191, embedding_dimensions=1536, ), diff --git a/tests/unit/test_api_manager.py b/tests/unit/test_api_manager.py index 9585fba7..e259f56a 100644 --- a/tests/unit/test_api_manager.py +++ b/tests/unit/test_api_manager.py @@ -1,8 +1,9 @@ from unittest.mock import MagicMock, patch import pytest +from pytest_mock import MockerFixture -from autogpt.llm.api_manager import COSTS, ApiManager +from autogpt.llm.api_manager import OPEN_AI_MODELS, ApiManager api_manager = ApiManager() @@ -14,16 +15,17 @@ def reset_api_manager(): @pytest.fixture(autouse=True) -def mock_costs(): - with patch.dict( - COSTS, - { - "gpt-3.5-turbo": {"prompt": 0.002, "completion": 0.002}, - "text-embedding-ada-002": {"prompt": 0.0004, "completion": 0}, - }, - clear=True, - ): - yield +def mock_costs(mocker: MockerFixture): + mocker.patch.multiple( + OPEN_AI_MODELS["gpt-3.5-turbo"], + prompt_token_cost=0.0013, + completion_token_cost=0.0025, + ) + mocker.patch.multiple( + OPEN_AI_MODELS["text-embedding-ada-002"], + prompt_token_cost=0.0004, + ) + yield class TestApiManager: @@ -87,15 +89,15 @@ class TestApiManager: assert api_manager.get_total_prompt_tokens() == 10 assert api_manager.get_total_completion_tokens() == 20 - assert api_manager.get_total_cost() == (10 * 0.002 + 20 * 0.002) / 1000 + assert api_manager.get_total_cost() == (10 * 0.0013 + 20 * 0.0025) / 1000 def test_getter_methods(self): """Test the getter methods for total tokens, cost, and budget.""" - api_manager.update_cost(60, 120, "gpt-3.5-turbo") + api_manager.update_cost(600, 1200, "gpt-3.5-turbo") api_manager.set_total_budget(10.0) - assert api_manager.get_total_prompt_tokens() == 60 - assert api_manager.get_total_completion_tokens() == 120 - assert api_manager.get_total_cost() == (60 * 0.002 + 120 * 0.002) / 1000 + assert api_manager.get_total_prompt_tokens() == 600 + assert api_manager.get_total_completion_tokens() == 1200 + assert api_manager.get_total_cost() == (600 * 0.0013 + 1200 * 0.0025) / 1000 assert api_manager.get_total_budget() == 10.0 @staticmethod @@ -107,7 +109,7 @@ class TestApiManager: assert api_manager.get_total_budget() == total_budget @staticmethod - def test_update_cost(): + def test_update_cost_completion_model(): """Test if updating the cost works correctly.""" prompt_tokens = 50 completion_tokens = 100 @@ -115,9 +117,24 @@ class TestApiManager: api_manager.update_cost(prompt_tokens, completion_tokens, model) - assert api_manager.get_total_prompt_tokens() == 50 - assert api_manager.get_total_completion_tokens() == 100 - assert api_manager.get_total_cost() == (50 * 0.002 + 100 * 0.002) / 1000 + assert api_manager.get_total_prompt_tokens() == prompt_tokens + assert api_manager.get_total_completion_tokens() == completion_tokens + assert ( + api_manager.get_total_cost() + == (prompt_tokens * 0.0013 + completion_tokens * 0.0025) / 1000 + ) + + @staticmethod + def test_update_cost_embedding_model(): + """Test if updating the cost works correctly.""" + prompt_tokens = 1337 + model = "text-embedding-ada-002" + + api_manager.update_cost(prompt_tokens, 0, model) + + assert api_manager.get_total_prompt_tokens() == prompt_tokens + assert api_manager.get_total_completion_tokens() == 0 + assert api_manager.get_total_cost() == (prompt_tokens * 0.0004) / 1000 @staticmethod def test_get_models():