From e461e6845d0a2a9859bbfd268f02c398eb06cb5c Mon Sep 17 00:00:00 2001 From: Michael Neale Date: Wed, 12 Feb 2025 14:36:46 -0800 Subject: [PATCH] feat: handle images mentioned in messages (#1202) --- crates/goose/src/providers/formats/openai.rs | 53 +++++- crates/goose/src/providers/utils.rs | 165 +++++++++++++++++++ 2 files changed, 216 insertions(+), 2 deletions(-) diff --git a/crates/goose/src/providers/formats/openai.rs b/crates/goose/src/providers/formats/openai.rs index 009b4289..b564a544 100644 --- a/crates/goose/src/providers/formats/openai.rs +++ b/crates/goose/src/providers/formats/openai.rs @@ -3,7 +3,8 @@ use crate::model::ModelConfig; use crate::providers::base::Usage; use crate::providers::errors::ProviderError; use crate::providers::utils::{ - convert_image, is_valid_function_name, sanitize_function_name, ImageFormat, + convert_image, detect_image_path, is_valid_function_name, load_image_file, + sanitize_function_name, ImageFormat, }; use anyhow::{anyhow, Error}; use mcp_core::ToolError; @@ -26,7 +27,21 @@ pub fn format_messages(messages: &[Message], image_format: &ImageFormat) -> Vec< match content { MessageContent::Text(text) => { if !text.text.is_empty() { - converted["content"] = json!(text.text); + // Check for image paths in the text + if let Some(image_path) = detect_image_path(&text.text) { + // Try to load and convert the image + if let Ok(image) = load_image_file(image_path) { + converted["content"] = json!([ + {"type": "text", "text": text.text}, + convert_image(&image, image_format) + ]); + } else { + // If image loading fails, just use the text + converted["content"] = json!(text.text); + } + } else { + converted["content"] = json!(text.text); + } } } MessageContent::ToolRequest(request) => match &request.tool_call { @@ -622,6 +637,40 @@ mod tests { Ok(()) } + #[test] + fn test_format_messages_with_image_path() -> anyhow::Result<()> { + // Create a temporary PNG file with valid PNG magic numbers + let temp_dir = tempfile::tempdir()?; + let png_path = temp_dir.path().join("test.png"); + let png_data = [ + 0x89, 0x50, 0x4E, 0x47, // PNG magic number + 0x0D, 0x0A, 0x1A, 0x0A, // PNG header + 0x00, 0x00, 0x00, 0x0D, // Rest of fake PNG data + ]; + std::fs::write(&png_path, &png_data)?; + let png_path_str = png_path.to_str().unwrap(); + + // Create message with image path + let message = Message::user().with_text(format!("Here is an image: {}", png_path_str)); + let spec = format_messages(&[message], &ImageFormat::OpenAi); + + assert_eq!(spec.len(), 1); + assert_eq!(spec[0]["role"], "user"); + + // Content should be an array with text and image + let content = spec[0]["content"].as_array().unwrap(); + assert_eq!(content.len(), 2); + assert_eq!(content[0]["type"], "text"); + assert!(content[0]["text"].as_str().unwrap().contains(png_path_str)); + assert_eq!(content[1]["type"], "image_url"); + assert!(content[1]["image_url"]["url"] + .as_str() + .unwrap() + .starts_with("data:image/png;base64,")); + + Ok(()) + } + #[test] fn test_response_to_message_text() -> anyhow::Result<()> { let response = json!({ diff --git a/crates/goose/src/providers/utils.rs b/crates/goose/src/providers/utils.rs index e7011411..52881470 100644 --- a/crates/goose/src/providers/utils.rs +++ b/crates/goose/src/providers/utils.rs @@ -1,9 +1,12 @@ use super::base::Usage; use anyhow::Result; +use base64::Engine; use regex::Regex; use reqwest::{Response, StatusCode}; use serde::{Deserialize, Serialize}; use serde_json::{json, Map, Value}; +use std::io::Read; +use std::path::Path; use crate::providers::errors::ProviderError; use mcp_core::content::ImageContent; @@ -110,6 +113,91 @@ pub fn get_model(data: &Value) -> String { } } +/// Check if a file is actually an image by examining its magic bytes +fn is_image_file(path: &Path) -> bool { + if let Ok(mut file) = std::fs::File::open(path) { + let mut buffer = [0u8; 8]; // Large enough for most image magic numbers + if file.read(&mut buffer).is_ok() { + // Check magic numbers for common image formats + return match &buffer[0..4] { + // PNG: 89 50 4E 47 + [0x89, 0x50, 0x4E, 0x47] => true, + // JPEG: FF D8 FF + [0xFF, 0xD8, 0xFF, _] => true, + _ => false, + }; + } + } + false +} + +/// Detect if a string contains a path to an image file +pub fn detect_image_path(text: &str) -> Option<&str> { + // Basic image file extension check + let extensions = [".png", ".jpg", ".jpeg"]; + + // Find any word that ends with an image extension + for word in text.split_whitespace() { + if extensions + .iter() + .any(|ext| word.to_lowercase().ends_with(ext)) + { + let path = Path::new(word); + // Check if it's an absolute path and file exists + if path.is_absolute() && path.is_file() { + // Verify it's actually an image file + if is_image_file(path) { + return Some(word); + } + } + } + } + None +} + +/// Convert a local image file to base64 encoded ImageContent +pub fn load_image_file(path: &str) -> Result { + let path = Path::new(path); + + // Verify it's an image before proceeding + if !is_image_file(path) { + return Err(ProviderError::RequestFailed( + "File is not a valid image".to_string(), + )); + } + + // Read the file + let bytes = std::fs::read(path) + .map_err(|e| ProviderError::RequestFailed(format!("Failed to read image file: {}", e)))?; + + // Detect mime type from extension + let mime_type = match path.extension().and_then(|e| e.to_str()) { + Some(ext) => match ext.to_lowercase().as_str() { + "png" => "image/png", + "jpg" | "jpeg" => "image/jpeg", + _ => { + return Err(ProviderError::RequestFailed( + "Unsupported image format".to_string(), + )) + } + }, + None => { + return Err(ProviderError::RequestFailed( + "Unknown image format".to_string(), + )) + } + }; + + // Convert to base64 + let data = base64::prelude::BASE64_STANDARD.encode(&bytes); + + Ok(ImageContent { + mime_type: mime_type.to_string(), + data, + annotations: None, + }) +} + pub fn unescape_json_values(value: &Value) -> Value { match value { Value::Object(map) => { @@ -165,6 +253,83 @@ pub fn emit_debug_trace( mod tests { use super::*; use serde_json::json; + use std::io::Write; + use tempfile::NamedTempFile; + + #[test] + fn test_detect_image_path() { + // Create a temporary PNG file with valid PNG magic numbers + let temp_dir = tempfile::tempdir().unwrap(); + let png_path = temp_dir.path().join("test.png"); + let png_data = [ + 0x89, 0x50, 0x4E, 0x47, // PNG magic number + 0x0D, 0x0A, 0x1A, 0x0A, // PNG header + 0x00, 0x00, 0x00, 0x0D, // Rest of fake PNG data + ]; + std::fs::write(&png_path, &png_data).unwrap(); + let png_path_str = png_path.to_str().unwrap(); + + // Create a fake PNG (wrong magic numbers) + let fake_png_path = temp_dir.path().join("fake.png"); + std::fs::write(&fake_png_path, b"not a real png").unwrap(); + + // Test with valid PNG file using absolute path + let text = format!("Here is an image {}", png_path_str); + assert_eq!(detect_image_path(&text), Some(png_path_str)); + + // Test with non-image file that has .png extension + let text = format!("Here is a fake image {}", fake_png_path.to_str().unwrap()); + assert_eq!(detect_image_path(&text), None); + + // Test with non-existent file + let text = "Here is a fake.png that doesn't exist"; + assert_eq!(detect_image_path(text), None); + + // Test with non-image file + let text = "Here is a file.txt"; + assert_eq!(detect_image_path(text), None); + + // Test with relative path (should not match) + let text = "Here is a relative/path/image.png"; + assert_eq!(detect_image_path(text), None); + } + + #[test] + fn test_load_image_file() { + // Create a temporary PNG file with valid PNG magic numbers + let temp_dir = tempfile::tempdir().unwrap(); + let png_path = temp_dir.path().join("test.png"); + let png_data = [ + 0x89, 0x50, 0x4E, 0x47, // PNG magic number + 0x0D, 0x0A, 0x1A, 0x0A, // PNG header + 0x00, 0x00, 0x00, 0x0D, // Rest of fake PNG data + ]; + std::fs::write(&png_path, &png_data).unwrap(); + let png_path_str = png_path.to_str().unwrap(); + + // Create a fake PNG (wrong magic numbers) + let fake_png_path = temp_dir.path().join("fake.png"); + std::fs::write(&fake_png_path, b"not a real png").unwrap(); + let fake_png_path_str = fake_png_path.to_str().unwrap(); + + // Test loading valid PNG file + let result = load_image_file(png_path_str); + assert!(result.is_ok()); + let image = result.unwrap(); + assert_eq!(image.mime_type, "image/png"); + + // Test loading fake PNG file + let result = load_image_file(fake_png_path_str); + assert!(result.is_err()); + assert!(result + .unwrap_err() + .to_string() + .contains("not a valid image")); + + // Test non-existent file + let result = load_image_file("nonexistent.png"); + assert!(result.is_err()); + } #[test] fn test_sanitize_function_name() {