mirror of
https://github.com/aljazceru/Auto-GPT.git
synced 2025-12-24 09:24:27 +01:00
Speed up CI (#4930)
* Match requests in cassette by hash * Strip requests more extensively for VCR * Sort JSON keys on cassette save * Strip max_tokens from cassettes * Improve logging in retry decorator * Raise when quota exceeded * Clean up @retry_api * Fix @retry_api * Remove dead tests/vcr/openai_filter.py * Add debug logging to execute_python_file * Make Docker CI pass
This commit is contained in:
committed by
GitHub
parent
2b56996a27
commit
5e39dd1d26
6
.github/workflows/ci.yml
vendored
6
.github/workflows/ci.yml
vendored
@@ -148,8 +148,8 @@ jobs:
|
||||
|
||||
- name: Run pytest with coverage
|
||||
run: |
|
||||
pytest -v --cov=autogpt --cov-branch --cov-report term-missing --cov-report xml \
|
||||
--numprocesses=4 --durations=10 \
|
||||
pytest -vv --cov=autogpt --cov-branch --cov-report term-missing --cov-report xml \
|
||||
--numprocesses=logical --durations=10 \
|
||||
tests/unit tests/integration tests/challenges
|
||||
python tests/challenges/utils/build_current_score.py
|
||||
env:
|
||||
@@ -247,7 +247,7 @@ jobs:
|
||||
gh api repos/$REPO/issues/$PR_NUMBER/comments -X POST -F body="You changed AutoGPT's behaviour. The cassettes have been updated and will be merged to the submodule when this Pull Request gets merged."
|
||||
fi
|
||||
|
||||
- name: Upload logs as artifact
|
||||
- name: Upload logs to artifact
|
||||
if: always()
|
||||
uses: actions/upload-artifact@v3
|
||||
with:
|
||||
|
||||
@@ -103,6 +103,9 @@ def execute_python_file(filename: str, agent: Agent) -> str:
|
||||
)
|
||||
|
||||
if we_are_running_in_a_docker_container():
|
||||
logger.debug(
|
||||
f"Auto-GPT is running in a Docker container; executing {file_path} directly..."
|
||||
)
|
||||
result = subprocess.run(
|
||||
["python", str(file_path)],
|
||||
capture_output=True,
|
||||
@@ -114,6 +117,7 @@ def execute_python_file(filename: str, agent: Agent) -> str:
|
||||
else:
|
||||
return f"Error: {result.stderr}"
|
||||
|
||||
logger.debug("Auto-GPT is not running in a Docker container")
|
||||
try:
|
||||
client = docker.from_env()
|
||||
# You can replace this with the desired Python image/version
|
||||
@@ -122,10 +126,10 @@ def execute_python_file(filename: str, agent: Agent) -> str:
|
||||
image_name = "python:3-alpine"
|
||||
try:
|
||||
client.images.get(image_name)
|
||||
logger.warn(f"Image '{image_name}' found locally")
|
||||
logger.debug(f"Image '{image_name}' found locally")
|
||||
except ImageNotFound:
|
||||
logger.info(
|
||||
f"Image '{image_name}' not found locally, pulling from Docker Hub"
|
||||
f"Image '{image_name}' not found locally, pulling from Docker Hub..."
|
||||
)
|
||||
# Use the low-level API to stream the pull response
|
||||
low_level_client = docker.APIClient()
|
||||
@@ -138,6 +142,7 @@ def execute_python_file(filename: str, agent: Agent) -> str:
|
||||
elif status:
|
||||
logger.info(status)
|
||||
|
||||
logger.debug(f"Running {file_path} in a {image_name} container...")
|
||||
container: DockerContainer = client.containers.run(
|
||||
image_name,
|
||||
["python", str(file_path.relative_to(agent.workspace.root))],
|
||||
|
||||
@@ -3,7 +3,7 @@ from __future__ import annotations
|
||||
import functools
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional
|
||||
from typing import Callable, List, Optional
|
||||
from unittest.mock import patch
|
||||
|
||||
import openai
|
||||
@@ -112,7 +112,7 @@ OPEN_AI_MODELS: dict[str, ChatModelInfo | EmbeddingModelInfo | TextModelInfo] =
|
||||
}
|
||||
|
||||
|
||||
def meter_api(func):
|
||||
def meter_api(func: Callable):
|
||||
"""Adds ApiManager metering to functions which make OpenAI API calls"""
|
||||
from autogpt.llm.api_manager import ApiManager
|
||||
|
||||
@@ -150,7 +150,7 @@ def meter_api(func):
|
||||
|
||||
|
||||
def retry_api(
|
||||
num_retries: int = 10,
|
||||
max_retries: int = 10,
|
||||
backoff_base: float = 2.0,
|
||||
warn_user: bool = True,
|
||||
):
|
||||
@@ -162,43 +162,49 @@ def retry_api(
|
||||
warn_user bool: Whether to warn the user. Defaults to True.
|
||||
"""
|
||||
error_messages = {
|
||||
ServiceUnavailableError: f"{Fore.RED}Error: The OpenAI API engine is currently overloaded, passing...{Fore.RESET}",
|
||||
RateLimitError: f"{Fore.RED}Error: Reached rate limit, passing...{Fore.RESET}",
|
||||
ServiceUnavailableError: f"{Fore.RED}Error: The OpenAI API engine is currently overloaded{Fore.RESET}",
|
||||
RateLimitError: f"{Fore.RED}Error: Reached rate limit{Fore.RESET}",
|
||||
}
|
||||
api_key_error_msg = (
|
||||
f"Please double check that you have setup a "
|
||||
f"{Fore.CYAN + Style.BRIGHT}PAID{Style.RESET_ALL} OpenAI API Account. You can "
|
||||
f"read more here: {Fore.CYAN}https://docs.agpt.co/setup/#getting-an-api-key{Fore.RESET}"
|
||||
)
|
||||
backoff_msg = (
|
||||
f"{Fore.RED}Error: API Bad gateway. Waiting {{backoff}} seconds...{Fore.RESET}"
|
||||
)
|
||||
backoff_msg = f"{Fore.RED}Waiting {{backoff}} seconds...{Fore.RESET}"
|
||||
|
||||
def _wrapper(func):
|
||||
def _wrapper(func: Callable):
|
||||
@functools.wraps(func)
|
||||
def _wrapped(*args, **kwargs):
|
||||
user_warned = not warn_user
|
||||
num_attempts = num_retries + 1 # +1 for the first attempt
|
||||
for attempt in range(1, num_attempts + 1):
|
||||
max_attempts = max_retries + 1 # +1 for the first attempt
|
||||
for attempt in range(1, max_attempts + 1):
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
|
||||
except (RateLimitError, ServiceUnavailableError) as e:
|
||||
if attempt == num_attempts:
|
||||
if attempt >= max_attempts or (
|
||||
# User's API quota exceeded
|
||||
isinstance(e, RateLimitError)
|
||||
and (err := getattr(e, "error", {}))
|
||||
and err.get("code") == "insufficient_quota"
|
||||
):
|
||||
raise
|
||||
|
||||
error_msg = error_messages[type(e)]
|
||||
logger.debug(error_msg)
|
||||
logger.warn(error_msg)
|
||||
if not user_warned:
|
||||
logger.double_check(api_key_error_msg)
|
||||
logger.debug(f"Status: {e.http_status}")
|
||||
logger.debug(f"Response body: {e.json_body}")
|
||||
logger.debug(f"Response headers: {e.headers}")
|
||||
user_warned = True
|
||||
|
||||
except (APIError, Timeout) as e:
|
||||
if (e.http_status not in [429, 502]) or (attempt == num_attempts):
|
||||
if (e.http_status not in [429, 502]) or (attempt == max_attempts):
|
||||
raise
|
||||
|
||||
backoff = backoff_base ** (attempt + 2)
|
||||
logger.debug(backoff_msg.format(backoff=backoff))
|
||||
logger.warn(backoff_msg.format(backoff=backoff))
|
||||
time.sleep(backoff)
|
||||
|
||||
return _wrapped
|
||||
|
||||
Submodule tests/Auto-GPT-test-cassettes updated: b36e755eef...d584872257
@@ -20,7 +20,7 @@ def error_factory(error_instance, error_count, retry_count, warn_user=True):
|
||||
self.count = 0
|
||||
|
||||
@openai.retry_api(
|
||||
num_retries=retry_count, backoff_base=0.001, warn_user=warn_user
|
||||
max_retries=retry_count, backoff_base=0.001, warn_user=warn_user
|
||||
)
|
||||
def __call__(self):
|
||||
self.count += 1
|
||||
@@ -69,16 +69,11 @@ def test_retry_open_api_passing(capsys, error, error_count, retry_count, failure
|
||||
|
||||
if error_count and retry_count:
|
||||
if type(error) == RateLimitError:
|
||||
assert "Reached rate limit, passing..." in output.out
|
||||
assert "Reached rate limit" in output.out
|
||||
assert "Please double check" in output.out
|
||||
if type(error) == ServiceUnavailableError:
|
||||
assert (
|
||||
"The OpenAI API engine is currently overloaded, passing..."
|
||||
in output.out
|
||||
)
|
||||
assert "The OpenAI API engine is currently overloaded" in output.out
|
||||
assert "Please double check" in output.out
|
||||
if type(error) == APIError:
|
||||
assert "API Bad gateway" in output.out
|
||||
else:
|
||||
assert output.out == ""
|
||||
|
||||
@@ -96,7 +91,7 @@ def test_retry_open_api_rate_limit_no_warn(capsys):
|
||||
|
||||
output = capsys.readouterr()
|
||||
|
||||
assert "Reached rate limit, passing..." in output.out
|
||||
assert "Reached rate limit" in output.out
|
||||
assert "Please double check" not in output.out
|
||||
|
||||
|
||||
@@ -115,7 +110,7 @@ def test_retry_open_api_service_unavairable_no_warn(capsys):
|
||||
|
||||
output = capsys.readouterr()
|
||||
|
||||
assert "The OpenAI API engine is currently overloaded, passing..." in output.out
|
||||
assert "The OpenAI API engine is currently overloaded" in output.out
|
||||
assert "Please double check" not in output.out
|
||||
|
||||
|
||||
|
||||
@@ -1,10 +1,16 @@
|
||||
import os
|
||||
from hashlib import sha256
|
||||
|
||||
import openai.api_requestor
|
||||
import pytest
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from .vcr_filter import PROXY, before_record_request, before_record_response
|
||||
from .vcr_filter import (
|
||||
PROXY,
|
||||
before_record_request,
|
||||
before_record_response,
|
||||
freeze_request_body,
|
||||
)
|
||||
|
||||
DEFAULT_RECORD_MODE = "new_episodes"
|
||||
BASE_VCR_CONFIG = {
|
||||
@@ -12,10 +18,13 @@ BASE_VCR_CONFIG = {
|
||||
"before_record_response": before_record_response,
|
||||
"filter_headers": [
|
||||
"Authorization",
|
||||
"AGENT-MODE",
|
||||
"AGENT-TYPE",
|
||||
"OpenAI-Organization",
|
||||
"X-OpenAI-Client-User-Agent",
|
||||
"User-Agent",
|
||||
],
|
||||
"match_on": ["method", "body"],
|
||||
"match_on": ["method", "headers"],
|
||||
}
|
||||
|
||||
|
||||
@@ -41,7 +50,7 @@ def vcr_cassette_dir(request):
|
||||
return os.path.join("tests/Auto-GPT-test-cassettes", test_name)
|
||||
|
||||
|
||||
def patch_api_base(requestor):
|
||||
def patch_api_base(requestor: openai.api_requestor.APIRequestor):
|
||||
new_api_base = f"{PROXY}/v1"
|
||||
requestor.api_base = new_api_base
|
||||
return requestor
|
||||
@@ -49,23 +58,35 @@ def patch_api_base(requestor):
|
||||
|
||||
@pytest.fixture
|
||||
def patched_api_requestor(mocker: MockerFixture):
|
||||
original_init = openai.api_requestor.APIRequestor.__init__
|
||||
original_validate_headers = openai.api_requestor.APIRequestor._validate_headers
|
||||
init_requestor = openai.api_requestor.APIRequestor.__init__
|
||||
prepare_request = openai.api_requestor.APIRequestor._prepare_request_raw
|
||||
|
||||
def patched_init(requestor, *args, **kwargs):
|
||||
original_init(requestor, *args, **kwargs)
|
||||
def patched_init_requestor(requestor, *args, **kwargs):
|
||||
init_requestor(requestor, *args, **kwargs)
|
||||
patch_api_base(requestor)
|
||||
|
||||
def patched_validate_headers(self, supplied_headers):
|
||||
headers = original_validate_headers(self, supplied_headers)
|
||||
headers["AGENT-MODE"] = os.environ.get("AGENT_MODE")
|
||||
headers["AGENT-TYPE"] = os.environ.get("AGENT_TYPE")
|
||||
return headers
|
||||
def patched_prepare_request(self, *args, **kwargs):
|
||||
url, headers, data = prepare_request(self, *args, **kwargs)
|
||||
|
||||
if PROXY:
|
||||
headers["AGENT-MODE"] = os.environ.get("AGENT_MODE")
|
||||
headers["AGENT-TYPE"] = os.environ.get("AGENT_TYPE")
|
||||
|
||||
# Add hash header for cheap & fast matching on cassette playback
|
||||
headers["X-Content-Hash"] = sha256(
|
||||
freeze_request_body(data), usedforsecurity=False
|
||||
).hexdigest()
|
||||
|
||||
return url, headers, data
|
||||
|
||||
if PROXY:
|
||||
mocker.patch("openai.api_requestor.APIRequestor.__init__", new=patched_init)
|
||||
mocker.patch.object(
|
||||
openai.api_requestor.APIRequestor,
|
||||
"_validate_headers",
|
||||
new=patched_validate_headers,
|
||||
"__init__",
|
||||
new=patched_init_requestor,
|
||||
)
|
||||
mocker.patch.object(
|
||||
openai.api_requestor.APIRequestor,
|
||||
"_prepare_request_raw",
|
||||
new=patched_prepare_request,
|
||||
)
|
||||
|
||||
@@ -1,52 +0,0 @@
|
||||
import json
|
||||
import re
|
||||
|
||||
|
||||
def replace_timestamp_in_request(request):
|
||||
# Check if the request body contains a JSON object
|
||||
|
||||
try:
|
||||
if not request or not request.body:
|
||||
return request
|
||||
body = json.loads(request.body)
|
||||
except ValueError:
|
||||
return request
|
||||
|
||||
if "messages" not in body:
|
||||
return request
|
||||
|
||||
for message in body["messages"]:
|
||||
if "content" in message and "role" in message and message["role"] == "system":
|
||||
timestamp_regex = re.compile(r"\w{3} \w{3} \d{2} \d{2}:\d{2}:\d{2} \d{4}")
|
||||
message["content"] = timestamp_regex.sub(
|
||||
"Tue Jan 01 00:00:00 2000", message["content"]
|
||||
)
|
||||
|
||||
request.body = json.dumps(body)
|
||||
return request
|
||||
|
||||
|
||||
def before_record_response(response):
|
||||
if "Transfer-Encoding" in response["headers"]:
|
||||
del response["headers"]["Transfer-Encoding"]
|
||||
return response
|
||||
|
||||
|
||||
def before_record_request(request):
|
||||
filtered_request = filter_hostnames(request)
|
||||
filtered_request_without_dynamic_data = replace_timestamp_in_request(
|
||||
filtered_request
|
||||
)
|
||||
return filtered_request_without_dynamic_data
|
||||
|
||||
|
||||
def filter_hostnames(request):
|
||||
allowed_hostnames = [
|
||||
"api.openai.com",
|
||||
"localhost:50337",
|
||||
] # List of hostnames you want to allow
|
||||
|
||||
if any(hostname in request.url for hostname in allowed_hostnames):
|
||||
return request
|
||||
else:
|
||||
return None
|
||||
@@ -1,8 +1,12 @@
|
||||
import contextlib
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from io import BytesIO
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from vcr.request import Request
|
||||
|
||||
PROXY = os.environ.get("PROXY")
|
||||
|
||||
REPLACEMENTS: List[Dict[str, str]] = [
|
||||
@@ -39,19 +43,20 @@ def replace_message_content(content: str, replacements: List[Dict[str, str]]) ->
|
||||
return content
|
||||
|
||||
|
||||
def replace_timestamp_in_request(request: Any) -> Any:
|
||||
def freeze_request_body(json_body: str | bytes) -> bytes:
|
||||
"""Remove any dynamic items from the request body"""
|
||||
|
||||
try:
|
||||
if not request or not request.body:
|
||||
return request
|
||||
body = json.loads(request.body)
|
||||
body = json.loads(json_body)
|
||||
except ValueError:
|
||||
return request
|
||||
return json_body if type(json_body) == bytes else json_body.encode()
|
||||
|
||||
if "messages" not in body:
|
||||
return request
|
||||
body[
|
||||
"max_tokens"
|
||||
] = 0 # this field is inconsistent between requests and not used at the moment.
|
||||
return json.dumps(body, sort_keys=True).encode()
|
||||
|
||||
if "max_tokens" in body:
|
||||
del body["max_tokens"]
|
||||
|
||||
for message in body["messages"]:
|
||||
if "content" in message and "role" in message:
|
||||
if message["role"] == "system":
|
||||
@@ -59,7 +64,20 @@ def replace_timestamp_in_request(request: Any) -> Any:
|
||||
message["content"], REPLACEMENTS
|
||||
)
|
||||
|
||||
request.body = json.dumps(body)
|
||||
return json.dumps(body, sort_keys=True).encode()
|
||||
|
||||
|
||||
def freeze_request(request: Request) -> Request:
|
||||
if not request or not request.body:
|
||||
return request
|
||||
|
||||
with contextlib.suppress(ValueError):
|
||||
request.body = freeze_request_body(
|
||||
request.body.getvalue()
|
||||
if isinstance(request.body, BytesIO)
|
||||
else request.body
|
||||
)
|
||||
|
||||
return request
|
||||
|
||||
|
||||
@@ -69,20 +87,23 @@ def before_record_response(response: Dict[str, Any]) -> Dict[str, Any]:
|
||||
return response
|
||||
|
||||
|
||||
def before_record_request(request: Any) -> Any:
|
||||
def before_record_request(request: Request) -> Request | None:
|
||||
request = replace_request_hostname(request, ORIGINAL_URL, NEW_URL)
|
||||
|
||||
filtered_request = filter_hostnames(request)
|
||||
filtered_request_without_dynamic_data = replace_timestamp_in_request(
|
||||
filtered_request
|
||||
)
|
||||
if not filtered_request:
|
||||
return None
|
||||
|
||||
filtered_request_without_dynamic_data = freeze_request(filtered_request)
|
||||
return filtered_request_without_dynamic_data
|
||||
|
||||
|
||||
from urllib.parse import urlparse, urlunparse
|
||||
|
||||
|
||||
def replace_request_hostname(request: Any, original_url: str, new_hostname: str) -> Any:
|
||||
def replace_request_hostname(
|
||||
request: Request, original_url: str, new_hostname: str
|
||||
) -> Request:
|
||||
parsed_url = urlparse(request.uri)
|
||||
|
||||
if parsed_url.hostname in original_url:
|
||||
@@ -94,7 +115,7 @@ def replace_request_hostname(request: Any, original_url: str, new_hostname: str)
|
||||
return request
|
||||
|
||||
|
||||
def filter_hostnames(request: Any) -> Any:
|
||||
def filter_hostnames(request: Request) -> Request | None:
|
||||
# Add your implementation here for filtering hostnames
|
||||
if any(hostname in request.url for hostname in ALLOWED_HOSTNAMES):
|
||||
return request
|
||||
|
||||
Reference in New Issue
Block a user