test: convert Google Gemini tests to VCR (#118)

Signed-off-by: Adrian Cole <adrian.cole@elastic.co>
This commit is contained in:
Adrian Cole
2024-10-10 08:04:26 +08:00
committed by GitHub
parent 8276e9b01f
commit 3eb5a93f1b
4 changed files with 197 additions and 55 deletions

View File

@@ -0,0 +1,66 @@
interactions:
- request:
body: '{"system_instruction": {"parts": [{"text": "You are a helpful assistant."}]},
"contents": [{"role": "user", "parts": [{"text": "Hello"}]}]}'
headers:
accept:
- '*/*'
accept-encoding:
- gzip, deflate
connection:
- keep-alive
content-length:
- '139'
content-type:
- application/json
host:
- generativelanguage.googleapis.com
user-agent:
- python-httpx/0.27.2
method: POST
uri: https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5-flash:generateContent?key=test_google_api_key
response:
body:
string: "{\n \"candidates\": [\n {\n \"content\": {\n \"parts\":
[\n {\n \"text\": \"Hello! \U0001F44B How can I help
you today? \U0001F60A \\n\"\n }\n ],\n \"role\": \"model\"\n
\ },\n \"finishReason\": \"STOP\",\n \"index\": 0,\n \"safetyRatings\":
[\n {\n \"category\": \"HARM_CATEGORY_SEXUALLY_EXPLICIT\",\n
\ \"probability\": \"NEGLIGIBLE\"\n },\n {\n \"category\":
\"HARM_CATEGORY_HATE_SPEECH\",\n \"probability\": \"NEGLIGIBLE\"\n
\ },\n {\n \"category\": \"HARM_CATEGORY_HARASSMENT\",\n
\ \"probability\": \"NEGLIGIBLE\"\n },\n {\n \"category\":
\"HARM_CATEGORY_DANGEROUS_CONTENT\",\n \"probability\": \"NEGLIGIBLE\"\n
\ }\n ]\n }\n ],\n \"usageMetadata\": {\n \"promptTokenCount\":
8,\n \"candidatesTokenCount\": 12,\n \"totalTokenCount\": 20\n }\n}\n"
headers:
Alt-Svc:
- h3=":443"; ma=2592000,h3-29=":443"; ma=2592000
Cache-Control:
- private
Content-Type:
- application/json; charset=UTF-8
Date:
- Wed, 02 Oct 2024 01:06:50 GMT
Server:
- scaffolding on HTTPServer2
Server-Timing:
- gfet4t7; dur=426
Transfer-Encoding:
- chunked
Vary:
- Origin
- X-Origin
- Referer
X-Content-Type-Options:
- nosniff
X-Frame-Options:
- SAMEORIGIN
X-XSS-Protection:
- '0'
content-length:
- '855'
status:
code: 200
message: OK
version: 1

View File

@@ -0,0 +1,73 @@
interactions:
- request:
body: '{"system_instruction": {"parts": [{"text": "You are a helpful assistant.
Expect to need to read a file using read_file."}]}, "contents": [{"role": "user",
"parts": [{"text": "What are the contents of this file? test.txt"}]}], "tools":
{"functionDeclarations": [{"name": "read_file", "description": "Read the contents
of the file.", "parameters": {"type": "object", "properties": {"filename": {"type":
"string", "description": "The path to the file, which can be relative or\nabsolute.
If it is a plain filename, it is assumed to be in the\ncurrent working directory."}},
"required": ["filename"]}}]}}'
headers:
accept:
- '*/*'
accept-encoding:
- gzip, deflate
connection:
- keep-alive
content-length:
- '600'
content-type:
- application/json
host:
- generativelanguage.googleapis.com
user-agent:
- python-httpx/0.27.2
method: POST
uri: https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5-flash:generateContent?key=test_google_api_key
response:
body:
string: "{\n \"candidates\": [\n {\n \"content\": {\n \"parts\":
[\n {\n \"functionCall\": {\n \"name\": \"read_file\",\n
\ \"args\": {\n \"filename\": \"test.txt\"\n }\n
\ }\n }\n ],\n \"role\": \"model\"\n },\n
\ \"finishReason\": \"STOP\",\n \"index\": 0,\n \"safetyRatings\":
[\n {\n \"category\": \"HARM_CATEGORY_SEXUALLY_EXPLICIT\",\n
\ \"probability\": \"NEGLIGIBLE\"\n },\n {\n \"category\":
\"HARM_CATEGORY_HATE_SPEECH\",\n \"probability\": \"NEGLIGIBLE\"\n
\ },\n {\n \"category\": \"HARM_CATEGORY_HARASSMENT\",\n
\ \"probability\": \"NEGLIGIBLE\"\n },\n {\n \"category\":
\"HARM_CATEGORY_DANGEROUS_CONTENT\",\n \"probability\": \"NEGLIGIBLE\"\n
\ }\n ]\n }\n ],\n \"usageMetadata\": {\n \"promptTokenCount\":
101,\n \"candidatesTokenCount\": 17,\n \"totalTokenCount\": 118\n }\n}\n"
headers:
Alt-Svc:
- h3=":443"; ma=2592000,h3-29=":443"; ma=2592000
Cache-Control:
- private
Content-Type:
- application/json; charset=UTF-8
Date:
- Wed, 02 Oct 2024 01:06:51 GMT
Server:
- scaffolding on HTTPServer2
Server-Timing:
- gfet4t7; dur=449
Transfer-Encoding:
- chunked
Vary:
- Origin
- X-Origin
- Referer
X-Content-Type-Options:
- nosniff
X-Frame-Options:
- SAMEORIGIN
X-XSS-Protection:
- '0'
content-length:
- '947'
status:
code: 200
message: OK
version: 1

View File

@@ -53,6 +53,23 @@ def default_azure_env(monkeypatch):
monkeypatch.setenv("AZURE_CHAT_COMPLETIONS_KEY", AZURE_API_KEY)
GOOGLE_API_KEY = "test_google_api_key"
@pytest.fixture
def default_google_env(monkeypatch):
"""
This fixture prevents GoogleProvider.from_env() from erring on missing
environment variables.
When running VCR tests for the first time or after deleting a cassette
recording, set required environment variables, so that real requests don't
fail. Subsequent runs use the recorded data, so don't need them.
"""
if "GOOGLE_API_KEY" not in os.environ:
monkeypatch.setenv("GOOGLE_API_KEY", GOOGLE_API_KEY)
@pytest.fixture(scope="module")
def vcr_config():
"""
@@ -85,6 +102,8 @@ def scrub_request_url(request):
request.uri = re.sub(r"/deployments/[^/]+", f"/deployments/{AZURE_DEPLOYMENT_NAME}", request.uri)
request.headers["host"] = AZURE_ENDPOINT.replace("https://", "")
request.headers["api-key"] = AZURE_API_KEY
elif "generativelanguage.googleapis.com" in request.uri:
request.uri = re.sub(r"([?&])key=[^&]+", r"\1key=" + GOOGLE_API_KEY, request.uri)
return request
@@ -93,8 +112,10 @@ def scrub_response_headers(response):
"""
This scrubs sensitive response headers. Note they are case-sensitive!
"""
response["headers"]["openai-organization"] = OPENAI_ORG_ID
response["headers"]["Set-Cookie"] = "test_set_cookie"
if "openai-organization" in response["headers"]:
response["headers"]["openai-organization"] = OPENAI_ORG_ID
if "Set-Cookie" in response["headers"]:
response["headers"]["Set-Cookie"] = "test_set_cookie"
return response
@@ -102,7 +123,7 @@ def complete(provider_cls: Type[Provider], model: str, **kwargs) -> Tuple[Messag
provider = provider_cls.from_env()
system = "You are a helpful assistant."
messages = [Message.user("Hello")]
return provider.complete(model=model, system=system, messages=messages, tools=None, **kwargs)
return provider.complete(model=model, system=system, messages=messages, tools=(), **kwargs)
def tools(provider_cls: Type[Provider], model: str, **kwargs) -> Tuple[Message, Usage]:
@@ -128,4 +149,4 @@ def vision(provider_cls: Type[Provider], model: str, **kwargs) -> Tuple[Message,
content=[ToolResult(tool_use_id="xyz", output='"image:tests/test_image.png"')],
),
]
return provider.complete(model=model, system=system, messages=messages, tools=None, **kwargs)
return provider.complete(model=model, system=system, messages=messages, tools=(), **kwargs)

View File

@@ -1,13 +1,15 @@
import os
from unittest.mock import patch
import httpx
import pytest
from exchange import Message, Text
from exchange.content import ToolResult, ToolUse
from exchange.providers.base import MissingProviderEnvVariableError
from exchange.providers.google import GoogleProvider
from exchange.tool import Tool
from .conftest import complete, tools
GOOGLE_MODEL = os.getenv("GOOGLE_MODEL", "gemini-1.5-flash")
def example_fn(param: str) -> None:
@@ -30,12 +32,6 @@ def test_from_env_throw_error_when_missing_api_key():
assert "https://ai.google.dev/gemini-api/docs/api-key" in context.value.message
@pytest.fixture
@patch.dict(os.environ, {"GOOGLE_API_KEY": "test_api_key"})
def google_provider():
return GoogleProvider.from_env()
def test_google_response_to_text_message() -> None:
response = {"candidates": [{"content": {"parts": [{"text": "Hello from Gemini!"}], "role": "model"}}]}
message = GoogleProvider.google_response_to_message(response)
@@ -105,54 +101,40 @@ def test_messages_to_google_spec() -> None:
assert actual_spec == expected_spec
@patch("httpx.Client.post")
@patch("logging.warning")
@patch("logging.error")
def test_google_completion(mock_error, mock_warning, mock_post, google_provider):
mock_response = {
"candidates": [{"content": {"parts": [{"text": "Hello from Gemini!"}], "role": "model"}}],
"usageMetadata": {"promptTokenCount": 3, "candidatesTokenCount": 10, "totalTokenCount": 13},
}
@pytest.mark.vcr()
def test_google_complete(default_google_env):
reply_message, reply_usage = complete(GoogleProvider, GOOGLE_MODEL)
# First attempts fail with status code 429, 2nd succeeds
def create_response(status_code, json_data=None):
response = httpx.Response(status_code)
response._content = httpx._content.json_dumps(json_data or {}).encode()
response._request = httpx.Request("POST", "https://generativelanguage.googleapis.com/v1beta/")
return response
mock_post.side_effect = [
create_response(429), # 1st attempt
create_response(200, mock_response), # Final success
]
model = "gemini-1.5-flash"
system = "You are a helpful assistant."
messages = [Message.user("Hello, Gemini")]
reply_message, reply_usage = google_provider.complete(model=model, system=system, messages=messages)
assert reply_message.content == [Text(text="Hello from Gemini!")]
assert reply_usage.total_tokens == 13
assert mock_post.call_count == 2
mock_post.assert_any_call(
"models/gemini-1.5-flash:generateContent",
json={
"system_instruction": {"parts": [{"text": "You are a helpful assistant."}]},
"contents": [{"role": "user", "parts": [{"text": "Hello, Gemini"}]}],
},
)
assert reply_message.content == [Text("Hello! 👋 How can I help you today? 😊 \n")]
assert reply_usage.total_tokens == 20
@pytest.mark.integration
def test_google_integration():
provider = GoogleProvider.from_env()
model = "gemini-1.5-flash" # updated model to a known valid model
system = "You are a helpful assistant."
messages = [Message.user("Hello, Gemini")]
# Run the completion
reply = provider.complete(model=model, system=system, messages=messages)
def test_google_complete_integration():
reply = complete(GoogleProvider, GOOGLE_MODEL)
assert reply[0].content is not None
print("Completion content from Google:", reply[0].content)
@pytest.mark.vcr()
def test_google_tools(default_google_env):
reply_message, reply_usage = tools(GoogleProvider, GOOGLE_MODEL)
tool_use = reply_message.content[0]
assert isinstance(tool_use, ToolUse), f"Expected ToolUse, but was {type(tool_use).__name__}"
assert tool_use.id == "read_file"
assert tool_use.name == "read_file"
assert tool_use.parameters == {"filename": "test.txt"}
assert reply_usage.total_tokens == 118
@pytest.mark.integration
def test_google_tools_integration():
reply = tools(GoogleProvider, GOOGLE_MODEL)
tool_use = reply[0].content[0]
assert isinstance(tool_use, ToolUse), f"Expected ToolUse, but was {type(tool_use).__name__}"
assert tool_use.id is not None
assert tool_use.name == "read_file"
assert tool_use.parameters == {"filename": "test.txt"}