diff --git a/packages/exchange/tests/providers/cassettes/test_google_complete.yaml b/packages/exchange/tests/providers/cassettes/test_google_complete.yaml new file mode 100644 index 00000000..56ec5761 --- /dev/null +++ b/packages/exchange/tests/providers/cassettes/test_google_complete.yaml @@ -0,0 +1,66 @@ +interactions: +- request: + body: '{"system_instruction": {"parts": [{"text": "You are a helpful assistant."}]}, + "contents": [{"role": "user", "parts": [{"text": "Hello"}]}]}' + headers: + accept: + - '*/*' + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '139' + content-type: + - application/json + host: + - generativelanguage.googleapis.com + user-agent: + - python-httpx/0.27.2 + method: POST + uri: https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5-flash:generateContent?key=test_google_api_key + response: + body: + string: "{\n \"candidates\": [\n {\n \"content\": {\n \"parts\": + [\n {\n \"text\": \"Hello! \U0001F44B How can I help + you today? \U0001F60A \\n\"\n }\n ],\n \"role\": \"model\"\n + \ },\n \"finishReason\": \"STOP\",\n \"index\": 0,\n \"safetyRatings\": + [\n {\n \"category\": \"HARM_CATEGORY_SEXUALLY_EXPLICIT\",\n + \ \"probability\": \"NEGLIGIBLE\"\n },\n {\n \"category\": + \"HARM_CATEGORY_HATE_SPEECH\",\n \"probability\": \"NEGLIGIBLE\"\n + \ },\n {\n \"category\": \"HARM_CATEGORY_HARASSMENT\",\n + \ \"probability\": \"NEGLIGIBLE\"\n },\n {\n \"category\": + \"HARM_CATEGORY_DANGEROUS_CONTENT\",\n \"probability\": \"NEGLIGIBLE\"\n + \ }\n ]\n }\n ],\n \"usageMetadata\": {\n \"promptTokenCount\": + 8,\n \"candidatesTokenCount\": 12,\n \"totalTokenCount\": 20\n }\n}\n" + headers: + Alt-Svc: + - h3=":443"; ma=2592000,h3-29=":443"; ma=2592000 + Cache-Control: + - private + Content-Type: + - application/json; charset=UTF-8 + Date: + - Wed, 02 Oct 2024 01:06:50 GMT + Server: + - scaffolding on HTTPServer2 + Server-Timing: + - gfet4t7; dur=426 + Transfer-Encoding: + - chunked + Vary: + - Origin + - X-Origin + - Referer + X-Content-Type-Options: + - nosniff + X-Frame-Options: + - SAMEORIGIN + X-XSS-Protection: + - '0' + content-length: + - '855' + status: + code: 200 + message: OK +version: 1 diff --git a/packages/exchange/tests/providers/cassettes/test_google_tools.yaml b/packages/exchange/tests/providers/cassettes/test_google_tools.yaml new file mode 100644 index 00000000..f742f784 --- /dev/null +++ b/packages/exchange/tests/providers/cassettes/test_google_tools.yaml @@ -0,0 +1,73 @@ +interactions: +- request: + body: '{"system_instruction": {"parts": [{"text": "You are a helpful assistant. + Expect to need to read a file using read_file."}]}, "contents": [{"role": "user", + "parts": [{"text": "What are the contents of this file? test.txt"}]}], "tools": + {"functionDeclarations": [{"name": "read_file", "description": "Read the contents + of the file.", "parameters": {"type": "object", "properties": {"filename": {"type": + "string", "description": "The path to the file, which can be relative or\nabsolute. + If it is a plain filename, it is assumed to be in the\ncurrent working directory."}}, + "required": ["filename"]}}]}}' + headers: + accept: + - '*/*' + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '600' + content-type: + - application/json + host: + - generativelanguage.googleapis.com + user-agent: + - python-httpx/0.27.2 + method: POST + uri: https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5-flash:generateContent?key=test_google_api_key + response: + body: + string: "{\n \"candidates\": [\n {\n \"content\": {\n \"parts\": + [\n {\n \"functionCall\": {\n \"name\": \"read_file\",\n + \ \"args\": {\n \"filename\": \"test.txt\"\n }\n + \ }\n }\n ],\n \"role\": \"model\"\n },\n + \ \"finishReason\": \"STOP\",\n \"index\": 0,\n \"safetyRatings\": + [\n {\n \"category\": \"HARM_CATEGORY_SEXUALLY_EXPLICIT\",\n + \ \"probability\": \"NEGLIGIBLE\"\n },\n {\n \"category\": + \"HARM_CATEGORY_HATE_SPEECH\",\n \"probability\": \"NEGLIGIBLE\"\n + \ },\n {\n \"category\": \"HARM_CATEGORY_HARASSMENT\",\n + \ \"probability\": \"NEGLIGIBLE\"\n },\n {\n \"category\": + \"HARM_CATEGORY_DANGEROUS_CONTENT\",\n \"probability\": \"NEGLIGIBLE\"\n + \ }\n ]\n }\n ],\n \"usageMetadata\": {\n \"promptTokenCount\": + 101,\n \"candidatesTokenCount\": 17,\n \"totalTokenCount\": 118\n }\n}\n" + headers: + Alt-Svc: + - h3=":443"; ma=2592000,h3-29=":443"; ma=2592000 + Cache-Control: + - private + Content-Type: + - application/json; charset=UTF-8 + Date: + - Wed, 02 Oct 2024 01:06:51 GMT + Server: + - scaffolding on HTTPServer2 + Server-Timing: + - gfet4t7; dur=449 + Transfer-Encoding: + - chunked + Vary: + - Origin + - X-Origin + - Referer + X-Content-Type-Options: + - nosniff + X-Frame-Options: + - SAMEORIGIN + X-XSS-Protection: + - '0' + content-length: + - '947' + status: + code: 200 + message: OK +version: 1 diff --git a/packages/exchange/tests/providers/conftest.py b/packages/exchange/tests/providers/conftest.py index 010504e8..83d97be9 100644 --- a/packages/exchange/tests/providers/conftest.py +++ b/packages/exchange/tests/providers/conftest.py @@ -53,6 +53,23 @@ def default_azure_env(monkeypatch): monkeypatch.setenv("AZURE_CHAT_COMPLETIONS_KEY", AZURE_API_KEY) +GOOGLE_API_KEY = "test_google_api_key" + + +@pytest.fixture +def default_google_env(monkeypatch): + """ + This fixture prevents GoogleProvider.from_env() from erring on missing + environment variables. + + When running VCR tests for the first time or after deleting a cassette + recording, set required environment variables, so that real requests don't + fail. Subsequent runs use the recorded data, so don't need them. + """ + if "GOOGLE_API_KEY" not in os.environ: + monkeypatch.setenv("GOOGLE_API_KEY", GOOGLE_API_KEY) + + @pytest.fixture(scope="module") def vcr_config(): """ @@ -85,6 +102,8 @@ def scrub_request_url(request): request.uri = re.sub(r"/deployments/[^/]+", f"/deployments/{AZURE_DEPLOYMENT_NAME}", request.uri) request.headers["host"] = AZURE_ENDPOINT.replace("https://", "") request.headers["api-key"] = AZURE_API_KEY + elif "generativelanguage.googleapis.com" in request.uri: + request.uri = re.sub(r"([?&])key=[^&]+", r"\1key=" + GOOGLE_API_KEY, request.uri) return request @@ -93,8 +112,10 @@ def scrub_response_headers(response): """ This scrubs sensitive response headers. Note they are case-sensitive! """ - response["headers"]["openai-organization"] = OPENAI_ORG_ID - response["headers"]["Set-Cookie"] = "test_set_cookie" + if "openai-organization" in response["headers"]: + response["headers"]["openai-organization"] = OPENAI_ORG_ID + if "Set-Cookie" in response["headers"]: + response["headers"]["Set-Cookie"] = "test_set_cookie" return response @@ -102,7 +123,7 @@ def complete(provider_cls: Type[Provider], model: str, **kwargs) -> Tuple[Messag provider = provider_cls.from_env() system = "You are a helpful assistant." messages = [Message.user("Hello")] - return provider.complete(model=model, system=system, messages=messages, tools=None, **kwargs) + return provider.complete(model=model, system=system, messages=messages, tools=(), **kwargs) def tools(provider_cls: Type[Provider], model: str, **kwargs) -> Tuple[Message, Usage]: @@ -128,4 +149,4 @@ def vision(provider_cls: Type[Provider], model: str, **kwargs) -> Tuple[Message, content=[ToolResult(tool_use_id="xyz", output='"image:tests/test_image.png"')], ), ] - return provider.complete(model=model, system=system, messages=messages, tools=None, **kwargs) + return provider.complete(model=model, system=system, messages=messages, tools=(), **kwargs) diff --git a/packages/exchange/tests/providers/test_google.py b/packages/exchange/tests/providers/test_google.py index 76ae4c8d..16bb4462 100644 --- a/packages/exchange/tests/providers/test_google.py +++ b/packages/exchange/tests/providers/test_google.py @@ -1,13 +1,15 @@ import os from unittest.mock import patch -import httpx import pytest from exchange import Message, Text from exchange.content import ToolResult, ToolUse from exchange.providers.base import MissingProviderEnvVariableError from exchange.providers.google import GoogleProvider from exchange.tool import Tool +from .conftest import complete, tools + +GOOGLE_MODEL = os.getenv("GOOGLE_MODEL", "gemini-1.5-flash") def example_fn(param: str) -> None: @@ -30,12 +32,6 @@ def test_from_env_throw_error_when_missing_api_key(): assert "https://ai.google.dev/gemini-api/docs/api-key" in context.value.message -@pytest.fixture -@patch.dict(os.environ, {"GOOGLE_API_KEY": "test_api_key"}) -def google_provider(): - return GoogleProvider.from_env() - - def test_google_response_to_text_message() -> None: response = {"candidates": [{"content": {"parts": [{"text": "Hello from Gemini!"}], "role": "model"}}]} message = GoogleProvider.google_response_to_message(response) @@ -105,54 +101,40 @@ def test_messages_to_google_spec() -> None: assert actual_spec == expected_spec -@patch("httpx.Client.post") -@patch("logging.warning") -@patch("logging.error") -def test_google_completion(mock_error, mock_warning, mock_post, google_provider): - mock_response = { - "candidates": [{"content": {"parts": [{"text": "Hello from Gemini!"}], "role": "model"}}], - "usageMetadata": {"promptTokenCount": 3, "candidatesTokenCount": 10, "totalTokenCount": 13}, - } +@pytest.mark.vcr() +def test_google_complete(default_google_env): + reply_message, reply_usage = complete(GoogleProvider, GOOGLE_MODEL) - # First attempts fail with status code 429, 2nd succeeds - def create_response(status_code, json_data=None): - response = httpx.Response(status_code) - response._content = httpx._content.json_dumps(json_data or {}).encode() - response._request = httpx.Request("POST", "https://generativelanguage.googleapis.com/v1beta/") - return response - - mock_post.side_effect = [ - create_response(429), # 1st attempt - create_response(200, mock_response), # Final success - ] - - model = "gemini-1.5-flash" - system = "You are a helpful assistant." - messages = [Message.user("Hello, Gemini")] - - reply_message, reply_usage = google_provider.complete(model=model, system=system, messages=messages) - - assert reply_message.content == [Text(text="Hello from Gemini!")] - assert reply_usage.total_tokens == 13 - assert mock_post.call_count == 2 - mock_post.assert_any_call( - "models/gemini-1.5-flash:generateContent", - json={ - "system_instruction": {"parts": [{"text": "You are a helpful assistant."}]}, - "contents": [{"role": "user", "parts": [{"text": "Hello, Gemini"}]}], - }, - ) + assert reply_message.content == [Text("Hello! 👋 How can I help you today? 😊 \n")] + assert reply_usage.total_tokens == 20 @pytest.mark.integration -def test_google_integration(): - provider = GoogleProvider.from_env() - model = "gemini-1.5-flash" # updated model to a known valid model - system = "You are a helpful assistant." - messages = [Message.user("Hello, Gemini")] - - # Run the completion - reply = provider.complete(model=model, system=system, messages=messages) +def test_google_complete_integration(): + reply = complete(GoogleProvider, GOOGLE_MODEL) assert reply[0].content is not None print("Completion content from Google:", reply[0].content) + + +@pytest.mark.vcr() +def test_google_tools(default_google_env): + reply_message, reply_usage = tools(GoogleProvider, GOOGLE_MODEL) + + tool_use = reply_message.content[0] + assert isinstance(tool_use, ToolUse), f"Expected ToolUse, but was {type(tool_use).__name__}" + assert tool_use.id == "read_file" + assert tool_use.name == "read_file" + assert tool_use.parameters == {"filename": "test.txt"} + assert reply_usage.total_tokens == 118 + + +@pytest.mark.integration +def test_google_tools_integration(): + reply = tools(GoogleProvider, GOOGLE_MODEL) + + tool_use = reply[0].content[0] + assert isinstance(tool_use, ToolUse), f"Expected ToolUse, but was {type(tool_use).__name__}" + assert tool_use.id is not None + assert tool_use.name == "read_file" + assert tool_use.parameters == {"filename": "test.txt"}