mirror of
https://github.com/aljazceru/goose.git
synced 2025-12-20 15:44:25 +01:00
feat: handle images mentioned in messages (#1202)
This commit is contained in:
@@ -3,7 +3,8 @@ use crate::model::ModelConfig;
|
|||||||
use crate::providers::base::Usage;
|
use crate::providers::base::Usage;
|
||||||
use crate::providers::errors::ProviderError;
|
use crate::providers::errors::ProviderError;
|
||||||
use crate::providers::utils::{
|
use crate::providers::utils::{
|
||||||
convert_image, is_valid_function_name, sanitize_function_name, ImageFormat,
|
convert_image, detect_image_path, is_valid_function_name, load_image_file,
|
||||||
|
sanitize_function_name, ImageFormat,
|
||||||
};
|
};
|
||||||
use anyhow::{anyhow, Error};
|
use anyhow::{anyhow, Error};
|
||||||
use mcp_core::ToolError;
|
use mcp_core::ToolError;
|
||||||
@@ -26,8 +27,22 @@ pub fn format_messages(messages: &[Message], image_format: &ImageFormat) -> Vec<
|
|||||||
match content {
|
match content {
|
||||||
MessageContent::Text(text) => {
|
MessageContent::Text(text) => {
|
||||||
if !text.text.is_empty() {
|
if !text.text.is_empty() {
|
||||||
|
// Check for image paths in the text
|
||||||
|
if let Some(image_path) = detect_image_path(&text.text) {
|
||||||
|
// Try to load and convert the image
|
||||||
|
if let Ok(image) = load_image_file(image_path) {
|
||||||
|
converted["content"] = json!([
|
||||||
|
{"type": "text", "text": text.text},
|
||||||
|
convert_image(&image, image_format)
|
||||||
|
]);
|
||||||
|
} else {
|
||||||
|
// If image loading fails, just use the text
|
||||||
converted["content"] = json!(text.text);
|
converted["content"] = json!(text.text);
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
converted["content"] = json!(text.text);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
MessageContent::ToolRequest(request) => match &request.tool_call {
|
MessageContent::ToolRequest(request) => match &request.tool_call {
|
||||||
Ok(tool_call) => {
|
Ok(tool_call) => {
|
||||||
@@ -622,6 +637,40 @@ mod tests {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_format_messages_with_image_path() -> anyhow::Result<()> {
|
||||||
|
// Create a temporary PNG file with valid PNG magic numbers
|
||||||
|
let temp_dir = tempfile::tempdir()?;
|
||||||
|
let png_path = temp_dir.path().join("test.png");
|
||||||
|
let png_data = [
|
||||||
|
0x89, 0x50, 0x4E, 0x47, // PNG magic number
|
||||||
|
0x0D, 0x0A, 0x1A, 0x0A, // PNG header
|
||||||
|
0x00, 0x00, 0x00, 0x0D, // Rest of fake PNG data
|
||||||
|
];
|
||||||
|
std::fs::write(&png_path, &png_data)?;
|
||||||
|
let png_path_str = png_path.to_str().unwrap();
|
||||||
|
|
||||||
|
// Create message with image path
|
||||||
|
let message = Message::user().with_text(format!("Here is an image: {}", png_path_str));
|
||||||
|
let spec = format_messages(&[message], &ImageFormat::OpenAi);
|
||||||
|
|
||||||
|
assert_eq!(spec.len(), 1);
|
||||||
|
assert_eq!(spec[0]["role"], "user");
|
||||||
|
|
||||||
|
// Content should be an array with text and image
|
||||||
|
let content = spec[0]["content"].as_array().unwrap();
|
||||||
|
assert_eq!(content.len(), 2);
|
||||||
|
assert_eq!(content[0]["type"], "text");
|
||||||
|
assert!(content[0]["text"].as_str().unwrap().contains(png_path_str));
|
||||||
|
assert_eq!(content[1]["type"], "image_url");
|
||||||
|
assert!(content[1]["image_url"]["url"]
|
||||||
|
.as_str()
|
||||||
|
.unwrap()
|
||||||
|
.starts_with("data:image/png;base64,"));
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_response_to_message_text() -> anyhow::Result<()> {
|
fn test_response_to_message_text() -> anyhow::Result<()> {
|
||||||
let response = json!({
|
let response = json!({
|
||||||
|
|||||||
@@ -1,9 +1,12 @@
|
|||||||
use super::base::Usage;
|
use super::base::Usage;
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
|
use base64::Engine;
|
||||||
use regex::Regex;
|
use regex::Regex;
|
||||||
use reqwest::{Response, StatusCode};
|
use reqwest::{Response, StatusCode};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use serde_json::{json, Map, Value};
|
use serde_json::{json, Map, Value};
|
||||||
|
use std::io::Read;
|
||||||
|
use std::path::Path;
|
||||||
|
|
||||||
use crate::providers::errors::ProviderError;
|
use crate::providers::errors::ProviderError;
|
||||||
use mcp_core::content::ImageContent;
|
use mcp_core::content::ImageContent;
|
||||||
@@ -110,6 +113,91 @@ pub fn get_model(data: &Value) -> String {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Check if a file is actually an image by examining its magic bytes
|
||||||
|
fn is_image_file(path: &Path) -> bool {
|
||||||
|
if let Ok(mut file) = std::fs::File::open(path) {
|
||||||
|
let mut buffer = [0u8; 8]; // Large enough for most image magic numbers
|
||||||
|
if file.read(&mut buffer).is_ok() {
|
||||||
|
// Check magic numbers for common image formats
|
||||||
|
return match &buffer[0..4] {
|
||||||
|
// PNG: 89 50 4E 47
|
||||||
|
[0x89, 0x50, 0x4E, 0x47] => true,
|
||||||
|
// JPEG: FF D8 FF
|
||||||
|
[0xFF, 0xD8, 0xFF, _] => true,
|
||||||
|
_ => false,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
false
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Detect if a string contains a path to an image file
|
||||||
|
pub fn detect_image_path(text: &str) -> Option<&str> {
|
||||||
|
// Basic image file extension check
|
||||||
|
let extensions = [".png", ".jpg", ".jpeg"];
|
||||||
|
|
||||||
|
// Find any word that ends with an image extension
|
||||||
|
for word in text.split_whitespace() {
|
||||||
|
if extensions
|
||||||
|
.iter()
|
||||||
|
.any(|ext| word.to_lowercase().ends_with(ext))
|
||||||
|
{
|
||||||
|
let path = Path::new(word);
|
||||||
|
// Check if it's an absolute path and file exists
|
||||||
|
if path.is_absolute() && path.is_file() {
|
||||||
|
// Verify it's actually an image file
|
||||||
|
if is_image_file(path) {
|
||||||
|
return Some(word);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
None
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Convert a local image file to base64 encoded ImageContent
|
||||||
|
pub fn load_image_file(path: &str) -> Result<ImageContent, ProviderError> {
|
||||||
|
let path = Path::new(path);
|
||||||
|
|
||||||
|
// Verify it's an image before proceeding
|
||||||
|
if !is_image_file(path) {
|
||||||
|
return Err(ProviderError::RequestFailed(
|
||||||
|
"File is not a valid image".to_string(),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read the file
|
||||||
|
let bytes = std::fs::read(path)
|
||||||
|
.map_err(|e| ProviderError::RequestFailed(format!("Failed to read image file: {}", e)))?;
|
||||||
|
|
||||||
|
// Detect mime type from extension
|
||||||
|
let mime_type = match path.extension().and_then(|e| e.to_str()) {
|
||||||
|
Some(ext) => match ext.to_lowercase().as_str() {
|
||||||
|
"png" => "image/png",
|
||||||
|
"jpg" | "jpeg" => "image/jpeg",
|
||||||
|
_ => {
|
||||||
|
return Err(ProviderError::RequestFailed(
|
||||||
|
"Unsupported image format".to_string(),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
},
|
||||||
|
None => {
|
||||||
|
return Err(ProviderError::RequestFailed(
|
||||||
|
"Unknown image format".to_string(),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Convert to base64
|
||||||
|
let data = base64::prelude::BASE64_STANDARD.encode(&bytes);
|
||||||
|
|
||||||
|
Ok(ImageContent {
|
||||||
|
mime_type: mime_type.to_string(),
|
||||||
|
data,
|
||||||
|
annotations: None,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
pub fn unescape_json_values(value: &Value) -> Value {
|
pub fn unescape_json_values(value: &Value) -> Value {
|
||||||
match value {
|
match value {
|
||||||
Value::Object(map) => {
|
Value::Object(map) => {
|
||||||
@@ -165,6 +253,83 @@ pub fn emit_debug_trace<T: serde::Serialize>(
|
|||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use serde_json::json;
|
use serde_json::json;
|
||||||
|
use std::io::Write;
|
||||||
|
use tempfile::NamedTempFile;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_detect_image_path() {
|
||||||
|
// Create a temporary PNG file with valid PNG magic numbers
|
||||||
|
let temp_dir = tempfile::tempdir().unwrap();
|
||||||
|
let png_path = temp_dir.path().join("test.png");
|
||||||
|
let png_data = [
|
||||||
|
0x89, 0x50, 0x4E, 0x47, // PNG magic number
|
||||||
|
0x0D, 0x0A, 0x1A, 0x0A, // PNG header
|
||||||
|
0x00, 0x00, 0x00, 0x0D, // Rest of fake PNG data
|
||||||
|
];
|
||||||
|
std::fs::write(&png_path, &png_data).unwrap();
|
||||||
|
let png_path_str = png_path.to_str().unwrap();
|
||||||
|
|
||||||
|
// Create a fake PNG (wrong magic numbers)
|
||||||
|
let fake_png_path = temp_dir.path().join("fake.png");
|
||||||
|
std::fs::write(&fake_png_path, b"not a real png").unwrap();
|
||||||
|
|
||||||
|
// Test with valid PNG file using absolute path
|
||||||
|
let text = format!("Here is an image {}", png_path_str);
|
||||||
|
assert_eq!(detect_image_path(&text), Some(png_path_str));
|
||||||
|
|
||||||
|
// Test with non-image file that has .png extension
|
||||||
|
let text = format!("Here is a fake image {}", fake_png_path.to_str().unwrap());
|
||||||
|
assert_eq!(detect_image_path(&text), None);
|
||||||
|
|
||||||
|
// Test with non-existent file
|
||||||
|
let text = "Here is a fake.png that doesn't exist";
|
||||||
|
assert_eq!(detect_image_path(text), None);
|
||||||
|
|
||||||
|
// Test with non-image file
|
||||||
|
let text = "Here is a file.txt";
|
||||||
|
assert_eq!(detect_image_path(text), None);
|
||||||
|
|
||||||
|
// Test with relative path (should not match)
|
||||||
|
let text = "Here is a relative/path/image.png";
|
||||||
|
assert_eq!(detect_image_path(text), None);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_load_image_file() {
|
||||||
|
// Create a temporary PNG file with valid PNG magic numbers
|
||||||
|
let temp_dir = tempfile::tempdir().unwrap();
|
||||||
|
let png_path = temp_dir.path().join("test.png");
|
||||||
|
let png_data = [
|
||||||
|
0x89, 0x50, 0x4E, 0x47, // PNG magic number
|
||||||
|
0x0D, 0x0A, 0x1A, 0x0A, // PNG header
|
||||||
|
0x00, 0x00, 0x00, 0x0D, // Rest of fake PNG data
|
||||||
|
];
|
||||||
|
std::fs::write(&png_path, &png_data).unwrap();
|
||||||
|
let png_path_str = png_path.to_str().unwrap();
|
||||||
|
|
||||||
|
// Create a fake PNG (wrong magic numbers)
|
||||||
|
let fake_png_path = temp_dir.path().join("fake.png");
|
||||||
|
std::fs::write(&fake_png_path, b"not a real png").unwrap();
|
||||||
|
let fake_png_path_str = fake_png_path.to_str().unwrap();
|
||||||
|
|
||||||
|
// Test loading valid PNG file
|
||||||
|
let result = load_image_file(png_path_str);
|
||||||
|
assert!(result.is_ok());
|
||||||
|
let image = result.unwrap();
|
||||||
|
assert_eq!(image.mime_type, "image/png");
|
||||||
|
|
||||||
|
// Test loading fake PNG file
|
||||||
|
let result = load_image_file(fake_png_path_str);
|
||||||
|
assert!(result.is_err());
|
||||||
|
assert!(result
|
||||||
|
.unwrap_err()
|
||||||
|
.to_string()
|
||||||
|
.contains("not a valid image"));
|
||||||
|
|
||||||
|
// Test non-existent file
|
||||||
|
let result = load_image_file("nonexistent.png");
|
||||||
|
assert!(result.is_err());
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_sanitize_function_name() {
|
fn test_sanitize_function_name() {
|
||||||
|
|||||||
Reference in New Issue
Block a user