feat: handle images mentioned in messages (#1202)

This commit is contained in:
Michael Neale
2025-02-12 14:36:46 -08:00
committed by GitHub
parent cfb1bba9b5
commit e461e6845d
2 changed files with 216 additions and 2 deletions

View File

@@ -3,7 +3,8 @@ use crate::model::ModelConfig;
use crate::providers::base::Usage;
use crate::providers::errors::ProviderError;
use crate::providers::utils::{
convert_image, is_valid_function_name, sanitize_function_name, ImageFormat,
convert_image, detect_image_path, is_valid_function_name, load_image_file,
sanitize_function_name, ImageFormat,
};
use anyhow::{anyhow, Error};
use mcp_core::ToolError;
@@ -26,7 +27,21 @@ pub fn format_messages(messages: &[Message], image_format: &ImageFormat) -> Vec<
match content {
MessageContent::Text(text) => {
if !text.text.is_empty() {
converted["content"] = json!(text.text);
// Check for image paths in the text
if let Some(image_path) = detect_image_path(&text.text) {
// Try to load and convert the image
if let Ok(image) = load_image_file(image_path) {
converted["content"] = json!([
{"type": "text", "text": text.text},
convert_image(&image, image_format)
]);
} else {
// If image loading fails, just use the text
converted["content"] = json!(text.text);
}
} else {
converted["content"] = json!(text.text);
}
}
}
MessageContent::ToolRequest(request) => match &request.tool_call {
@@ -622,6 +637,40 @@ mod tests {
Ok(())
}
#[test]
fn test_format_messages_with_image_path() -> anyhow::Result<()> {
// Create a temporary PNG file with valid PNG magic numbers
let temp_dir = tempfile::tempdir()?;
let png_path = temp_dir.path().join("test.png");
let png_data = [
0x89, 0x50, 0x4E, 0x47, // PNG magic number
0x0D, 0x0A, 0x1A, 0x0A, // PNG header
0x00, 0x00, 0x00, 0x0D, // Rest of fake PNG data
];
std::fs::write(&png_path, &png_data)?;
let png_path_str = png_path.to_str().unwrap();
// Create message with image path
let message = Message::user().with_text(format!("Here is an image: {}", png_path_str));
let spec = format_messages(&[message], &ImageFormat::OpenAi);
assert_eq!(spec.len(), 1);
assert_eq!(spec[0]["role"], "user");
// Content should be an array with text and image
let content = spec[0]["content"].as_array().unwrap();
assert_eq!(content.len(), 2);
assert_eq!(content[0]["type"], "text");
assert!(content[0]["text"].as_str().unwrap().contains(png_path_str));
assert_eq!(content[1]["type"], "image_url");
assert!(content[1]["image_url"]["url"]
.as_str()
.unwrap()
.starts_with("data:image/png;base64,"));
Ok(())
}
#[test]
fn test_response_to_message_text() -> anyhow::Result<()> {
let response = json!({

View File

@@ -1,9 +1,12 @@
use super::base::Usage;
use anyhow::Result;
use base64::Engine;
use regex::Regex;
use reqwest::{Response, StatusCode};
use serde::{Deserialize, Serialize};
use serde_json::{json, Map, Value};
use std::io::Read;
use std::path::Path;
use crate::providers::errors::ProviderError;
use mcp_core::content::ImageContent;
@@ -110,6 +113,91 @@ pub fn get_model(data: &Value) -> String {
}
}
/// Check if a file is actually an image by examining its magic bytes
fn is_image_file(path: &Path) -> bool {
if let Ok(mut file) = std::fs::File::open(path) {
let mut buffer = [0u8; 8]; // Large enough for most image magic numbers
if file.read(&mut buffer).is_ok() {
// Check magic numbers for common image formats
return match &buffer[0..4] {
// PNG: 89 50 4E 47
[0x89, 0x50, 0x4E, 0x47] => true,
// JPEG: FF D8 FF
[0xFF, 0xD8, 0xFF, _] => true,
_ => false,
};
}
}
false
}
/// Detect if a string contains a path to an image file
pub fn detect_image_path(text: &str) -> Option<&str> {
// Basic image file extension check
let extensions = [".png", ".jpg", ".jpeg"];
// Find any word that ends with an image extension
for word in text.split_whitespace() {
if extensions
.iter()
.any(|ext| word.to_lowercase().ends_with(ext))
{
let path = Path::new(word);
// Check if it's an absolute path and file exists
if path.is_absolute() && path.is_file() {
// Verify it's actually an image file
if is_image_file(path) {
return Some(word);
}
}
}
}
None
}
/// Convert a local image file to base64 encoded ImageContent
pub fn load_image_file(path: &str) -> Result<ImageContent, ProviderError> {
let path = Path::new(path);
// Verify it's an image before proceeding
if !is_image_file(path) {
return Err(ProviderError::RequestFailed(
"File is not a valid image".to_string(),
));
}
// Read the file
let bytes = std::fs::read(path)
.map_err(|e| ProviderError::RequestFailed(format!("Failed to read image file: {}", e)))?;
// Detect mime type from extension
let mime_type = match path.extension().and_then(|e| e.to_str()) {
Some(ext) => match ext.to_lowercase().as_str() {
"png" => "image/png",
"jpg" | "jpeg" => "image/jpeg",
_ => {
return Err(ProviderError::RequestFailed(
"Unsupported image format".to_string(),
))
}
},
None => {
return Err(ProviderError::RequestFailed(
"Unknown image format".to_string(),
))
}
};
// Convert to base64
let data = base64::prelude::BASE64_STANDARD.encode(&bytes);
Ok(ImageContent {
mime_type: mime_type.to_string(),
data,
annotations: None,
})
}
pub fn unescape_json_values(value: &Value) -> Value {
match value {
Value::Object(map) => {
@@ -165,6 +253,83 @@ pub fn emit_debug_trace<T: serde::Serialize>(
mod tests {
use super::*;
use serde_json::json;
use std::io::Write;
use tempfile::NamedTempFile;
#[test]
fn test_detect_image_path() {
// Create a temporary PNG file with valid PNG magic numbers
let temp_dir = tempfile::tempdir().unwrap();
let png_path = temp_dir.path().join("test.png");
let png_data = [
0x89, 0x50, 0x4E, 0x47, // PNG magic number
0x0D, 0x0A, 0x1A, 0x0A, // PNG header
0x00, 0x00, 0x00, 0x0D, // Rest of fake PNG data
];
std::fs::write(&png_path, &png_data).unwrap();
let png_path_str = png_path.to_str().unwrap();
// Create a fake PNG (wrong magic numbers)
let fake_png_path = temp_dir.path().join("fake.png");
std::fs::write(&fake_png_path, b"not a real png").unwrap();
// Test with valid PNG file using absolute path
let text = format!("Here is an image {}", png_path_str);
assert_eq!(detect_image_path(&text), Some(png_path_str));
// Test with non-image file that has .png extension
let text = format!("Here is a fake image {}", fake_png_path.to_str().unwrap());
assert_eq!(detect_image_path(&text), None);
// Test with non-existent file
let text = "Here is a fake.png that doesn't exist";
assert_eq!(detect_image_path(text), None);
// Test with non-image file
let text = "Here is a file.txt";
assert_eq!(detect_image_path(text), None);
// Test with relative path (should not match)
let text = "Here is a relative/path/image.png";
assert_eq!(detect_image_path(text), None);
}
#[test]
fn test_load_image_file() {
// Create a temporary PNG file with valid PNG magic numbers
let temp_dir = tempfile::tempdir().unwrap();
let png_path = temp_dir.path().join("test.png");
let png_data = [
0x89, 0x50, 0x4E, 0x47, // PNG magic number
0x0D, 0x0A, 0x1A, 0x0A, // PNG header
0x00, 0x00, 0x00, 0x0D, // Rest of fake PNG data
];
std::fs::write(&png_path, &png_data).unwrap();
let png_path_str = png_path.to_str().unwrap();
// Create a fake PNG (wrong magic numbers)
let fake_png_path = temp_dir.path().join("fake.png");
std::fs::write(&fake_png_path, b"not a real png").unwrap();
let fake_png_path_str = fake_png_path.to_str().unwrap();
// Test loading valid PNG file
let result = load_image_file(png_path_str);
assert!(result.is_ok());
let image = result.unwrap();
assert_eq!(image.mime_type, "image/png");
// Test loading fake PNG file
let result = load_image_file(fake_png_path_str);
assert!(result.is_err());
assert!(result
.unwrap_err()
.to_string()
.contains("not a valid image"));
// Test non-existent file
let result = load_image_file("nonexistent.png");
assert!(result.is_err());
}
#[test]
fn test_sanitize_function_name() {