Files
goose/packages/exchange/tests/providers/conftest.py
2024-10-02 11:05:43 -07:00

132 lines
4.8 KiB
Python

import os
import re
from typing import Type, Tuple
import pytest
from exchange import Message, ToolUse, ToolResult, Tool
from exchange.providers import Usage, Provider
from tests.conftest import read_file
OPENAI_API_KEY = "test_openai_api_key"
OPENAI_ORG_ID = "test_openai_org_key"
OPENAI_PROJECT_ID = "test_openai_project_id"
@pytest.fixture
def default_openai_env(monkeypatch):
"""
This fixture prevents OpenAIProvider.from_env() from erring on missing
environment variables.
When running VCR tests for the first time or after deleting a cassette
recording, set required environment variables, so that real requests don't
fail. Subsequent runs use the recorded data, so don't need them.
"""
if "OPENAI_API_KEY" not in os.environ:
monkeypatch.setenv("OPENAI_API_KEY", OPENAI_API_KEY)
AZURE_ENDPOINT = "https://test.openai.azure.com"
AZURE_DEPLOYMENT_NAME = "test-azure-deployment"
AZURE_API_VERSION = "2024-05-01-preview"
AZURE_API_KEY = "test_azure_api_key"
@pytest.fixture
def default_azure_env(monkeypatch):
"""
This fixture prevents AzureProvider.from_env() from erring on missing
environment variables.
When running VCR tests for the first time or after deleting a cassette
recording, set required environment variables, so that real requests don't
fail. Subsequent runs use the recorded data, so don't need them.
"""
if "AZURE_CHAT_COMPLETIONS_HOST_NAME" not in os.environ:
monkeypatch.setenv("AZURE_CHAT_COMPLETIONS_HOST_NAME", AZURE_ENDPOINT)
if "AZURE_CHAT_COMPLETIONS_DEPLOYMENT_NAME" not in os.environ:
monkeypatch.setenv("AZURE_CHAT_COMPLETIONS_DEPLOYMENT_NAME", AZURE_DEPLOYMENT_NAME)
if "AZURE_CHAT_COMPLETIONS_DEPLOYMENT_API_VERSION" not in os.environ:
monkeypatch.setenv("AZURE_CHAT_COMPLETIONS_DEPLOYMENT_API_VERSION", AZURE_API_VERSION)
if "AZURE_CHAT_COMPLETIONS_KEY" not in os.environ:
monkeypatch.setenv("AZURE_CHAT_COMPLETIONS_KEY", AZURE_API_KEY)
@pytest.fixture(scope="module")
def vcr_config():
"""
This scrubs sensitive data and gunzips bodies when in recording mode.
Without this, you would leak cookies and auth tokens in the cassettes.
Also, depending on the request, some responses would be binary encoded
while others plain json. This ensures all bodies are human-readable.
"""
return {
"decode_compressed_response": True,
"filter_headers": [
("authorization", "Bearer " + OPENAI_API_KEY),
("openai-organization", OPENAI_ORG_ID),
("openai-project", OPENAI_PROJECT_ID),
("cookie", None),
],
"before_record_request": scrub_request_url,
"before_record_response": scrub_response_headers,
}
def scrub_request_url(request):
"""
This scrubs sensitive request data in provider-specific way. Note that headers
are case-sensitive!
"""
if "openai.azure.com" in request.uri:
request.uri = re.sub(r"https://[^/]+", AZURE_ENDPOINT, request.uri)
request.uri = re.sub(r"/deployments/[^/]+", f"/deployments/{AZURE_DEPLOYMENT_NAME}", request.uri)
request.headers["host"] = AZURE_ENDPOINT.replace("https://", "")
request.headers["api-key"] = AZURE_API_KEY
return request
def scrub_response_headers(response):
"""
This scrubs sensitive response headers. Note they are case-sensitive!
"""
response["headers"]["openai-organization"] = OPENAI_ORG_ID
response["headers"]["Set-Cookie"] = "test_set_cookie"
return response
def complete(provider_cls: Type[Provider], model: str, **kwargs) -> Tuple[Message, Usage]:
provider = provider_cls.from_env()
system = "You are a helpful assistant."
messages = [Message.user("Hello")]
return provider.complete(model=model, system=system, messages=messages, tools=None, **kwargs)
def tools(provider_cls: Type[Provider], model: str, **kwargs) -> Tuple[Message, Usage]:
provider = provider_cls.from_env()
system = "You are a helpful assistant. Expect to need to read a file using read_file."
messages = [Message.user("What are the contents of this file? test.txt")]
return provider.complete(
model=model, system=system, messages=messages, tools=(Tool.from_function(read_file),), **kwargs
)
def vision(provider_cls: Type[Provider], model: str, **kwargs) -> Tuple[Message, Usage]:
provider = provider_cls.from_env()
system = "You are a helpful assistant."
messages = [
Message.user("What does the first entry in the menu say?"),
Message(
role="assistant",
content=[ToolUse(id="xyz", name="screenshot", parameters={})],
),
Message(
role="user",
content=[ToolResult(tool_use_id="xyz", output='"image:tests/test_image.png"')],
),
]
return provider.complete(model=model, system=system, messages=messages, tools=None, **kwargs)