From 9679f07d09edabd01909f4f9527f437ed53a5799 Mon Sep 17 00:00:00 2001 From: Drew Hintz Date: Sun, 13 Oct 2024 17:40:03 -0500 Subject: [PATCH] feat: add vision support for Google (#141) --- .../exchange/src/exchange/providers/google.py | 19 +++++++++++++++---- .../exchange/tests/test_integration_vision.py | 1 + 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/packages/exchange/src/exchange/providers/google.py b/packages/exchange/src/exchange/providers/google.py index fe83cd60..1bcac320 100644 --- a/packages/exchange/src/exchange/providers/google.py +++ b/packages/exchange/src/exchange/providers/google.py @@ -7,7 +7,7 @@ from exchange import Message, Tool from exchange.content import Text, ToolResult, ToolUse from exchange.providers.base import Provider, Usage from tenacity import retry, wait_fixed, stop_after_attempt -from exchange.providers.utils import raise_for_status, retry_if_status +from exchange.providers.utils import raise_for_status, retry_if_status, encode_image GOOGLE_HOST = "https://generativelanguage.googleapis.com/v1beta" @@ -111,9 +111,20 @@ class GoogleProvider(Provider): elif isinstance(content, ToolUse): converted["parts"].append({"functionCall": {"name": content.name, "args": content.parameters}}) elif isinstance(content, ToolResult): - converted["parts"].append( - {"functionResponse": {"name": content.tool_use_id, "response": {"content": content.output}}} - ) + if content.output.startswith('"image:'): + image_path = content.output.replace('"image:', "").replace('"', "") + converted["parts"].append( + { + "inline_data": { + "mime_type": "image/png", + "data": f"{encode_image(image_path)}", + } + } + ) + else: + converted["parts"].append( + {"functionResponse": {"name": content.tool_use_id, "response": {"content": content.output}}} + ) messages_spec.append(converted) if not messages_spec: diff --git a/packages/exchange/tests/test_integration_vision.py b/packages/exchange/tests/test_integration_vision.py index 20f165ad..6adf3f04 100644 --- a/packages/exchange/tests/test_integration_vision.py +++ b/packages/exchange/tests/test_integration_vision.py @@ -9,6 +9,7 @@ from exchange.providers import get_provider cases = [ (get_provider("openai"), os.getenv("OPENAI_MODEL", "gpt-4o-mini")), + (get_provider("google"), os.getenv("GOOGLE_MODEL", "gemini-1.5-flash")), ]