mirror of
https://github.com/aljazceru/Auto-GPT.git
synced 2025-12-21 16:04:21 +01:00
Extract OpenAI API retry handler and unify ADA embeddings calls. (#3191)
* Extract retry logic, unify embedding functions * Add some docstrings * Remove embedding creation from API manager * Add test suite for retry handler * Make api manager fixture * Fix typing * Streamline tests
This commit is contained in:
@@ -65,32 +65,6 @@ class ApiManager:
|
|||||||
self.update_cost(prompt_tokens, completion_tokens, model)
|
self.update_cost(prompt_tokens, completion_tokens, model)
|
||||||
return response
|
return response
|
||||||
|
|
||||||
def embedding_create(
|
|
||||||
self,
|
|
||||||
text_list: List[str],
|
|
||||||
model: str = "text-embedding-ada-002",
|
|
||||||
) -> List[float]:
|
|
||||||
"""
|
|
||||||
Create an embedding for the given input text using the specified model.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
text_list (List[str]): Input text for which the embedding is to be created.
|
|
||||||
model (str, optional): The model to use for generating the embedding.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List[float]: The generated embedding as a list of float values.
|
|
||||||
"""
|
|
||||||
if cfg.use_azure:
|
|
||||||
response = openai.Embedding.create(
|
|
||||||
input=text_list,
|
|
||||||
engine=cfg.get_azure_deployment_id_for_model(model),
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
response = openai.Embedding.create(input=text_list, model=model)
|
|
||||||
|
|
||||||
self.update_cost(response.usage.prompt_tokens, 0, model)
|
|
||||||
return response["data"][0]["embedding"]
|
|
||||||
|
|
||||||
def update_cost(self, prompt_tokens, completion_tokens, model):
|
def update_cost(self, prompt_tokens, completion_tokens, model):
|
||||||
"""
|
"""
|
||||||
Update the total cost, prompt tokens, and completion tokens.
|
Update the total cost, prompt tokens, and completion tokens.
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import functools
|
||||||
import time
|
import time
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
@@ -13,10 +14,62 @@ from autogpt.logs import logger
|
|||||||
from autogpt.types.openai import Message
|
from autogpt.types.openai import Message
|
||||||
|
|
||||||
CFG = Config()
|
CFG = Config()
|
||||||
|
|
||||||
openai.api_key = CFG.openai_api_key
|
openai.api_key = CFG.openai_api_key
|
||||||
|
|
||||||
|
|
||||||
|
def retry_openai_api(
|
||||||
|
num_retries: int = 10,
|
||||||
|
backoff_base: float = 2.0,
|
||||||
|
warn_user: bool = True,
|
||||||
|
):
|
||||||
|
"""Retry an OpenAI API call.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_retries int: Number of retries. Defaults to 10.
|
||||||
|
backoff_base float: Base for exponential backoff. Defaults to 2.
|
||||||
|
warn_user bool: Whether to warn the user. Defaults to True.
|
||||||
|
"""
|
||||||
|
retry_limit_msg = f"{Fore.RED}Error: " f"Reached rate limit, passing...{Fore.RESET}"
|
||||||
|
api_key_error_msg = (
|
||||||
|
f"Please double check that you have setup a "
|
||||||
|
f"{Fore.CYAN + Style.BRIGHT}PAID{Style.RESET_ALL} OpenAI API Account. You can "
|
||||||
|
f"read more here: {Fore.CYAN}https://github.com/Significant-Gravitas/Auto-GPT#openai-api-keys-configuration{Fore.RESET}"
|
||||||
|
)
|
||||||
|
backoff_msg = (
|
||||||
|
f"{Fore.RED}Error: API Bad gateway. Waiting {{backoff}} seconds...{Fore.RESET}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _wrapper(func):
|
||||||
|
@functools.wraps(func)
|
||||||
|
def _wrapped(*args, **kwargs):
|
||||||
|
user_warned = not warn_user
|
||||||
|
num_attempts = num_retries + 1 # +1 for the first attempt
|
||||||
|
for attempt in range(1, num_attempts + 1):
|
||||||
|
try:
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
|
except RateLimitError:
|
||||||
|
if attempt == num_attempts:
|
||||||
|
raise
|
||||||
|
|
||||||
|
logger.debug(retry_limit_msg)
|
||||||
|
if not user_warned:
|
||||||
|
logger.double_check(api_key_error_msg)
|
||||||
|
user_warned = True
|
||||||
|
|
||||||
|
except APIError as e:
|
||||||
|
if (e.http_status != 502) or (attempt == num_attempts):
|
||||||
|
raise
|
||||||
|
|
||||||
|
backoff = backoff_base ** (attempt + 2)
|
||||||
|
logger.debug(backoff_msg.format(backoff=backoff))
|
||||||
|
time.sleep(backoff)
|
||||||
|
|
||||||
|
return _wrapped
|
||||||
|
|
||||||
|
return _wrapper
|
||||||
|
|
||||||
|
|
||||||
def call_ai_function(
|
def call_ai_function(
|
||||||
function: str, args: list, description: str, model: str | None = None
|
function: str, args: list, description: str, model: str | None = None
|
||||||
) -> str:
|
) -> str:
|
||||||
@@ -154,32 +207,46 @@ def create_chat_completion(
|
|||||||
return resp
|
return resp
|
||||||
|
|
||||||
|
|
||||||
def get_ada_embedding(text):
|
def get_ada_embedding(text: str) -> List[int]:
|
||||||
|
"""Get an embedding from the ada model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text (str): The text to embed.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[int]: The embedding.
|
||||||
|
"""
|
||||||
|
model = "text-embedding-ada-002"
|
||||||
text = text.replace("\n", " ")
|
text = text.replace("\n", " ")
|
||||||
return api_manager.embedding_create(
|
|
||||||
text_list=[text], model="text-embedding-ada-002"
|
if CFG.use_azure:
|
||||||
|
kwargs = {"engine": CFG.get_azure_deployment_id_for_model(model)}
|
||||||
|
else:
|
||||||
|
kwargs = {"model": model}
|
||||||
|
|
||||||
|
embedding = create_embedding(text, **kwargs)
|
||||||
|
api_manager.update_cost(
|
||||||
|
prompt_tokens=embedding.usage.prompt_tokens,
|
||||||
|
completion_tokens=0,
|
||||||
|
model=model,
|
||||||
)
|
)
|
||||||
|
return embedding["data"][0]["embedding"]
|
||||||
|
|
||||||
|
|
||||||
def create_embedding_with_ada(text) -> list:
|
@retry_openai_api()
|
||||||
"""Create an embedding with text-ada-002 using the OpenAI SDK"""
|
def create_embedding(
|
||||||
num_retries = 10
|
text: str,
|
||||||
for attempt in range(num_retries):
|
*_,
|
||||||
backoff = 2 ** (attempt + 2)
|
**kwargs,
|
||||||
try:
|
) -> openai.Embedding:
|
||||||
return api_manager.embedding_create(
|
"""Create an embedding using the OpenAI API
|
||||||
text_list=[text], model="text-embedding-ada-002"
|
|
||||||
)
|
Args:
|
||||||
except RateLimitError:
|
text (str): The text to embed.
|
||||||
pass
|
kwargs: Other arguments to pass to the OpenAI API embedding creation call.
|
||||||
except (APIError, Timeout) as e:
|
|
||||||
if e.http_status != 502:
|
Returns:
|
||||||
raise
|
openai.Embedding: The embedding object.
|
||||||
if attempt == num_retries - 1:
|
"""
|
||||||
raise
|
|
||||||
if CFG.debug_mode:
|
return openai.Embedding.create(input=[text], **kwargs)
|
||||||
print(
|
|
||||||
f"{Fore.RED}Error: ",
|
|
||||||
f"API Bad gateway. Waiting {backoff} seconds...{Fore.RESET}",
|
|
||||||
)
|
|
||||||
time.sleep(backoff)
|
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ from typing import Any, List
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import orjson
|
import orjson
|
||||||
|
|
||||||
from autogpt.llm_utils import create_embedding_with_ada
|
from autogpt.llm_utils import get_ada_embedding
|
||||||
from autogpt.memory.base import MemoryProviderSingleton
|
from autogpt.memory.base import MemoryProviderSingleton
|
||||||
|
|
||||||
EMBED_DIM = 1536
|
EMBED_DIM = 1536
|
||||||
@@ -63,7 +63,7 @@ class LocalCache(MemoryProviderSingleton):
|
|||||||
return ""
|
return ""
|
||||||
self.data.texts.append(text)
|
self.data.texts.append(text)
|
||||||
|
|
||||||
embedding = create_embedding_with_ada(text)
|
embedding = get_ada_embedding(text)
|
||||||
|
|
||||||
vector = np.array(embedding).astype(np.float32)
|
vector = np.array(embedding).astype(np.float32)
|
||||||
vector = vector[np.newaxis, :]
|
vector = vector[np.newaxis, :]
|
||||||
@@ -111,7 +111,7 @@ class LocalCache(MemoryProviderSingleton):
|
|||||||
|
|
||||||
Returns: List[str]
|
Returns: List[str]
|
||||||
"""
|
"""
|
||||||
embedding = create_embedding_with_ada(text)
|
embedding = get_ada_embedding(text)
|
||||||
|
|
||||||
scores = np.dot(self.data.embeddings, embedding)
|
scores = np.dot(self.data.embeddings, embedding)
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
import pinecone
|
import pinecone
|
||||||
from colorama import Fore, Style
|
from colorama import Fore, Style
|
||||||
|
|
||||||
from autogpt.llm_utils import create_embedding_with_ada
|
from autogpt.llm_utils import get_ada_embedding
|
||||||
from autogpt.logs import logger
|
from autogpt.logs import logger
|
||||||
from autogpt.memory.base import MemoryProviderSingleton
|
from autogpt.memory.base import MemoryProviderSingleton
|
||||||
|
|
||||||
@@ -44,7 +44,7 @@ class PineconeMemory(MemoryProviderSingleton):
|
|||||||
self.index = pinecone.Index(table_name)
|
self.index = pinecone.Index(table_name)
|
||||||
|
|
||||||
def add(self, data):
|
def add(self, data):
|
||||||
vector = create_embedding_with_ada(data)
|
vector = get_ada_embedding(data)
|
||||||
# no metadata here. We may wish to change that long term.
|
# no metadata here. We may wish to change that long term.
|
||||||
self.index.upsert([(str(self.vec_num), vector, {"raw_text": data})])
|
self.index.upsert([(str(self.vec_num), vector, {"raw_text": data})])
|
||||||
_text = f"Inserting data into memory at index: {self.vec_num}:\n data: {data}"
|
_text = f"Inserting data into memory at index: {self.vec_num}:\n data: {data}"
|
||||||
@@ -64,7 +64,7 @@ class PineconeMemory(MemoryProviderSingleton):
|
|||||||
:param data: The data to compare to.
|
:param data: The data to compare to.
|
||||||
:param num_relevant: The number of relevant data to return. Defaults to 5
|
:param num_relevant: The number of relevant data to return. Defaults to 5
|
||||||
"""
|
"""
|
||||||
query_embedding = create_embedding_with_ada(data)
|
query_embedding = get_ada_embedding(data)
|
||||||
results = self.index.query(
|
results = self.index.query(
|
||||||
query_embedding, top_k=num_relevant, include_metadata=True
|
query_embedding, top_k=num_relevant, include_metadata=True
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ from redis.commands.search.field import TextField, VectorField
|
|||||||
from redis.commands.search.indexDefinition import IndexDefinition, IndexType
|
from redis.commands.search.indexDefinition import IndexDefinition, IndexType
|
||||||
from redis.commands.search.query import Query
|
from redis.commands.search.query import Query
|
||||||
|
|
||||||
from autogpt.llm_utils import create_embedding_with_ada
|
from autogpt.llm_utils import get_ada_embedding
|
||||||
from autogpt.logs import logger
|
from autogpt.logs import logger
|
||||||
from autogpt.memory.base import MemoryProviderSingleton
|
from autogpt.memory.base import MemoryProviderSingleton
|
||||||
|
|
||||||
@@ -88,7 +88,7 @@ class RedisMemory(MemoryProviderSingleton):
|
|||||||
"""
|
"""
|
||||||
if "Command Error:" in data:
|
if "Command Error:" in data:
|
||||||
return ""
|
return ""
|
||||||
vector = create_embedding_with_ada(data)
|
vector = get_ada_embedding(data)
|
||||||
vector = np.array(vector).astype(np.float32).tobytes()
|
vector = np.array(vector).astype(np.float32).tobytes()
|
||||||
data_dict = {b"data": data, "embedding": vector}
|
data_dict = {b"data": data, "embedding": vector}
|
||||||
pipe = self.redis.pipeline()
|
pipe = self.redis.pipeline()
|
||||||
@@ -130,7 +130,7 @@ class RedisMemory(MemoryProviderSingleton):
|
|||||||
|
|
||||||
Returns: A list of the most relevant data.
|
Returns: A list of the most relevant data.
|
||||||
"""
|
"""
|
||||||
query_embedding = create_embedding_with_ada(data)
|
query_embedding = get_ada_embedding(data)
|
||||||
base_query = f"*=>[KNN {num_relevant} @embedding $vector AS vector_score]"
|
base_query = f"*=>[KNN {num_relevant} @embedding $vector AS vector_score]"
|
||||||
query = (
|
query = (
|
||||||
Query(base_query)
|
Query(base_query)
|
||||||
|
|||||||
@@ -3,6 +3,8 @@ from pathlib import Path
|
|||||||
import pytest
|
import pytest
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
from autogpt.api_manager import ApiManager
|
||||||
|
from autogpt.api_manager import api_manager as api_manager_
|
||||||
from autogpt.config import Config
|
from autogpt.config import Config
|
||||||
from autogpt.workspace import Workspace
|
from autogpt.workspace import Workspace
|
||||||
|
|
||||||
@@ -29,3 +31,11 @@ def config(workspace: Workspace) -> Config:
|
|||||||
config.workspace_path = workspace.root
|
config.workspace_path = workspace.root
|
||||||
yield config
|
yield config
|
||||||
config.workspace_path = old_ws_path
|
config.workspace_path = old_ws_path
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def api_manager() -> ApiManager:
|
||||||
|
old_attrs = api_manager_.__dict__.copy()
|
||||||
|
api_manager_.reset()
|
||||||
|
yield api_manager_
|
||||||
|
api_manager_.__dict__.update(old_attrs)
|
||||||
|
|||||||
@@ -86,37 +86,6 @@ class TestApiManager:
|
|||||||
assert api_manager.get_total_completion_tokens() == 20
|
assert api_manager.get_total_completion_tokens() == 20
|
||||||
assert api_manager.get_total_cost() == (10 * 0.002 + 20 * 0.002) / 1000
|
assert api_manager.get_total_cost() == (10 * 0.002 + 20 * 0.002) / 1000
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def test_embedding_create_invalid_model():
|
|
||||||
"""Test if an invalid model for embedding raises a KeyError."""
|
|
||||||
text_list = ["Hello, how are you?"]
|
|
||||||
model = "invalid-model"
|
|
||||||
|
|
||||||
with patch("openai.Embedding.create") as mock_create:
|
|
||||||
mock_response = MagicMock()
|
|
||||||
mock_response.usage.prompt_tokens = 5
|
|
||||||
mock_create.side_effect = KeyError("Invalid model")
|
|
||||||
with pytest.raises(KeyError):
|
|
||||||
api_manager.embedding_create(text_list, model=model)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def test_embedding_create_valid_inputs():
|
|
||||||
"""Test if valid inputs for embedding result in correct tokens and cost."""
|
|
||||||
text_list = ["Hello, how are you?"]
|
|
||||||
model = "text-embedding-ada-002"
|
|
||||||
|
|
||||||
with patch("openai.Embedding.create") as mock_create:
|
|
||||||
mock_response = MagicMock()
|
|
||||||
mock_response.usage.prompt_tokens = 5
|
|
||||||
mock_response["data"] = [{"embedding": [0.1, 0.2, 0.3]}]
|
|
||||||
mock_create.return_value = mock_response
|
|
||||||
|
|
||||||
api_manager.embedding_create(text_list, model=model)
|
|
||||||
|
|
||||||
assert api_manager.get_total_prompt_tokens() == 5
|
|
||||||
assert api_manager.get_total_completion_tokens() == 0
|
|
||||||
assert api_manager.get_total_cost() == (5 * 0.0004) / 1000
|
|
||||||
|
|
||||||
def test_getter_methods(self):
|
def test_getter_methods(self):
|
||||||
"""Test the getter methods for total tokens, cost, and budget."""
|
"""Test the getter methods for total tokens, cost, and budget."""
|
||||||
api_manager.update_cost(60, 120, "gpt-3.5-turbo")
|
api_manager.update_cost(60, 120, "gpt-3.5-turbo")
|
||||||
|
|||||||
129
tests/test_llm_utils.py
Normal file
129
tests/test_llm_utils.py
Normal file
@@ -0,0 +1,129 @@
|
|||||||
|
import pytest
|
||||||
|
from openai.error import APIError, RateLimitError
|
||||||
|
|
||||||
|
from autogpt.llm_utils import get_ada_embedding, retry_openai_api
|
||||||
|
from autogpt.modelsinfo import COSTS
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(params=[RateLimitError, APIError])
|
||||||
|
def error(request):
|
||||||
|
if request.param == APIError:
|
||||||
|
return request.param("Error", http_status=502)
|
||||||
|
else:
|
||||||
|
return request.param("Error")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_create_embedding(mocker):
|
||||||
|
mock_response = mocker.MagicMock()
|
||||||
|
mock_response.usage.prompt_tokens = 5
|
||||||
|
mock_response.__getitem__.side_effect = lambda key: [{"embedding": [0.1, 0.2, 0.3]}]
|
||||||
|
return mocker.patch(
|
||||||
|
"autogpt.llm_utils.create_embedding", return_value=mock_response
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def error_factory(error_instance, error_count, retry_count, warn_user=True):
|
||||||
|
class RaisesError:
|
||||||
|
def __init__(self):
|
||||||
|
self.count = 0
|
||||||
|
|
||||||
|
@retry_openai_api(
|
||||||
|
num_retries=retry_count, backoff_base=0.001, warn_user=warn_user
|
||||||
|
)
|
||||||
|
def __call__(self):
|
||||||
|
self.count += 1
|
||||||
|
if self.count <= error_count:
|
||||||
|
raise error_instance
|
||||||
|
return self.count
|
||||||
|
|
||||||
|
return RaisesError()
|
||||||
|
|
||||||
|
|
||||||
|
def test_retry_open_api_no_error(capsys):
|
||||||
|
@retry_openai_api()
|
||||||
|
def f():
|
||||||
|
return 1
|
||||||
|
|
||||||
|
result = f()
|
||||||
|
assert result == 1
|
||||||
|
|
||||||
|
output = capsys.readouterr()
|
||||||
|
assert output.out == ""
|
||||||
|
assert output.err == ""
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"error_count, retry_count, failure",
|
||||||
|
[(2, 10, False), (2, 2, False), (10, 2, True), (3, 2, True), (1, 0, True)],
|
||||||
|
ids=["passing", "passing_edge", "failing", "failing_edge", "failing_no_retries"],
|
||||||
|
)
|
||||||
|
def test_retry_open_api_passing(capsys, error, error_count, retry_count, failure):
|
||||||
|
call_count = min(error_count, retry_count) + 1
|
||||||
|
|
||||||
|
raises = error_factory(error, error_count, retry_count)
|
||||||
|
if failure:
|
||||||
|
with pytest.raises(type(error)):
|
||||||
|
raises()
|
||||||
|
else:
|
||||||
|
result = raises()
|
||||||
|
assert result == call_count
|
||||||
|
|
||||||
|
assert raises.count == call_count
|
||||||
|
|
||||||
|
output = capsys.readouterr()
|
||||||
|
|
||||||
|
if error_count and retry_count:
|
||||||
|
if type(error) == RateLimitError:
|
||||||
|
assert "Reached rate limit, passing..." in output.out
|
||||||
|
assert "Please double check" in output.out
|
||||||
|
if type(error) == APIError:
|
||||||
|
assert "API Bad gateway" in output.out
|
||||||
|
else:
|
||||||
|
assert output.out == ""
|
||||||
|
|
||||||
|
|
||||||
|
def test_retry_open_api_rate_limit_no_warn(capsys):
|
||||||
|
error_count = 2
|
||||||
|
retry_count = 10
|
||||||
|
|
||||||
|
raises = error_factory(RateLimitError, error_count, retry_count, warn_user=False)
|
||||||
|
result = raises()
|
||||||
|
call_count = min(error_count, retry_count) + 1
|
||||||
|
assert result == call_count
|
||||||
|
assert raises.count == call_count
|
||||||
|
|
||||||
|
output = capsys.readouterr()
|
||||||
|
|
||||||
|
assert "Reached rate limit, passing..." in output.out
|
||||||
|
assert "Please double check" not in output.out
|
||||||
|
|
||||||
|
|
||||||
|
def test_retry_openapi_other_api_error(capsys):
|
||||||
|
error_count = 2
|
||||||
|
retry_count = 10
|
||||||
|
|
||||||
|
raises = error_factory(APIError("Error", http_status=500), error_count, retry_count)
|
||||||
|
|
||||||
|
with pytest.raises(APIError):
|
||||||
|
raises()
|
||||||
|
call_count = 1
|
||||||
|
assert raises.count == call_count
|
||||||
|
|
||||||
|
output = capsys.readouterr()
|
||||||
|
assert output.out == ""
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_ada_embedding(mock_create_embedding, api_manager):
|
||||||
|
model = "text-embedding-ada-002"
|
||||||
|
embedding = get_ada_embedding("test")
|
||||||
|
mock_create_embedding.assert_called_once_with(
|
||||||
|
"test", model="text-embedding-ada-002"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert embedding == [0.1, 0.2, 0.3]
|
||||||
|
|
||||||
|
cost = COSTS[model]["prompt"]
|
||||||
|
assert api_manager.get_total_prompt_tokens() == 5
|
||||||
|
assert api_manager.get_total_completion_tokens() == 0
|
||||||
|
assert api_manager.get_total_cost() == (5 * cost) / 1000
|
||||||
@@ -21,7 +21,7 @@ def LocalCache():
|
|||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_embed_with_ada(mocker):
|
def mock_embed_with_ada(mocker):
|
||||||
mocker.patch(
|
mocker.patch(
|
||||||
"autogpt.memory.local.create_embedding_with_ada",
|
"autogpt.memory.local.get_ada_embedding",
|
||||||
return_value=[0.1] * EMBED_DIM,
|
return_value=[0.1] * EMBED_DIM,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user