mirror of
https://github.com/aljazceru/Auto-GPT.git
synced 2026-01-04 14:54:32 +01:00
Extract openai API calls and retry at lowest level (#3696)
* Extract open ai api calls and retry at lowest level * Forgot a test * Gotta fix my local docker config so I can let pre-commit hooks run, ugh * fix: merge artiface * Fix linting * Update memory.vector.utils * feat: make sure resp exists * fix: raise error message if created * feat: rename file * fix: partial test fix * fix: update comments * fix: linting * fix: remove broken test * fix: require a model to exist * fix: BaseError issue * fix: runtime error * Fix mock response in test_make_agent * add 429 as errors to retry --------- Co-authored-by: k-boikov <64261260+k-boikov@users.noreply.github.com> Co-authored-by: Nicholas Tindle <nick@ntindle.com> Co-authored-by: Reinier van der Leer <github@pwuts.nl> Co-authored-by: Nicholas Tindle <nicktindle@outlook.com> Co-authored-by: Luke K (pr-0f3t) <2609441+lc0rp@users.noreply.github.com> Co-authored-by: Merwane Hamadi <merwanehamadi@gmail.com>
This commit is contained in:
@@ -185,6 +185,9 @@ def start_agent(name: str, task: str, prompt: str, agent: Agent, model=None) ->
|
||||
first_message = f"""You are {name}. Respond with: "Acknowledged"."""
|
||||
agent_intro = f"{voice_name} here, Reporting for duty!"
|
||||
|
||||
if model is None:
|
||||
model = config.smart_llm_model
|
||||
|
||||
# Create agent
|
||||
if agent.config.speak_mode:
|
||||
say_text(agent_intro, 1)
|
||||
|
||||
@@ -5,8 +5,6 @@ from typing import List, Optional
|
||||
import openai
|
||||
from openai import Model
|
||||
|
||||
from autogpt.config import Config
|
||||
from autogpt.llm.base import MessageDict
|
||||
from autogpt.llm.modelsinfo import COSTS
|
||||
from autogpt.logs import logger
|
||||
from autogpt.singleton import Singleton
|
||||
@@ -27,52 +25,7 @@ class ApiManager(metaclass=Singleton):
|
||||
self.total_budget = 0.0
|
||||
self.models = None
|
||||
|
||||
def create_chat_completion(
|
||||
self,
|
||||
messages: list[MessageDict],
|
||||
model: str | None = None,
|
||||
temperature: float = None,
|
||||
max_tokens: int | None = None,
|
||||
deployment_id=None,
|
||||
) -> str:
|
||||
"""
|
||||
Create a chat completion and update the cost.
|
||||
Args:
|
||||
messages (list): The list of messages to send to the API.
|
||||
model (str): The model to use for the API call.
|
||||
temperature (float): The temperature to use for the API call.
|
||||
max_tokens (int): The maximum number of tokens for the API call.
|
||||
Returns:
|
||||
str: The AI's response.
|
||||
"""
|
||||
cfg = Config()
|
||||
if temperature is None:
|
||||
temperature = cfg.temperature
|
||||
if deployment_id is not None:
|
||||
response = openai.ChatCompletion.create(
|
||||
deployment_id=deployment_id,
|
||||
model=model,
|
||||
messages=messages,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
api_key=cfg.openai_api_key,
|
||||
)
|
||||
else:
|
||||
response = openai.ChatCompletion.create(
|
||||
model=model,
|
||||
messages=messages,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
api_key=cfg.openai_api_key,
|
||||
)
|
||||
if not hasattr(response, "error"):
|
||||
logger.debug(f"Response: {response}")
|
||||
prompt_tokens = response.usage.prompt_tokens
|
||||
completion_tokens = response.usage.completion_tokens
|
||||
self.update_cost(prompt_tokens, completion_tokens, model)
|
||||
return response
|
||||
|
||||
def update_cost(self, prompt_tokens, completion_tokens, model: str):
|
||||
def update_cost(self, prompt_tokens, completion_tokens, model):
|
||||
"""
|
||||
Update the total cost, prompt tokens, and completion tokens.
|
||||
|
||||
|
||||
@@ -7,6 +7,9 @@ from typing import List, Literal, TypedDict
|
||||
MessageRole = Literal["system", "user", "assistant"]
|
||||
MessageType = Literal["ai_response", "action_result"]
|
||||
|
||||
TText = list[int]
|
||||
"""Token array representing tokenized text"""
|
||||
|
||||
|
||||
class MessageDict(TypedDict):
|
||||
role: MessageRole
|
||||
|
||||
@@ -1,4 +1,23 @@
|
||||
from autogpt.llm.base import ChatModelInfo, EmbeddingModelInfo, TextModelInfo
|
||||
import functools
|
||||
import time
|
||||
from typing import List
|
||||
from unittest.mock import patch
|
||||
|
||||
import openai
|
||||
import openai.api_resources.abstract.engine_api_resource as engine_api_resource
|
||||
from colorama import Fore, Style
|
||||
from openai.error import APIError, RateLimitError, Timeout
|
||||
from openai.openai_object import OpenAIObject
|
||||
|
||||
from autogpt.llm.api_manager import ApiManager
|
||||
from autogpt.llm.base import (
|
||||
ChatModelInfo,
|
||||
EmbeddingModelInfo,
|
||||
MessageDict,
|
||||
TextModelInfo,
|
||||
TText,
|
||||
)
|
||||
from autogpt.logs import logger
|
||||
|
||||
OPEN_AI_CHAT_MODELS = {
|
||||
info.name: info
|
||||
@@ -72,3 +91,160 @@ OPEN_AI_MODELS: dict[str, ChatModelInfo | EmbeddingModelInfo | TextModelInfo] =
|
||||
**OPEN_AI_TEXT_MODELS,
|
||||
**OPEN_AI_EMBEDDING_MODELS,
|
||||
}
|
||||
|
||||
|
||||
def meter_api(func):
|
||||
"""Adds ApiManager metering to functions which make OpenAI API calls"""
|
||||
api_manager = ApiManager()
|
||||
|
||||
openai_obj_processor = openai.util.convert_to_openai_object
|
||||
|
||||
def update_usage_with_response(response: OpenAIObject):
|
||||
try:
|
||||
usage = response.usage
|
||||
logger.debug(f"Reported usage from call to model {response.model}: {usage}")
|
||||
api_manager.update_cost(
|
||||
response.usage.prompt_tokens,
|
||||
response.usage.completion_tokens if "completion_tokens" in usage else 0,
|
||||
response.model,
|
||||
)
|
||||
except Exception as err:
|
||||
logger.warn(f"Failed to update API costs: {err.__class__.__name__}: {err}")
|
||||
|
||||
def metering_wrapper(*args, **kwargs):
|
||||
openai_obj = openai_obj_processor(*args, **kwargs)
|
||||
if isinstance(openai_obj, OpenAIObject) and "usage" in openai_obj:
|
||||
update_usage_with_response(openai_obj)
|
||||
return openai_obj
|
||||
|
||||
def metered_func(*args, **kwargs):
|
||||
with patch.object(
|
||||
engine_api_resource.util,
|
||||
"convert_to_openai_object",
|
||||
side_effect=metering_wrapper,
|
||||
):
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return metered_func
|
||||
|
||||
|
||||
def retry_api(
|
||||
num_retries: int = 10,
|
||||
backoff_base: float = 2.0,
|
||||
warn_user: bool = True,
|
||||
):
|
||||
"""Retry an OpenAI API call.
|
||||
|
||||
Args:
|
||||
num_retries int: Number of retries. Defaults to 10.
|
||||
backoff_base float: Base for exponential backoff. Defaults to 2.
|
||||
warn_user bool: Whether to warn the user. Defaults to True.
|
||||
"""
|
||||
retry_limit_msg = f"{Fore.RED}Error: " f"Reached rate limit, passing...{Fore.RESET}"
|
||||
api_key_error_msg = (
|
||||
f"Please double check that you have setup a "
|
||||
f"{Fore.CYAN + Style.BRIGHT}PAID{Style.RESET_ALL} OpenAI API Account. You can "
|
||||
f"read more here: {Fore.CYAN}https://docs.agpt.co/setup/#getting-an-api-key{Fore.RESET}"
|
||||
)
|
||||
backoff_msg = (
|
||||
f"{Fore.RED}Error: API Bad gateway. Waiting {{backoff}} seconds...{Fore.RESET}"
|
||||
)
|
||||
|
||||
def _wrapper(func):
|
||||
@functools.wraps(func)
|
||||
def _wrapped(*args, **kwargs):
|
||||
user_warned = not warn_user
|
||||
num_attempts = num_retries + 1 # +1 for the first attempt
|
||||
for attempt in range(1, num_attempts + 1):
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
|
||||
except RateLimitError:
|
||||
if attempt == num_attempts:
|
||||
raise
|
||||
|
||||
logger.debug(retry_limit_msg)
|
||||
if not user_warned:
|
||||
logger.double_check(api_key_error_msg)
|
||||
user_warned = True
|
||||
|
||||
except (APIError, Timeout) as e:
|
||||
if (e.http_status not in [502, 429]) or (attempt == num_attempts):
|
||||
raise
|
||||
|
||||
backoff = backoff_base ** (attempt + 2)
|
||||
logger.debug(backoff_msg.format(backoff=backoff))
|
||||
time.sleep(backoff)
|
||||
|
||||
return _wrapped
|
||||
|
||||
return _wrapper
|
||||
|
||||
|
||||
@meter_api
|
||||
@retry_api()
|
||||
def create_chat_completion(
|
||||
messages: List[MessageDict],
|
||||
*_,
|
||||
**kwargs,
|
||||
) -> OpenAIObject:
|
||||
"""Create a chat completion using the OpenAI API
|
||||
|
||||
Args:
|
||||
messages: A list of messages to feed to the chatbot.
|
||||
kwargs: Other arguments to pass to the OpenAI API chat completion call.
|
||||
Returns:
|
||||
OpenAIObject: The ChatCompletion response from OpenAI
|
||||
|
||||
"""
|
||||
completion: OpenAIObject = openai.ChatCompletion.create(
|
||||
messages=messages,
|
||||
**kwargs,
|
||||
)
|
||||
if not hasattr(completion, "error"):
|
||||
logger.debug(f"Response: {completion}")
|
||||
return completion
|
||||
|
||||
|
||||
@meter_api
|
||||
@retry_api()
|
||||
def create_text_completion(
|
||||
prompt: str,
|
||||
*_,
|
||||
**kwargs,
|
||||
) -> OpenAIObject:
|
||||
"""Create a text completion using the OpenAI API
|
||||
|
||||
Args:
|
||||
prompt: A text prompt to feed to the LLM
|
||||
kwargs: Other arguments to pass to the OpenAI API text completion call.
|
||||
Returns:
|
||||
OpenAIObject: The Completion response from OpenAI
|
||||
|
||||
"""
|
||||
return openai.Completion.create(
|
||||
prompt=prompt,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
@meter_api
|
||||
@retry_api()
|
||||
def create_embedding(
|
||||
input: str | TText | List[str] | List[TText],
|
||||
*_,
|
||||
**kwargs,
|
||||
) -> OpenAIObject:
|
||||
"""Create an embedding using the OpenAI API
|
||||
|
||||
Args:
|
||||
input: The text to embed.
|
||||
kwargs: Other arguments to pass to the OpenAI API embedding call.
|
||||
Returns:
|
||||
OpenAIObject: The Embedding response from OpenAI
|
||||
|
||||
"""
|
||||
return openai.Embedding.create(
|
||||
input=input,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@@ -1,119 +1,24 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
import time
|
||||
from typing import List, Literal, Optional
|
||||
from unittest.mock import patch
|
||||
|
||||
import openai
|
||||
import openai.api_resources.abstract.engine_api_resource as engine_api_resource
|
||||
import openai.util
|
||||
from colorama import Fore, Style
|
||||
from openai.error import APIError, RateLimitError
|
||||
from openai.openai_object import OpenAIObject
|
||||
from colorama import Fore
|
||||
|
||||
from autogpt.config import Config
|
||||
from autogpt.logs import logger
|
||||
|
||||
from ..api_manager import ApiManager
|
||||
from ..base import ChatSequence, Message
|
||||
from ..providers import openai as iopenai
|
||||
from .token_counter import *
|
||||
|
||||
|
||||
def metered(func):
|
||||
"""Adds ApiManager metering to functions which make OpenAI API calls"""
|
||||
api_manager = ApiManager()
|
||||
|
||||
openai_obj_processor = openai.util.convert_to_openai_object
|
||||
|
||||
def update_usage_with_response(response: OpenAIObject):
|
||||
try:
|
||||
usage = response.usage
|
||||
logger.debug(f"Reported usage from call to model {response.model}: {usage}")
|
||||
api_manager.update_cost(
|
||||
response.usage.prompt_tokens,
|
||||
response.usage.completion_tokens if "completion_tokens" in usage else 0,
|
||||
response.model,
|
||||
)
|
||||
except Exception as err:
|
||||
logger.warn(f"Failed to update API costs: {err.__class__.__name__}: {err}")
|
||||
|
||||
def metering_wrapper(*args, **kwargs):
|
||||
openai_obj = openai_obj_processor(*args, **kwargs)
|
||||
if isinstance(openai_obj, OpenAIObject) and "usage" in openai_obj:
|
||||
update_usage_with_response(openai_obj)
|
||||
return openai_obj
|
||||
|
||||
def metered_func(*args, **kwargs):
|
||||
with patch.object(
|
||||
engine_api_resource.util,
|
||||
"convert_to_openai_object",
|
||||
side_effect=metering_wrapper,
|
||||
):
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return metered_func
|
||||
|
||||
|
||||
def retry_openai_api(
|
||||
num_retries: int = 10,
|
||||
backoff_base: float = 2.0,
|
||||
warn_user: bool = True,
|
||||
):
|
||||
"""Retry an OpenAI API call.
|
||||
|
||||
Args:
|
||||
num_retries int: Number of retries. Defaults to 10.
|
||||
backoff_base float: Base for exponential backoff. Defaults to 2.
|
||||
warn_user bool: Whether to warn the user. Defaults to True.
|
||||
"""
|
||||
retry_limit_msg = f"{Fore.RED}Error: " f"Reached rate limit, passing...{Fore.RESET}"
|
||||
api_key_error_msg = (
|
||||
f"Please double check that you have setup a "
|
||||
f"{Fore.CYAN + Style.BRIGHT}PAID{Style.RESET_ALL} OpenAI API Account. You can "
|
||||
f"read more here: {Fore.CYAN}https://docs.agpt.co/setup/#getting-an-api-key{Fore.RESET}"
|
||||
)
|
||||
backoff_msg = (
|
||||
f"{Fore.RED}Error: API Bad gateway. Waiting {{backoff}} seconds...{Fore.RESET}"
|
||||
)
|
||||
|
||||
def _wrapper(func):
|
||||
@functools.wraps(func)
|
||||
def _wrapped(*args, **kwargs):
|
||||
user_warned = not warn_user
|
||||
num_attempts = num_retries + 1 # +1 for the first attempt
|
||||
for attempt in range(1, num_attempts + 1):
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
|
||||
except RateLimitError:
|
||||
if attempt == num_attempts:
|
||||
raise
|
||||
|
||||
logger.debug(retry_limit_msg)
|
||||
if not user_warned:
|
||||
logger.double_check(api_key_error_msg)
|
||||
user_warned = True
|
||||
|
||||
except APIError as e:
|
||||
if (e.http_status not in [502, 429]) or (attempt == num_attempts):
|
||||
raise
|
||||
|
||||
backoff = backoff_base ** (attempt + 2)
|
||||
logger.debug(backoff_msg.format(backoff=backoff))
|
||||
time.sleep(backoff)
|
||||
|
||||
return _wrapped
|
||||
|
||||
return _wrapper
|
||||
|
||||
|
||||
def call_ai_function(
|
||||
function: str,
|
||||
args: list,
|
||||
description: str,
|
||||
model: str | None = None,
|
||||
config: Config = None,
|
||||
model: Optional[str] = None,
|
||||
config: Optional[Config] = None,
|
||||
) -> str:
|
||||
"""Call an AI function
|
||||
|
||||
@@ -150,8 +55,6 @@ def call_ai_function(
|
||||
return create_chat_completion(prompt=prompt, temperature=0)
|
||||
|
||||
|
||||
@metered
|
||||
@retry_openai_api()
|
||||
def create_text_completion(
|
||||
prompt: str,
|
||||
model: Optional[str],
|
||||
@@ -169,24 +72,23 @@ def create_text_completion(
|
||||
else:
|
||||
kwargs = {"model": model}
|
||||
|
||||
response = openai.Completion.create(
|
||||
**kwargs,
|
||||
response = iopenai.create_text_completion(
|
||||
prompt=prompt,
|
||||
**kwargs,
|
||||
temperature=temperature,
|
||||
max_tokens=max_output_tokens,
|
||||
api_key=cfg.openai_api_key,
|
||||
)
|
||||
logger.debug(f"Response: {response}")
|
||||
|
||||
return response.choices[0].text
|
||||
|
||||
|
||||
# Overly simple abstraction until we create something better
|
||||
# simple retry mechanism when getting a rate error or a bad gateway
|
||||
@metered
|
||||
@retry_openai_api()
|
||||
def create_chat_completion(
|
||||
prompt: ChatSequence,
|
||||
model: Optional[str] = None,
|
||||
temperature: float = None,
|
||||
temperature: Optional[float] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
) -> str:
|
||||
"""Create a chat completion using the OpenAI API
|
||||
@@ -209,41 +111,48 @@ def create_chat_completion(
|
||||
logger.debug(
|
||||
f"{Fore.GREEN}Creating chat completion with model {model}, temperature {temperature}, max_tokens {max_tokens}{Fore.RESET}"
|
||||
)
|
||||
chat_completion_kwargs = {
|
||||
"model": model,
|
||||
"temperature": temperature,
|
||||
"max_tokens": max_tokens,
|
||||
}
|
||||
|
||||
for plugin in cfg.plugins:
|
||||
if plugin.can_handle_chat_completion(
|
||||
messages=prompt.raw(),
|
||||
model=model,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
**chat_completion_kwargs,
|
||||
):
|
||||
message = plugin.handle_chat_completion(
|
||||
messages=prompt.raw(),
|
||||
model=model,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
**chat_completion_kwargs,
|
||||
)
|
||||
if message is not None:
|
||||
return message
|
||||
api_manager = ApiManager()
|
||||
response = None
|
||||
|
||||
chat_completion_kwargs["api_key"] = cfg.openai_api_key
|
||||
if cfg.use_azure:
|
||||
kwargs = {"deployment_id": cfg.get_azure_deployment_id_for_model(model)}
|
||||
else:
|
||||
kwargs = {"model": model}
|
||||
chat_completion_kwargs["deployment_id"] = cfg.get_azure_deployment_id_for_model(
|
||||
model
|
||||
)
|
||||
|
||||
response = api_manager.create_chat_completion(
|
||||
**kwargs,
|
||||
response = iopenai.create_chat_completion(
|
||||
messages=prompt.raw(),
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
**chat_completion_kwargs,
|
||||
)
|
||||
logger.debug(f"Response: {response}")
|
||||
|
||||
resp = ""
|
||||
if not hasattr(response, "error"):
|
||||
resp = response.choices[0].message["content"]
|
||||
else:
|
||||
logger.error(response.error)
|
||||
raise RuntimeError(response.error)
|
||||
|
||||
resp = response.choices[0].message["content"]
|
||||
for plugin in cfg.plugins:
|
||||
if not plugin.can_handle_on_response():
|
||||
continue
|
||||
resp = plugin.on_response(resp)
|
||||
|
||||
return resp
|
||||
|
||||
|
||||
|
||||
@@ -1,16 +1,14 @@
|
||||
from typing import Any, overload
|
||||
|
||||
import numpy as np
|
||||
import openai
|
||||
|
||||
from autogpt.config import Config
|
||||
from autogpt.llm.utils import metered, retry_openai_api
|
||||
from autogpt.llm.base import TText
|
||||
from autogpt.llm.providers import openai as iopenai
|
||||
from autogpt.logs import logger
|
||||
|
||||
Embedding = list[np.float32] | np.ndarray[Any, np.dtype[np.float32]]
|
||||
"""Embedding vector"""
|
||||
TText = list[int]
|
||||
"""Token array representing text"""
|
||||
|
||||
|
||||
@overload
|
||||
@@ -23,8 +21,6 @@ def get_embedding(input: list[str] | list[TText]) -> list[Embedding]:
|
||||
...
|
||||
|
||||
|
||||
@metered
|
||||
@retry_openai_api()
|
||||
def get_embedding(
|
||||
input: str | TText | list[str] | list[TText],
|
||||
) -> Embedding | list[Embedding]:
|
||||
@@ -57,10 +53,10 @@ def get_embedding(
|
||||
+ (f" via Azure deployment '{kwargs['engine']}'" if cfg.use_azure else "")
|
||||
)
|
||||
|
||||
embeddings = openai.Embedding.create(
|
||||
input=input,
|
||||
api_key=cfg.openai_api_key,
|
||||
embeddings = iopenai.create_embedding(
|
||||
input,
|
||||
**kwargs,
|
||||
api_key=cfg.openai_api_key,
|
||||
).data
|
||||
|
||||
if not multiple:
|
||||
|
||||
67
tests/integration/test_provider_openai.py
Normal file
67
tests/integration/test_provider_openai.py
Normal file
@@ -0,0 +1,67 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from autogpt.llm.api_manager import COSTS, ApiManager
|
||||
from autogpt.llm.providers import openai
|
||||
|
||||
api_manager = ApiManager()
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_api_manager():
|
||||
api_manager.reset()
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_costs():
|
||||
with patch.dict(
|
||||
COSTS,
|
||||
{
|
||||
"gpt-3.5-turbo": {"prompt": 0.002, "completion": 0.002},
|
||||
"text-embedding-ada-002": {"prompt": 0.0004, "completion": 0},
|
||||
},
|
||||
clear=True,
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
class TestProviderOpenAI:
|
||||
@staticmethod
|
||||
def test_create_chat_completion_debug_mode(caplog):
|
||||
"""Test if debug mode logs response."""
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Who won the world series in 2020?"},
|
||||
]
|
||||
model = "gpt-3.5-turbo"
|
||||
with patch("openai.ChatCompletion.create") as mock_create:
|
||||
mock_response = MagicMock()
|
||||
del mock_response.error
|
||||
mock_response.usage.prompt_tokens = 10
|
||||
mock_response.usage.completion_tokens = 20
|
||||
mock_create.return_value = mock_response
|
||||
|
||||
openai.create_chat_completion(messages, model=model)
|
||||
|
||||
assert "Response" in caplog.text
|
||||
|
||||
@staticmethod
|
||||
def test_create_chat_completion_empty_messages():
|
||||
"""Test if empty messages result in zero tokens and cost."""
|
||||
messages = []
|
||||
model = "gpt-3.5-turbo"
|
||||
|
||||
with patch("openai.ChatCompletion.create") as mock_create:
|
||||
mock_response = MagicMock()
|
||||
del mock_response.error
|
||||
mock_response.usage.prompt_tokens = 0
|
||||
mock_response.usage.completion_tokens = 0
|
||||
mock_create.return_value = mock_response
|
||||
|
||||
openai.create_chat_completion(messages, model=model)
|
||||
|
||||
assert api_manager.get_total_prompt_tokens() == 0
|
||||
assert api_manager.get_total_completion_tokens() == 0
|
||||
assert api_manager.get_total_cost() == 0
|
||||
@@ -1,4 +1,4 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -27,68 +27,6 @@ def mock_costs():
|
||||
|
||||
|
||||
class TestApiManager:
|
||||
@staticmethod
|
||||
def test_create_chat_completion_debug_mode(caplog):
|
||||
"""Test if debug mode logs response."""
|
||||
api_manager_debug = ApiManager(debug=True)
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Who won the world series in 2020?"},
|
||||
]
|
||||
model = "gpt-3.5-turbo"
|
||||
|
||||
with patch("openai.ChatCompletion.create") as mock_create:
|
||||
mock_response = MagicMock()
|
||||
del mock_response.error
|
||||
mock_response.usage.prompt_tokens = 10
|
||||
mock_response.usage.completion_tokens = 20
|
||||
mock_create.return_value = mock_response
|
||||
|
||||
api_manager_debug.create_chat_completion(messages, model=model)
|
||||
|
||||
assert "Response" in caplog.text
|
||||
|
||||
@staticmethod
|
||||
def test_create_chat_completion_empty_messages():
|
||||
"""Test if empty messages result in zero tokens and cost."""
|
||||
messages = []
|
||||
model = "gpt-3.5-turbo"
|
||||
|
||||
with patch("openai.ChatCompletion.create") as mock_create:
|
||||
mock_response = MagicMock()
|
||||
del mock_response.error
|
||||
mock_response.usage.prompt_tokens = 0
|
||||
mock_response.usage.completion_tokens = 0
|
||||
mock_create.return_value = mock_response
|
||||
|
||||
api_manager.create_chat_completion(messages, model=model)
|
||||
|
||||
assert api_manager.get_total_prompt_tokens() == 0
|
||||
assert api_manager.get_total_completion_tokens() == 0
|
||||
assert api_manager.get_total_cost() == 0
|
||||
|
||||
@staticmethod
|
||||
def test_create_chat_completion_valid_inputs():
|
||||
"""Test if valid inputs result in correct tokens and cost."""
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Who won the world series in 2020?"},
|
||||
]
|
||||
model = "gpt-3.5-turbo"
|
||||
|
||||
with patch("openai.ChatCompletion.create") as mock_create:
|
||||
mock_response = MagicMock()
|
||||
del mock_response.error
|
||||
mock_response.usage.prompt_tokens = 10
|
||||
mock_response.usage.completion_tokens = 20
|
||||
mock_create.return_value = mock_response
|
||||
|
||||
api_manager.create_chat_completion(messages, model=model)
|
||||
|
||||
assert api_manager.get_total_prompt_tokens() == 10
|
||||
assert api_manager.get_total_completion_tokens() == 20
|
||||
assert api_manager.get_total_cost() == (10 * 0.002 + 20 * 0.002) / 1000
|
||||
|
||||
def test_getter_methods(self):
|
||||
"""Test the getter methods for total tokens, cost, and budget."""
|
||||
api_manager.update_cost(60, 120, "gpt-3.5-turbo")
|
||||
|
||||
@@ -11,7 +11,7 @@ def test_make_agent(agent: Agent, mocker: MockerFixture) -> None:
|
||||
mock = mocker.patch("openai.ChatCompletion.create")
|
||||
|
||||
response = MagicMock()
|
||||
# del response.error
|
||||
del response.error
|
||||
response.choices[0].messages[0].content = "Test message"
|
||||
response.usage.prompt_tokens = 1
|
||||
response.usage.completion_tokens = 1
|
||||
|
||||
110
tests/unit/test_retry_provider_openai.py
Normal file
110
tests/unit/test_retry_provider_openai.py
Normal file
@@ -0,0 +1,110 @@
|
||||
import pytest
|
||||
from openai.error import APIError, RateLimitError
|
||||
|
||||
from autogpt.llm.providers import openai
|
||||
|
||||
|
||||
@pytest.fixture(params=[RateLimitError, APIError])
|
||||
def error(request):
|
||||
if request.param == APIError:
|
||||
return request.param("Error", http_status=502)
|
||||
else:
|
||||
return request.param("Error")
|
||||
|
||||
|
||||
def error_factory(error_instance, error_count, retry_count, warn_user=True):
|
||||
"""Creates errors"""
|
||||
|
||||
class RaisesError:
|
||||
def __init__(self):
|
||||
self.count = 0
|
||||
|
||||
@openai.retry_api(
|
||||
num_retries=retry_count, backoff_base=0.001, warn_user=warn_user
|
||||
)
|
||||
def __call__(self):
|
||||
self.count += 1
|
||||
if self.count <= error_count:
|
||||
raise error_instance
|
||||
return self.count
|
||||
|
||||
return RaisesError()
|
||||
|
||||
|
||||
def test_retry_open_api_no_error(capsys):
|
||||
"""Tests the retry functionality with no errors expected"""
|
||||
|
||||
@openai.retry_api()
|
||||
def f():
|
||||
return 1
|
||||
|
||||
result = f()
|
||||
assert result == 1
|
||||
|
||||
output = capsys.readouterr()
|
||||
assert output.out == ""
|
||||
assert output.err == ""
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"error_count, retry_count, failure",
|
||||
[(2, 10, False), (2, 2, False), (10, 2, True), (3, 2, True), (1, 0, True)],
|
||||
ids=["passing", "passing_edge", "failing", "failing_edge", "failing_no_retries"],
|
||||
)
|
||||
def test_retry_open_api_passing(capsys, error, error_count, retry_count, failure):
|
||||
"""Tests the retry with simulated errors [RateLimitError, APIError], but should ulimately pass"""
|
||||
call_count = min(error_count, retry_count) + 1
|
||||
|
||||
raises = error_factory(error, error_count, retry_count)
|
||||
if failure:
|
||||
with pytest.raises(type(error)):
|
||||
raises()
|
||||
else:
|
||||
result = raises()
|
||||
assert result == call_count
|
||||
|
||||
assert raises.count == call_count
|
||||
|
||||
output = capsys.readouterr()
|
||||
|
||||
if error_count and retry_count:
|
||||
if type(error) == RateLimitError:
|
||||
assert "Reached rate limit, passing..." in output.out
|
||||
assert "Please double check" in output.out
|
||||
if type(error) == APIError:
|
||||
assert "API Bad gateway" in output.out
|
||||
else:
|
||||
assert output.out == ""
|
||||
|
||||
|
||||
def test_retry_open_api_rate_limit_no_warn(capsys):
|
||||
"""Tests the retry logic with a rate limit error"""
|
||||
error_count = 2
|
||||
retry_count = 10
|
||||
|
||||
raises = error_factory(RateLimitError, error_count, retry_count, warn_user=False)
|
||||
result = raises()
|
||||
call_count = min(error_count, retry_count) + 1
|
||||
assert result == call_count
|
||||
assert raises.count == call_count
|
||||
|
||||
output = capsys.readouterr()
|
||||
|
||||
assert "Reached rate limit, passing..." in output.out
|
||||
assert "Please double check" not in output.out
|
||||
|
||||
|
||||
def test_retry_openapi_other_api_error(capsys):
|
||||
"""Tests the Retry logic with a non rate limit error such as HTTP500"""
|
||||
error_count = 2
|
||||
retry_count = 10
|
||||
|
||||
raises = error_factory(APIError("Error", http_status=500), error_count, retry_count)
|
||||
|
||||
with pytest.raises(APIError):
|
||||
raises()
|
||||
call_count = 1
|
||||
assert raises.count == call_count
|
||||
|
||||
output = capsys.readouterr()
|
||||
assert output.out == ""
|
||||
Reference in New Issue
Block a user