mirror of
https://github.com/aljazceru/goose.git
synced 2026-01-05 07:24:28 +01:00
test: convert Google Gemini tests to VCR (#118)
Signed-off-by: Adrian Cole <adrian.cole@elastic.co>
This commit is contained in:
@@ -0,0 +1,66 @@
|
||||
interactions:
|
||||
- request:
|
||||
body: '{"system_instruction": {"parts": [{"text": "You are a helpful assistant."}]},
|
||||
"contents": [{"role": "user", "parts": [{"text": "Hello"}]}]}'
|
||||
headers:
|
||||
accept:
|
||||
- '*/*'
|
||||
accept-encoding:
|
||||
- gzip, deflate
|
||||
connection:
|
||||
- keep-alive
|
||||
content-length:
|
||||
- '139'
|
||||
content-type:
|
||||
- application/json
|
||||
host:
|
||||
- generativelanguage.googleapis.com
|
||||
user-agent:
|
||||
- python-httpx/0.27.2
|
||||
method: POST
|
||||
uri: https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5-flash:generateContent?key=test_google_api_key
|
||||
response:
|
||||
body:
|
||||
string: "{\n \"candidates\": [\n {\n \"content\": {\n \"parts\":
|
||||
[\n {\n \"text\": \"Hello! \U0001F44B How can I help
|
||||
you today? \U0001F60A \\n\"\n }\n ],\n \"role\": \"model\"\n
|
||||
\ },\n \"finishReason\": \"STOP\",\n \"index\": 0,\n \"safetyRatings\":
|
||||
[\n {\n \"category\": \"HARM_CATEGORY_SEXUALLY_EXPLICIT\",\n
|
||||
\ \"probability\": \"NEGLIGIBLE\"\n },\n {\n \"category\":
|
||||
\"HARM_CATEGORY_HATE_SPEECH\",\n \"probability\": \"NEGLIGIBLE\"\n
|
||||
\ },\n {\n \"category\": \"HARM_CATEGORY_HARASSMENT\",\n
|
||||
\ \"probability\": \"NEGLIGIBLE\"\n },\n {\n \"category\":
|
||||
\"HARM_CATEGORY_DANGEROUS_CONTENT\",\n \"probability\": \"NEGLIGIBLE\"\n
|
||||
\ }\n ]\n }\n ],\n \"usageMetadata\": {\n \"promptTokenCount\":
|
||||
8,\n \"candidatesTokenCount\": 12,\n \"totalTokenCount\": 20\n }\n}\n"
|
||||
headers:
|
||||
Alt-Svc:
|
||||
- h3=":443"; ma=2592000,h3-29=":443"; ma=2592000
|
||||
Cache-Control:
|
||||
- private
|
||||
Content-Type:
|
||||
- application/json; charset=UTF-8
|
||||
Date:
|
||||
- Wed, 02 Oct 2024 01:06:50 GMT
|
||||
Server:
|
||||
- scaffolding on HTTPServer2
|
||||
Server-Timing:
|
||||
- gfet4t7; dur=426
|
||||
Transfer-Encoding:
|
||||
- chunked
|
||||
Vary:
|
||||
- Origin
|
||||
- X-Origin
|
||||
- Referer
|
||||
X-Content-Type-Options:
|
||||
- nosniff
|
||||
X-Frame-Options:
|
||||
- SAMEORIGIN
|
||||
X-XSS-Protection:
|
||||
- '0'
|
||||
content-length:
|
||||
- '855'
|
||||
status:
|
||||
code: 200
|
||||
message: OK
|
||||
version: 1
|
||||
@@ -0,0 +1,73 @@
|
||||
interactions:
|
||||
- request:
|
||||
body: '{"system_instruction": {"parts": [{"text": "You are a helpful assistant.
|
||||
Expect to need to read a file using read_file."}]}, "contents": [{"role": "user",
|
||||
"parts": [{"text": "What are the contents of this file? test.txt"}]}], "tools":
|
||||
{"functionDeclarations": [{"name": "read_file", "description": "Read the contents
|
||||
of the file.", "parameters": {"type": "object", "properties": {"filename": {"type":
|
||||
"string", "description": "The path to the file, which can be relative or\nabsolute.
|
||||
If it is a plain filename, it is assumed to be in the\ncurrent working directory."}},
|
||||
"required": ["filename"]}}]}}'
|
||||
headers:
|
||||
accept:
|
||||
- '*/*'
|
||||
accept-encoding:
|
||||
- gzip, deflate
|
||||
connection:
|
||||
- keep-alive
|
||||
content-length:
|
||||
- '600'
|
||||
content-type:
|
||||
- application/json
|
||||
host:
|
||||
- generativelanguage.googleapis.com
|
||||
user-agent:
|
||||
- python-httpx/0.27.2
|
||||
method: POST
|
||||
uri: https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5-flash:generateContent?key=test_google_api_key
|
||||
response:
|
||||
body:
|
||||
string: "{\n \"candidates\": [\n {\n \"content\": {\n \"parts\":
|
||||
[\n {\n \"functionCall\": {\n \"name\": \"read_file\",\n
|
||||
\ \"args\": {\n \"filename\": \"test.txt\"\n }\n
|
||||
\ }\n }\n ],\n \"role\": \"model\"\n },\n
|
||||
\ \"finishReason\": \"STOP\",\n \"index\": 0,\n \"safetyRatings\":
|
||||
[\n {\n \"category\": \"HARM_CATEGORY_SEXUALLY_EXPLICIT\",\n
|
||||
\ \"probability\": \"NEGLIGIBLE\"\n },\n {\n \"category\":
|
||||
\"HARM_CATEGORY_HATE_SPEECH\",\n \"probability\": \"NEGLIGIBLE\"\n
|
||||
\ },\n {\n \"category\": \"HARM_CATEGORY_HARASSMENT\",\n
|
||||
\ \"probability\": \"NEGLIGIBLE\"\n },\n {\n \"category\":
|
||||
\"HARM_CATEGORY_DANGEROUS_CONTENT\",\n \"probability\": \"NEGLIGIBLE\"\n
|
||||
\ }\n ]\n }\n ],\n \"usageMetadata\": {\n \"promptTokenCount\":
|
||||
101,\n \"candidatesTokenCount\": 17,\n \"totalTokenCount\": 118\n }\n}\n"
|
||||
headers:
|
||||
Alt-Svc:
|
||||
- h3=":443"; ma=2592000,h3-29=":443"; ma=2592000
|
||||
Cache-Control:
|
||||
- private
|
||||
Content-Type:
|
||||
- application/json; charset=UTF-8
|
||||
Date:
|
||||
- Wed, 02 Oct 2024 01:06:51 GMT
|
||||
Server:
|
||||
- scaffolding on HTTPServer2
|
||||
Server-Timing:
|
||||
- gfet4t7; dur=449
|
||||
Transfer-Encoding:
|
||||
- chunked
|
||||
Vary:
|
||||
- Origin
|
||||
- X-Origin
|
||||
- Referer
|
||||
X-Content-Type-Options:
|
||||
- nosniff
|
||||
X-Frame-Options:
|
||||
- SAMEORIGIN
|
||||
X-XSS-Protection:
|
||||
- '0'
|
||||
content-length:
|
||||
- '947'
|
||||
status:
|
||||
code: 200
|
||||
message: OK
|
||||
version: 1
|
||||
@@ -53,6 +53,23 @@ def default_azure_env(monkeypatch):
|
||||
monkeypatch.setenv("AZURE_CHAT_COMPLETIONS_KEY", AZURE_API_KEY)
|
||||
|
||||
|
||||
GOOGLE_API_KEY = "test_google_api_key"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def default_google_env(monkeypatch):
|
||||
"""
|
||||
This fixture prevents GoogleProvider.from_env() from erring on missing
|
||||
environment variables.
|
||||
|
||||
When running VCR tests for the first time or after deleting a cassette
|
||||
recording, set required environment variables, so that real requests don't
|
||||
fail. Subsequent runs use the recorded data, so don't need them.
|
||||
"""
|
||||
if "GOOGLE_API_KEY" not in os.environ:
|
||||
monkeypatch.setenv("GOOGLE_API_KEY", GOOGLE_API_KEY)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def vcr_config():
|
||||
"""
|
||||
@@ -85,6 +102,8 @@ def scrub_request_url(request):
|
||||
request.uri = re.sub(r"/deployments/[^/]+", f"/deployments/{AZURE_DEPLOYMENT_NAME}", request.uri)
|
||||
request.headers["host"] = AZURE_ENDPOINT.replace("https://", "")
|
||||
request.headers["api-key"] = AZURE_API_KEY
|
||||
elif "generativelanguage.googleapis.com" in request.uri:
|
||||
request.uri = re.sub(r"([?&])key=[^&]+", r"\1key=" + GOOGLE_API_KEY, request.uri)
|
||||
|
||||
return request
|
||||
|
||||
@@ -93,8 +112,10 @@ def scrub_response_headers(response):
|
||||
"""
|
||||
This scrubs sensitive response headers. Note they are case-sensitive!
|
||||
"""
|
||||
response["headers"]["openai-organization"] = OPENAI_ORG_ID
|
||||
response["headers"]["Set-Cookie"] = "test_set_cookie"
|
||||
if "openai-organization" in response["headers"]:
|
||||
response["headers"]["openai-organization"] = OPENAI_ORG_ID
|
||||
if "Set-Cookie" in response["headers"]:
|
||||
response["headers"]["Set-Cookie"] = "test_set_cookie"
|
||||
return response
|
||||
|
||||
|
||||
@@ -102,7 +123,7 @@ def complete(provider_cls: Type[Provider], model: str, **kwargs) -> Tuple[Messag
|
||||
provider = provider_cls.from_env()
|
||||
system = "You are a helpful assistant."
|
||||
messages = [Message.user("Hello")]
|
||||
return provider.complete(model=model, system=system, messages=messages, tools=None, **kwargs)
|
||||
return provider.complete(model=model, system=system, messages=messages, tools=(), **kwargs)
|
||||
|
||||
|
||||
def tools(provider_cls: Type[Provider], model: str, **kwargs) -> Tuple[Message, Usage]:
|
||||
@@ -128,4 +149,4 @@ def vision(provider_cls: Type[Provider], model: str, **kwargs) -> Tuple[Message,
|
||||
content=[ToolResult(tool_use_id="xyz", output='"image:tests/test_image.png"')],
|
||||
),
|
||||
]
|
||||
return provider.complete(model=model, system=system, messages=messages, tools=None, **kwargs)
|
||||
return provider.complete(model=model, system=system, messages=messages, tools=(), **kwargs)
|
||||
|
||||
@@ -1,13 +1,15 @@
|
||||
import os
|
||||
from unittest.mock import patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
from exchange import Message, Text
|
||||
from exchange.content import ToolResult, ToolUse
|
||||
from exchange.providers.base import MissingProviderEnvVariableError
|
||||
from exchange.providers.google import GoogleProvider
|
||||
from exchange.tool import Tool
|
||||
from .conftest import complete, tools
|
||||
|
||||
GOOGLE_MODEL = os.getenv("GOOGLE_MODEL", "gemini-1.5-flash")
|
||||
|
||||
|
||||
def example_fn(param: str) -> None:
|
||||
@@ -30,12 +32,6 @@ def test_from_env_throw_error_when_missing_api_key():
|
||||
assert "https://ai.google.dev/gemini-api/docs/api-key" in context.value.message
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@patch.dict(os.environ, {"GOOGLE_API_KEY": "test_api_key"})
|
||||
def google_provider():
|
||||
return GoogleProvider.from_env()
|
||||
|
||||
|
||||
def test_google_response_to_text_message() -> None:
|
||||
response = {"candidates": [{"content": {"parts": [{"text": "Hello from Gemini!"}], "role": "model"}}]}
|
||||
message = GoogleProvider.google_response_to_message(response)
|
||||
@@ -105,54 +101,40 @@ def test_messages_to_google_spec() -> None:
|
||||
assert actual_spec == expected_spec
|
||||
|
||||
|
||||
@patch("httpx.Client.post")
|
||||
@patch("logging.warning")
|
||||
@patch("logging.error")
|
||||
def test_google_completion(mock_error, mock_warning, mock_post, google_provider):
|
||||
mock_response = {
|
||||
"candidates": [{"content": {"parts": [{"text": "Hello from Gemini!"}], "role": "model"}}],
|
||||
"usageMetadata": {"promptTokenCount": 3, "candidatesTokenCount": 10, "totalTokenCount": 13},
|
||||
}
|
||||
@pytest.mark.vcr()
|
||||
def test_google_complete(default_google_env):
|
||||
reply_message, reply_usage = complete(GoogleProvider, GOOGLE_MODEL)
|
||||
|
||||
# First attempts fail with status code 429, 2nd succeeds
|
||||
def create_response(status_code, json_data=None):
|
||||
response = httpx.Response(status_code)
|
||||
response._content = httpx._content.json_dumps(json_data or {}).encode()
|
||||
response._request = httpx.Request("POST", "https://generativelanguage.googleapis.com/v1beta/")
|
||||
return response
|
||||
|
||||
mock_post.side_effect = [
|
||||
create_response(429), # 1st attempt
|
||||
create_response(200, mock_response), # Final success
|
||||
]
|
||||
|
||||
model = "gemini-1.5-flash"
|
||||
system = "You are a helpful assistant."
|
||||
messages = [Message.user("Hello, Gemini")]
|
||||
|
||||
reply_message, reply_usage = google_provider.complete(model=model, system=system, messages=messages)
|
||||
|
||||
assert reply_message.content == [Text(text="Hello from Gemini!")]
|
||||
assert reply_usage.total_tokens == 13
|
||||
assert mock_post.call_count == 2
|
||||
mock_post.assert_any_call(
|
||||
"models/gemini-1.5-flash:generateContent",
|
||||
json={
|
||||
"system_instruction": {"parts": [{"text": "You are a helpful assistant."}]},
|
||||
"contents": [{"role": "user", "parts": [{"text": "Hello, Gemini"}]}],
|
||||
},
|
||||
)
|
||||
assert reply_message.content == [Text("Hello! 👋 How can I help you today? 😊 \n")]
|
||||
assert reply_usage.total_tokens == 20
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_google_integration():
|
||||
provider = GoogleProvider.from_env()
|
||||
model = "gemini-1.5-flash" # updated model to a known valid model
|
||||
system = "You are a helpful assistant."
|
||||
messages = [Message.user("Hello, Gemini")]
|
||||
|
||||
# Run the completion
|
||||
reply = provider.complete(model=model, system=system, messages=messages)
|
||||
def test_google_complete_integration():
|
||||
reply = complete(GoogleProvider, GOOGLE_MODEL)
|
||||
|
||||
assert reply[0].content is not None
|
||||
print("Completion content from Google:", reply[0].content)
|
||||
|
||||
|
||||
@pytest.mark.vcr()
|
||||
def test_google_tools(default_google_env):
|
||||
reply_message, reply_usage = tools(GoogleProvider, GOOGLE_MODEL)
|
||||
|
||||
tool_use = reply_message.content[0]
|
||||
assert isinstance(tool_use, ToolUse), f"Expected ToolUse, but was {type(tool_use).__name__}"
|
||||
assert tool_use.id == "read_file"
|
||||
assert tool_use.name == "read_file"
|
||||
assert tool_use.parameters == {"filename": "test.txt"}
|
||||
assert reply_usage.total_tokens == 118
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_google_tools_integration():
|
||||
reply = tools(GoogleProvider, GOOGLE_MODEL)
|
||||
|
||||
tool_use = reply[0].content[0]
|
||||
assert isinstance(tool_use, ToolUse), f"Expected ToolUse, but was {type(tool_use).__name__}"
|
||||
assert tool_use.id is not None
|
||||
assert tool_use.name == "read_file"
|
||||
assert tool_use.parameters == {"filename": "test.txt"}
|
||||
|
||||
Reference in New Issue
Block a user