mirror of
https://github.com/aljazceru/goose.git
synced 2025-12-17 06:04:23 +01:00
feat: run sub recipe multiple times in parallel (Experimental feature) (#3274)
Co-authored-by: Wendy Tang <wendytang@squareup.com>
This commit is contained in:
3
.gitignore
vendored
3
.gitignore
vendored
@@ -31,6 +31,9 @@ ui/desktop/src/bin/goose_llm.dll
|
||||
# Hermit
|
||||
.hermit/
|
||||
|
||||
# Claude
|
||||
.claude
|
||||
|
||||
debug_*.txt
|
||||
|
||||
# Docs
|
||||
|
||||
@@ -32,6 +32,7 @@ pub fn extract_recipe_info_from_cli(
|
||||
path: recipe_file_path.to_string_lossy().to_string(),
|
||||
name,
|
||||
values: None,
|
||||
sequential_when_repeated: true,
|
||||
};
|
||||
all_sub_recipes.push(additional_sub_recipe);
|
||||
}
|
||||
|
||||
@@ -4,8 +4,11 @@ mod export;
|
||||
mod input;
|
||||
mod output;
|
||||
mod prompt;
|
||||
mod task_execution_display;
|
||||
mod thinking;
|
||||
|
||||
use crate::session::task_execution_display::TASK_EXECUTION_NOTIFICATION_TYPE;
|
||||
|
||||
pub use self::export::message_to_markdown;
|
||||
pub use builder::{build_session, SessionBuilderConfig, SessionSettings};
|
||||
use console::Color;
|
||||
@@ -17,6 +20,8 @@ use goose::permission::PermissionConfirmation;
|
||||
use goose::providers::base::Provider;
|
||||
pub use goose::session::Identifier;
|
||||
use goose::utils::safe_truncate;
|
||||
use std::io::Write;
|
||||
use task_execution_display::format_task_execution_notification;
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use completion::GooseCompleter;
|
||||
@@ -1008,7 +1013,7 @@ impl Session {
|
||||
match method.as_str() {
|
||||
"notifications/message" => {
|
||||
let data = o.get("data").unwrap_or(&Value::Null);
|
||||
let (formatted_message, subagent_id, _notification_type) = match data {
|
||||
let (formatted_message, subagent_id, message_notification_type) = match data {
|
||||
Value::String(s) => (s.clone(), None, None),
|
||||
Value::Object(o) => {
|
||||
// Check for subagent notification structure first
|
||||
@@ -1059,6 +1064,8 @@ impl Session {
|
||||
} else if let Some(Value::String(output)) = o.get("output") {
|
||||
// Fallback for other MCP notification types
|
||||
(output.to_owned(), None, None)
|
||||
} else if let Some(result) = format_task_execution_notification(data) {
|
||||
result
|
||||
} else {
|
||||
(data.to_string(), None, None)
|
||||
}
|
||||
@@ -1077,7 +1084,19 @@ impl Session {
|
||||
} else {
|
||||
progress_bars.log(&formatted_message);
|
||||
}
|
||||
} else if let Some(ref notification_type) = message_notification_type {
|
||||
if notification_type == TASK_EXECUTION_NOTIFICATION_TYPE {
|
||||
if interactive {
|
||||
let _ = progress_bars.hide();
|
||||
print!("{}", formatted_message);
|
||||
std::io::stdout().flush().unwrap();
|
||||
} else {
|
||||
print!("{}", formatted_message);
|
||||
std::io::stdout().flush().unwrap();
|
||||
}
|
||||
}
|
||||
}
|
||||
else {
|
||||
// Non-subagent notification, display immediately with compact spacing
|
||||
if interactive {
|
||||
let _ = progress_bars.hide();
|
||||
|
||||
247
crates/goose-cli/src/session/task_execution_display/mod.rs
Normal file
247
crates/goose-cli/src/session/task_execution_display/mod.rs
Normal file
@@ -0,0 +1,247 @@
|
||||
use goose::agents::sub_recipe_execution_tool::lib::TaskStatus;
|
||||
use goose::agents::sub_recipe_execution_tool::notification_events::{
|
||||
TaskExecutionNotificationEvent, TaskInfo,
|
||||
};
|
||||
use serde_json::Value;
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
|
||||
const CLEAR_SCREEN: &str = "\x1b[2J\x1b[H";
|
||||
const MOVE_TO_PROGRESS_LINE: &str = "\x1b[4;1H";
|
||||
const CLEAR_TO_EOL: &str = "\x1b[K";
|
||||
const CLEAR_BELOW: &str = "\x1b[J";
|
||||
pub const TASK_EXECUTION_NOTIFICATION_TYPE: &str = "task_execution";
|
||||
|
||||
static INITIAL_SHOWN: AtomicBool = AtomicBool::new(false);
|
||||
|
||||
fn format_result_data_for_display(result_data: &Value) -> String {
|
||||
match result_data {
|
||||
Value::String(s) => strip_ansi_codes(s),
|
||||
Value::Object(obj) => {
|
||||
if let Some(partial_output) = obj.get("partial_output").and_then(|v| v.as_str()) {
|
||||
format!("Partial output: {}", partial_output)
|
||||
} else {
|
||||
serde_json::to_string_pretty(obj).unwrap_or_default()
|
||||
}
|
||||
}
|
||||
Value::Array(arr) => serde_json::to_string_pretty(arr).unwrap_or_default(),
|
||||
Value::Bool(b) => b.to_string(),
|
||||
Value::Number(n) => n.to_string(),
|
||||
Value::Null => "null".to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
fn process_output_for_display(output: &str) -> String {
|
||||
const MAX_OUTPUT_LINES: usize = 2;
|
||||
const OUTPUT_PREVIEW_LENGTH: usize = 100;
|
||||
|
||||
let lines: Vec<&str> = output.lines().collect();
|
||||
let recent_lines = if lines.len() > MAX_OUTPUT_LINES {
|
||||
&lines[lines.len() - MAX_OUTPUT_LINES..]
|
||||
} else {
|
||||
&lines
|
||||
};
|
||||
|
||||
let clean_output = recent_lines.join(" ... ");
|
||||
let stripped = strip_ansi_codes(&clean_output);
|
||||
truncate_with_ellipsis(&stripped, OUTPUT_PREVIEW_LENGTH)
|
||||
}
|
||||
|
||||
fn truncate_with_ellipsis(text: &str, max_len: usize) -> String {
|
||||
if text.len() > max_len {
|
||||
let mut end = max_len.saturating_sub(3);
|
||||
while end > 0 && !text.is_char_boundary(end) {
|
||||
end -= 1;
|
||||
}
|
||||
format!("{}...", &text[..end])
|
||||
} else {
|
||||
text.to_string()
|
||||
}
|
||||
}
|
||||
|
||||
fn strip_ansi_codes(text: &str) -> String {
|
||||
let mut result = String::new();
|
||||
let mut chars = text.chars();
|
||||
|
||||
while let Some(ch) = chars.next() {
|
||||
if ch == '\x1b' {
|
||||
if let Some(next_ch) = chars.next() {
|
||||
if next_ch == '[' {
|
||||
// This is an ANSI escape sequence, consume until alphabetic character
|
||||
loop {
|
||||
match chars.next() {
|
||||
Some(c) if c.is_ascii_alphabetic() => break,
|
||||
Some(_) => continue,
|
||||
None => break,
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Not an ANSI sequence, keep both characters
|
||||
result.push(ch);
|
||||
result.push(next_ch);
|
||||
}
|
||||
} else {
|
||||
// End of string after \x1b
|
||||
result.push(ch);
|
||||
}
|
||||
} else {
|
||||
result.push(ch);
|
||||
}
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
pub fn format_task_execution_notification(
|
||||
data: &Value,
|
||||
) -> Option<(String, Option<String>, Option<String>)> {
|
||||
if let Ok(event) = serde_json::from_value::<TaskExecutionNotificationEvent>(data.clone()) {
|
||||
return Some(match event {
|
||||
TaskExecutionNotificationEvent::LineOutput { output, .. } => (
|
||||
format!("{}\n", output),
|
||||
None,
|
||||
Some(TASK_EXECUTION_NOTIFICATION_TYPE.to_string()),
|
||||
),
|
||||
TaskExecutionNotificationEvent::TasksUpdate { .. } => {
|
||||
let formatted_display = format_tasks_update_from_event(&event);
|
||||
(
|
||||
formatted_display,
|
||||
None,
|
||||
Some(TASK_EXECUTION_NOTIFICATION_TYPE.to_string()),
|
||||
)
|
||||
}
|
||||
TaskExecutionNotificationEvent::TasksComplete { .. } => {
|
||||
let formatted_summary = format_tasks_complete_from_event(&event);
|
||||
(
|
||||
formatted_summary,
|
||||
None,
|
||||
Some(TASK_EXECUTION_NOTIFICATION_TYPE.to_string()),
|
||||
)
|
||||
}
|
||||
});
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
fn format_tasks_update_from_event(event: &TaskExecutionNotificationEvent) -> String {
|
||||
if let TaskExecutionNotificationEvent::TasksUpdate { stats, tasks } = event {
|
||||
let mut display = String::new();
|
||||
|
||||
if !INITIAL_SHOWN.swap(true, Ordering::SeqCst) {
|
||||
display.push_str(CLEAR_SCREEN);
|
||||
display.push_str("🎯 Task Execution Dashboard\n");
|
||||
display.push_str("═══════════════════════════\n\n");
|
||||
} else {
|
||||
display.push_str(MOVE_TO_PROGRESS_LINE);
|
||||
}
|
||||
|
||||
display.push_str(&format!(
|
||||
"📊 Progress: {} total | ⏳ {} pending | 🏃 {} running | ✅ {} completed | ❌ {} failed",
|
||||
stats.total, stats.pending, stats.running, stats.completed, stats.failed
|
||||
));
|
||||
display.push_str(&format!("{}\n\n", CLEAR_TO_EOL));
|
||||
|
||||
let mut sorted_tasks = tasks.clone();
|
||||
sorted_tasks.sort_by(|a, b| a.id.cmp(&b.id));
|
||||
|
||||
for task in sorted_tasks {
|
||||
display.push_str(&format_task_display(&task));
|
||||
}
|
||||
|
||||
display.push_str(CLEAR_BELOW);
|
||||
display
|
||||
} else {
|
||||
String::new()
|
||||
}
|
||||
}
|
||||
|
||||
fn format_tasks_complete_from_event(event: &TaskExecutionNotificationEvent) -> String {
|
||||
if let TaskExecutionNotificationEvent::TasksComplete {
|
||||
stats,
|
||||
failed_tasks,
|
||||
} = event
|
||||
{
|
||||
let mut summary = String::new();
|
||||
summary.push_str("Execution Complete!\n");
|
||||
summary.push_str("═══════════════════════\n");
|
||||
|
||||
summary.push_str(&format!("Total Tasks: {}\n", stats.total));
|
||||
summary.push_str(&format!("✅ Completed: {}\n", stats.completed));
|
||||
summary.push_str(&format!("❌ Failed: {}\n", stats.failed));
|
||||
summary.push_str(&format!("📈 Success Rate: {:.1}%\n", stats.success_rate));
|
||||
|
||||
if !failed_tasks.is_empty() {
|
||||
summary.push_str("\n❌ Failed Tasks:\n");
|
||||
for task in failed_tasks {
|
||||
summary.push_str(&format!(" • {}\n", task.name));
|
||||
if let Some(error) = &task.error {
|
||||
summary.push_str(&format!(" Error: {}\n", error));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
summary.push_str("\n📝 Generating summary...\n");
|
||||
summary
|
||||
} else {
|
||||
String::new()
|
||||
}
|
||||
}
|
||||
|
||||
fn format_task_display(task: &TaskInfo) -> String {
|
||||
let mut task_display = String::new();
|
||||
|
||||
let status_icon = match task.status {
|
||||
TaskStatus::Pending => "⏳",
|
||||
TaskStatus::Running => "🏃",
|
||||
TaskStatus::Completed => "✅",
|
||||
TaskStatus::Failed => "❌",
|
||||
};
|
||||
|
||||
task_display.push_str(&format!(
|
||||
"{} {} ({}){}\n",
|
||||
status_icon, task.task_name, task.task_type, CLEAR_TO_EOL
|
||||
));
|
||||
|
||||
if !task.task_metadata.is_empty() {
|
||||
task_display.push_str(&format!(
|
||||
" 📋 Parameters: {}{}\n",
|
||||
task.task_metadata, CLEAR_TO_EOL
|
||||
));
|
||||
}
|
||||
|
||||
if let Some(duration_secs) = task.duration_secs {
|
||||
task_display.push_str(&format!(" ⏱️ {:.1}s{}\n", duration_secs, CLEAR_TO_EOL));
|
||||
}
|
||||
|
||||
if matches!(task.status, TaskStatus::Running) && !task.current_output.trim().is_empty() {
|
||||
let processed_output = process_output_for_display(&task.current_output);
|
||||
if !processed_output.is_empty() {
|
||||
task_display.push_str(&format!(" 💬 {}{}\n", processed_output, CLEAR_TO_EOL));
|
||||
}
|
||||
}
|
||||
|
||||
if matches!(task.status, TaskStatus::Completed) {
|
||||
if let Some(result_data) = &task.result_data {
|
||||
let result_preview = format_result_data_for_display(result_data);
|
||||
if !result_preview.is_empty() {
|
||||
task_display.push_str(&format!(" 📄 {}{}\n", result_preview, CLEAR_TO_EOL));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if matches!(task.status, TaskStatus::Failed) {
|
||||
if let Some(error) = &task.error {
|
||||
let error_preview = truncate_with_ellipsis(error, 80);
|
||||
task_display.push_str(&format!(
|
||||
" ⚠️ {}{}\n",
|
||||
error_preview.replace('\n', " "),
|
||||
CLEAR_TO_EOL
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
task_display.push_str(&format!("{}\n", CLEAR_TO_EOL));
|
||||
task_display
|
||||
}
|
||||
337
crates/goose-cli/src/session/task_execution_display/tests.rs
Normal file
337
crates/goose-cli/src/session/task_execution_display/tests.rs
Normal file
@@ -0,0 +1,337 @@
|
||||
use super::*;
|
||||
use goose::agents::sub_recipe_execution_tool::notification_events::{
|
||||
FailedTaskInfo, TaskCompletionStats, TaskExecutionStats,
|
||||
};
|
||||
use serde_json::json;
|
||||
|
||||
#[test]
|
||||
fn test_strip_ansi_codes() {
|
||||
assert_eq!(strip_ansi_codes("hello world"), "hello world");
|
||||
assert_eq!(strip_ansi_codes("\x1b[31mred text\x1b[0m"), "red text");
|
||||
assert_eq!(
|
||||
strip_ansi_codes("\x1b[1;32mbold green\x1b[0m"),
|
||||
"bold green"
|
||||
);
|
||||
assert_eq!(
|
||||
strip_ansi_codes("normal\x1b[33myellow\x1b[0mnormal"),
|
||||
"normalyellownormal"
|
||||
);
|
||||
assert_eq!(strip_ansi_codes("\x1bhello"), "\x1bhello");
|
||||
assert_eq!(strip_ansi_codes("hello\x1b"), "hello\x1b");
|
||||
assert_eq!(strip_ansi_codes(""), "");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_truncate_with_ellipsis() {
|
||||
assert_eq!(truncate_with_ellipsis("hello", 10), "hello");
|
||||
assert_eq!(truncate_with_ellipsis("hello", 5), "hello");
|
||||
assert_eq!(truncate_with_ellipsis("hello world", 8), "hello...");
|
||||
assert_eq!(truncate_with_ellipsis("hello", 3), "...");
|
||||
assert_eq!(truncate_with_ellipsis("hello", 2), "...");
|
||||
assert_eq!(truncate_with_ellipsis("hello", 1), "...");
|
||||
assert_eq!(truncate_with_ellipsis("", 5), "");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_process_output_for_display() {
|
||||
assert_eq!(process_output_for_display("hello world"), "hello world");
|
||||
assert_eq!(
|
||||
process_output_for_display("line1\nline2"),
|
||||
"line1 ... line2"
|
||||
);
|
||||
|
||||
let input = "line1\nline2\nline3\nline4";
|
||||
let result = process_output_for_display(input);
|
||||
assert_eq!(result, "line3 ... line4");
|
||||
|
||||
let long_line = "a".repeat(150);
|
||||
let result = process_output_for_display(&long_line);
|
||||
assert!(result.len() <= 100);
|
||||
assert!(result.ends_with("..."));
|
||||
|
||||
let ansi_output = "\x1b[31mred line 1\x1b[0m\n\x1b[32mgreen line 2\x1b[0m";
|
||||
let result = process_output_for_display(ansi_output);
|
||||
assert_eq!(result, "red line 1 ... green line 2");
|
||||
|
||||
assert_eq!(process_output_for_display(""), "");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_format_result_data_for_display() {
|
||||
let string_val = json!("hello world");
|
||||
assert_eq!(format_result_data_for_display(&string_val), "hello world");
|
||||
|
||||
let ansi_string = json!("\x1b[31mred text\x1b[0m");
|
||||
assert_eq!(format_result_data_for_display(&ansi_string), "red text");
|
||||
|
||||
assert_eq!(format_result_data_for_display(&json!(true)), "true");
|
||||
assert_eq!(format_result_data_for_display(&json!(false)), "false");
|
||||
assert_eq!(format_result_data_for_display(&json!(42)), "42");
|
||||
assert_eq!(format_result_data_for_display(&json!(3.14)), "3.14");
|
||||
assert_eq!(format_result_data_for_display(&json!(null)), "null");
|
||||
|
||||
let partial_obj = json!({
|
||||
"partial_output": "some output",
|
||||
"other_field": "ignored"
|
||||
});
|
||||
assert_eq!(
|
||||
format_result_data_for_display(&partial_obj),
|
||||
"Partial output: some output"
|
||||
);
|
||||
|
||||
let obj = json!({"key": "value", "num": 42});
|
||||
let result = format_result_data_for_display(&obj);
|
||||
assert!(result.contains("key"));
|
||||
assert!(result.contains("value"));
|
||||
|
||||
let arr = json!([1, 2, 3]);
|
||||
let result = format_result_data_for_display(&arr);
|
||||
assert!(result.contains("1"));
|
||||
assert!(result.contains("2"));
|
||||
assert!(result.contains("3"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_format_task_execution_notification_line_output() {
|
||||
let _event = TaskExecutionNotificationEvent::LineOutput {
|
||||
task_id: "task-1".to_string(),
|
||||
output: "Hello World".to_string(),
|
||||
};
|
||||
|
||||
let data = json!({
|
||||
"subtype": "line_output",
|
||||
"task_id": "task-1",
|
||||
"output": "Hello World"
|
||||
});
|
||||
|
||||
let result = format_task_execution_notification(&data);
|
||||
assert!(result.is_some());
|
||||
|
||||
let (formatted, second, third) = result.unwrap();
|
||||
assert_eq!(formatted, "Hello World\n");
|
||||
assert_eq!(second, None);
|
||||
assert_eq!(third, Some("task_execution".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_format_task_execution_notification_invalid_data() {
|
||||
let invalid_data = json!({
|
||||
"invalid": "structure"
|
||||
});
|
||||
|
||||
let result = format_task_execution_notification(&invalid_data);
|
||||
assert_eq!(result, None);
|
||||
|
||||
let incomplete_data = json!({
|
||||
"subtype": "line_output"
|
||||
});
|
||||
|
||||
let result = format_task_execution_notification(&incomplete_data);
|
||||
assert_eq!(result, None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_format_tasks_update_from_event() {
|
||||
INITIAL_SHOWN.store(false, Ordering::SeqCst);
|
||||
|
||||
let stats = TaskExecutionStats::new(3, 1, 1, 1, 0);
|
||||
let tasks = vec![
|
||||
TaskInfo {
|
||||
id: "task-1".to_string(),
|
||||
status: TaskStatus::Running,
|
||||
duration_secs: Some(1.5),
|
||||
current_output: "Processing...".to_string(),
|
||||
task_type: "sub_recipe".to_string(),
|
||||
task_name: "test-task".to_string(),
|
||||
task_metadata: "param=value".to_string(),
|
||||
error: None,
|
||||
result_data: None,
|
||||
},
|
||||
TaskInfo {
|
||||
id: "task-2".to_string(),
|
||||
status: TaskStatus::Completed,
|
||||
duration_secs: Some(2.3),
|
||||
current_output: "".to_string(),
|
||||
task_type: "text_instruction".to_string(),
|
||||
task_name: "another-task".to_string(),
|
||||
task_metadata: "".to_string(),
|
||||
error: None,
|
||||
result_data: Some(json!({"result": "success"})),
|
||||
},
|
||||
];
|
||||
|
||||
let event = TaskExecutionNotificationEvent::TasksUpdate { stats, tasks };
|
||||
let result = format_tasks_update_from_event(&event);
|
||||
|
||||
assert!(result.contains("🎯 Task Execution Dashboard"));
|
||||
assert!(result.contains("═══════════════════════════"));
|
||||
assert!(result.contains("📊 Progress: 3 total"));
|
||||
assert!(result.contains("⏳ 1 pending"));
|
||||
assert!(result.contains("🏃 1 running"));
|
||||
assert!(result.contains("✅ 1 completed"));
|
||||
assert!(result.contains("❌ 0 failed"));
|
||||
assert!(result.contains("🏃 test-task"));
|
||||
assert!(result.contains("✅ another-task"));
|
||||
assert!(result.contains("📋 Parameters: param=value"));
|
||||
assert!(result.contains("⏱️ 1.5s"));
|
||||
assert!(result.contains("💬 Processing..."));
|
||||
|
||||
let result2 = format_tasks_update_from_event(&event);
|
||||
assert!(!result2.contains("🎯 Task Execution Dashboard"));
|
||||
assert!(result2.contains(MOVE_TO_PROGRESS_LINE));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_format_tasks_complete_from_event() {
|
||||
let stats = TaskCompletionStats::new(5, 4, 1);
|
||||
let failed_tasks = vec![FailedTaskInfo {
|
||||
id: "task-3".to_string(),
|
||||
name: "failed-task".to_string(),
|
||||
error: Some("Connection timeout".to_string()),
|
||||
}];
|
||||
|
||||
let event = TaskExecutionNotificationEvent::TasksComplete {
|
||||
stats,
|
||||
failed_tasks,
|
||||
};
|
||||
let result = format_tasks_complete_from_event(&event);
|
||||
|
||||
assert!(result.contains("Execution Complete!"));
|
||||
assert!(result.contains("═══════════════════════"));
|
||||
assert!(result.contains("Total Tasks: 5"));
|
||||
assert!(result.contains("✅ Completed: 4"));
|
||||
assert!(result.contains("❌ Failed: 1"));
|
||||
assert!(result.contains("📈 Success Rate: 80.0%"));
|
||||
assert!(result.contains("❌ Failed Tasks:"));
|
||||
assert!(result.contains("• failed-task"));
|
||||
assert!(result.contains("Error: Connection timeout"));
|
||||
assert!(result.contains("📝 Generating summary..."));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_format_tasks_complete_from_event_no_failures() {
|
||||
let stats = TaskCompletionStats::new(3, 3, 0);
|
||||
let failed_tasks = vec![];
|
||||
|
||||
let event = TaskExecutionNotificationEvent::TasksComplete {
|
||||
stats,
|
||||
failed_tasks,
|
||||
};
|
||||
let result = format_tasks_complete_from_event(&event);
|
||||
|
||||
assert!(!result.contains("❌ Failed Tasks:"));
|
||||
assert!(result.contains("📈 Success Rate: 100.0%"));
|
||||
assert!(result.contains("❌ Failed: 0"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_format_task_display_running() {
|
||||
let task = TaskInfo {
|
||||
id: "task-1".to_string(),
|
||||
status: TaskStatus::Running,
|
||||
duration_secs: Some(1.5),
|
||||
current_output: "Processing data...\nAlmost done...".to_string(),
|
||||
task_type: "sub_recipe".to_string(),
|
||||
task_name: "data-processor".to_string(),
|
||||
task_metadata: "input=file.txt,output=result.json".to_string(),
|
||||
error: None,
|
||||
result_data: None,
|
||||
};
|
||||
|
||||
let result = format_task_display(&task);
|
||||
|
||||
assert!(result.contains("🏃 data-processor (sub_recipe)"));
|
||||
assert!(result.contains("📋 Parameters: input=file.txt,output=result.json"));
|
||||
assert!(result.contains("⏱️ 1.5s"));
|
||||
assert!(result.contains("💬 Processing data... ... Almost done..."));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_format_task_display_completed() {
|
||||
let task = TaskInfo {
|
||||
id: "task-2".to_string(),
|
||||
status: TaskStatus::Completed,
|
||||
duration_secs: Some(3.2),
|
||||
current_output: "".to_string(),
|
||||
task_type: "text_instruction".to_string(),
|
||||
task_name: "analyzer".to_string(),
|
||||
task_metadata: "".to_string(),
|
||||
error: None,
|
||||
result_data: Some(json!({"status": "success", "count": 42})),
|
||||
};
|
||||
|
||||
let result = format_task_display(&task);
|
||||
|
||||
assert!(result.contains("✅ analyzer (text_instruction)"));
|
||||
assert!(result.contains("⏱️ 3.2s"));
|
||||
assert!(!result.contains("📋 Parameters"));
|
||||
assert!(result.contains("📄"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_format_task_display_failed() {
|
||||
let task = TaskInfo {
|
||||
id: "task-3".to_string(),
|
||||
status: TaskStatus::Failed,
|
||||
duration_secs: None,
|
||||
current_output: "".to_string(),
|
||||
task_type: "sub_recipe".to_string(),
|
||||
task_name: "failing-task".to_string(),
|
||||
task_metadata: "".to_string(),
|
||||
error: Some(
|
||||
"Network connection failed after multiple retries. The server is unreachable."
|
||||
.to_string(),
|
||||
),
|
||||
result_data: None,
|
||||
};
|
||||
|
||||
let result = format_task_display(&task);
|
||||
|
||||
assert!(result.contains("❌ failing-task (sub_recipe)"));
|
||||
assert!(!result.contains("⏱️"));
|
||||
assert!(result.contains("⚠️"));
|
||||
assert!(result.contains("Network connection failed after multiple retries"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_format_task_display_pending() {
|
||||
let task = TaskInfo {
|
||||
id: "task-4".to_string(),
|
||||
status: TaskStatus::Pending,
|
||||
duration_secs: None,
|
||||
current_output: "".to_string(),
|
||||
task_type: "sub_recipe".to_string(),
|
||||
task_name: "waiting-task".to_string(),
|
||||
task_metadata: "priority=high".to_string(),
|
||||
error: None,
|
||||
result_data: None,
|
||||
};
|
||||
|
||||
let result = format_task_display(&task);
|
||||
|
||||
assert!(result.contains("⏳ waiting-task (sub_recipe)"));
|
||||
assert!(result.contains("📋 Parameters: priority=high"));
|
||||
assert!(!result.contains("⏱️"));
|
||||
assert!(!result.contains("💬"));
|
||||
assert!(!result.contains("📄"));
|
||||
assert!(!result.contains("⚠️"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_format_task_display_empty_current_output() {
|
||||
let task = TaskInfo {
|
||||
id: "task-5".to_string(),
|
||||
status: TaskStatus::Running,
|
||||
duration_secs: Some(0.5),
|
||||
current_output: " \n\t \n ".to_string(),
|
||||
task_type: "sub_recipe".to_string(),
|
||||
task_name: "quiet-task".to_string(),
|
||||
task_metadata: "".to_string(),
|
||||
error: None,
|
||||
result_data: None,
|
||||
};
|
||||
|
||||
let result = format_task_display(&task);
|
||||
|
||||
assert!(!result.contains("💬"));
|
||||
}
|
||||
@@ -12,6 +12,7 @@ use crate::agents::final_output_tool::{FINAL_OUTPUT_CONTINUATION_MESSAGE, FINAL_
|
||||
use crate::agents::sub_recipe_execution_tool::sub_recipe_execute_task_tool::{
|
||||
self, SUB_RECIPE_EXECUTE_TASK_TOOL_NAME,
|
||||
};
|
||||
use crate::agents::sub_recipe_execution_tool::tasks_manager::TasksManager;
|
||||
use crate::agents::sub_recipe_manager::SubRecipeManager;
|
||||
use crate::config::{Config, ExtensionConfigManager, PermissionManager};
|
||||
use crate::message::{push_message, Message};
|
||||
@@ -63,6 +64,7 @@ pub struct Agent {
|
||||
pub(super) provider: Mutex<Option<Arc<dyn Provider>>>,
|
||||
pub(super) extension_manager: RwLock<ExtensionManager>,
|
||||
pub(super) sub_recipe_manager: Mutex<SubRecipeManager>,
|
||||
pub(super) tasks_manager: TasksManager,
|
||||
pub(super) final_output_tool: Mutex<Option<FinalOutputTool>>,
|
||||
pub(super) frontend_tools: Mutex<HashMap<String, FrontendTool>>,
|
||||
pub(super) frontend_instructions: Mutex<Option<String>>,
|
||||
@@ -137,6 +139,7 @@ impl Agent {
|
||||
provider: Mutex::new(None),
|
||||
extension_manager: RwLock::new(ExtensionManager::new()),
|
||||
sub_recipe_manager: Mutex::new(SubRecipeManager::new()),
|
||||
tasks_manager: TasksManager::new(),
|
||||
final_output_tool: Mutex::new(None),
|
||||
frontend_tools: Mutex::new(HashMap::new()),
|
||||
frontend_instructions: Mutex::new(None),
|
||||
@@ -291,10 +294,18 @@ impl Agent {
|
||||
let sub_recipe_manager = self.sub_recipe_manager.lock().await;
|
||||
let result: ToolCallResult = if sub_recipe_manager.is_sub_recipe_tool(&tool_call.name) {
|
||||
sub_recipe_manager
|
||||
.dispatch_sub_recipe_tool_call(&tool_call.name, tool_call.arguments.clone())
|
||||
.dispatch_sub_recipe_tool_call(
|
||||
&tool_call.name,
|
||||
tool_call.arguments.clone(),
|
||||
&self.tasks_manager,
|
||||
)
|
||||
.await
|
||||
} else if tool_call.name == SUB_RECIPE_EXECUTE_TASK_TOOL_NAME {
|
||||
sub_recipe_execute_task_tool::run_tasks(tool_call.arguments.clone()).await
|
||||
sub_recipe_execute_task_tool::run_tasks(
|
||||
tool_call.arguments.clone(),
|
||||
&self.tasks_manager,
|
||||
)
|
||||
.await
|
||||
} else if tool_call.name == PLATFORM_READ_RESOURCE_TOOL_NAME {
|
||||
// Check if the tool is read_resource and handle it separately
|
||||
ToolCallResult::from(
|
||||
|
||||
@@ -1 +1,2 @@
|
||||
pub mod param_utils;
|
||||
pub mod sub_recipe_tools;
|
||||
|
||||
38
crates/goose/src/agents/recipe_tools/param_utils/mod.rs
Normal file
38
crates/goose/src/agents/recipe_tools/param_utils/mod.rs
Normal file
@@ -0,0 +1,38 @@
|
||||
use anyhow::Result;
|
||||
use serde_json::Value;
|
||||
use std::collections::HashMap;
|
||||
|
||||
use crate::recipe::SubRecipe;
|
||||
|
||||
pub fn prepare_command_params(
|
||||
sub_recipe: &SubRecipe,
|
||||
params_from_tool_call: Vec<Value>,
|
||||
) -> Result<Vec<HashMap<String, String>>> {
|
||||
let base_params = sub_recipe.values.clone().unwrap_or_default();
|
||||
|
||||
if params_from_tool_call.is_empty() {
|
||||
return Ok(vec![base_params]);
|
||||
}
|
||||
|
||||
let result = params_from_tool_call
|
||||
.into_iter()
|
||||
.map(|tool_param| {
|
||||
let mut param_map = base_params.clone();
|
||||
if let Some(param_obj) = tool_param.as_object() {
|
||||
for (key, value) in param_obj {
|
||||
let value_str = value
|
||||
.as_str()
|
||||
.map(String::from)
|
||||
.unwrap_or_else(|| value.to_string());
|
||||
param_map.entry(key.clone()).or_insert(value_str);
|
||||
}
|
||||
}
|
||||
param_map
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
140
crates/goose/src/agents/recipe_tools/param_utils/tests.rs
Normal file
140
crates/goose/src/agents/recipe_tools/param_utils/tests.rs
Normal file
@@ -0,0 +1,140 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use crate::recipe::SubRecipe;
|
||||
use serde_json::json;
|
||||
|
||||
use crate::agents::recipe_tools::param_utils::prepare_command_params;
|
||||
|
||||
fn setup_default_sub_recipe() -> SubRecipe {
|
||||
let sub_recipe = SubRecipe {
|
||||
name: "test_sub_recipe".to_string(),
|
||||
path: "test_sub_recipe.yaml".to_string(),
|
||||
values: Some(HashMap::from([("key1".to_string(), "value1".to_string())])),
|
||||
sequential_when_repeated: true,
|
||||
};
|
||||
sub_recipe
|
||||
}
|
||||
|
||||
mod prepare_command_params_tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_return_command_param() {
|
||||
let parameter_array = vec![json!(HashMap::from([(
|
||||
"key2".to_string(),
|
||||
"value2".to_string()
|
||||
)]))];
|
||||
let mut sub_recipe = setup_default_sub_recipe();
|
||||
sub_recipe.values = Some(HashMap::from([("key1".to_string(), "value1".to_string())]));
|
||||
|
||||
let result = prepare_command_params(&sub_recipe, parameter_array).unwrap();
|
||||
assert_eq!(
|
||||
vec![HashMap::from([
|
||||
("key1".to_string(), "value1".to_string()),
|
||||
("key2".to_string(), "value2".to_string())
|
||||
]),],
|
||||
result
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_return_command_param_when_value_override_passed_param_value() {
|
||||
let parameter_array = vec![json!(HashMap::from([(
|
||||
"key2".to_string(),
|
||||
"different_value".to_string()
|
||||
)]))];
|
||||
let mut sub_recipe = setup_default_sub_recipe();
|
||||
sub_recipe.values = Some(HashMap::from([
|
||||
("key1".to_string(), "value1".to_string()),
|
||||
("key2".to_string(), "value2".to_string()),
|
||||
]));
|
||||
|
||||
let result = prepare_command_params(&sub_recipe, parameter_array).unwrap();
|
||||
assert_eq!(
|
||||
vec![HashMap::from([
|
||||
("key1".to_string(), "value1".to_string()),
|
||||
("key2".to_string(), "value2".to_string())
|
||||
]),],
|
||||
result
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_return_empty_command_param() {
|
||||
let parameter_array = vec![];
|
||||
let mut sub_recipe = setup_default_sub_recipe();
|
||||
sub_recipe.values = None;
|
||||
|
||||
let result = prepare_command_params(&sub_recipe, parameter_array).unwrap();
|
||||
assert_eq!(result, vec![HashMap::new()]);
|
||||
}
|
||||
|
||||
mod multiple_tool_parameters {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_return_command_param_when_all_values_from_tool_call_parameters() {
|
||||
let parameter_array = vec![
|
||||
json!(HashMap::from([
|
||||
("key1".to_string(), "key1_value1".to_string()),
|
||||
("key2".to_string(), "key2_value1".to_string())
|
||||
])),
|
||||
json!(HashMap::from([
|
||||
("key1".to_string(), "key1_value2".to_string()),
|
||||
("key2".to_string(), "key2_value2".to_string())
|
||||
])),
|
||||
];
|
||||
let mut sub_recipe = setup_default_sub_recipe();
|
||||
sub_recipe.values = None;
|
||||
|
||||
let result = prepare_command_params(&sub_recipe, parameter_array).unwrap();
|
||||
assert_eq!(
|
||||
vec![
|
||||
HashMap::from([
|
||||
("key1".to_string(), "key1_value1".to_string()),
|
||||
("key2".to_string(), "key2_value1".to_string()),
|
||||
]),
|
||||
HashMap::from([
|
||||
("key1".to_string(), "key1_value2".to_string()),
|
||||
("key2".to_string(), "key2_value2".to_string()),
|
||||
]),
|
||||
],
|
||||
result
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_merge_base_values_with_tool_parameters() {
|
||||
let parameter_array = vec![
|
||||
json!(HashMap::from([(
|
||||
"key2".to_string(),
|
||||
"override_value1".to_string()
|
||||
)])),
|
||||
json!(HashMap::from([(
|
||||
"key2".to_string(),
|
||||
"override_value2".to_string()
|
||||
)])),
|
||||
];
|
||||
let mut sub_recipe = setup_default_sub_recipe();
|
||||
sub_recipe.values = Some(HashMap::from([
|
||||
("key1".to_string(), "base_value".to_string()),
|
||||
("key2".to_string(), "original_value".to_string()),
|
||||
]));
|
||||
|
||||
let result = prepare_command_params(&sub_recipe, parameter_array).unwrap();
|
||||
assert_eq!(
|
||||
vec![
|
||||
HashMap::from([
|
||||
("key1".to_string(), "base_value".to_string()),
|
||||
("key2".to_string(), "original_value".to_string()),
|
||||
]),
|
||||
HashMap::from([
|
||||
("key1".to_string(), "base_value".to_string()),
|
||||
("key2".to_string(), "original_value".to_string()),
|
||||
]),
|
||||
],
|
||||
result
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,22 +1,35 @@
|
||||
use std::{collections::HashMap, fs};
|
||||
use std::collections::HashSet;
|
||||
use std::fs;
|
||||
|
||||
use anyhow::Result;
|
||||
use mcp_core::tool::{Tool, ToolAnnotations};
|
||||
use serde_json::{json, Map, Value};
|
||||
|
||||
use crate::agents::sub_recipe_execution_tool::lib::Task;
|
||||
use crate::agents::sub_recipe_execution_tool::lib::{ExecutionMode, Task};
|
||||
use crate::agents::sub_recipe_execution_tool::tasks_manager::TasksManager;
|
||||
use crate::recipe::{Recipe, RecipeParameter, RecipeParameterRequirement, SubRecipe};
|
||||
|
||||
use super::param_utils::prepare_command_params;
|
||||
|
||||
pub const SUB_RECIPE_TASK_TOOL_NAME_PREFIX: &str = "subrecipe__create_task";
|
||||
|
||||
pub fn create_sub_recipe_task_tool(sub_recipe: &SubRecipe) -> Tool {
|
||||
let input_schema = get_input_schema(sub_recipe).unwrap();
|
||||
Tool::new(
|
||||
format!("{}_{}", SUB_RECIPE_TASK_TOOL_NAME_PREFIX, sub_recipe.name),
|
||||
"Before running this sub recipe, you should first create a task with this tool and then pass the task to the task executor".to_string(),
|
||||
format!(
|
||||
"Create one or more tasks to run the '{}' sub recipe. \
|
||||
Provide an array of parameter sets in the 'task_parameters' field:\n\
|
||||
- For a single task: provide an array with one parameter set\n\
|
||||
- For multiple tasks: provide an array with multiple parameter sets, each with different values\n\n\
|
||||
Each task will run the same sub recipe but with different parameter values. \
|
||||
This is useful when you need to execute the same sub recipe multiple times with varying inputs. \
|
||||
After creating the tasks and execution_mode is provided, pass them to the task executor to run these tasks",
|
||||
sub_recipe.name
|
||||
),
|
||||
input_schema,
|
||||
Some(ToolAnnotations {
|
||||
title: Some(format!("create sub recipe task {}", sub_recipe.name)),
|
||||
title: Some(format!("create multiple sub recipe tasks for {}", sub_recipe.name)),
|
||||
read_only_hint: false,
|
||||
destructive_hint: true,
|
||||
idempotent_hint: false,
|
||||
@@ -25,6 +38,64 @@ pub fn create_sub_recipe_task_tool(sub_recipe: &SubRecipe) -> Tool {
|
||||
)
|
||||
}
|
||||
|
||||
fn extract_task_parameters(params: &Value) -> Vec<Value> {
|
||||
params
|
||||
.get("task_parameters")
|
||||
.and_then(|v| v.as_array())
|
||||
.cloned()
|
||||
.unwrap_or_default()
|
||||
}
|
||||
|
||||
fn create_tasks_from_params(
|
||||
sub_recipe: &SubRecipe,
|
||||
command_params: &[std::collections::HashMap<String, String>],
|
||||
) -> Vec<Task> {
|
||||
let tasks: Vec<Task> = command_params
|
||||
.iter()
|
||||
.map(|task_command_param| {
|
||||
let payload = json!({
|
||||
"sub_recipe": {
|
||||
"name": sub_recipe.name.clone(),
|
||||
"command_parameters": task_command_param,
|
||||
"recipe_path": sub_recipe.path.clone(),
|
||||
"sequential_when_repeated": sub_recipe.sequential_when_repeated
|
||||
}
|
||||
});
|
||||
Task {
|
||||
id: uuid::Uuid::new_v4().to_string(),
|
||||
task_type: "sub_recipe".to_string(),
|
||||
payload,
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
tasks
|
||||
}
|
||||
|
||||
fn create_task_execution_payload(tasks: &[Task], sub_recipe: &SubRecipe) -> Value {
|
||||
let task_ids: Vec<String> = tasks.iter().map(|task| task.id.clone()).collect();
|
||||
json!({
|
||||
"task_ids": task_ids,
|
||||
"execution_mode": if sub_recipe.sequential_when_repeated { ExecutionMode::Sequential } else { ExecutionMode::Parallel },
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn create_sub_recipe_task(
|
||||
sub_recipe: &SubRecipe,
|
||||
params: Value,
|
||||
tasks_manager: &TasksManager,
|
||||
) -> Result<String> {
|
||||
let task_params_array = extract_task_parameters(¶ms);
|
||||
let command_params = prepare_command_params(sub_recipe, task_params_array.clone())?;
|
||||
let tasks = create_tasks_from_params(sub_recipe, &command_params);
|
||||
let task_execution_payload = create_task_execution_payload(&tasks, sub_recipe);
|
||||
|
||||
let tasks_json = serde_json::to_string(&task_execution_payload)
|
||||
.map_err(|e| anyhow::anyhow!("Failed to serialize task list: {}", e))?;
|
||||
tasks_manager.save_tasks(tasks.clone()).await;
|
||||
Ok(tasks_json)
|
||||
}
|
||||
|
||||
fn get_sub_recipe_parameter_definition(
|
||||
sub_recipe: &SubRecipe,
|
||||
) -> Result<Option<Vec<RecipeParameter>>> {
|
||||
@@ -34,22 +105,55 @@ fn get_sub_recipe_parameter_definition(
|
||||
Ok(recipe.parameters)
|
||||
}
|
||||
|
||||
fn get_input_schema(sub_recipe: &SubRecipe) -> Result<Value> {
|
||||
let mut sub_recipe_params_map = HashMap::<String, String>::new();
|
||||
fn get_params_with_values(sub_recipe: &SubRecipe) -> HashSet<String> {
|
||||
let mut sub_recipe_params_with_values = HashSet::<String>::new();
|
||||
if let Some(params_with_value) = &sub_recipe.values {
|
||||
for (param_name, param_value) in params_with_value {
|
||||
sub_recipe_params_map.insert(param_name.clone(), param_value.clone());
|
||||
for param_name in params_with_value.keys() {
|
||||
sub_recipe_params_with_values.insert(param_name.clone());
|
||||
}
|
||||
}
|
||||
let parameter_definition = get_sub_recipe_parameter_definition(sub_recipe)?;
|
||||
if let Some(parameters) = parameter_definition {
|
||||
sub_recipe_params_with_values
|
||||
}
|
||||
|
||||
fn create_input_schema(param_properties: Map<String, Value>, param_required: Vec<String>) -> Value {
|
||||
let mut properties = Map::new();
|
||||
let mut required = Vec::new();
|
||||
if !param_properties.is_empty() {
|
||||
properties.insert(
|
||||
"task_parameters".to_string(),
|
||||
json!({
|
||||
"type": "array",
|
||||
"description": "Array of parameter sets for creating tasks. \
|
||||
For a single task, provide an array with one element. \
|
||||
For multiple tasks, provide an array with multiple elements, each with different parameter values. \
|
||||
If there is no parameter set, provide an empty array.",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": param_properties,
|
||||
"required": param_required
|
||||
},
|
||||
})
|
||||
);
|
||||
}
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": properties,
|
||||
})
|
||||
}
|
||||
|
||||
fn get_input_schema(sub_recipe: &SubRecipe) -> Result<Value> {
|
||||
let sub_recipe_params_with_values = get_params_with_values(sub_recipe);
|
||||
|
||||
let parameter_definition = get_sub_recipe_parameter_definition(sub_recipe)?;
|
||||
|
||||
let mut param_properties = Map::new();
|
||||
let mut param_required = Vec::new();
|
||||
|
||||
if let Some(parameters) = parameter_definition {
|
||||
for param in parameters {
|
||||
if sub_recipe_params_map.contains_key(¶m.key) {
|
||||
if sub_recipe_params_with_values.contains(¶m.key.clone()) {
|
||||
continue;
|
||||
}
|
||||
properties.insert(
|
||||
param_properties.insert(
|
||||
param.key.clone(),
|
||||
json!({
|
||||
"type": param.input_type.to_string(),
|
||||
@@ -57,60 +161,11 @@ fn get_input_schema(sub_recipe: &SubRecipe) -> Result<Value> {
|
||||
}),
|
||||
);
|
||||
if !matches!(param.requirement, RecipeParameterRequirement::Optional) {
|
||||
required.push(param.key);
|
||||
param_required.push(param.key);
|
||||
}
|
||||
}
|
||||
Ok(json!({
|
||||
"type": "object",
|
||||
"properties": properties,
|
||||
"required": required
|
||||
}))
|
||||
} else {
|
||||
Ok(json!({
|
||||
"type": "object",
|
||||
"properties": {}
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
fn prepare_command_params(
|
||||
sub_recipe: &SubRecipe,
|
||||
params_from_tool_call: Value,
|
||||
) -> Result<HashMap<String, String>> {
|
||||
let mut sub_recipe_params = HashMap::<String, String>::new();
|
||||
if let Some(params_with_value) = &sub_recipe.values {
|
||||
for (param_name, param_value) in params_with_value {
|
||||
sub_recipe_params.insert(param_name.clone(), param_value.clone());
|
||||
}
|
||||
}
|
||||
if let Some(params_map) = params_from_tool_call.as_object() {
|
||||
for (key, value) in params_map {
|
||||
sub_recipe_params.insert(
|
||||
key.to_string(),
|
||||
value.as_str().unwrap_or(&value.to_string()).to_string(),
|
||||
);
|
||||
}
|
||||
}
|
||||
Ok(sub_recipe_params)
|
||||
}
|
||||
|
||||
pub async fn create_sub_recipe_task(sub_recipe: &SubRecipe, params: Value) -> Result<String> {
|
||||
let command_params = prepare_command_params(sub_recipe, params)?;
|
||||
let payload = json!({
|
||||
"sub_recipe": {
|
||||
"name": sub_recipe.name.clone(),
|
||||
"command_parameters": command_params,
|
||||
"recipe_path": sub_recipe.path.clone(),
|
||||
}
|
||||
});
|
||||
let task = Task {
|
||||
id: uuid::Uuid::new_v4().to_string(),
|
||||
task_type: "sub_recipe".to_string(),
|
||||
payload,
|
||||
};
|
||||
let task_json = serde_json::to_string(&task)
|
||||
.map_err(|e| anyhow::anyhow!("Failed to serialize Task: {}", e))?;
|
||||
Ok(task_json)
|
||||
Ok(create_input_schema(param_properties, param_required))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
||||
@@ -3,66 +3,48 @@ mod tests {
|
||||
use std::collections::HashMap;
|
||||
|
||||
use crate::recipe::SubRecipe;
|
||||
use serde_json::json;
|
||||
use serde_json::Value;
|
||||
use tempfile::TempDir;
|
||||
|
||||
fn setup_sub_recipe() -> SubRecipe {
|
||||
fn setup_default_sub_recipe() -> SubRecipe {
|
||||
let sub_recipe = SubRecipe {
|
||||
name: "test_sub_recipe".to_string(),
|
||||
path: "test_sub_recipe.yaml".to_string(),
|
||||
values: Some(HashMap::from([("key1".to_string(), "value1".to_string())])),
|
||||
sequential_when_repeated: true,
|
||||
};
|
||||
sub_recipe
|
||||
}
|
||||
mod prepare_command_params_tests {
|
||||
use std::collections::HashMap;
|
||||
|
||||
use crate::{
|
||||
agents::recipe_tools::sub_recipe_tools::{
|
||||
prepare_command_params, tests::tests::setup_sub_recipe,
|
||||
},
|
||||
recipe::SubRecipe,
|
||||
};
|
||||
mod get_input_schema {
|
||||
use super::*;
|
||||
use crate::agents::recipe_tools::sub_recipe_tools::get_input_schema;
|
||||
|
||||
#[test]
|
||||
fn test_prepare_command_params_basic() {
|
||||
let mut params = HashMap::new();
|
||||
params.insert("key2".to_string(), "value2".to_string());
|
||||
|
||||
let sub_recipe = setup_sub_recipe();
|
||||
|
||||
let params_value = serde_json::to_value(params).unwrap();
|
||||
let result = prepare_command_params(&sub_recipe, params_value).unwrap();
|
||||
assert_eq!(result.len(), 2);
|
||||
assert_eq!(result.get("key1"), Some(&"value1".to_string()));
|
||||
assert_eq!(result.get("key2"), Some(&"value2".to_string()));
|
||||
fn prepare_sub_recipe(sub_recipe_file_content: &str) -> (SubRecipe, TempDir) {
|
||||
let mut sub_recipe = setup_default_sub_recipe();
|
||||
let temp_dir = tempfile::tempdir().unwrap();
|
||||
let temp_file = temp_dir.path().join(sub_recipe.path.clone());
|
||||
std::fs::write(&temp_file, sub_recipe_file_content).unwrap();
|
||||
sub_recipe.path = temp_file.to_string_lossy().to_string();
|
||||
(sub_recipe, temp_dir)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_prepare_command_params_empty() {
|
||||
let sub_recipe = SubRecipe {
|
||||
name: "test_sub_recipe".to_string(),
|
||||
path: "test_sub_recipe.yaml".to_string(),
|
||||
values: None,
|
||||
};
|
||||
let params: HashMap<String, String> = HashMap::new();
|
||||
let params_value = serde_json::to_value(params).unwrap();
|
||||
let result = prepare_command_params(&sub_recipe, params_value).unwrap();
|
||||
assert_eq!(result.len(), 0);
|
||||
}
|
||||
fn verify_task_parameters(result: Value, expected_task_parameters_items: Value) {
|
||||
let task_parameters = result
|
||||
.get("properties")
|
||||
.unwrap()
|
||||
.as_object()
|
||||
.unwrap()
|
||||
.get("task_parameters")
|
||||
.unwrap()
|
||||
.as_object()
|
||||
.unwrap();
|
||||
let task_parameters_items = task_parameters.get("items").unwrap();
|
||||
assert_eq!(&expected_task_parameters_items, task_parameters_items);
|
||||
}
|
||||
|
||||
mod get_input_schema_tests {
|
||||
use crate::{
|
||||
agents::recipe_tools::sub_recipe_tools::{
|
||||
get_input_schema, tests::tests::setup_sub_recipe,
|
||||
},
|
||||
recipe::SubRecipe,
|
||||
};
|
||||
|
||||
#[test]
|
||||
fn test_get_input_schema_with_parameters() {
|
||||
let sub_recipe = setup_sub_recipe();
|
||||
|
||||
let sub_recipe_file_content = r#"{
|
||||
const SUB_RECIPE_FILE_CONTENT_WITH_TWO_PARAMS: &str = r#"{
|
||||
"version": "1.0.0",
|
||||
"title": "Test Recipe",
|
||||
"description": "A test recipe",
|
||||
@@ -83,73 +65,67 @@ mod tests {
|
||||
]
|
||||
}"#;
|
||||
|
||||
let temp_dir = tempfile::tempdir().unwrap();
|
||||
let temp_file = temp_dir.path().join("test_sub_recipe.yaml");
|
||||
std::fs::write(&temp_file, sub_recipe_file_content).unwrap();
|
||||
|
||||
let mut sub_recipe = sub_recipe;
|
||||
sub_recipe.path = temp_file.to_string_lossy().to_string();
|
||||
#[test]
|
||||
fn test_with_one_param_in_tool_input() {
|
||||
let (mut sub_recipe, _temp_dir) =
|
||||
prepare_sub_recipe(SUB_RECIPE_FILE_CONTENT_WITH_TWO_PARAMS);
|
||||
sub_recipe.values = Some(HashMap::from([("key1".to_string(), "value1".to_string())]));
|
||||
|
||||
let result = get_input_schema(&sub_recipe).unwrap();
|
||||
|
||||
// Verify the schema structure
|
||||
assert_eq!(result["type"], "object");
|
||||
assert!(result["properties"].is_object());
|
||||
|
||||
let properties = result["properties"].as_object().unwrap();
|
||||
assert_eq!(properties.len(), 1);
|
||||
|
||||
let key2_prop = &properties["key2"];
|
||||
assert_eq!(key2_prop["type"], "number");
|
||||
assert_eq!(key2_prop["description"], "An optional parameter");
|
||||
|
||||
let required = result["required"].as_array().unwrap();
|
||||
assert_eq!(required.len(), 0);
|
||||
verify_task_parameters(
|
||||
result,
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"key2": { "type": "number", "description": "An optional parameter" }
|
||||
},
|
||||
"required": []
|
||||
}),
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_get_input_schema_no_parameters_values() {
|
||||
let sub_recipe = SubRecipe {
|
||||
name: "test_sub_recipe".to_string(),
|
||||
path: "test_sub_recipe.yaml".to_string(),
|
||||
values: None,
|
||||
};
|
||||
|
||||
let sub_recipe_file_content = r#"{
|
||||
"version": "1.0.0",
|
||||
"title": "Test Recipe",
|
||||
"description": "A test recipe",
|
||||
"prompt": "Test prompt",
|
||||
"parameters": [
|
||||
{
|
||||
"key": "key1",
|
||||
"input_type": "string",
|
||||
"requirement": "required",
|
||||
"description": "A test parameter"
|
||||
}
|
||||
]
|
||||
}"#;
|
||||
|
||||
let temp_dir = tempfile::tempdir().unwrap();
|
||||
let temp_file = temp_dir.path().join("test_sub_recipe.yaml");
|
||||
std::fs::write(&temp_file, sub_recipe_file_content).unwrap();
|
||||
|
||||
let mut sub_recipe = sub_recipe;
|
||||
sub_recipe.path = temp_file.to_string_lossy().to_string();
|
||||
fn test_without_param_in_tool_input() {
|
||||
let (mut sub_recipe, _temp_dir) =
|
||||
prepare_sub_recipe(SUB_RECIPE_FILE_CONTENT_WITH_TWO_PARAMS);
|
||||
sub_recipe.values = Some(HashMap::from([
|
||||
("key1".to_string(), "value1".to_string()),
|
||||
("key2".to_string(), "value2".to_string()),
|
||||
]));
|
||||
|
||||
let result = get_input_schema(&sub_recipe).unwrap();
|
||||
|
||||
assert_eq!(result["type"], "object");
|
||||
assert!(result["properties"].is_object());
|
||||
assert_eq!(
|
||||
None,
|
||||
result
|
||||
.get("properties")
|
||||
.unwrap()
|
||||
.as_object()
|
||||
.unwrap()
|
||||
.get("task_parameters")
|
||||
);
|
||||
}
|
||||
|
||||
let properties = result["properties"].as_object().unwrap();
|
||||
assert_eq!(properties.len(), 1);
|
||||
#[test]
|
||||
fn test_with_all_params_in_tool_input() {
|
||||
let (mut sub_recipe, _temp_dir) =
|
||||
prepare_sub_recipe(SUB_RECIPE_FILE_CONTENT_WITH_TWO_PARAMS);
|
||||
sub_recipe.values = None;
|
||||
|
||||
let key1_prop = &properties["key1"];
|
||||
assert_eq!(key1_prop["type"], "string");
|
||||
assert_eq!(key1_prop["description"], "A test parameter");
|
||||
assert_eq!(result["required"].as_array().unwrap().len(), 1);
|
||||
assert_eq!(result["required"][0], "key1");
|
||||
let result = get_input_schema(&sub_recipe).unwrap();
|
||||
|
||||
verify_task_parameters(
|
||||
result,
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"key1": { "type": "string", "description": "A test parameter" },
|
||||
"key2": { "type": "number", "description": "An optional parameter" }
|
||||
},
|
||||
"required": ["key1"]
|
||||
}),
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,103 +0,0 @@
|
||||
use std::sync::atomic::{AtomicBool, AtomicUsize};
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::time::Instant;
|
||||
|
||||
use crate::agents::sub_recipe_execution_tool::lib::{
|
||||
Config, ExecutionResponse, ExecutionStats, Task, TaskResult,
|
||||
};
|
||||
use crate::agents::sub_recipe_execution_tool::tasks::process_task;
|
||||
use crate::agents::sub_recipe_execution_tool::workers::{run_scaler, spawn_worker, SharedState};
|
||||
|
||||
pub async fn execute_single_task(task: &Task, config: Config) -> ExecutionResponse {
|
||||
let start_time = Instant::now();
|
||||
let result = process_task(task, config.timeout_seconds).await;
|
||||
|
||||
let execution_time = start_time.elapsed().as_millis();
|
||||
let completed = if result.status == "success" { 1 } else { 0 };
|
||||
let failed = if result.status == "failed" { 1 } else { 0 };
|
||||
|
||||
ExecutionResponse {
|
||||
status: "completed".to_string(),
|
||||
results: vec![result],
|
||||
stats: ExecutionStats {
|
||||
total_tasks: 1,
|
||||
completed,
|
||||
failed,
|
||||
execution_time_ms: execution_time,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Main parallel execution function
|
||||
pub async fn parallel_execute(tasks: Vec<Task>, config: Config) -> ExecutionResponse {
|
||||
let start_time = Instant::now();
|
||||
let task_count = tasks.len();
|
||||
|
||||
// Create channels
|
||||
let (task_tx, task_rx) = mpsc::channel::<Task>(task_count);
|
||||
let (result_tx, mut result_rx) = mpsc::channel::<TaskResult>(task_count);
|
||||
|
||||
// Initialize shared state
|
||||
let shared_state = Arc::new(SharedState {
|
||||
task_receiver: Arc::new(tokio::sync::Mutex::new(task_rx)),
|
||||
result_sender: result_tx,
|
||||
active_workers: Arc::new(AtomicUsize::new(0)),
|
||||
should_stop: Arc::new(AtomicBool::new(false)),
|
||||
completed_tasks: Arc::new(AtomicUsize::new(0)),
|
||||
});
|
||||
|
||||
// Send all tasks to the queue
|
||||
for task in tasks.clone() {
|
||||
let _ = task_tx.send(task).await;
|
||||
}
|
||||
// Close sender so workers know when queue is empty
|
||||
drop(task_tx);
|
||||
|
||||
// Start initial workers
|
||||
let mut worker_handles = Vec::new();
|
||||
for i in 0..config.initial_workers {
|
||||
let handle = spawn_worker(shared_state.clone(), i, config.timeout_seconds);
|
||||
worker_handles.push(handle);
|
||||
}
|
||||
|
||||
// Start the scaler
|
||||
let scaler_state = shared_state.clone();
|
||||
let scaler_handle = tokio::spawn(async move {
|
||||
run_scaler(
|
||||
scaler_state,
|
||||
task_count,
|
||||
config.max_workers,
|
||||
config.timeout_seconds,
|
||||
)
|
||||
.await;
|
||||
});
|
||||
|
||||
// Collect results
|
||||
let mut results = Vec::new();
|
||||
while let Some(result) = result_rx.recv().await {
|
||||
results.push(result);
|
||||
if results.len() >= task_count {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Wait for scaler to finish
|
||||
let _ = scaler_handle.await;
|
||||
|
||||
// Calculate stats
|
||||
let execution_time = start_time.elapsed().as_millis();
|
||||
let completed = results.iter().filter(|r| r.status == "success").count();
|
||||
let failed = results.iter().filter(|r| r.status == "failed").count();
|
||||
|
||||
ExecutionResponse {
|
||||
status: "completed".to_string(),
|
||||
results,
|
||||
stats: ExecutionStats {
|
||||
total_tasks: task_count,
|
||||
completed,
|
||||
failed,
|
||||
execution_time_ms: execution_time,
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,197 @@
|
||||
use mcp_core::protocol::JsonRpcMessage;
|
||||
use std::sync::atomic::AtomicUsize;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::time::Instant;
|
||||
|
||||
use crate::agents::sub_recipe_execution_tool::lib::{
|
||||
ExecutionResponse, ExecutionStats, SharedState, Task, TaskResult, TaskStatus,
|
||||
};
|
||||
use crate::agents::sub_recipe_execution_tool::task_execution_tracker::{
|
||||
DisplayMode, TaskExecutionTracker,
|
||||
};
|
||||
use crate::agents::sub_recipe_execution_tool::tasks::process_task;
|
||||
use crate::agents::sub_recipe_execution_tool::workers::spawn_worker;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
|
||||
const EXECUTION_STATUS_COMPLETED: &str = "completed";
|
||||
const DEFAULT_MAX_WORKERS: usize = 10;
|
||||
|
||||
pub async fn execute_single_task(
|
||||
task: &Task,
|
||||
notifier: mpsc::Sender<JsonRpcMessage>,
|
||||
) -> ExecutionResponse {
|
||||
let start_time = Instant::now();
|
||||
let task_execution_tracker = Arc::new(TaskExecutionTracker::new(
|
||||
vec![task.clone()],
|
||||
DisplayMode::SingleTaskOutput,
|
||||
notifier,
|
||||
));
|
||||
let result = process_task(task, task_execution_tracker).await;
|
||||
let execution_time = start_time.elapsed().as_millis();
|
||||
let stats = calculate_stats(&[result.clone()], execution_time);
|
||||
|
||||
ExecutionResponse {
|
||||
status: EXECUTION_STATUS_COMPLETED.to_string(),
|
||||
results: vec![result],
|
||||
stats,
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn execute_tasks_in_parallel(
|
||||
tasks: Vec<Task>,
|
||||
notifier: mpsc::Sender<JsonRpcMessage>,
|
||||
) -> ExecutionResponse {
|
||||
let task_execution_tracker = Arc::new(TaskExecutionTracker::new(
|
||||
tasks.clone(),
|
||||
DisplayMode::MultipleTasksOutput,
|
||||
notifier,
|
||||
));
|
||||
let start_time = Instant::now();
|
||||
let task_count = tasks.len();
|
||||
|
||||
if task_count == 0 {
|
||||
return create_empty_response();
|
||||
}
|
||||
|
||||
task_execution_tracker.refresh_display().await;
|
||||
|
||||
let (task_tx, task_rx, result_tx, mut result_rx) = create_channels(task_count);
|
||||
|
||||
if let Err(e) = send_tasks_to_channel(tasks, task_tx).await {
|
||||
tracing::error!("Task execution failed: {}", e);
|
||||
return create_error_response(e);
|
||||
}
|
||||
|
||||
let shared_state = create_shared_state(task_rx, result_tx, task_execution_tracker.clone());
|
||||
|
||||
let worker_count = std::cmp::min(task_count, DEFAULT_MAX_WORKERS);
|
||||
let mut worker_handles = Vec::new();
|
||||
for i in 0..worker_count {
|
||||
let handle = spawn_worker(shared_state.clone(), i);
|
||||
worker_handles.push(handle);
|
||||
}
|
||||
|
||||
let results = collect_results(&mut result_rx, task_execution_tracker.clone(), task_count).await;
|
||||
|
||||
for handle in worker_handles {
|
||||
if let Err(e) = handle.await {
|
||||
tracing::error!("Worker error: {}", e);
|
||||
}
|
||||
}
|
||||
|
||||
task_execution_tracker.send_tasks_complete().await;
|
||||
|
||||
let execution_time = start_time.elapsed().as_millis();
|
||||
let stats = calculate_stats(&results, execution_time);
|
||||
|
||||
ExecutionResponse {
|
||||
status: EXECUTION_STATUS_COMPLETED.to_string(),
|
||||
results,
|
||||
stats,
|
||||
}
|
||||
}
|
||||
|
||||
fn calculate_stats(results: &[TaskResult], execution_time_ms: u128) -> ExecutionStats {
|
||||
let completed = results
|
||||
.iter()
|
||||
.filter(|r| matches!(r.status, TaskStatus::Completed))
|
||||
.count();
|
||||
let failed = results
|
||||
.iter()
|
||||
.filter(|r| matches!(r.status, TaskStatus::Failed))
|
||||
.count();
|
||||
|
||||
ExecutionStats {
|
||||
total_tasks: results.len(),
|
||||
completed,
|
||||
failed,
|
||||
execution_time_ms,
|
||||
}
|
||||
}
|
||||
|
||||
fn create_channels(
|
||||
task_count: usize,
|
||||
) -> (
|
||||
mpsc::Sender<Task>,
|
||||
mpsc::Receiver<Task>,
|
||||
mpsc::Sender<TaskResult>,
|
||||
mpsc::Receiver<TaskResult>,
|
||||
) {
|
||||
let (task_tx, task_rx) = mpsc::channel::<Task>(task_count);
|
||||
let (result_tx, result_rx) = mpsc::channel::<TaskResult>(task_count);
|
||||
(task_tx, task_rx, result_tx, result_rx)
|
||||
}
|
||||
|
||||
fn create_shared_state(
|
||||
task_rx: mpsc::Receiver<Task>,
|
||||
result_tx: mpsc::Sender<TaskResult>,
|
||||
task_execution_tracker: Arc<TaskExecutionTracker>,
|
||||
) -> Arc<SharedState> {
|
||||
Arc::new(SharedState {
|
||||
task_receiver: Arc::new(tokio::sync::Mutex::new(task_rx)),
|
||||
result_sender: result_tx,
|
||||
active_workers: Arc::new(AtomicUsize::new(0)),
|
||||
task_execution_tracker,
|
||||
})
|
||||
}
|
||||
|
||||
async fn send_tasks_to_channel(
|
||||
tasks: Vec<Task>,
|
||||
task_tx: mpsc::Sender<Task>,
|
||||
) -> Result<(), String> {
|
||||
for task in tasks {
|
||||
task_tx
|
||||
.send(task)
|
||||
.await
|
||||
.map_err(|e| format!("Failed to queue task: {}", e))?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn create_empty_response() -> ExecutionResponse {
|
||||
ExecutionResponse {
|
||||
status: EXECUTION_STATUS_COMPLETED.to_string(),
|
||||
results: vec![],
|
||||
stats: ExecutionStats {
|
||||
total_tasks: 0,
|
||||
completed: 0,
|
||||
failed: 0,
|
||||
execution_time_ms: 0,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
async fn collect_results(
|
||||
result_rx: &mut mpsc::Receiver<TaskResult>,
|
||||
task_execution_tracker: Arc<TaskExecutionTracker>,
|
||||
expected_count: usize,
|
||||
) -> Vec<TaskResult> {
|
||||
let mut results = Vec::new();
|
||||
while let Some(result) = result_rx.recv().await {
|
||||
task_execution_tracker
|
||||
.complete_task(&result.task_id, result.clone())
|
||||
.await;
|
||||
results.push(result);
|
||||
if results.len() >= expected_count {
|
||||
break;
|
||||
}
|
||||
}
|
||||
results
|
||||
}
|
||||
|
||||
fn create_error_response(error: String) -> ExecutionResponse {
|
||||
tracing::error!("Creating error response: {}", error);
|
||||
ExecutionResponse {
|
||||
status: "failed".to_string(),
|
||||
results: vec![],
|
||||
stats: ExecutionStats {
|
||||
total_tasks: 0,
|
||||
completed: 0,
|
||||
failed: 1,
|
||||
execution_time_ms: 0,
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,100 @@
|
||||
use super::{calculate_stats, create_empty_response, create_error_response};
|
||||
use crate::agents::sub_recipe_execution_tool::lib::{TaskResult, TaskStatus};
|
||||
use serde_json::json;
|
||||
|
||||
fn create_test_task_result(task_id: &str, status: TaskStatus) -> TaskResult {
|
||||
let is_failed = matches!(status, TaskStatus::Failed);
|
||||
TaskResult {
|
||||
task_id: task_id.to_string(),
|
||||
status,
|
||||
data: Some(json!({"output": "test output"})),
|
||||
error: if is_failed {
|
||||
Some("Test error".to_string())
|
||||
} else {
|
||||
None
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_calculate_stats() {
|
||||
let results = vec![
|
||||
create_test_task_result("task1", TaskStatus::Completed),
|
||||
create_test_task_result("task2", TaskStatus::Completed),
|
||||
create_test_task_result("task3", TaskStatus::Failed),
|
||||
create_test_task_result("task4", TaskStatus::Completed),
|
||||
];
|
||||
|
||||
let stats = calculate_stats(&results, 1500);
|
||||
|
||||
assert_eq!(stats.total_tasks, 4);
|
||||
assert_eq!(stats.completed, 3);
|
||||
assert_eq!(stats.failed, 1);
|
||||
assert_eq!(stats.execution_time_ms, 1500);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_calculate_stats_empty_results() {
|
||||
let results = vec![];
|
||||
let stats = calculate_stats(&results, 0);
|
||||
|
||||
assert_eq!(stats.total_tasks, 0);
|
||||
assert_eq!(stats.completed, 0);
|
||||
assert_eq!(stats.failed, 0);
|
||||
assert_eq!(stats.execution_time_ms, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_calculate_stats_all_completed() {
|
||||
let results = vec![
|
||||
create_test_task_result("task1", TaskStatus::Completed),
|
||||
create_test_task_result("task2", TaskStatus::Completed),
|
||||
];
|
||||
|
||||
let stats = calculate_stats(&results, 800);
|
||||
|
||||
assert_eq!(stats.total_tasks, 2);
|
||||
assert_eq!(stats.completed, 2);
|
||||
assert_eq!(stats.failed, 0);
|
||||
assert_eq!(stats.execution_time_ms, 800);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_calculate_stats_all_failed() {
|
||||
let results = vec![
|
||||
create_test_task_result("task1", TaskStatus::Failed),
|
||||
create_test_task_result("task2", TaskStatus::Failed),
|
||||
];
|
||||
|
||||
let stats = calculate_stats(&results, 1200);
|
||||
|
||||
assert_eq!(stats.total_tasks, 2);
|
||||
assert_eq!(stats.completed, 0);
|
||||
assert_eq!(stats.failed, 2);
|
||||
assert_eq!(stats.execution_time_ms, 1200);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_create_empty_response() {
|
||||
let response = create_empty_response();
|
||||
|
||||
assert_eq!(response.status, "completed");
|
||||
assert_eq!(response.results.len(), 0);
|
||||
assert_eq!(response.stats.total_tasks, 0);
|
||||
assert_eq!(response.stats.completed, 0);
|
||||
assert_eq!(response.stats.failed, 0);
|
||||
assert_eq!(response.stats.execution_time_ms, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_create_error_response() {
|
||||
let error_msg = "Test error message";
|
||||
let response = create_error_response(error_msg.to_string());
|
||||
|
||||
assert_eq!(response.status, "failed");
|
||||
assert_eq!(response.results.len(), 0);
|
||||
assert_eq!(response.stats.total_tasks, 0);
|
||||
assert_eq!(response.stats.completed, 0);
|
||||
assert_eq!(response.stats.failed, 1);
|
||||
assert_eq!(response.stats.execution_time_ms, 0);
|
||||
}
|
||||
@@ -1,38 +0,0 @@
|
||||
use crate::agents::sub_recipe_execution_tool::executor::execute_single_task;
|
||||
pub use crate::agents::sub_recipe_execution_tool::executor::parallel_execute;
|
||||
pub use crate::agents::sub_recipe_execution_tool::types::{
|
||||
Config, ExecutionResponse, ExecutionStats, Task, TaskResult,
|
||||
};
|
||||
|
||||
use serde_json::Value;
|
||||
|
||||
pub async fn execute_tasks(input: Value, execution_mode: &str) -> Result<Value, String> {
|
||||
let tasks: Vec<Task> =
|
||||
serde_json::from_value(input.get("tasks").ok_or("Missing tasks field")?.clone())
|
||||
.map_err(|e| format!("Failed to parse tasks: {}", e))?;
|
||||
|
||||
let config: Config = if let Some(config_value) = input.get("config") {
|
||||
serde_json::from_value(config_value.clone())
|
||||
.map_err(|e| format!("Failed to parse config: {}", e))?
|
||||
} else {
|
||||
Config::default()
|
||||
};
|
||||
let task_count = tasks.len();
|
||||
match execution_mode {
|
||||
"sequential" => {
|
||||
if task_count == 1 {
|
||||
let response = execute_single_task(&tasks[0], config).await;
|
||||
serde_json::to_value(response)
|
||||
.map_err(|e| format!("Failed to serialize response: {}", e))
|
||||
} else {
|
||||
Err("Sequential execution mode requires exactly one task".to_string())
|
||||
}
|
||||
}
|
||||
"parallel" => {
|
||||
let response = parallel_execute(tasks, config).await;
|
||||
serde_json::to_value(response)
|
||||
.map_err(|e| format!("Failed to serialize response: {}", e))
|
||||
}
|
||||
_ => Err("Invalid execution mode".to_string()),
|
||||
}
|
||||
}
|
||||
127
crates/goose/src/agents/sub_recipe_execution_tool/lib/mod.rs
Normal file
127
crates/goose/src/agents/sub_recipe_execution_tool/lib/mod.rs
Normal file
@@ -0,0 +1,127 @@
|
||||
use crate::agents::sub_recipe_execution_tool::executor::{
|
||||
execute_single_task, execute_tasks_in_parallel,
|
||||
};
|
||||
pub use crate::agents::sub_recipe_execution_tool::task_types::{
|
||||
ExecutionMode, ExecutionResponse, ExecutionStats, SharedState, Task, TaskResult, TaskStatus,
|
||||
};
|
||||
use crate::agents::sub_recipe_execution_tool::tasks_manager::TasksManager;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
|
||||
use mcp_core::protocol::JsonRpcMessage;
|
||||
use serde_json::{json, Value};
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
pub async fn execute_tasks(
|
||||
input: Value,
|
||||
execution_mode: ExecutionMode,
|
||||
notifier: mpsc::Sender<JsonRpcMessage>,
|
||||
tasks_manager: &TasksManager,
|
||||
) -> Result<Value, String> {
|
||||
let task_ids: Vec<String> = serde_json::from_value(
|
||||
input
|
||||
.get("task_ids")
|
||||
.ok_or("Missing task_ids field")?
|
||||
.clone(),
|
||||
)
|
||||
.map_err(|e| format!("Failed to parse task_ids: {}", e))?;
|
||||
|
||||
let mut tasks = Vec::new();
|
||||
for task_id in &task_ids {
|
||||
match tasks_manager.get_task(task_id).await {
|
||||
Some(task) => tasks.push(task),
|
||||
None => {
|
||||
return Err(format!(
|
||||
"Task with ID '{}' not found in TasksManager",
|
||||
task_id
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let task_count = tasks.len();
|
||||
match execution_mode {
|
||||
ExecutionMode::Sequential => {
|
||||
if task_count == 1 {
|
||||
let response = execute_single_task(&tasks[0], notifier).await;
|
||||
handle_response(response)
|
||||
} else {
|
||||
Err("Sequential execution mode requires exactly one task".to_string())
|
||||
}
|
||||
}
|
||||
|
||||
ExecutionMode::Parallel => {
|
||||
if tasks.iter().any(|task| task.get_sequential_when_repeated()) {
|
||||
Ok(json!(
|
||||
{
|
||||
"execution_mode": ExecutionMode::Sequential,
|
||||
"task_ids": task_ids,
|
||||
"results": ["the tasks should be executed sequentially, no matter how user requests it. Please use the subrecipe__execute_task tool to execute the tasks sequentially."]
|
||||
}
|
||||
))
|
||||
} else {
|
||||
let response: ExecutionResponse =
|
||||
execute_tasks_in_parallel(tasks, notifier.clone()).await;
|
||||
handle_response(response)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn extract_failed_tasks(results: &[TaskResult]) -> Vec<String> {
|
||||
results
|
||||
.iter()
|
||||
.filter(|r| matches!(r.status, TaskStatus::Failed))
|
||||
.map(format_failed_task_error)
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn format_failed_task_error(result: &TaskResult) -> String {
|
||||
let error_msg = result.error.as_deref().unwrap_or("Unknown error");
|
||||
let partial_output = result
|
||||
.data
|
||||
.as_ref()
|
||||
.and_then(|d| d.get("partial_output"))
|
||||
.and_then(|v| v.as_str())
|
||||
.filter(|s| !s.trim().is_empty())
|
||||
.unwrap_or("No output captured");
|
||||
|
||||
format!(
|
||||
"Task '{}' ({}): {}\nOutput: {}",
|
||||
result.task_id,
|
||||
get_task_description(result),
|
||||
error_msg,
|
||||
partial_output
|
||||
)
|
||||
}
|
||||
|
||||
fn format_error_summary(
|
||||
failed_count: usize,
|
||||
total_count: usize,
|
||||
failed_tasks: Vec<String>,
|
||||
) -> String {
|
||||
format!(
|
||||
"{}/{} tasks failed:\n{}",
|
||||
failed_count,
|
||||
total_count,
|
||||
failed_tasks.join("\n")
|
||||
)
|
||||
}
|
||||
|
||||
fn handle_response(response: ExecutionResponse) -> Result<Value, String> {
|
||||
if response.stats.failed > 0 {
|
||||
let failed_tasks = extract_failed_tasks(&response.results);
|
||||
let error_summary = format_error_summary(
|
||||
response.stats.failed,
|
||||
response.stats.total_tasks,
|
||||
failed_tasks,
|
||||
);
|
||||
return Err(error_summary);
|
||||
}
|
||||
serde_json::to_value(response).map_err(|e| format!("Failed to serialize response: {}", e))
|
||||
}
|
||||
|
||||
fn get_task_description(result: &TaskResult) -> String {
|
||||
format!("ID: {}", result.task_id)
|
||||
}
|
||||
216
crates/goose/src/agents/sub_recipe_execution_tool/lib/tests.rs
Normal file
216
crates/goose/src/agents/sub_recipe_execution_tool/lib/tests.rs
Normal file
@@ -0,0 +1,216 @@
|
||||
use super::{
|
||||
extract_failed_tasks, format_error_summary, format_failed_task_error, get_task_description,
|
||||
handle_response,
|
||||
};
|
||||
use crate::agents::sub_recipe_execution_tool::lib::{
|
||||
ExecutionResponse, ExecutionStats, TaskResult, TaskStatus,
|
||||
};
|
||||
use serde_json::json;
|
||||
|
||||
fn create_test_task_result(task_id: &str, status: TaskStatus, error: Option<String>) -> TaskResult {
|
||||
TaskResult {
|
||||
task_id: task_id.to_string(),
|
||||
status,
|
||||
data: Some(json!({"partial_output": "test output"})),
|
||||
error,
|
||||
}
|
||||
}
|
||||
|
||||
fn create_test_execution_response(
|
||||
results: Vec<TaskResult>,
|
||||
failed_count: usize,
|
||||
) -> ExecutionResponse {
|
||||
ExecutionResponse {
|
||||
status: "completed".to_string(),
|
||||
results: results.clone(),
|
||||
stats: ExecutionStats {
|
||||
total_tasks: results.len(),
|
||||
completed: results.len() - failed_count,
|
||||
failed: failed_count,
|
||||
execution_time_ms: 1000,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_failed_tasks() {
|
||||
let results = vec![
|
||||
create_test_task_result("task1", TaskStatus::Completed, None),
|
||||
create_test_task_result(
|
||||
"task2",
|
||||
TaskStatus::Failed,
|
||||
Some("Error message".to_string()),
|
||||
),
|
||||
create_test_task_result("task3", TaskStatus::Completed, None),
|
||||
create_test_task_result(
|
||||
"task4",
|
||||
TaskStatus::Failed,
|
||||
Some("Another error".to_string()),
|
||||
),
|
||||
];
|
||||
|
||||
let failed_tasks = extract_failed_tasks(&results);
|
||||
|
||||
assert_eq!(failed_tasks.len(), 2);
|
||||
assert!(failed_tasks[0].contains("task2"));
|
||||
assert!(failed_tasks[0].contains("Error message"));
|
||||
assert!(failed_tasks[1].contains("task4"));
|
||||
assert!(failed_tasks[1].contains("Another error"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_failed_tasks_empty() {
|
||||
let results = vec![
|
||||
create_test_task_result("task1", TaskStatus::Completed, None),
|
||||
create_test_task_result("task2", TaskStatus::Completed, None),
|
||||
];
|
||||
|
||||
let failed_tasks = extract_failed_tasks(&results);
|
||||
|
||||
assert_eq!(failed_tasks.len(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_format_failed_task_error_with_error_message() {
|
||||
let result = create_test_task_result(
|
||||
"task1",
|
||||
TaskStatus::Failed,
|
||||
Some("Test error message".to_string()),
|
||||
);
|
||||
|
||||
let formatted = format_failed_task_error(&result);
|
||||
|
||||
assert!(formatted.contains("task1"));
|
||||
assert!(formatted.contains("Test error message"));
|
||||
assert!(formatted.contains("test output"));
|
||||
assert!(formatted.contains("ID: task1"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_format_failed_task_error_without_error_message() {
|
||||
let result = create_test_task_result("task2", TaskStatus::Failed, None);
|
||||
|
||||
let formatted = format_failed_task_error(&result);
|
||||
|
||||
assert!(formatted.contains("task2"));
|
||||
assert!(formatted.contains("Unknown error"));
|
||||
assert!(formatted.contains("test output"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_format_failed_task_error_empty_partial_output() {
|
||||
let mut result =
|
||||
create_test_task_result("task3", TaskStatus::Failed, Some("Error".to_string()));
|
||||
result.data = Some(json!({"partial_output": ""}));
|
||||
|
||||
let formatted = format_failed_task_error(&result);
|
||||
|
||||
assert!(formatted.contains("No output captured"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_format_failed_task_error_no_partial_output() {
|
||||
let mut result =
|
||||
create_test_task_result("task4", TaskStatus::Failed, Some("Error".to_string()));
|
||||
result.data = Some(json!({}));
|
||||
|
||||
let formatted = format_failed_task_error(&result);
|
||||
|
||||
assert!(formatted.contains("No output captured"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_format_failed_task_error_no_data() {
|
||||
let mut result =
|
||||
create_test_task_result("task5", TaskStatus::Failed, Some("Error".to_string()));
|
||||
result.data = None;
|
||||
|
||||
let formatted = format_failed_task_error(&result);
|
||||
|
||||
assert!(formatted.contains("No output captured"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_format_error_summary() {
|
||||
let failed_tasks = vec![
|
||||
"Task 'task1': Error 1\nOutput: output1".to_string(),
|
||||
"Task 'task2': Error 2\nOutput: output2".to_string(),
|
||||
];
|
||||
|
||||
let summary = format_error_summary(2, 5, failed_tasks);
|
||||
|
||||
assert_eq!(summary, "2/5 tasks failed:\nTask 'task1': Error 1\nOutput: output1\nTask 'task2': Error 2\nOutput: output2");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_format_error_summary_single_failure() {
|
||||
let failed_tasks = vec!["Task 'task1': Error\nOutput: output".to_string()];
|
||||
|
||||
let summary = format_error_summary(1, 3, failed_tasks);
|
||||
|
||||
assert_eq!(
|
||||
summary,
|
||||
"1/3 tasks failed:\nTask 'task1': Error\nOutput: output"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_handle_response_success() {
|
||||
let results = vec![
|
||||
create_test_task_result("task1", TaskStatus::Completed, None),
|
||||
create_test_task_result("task2", TaskStatus::Completed, None),
|
||||
];
|
||||
let response = create_test_execution_response(results, 0);
|
||||
|
||||
let result = handle_response(response);
|
||||
|
||||
assert!(result.is_ok());
|
||||
let value = result.unwrap();
|
||||
assert_eq!(value["status"], "completed");
|
||||
assert_eq!(value["stats"]["failed"], 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_handle_response_with_failures() {
|
||||
let results = vec![
|
||||
create_test_task_result("task1", TaskStatus::Completed, None),
|
||||
create_test_task_result("task2", TaskStatus::Failed, Some("Test error".to_string())),
|
||||
];
|
||||
let response = create_test_execution_response(results, 1);
|
||||
|
||||
let result = handle_response(response);
|
||||
|
||||
assert!(result.is_err());
|
||||
let error = result.unwrap_err();
|
||||
assert!(error.contains("1/2 tasks failed"));
|
||||
assert!(error.contains("task2"));
|
||||
assert!(error.contains("Test error"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_handle_response_all_failures() {
|
||||
let results = vec![
|
||||
create_test_task_result("task1", TaskStatus::Failed, Some("Error 1".to_string())),
|
||||
create_test_task_result("task2", TaskStatus::Failed, Some("Error 2".to_string())),
|
||||
];
|
||||
let response = create_test_execution_response(results, 2);
|
||||
|
||||
let result = handle_response(response);
|
||||
|
||||
assert!(result.is_err());
|
||||
let error = result.unwrap_err();
|
||||
assert!(error.contains("2/2 tasks failed"));
|
||||
assert!(error.contains("task1"));
|
||||
assert!(error.contains("task2"));
|
||||
assert!(error.contains("Error 1"));
|
||||
assert!(error.contains("Error 2"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_get_task_description() {
|
||||
let result = create_test_task_result("test_task_123", TaskStatus::Completed, None);
|
||||
|
||||
let description = get_task_description(&result);
|
||||
|
||||
assert_eq!(description, "ID: test_task_123");
|
||||
}
|
||||
@@ -1,6 +1,10 @@
|
||||
mod executor;
|
||||
pub mod lib;
|
||||
pub mod notification_events;
|
||||
pub mod sub_recipe_execute_task_tool;
|
||||
mod task_execution_tracker;
|
||||
mod task_types;
|
||||
mod tasks;
|
||||
mod types;
|
||||
pub mod tasks_manager;
|
||||
pub mod utils;
|
||||
mod workers;
|
||||
|
||||
@@ -0,0 +1,204 @@
|
||||
use crate::agents::sub_recipe_execution_tool::task_types::TaskStatus;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(tag = "subtype")]
|
||||
pub enum TaskExecutionNotificationEvent {
|
||||
#[serde(rename = "line_output")]
|
||||
LineOutput { task_id: String, output: String },
|
||||
#[serde(rename = "tasks_update")]
|
||||
TasksUpdate {
|
||||
stats: TaskExecutionStats,
|
||||
tasks: Vec<TaskInfo>,
|
||||
},
|
||||
#[serde(rename = "tasks_complete")]
|
||||
TasksComplete {
|
||||
stats: TaskCompletionStats,
|
||||
failed_tasks: Vec<FailedTaskInfo>,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct TaskExecutionStats {
|
||||
pub total: usize,
|
||||
pub pending: usize,
|
||||
pub running: usize,
|
||||
pub completed: usize,
|
||||
pub failed: usize,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct TaskCompletionStats {
|
||||
pub total: usize,
|
||||
pub completed: usize,
|
||||
pub failed: usize,
|
||||
pub success_rate: f64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct TaskInfo {
|
||||
pub id: String,
|
||||
pub status: TaskStatus,
|
||||
pub duration_secs: Option<f64>,
|
||||
pub current_output: String,
|
||||
pub task_type: String,
|
||||
pub task_name: String,
|
||||
pub task_metadata: String,
|
||||
pub error: Option<String>,
|
||||
pub result_data: Option<Value>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct FailedTaskInfo {
|
||||
pub id: String,
|
||||
pub name: String,
|
||||
pub error: Option<String>,
|
||||
}
|
||||
|
||||
impl TaskExecutionNotificationEvent {
|
||||
pub fn line_output(task_id: String, output: String) -> Self {
|
||||
Self::LineOutput { task_id, output }
|
||||
}
|
||||
|
||||
pub fn tasks_update(stats: TaskExecutionStats, tasks: Vec<TaskInfo>) -> Self {
|
||||
Self::TasksUpdate { stats, tasks }
|
||||
}
|
||||
|
||||
pub fn tasks_complete(stats: TaskCompletionStats, failed_tasks: Vec<FailedTaskInfo>) -> Self {
|
||||
Self::TasksComplete {
|
||||
stats,
|
||||
failed_tasks,
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert event to JSON format for MCP notification
|
||||
pub fn to_notification_data(&self) -> serde_json::Value {
|
||||
let mut event_data = serde_json::to_value(self).expect("Failed to serialize event");
|
||||
|
||||
// Add the type field at the root level
|
||||
if let serde_json::Value::Object(ref mut map) = event_data {
|
||||
map.insert(
|
||||
"type".to_string(),
|
||||
serde_json::Value::String("task_execution".to_string()),
|
||||
);
|
||||
}
|
||||
|
||||
event_data
|
||||
}
|
||||
}
|
||||
|
||||
impl TaskExecutionStats {
|
||||
pub fn new(
|
||||
total: usize,
|
||||
pending: usize,
|
||||
running: usize,
|
||||
completed: usize,
|
||||
failed: usize,
|
||||
) -> Self {
|
||||
Self {
|
||||
total,
|
||||
pending,
|
||||
running,
|
||||
completed,
|
||||
failed,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl TaskCompletionStats {
|
||||
pub fn new(total: usize, completed: usize, failed: usize) -> Self {
|
||||
let success_rate = if total > 0 {
|
||||
(completed as f64 / total as f64) * 100.0
|
||||
} else {
|
||||
0.0
|
||||
};
|
||||
|
||||
Self {
|
||||
total,
|
||||
completed,
|
||||
failed,
|
||||
success_rate,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_line_output_event_serialization() {
|
||||
let event = TaskExecutionNotificationEvent::line_output(
|
||||
"task-1".to_string(),
|
||||
"Hello World".to_string(),
|
||||
);
|
||||
|
||||
let notification_data = event.to_notification_data();
|
||||
assert_eq!(notification_data["type"], "task_execution");
|
||||
assert_eq!(notification_data["subtype"], "line_output");
|
||||
assert_eq!(notification_data["task_id"], "task-1");
|
||||
assert_eq!(notification_data["output"], "Hello World");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tasks_update_event_serialization() {
|
||||
let stats = TaskExecutionStats::new(5, 2, 1, 1, 1);
|
||||
let tasks = vec![TaskInfo {
|
||||
id: "task-1".to_string(),
|
||||
status: TaskStatus::Running,
|
||||
duration_secs: Some(1.5),
|
||||
current_output: "Processing...".to_string(),
|
||||
task_type: "sub_recipe".to_string(),
|
||||
task_name: "test-task".to_string(),
|
||||
task_metadata: "param=value".to_string(),
|
||||
error: None,
|
||||
result_data: None,
|
||||
}];
|
||||
|
||||
let event = TaskExecutionNotificationEvent::tasks_update(stats, tasks);
|
||||
let notification_data = event.to_notification_data();
|
||||
|
||||
assert_eq!(notification_data["type"], "task_execution");
|
||||
assert_eq!(notification_data["subtype"], "tasks_update");
|
||||
assert_eq!(notification_data["stats"]["total"], 5);
|
||||
assert_eq!(notification_data["tasks"].as_array().unwrap().len(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_event_roundtrip_serialization() {
|
||||
let original_event = TaskExecutionNotificationEvent::line_output(
|
||||
"task-1".to_string(),
|
||||
"Test output".to_string(),
|
||||
);
|
||||
|
||||
// Serialize to JSON
|
||||
let json_data = original_event.to_notification_data();
|
||||
|
||||
// Deserialize back to event (excluding the type field)
|
||||
let mut event_data = json_data.clone();
|
||||
if let serde_json::Value::Object(ref mut map) = event_data {
|
||||
map.remove("type");
|
||||
}
|
||||
|
||||
let deserialized_event: TaskExecutionNotificationEvent =
|
||||
serde_json::from_value(event_data).expect("Failed to deserialize");
|
||||
|
||||
match (original_event, deserialized_event) {
|
||||
(
|
||||
TaskExecutionNotificationEvent::LineOutput {
|
||||
task_id: id1,
|
||||
output: out1,
|
||||
},
|
||||
TaskExecutionNotificationEvent::LineOutput {
|
||||
task_id: id2,
|
||||
output: out2,
|
||||
},
|
||||
) => {
|
||||
assert_eq!(id1, id2);
|
||||
assert_eq!(out1, out2);
|
||||
}
|
||||
_ => panic!("Event types don't match after roundtrip"),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -2,17 +2,22 @@ use mcp_core::{tool::ToolAnnotations, Content, Tool, ToolError};
|
||||
use serde_json::Value;
|
||||
|
||||
use crate::agents::{
|
||||
sub_recipe_execution_tool::lib::execute_tasks, tool_execution::ToolCallResult,
|
||||
sub_recipe_execution_tool::lib::execute_tasks,
|
||||
sub_recipe_execution_tool::task_types::ExecutionMode,
|
||||
sub_recipe_execution_tool::tasks_manager::TasksManager, tool_execution::ToolCallResult,
|
||||
};
|
||||
use mcp_core::protocol::JsonRpcMessage;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio_stream;
|
||||
|
||||
pub const SUB_RECIPE_EXECUTE_TASK_TOOL_NAME: &str = "sub_recipe__execute_task";
|
||||
pub fn create_sub_recipe_execute_task_tool() -> Tool {
|
||||
Tool::new(
|
||||
SUB_RECIPE_EXECUTE_TASK_TOOL_NAME,
|
||||
"Only use this tool when you execute sub recipe task.
|
||||
EXECUTION STRATEGY:
|
||||
- DEFAULT: Execute tasks sequentially (one at a time) unless user explicitly requests parallel execution
|
||||
- PARALLEL: Only when user explicitly uses keywords like 'parallel', 'simultaneously', 'at the same time', 'concurrently'
|
||||
EXECUTION STRATEGY DECISION:
|
||||
1. If the tasks are created with execution_mode, use the execution_mode.
|
||||
2. Execute tasks sequentially unless user explicitly requests parallel execution. PARALLEL: User uses keywords like 'parallel', 'simultaneously', 'at the same time', 'concurrently'
|
||||
|
||||
IMPLEMENTATION:
|
||||
- Sequential execution: Call this tool multiple times, passing exactly ONE task per call
|
||||
@@ -32,69 +37,15 @@ EXAMPLES:
|
||||
"default": "sequential",
|
||||
"description": "Execution strategy for multiple tasks. Use 'sequential' (default) unless user explicitly requests parallel execution with words like 'parallel', 'simultaneously', 'at the same time', or 'concurrently'."
|
||||
},
|
||||
"tasks": {
|
||||
"task_ids": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"id": {
|
||||
"type": "string",
|
||||
"description": "Unique identifier for the task"
|
||||
},
|
||||
"task_type": {
|
||||
"type": "string",
|
||||
"enum": ["sub_recipe", "text_instruction"],
|
||||
"default": "sub_recipe",
|
||||
"description": "the type of task to execute, can be one of: sub_recipe, text_instruction"
|
||||
},
|
||||
"payload": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"sub_recipe": {
|
||||
"type": "object",
|
||||
"description": "sub recipe to execute",
|
||||
"properties": {
|
||||
"name": {
|
||||
"type": "string",
|
||||
"description": "name of the sub recipe to execute"
|
||||
},
|
||||
"recipe_path": {
|
||||
"type": "string",
|
||||
"description": "path of the sub recipe file"
|
||||
},
|
||||
"command_parameters": {
|
||||
"type": "object",
|
||||
"description": "parameters to pass to run recipe command with sub recipe file"
|
||||
}
|
||||
}
|
||||
},
|
||||
"text_instruction": {
|
||||
"type": "string",
|
||||
"description": "text instruction to execute"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"required": ["id", "payload"]
|
||||
},
|
||||
"description": "The tasks to run in parallel"
|
||||
},
|
||||
"config": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"timeout_seconds": {
|
||||
"type": "number"
|
||||
},
|
||||
"max_workers": {
|
||||
"type": "number"
|
||||
},
|
||||
"initial_workers": {
|
||||
"type": "number"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"required": ["tasks"]
|
||||
"required": ["task_ids"]
|
||||
}),
|
||||
Some(ToolAnnotations {
|
||||
title: Some("Run tasks in parallel".to_string()),
|
||||
@@ -106,19 +57,38 @@ EXAMPLES:
|
||||
)
|
||||
}
|
||||
|
||||
pub async fn run_tasks(execute_data: Value) -> ToolCallResult {
|
||||
pub async fn run_tasks(execute_data: Value, tasks_manager: &TasksManager) -> ToolCallResult {
|
||||
let (notification_tx, notification_rx) = mpsc::channel::<JsonRpcMessage>(100);
|
||||
|
||||
let tasks_manager_clone = tasks_manager.clone();
|
||||
let result_future = async move {
|
||||
let execute_data_clone = execute_data.clone();
|
||||
let default_execution_mode_value = Value::String("sequential".to_string());
|
||||
let execution_mode = execute_data_clone
|
||||
.get("execution_mode")
|
||||
.unwrap_or(&default_execution_mode_value)
|
||||
.as_str()
|
||||
.unwrap_or("sequential");
|
||||
match execute_tasks(execute_data, execution_mode).await {
|
||||
.and_then(|v| serde_json::from_value::<ExecutionMode>(v.clone()).ok())
|
||||
.unwrap_or_default();
|
||||
|
||||
match execute_tasks(
|
||||
execute_data,
|
||||
execution_mode,
|
||||
notification_tx,
|
||||
&tasks_manager_clone,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(result) => {
|
||||
let output = serde_json::to_string(&result).unwrap();
|
||||
ToolCallResult::from(Ok(vec![Content::text(output)]))
|
||||
Ok(vec![Content::text(output)])
|
||||
}
|
||||
Err(e) => ToolCallResult::from(Err(ToolError::ExecutionError(e.to_string()))),
|
||||
Err(e) => Err(ToolError::ExecutionError(e.to_string())),
|
||||
}
|
||||
};
|
||||
|
||||
// Convert receiver to stream
|
||||
let notification_stream = tokio_stream::wrappers::ReceiverStream::new(notification_rx);
|
||||
|
||||
ToolCallResult {
|
||||
result: Box::new(Box::pin(result_future)),
|
||||
notification_stream: Some(Box::new(notification_stream)),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,292 @@
|
||||
use mcp_core::protocol::{JsonRpcMessage, JsonRpcNotification};
|
||||
use serde_json::json;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::{mpsc, RwLock};
|
||||
use tokio::time::{sleep, Duration, Instant};
|
||||
|
||||
use crate::agents::sub_recipe_execution_tool::notification_events::{
|
||||
FailedTaskInfo, TaskCompletionStats, TaskExecutionNotificationEvent, TaskExecutionStats,
|
||||
TaskInfo as EventTaskInfo,
|
||||
};
|
||||
use crate::agents::sub_recipe_execution_tool::task_types::{
|
||||
Task, TaskInfo, TaskResult, TaskStatus,
|
||||
};
|
||||
use crate::agents::sub_recipe_execution_tool::utils::{count_by_status, get_task_name};
|
||||
use serde_json::Value;
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub enum DisplayMode {
|
||||
MultipleTasksOutput,
|
||||
SingleTaskOutput,
|
||||
}
|
||||
|
||||
const THROTTLE_INTERVAL_MS: u64 = 250;
|
||||
const COMPLETION_NOTIFICATION_DELAY_MS: u64 = 500;
|
||||
|
||||
fn format_task_metadata(task_info: &TaskInfo) -> String {
|
||||
if let Some(params) = task_info.task.get_command_parameters() {
|
||||
if params.is_empty() {
|
||||
return String::new();
|
||||
}
|
||||
|
||||
params
|
||||
.iter()
|
||||
.map(|(key, value)| {
|
||||
let value_str = match value {
|
||||
Value::String(s) => s.clone(),
|
||||
_ => value.to_string(),
|
||||
};
|
||||
format!("{}={}", key, value_str)
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
.join(",")
|
||||
} else {
|
||||
String::new()
|
||||
}
|
||||
}
|
||||
|
||||
pub struct TaskExecutionTracker {
|
||||
tasks: Arc<RwLock<HashMap<String, TaskInfo>>>,
|
||||
last_refresh: Arc<RwLock<Instant>>,
|
||||
notifier: mpsc::Sender<JsonRpcMessage>,
|
||||
display_mode: DisplayMode,
|
||||
}
|
||||
|
||||
impl TaskExecutionTracker {
|
||||
pub fn new(
|
||||
tasks: Vec<Task>,
|
||||
display_mode: DisplayMode,
|
||||
notifier: mpsc::Sender<JsonRpcMessage>,
|
||||
) -> Self {
|
||||
let task_map = tasks
|
||||
.into_iter()
|
||||
.map(|task| {
|
||||
let task_id = task.id.clone();
|
||||
(
|
||||
task_id,
|
||||
TaskInfo {
|
||||
task,
|
||||
status: TaskStatus::Pending,
|
||||
start_time: None,
|
||||
end_time: None,
|
||||
result: None,
|
||||
current_output: String::new(),
|
||||
},
|
||||
)
|
||||
})
|
||||
.collect();
|
||||
|
||||
Self {
|
||||
tasks: Arc::new(RwLock::new(task_map)),
|
||||
last_refresh: Arc::new(RwLock::new(Instant::now())),
|
||||
notifier,
|
||||
display_mode,
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn start_task(&self, task_id: &str) {
|
||||
let mut tasks = self.tasks.write().await;
|
||||
if let Some(task_info) = tasks.get_mut(task_id) {
|
||||
task_info.status = TaskStatus::Running;
|
||||
task_info.start_time = Some(Instant::now());
|
||||
}
|
||||
drop(tasks);
|
||||
self.force_refresh_display().await;
|
||||
}
|
||||
|
||||
pub async fn complete_task(&self, task_id: &str, result: TaskResult) {
|
||||
let mut tasks = self.tasks.write().await;
|
||||
if let Some(task_info) = tasks.get_mut(task_id) {
|
||||
task_info.status = result.status.clone();
|
||||
task_info.end_time = Some(Instant::now());
|
||||
task_info.result = Some(result);
|
||||
}
|
||||
drop(tasks);
|
||||
self.force_refresh_display().await;
|
||||
}
|
||||
|
||||
pub async fn get_current_output(&self, task_id: &str) -> Option<String> {
|
||||
let tasks = self.tasks.read().await;
|
||||
tasks
|
||||
.get(task_id)
|
||||
.map(|task_info| task_info.current_output.clone())
|
||||
}
|
||||
|
||||
pub async fn send_live_output(&self, task_id: &str, line: &str) {
|
||||
match self.display_mode {
|
||||
DisplayMode::SingleTaskOutput => {
|
||||
let tasks = self.tasks.read().await;
|
||||
let task_info = tasks.get(task_id);
|
||||
|
||||
let formatted_line = if let Some(task_info) = task_info {
|
||||
let task_name = get_task_name(task_info);
|
||||
let task_type = task_info.task.task_type.clone();
|
||||
let metadata = format_task_metadata(task_info);
|
||||
|
||||
if metadata.is_empty() {
|
||||
format!("[{} ({})] {}", task_name, task_type, line)
|
||||
} else {
|
||||
format!("[{} ({}) {}] {}", task_name, task_type, metadata, line)
|
||||
}
|
||||
} else {
|
||||
line.to_string()
|
||||
};
|
||||
drop(tasks);
|
||||
|
||||
let event = TaskExecutionNotificationEvent::line_output(
|
||||
task_id.to_string(),
|
||||
formatted_line,
|
||||
);
|
||||
|
||||
if let Err(e) =
|
||||
self.notifier
|
||||
.try_send(JsonRpcMessage::Notification(JsonRpcNotification {
|
||||
jsonrpc: "2.0".to_string(),
|
||||
method: "notifications/message".to_string(),
|
||||
params: Some(json!({
|
||||
"data": event.to_notification_data()
|
||||
})),
|
||||
}))
|
||||
{
|
||||
tracing::warn!("Failed to send live output notification: {}", e);
|
||||
}
|
||||
}
|
||||
DisplayMode::MultipleTasksOutput => {
|
||||
let mut tasks = self.tasks.write().await;
|
||||
if let Some(task_info) = tasks.get_mut(task_id) {
|
||||
task_info.current_output.push_str(line);
|
||||
task_info.current_output.push('\n');
|
||||
}
|
||||
drop(tasks);
|
||||
|
||||
if !self.should_throttle_refresh().await {
|
||||
self.refresh_display().await;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn should_throttle_refresh(&self) -> bool {
|
||||
let now = Instant::now();
|
||||
let mut last_refresh = self.last_refresh.write().await;
|
||||
|
||||
if now.duration_since(*last_refresh) > Duration::from_millis(THROTTLE_INTERVAL_MS) {
|
||||
*last_refresh = now;
|
||||
false
|
||||
} else {
|
||||
true
|
||||
}
|
||||
}
|
||||
|
||||
async fn send_tasks_update(&self) {
|
||||
let tasks = self.tasks.read().await;
|
||||
let task_list: Vec<_> = tasks.values().collect();
|
||||
let (total, pending, running, completed, failed) = count_by_status(&tasks);
|
||||
|
||||
let stats = TaskExecutionStats::new(total, pending, running, completed, failed);
|
||||
|
||||
let event_tasks: Vec<EventTaskInfo> = task_list
|
||||
.iter()
|
||||
.map(|task_info| {
|
||||
let now = Instant::now();
|
||||
EventTaskInfo {
|
||||
id: task_info.task.id.clone(),
|
||||
status: task_info.status.clone(),
|
||||
duration_secs: task_info.start_time.map(|start| {
|
||||
if let Some(end) = task_info.end_time {
|
||||
end.duration_since(start).as_secs_f64()
|
||||
} else {
|
||||
now.duration_since(start).as_secs_f64()
|
||||
}
|
||||
}),
|
||||
current_output: task_info.current_output.clone(),
|
||||
task_type: task_info.task.task_type.clone(),
|
||||
task_name: get_task_name(task_info).to_string(),
|
||||
task_metadata: format_task_metadata(task_info),
|
||||
error: task_info.error().cloned(),
|
||||
result_data: task_info.data().cloned(),
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
let event = TaskExecutionNotificationEvent::tasks_update(stats, event_tasks);
|
||||
|
||||
if let Err(e) = self
|
||||
.notifier
|
||||
.try_send(JsonRpcMessage::Notification(JsonRpcNotification {
|
||||
jsonrpc: "2.0".to_string(),
|
||||
method: "notifications/message".to_string(),
|
||||
params: Some(json!({
|
||||
"data": event.to_notification_data()
|
||||
})),
|
||||
}))
|
||||
{
|
||||
tracing::warn!("Failed to send tasks update notification: {}", e);
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn refresh_display(&self) {
|
||||
match self.display_mode {
|
||||
DisplayMode::MultipleTasksOutput => {
|
||||
self.send_tasks_update().await;
|
||||
}
|
||||
DisplayMode::SingleTaskOutput => {
|
||||
// No dashboard display needed for single task output mode
|
||||
// Live output is handled via send_live_output method
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Force refresh without throttling - used for important status changes
|
||||
async fn force_refresh_display(&self) {
|
||||
match self.display_mode {
|
||||
DisplayMode::MultipleTasksOutput => {
|
||||
// Reset throttle timer to allow immediate update
|
||||
let mut last_refresh = self.last_refresh.write().await;
|
||||
*last_refresh = Instant::now() - Duration::from_millis(THROTTLE_INTERVAL_MS + 1);
|
||||
drop(last_refresh);
|
||||
|
||||
self.send_tasks_update().await;
|
||||
}
|
||||
DisplayMode::SingleTaskOutput => {
|
||||
// No dashboard display needed for single task output mode
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn send_tasks_complete(&self) {
|
||||
let tasks = self.tasks.read().await;
|
||||
let (total, _, _, completed, failed) = count_by_status(&tasks);
|
||||
|
||||
let stats = TaskCompletionStats::new(total, completed, failed);
|
||||
|
||||
let failed_tasks: Vec<FailedTaskInfo> = tasks
|
||||
.values()
|
||||
.filter(|task_info| matches!(task_info.status, TaskStatus::Failed))
|
||||
.map(|task_info| FailedTaskInfo {
|
||||
id: task_info.task.id.clone(),
|
||||
name: get_task_name(task_info).to_string(),
|
||||
error: task_info.error().cloned(),
|
||||
})
|
||||
.collect();
|
||||
|
||||
let event = TaskExecutionNotificationEvent::tasks_complete(stats, failed_tasks);
|
||||
|
||||
if let Err(e) = self
|
||||
.notifier
|
||||
.try_send(JsonRpcMessage::Notification(JsonRpcNotification {
|
||||
jsonrpc: "2.0".to_string(),
|
||||
method: "notifications/message".to_string(),
|
||||
params: Some(json!({
|
||||
"data": event.to_notification_data()
|
||||
})),
|
||||
}))
|
||||
{
|
||||
tracing::warn!("Failed to send tasks complete notification: {}", e);
|
||||
}
|
||||
|
||||
// Brief delay to ensure completion notification is processed
|
||||
sleep(Duration::from_millis(COMPLETION_NOTIFICATION_DELAY_MS)).await;
|
||||
}
|
||||
}
|
||||
145
crates/goose/src/agents/sub_recipe_execution_tool/task_types.rs
Normal file
145
crates/goose/src/agents/sub_recipe_execution_tool/task_types.rs
Normal file
@@ -0,0 +1,145 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::{Map, Value};
|
||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
use crate::agents::sub_recipe_execution_tool::task_execution_tracker::TaskExecutionTracker;
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum ExecutionMode {
|
||||
#[default]
|
||||
Sequential,
|
||||
Parallel,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Task {
|
||||
pub id: String,
|
||||
pub task_type: String,
|
||||
pub payload: Value,
|
||||
}
|
||||
|
||||
impl Task {
|
||||
pub fn get_sub_recipe(&self) -> Option<&Map<String, Value>> {
|
||||
(self.task_type == "sub_recipe")
|
||||
.then(|| self.payload.get("sub_recipe")?.as_object())
|
||||
.flatten()
|
||||
}
|
||||
|
||||
pub fn get_sequential_when_repeated(&self) -> bool {
|
||||
self.get_sub_recipe()
|
||||
.and_then(|sr| sr.get("sequential_when_repeated").and_then(|v| v.as_bool()))
|
||||
.unwrap_or_default()
|
||||
}
|
||||
|
||||
pub fn get_command_parameters(&self) -> Option<&Map<String, Value>> {
|
||||
self.get_sub_recipe()
|
||||
.and_then(|sr| sr.get("command_parameters"))
|
||||
.and_then(|cp| cp.as_object())
|
||||
}
|
||||
|
||||
pub fn get_sub_recipe_name(&self) -> Option<&str> {
|
||||
self.get_sub_recipe()
|
||||
.and_then(|sr| sr.get("name"))
|
||||
.and_then(|name| name.as_str())
|
||||
}
|
||||
|
||||
pub fn get_sub_recipe_path(&self) -> Option<&str> {
|
||||
self.get_sub_recipe()
|
||||
.and_then(|sr| sr.get("recipe_path"))
|
||||
.and_then(|path| path.as_str())
|
||||
}
|
||||
|
||||
pub fn get_text_instruction(&self) -> Option<&str> {
|
||||
if self.task_type != "sub_recipe" {
|
||||
self.payload
|
||||
.get("text_instruction")
|
||||
.and_then(|text| text.as_str())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct TaskResult {
|
||||
pub task_id: String,
|
||||
pub status: TaskStatus,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub data: Option<Value>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub error: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub enum TaskStatus {
|
||||
Pending,
|
||||
Running,
|
||||
Completed,
|
||||
Failed,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for TaskStatus {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
TaskStatus::Pending => write!(f, "Pending"),
|
||||
TaskStatus::Running => write!(f, "Running"),
|
||||
TaskStatus::Completed => write!(f, "Completed"),
|
||||
TaskStatus::Failed => write!(f, "Failed"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct TaskInfo {
|
||||
pub task: Task,
|
||||
pub status: TaskStatus,
|
||||
pub start_time: Option<tokio::time::Instant>,
|
||||
pub end_time: Option<tokio::time::Instant>,
|
||||
pub result: Option<TaskResult>,
|
||||
pub current_output: String,
|
||||
}
|
||||
|
||||
impl TaskInfo {
|
||||
pub fn error(&self) -> Option<&String> {
|
||||
self.result.as_ref().and_then(|r| r.error.as_ref())
|
||||
}
|
||||
|
||||
pub fn data(&self) -> Option<&Value> {
|
||||
self.result.as_ref().and_then(|r| r.data.as_ref())
|
||||
}
|
||||
}
|
||||
|
||||
pub struct SharedState {
|
||||
pub task_receiver: Arc<tokio::sync::Mutex<mpsc::Receiver<Task>>>,
|
||||
pub result_sender: mpsc::Sender<TaskResult>,
|
||||
pub active_workers: Arc<AtomicUsize>,
|
||||
pub task_execution_tracker: Arc<TaskExecutionTracker>,
|
||||
}
|
||||
|
||||
impl SharedState {
|
||||
pub fn increment_active_workers(&self) {
|
||||
self.active_workers.fetch_add(1, Ordering::SeqCst);
|
||||
}
|
||||
|
||||
pub fn decrement_active_workers(&self) {
|
||||
self.active_workers.fetch_sub(1, Ordering::SeqCst);
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct ExecutionStats {
|
||||
pub total_tasks: usize,
|
||||
pub completed: usize,
|
||||
pub failed: usize,
|
||||
pub execution_time_ms: u128,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct ExecutionResponse {
|
||||
pub status: String,
|
||||
pub results: Vec<TaskResult>,
|
||||
pub stats: ExecutionStats,
|
||||
}
|
||||
@@ -1,76 +1,98 @@
|
||||
use serde_json::Value;
|
||||
use std::process::Stdio;
|
||||
use std::time::Duration;
|
||||
use std::sync::Arc;
|
||||
use tokio::io::{AsyncBufReadExt, BufReader};
|
||||
use tokio::process::Command;
|
||||
use tokio::time::timeout;
|
||||
|
||||
use crate::agents::sub_recipe_execution_tool::types::{Task, TaskResult};
|
||||
use crate::agents::sub_recipe_execution_tool::task_execution_tracker::TaskExecutionTracker;
|
||||
use crate::agents::sub_recipe_execution_tool::task_types::{Task, TaskResult, TaskStatus};
|
||||
|
||||
// Process a single task based on its type
|
||||
pub async fn process_task(task: &Task, timeout_seconds: u64) -> TaskResult {
|
||||
let task_clone = task.clone();
|
||||
let timeout_duration = Duration::from_secs(timeout_seconds);
|
||||
|
||||
// Execute with timeout
|
||||
match timeout(timeout_duration, execute_task(task_clone)).await {
|
||||
Ok(Ok(data)) => TaskResult {
|
||||
pub async fn process_task(
|
||||
task: &Task,
|
||||
task_execution_tracker: Arc<TaskExecutionTracker>,
|
||||
) -> TaskResult {
|
||||
match get_task_result(task.clone(), task_execution_tracker).await {
|
||||
Ok(data) => TaskResult {
|
||||
task_id: task.id.clone(),
|
||||
status: "success".to_string(),
|
||||
status: TaskStatus::Completed,
|
||||
data: Some(data),
|
||||
error: None,
|
||||
},
|
||||
Ok(Err(error)) => TaskResult {
|
||||
Err(error) => TaskResult {
|
||||
task_id: task.id.clone(),
|
||||
status: "failed".to_string(),
|
||||
status: TaskStatus::Failed,
|
||||
data: None,
|
||||
error: Some(error),
|
||||
},
|
||||
Err(_) => TaskResult {
|
||||
task_id: task.id.clone(),
|
||||
status: "failed".to_string(),
|
||||
data: None,
|
||||
error: Some("Task timeout".to_string()),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
async fn execute_task(task: Task) -> Result<Value, String> {
|
||||
async fn get_task_result(
|
||||
task: Task,
|
||||
task_execution_tracker: Arc<TaskExecutionTracker>,
|
||||
) -> Result<Value, String> {
|
||||
let (command, output_identifier) = build_command(&task)?;
|
||||
let (stdout_output, stderr_output, success) = run_command(
|
||||
command,
|
||||
&output_identifier,
|
||||
&task.id,
|
||||
task_execution_tracker,
|
||||
)
|
||||
.await?;
|
||||
|
||||
if success {
|
||||
process_output(stdout_output)
|
||||
} else {
|
||||
Err(format!("Command failed:\n{}", stderr_output))
|
||||
}
|
||||
}
|
||||
|
||||
fn build_command(task: &Task) -> Result<(Command, String), String> {
|
||||
let task_error = |field: &str| format!("Task {}: Missing {}", task.id, field);
|
||||
|
||||
let mut output_identifier = task.id.clone();
|
||||
let mut command = if task.task_type == "sub_recipe" {
|
||||
let sub_recipe = task.payload.get("sub_recipe").unwrap();
|
||||
let sub_recipe_name = sub_recipe.get("name").unwrap().as_str().unwrap();
|
||||
let path = sub_recipe.get("recipe_path").unwrap().as_str().unwrap();
|
||||
let command_parameters = sub_recipe.get("command_parameters").unwrap();
|
||||
let sub_recipe_name = task
|
||||
.get_sub_recipe_name()
|
||||
.ok_or_else(|| task_error("sub_recipe name"))?;
|
||||
let path = task
|
||||
.get_sub_recipe_path()
|
||||
.ok_or_else(|| task_error("sub_recipe path"))?;
|
||||
let command_parameters = task
|
||||
.get_command_parameters()
|
||||
.ok_or_else(|| task_error("command_parameters"))?;
|
||||
|
||||
output_identifier = format!("sub-recipe {}", sub_recipe_name);
|
||||
let mut cmd = Command::new("goose");
|
||||
cmd.arg("run").arg("--recipe").arg(path);
|
||||
if let Some(params_map) = command_parameters.as_object() {
|
||||
for (key, value) in params_map {
|
||||
cmd.arg("run").arg("--recipe").arg(path).arg("--no-session");
|
||||
|
||||
for (key, value) in command_parameters {
|
||||
let key_str = key.to_string();
|
||||
let value_str = value.as_str().unwrap_or(&value.to_string()).to_string();
|
||||
cmd.arg("--params")
|
||||
.arg(format!("{}={}", key_str, value_str));
|
||||
}
|
||||
}
|
||||
cmd
|
||||
} else {
|
||||
let text = task
|
||||
.payload
|
||||
.get("text_instruction")
|
||||
.unwrap()
|
||||
.as_str()
|
||||
.unwrap();
|
||||
.get_text_instruction()
|
||||
.ok_or_else(|| task_error("text_instruction"))?;
|
||||
let mut cmd = Command::new("goose");
|
||||
cmd.arg("run").arg("--text").arg(text);
|
||||
cmd
|
||||
};
|
||||
|
||||
// Configure to capture stdout
|
||||
command.stdout(Stdio::piped());
|
||||
command.stderr(Stdio::piped());
|
||||
Ok((command, output_identifier))
|
||||
}
|
||||
|
||||
// Spawn the child process
|
||||
async fn run_command(
|
||||
mut command: Command,
|
||||
output_identifier: &str,
|
||||
task_id: &str,
|
||||
task_execution_tracker: Arc<TaskExecutionTracker>,
|
||||
) -> Result<(String, String, bool), String> {
|
||||
let mut child = command
|
||||
.spawn()
|
||||
.map_err(|e| format!("Failed to spawn goose: {}", e))?;
|
||||
@@ -78,30 +100,20 @@ async fn execute_task(task: Task) -> Result<Value, String> {
|
||||
let stdout = child.stdout.take().expect("Failed to capture stdout");
|
||||
let stderr = child.stderr.take().expect("Failed to capture stderr");
|
||||
|
||||
let mut stdout_reader = BufReader::new(stdout).lines();
|
||||
let mut stderr_reader = BufReader::new(stderr).lines();
|
||||
|
||||
// Spawn background tasks to read from stdout and stderr
|
||||
let output_identifier_clone = output_identifier.clone();
|
||||
let stdout_task = tokio::spawn(async move {
|
||||
let mut buffer = String::new();
|
||||
while let Ok(Some(line)) = stdout_reader.next_line().await {
|
||||
println!("[{}] {}", output_identifier_clone, line);
|
||||
buffer.push_str(&line);
|
||||
buffer.push('\n');
|
||||
}
|
||||
buffer
|
||||
});
|
||||
|
||||
let stderr_task = tokio::spawn(async move {
|
||||
let mut buffer = String::new();
|
||||
while let Ok(Some(line)) = stderr_reader.next_line().await {
|
||||
eprintln!("[stderr for {}] {}", output_identifier, line);
|
||||
buffer.push_str(&line);
|
||||
buffer.push('\n');
|
||||
}
|
||||
buffer
|
||||
});
|
||||
let stdout_task = spawn_output_reader(
|
||||
stdout,
|
||||
output_identifier,
|
||||
false,
|
||||
task_id,
|
||||
task_execution_tracker.clone(),
|
||||
);
|
||||
let stderr_task = spawn_output_reader(
|
||||
stderr,
|
||||
output_identifier,
|
||||
true,
|
||||
task_id,
|
||||
task_execution_tracker.clone(),
|
||||
);
|
||||
|
||||
let status = child
|
||||
.wait()
|
||||
@@ -111,9 +123,63 @@ async fn execute_task(task: Task) -> Result<Value, String> {
|
||||
let stdout_output = stdout_task.await.unwrap();
|
||||
let stderr_output = stderr_task.await.unwrap();
|
||||
|
||||
if status.success() {
|
||||
Ok(Value::String(stdout_output))
|
||||
Ok((stdout_output, stderr_output, status.success()))
|
||||
}
|
||||
|
||||
fn spawn_output_reader(
|
||||
reader: impl tokio::io::AsyncRead + Unpin + Send + 'static,
|
||||
output_identifier: &str,
|
||||
is_stderr: bool,
|
||||
task_id: &str,
|
||||
task_execution_tracker: Arc<TaskExecutionTracker>,
|
||||
) -> tokio::task::JoinHandle<String> {
|
||||
let output_identifier = output_identifier.to_string();
|
||||
let task_id = task_id.to_string();
|
||||
tokio::spawn(async move {
|
||||
let mut buffer = String::new();
|
||||
let mut lines = BufReader::new(reader).lines();
|
||||
while let Ok(Some(line)) = lines.next_line().await {
|
||||
buffer.push_str(&line);
|
||||
buffer.push('\n');
|
||||
|
||||
if !is_stderr {
|
||||
task_execution_tracker
|
||||
.send_live_output(&task_id, &line)
|
||||
.await;
|
||||
} else {
|
||||
Err(format!("Command failed:\n{}", stderr_output))
|
||||
tracing::warn!("Task stderr [{}]: {}", output_identifier, line);
|
||||
}
|
||||
}
|
||||
buffer
|
||||
})
|
||||
}
|
||||
|
||||
fn extract_json_from_line(line: &str) -> Option<String> {
|
||||
let start = line.find('{')?;
|
||||
let end = line.rfind('}')?;
|
||||
|
||||
if start >= end {
|
||||
return None;
|
||||
}
|
||||
|
||||
let potential_json = &line[start..=end];
|
||||
if serde_json::from_str::<Value>(potential_json).is_ok() {
|
||||
Some(potential_json.to_string())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
fn process_output(stdout_output: String) -> Result<Value, String> {
|
||||
let last_line = stdout_output
|
||||
.lines()
|
||||
.filter(|line| !line.trim().is_empty())
|
||||
.next_back()
|
||||
.unwrap_or("");
|
||||
|
||||
if let Some(json_string) = extract_json_from_line(last_line) {
|
||||
Ok(Value::String(json_string))
|
||||
} else {
|
||||
Ok(Value::String(stdout_output))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,86 @@
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::RwLock;
|
||||
|
||||
use crate::agents::sub_recipe_execution_tool::task_types::Task;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct TasksManager {
|
||||
tasks: Arc<RwLock<HashMap<String, Task>>>,
|
||||
}
|
||||
|
||||
impl Default for TasksManager {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl TasksManager {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
tasks: Arc::new(RwLock::new(HashMap::new())),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn save_tasks(&self, tasks: Vec<Task>) {
|
||||
let mut task_map = self.tasks.write().await;
|
||||
for task in tasks {
|
||||
task_map.insert(task.id.clone(), task);
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn get_task(&self, task_id: &str) -> Option<Task> {
|
||||
let tasks = self.tasks.read().await;
|
||||
tasks.get(task_id).cloned()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use serde_json::json;
|
||||
|
||||
fn create_test_task(id: &str, sub_recipe_name: &str) -> Task {
|
||||
Task {
|
||||
id: id.to_string(),
|
||||
task_type: "sub_recipe".to_string(),
|
||||
payload: json!({
|
||||
"sub_recipe": {
|
||||
"name": sub_recipe_name,
|
||||
"command_parameters": {},
|
||||
"recipe_path": "/test/path"
|
||||
}
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_save_and_get_task() {
|
||||
let manager = TasksManager::new();
|
||||
let tasks = vec![create_test_task("task1", "weather")];
|
||||
|
||||
manager.save_tasks(tasks).await;
|
||||
|
||||
let retrieved = manager.get_task("task1").await;
|
||||
assert!(retrieved.is_some());
|
||||
assert_eq!(retrieved.unwrap().id, "task1");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_save_multiple_tasks() {
|
||||
let manager = TasksManager::new();
|
||||
let tasks = vec![
|
||||
create_test_task("task1", "weather"),
|
||||
create_test_task("task2", "news"),
|
||||
];
|
||||
|
||||
manager.save_tasks(tasks).await;
|
||||
|
||||
let task1 = manager.get_task("task1").await;
|
||||
let task2 = manager.get_task("task2").await;
|
||||
assert!(task1.is_some());
|
||||
assert!(task2.is_some());
|
||||
assert_eq!(task1.unwrap().id, "task1");
|
||||
assert_eq!(task2.unwrap().id, "task2");
|
||||
}
|
||||
}
|
||||
@@ -1,69 +0,0 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
|
||||
// Task definition that LLMs will send
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Task {
|
||||
pub id: String,
|
||||
pub task_type: String,
|
||||
pub payload: Value,
|
||||
}
|
||||
|
||||
// Result for each task
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct TaskResult {
|
||||
pub task_id: String,
|
||||
pub status: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub data: Option<Value>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub error: Option<String>,
|
||||
}
|
||||
|
||||
// Configuration for the parallel executor
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct Config {
|
||||
#[serde(default = "default_max_workers")]
|
||||
pub max_workers: usize,
|
||||
#[serde(default = "default_timeout")]
|
||||
pub timeout_seconds: u64,
|
||||
#[serde(default = "default_initial_workers")]
|
||||
pub initial_workers: usize,
|
||||
}
|
||||
|
||||
impl Default for Config {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
max_workers: default_max_workers(),
|
||||
timeout_seconds: default_timeout(),
|
||||
initial_workers: default_initial_workers(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn default_max_workers() -> usize {
|
||||
10
|
||||
}
|
||||
fn default_timeout() -> u64 {
|
||||
300
|
||||
}
|
||||
fn default_initial_workers() -> usize {
|
||||
2
|
||||
}
|
||||
|
||||
// Stats for the execution
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct ExecutionStats {
|
||||
pub total_tasks: usize,
|
||||
pub completed: usize,
|
||||
pub failed: usize,
|
||||
pub execution_time_ms: u128,
|
||||
}
|
||||
|
||||
// Main response structure
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct ExecutionResponse {
|
||||
pub status: String,
|
||||
pub results: Vec<TaskResult>,
|
||||
pub stats: ExecutionStats,
|
||||
}
|
||||
@@ -0,0 +1,27 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use crate::agents::sub_recipe_execution_tool::task_types::{TaskInfo, TaskStatus};
|
||||
|
||||
pub fn get_task_name(task_info: &TaskInfo) -> &str {
|
||||
task_info
|
||||
.task
|
||||
.get_sub_recipe_name()
|
||||
.unwrap_or(&task_info.task.id)
|
||||
}
|
||||
|
||||
pub fn count_by_status(tasks: &HashMap<String, TaskInfo>) -> (usize, usize, usize, usize, usize) {
|
||||
let total = tasks.len();
|
||||
let (pending, running, completed, failed) = tasks.values().fold(
|
||||
(0, 0, 0, 0),
|
||||
|(pending, running, completed, failed), task| match task.status {
|
||||
TaskStatus::Pending => (pending + 1, running, completed, failed),
|
||||
TaskStatus::Running => (pending, running + 1, completed, failed),
|
||||
TaskStatus::Completed => (pending, running, completed + 1, failed),
|
||||
TaskStatus::Failed => (pending, running, completed, failed + 1),
|
||||
},
|
||||
);
|
||||
(total, pending, running, completed, failed)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
154
crates/goose/src/agents/sub_recipe_execution_tool/utils/tests.rs
Normal file
154
crates/goose/src/agents/sub_recipe_execution_tool/utils/tests.rs
Normal file
@@ -0,0 +1,154 @@
|
||||
use crate::agents::sub_recipe_execution_tool::task_types::{Task, TaskInfo, TaskStatus};
|
||||
use crate::agents::sub_recipe_execution_tool::utils::{count_by_status, get_task_name};
|
||||
use serde_json::json;
|
||||
use std::collections::HashMap;
|
||||
|
||||
fn create_task_info_with_defaults(task: Task, status: TaskStatus) -> TaskInfo {
|
||||
TaskInfo {
|
||||
task,
|
||||
status,
|
||||
start_time: None,
|
||||
end_time: None,
|
||||
result: None,
|
||||
current_output: String::new(),
|
||||
}
|
||||
}
|
||||
|
||||
mod test_get_task_name {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_extracts_sub_recipe_name() {
|
||||
let sub_recipe_task = Task {
|
||||
id: "task_1".to_string(),
|
||||
task_type: "sub_recipe".to_string(),
|
||||
payload: json!({
|
||||
"sub_recipe": {
|
||||
"name": "my_recipe",
|
||||
"recipe_path": "/path/to/recipe"
|
||||
}
|
||||
}),
|
||||
};
|
||||
|
||||
let task_info = create_task_info_with_defaults(sub_recipe_task, TaskStatus::Pending);
|
||||
|
||||
assert_eq!(get_task_name(&task_info), "my_recipe");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn falls_back_to_task_id_for_text_instruction() {
|
||||
let text_task = Task {
|
||||
id: "task_2".to_string(),
|
||||
task_type: "text_instruction".to_string(),
|
||||
payload: json!({"text_instruction": "do something"}),
|
||||
};
|
||||
|
||||
let task_info = create_task_info_with_defaults(text_task, TaskStatus::Pending);
|
||||
|
||||
assert_eq!(get_task_name(&task_info), "task_2");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn falls_back_to_task_id_when_sub_recipe_name_missing() {
|
||||
let malformed_task = Task {
|
||||
id: "task_3".to_string(),
|
||||
task_type: "sub_recipe".to_string(),
|
||||
payload: json!({
|
||||
"sub_recipe": {
|
||||
"recipe_path": "/path/to/recipe"
|
||||
// missing "name" field
|
||||
}
|
||||
}),
|
||||
};
|
||||
|
||||
let task_info = create_task_info_with_defaults(malformed_task, TaskStatus::Pending);
|
||||
|
||||
assert_eq!(get_task_name(&task_info), "task_3");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn falls_back_to_task_id_when_sub_recipe_missing() {
|
||||
let malformed_task = Task {
|
||||
id: "task_4".to_string(),
|
||||
task_type: "sub_recipe".to_string(),
|
||||
payload: json!({}), // missing "sub_recipe" field
|
||||
};
|
||||
|
||||
let task_info = create_task_info_with_defaults(malformed_task, TaskStatus::Pending);
|
||||
|
||||
assert_eq!(get_task_name(&task_info), "task_4");
|
||||
}
|
||||
}
|
||||
|
||||
mod count_by_status {
|
||||
use super::*;
|
||||
|
||||
fn create_test_task(id: &str, status: TaskStatus) -> TaskInfo {
|
||||
let task = Task {
|
||||
id: id.to_string(),
|
||||
task_type: "test".to_string(),
|
||||
payload: json!({}),
|
||||
};
|
||||
create_task_info_with_defaults(task, status)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn counts_empty_map() {
|
||||
let tasks = HashMap::new();
|
||||
let (total, pending, running, completed, failed) = count_by_status(&tasks);
|
||||
assert_eq!(
|
||||
(total, pending, running, completed, failed),
|
||||
(0, 0, 0, 0, 0)
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn counts_single_status() {
|
||||
let mut tasks = HashMap::new();
|
||||
tasks.insert(
|
||||
"task1".to_string(),
|
||||
create_test_task("task1", TaskStatus::Pending),
|
||||
);
|
||||
tasks.insert(
|
||||
"task2".to_string(),
|
||||
create_test_task("task2", TaskStatus::Pending),
|
||||
);
|
||||
|
||||
let (total, pending, running, completed, failed) = count_by_status(&tasks);
|
||||
assert_eq!(
|
||||
(total, pending, running, completed, failed),
|
||||
(2, 2, 0, 0, 0)
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn counts_mixed_statuses() {
|
||||
let mut tasks = HashMap::new();
|
||||
tasks.insert(
|
||||
"task1".to_string(),
|
||||
create_test_task("task1", TaskStatus::Pending),
|
||||
);
|
||||
tasks.insert(
|
||||
"task2".to_string(),
|
||||
create_test_task("task2", TaskStatus::Running),
|
||||
);
|
||||
tasks.insert(
|
||||
"task3".to_string(),
|
||||
create_test_task("task3", TaskStatus::Completed),
|
||||
);
|
||||
tasks.insert(
|
||||
"task4".to_string(),
|
||||
create_test_task("task4", TaskStatus::Failed),
|
||||
);
|
||||
tasks.insert(
|
||||
"task5".to_string(),
|
||||
create_test_task("task5", TaskStatus::Completed),
|
||||
);
|
||||
|
||||
let (total, pending, running, completed, failed) = count_by_status(&tasks);
|
||||
assert_eq!(
|
||||
(total, pending, running, completed, failed),
|
||||
(5, 1, 1, 2, 1)
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -1,133 +1,30 @@
|
||||
use crate::agents::sub_recipe_execution_tool::task_types::{SharedState, Task};
|
||||
use crate::agents::sub_recipe_execution_tool::tasks::process_task;
|
||||
use crate::agents::sub_recipe_execution_tool::types::{Task, TaskResult};
|
||||
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::time::{sleep, Duration};
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::agents::sub_recipe_execution_tool::types::Task;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_spawn_worker_returns_handle() {
|
||||
// Create a simple shared state for testing
|
||||
let (task_tx, task_rx) = mpsc::channel::<Task>(1);
|
||||
let (result_tx, _result_rx) = mpsc::channel::<TaskResult>(1);
|
||||
|
||||
let shared_state = Arc::new(SharedState {
|
||||
task_receiver: Arc::new(tokio::sync::Mutex::new(task_rx)),
|
||||
result_sender: result_tx,
|
||||
active_workers: Arc::new(AtomicUsize::new(0)),
|
||||
should_stop: Arc::new(AtomicBool::new(false)),
|
||||
completed_tasks: Arc::new(AtomicUsize::new(0)),
|
||||
});
|
||||
|
||||
// Test that spawn_worker returns a JoinHandle
|
||||
let handle = spawn_worker(shared_state.clone(), 0, 5);
|
||||
|
||||
// Verify it's a JoinHandle by checking we can abort it
|
||||
assert!(!handle.is_finished());
|
||||
|
||||
// Signal stop and close the channel to let the worker exit
|
||||
shared_state.should_stop.store(true, Ordering::SeqCst);
|
||||
drop(task_tx); // Close the channel
|
||||
|
||||
// Wait for the worker to finish
|
||||
let result = handle.await;
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
async fn receive_task(state: &SharedState) -> Option<Task> {
|
||||
let mut receiver = state.task_receiver.lock().await;
|
||||
receiver.recv().await
|
||||
}
|
||||
|
||||
pub struct SharedState {
|
||||
pub task_receiver: Arc<tokio::sync::Mutex<mpsc::Receiver<Task>>>,
|
||||
pub result_sender: mpsc::Sender<TaskResult>,
|
||||
pub active_workers: Arc<AtomicUsize>,
|
||||
pub should_stop: Arc<AtomicBool>,
|
||||
pub completed_tasks: Arc<AtomicUsize>,
|
||||
}
|
||||
|
||||
// Spawn a worker task
|
||||
pub fn spawn_worker(
|
||||
state: Arc<SharedState>,
|
||||
worker_id: usize,
|
||||
timeout_seconds: u64,
|
||||
) -> tokio::task::JoinHandle<()> {
|
||||
state.active_workers.fetch_add(1, Ordering::SeqCst);
|
||||
pub fn spawn_worker(state: Arc<SharedState>, worker_id: usize) -> tokio::task::JoinHandle<()> {
|
||||
state.increment_active_workers();
|
||||
|
||||
tokio::spawn(async move {
|
||||
worker_loop(state, worker_id, timeout_seconds).await;
|
||||
worker_loop(state, worker_id).await;
|
||||
})
|
||||
}
|
||||
|
||||
async fn worker_loop(state: Arc<SharedState>, _worker_id: usize, timeout_seconds: u64) {
|
||||
loop {
|
||||
// Try to receive a task
|
||||
let task = {
|
||||
let mut receiver = state.task_receiver.lock().await;
|
||||
receiver.recv().await
|
||||
};
|
||||
async fn worker_loop(state: Arc<SharedState>, _worker_id: usize) {
|
||||
while let Some(task) = receive_task(&state).await {
|
||||
state.task_execution_tracker.start_task(&task.id).await;
|
||||
let result = process_task(&task, state.task_execution_tracker.clone()).await;
|
||||
|
||||
match task {
|
||||
Some(task) => {
|
||||
// Process the task
|
||||
let result = process_task(&task, timeout_seconds).await;
|
||||
|
||||
// Send result
|
||||
let _ = state.result_sender.send(result).await;
|
||||
|
||||
// Update completed count
|
||||
state.completed_tasks.fetch_add(1, Ordering::SeqCst);
|
||||
}
|
||||
None => {
|
||||
// Channel closed, exit worker
|
||||
if let Err(e) = state.result_sender.send(result).await {
|
||||
tracing::error!("Worker failed to send result: {}", e);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Check if we should stop
|
||||
if state.should_stop.load(Ordering::SeqCst) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Worker is exiting
|
||||
state.active_workers.fetch_sub(1, Ordering::SeqCst);
|
||||
}
|
||||
|
||||
// Scaling controller that monitors queue and spawns workers
|
||||
pub async fn run_scaler(
|
||||
state: Arc<SharedState>,
|
||||
task_count: usize,
|
||||
max_workers: usize,
|
||||
timeout_seconds: u64,
|
||||
) {
|
||||
let mut worker_count = 0;
|
||||
|
||||
loop {
|
||||
sleep(Duration::from_millis(100)).await;
|
||||
|
||||
let active = state.active_workers.load(Ordering::SeqCst);
|
||||
let completed = state.completed_tasks.load(Ordering::SeqCst);
|
||||
let pending = task_count.saturating_sub(completed);
|
||||
|
||||
// Simple scaling logic: spawn worker if many pending tasks and under limit
|
||||
if pending > active * 2 && active < max_workers && worker_count < max_workers {
|
||||
let _handle = spawn_worker(state.clone(), worker_count, timeout_seconds);
|
||||
worker_count += 1;
|
||||
}
|
||||
|
||||
// If all tasks completed, signal stop
|
||||
if completed >= task_count {
|
||||
state.should_stop.store(true, Ordering::SeqCst);
|
||||
break;
|
||||
}
|
||||
|
||||
// If no active workers and tasks remaining, spawn one
|
||||
if active == 0 && pending > 0 {
|
||||
let _handle = spawn_worker(state.clone(), worker_count, timeout_seconds);
|
||||
worker_count += 1;
|
||||
}
|
||||
}
|
||||
state.decrement_active_workers();
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@ use crate::{
|
||||
recipe_tools::sub_recipe_tools::{
|
||||
create_sub_recipe_task, create_sub_recipe_task_tool, SUB_RECIPE_TASK_TOOL_NAME_PREFIX,
|
||||
},
|
||||
sub_recipe_execution_tool::tasks_manager::TasksManager,
|
||||
tool_execution::ToolCallResult,
|
||||
},
|
||||
recipe::SubRecipe,
|
||||
@@ -34,12 +35,6 @@ impl SubRecipeManager {
|
||||
|
||||
pub fn add_sub_recipe_tools(&mut self, sub_recipes_to_add: Vec<SubRecipe>) {
|
||||
for sub_recipe in sub_recipes_to_add {
|
||||
// let sub_recipe_key = format!(
|
||||
// "{}_{}",
|
||||
// SUB_RECIPE_TOOL_NAME_PREFIX,
|
||||
// sub_recipe.name.clone()
|
||||
// );
|
||||
// let tool = create_sub_recipe_tool(&sub_recipe);
|
||||
let sub_recipe_key = format!(
|
||||
"{}_{}",
|
||||
SUB_RECIPE_TASK_TOOL_NAME_PREFIX,
|
||||
@@ -59,43 +54,22 @@ impl SubRecipeManager {
|
||||
&self,
|
||||
tool_name: &str,
|
||||
params: Value,
|
||||
tasks_manager: &TasksManager,
|
||||
) -> ToolCallResult {
|
||||
let result = self.call_sub_recipe_tool(tool_name, params).await;
|
||||
let result = self
|
||||
.call_sub_recipe_tool(tool_name, params, tasks_manager)
|
||||
.await;
|
||||
match result {
|
||||
Ok(call_result) => ToolCallResult::from(Ok(call_result)),
|
||||
Err(e) => ToolCallResult::from(Err(ToolError::ExecutionError(e.to_string()))),
|
||||
}
|
||||
}
|
||||
|
||||
// async fn call_sub_recipe_tool(
|
||||
// &self,
|
||||
// tool_name: &str,
|
||||
// params: Value,
|
||||
// ) -> Result<Vec<Content>, ToolError> {
|
||||
// let sub_recipe = self.sub_recipes.get(tool_name).ok_or_else(|| {
|
||||
// let sub_recipe_name = tool_name
|
||||
// .strip_prefix(SUB_RECIPE_TOOL_NAME_PREFIX)
|
||||
// .and_then(|s| s.strip_prefix("_"))
|
||||
// .ok_or_else(|| {
|
||||
// ToolError::InvalidParameters(format!(
|
||||
// "Invalid sub-recipe tool name format: {}",
|
||||
// tool_name
|
||||
// ))
|
||||
// })
|
||||
// .unwrap();
|
||||
|
||||
// ToolError::InvalidParameters(format!("Sub-recipe '{}' not found", sub_recipe_name))
|
||||
// })?;
|
||||
|
||||
// let output = run_sub_recipe(sub_recipe, params).await.map_err(|e| {
|
||||
// ToolError::ExecutionError(format!("Sub-recipe execution failed: {}", e))
|
||||
// })?;
|
||||
// Ok(vec![Content::text(output)])
|
||||
// }
|
||||
async fn call_sub_recipe_tool(
|
||||
&self,
|
||||
tool_name: &str,
|
||||
params: Value,
|
||||
tasks_manager: &TasksManager,
|
||||
) -> Result<Vec<Content>, ToolError> {
|
||||
let sub_recipe = self.sub_recipes.get(tool_name).ok_or_else(|| {
|
||||
let sub_recipe_name = tool_name
|
||||
@@ -111,11 +85,10 @@ impl SubRecipeManager {
|
||||
|
||||
ToolError::InvalidParameters(format!("Sub-recipe '{}' not found", sub_recipe_name))
|
||||
})?;
|
||||
|
||||
let output = create_sub_recipe_task(sub_recipe, params)
|
||||
let output = create_sub_recipe_task(sub_recipe, params, tasks_manager)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
ToolError::ExecutionError(format!("Sub-recipe execution failed: {}", e))
|
||||
ToolError::ExecutionError(format!("Sub-recipe task createion failed: {}", e))
|
||||
})?;
|
||||
Ok(vec![Content::text(output)])
|
||||
}
|
||||
|
||||
@@ -137,6 +137,13 @@ pub struct SubRecipe {
|
||||
pub path: String,
|
||||
#[serde(default, deserialize_with = "deserialize_value_map_as_string")]
|
||||
pub values: Option<HashMap<String, String>>,
|
||||
#[serde(default)]
|
||||
pub sequential_when_repeated: bool,
|
||||
}
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
pub struct Execution {
|
||||
#[serde(default)]
|
||||
pub parallel: bool,
|
||||
}
|
||||
|
||||
fn deserialize_value_map_as_string<'de, D>(
|
||||
|
||||
Reference in New Issue
Block a user