Extract openai API calls and retry at lowest level (#3696)

* Extract open ai api calls and retry at lowest level

* Forgot a test

* Gotta fix my local docker config so I can let pre-commit hooks run, ugh

* fix: merge artiface

* Fix linting

* Update memory.vector.utils

* feat: make sure resp exists

* fix: raise error message if created

* feat: rename file

* fix: partial test fix

* fix: update comments

* fix: linting

* fix: remove broken test

* fix: require a model to exist

* fix: BaseError issue

* fix: runtime error

* Fix mock response in test_make_agent

* add 429 as errors to retry

---------

Co-authored-by: k-boikov <64261260+k-boikov@users.noreply.github.com>
Co-authored-by: Nicholas Tindle <nick@ntindle.com>
Co-authored-by: Reinier van der Leer <github@pwuts.nl>
Co-authored-by: Nicholas Tindle <nicktindle@outlook.com>
Co-authored-by: Luke K (pr-0f3t) <2609441+lc0rp@users.noreply.github.com>
Co-authored-by: Merwane Hamadi <merwanehamadi@gmail.com>
This commit is contained in:
James Collins
2023-06-14 07:59:26 -07:00
committed by GitHub
parent 49d1a5a17b
commit 6e6e7fcc9a
10 changed files with 400 additions and 245 deletions

View File

@@ -185,6 +185,9 @@ def start_agent(name: str, task: str, prompt: str, agent: Agent, model=None) ->
first_message = f"""You are {name}. Respond with: "Acknowledged"."""
agent_intro = f"{voice_name} here, Reporting for duty!"
if model is None:
model = config.smart_llm_model
# Create agent
if agent.config.speak_mode:
say_text(agent_intro, 1)

View File

@@ -5,8 +5,6 @@ from typing import List, Optional
import openai
from openai import Model
from autogpt.config import Config
from autogpt.llm.base import MessageDict
from autogpt.llm.modelsinfo import COSTS
from autogpt.logs import logger
from autogpt.singleton import Singleton
@@ -27,52 +25,7 @@ class ApiManager(metaclass=Singleton):
self.total_budget = 0.0
self.models = None
def create_chat_completion(
self,
messages: list[MessageDict],
model: str | None = None,
temperature: float = None,
max_tokens: int | None = None,
deployment_id=None,
) -> str:
"""
Create a chat completion and update the cost.
Args:
messages (list): The list of messages to send to the API.
model (str): The model to use for the API call.
temperature (float): The temperature to use for the API call.
max_tokens (int): The maximum number of tokens for the API call.
Returns:
str: The AI's response.
"""
cfg = Config()
if temperature is None:
temperature = cfg.temperature
if deployment_id is not None:
response = openai.ChatCompletion.create(
deployment_id=deployment_id,
model=model,
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
api_key=cfg.openai_api_key,
)
else:
response = openai.ChatCompletion.create(
model=model,
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
api_key=cfg.openai_api_key,
)
if not hasattr(response, "error"):
logger.debug(f"Response: {response}")
prompt_tokens = response.usage.prompt_tokens
completion_tokens = response.usage.completion_tokens
self.update_cost(prompt_tokens, completion_tokens, model)
return response
def update_cost(self, prompt_tokens, completion_tokens, model: str):
def update_cost(self, prompt_tokens, completion_tokens, model):
"""
Update the total cost, prompt tokens, and completion tokens.

View File

@@ -7,6 +7,9 @@ from typing import List, Literal, TypedDict
MessageRole = Literal["system", "user", "assistant"]
MessageType = Literal["ai_response", "action_result"]
TText = list[int]
"""Token array representing tokenized text"""
class MessageDict(TypedDict):
role: MessageRole

View File

@@ -1,4 +1,23 @@
from autogpt.llm.base import ChatModelInfo, EmbeddingModelInfo, TextModelInfo
import functools
import time
from typing import List
from unittest.mock import patch
import openai
import openai.api_resources.abstract.engine_api_resource as engine_api_resource
from colorama import Fore, Style
from openai.error import APIError, RateLimitError, Timeout
from openai.openai_object import OpenAIObject
from autogpt.llm.api_manager import ApiManager
from autogpt.llm.base import (
ChatModelInfo,
EmbeddingModelInfo,
MessageDict,
TextModelInfo,
TText,
)
from autogpt.logs import logger
OPEN_AI_CHAT_MODELS = {
info.name: info
@@ -72,3 +91,160 @@ OPEN_AI_MODELS: dict[str, ChatModelInfo | EmbeddingModelInfo | TextModelInfo] =
**OPEN_AI_TEXT_MODELS,
**OPEN_AI_EMBEDDING_MODELS,
}
def meter_api(func):
"""Adds ApiManager metering to functions which make OpenAI API calls"""
api_manager = ApiManager()
openai_obj_processor = openai.util.convert_to_openai_object
def update_usage_with_response(response: OpenAIObject):
try:
usage = response.usage
logger.debug(f"Reported usage from call to model {response.model}: {usage}")
api_manager.update_cost(
response.usage.prompt_tokens,
response.usage.completion_tokens if "completion_tokens" in usage else 0,
response.model,
)
except Exception as err:
logger.warn(f"Failed to update API costs: {err.__class__.__name__}: {err}")
def metering_wrapper(*args, **kwargs):
openai_obj = openai_obj_processor(*args, **kwargs)
if isinstance(openai_obj, OpenAIObject) and "usage" in openai_obj:
update_usage_with_response(openai_obj)
return openai_obj
def metered_func(*args, **kwargs):
with patch.object(
engine_api_resource.util,
"convert_to_openai_object",
side_effect=metering_wrapper,
):
return func(*args, **kwargs)
return metered_func
def retry_api(
num_retries: int = 10,
backoff_base: float = 2.0,
warn_user: bool = True,
):
"""Retry an OpenAI API call.
Args:
num_retries int: Number of retries. Defaults to 10.
backoff_base float: Base for exponential backoff. Defaults to 2.
warn_user bool: Whether to warn the user. Defaults to True.
"""
retry_limit_msg = f"{Fore.RED}Error: " f"Reached rate limit, passing...{Fore.RESET}"
api_key_error_msg = (
f"Please double check that you have setup a "
f"{Fore.CYAN + Style.BRIGHT}PAID{Style.RESET_ALL} OpenAI API Account. You can "
f"read more here: {Fore.CYAN}https://docs.agpt.co/setup/#getting-an-api-key{Fore.RESET}"
)
backoff_msg = (
f"{Fore.RED}Error: API Bad gateway. Waiting {{backoff}} seconds...{Fore.RESET}"
)
def _wrapper(func):
@functools.wraps(func)
def _wrapped(*args, **kwargs):
user_warned = not warn_user
num_attempts = num_retries + 1 # +1 for the first attempt
for attempt in range(1, num_attempts + 1):
try:
return func(*args, **kwargs)
except RateLimitError:
if attempt == num_attempts:
raise
logger.debug(retry_limit_msg)
if not user_warned:
logger.double_check(api_key_error_msg)
user_warned = True
except (APIError, Timeout) as e:
if (e.http_status not in [502, 429]) or (attempt == num_attempts):
raise
backoff = backoff_base ** (attempt + 2)
logger.debug(backoff_msg.format(backoff=backoff))
time.sleep(backoff)
return _wrapped
return _wrapper
@meter_api
@retry_api()
def create_chat_completion(
messages: List[MessageDict],
*_,
**kwargs,
) -> OpenAIObject:
"""Create a chat completion using the OpenAI API
Args:
messages: A list of messages to feed to the chatbot.
kwargs: Other arguments to pass to the OpenAI API chat completion call.
Returns:
OpenAIObject: The ChatCompletion response from OpenAI
"""
completion: OpenAIObject = openai.ChatCompletion.create(
messages=messages,
**kwargs,
)
if not hasattr(completion, "error"):
logger.debug(f"Response: {completion}")
return completion
@meter_api
@retry_api()
def create_text_completion(
prompt: str,
*_,
**kwargs,
) -> OpenAIObject:
"""Create a text completion using the OpenAI API
Args:
prompt: A text prompt to feed to the LLM
kwargs: Other arguments to pass to the OpenAI API text completion call.
Returns:
OpenAIObject: The Completion response from OpenAI
"""
return openai.Completion.create(
prompt=prompt,
**kwargs,
)
@meter_api
@retry_api()
def create_embedding(
input: str | TText | List[str] | List[TText],
*_,
**kwargs,
) -> OpenAIObject:
"""Create an embedding using the OpenAI API
Args:
input: The text to embed.
kwargs: Other arguments to pass to the OpenAI API embedding call.
Returns:
OpenAIObject: The Embedding response from OpenAI
"""
return openai.Embedding.create(
input=input,
**kwargs,
)

View File

@@ -1,119 +1,24 @@
from __future__ import annotations
import functools
import time
from typing import List, Literal, Optional
from unittest.mock import patch
import openai
import openai.api_resources.abstract.engine_api_resource as engine_api_resource
import openai.util
from colorama import Fore, Style
from openai.error import APIError, RateLimitError
from openai.openai_object import OpenAIObject
from colorama import Fore
from autogpt.config import Config
from autogpt.logs import logger
from ..api_manager import ApiManager
from ..base import ChatSequence, Message
from ..providers import openai as iopenai
from .token_counter import *
def metered(func):
"""Adds ApiManager metering to functions which make OpenAI API calls"""
api_manager = ApiManager()
openai_obj_processor = openai.util.convert_to_openai_object
def update_usage_with_response(response: OpenAIObject):
try:
usage = response.usage
logger.debug(f"Reported usage from call to model {response.model}: {usage}")
api_manager.update_cost(
response.usage.prompt_tokens,
response.usage.completion_tokens if "completion_tokens" in usage else 0,
response.model,
)
except Exception as err:
logger.warn(f"Failed to update API costs: {err.__class__.__name__}: {err}")
def metering_wrapper(*args, **kwargs):
openai_obj = openai_obj_processor(*args, **kwargs)
if isinstance(openai_obj, OpenAIObject) and "usage" in openai_obj:
update_usage_with_response(openai_obj)
return openai_obj
def metered_func(*args, **kwargs):
with patch.object(
engine_api_resource.util,
"convert_to_openai_object",
side_effect=metering_wrapper,
):
return func(*args, **kwargs)
return metered_func
def retry_openai_api(
num_retries: int = 10,
backoff_base: float = 2.0,
warn_user: bool = True,
):
"""Retry an OpenAI API call.
Args:
num_retries int: Number of retries. Defaults to 10.
backoff_base float: Base for exponential backoff. Defaults to 2.
warn_user bool: Whether to warn the user. Defaults to True.
"""
retry_limit_msg = f"{Fore.RED}Error: " f"Reached rate limit, passing...{Fore.RESET}"
api_key_error_msg = (
f"Please double check that you have setup a "
f"{Fore.CYAN + Style.BRIGHT}PAID{Style.RESET_ALL} OpenAI API Account. You can "
f"read more here: {Fore.CYAN}https://docs.agpt.co/setup/#getting-an-api-key{Fore.RESET}"
)
backoff_msg = (
f"{Fore.RED}Error: API Bad gateway. Waiting {{backoff}} seconds...{Fore.RESET}"
)
def _wrapper(func):
@functools.wraps(func)
def _wrapped(*args, **kwargs):
user_warned = not warn_user
num_attempts = num_retries + 1 # +1 for the first attempt
for attempt in range(1, num_attempts + 1):
try:
return func(*args, **kwargs)
except RateLimitError:
if attempt == num_attempts:
raise
logger.debug(retry_limit_msg)
if not user_warned:
logger.double_check(api_key_error_msg)
user_warned = True
except APIError as e:
if (e.http_status not in [502, 429]) or (attempt == num_attempts):
raise
backoff = backoff_base ** (attempt + 2)
logger.debug(backoff_msg.format(backoff=backoff))
time.sleep(backoff)
return _wrapped
return _wrapper
def call_ai_function(
function: str,
args: list,
description: str,
model: str | None = None,
config: Config = None,
model: Optional[str] = None,
config: Optional[Config] = None,
) -> str:
"""Call an AI function
@@ -150,8 +55,6 @@ def call_ai_function(
return create_chat_completion(prompt=prompt, temperature=0)
@metered
@retry_openai_api()
def create_text_completion(
prompt: str,
model: Optional[str],
@@ -169,24 +72,23 @@ def create_text_completion(
else:
kwargs = {"model": model}
response = openai.Completion.create(
**kwargs,
response = iopenai.create_text_completion(
prompt=prompt,
**kwargs,
temperature=temperature,
max_tokens=max_output_tokens,
api_key=cfg.openai_api_key,
)
logger.debug(f"Response: {response}")
return response.choices[0].text
# Overly simple abstraction until we create something better
# simple retry mechanism when getting a rate error or a bad gateway
@metered
@retry_openai_api()
def create_chat_completion(
prompt: ChatSequence,
model: Optional[str] = None,
temperature: float = None,
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
) -> str:
"""Create a chat completion using the OpenAI API
@@ -209,41 +111,48 @@ def create_chat_completion(
logger.debug(
f"{Fore.GREEN}Creating chat completion with model {model}, temperature {temperature}, max_tokens {max_tokens}{Fore.RESET}"
)
chat_completion_kwargs = {
"model": model,
"temperature": temperature,
"max_tokens": max_tokens,
}
for plugin in cfg.plugins:
if plugin.can_handle_chat_completion(
messages=prompt.raw(),
model=model,
temperature=temperature,
max_tokens=max_tokens,
**chat_completion_kwargs,
):
message = plugin.handle_chat_completion(
messages=prompt.raw(),
model=model,
temperature=temperature,
max_tokens=max_tokens,
**chat_completion_kwargs,
)
if message is not None:
return message
api_manager = ApiManager()
response = None
chat_completion_kwargs["api_key"] = cfg.openai_api_key
if cfg.use_azure:
kwargs = {"deployment_id": cfg.get_azure_deployment_id_for_model(model)}
else:
kwargs = {"model": model}
chat_completion_kwargs["deployment_id"] = cfg.get_azure_deployment_id_for_model(
model
)
response = api_manager.create_chat_completion(
**kwargs,
response = iopenai.create_chat_completion(
messages=prompt.raw(),
temperature=temperature,
max_tokens=max_tokens,
**chat_completion_kwargs,
)
logger.debug(f"Response: {response}")
resp = ""
if not hasattr(response, "error"):
resp = response.choices[0].message["content"]
else:
logger.error(response.error)
raise RuntimeError(response.error)
resp = response.choices[0].message["content"]
for plugin in cfg.plugins:
if not plugin.can_handle_on_response():
continue
resp = plugin.on_response(resp)
return resp

View File

@@ -1,16 +1,14 @@
from typing import Any, overload
import numpy as np
import openai
from autogpt.config import Config
from autogpt.llm.utils import metered, retry_openai_api
from autogpt.llm.base import TText
from autogpt.llm.providers import openai as iopenai
from autogpt.logs import logger
Embedding = list[np.float32] | np.ndarray[Any, np.dtype[np.float32]]
"""Embedding vector"""
TText = list[int]
"""Token array representing text"""
@overload
@@ -23,8 +21,6 @@ def get_embedding(input: list[str] | list[TText]) -> list[Embedding]:
...
@metered
@retry_openai_api()
def get_embedding(
input: str | TText | list[str] | list[TText],
) -> Embedding | list[Embedding]:
@@ -57,10 +53,10 @@ def get_embedding(
+ (f" via Azure deployment '{kwargs['engine']}'" if cfg.use_azure else "")
)
embeddings = openai.Embedding.create(
input=input,
api_key=cfg.openai_api_key,
embeddings = iopenai.create_embedding(
input,
**kwargs,
api_key=cfg.openai_api_key,
).data
if not multiple:

View File

@@ -0,0 +1,67 @@
from unittest.mock import MagicMock, patch
import pytest
from autogpt.llm.api_manager import COSTS, ApiManager
from autogpt.llm.providers import openai
api_manager = ApiManager()
@pytest.fixture(autouse=True)
def reset_api_manager():
api_manager.reset()
yield
@pytest.fixture(autouse=True)
def mock_costs():
with patch.dict(
COSTS,
{
"gpt-3.5-turbo": {"prompt": 0.002, "completion": 0.002},
"text-embedding-ada-002": {"prompt": 0.0004, "completion": 0},
},
clear=True,
):
yield
class TestProviderOpenAI:
@staticmethod
def test_create_chat_completion_debug_mode(caplog):
"""Test if debug mode logs response."""
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Who won the world series in 2020?"},
]
model = "gpt-3.5-turbo"
with patch("openai.ChatCompletion.create") as mock_create:
mock_response = MagicMock()
del mock_response.error
mock_response.usage.prompt_tokens = 10
mock_response.usage.completion_tokens = 20
mock_create.return_value = mock_response
openai.create_chat_completion(messages, model=model)
assert "Response" in caplog.text
@staticmethod
def test_create_chat_completion_empty_messages():
"""Test if empty messages result in zero tokens and cost."""
messages = []
model = "gpt-3.5-turbo"
with patch("openai.ChatCompletion.create") as mock_create:
mock_response = MagicMock()
del mock_response.error
mock_response.usage.prompt_tokens = 0
mock_response.usage.completion_tokens = 0
mock_create.return_value = mock_response
openai.create_chat_completion(messages, model=model)
assert api_manager.get_total_prompt_tokens() == 0
assert api_manager.get_total_completion_tokens() == 0
assert api_manager.get_total_cost() == 0

View File

@@ -1,4 +1,4 @@
from unittest.mock import MagicMock, patch
from unittest.mock import patch
import pytest
@@ -27,68 +27,6 @@ def mock_costs():
class TestApiManager:
@staticmethod
def test_create_chat_completion_debug_mode(caplog):
"""Test if debug mode logs response."""
api_manager_debug = ApiManager(debug=True)
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Who won the world series in 2020?"},
]
model = "gpt-3.5-turbo"
with patch("openai.ChatCompletion.create") as mock_create:
mock_response = MagicMock()
del mock_response.error
mock_response.usage.prompt_tokens = 10
mock_response.usage.completion_tokens = 20
mock_create.return_value = mock_response
api_manager_debug.create_chat_completion(messages, model=model)
assert "Response" in caplog.text
@staticmethod
def test_create_chat_completion_empty_messages():
"""Test if empty messages result in zero tokens and cost."""
messages = []
model = "gpt-3.5-turbo"
with patch("openai.ChatCompletion.create") as mock_create:
mock_response = MagicMock()
del mock_response.error
mock_response.usage.prompt_tokens = 0
mock_response.usage.completion_tokens = 0
mock_create.return_value = mock_response
api_manager.create_chat_completion(messages, model=model)
assert api_manager.get_total_prompt_tokens() == 0
assert api_manager.get_total_completion_tokens() == 0
assert api_manager.get_total_cost() == 0
@staticmethod
def test_create_chat_completion_valid_inputs():
"""Test if valid inputs result in correct tokens and cost."""
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Who won the world series in 2020?"},
]
model = "gpt-3.5-turbo"
with patch("openai.ChatCompletion.create") as mock_create:
mock_response = MagicMock()
del mock_response.error
mock_response.usage.prompt_tokens = 10
mock_response.usage.completion_tokens = 20
mock_create.return_value = mock_response
api_manager.create_chat_completion(messages, model=model)
assert api_manager.get_total_prompt_tokens() == 10
assert api_manager.get_total_completion_tokens() == 20
assert api_manager.get_total_cost() == (10 * 0.002 + 20 * 0.002) / 1000
def test_getter_methods(self):
"""Test the getter methods for total tokens, cost, and budget."""
api_manager.update_cost(60, 120, "gpt-3.5-turbo")

View File

@@ -11,7 +11,7 @@ def test_make_agent(agent: Agent, mocker: MockerFixture) -> None:
mock = mocker.patch("openai.ChatCompletion.create")
response = MagicMock()
# del response.error
del response.error
response.choices[0].messages[0].content = "Test message"
response.usage.prompt_tokens = 1
response.usage.completion_tokens = 1

View File

@@ -0,0 +1,110 @@
import pytest
from openai.error import APIError, RateLimitError
from autogpt.llm.providers import openai
@pytest.fixture(params=[RateLimitError, APIError])
def error(request):
if request.param == APIError:
return request.param("Error", http_status=502)
else:
return request.param("Error")
def error_factory(error_instance, error_count, retry_count, warn_user=True):
"""Creates errors"""
class RaisesError:
def __init__(self):
self.count = 0
@openai.retry_api(
num_retries=retry_count, backoff_base=0.001, warn_user=warn_user
)
def __call__(self):
self.count += 1
if self.count <= error_count:
raise error_instance
return self.count
return RaisesError()
def test_retry_open_api_no_error(capsys):
"""Tests the retry functionality with no errors expected"""
@openai.retry_api()
def f():
return 1
result = f()
assert result == 1
output = capsys.readouterr()
assert output.out == ""
assert output.err == ""
@pytest.mark.parametrize(
"error_count, retry_count, failure",
[(2, 10, False), (2, 2, False), (10, 2, True), (3, 2, True), (1, 0, True)],
ids=["passing", "passing_edge", "failing", "failing_edge", "failing_no_retries"],
)
def test_retry_open_api_passing(capsys, error, error_count, retry_count, failure):
"""Tests the retry with simulated errors [RateLimitError, APIError], but should ulimately pass"""
call_count = min(error_count, retry_count) + 1
raises = error_factory(error, error_count, retry_count)
if failure:
with pytest.raises(type(error)):
raises()
else:
result = raises()
assert result == call_count
assert raises.count == call_count
output = capsys.readouterr()
if error_count and retry_count:
if type(error) == RateLimitError:
assert "Reached rate limit, passing..." in output.out
assert "Please double check" in output.out
if type(error) == APIError:
assert "API Bad gateway" in output.out
else:
assert output.out == ""
def test_retry_open_api_rate_limit_no_warn(capsys):
"""Tests the retry logic with a rate limit error"""
error_count = 2
retry_count = 10
raises = error_factory(RateLimitError, error_count, retry_count, warn_user=False)
result = raises()
call_count = min(error_count, retry_count) + 1
assert result == call_count
assert raises.count == call_count
output = capsys.readouterr()
assert "Reached rate limit, passing..." in output.out
assert "Please double check" not in output.out
def test_retry_openapi_other_api_error(capsys):
"""Tests the Retry logic with a non rate limit error such as HTTP500"""
error_count = 2
retry_count = 10
raises = error_factory(APIError("Error", http_status=500), error_count, retry_count)
with pytest.raises(APIError):
raises()
call_count = 1
assert raises.count == call_count
output = capsys.readouterr()
assert output.out == ""