mirror of
https://github.com/aljazceru/Auto-GPT.git
synced 2025-12-17 14:04:27 +01:00
test(agent): Fix VCRpy request header filter for cross-platform cassette reuse (#7040)
- Move filtering logic from tests/vcr/__init__.py to tests/vcr/vcr_filter.py - Ignore all `X-Stainless-*` headers for cassette matching, e.g. `X-Stainless-OS` and `X-Stainless-Runtime-Version` - Remove deprecated OpenAI proxy logic - Reorder methods in vcr_filter.py for readability
This commit is contained in:
committed by
GitHub
parent
20041d65bf
commit
6dd76afad5
@@ -10,7 +10,6 @@ from openai._utils import is_given
|
|||||||
from pytest_mock import MockerFixture
|
from pytest_mock import MockerFixture
|
||||||
|
|
||||||
from .vcr_filter import (
|
from .vcr_filter import (
|
||||||
PROXY,
|
|
||||||
before_record_request,
|
before_record_request,
|
||||||
before_record_response,
|
before_record_response,
|
||||||
freeze_request_body,
|
freeze_request_body,
|
||||||
@@ -20,15 +19,6 @@ DEFAULT_RECORD_MODE = "new_episodes"
|
|||||||
BASE_VCR_CONFIG = {
|
BASE_VCR_CONFIG = {
|
||||||
"before_record_request": before_record_request,
|
"before_record_request": before_record_request,
|
||||||
"before_record_response": before_record_response,
|
"before_record_response": before_record_response,
|
||||||
"filter_headers": [
|
|
||||||
"Authorization",
|
|
||||||
"AGENT-MODE",
|
|
||||||
"AGENT-TYPE",
|
|
||||||
"Cookie",
|
|
||||||
"OpenAI-Organization",
|
|
||||||
"X-OpenAI-Client-User-Agent",
|
|
||||||
"User-Agent",
|
|
||||||
],
|
|
||||||
"match_on": ["method", "headers"],
|
"match_on": ["method", "headers"],
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -69,10 +59,6 @@ def cached_openai_client(mocker: MockerFixture) -> OpenAI:
|
|||||||
options.headers = headers
|
options.headers = headers
|
||||||
data: dict = options.json_data
|
data: dict = options.json_data
|
||||||
|
|
||||||
if PROXY:
|
|
||||||
headers["AGENT-MODE"] = os.environ.get("AGENT_MODE", Omit())
|
|
||||||
headers["AGENT-TYPE"] = os.environ.get("AGENT_TYPE", Omit())
|
|
||||||
|
|
||||||
logging.getLogger("cached_openai_client").debug(
|
logging.getLogger("cached_openai_client").debug(
|
||||||
f"Outgoing API request: {headers}\n{data if data else None}"
|
f"Outgoing API request: {headers}\n{data if data else None}"
|
||||||
)
|
)
|
||||||
@@ -82,8 +68,6 @@ def cached_openai_client(mocker: MockerFixture) -> OpenAI:
|
|||||||
freeze_request_body(data), usedforsecurity=False
|
freeze_request_body(data), usedforsecurity=False
|
||||||
).hexdigest()
|
).hexdigest()
|
||||||
|
|
||||||
if PROXY:
|
|
||||||
client.base_url = f"{PROXY}/v1"
|
|
||||||
mocker.patch.object(
|
mocker.patch.object(
|
||||||
client,
|
client,
|
||||||
"_prepare_options",
|
"_prepare_options",
|
||||||
|
|||||||
@@ -1,16 +1,27 @@
|
|||||||
import contextlib
|
import contextlib
|
||||||
import json
|
import json
|
||||||
import os
|
|
||||||
import re
|
import re
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from typing import Any, Dict, List
|
from typing import Any
|
||||||
from urllib.parse import urlparse, urlunparse
|
|
||||||
|
|
||||||
from vcr.request import Request
|
from vcr.request import Request
|
||||||
|
|
||||||
PROXY = os.environ.get("PROXY")
|
HOSTNAMES_TO_CACHE: list[str] = [
|
||||||
|
"api.openai.com",
|
||||||
|
"localhost:50337",
|
||||||
|
"duckduckgo.com",
|
||||||
|
]
|
||||||
|
|
||||||
REPLACEMENTS: List[Dict[str, str]] = [
|
IGNORE_REQUEST_HEADERS: set[str | re.Pattern] = {
|
||||||
|
"Authorization",
|
||||||
|
"Cookie",
|
||||||
|
"OpenAI-Organization",
|
||||||
|
"X-OpenAI-Client-User-Agent",
|
||||||
|
"User-Agent",
|
||||||
|
re.compile(r"X-Stainless-[\w\-]+", re.IGNORECASE),
|
||||||
|
}
|
||||||
|
|
||||||
|
LLM_MESSAGE_REPLACEMENTS: list[dict[str, str]] = [
|
||||||
{
|
{
|
||||||
"regex": r"\w{3} \w{3} {1,2}\d{1,2} \d{2}:\d{2}:\d{2} \d{4}",
|
"regex": r"\w{3} \w{3} {1,2}\d{1,2} \d{2}:\d{2}:\d{2} \d{4}",
|
||||||
"replacement": "Tue Jan 1 00:00:00 2000",
|
"replacement": "Tue Jan 1 00:00:00 2000",
|
||||||
@@ -21,46 +32,33 @@ REPLACEMENTS: List[Dict[str, str]] = [
|
|||||||
},
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
ALLOWED_HOSTNAMES: List[str] = [
|
OPENAI_URL = "api.openai.com"
|
||||||
"api.openai.com",
|
|
||||||
"localhost:50337",
|
|
||||||
"duckduckgo.com",
|
|
||||||
]
|
|
||||||
|
|
||||||
if PROXY:
|
|
||||||
ALLOWED_HOSTNAMES.append(PROXY)
|
|
||||||
ORIGINAL_URL = PROXY
|
|
||||||
else:
|
|
||||||
ORIGINAL_URL = "no_ci"
|
|
||||||
|
|
||||||
NEW_URL = "api.openai.com"
|
|
||||||
|
|
||||||
|
|
||||||
def replace_message_content(content: str, replacements: List[Dict[str, str]]) -> str:
|
def before_record_request(request: Request) -> Request | None:
|
||||||
for replacement in replacements:
|
if not should_cache_request(request):
|
||||||
pattern = re.compile(replacement["regex"])
|
return None
|
||||||
content = pattern.sub(replacement["replacement"], content)
|
|
||||||
|
|
||||||
return content
|
request = filter_request_headers(request)
|
||||||
|
request = freeze_request(request)
|
||||||
|
return request
|
||||||
|
|
||||||
|
|
||||||
def freeze_request_body(body: dict) -> bytes:
|
def should_cache_request(request: Request) -> bool:
|
||||||
"""Remove any dynamic items from the request body"""
|
return any(hostname in request.url for hostname in HOSTNAMES_TO_CACHE)
|
||||||
|
|
||||||
if "messages" not in body:
|
|
||||||
return json.dumps(body, sort_keys=True).encode()
|
|
||||||
|
|
||||||
if "max_tokens" in body:
|
def filter_request_headers(request: Request) -> Request:
|
||||||
del body["max_tokens"]
|
for header_name in list(request.headers):
|
||||||
|
if any(
|
||||||
for message in body["messages"]:
|
(
|
||||||
if "content" in message and "role" in message:
|
(type(ignore) is str and ignore.lower() == header_name.lower())
|
||||||
if message["role"] == "system":
|
or (isinstance(ignore, re.Pattern) and ignore.match(header_name))
|
||||||
message["content"] = replace_message_content(
|
)
|
||||||
message["content"], REPLACEMENTS
|
for ignore in IGNORE_REQUEST_HEADERS
|
||||||
)
|
):
|
||||||
|
del request.headers[header_name]
|
||||||
return json.dumps(body, sort_keys=True).encode()
|
return request
|
||||||
|
|
||||||
|
|
||||||
def freeze_request(request: Request) -> Request:
|
def freeze_request(request: Request) -> Request:
|
||||||
@@ -79,40 +77,34 @@ def freeze_request(request: Request) -> Request:
|
|||||||
return request
|
return request
|
||||||
|
|
||||||
|
|
||||||
def before_record_response(response: Dict[str, Any]) -> Dict[str, Any]:
|
def freeze_request_body(body: dict) -> bytes:
|
||||||
|
"""Remove any dynamic items from the request body"""
|
||||||
|
|
||||||
|
if "messages" not in body:
|
||||||
|
return json.dumps(body, sort_keys=True).encode()
|
||||||
|
|
||||||
|
if "max_tokens" in body:
|
||||||
|
del body["max_tokens"]
|
||||||
|
|
||||||
|
for message in body["messages"]:
|
||||||
|
if "content" in message and "role" in message:
|
||||||
|
if message["role"] == "system":
|
||||||
|
message["content"] = replace_message_content(
|
||||||
|
message["content"], LLM_MESSAGE_REPLACEMENTS
|
||||||
|
)
|
||||||
|
|
||||||
|
return json.dumps(body, sort_keys=True).encode()
|
||||||
|
|
||||||
|
|
||||||
|
def replace_message_content(content: str, replacements: list[dict[str, str]]) -> str:
|
||||||
|
for replacement in replacements:
|
||||||
|
pattern = re.compile(replacement["regex"])
|
||||||
|
content = pattern.sub(replacement["replacement"], content)
|
||||||
|
|
||||||
|
return content
|
||||||
|
|
||||||
|
|
||||||
|
def before_record_response(response: dict[str, Any]) -> dict[str, Any]:
|
||||||
if "Transfer-Encoding" in response["headers"]:
|
if "Transfer-Encoding" in response["headers"]:
|
||||||
del response["headers"]["Transfer-Encoding"]
|
del response["headers"]["Transfer-Encoding"]
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
def before_record_request(request: Request) -> Request | None:
|
|
||||||
request = replace_request_hostname(request, ORIGINAL_URL, NEW_URL)
|
|
||||||
|
|
||||||
filtered_request = filter_hostnames(request)
|
|
||||||
if not filtered_request:
|
|
||||||
return None
|
|
||||||
|
|
||||||
filtered_request_without_dynamic_data = freeze_request(filtered_request)
|
|
||||||
return filtered_request_without_dynamic_data
|
|
||||||
|
|
||||||
|
|
||||||
def replace_request_hostname(
|
|
||||||
request: Request, original_url: str, new_hostname: str
|
|
||||||
) -> Request:
|
|
||||||
parsed_url = urlparse(request.uri)
|
|
||||||
|
|
||||||
if parsed_url.hostname in original_url:
|
|
||||||
new_path = parsed_url.path.replace("/proxy_function", "")
|
|
||||||
request.uri = urlunparse(
|
|
||||||
parsed_url._replace(netloc=new_hostname, path=new_path, scheme="https")
|
|
||||||
)
|
|
||||||
|
|
||||||
return request
|
|
||||||
|
|
||||||
|
|
||||||
def filter_hostnames(request: Request) -> Request | None:
|
|
||||||
# Add your implementation here for filtering hostnames
|
|
||||||
if any(hostname in request.url for hostname in ALLOWED_HOSTNAMES):
|
|
||||||
return request
|
|
||||||
else:
|
|
||||||
return None
|
|
||||||
|
|||||||
Reference in New Issue
Block a user