Extract OpenAI API retry handler and unify ADA embeddings calls. (#3191)

* Extract retry logic, unify embedding functions

* Add some docstrings

* Remove embedding creation from API manager

* Add test suite for retry handler

* Make api manager fixture

* Fix typing

* Streamline tests
This commit is contained in:
James Collins
2023-04-25 11:12:24 -07:00
committed by GitHub
parent 940b115f0a
commit 2619740daa
9 changed files with 242 additions and 93 deletions

View File

@@ -1,5 +1,6 @@
from __future__ import annotations
import functools
import time
from typing import List, Optional
@@ -13,10 +14,62 @@ from autogpt.logs import logger
from autogpt.types.openai import Message
CFG = Config()
openai.api_key = CFG.openai_api_key
def retry_openai_api(
num_retries: int = 10,
backoff_base: float = 2.0,
warn_user: bool = True,
):
"""Retry an OpenAI API call.
Args:
num_retries int: Number of retries. Defaults to 10.
backoff_base float: Base for exponential backoff. Defaults to 2.
warn_user bool: Whether to warn the user. Defaults to True.
"""
retry_limit_msg = f"{Fore.RED}Error: " f"Reached rate limit, passing...{Fore.RESET}"
api_key_error_msg = (
f"Please double check that you have setup a "
f"{Fore.CYAN + Style.BRIGHT}PAID{Style.RESET_ALL} OpenAI API Account. You can "
f"read more here: {Fore.CYAN}https://github.com/Significant-Gravitas/Auto-GPT#openai-api-keys-configuration{Fore.RESET}"
)
backoff_msg = (
f"{Fore.RED}Error: API Bad gateway. Waiting {{backoff}} seconds...{Fore.RESET}"
)
def _wrapper(func):
@functools.wraps(func)
def _wrapped(*args, **kwargs):
user_warned = not warn_user
num_attempts = num_retries + 1 # +1 for the first attempt
for attempt in range(1, num_attempts + 1):
try:
return func(*args, **kwargs)
except RateLimitError:
if attempt == num_attempts:
raise
logger.debug(retry_limit_msg)
if not user_warned:
logger.double_check(api_key_error_msg)
user_warned = True
except APIError as e:
if (e.http_status != 502) or (attempt == num_attempts):
raise
backoff = backoff_base ** (attempt + 2)
logger.debug(backoff_msg.format(backoff=backoff))
time.sleep(backoff)
return _wrapped
return _wrapper
def call_ai_function(
function: str, args: list, description: str, model: str | None = None
) -> str:
@@ -154,32 +207,46 @@ def create_chat_completion(
return resp
def get_ada_embedding(text):
def get_ada_embedding(text: str) -> List[int]:
"""Get an embedding from the ada model.
Args:
text (str): The text to embed.
Returns:
List[int]: The embedding.
"""
model = "text-embedding-ada-002"
text = text.replace("\n", " ")
return api_manager.embedding_create(
text_list=[text], model="text-embedding-ada-002"
if CFG.use_azure:
kwargs = {"engine": CFG.get_azure_deployment_id_for_model(model)}
else:
kwargs = {"model": model}
embedding = create_embedding(text, **kwargs)
api_manager.update_cost(
prompt_tokens=embedding.usage.prompt_tokens,
completion_tokens=0,
model=model,
)
return embedding["data"][0]["embedding"]
def create_embedding_with_ada(text) -> list:
"""Create an embedding with text-ada-002 using the OpenAI SDK"""
num_retries = 10
for attempt in range(num_retries):
backoff = 2 ** (attempt + 2)
try:
return api_manager.embedding_create(
text_list=[text], model="text-embedding-ada-002"
)
except RateLimitError:
pass
except (APIError, Timeout) as e:
if e.http_status != 502:
raise
if attempt == num_retries - 1:
raise
if CFG.debug_mode:
print(
f"{Fore.RED}Error: ",
f"API Bad gateway. Waiting {backoff} seconds...{Fore.RESET}",
)
time.sleep(backoff)
@retry_openai_api()
def create_embedding(
text: str,
*_,
**kwargs,
) -> openai.Embedding:
"""Create an embedding using the OpenAI API
Args:
text (str): The text to embed.
kwargs: Other arguments to pass to the OpenAI API embedding creation call.
Returns:
openai.Embedding: The embedding object.
"""
return openai.Embedding.create(input=[text], **kwargs)