mirror of
https://github.com/aljazceru/goose.git
synced 2025-12-19 07:04:21 +01:00
feat: handle images mentioned in messages (#1202)
This commit is contained in:
@@ -3,7 +3,8 @@ use crate::model::ModelConfig;
|
||||
use crate::providers::base::Usage;
|
||||
use crate::providers::errors::ProviderError;
|
||||
use crate::providers::utils::{
|
||||
convert_image, is_valid_function_name, sanitize_function_name, ImageFormat,
|
||||
convert_image, detect_image_path, is_valid_function_name, load_image_file,
|
||||
sanitize_function_name, ImageFormat,
|
||||
};
|
||||
use anyhow::{anyhow, Error};
|
||||
use mcp_core::ToolError;
|
||||
@@ -26,7 +27,21 @@ pub fn format_messages(messages: &[Message], image_format: &ImageFormat) -> Vec<
|
||||
match content {
|
||||
MessageContent::Text(text) => {
|
||||
if !text.text.is_empty() {
|
||||
converted["content"] = json!(text.text);
|
||||
// Check for image paths in the text
|
||||
if let Some(image_path) = detect_image_path(&text.text) {
|
||||
// Try to load and convert the image
|
||||
if let Ok(image) = load_image_file(image_path) {
|
||||
converted["content"] = json!([
|
||||
{"type": "text", "text": text.text},
|
||||
convert_image(&image, image_format)
|
||||
]);
|
||||
} else {
|
||||
// If image loading fails, just use the text
|
||||
converted["content"] = json!(text.text);
|
||||
}
|
||||
} else {
|
||||
converted["content"] = json!(text.text);
|
||||
}
|
||||
}
|
||||
}
|
||||
MessageContent::ToolRequest(request) => match &request.tool_call {
|
||||
@@ -622,6 +637,40 @@ mod tests {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_format_messages_with_image_path() -> anyhow::Result<()> {
|
||||
// Create a temporary PNG file with valid PNG magic numbers
|
||||
let temp_dir = tempfile::tempdir()?;
|
||||
let png_path = temp_dir.path().join("test.png");
|
||||
let png_data = [
|
||||
0x89, 0x50, 0x4E, 0x47, // PNG magic number
|
||||
0x0D, 0x0A, 0x1A, 0x0A, // PNG header
|
||||
0x00, 0x00, 0x00, 0x0D, // Rest of fake PNG data
|
||||
];
|
||||
std::fs::write(&png_path, &png_data)?;
|
||||
let png_path_str = png_path.to_str().unwrap();
|
||||
|
||||
// Create message with image path
|
||||
let message = Message::user().with_text(format!("Here is an image: {}", png_path_str));
|
||||
let spec = format_messages(&[message], &ImageFormat::OpenAi);
|
||||
|
||||
assert_eq!(spec.len(), 1);
|
||||
assert_eq!(spec[0]["role"], "user");
|
||||
|
||||
// Content should be an array with text and image
|
||||
let content = spec[0]["content"].as_array().unwrap();
|
||||
assert_eq!(content.len(), 2);
|
||||
assert_eq!(content[0]["type"], "text");
|
||||
assert!(content[0]["text"].as_str().unwrap().contains(png_path_str));
|
||||
assert_eq!(content[1]["type"], "image_url");
|
||||
assert!(content[1]["image_url"]["url"]
|
||||
.as_str()
|
||||
.unwrap()
|
||||
.starts_with("data:image/png;base64,"));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_response_to_message_text() -> anyhow::Result<()> {
|
||||
let response = json!({
|
||||
|
||||
@@ -1,9 +1,12 @@
|
||||
use super::base::Usage;
|
||||
use anyhow::Result;
|
||||
use base64::Engine;
|
||||
use regex::Regex;
|
||||
use reqwest::{Response, StatusCode};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::{json, Map, Value};
|
||||
use std::io::Read;
|
||||
use std::path::Path;
|
||||
|
||||
use crate::providers::errors::ProviderError;
|
||||
use mcp_core::content::ImageContent;
|
||||
@@ -110,6 +113,91 @@ pub fn get_model(data: &Value) -> String {
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if a file is actually an image by examining its magic bytes
|
||||
fn is_image_file(path: &Path) -> bool {
|
||||
if let Ok(mut file) = std::fs::File::open(path) {
|
||||
let mut buffer = [0u8; 8]; // Large enough for most image magic numbers
|
||||
if file.read(&mut buffer).is_ok() {
|
||||
// Check magic numbers for common image formats
|
||||
return match &buffer[0..4] {
|
||||
// PNG: 89 50 4E 47
|
||||
[0x89, 0x50, 0x4E, 0x47] => true,
|
||||
// JPEG: FF D8 FF
|
||||
[0xFF, 0xD8, 0xFF, _] => true,
|
||||
_ => false,
|
||||
};
|
||||
}
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
/// Detect if a string contains a path to an image file
|
||||
pub fn detect_image_path(text: &str) -> Option<&str> {
|
||||
// Basic image file extension check
|
||||
let extensions = [".png", ".jpg", ".jpeg"];
|
||||
|
||||
// Find any word that ends with an image extension
|
||||
for word in text.split_whitespace() {
|
||||
if extensions
|
||||
.iter()
|
||||
.any(|ext| word.to_lowercase().ends_with(ext))
|
||||
{
|
||||
let path = Path::new(word);
|
||||
// Check if it's an absolute path and file exists
|
||||
if path.is_absolute() && path.is_file() {
|
||||
// Verify it's actually an image file
|
||||
if is_image_file(path) {
|
||||
return Some(word);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
/// Convert a local image file to base64 encoded ImageContent
|
||||
pub fn load_image_file(path: &str) -> Result<ImageContent, ProviderError> {
|
||||
let path = Path::new(path);
|
||||
|
||||
// Verify it's an image before proceeding
|
||||
if !is_image_file(path) {
|
||||
return Err(ProviderError::RequestFailed(
|
||||
"File is not a valid image".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
// Read the file
|
||||
let bytes = std::fs::read(path)
|
||||
.map_err(|e| ProviderError::RequestFailed(format!("Failed to read image file: {}", e)))?;
|
||||
|
||||
// Detect mime type from extension
|
||||
let mime_type = match path.extension().and_then(|e| e.to_str()) {
|
||||
Some(ext) => match ext.to_lowercase().as_str() {
|
||||
"png" => "image/png",
|
||||
"jpg" | "jpeg" => "image/jpeg",
|
||||
_ => {
|
||||
return Err(ProviderError::RequestFailed(
|
||||
"Unsupported image format".to_string(),
|
||||
))
|
||||
}
|
||||
},
|
||||
None => {
|
||||
return Err(ProviderError::RequestFailed(
|
||||
"Unknown image format".to_string(),
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
// Convert to base64
|
||||
let data = base64::prelude::BASE64_STANDARD.encode(&bytes);
|
||||
|
||||
Ok(ImageContent {
|
||||
mime_type: mime_type.to_string(),
|
||||
data,
|
||||
annotations: None,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn unescape_json_values(value: &Value) -> Value {
|
||||
match value {
|
||||
Value::Object(map) => {
|
||||
@@ -165,6 +253,83 @@ pub fn emit_debug_trace<T: serde::Serialize>(
|
||||
mod tests {
|
||||
use super::*;
|
||||
use serde_json::json;
|
||||
use std::io::Write;
|
||||
use tempfile::NamedTempFile;
|
||||
|
||||
#[test]
|
||||
fn test_detect_image_path() {
|
||||
// Create a temporary PNG file with valid PNG magic numbers
|
||||
let temp_dir = tempfile::tempdir().unwrap();
|
||||
let png_path = temp_dir.path().join("test.png");
|
||||
let png_data = [
|
||||
0x89, 0x50, 0x4E, 0x47, // PNG magic number
|
||||
0x0D, 0x0A, 0x1A, 0x0A, // PNG header
|
||||
0x00, 0x00, 0x00, 0x0D, // Rest of fake PNG data
|
||||
];
|
||||
std::fs::write(&png_path, &png_data).unwrap();
|
||||
let png_path_str = png_path.to_str().unwrap();
|
||||
|
||||
// Create a fake PNG (wrong magic numbers)
|
||||
let fake_png_path = temp_dir.path().join("fake.png");
|
||||
std::fs::write(&fake_png_path, b"not a real png").unwrap();
|
||||
|
||||
// Test with valid PNG file using absolute path
|
||||
let text = format!("Here is an image {}", png_path_str);
|
||||
assert_eq!(detect_image_path(&text), Some(png_path_str));
|
||||
|
||||
// Test with non-image file that has .png extension
|
||||
let text = format!("Here is a fake image {}", fake_png_path.to_str().unwrap());
|
||||
assert_eq!(detect_image_path(&text), None);
|
||||
|
||||
// Test with non-existent file
|
||||
let text = "Here is a fake.png that doesn't exist";
|
||||
assert_eq!(detect_image_path(text), None);
|
||||
|
||||
// Test with non-image file
|
||||
let text = "Here is a file.txt";
|
||||
assert_eq!(detect_image_path(text), None);
|
||||
|
||||
// Test with relative path (should not match)
|
||||
let text = "Here is a relative/path/image.png";
|
||||
assert_eq!(detect_image_path(text), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_load_image_file() {
|
||||
// Create a temporary PNG file with valid PNG magic numbers
|
||||
let temp_dir = tempfile::tempdir().unwrap();
|
||||
let png_path = temp_dir.path().join("test.png");
|
||||
let png_data = [
|
||||
0x89, 0x50, 0x4E, 0x47, // PNG magic number
|
||||
0x0D, 0x0A, 0x1A, 0x0A, // PNG header
|
||||
0x00, 0x00, 0x00, 0x0D, // Rest of fake PNG data
|
||||
];
|
||||
std::fs::write(&png_path, &png_data).unwrap();
|
||||
let png_path_str = png_path.to_str().unwrap();
|
||||
|
||||
// Create a fake PNG (wrong magic numbers)
|
||||
let fake_png_path = temp_dir.path().join("fake.png");
|
||||
std::fs::write(&fake_png_path, b"not a real png").unwrap();
|
||||
let fake_png_path_str = fake_png_path.to_str().unwrap();
|
||||
|
||||
// Test loading valid PNG file
|
||||
let result = load_image_file(png_path_str);
|
||||
assert!(result.is_ok());
|
||||
let image = result.unwrap();
|
||||
assert_eq!(image.mime_type, "image/png");
|
||||
|
||||
// Test loading fake PNG file
|
||||
let result = load_image_file(fake_png_path_str);
|
||||
assert!(result.is_err());
|
||||
assert!(result
|
||||
.unwrap_err()
|
||||
.to_string()
|
||||
.contains("not a valid image"));
|
||||
|
||||
// Test non-existent file
|
||||
let result = load_image_file("nonexistent.png");
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sanitize_function_name() {
|
||||
|
||||
Reference in New Issue
Block a user