test(agent): Fix VCRpy request header filter for cross-platform cassette reuse (#7040)

- Move filtering logic from tests/vcr/__init__.py to tests/vcr/vcr_filter.py
- Ignore all `X-Stainless-*` headers for cassette matching, e.g. `X-Stainless-OS` and `X-Stainless-Runtime-Version`
- Remove deprecated OpenAI proxy logic
- Reorder methods in vcr_filter.py for readability
This commit is contained in:
Reinier van der Leer
2024-03-22 13:08:15 +01:00
committed by GitHub
parent 20041d65bf
commit 6dd76afad5
2 changed files with 64 additions and 88 deletions

View File

@@ -10,7 +10,6 @@ from openai._utils import is_given
from pytest_mock import MockerFixture from pytest_mock import MockerFixture
from .vcr_filter import ( from .vcr_filter import (
PROXY,
before_record_request, before_record_request,
before_record_response, before_record_response,
freeze_request_body, freeze_request_body,
@@ -20,15 +19,6 @@ DEFAULT_RECORD_MODE = "new_episodes"
BASE_VCR_CONFIG = { BASE_VCR_CONFIG = {
"before_record_request": before_record_request, "before_record_request": before_record_request,
"before_record_response": before_record_response, "before_record_response": before_record_response,
"filter_headers": [
"Authorization",
"AGENT-MODE",
"AGENT-TYPE",
"Cookie",
"OpenAI-Organization",
"X-OpenAI-Client-User-Agent",
"User-Agent",
],
"match_on": ["method", "headers"], "match_on": ["method", "headers"],
} }
@@ -69,10 +59,6 @@ def cached_openai_client(mocker: MockerFixture) -> OpenAI:
options.headers = headers options.headers = headers
data: dict = options.json_data data: dict = options.json_data
if PROXY:
headers["AGENT-MODE"] = os.environ.get("AGENT_MODE", Omit())
headers["AGENT-TYPE"] = os.environ.get("AGENT_TYPE", Omit())
logging.getLogger("cached_openai_client").debug( logging.getLogger("cached_openai_client").debug(
f"Outgoing API request: {headers}\n{data if data else None}" f"Outgoing API request: {headers}\n{data if data else None}"
) )
@@ -82,8 +68,6 @@ def cached_openai_client(mocker: MockerFixture) -> OpenAI:
freeze_request_body(data), usedforsecurity=False freeze_request_body(data), usedforsecurity=False
).hexdigest() ).hexdigest()
if PROXY:
client.base_url = f"{PROXY}/v1"
mocker.patch.object( mocker.patch.object(
client, client,
"_prepare_options", "_prepare_options",

View File

@@ -1,16 +1,27 @@
import contextlib import contextlib
import json import json
import os
import re import re
from io import BytesIO from io import BytesIO
from typing import Any, Dict, List from typing import Any
from urllib.parse import urlparse, urlunparse
from vcr.request import Request from vcr.request import Request
PROXY = os.environ.get("PROXY") HOSTNAMES_TO_CACHE: list[str] = [
"api.openai.com",
"localhost:50337",
"duckduckgo.com",
]
REPLACEMENTS: List[Dict[str, str]] = [ IGNORE_REQUEST_HEADERS: set[str | re.Pattern] = {
"Authorization",
"Cookie",
"OpenAI-Organization",
"X-OpenAI-Client-User-Agent",
"User-Agent",
re.compile(r"X-Stainless-[\w\-]+", re.IGNORECASE),
}
LLM_MESSAGE_REPLACEMENTS: list[dict[str, str]] = [
{ {
"regex": r"\w{3} \w{3} {1,2}\d{1,2} \d{2}:\d{2}:\d{2} \d{4}", "regex": r"\w{3} \w{3} {1,2}\d{1,2} \d{2}:\d{2}:\d{2} \d{4}",
"replacement": "Tue Jan 1 00:00:00 2000", "replacement": "Tue Jan 1 00:00:00 2000",
@@ -21,46 +32,33 @@ REPLACEMENTS: List[Dict[str, str]] = [
}, },
] ]
ALLOWED_HOSTNAMES: List[str] = [ OPENAI_URL = "api.openai.com"
"api.openai.com",
"localhost:50337",
"duckduckgo.com",
]
if PROXY:
ALLOWED_HOSTNAMES.append(PROXY)
ORIGINAL_URL = PROXY
else:
ORIGINAL_URL = "no_ci"
NEW_URL = "api.openai.com"
def replace_message_content(content: str, replacements: List[Dict[str, str]]) -> str: def before_record_request(request: Request) -> Request | None:
for replacement in replacements: if not should_cache_request(request):
pattern = re.compile(replacement["regex"]) return None
content = pattern.sub(replacement["replacement"], content)
return content request = filter_request_headers(request)
request = freeze_request(request)
return request
def freeze_request_body(body: dict) -> bytes: def should_cache_request(request: Request) -> bool:
"""Remove any dynamic items from the request body""" return any(hostname in request.url for hostname in HOSTNAMES_TO_CACHE)
if "messages" not in body:
return json.dumps(body, sort_keys=True).encode()
if "max_tokens" in body: def filter_request_headers(request: Request) -> Request:
del body["max_tokens"] for header_name in list(request.headers):
if any(
for message in body["messages"]: (
if "content" in message and "role" in message: (type(ignore) is str and ignore.lower() == header_name.lower())
if message["role"] == "system": or (isinstance(ignore, re.Pattern) and ignore.match(header_name))
message["content"] = replace_message_content(
message["content"], REPLACEMENTS
) )
for ignore in IGNORE_REQUEST_HEADERS
return json.dumps(body, sort_keys=True).encode() ):
del request.headers[header_name]
return request
def freeze_request(request: Request) -> Request: def freeze_request(request: Request) -> Request:
@@ -79,40 +77,34 @@ def freeze_request(request: Request) -> Request:
return request return request
def before_record_response(response: Dict[str, Any]) -> Dict[str, Any]: def freeze_request_body(body: dict) -> bytes:
"""Remove any dynamic items from the request body"""
if "messages" not in body:
return json.dumps(body, sort_keys=True).encode()
if "max_tokens" in body:
del body["max_tokens"]
for message in body["messages"]:
if "content" in message and "role" in message:
if message["role"] == "system":
message["content"] = replace_message_content(
message["content"], LLM_MESSAGE_REPLACEMENTS
)
return json.dumps(body, sort_keys=True).encode()
def replace_message_content(content: str, replacements: list[dict[str, str]]) -> str:
for replacement in replacements:
pattern = re.compile(replacement["regex"])
content = pattern.sub(replacement["replacement"], content)
return content
def before_record_response(response: dict[str, Any]) -> dict[str, Any]:
if "Transfer-Encoding" in response["headers"]: if "Transfer-Encoding" in response["headers"]:
del response["headers"]["Transfer-Encoding"] del response["headers"]["Transfer-Encoding"]
return response return response
def before_record_request(request: Request) -> Request | None:
request = replace_request_hostname(request, ORIGINAL_URL, NEW_URL)
filtered_request = filter_hostnames(request)
if not filtered_request:
return None
filtered_request_without_dynamic_data = freeze_request(filtered_request)
return filtered_request_without_dynamic_data
def replace_request_hostname(
request: Request, original_url: str, new_hostname: str
) -> Request:
parsed_url = urlparse(request.uri)
if parsed_url.hostname in original_url:
new_path = parsed_url.path.replace("/proxy_function", "")
request.uri = urlunparse(
parsed_url._replace(netloc=new_hostname, path=new_path, scheme="https")
)
return request
def filter_hostnames(request: Request) -> Request | None:
# Add your implementation here for filtering hostnames
if any(hostname in request.url for hostname in ALLOWED_HOSTNAMES):
return request
else:
return None