Speed up CI (#4930)

* Match requests in cassette by hash

* Strip requests more extensively for VCR

* Sort JSON keys on cassette save

* Strip max_tokens from cassettes

* Improve logging in retry decorator

* Raise when quota exceeded

* Clean up @retry_api

* Fix @retry_api

* Remove dead tests/vcr/openai_filter.py

* Add debug logging to execute_python_file

* Make Docker CI pass
This commit is contained in:
Reinier van der Leer
2023-07-10 17:26:13 +02:00
committed by GitHub
parent 2b56996a27
commit 5e39dd1d26
8 changed files with 110 additions and 114 deletions

View File

@@ -148,8 +148,8 @@ jobs:
- name: Run pytest with coverage
run: |
pytest -v --cov=autogpt --cov-branch --cov-report term-missing --cov-report xml \
--numprocesses=4 --durations=10 \
pytest -vv --cov=autogpt --cov-branch --cov-report term-missing --cov-report xml \
--numprocesses=logical --durations=10 \
tests/unit tests/integration tests/challenges
python tests/challenges/utils/build_current_score.py
env:
@@ -247,7 +247,7 @@ jobs:
gh api repos/$REPO/issues/$PR_NUMBER/comments -X POST -F body="You changed AutoGPT's behaviour. The cassettes have been updated and will be merged to the submodule when this Pull Request gets merged."
fi
- name: Upload logs as artifact
- name: Upload logs to artifact
if: always()
uses: actions/upload-artifact@v3
with:

View File

@@ -103,6 +103,9 @@ def execute_python_file(filename: str, agent: Agent) -> str:
)
if we_are_running_in_a_docker_container():
logger.debug(
f"Auto-GPT is running in a Docker container; executing {file_path} directly..."
)
result = subprocess.run(
["python", str(file_path)],
capture_output=True,
@@ -114,6 +117,7 @@ def execute_python_file(filename: str, agent: Agent) -> str:
else:
return f"Error: {result.stderr}"
logger.debug("Auto-GPT is not running in a Docker container")
try:
client = docker.from_env()
# You can replace this with the desired Python image/version
@@ -122,10 +126,10 @@ def execute_python_file(filename: str, agent: Agent) -> str:
image_name = "python:3-alpine"
try:
client.images.get(image_name)
logger.warn(f"Image '{image_name}' found locally")
logger.debug(f"Image '{image_name}' found locally")
except ImageNotFound:
logger.info(
f"Image '{image_name}' not found locally, pulling from Docker Hub"
f"Image '{image_name}' not found locally, pulling from Docker Hub..."
)
# Use the low-level API to stream the pull response
low_level_client = docker.APIClient()
@@ -138,6 +142,7 @@ def execute_python_file(filename: str, agent: Agent) -> str:
elif status:
logger.info(status)
logger.debug(f"Running {file_path} in a {image_name} container...")
container: DockerContainer = client.containers.run(
image_name,
["python", str(file_path.relative_to(agent.workspace.root))],

View File

@@ -3,7 +3,7 @@ from __future__ import annotations
import functools
import time
from dataclasses import dataclass
from typing import List, Optional
from typing import Callable, List, Optional
from unittest.mock import patch
import openai
@@ -112,7 +112,7 @@ OPEN_AI_MODELS: dict[str, ChatModelInfo | EmbeddingModelInfo | TextModelInfo] =
}
def meter_api(func):
def meter_api(func: Callable):
"""Adds ApiManager metering to functions which make OpenAI API calls"""
from autogpt.llm.api_manager import ApiManager
@@ -150,7 +150,7 @@ def meter_api(func):
def retry_api(
num_retries: int = 10,
max_retries: int = 10,
backoff_base: float = 2.0,
warn_user: bool = True,
):
@@ -162,43 +162,49 @@ def retry_api(
warn_user bool: Whether to warn the user. Defaults to True.
"""
error_messages = {
ServiceUnavailableError: f"{Fore.RED}Error: The OpenAI API engine is currently overloaded, passing...{Fore.RESET}",
RateLimitError: f"{Fore.RED}Error: Reached rate limit, passing...{Fore.RESET}",
ServiceUnavailableError: f"{Fore.RED}Error: The OpenAI API engine is currently overloaded{Fore.RESET}",
RateLimitError: f"{Fore.RED}Error: Reached rate limit{Fore.RESET}",
}
api_key_error_msg = (
f"Please double check that you have setup a "
f"{Fore.CYAN + Style.BRIGHT}PAID{Style.RESET_ALL} OpenAI API Account. You can "
f"read more here: {Fore.CYAN}https://docs.agpt.co/setup/#getting-an-api-key{Fore.RESET}"
)
backoff_msg = (
f"{Fore.RED}Error: API Bad gateway. Waiting {{backoff}} seconds...{Fore.RESET}"
)
backoff_msg = f"{Fore.RED}Waiting {{backoff}} seconds...{Fore.RESET}"
def _wrapper(func):
def _wrapper(func: Callable):
@functools.wraps(func)
def _wrapped(*args, **kwargs):
user_warned = not warn_user
num_attempts = num_retries + 1 # +1 for the first attempt
for attempt in range(1, num_attempts + 1):
max_attempts = max_retries + 1 # +1 for the first attempt
for attempt in range(1, max_attempts + 1):
try:
return func(*args, **kwargs)
except (RateLimitError, ServiceUnavailableError) as e:
if attempt == num_attempts:
if attempt >= max_attempts or (
# User's API quota exceeded
isinstance(e, RateLimitError)
and (err := getattr(e, "error", {}))
and err.get("code") == "insufficient_quota"
):
raise
error_msg = error_messages[type(e)]
logger.debug(error_msg)
logger.warn(error_msg)
if not user_warned:
logger.double_check(api_key_error_msg)
logger.debug(f"Status: {e.http_status}")
logger.debug(f"Response body: {e.json_body}")
logger.debug(f"Response headers: {e.headers}")
user_warned = True
except (APIError, Timeout) as e:
if (e.http_status not in [429, 502]) or (attempt == num_attempts):
if (e.http_status not in [429, 502]) or (attempt == max_attempts):
raise
backoff = backoff_base ** (attempt + 2)
logger.debug(backoff_msg.format(backoff=backoff))
logger.warn(backoff_msg.format(backoff=backoff))
time.sleep(backoff)
return _wrapped

View File

@@ -20,7 +20,7 @@ def error_factory(error_instance, error_count, retry_count, warn_user=True):
self.count = 0
@openai.retry_api(
num_retries=retry_count, backoff_base=0.001, warn_user=warn_user
max_retries=retry_count, backoff_base=0.001, warn_user=warn_user
)
def __call__(self):
self.count += 1
@@ -69,16 +69,11 @@ def test_retry_open_api_passing(capsys, error, error_count, retry_count, failure
if error_count and retry_count:
if type(error) == RateLimitError:
assert "Reached rate limit, passing..." in output.out
assert "Reached rate limit" in output.out
assert "Please double check" in output.out
if type(error) == ServiceUnavailableError:
assert (
"The OpenAI API engine is currently overloaded, passing..."
in output.out
)
assert "The OpenAI API engine is currently overloaded" in output.out
assert "Please double check" in output.out
if type(error) == APIError:
assert "API Bad gateway" in output.out
else:
assert output.out == ""
@@ -96,7 +91,7 @@ def test_retry_open_api_rate_limit_no_warn(capsys):
output = capsys.readouterr()
assert "Reached rate limit, passing..." in output.out
assert "Reached rate limit" in output.out
assert "Please double check" not in output.out
@@ -115,7 +110,7 @@ def test_retry_open_api_service_unavairable_no_warn(capsys):
output = capsys.readouterr()
assert "The OpenAI API engine is currently overloaded, passing..." in output.out
assert "The OpenAI API engine is currently overloaded" in output.out
assert "Please double check" not in output.out

View File

@@ -1,10 +1,16 @@
import os
from hashlib import sha256
import openai.api_requestor
import pytest
from pytest_mock import MockerFixture
from .vcr_filter import PROXY, before_record_request, before_record_response
from .vcr_filter import (
PROXY,
before_record_request,
before_record_response,
freeze_request_body,
)
DEFAULT_RECORD_MODE = "new_episodes"
BASE_VCR_CONFIG = {
@@ -12,10 +18,13 @@ BASE_VCR_CONFIG = {
"before_record_response": before_record_response,
"filter_headers": [
"Authorization",
"AGENT-MODE",
"AGENT-TYPE",
"OpenAI-Organization",
"X-OpenAI-Client-User-Agent",
"User-Agent",
],
"match_on": ["method", "body"],
"match_on": ["method", "headers"],
}
@@ -41,7 +50,7 @@ def vcr_cassette_dir(request):
return os.path.join("tests/Auto-GPT-test-cassettes", test_name)
def patch_api_base(requestor):
def patch_api_base(requestor: openai.api_requestor.APIRequestor):
new_api_base = f"{PROXY}/v1"
requestor.api_base = new_api_base
return requestor
@@ -49,23 +58,35 @@ def patch_api_base(requestor):
@pytest.fixture
def patched_api_requestor(mocker: MockerFixture):
original_init = openai.api_requestor.APIRequestor.__init__
original_validate_headers = openai.api_requestor.APIRequestor._validate_headers
init_requestor = openai.api_requestor.APIRequestor.__init__
prepare_request = openai.api_requestor.APIRequestor._prepare_request_raw
def patched_init(requestor, *args, **kwargs):
original_init(requestor, *args, **kwargs)
def patched_init_requestor(requestor, *args, **kwargs):
init_requestor(requestor, *args, **kwargs)
patch_api_base(requestor)
def patched_validate_headers(self, supplied_headers):
headers = original_validate_headers(self, supplied_headers)
headers["AGENT-MODE"] = os.environ.get("AGENT_MODE")
headers["AGENT-TYPE"] = os.environ.get("AGENT_TYPE")
return headers
def patched_prepare_request(self, *args, **kwargs):
url, headers, data = prepare_request(self, *args, **kwargs)
if PROXY:
headers["AGENT-MODE"] = os.environ.get("AGENT_MODE")
headers["AGENT-TYPE"] = os.environ.get("AGENT_TYPE")
# Add hash header for cheap & fast matching on cassette playback
headers["X-Content-Hash"] = sha256(
freeze_request_body(data), usedforsecurity=False
).hexdigest()
return url, headers, data
if PROXY:
mocker.patch("openai.api_requestor.APIRequestor.__init__", new=patched_init)
mocker.patch.object(
openai.api_requestor.APIRequestor,
"_validate_headers",
new=patched_validate_headers,
"__init__",
new=patched_init_requestor,
)
mocker.patch.object(
openai.api_requestor.APIRequestor,
"_prepare_request_raw",
new=patched_prepare_request,
)

View File

@@ -1,52 +0,0 @@
import json
import re
def replace_timestamp_in_request(request):
# Check if the request body contains a JSON object
try:
if not request or not request.body:
return request
body = json.loads(request.body)
except ValueError:
return request
if "messages" not in body:
return request
for message in body["messages"]:
if "content" in message and "role" in message and message["role"] == "system":
timestamp_regex = re.compile(r"\w{3} \w{3} \d{2} \d{2}:\d{2}:\d{2} \d{4}")
message["content"] = timestamp_regex.sub(
"Tue Jan 01 00:00:00 2000", message["content"]
)
request.body = json.dumps(body)
return request
def before_record_response(response):
if "Transfer-Encoding" in response["headers"]:
del response["headers"]["Transfer-Encoding"]
return response
def before_record_request(request):
filtered_request = filter_hostnames(request)
filtered_request_without_dynamic_data = replace_timestamp_in_request(
filtered_request
)
return filtered_request_without_dynamic_data
def filter_hostnames(request):
allowed_hostnames = [
"api.openai.com",
"localhost:50337",
] # List of hostnames you want to allow
if any(hostname in request.url for hostname in allowed_hostnames):
return request
else:
return None

View File

@@ -1,8 +1,12 @@
import contextlib
import json
import os
import re
from io import BytesIO
from typing import Any, Dict, List
from vcr.request import Request
PROXY = os.environ.get("PROXY")
REPLACEMENTS: List[Dict[str, str]] = [
@@ -39,19 +43,20 @@ def replace_message_content(content: str, replacements: List[Dict[str, str]]) ->
return content
def replace_timestamp_in_request(request: Any) -> Any:
def freeze_request_body(json_body: str | bytes) -> bytes:
"""Remove any dynamic items from the request body"""
try:
if not request or not request.body:
return request
body = json.loads(request.body)
body = json.loads(json_body)
except ValueError:
return request
return json_body if type(json_body) == bytes else json_body.encode()
if "messages" not in body:
return request
body[
"max_tokens"
] = 0 # this field is inconsistent between requests and not used at the moment.
return json.dumps(body, sort_keys=True).encode()
if "max_tokens" in body:
del body["max_tokens"]
for message in body["messages"]:
if "content" in message and "role" in message:
if message["role"] == "system":
@@ -59,7 +64,20 @@ def replace_timestamp_in_request(request: Any) -> Any:
message["content"], REPLACEMENTS
)
request.body = json.dumps(body)
return json.dumps(body, sort_keys=True).encode()
def freeze_request(request: Request) -> Request:
if not request or not request.body:
return request
with contextlib.suppress(ValueError):
request.body = freeze_request_body(
request.body.getvalue()
if isinstance(request.body, BytesIO)
else request.body
)
return request
@@ -69,20 +87,23 @@ def before_record_response(response: Dict[str, Any]) -> Dict[str, Any]:
return response
def before_record_request(request: Any) -> Any:
def before_record_request(request: Request) -> Request | None:
request = replace_request_hostname(request, ORIGINAL_URL, NEW_URL)
filtered_request = filter_hostnames(request)
filtered_request_without_dynamic_data = replace_timestamp_in_request(
filtered_request
)
if not filtered_request:
return None
filtered_request_without_dynamic_data = freeze_request(filtered_request)
return filtered_request_without_dynamic_data
from urllib.parse import urlparse, urlunparse
def replace_request_hostname(request: Any, original_url: str, new_hostname: str) -> Any:
def replace_request_hostname(
request: Request, original_url: str, new_hostname: str
) -> Request:
parsed_url = urlparse(request.uri)
if parsed_url.hostname in original_url:
@@ -94,7 +115,7 @@ def replace_request_hostname(request: Any, original_url: str, new_hostname: str)
return request
def filter_hostnames(request: Any) -> Any:
def filter_hostnames(request: Request) -> Request | None:
# Add your implementation here for filtering hostnames
if any(hostname in request.url for hostname in ALLOWED_HOSTNAMES):
return request