diff --git a/.gitignore b/.gitignore index caab83d7..41f629f0 100644 --- a/.gitignore +++ b/.gitignore @@ -31,6 +31,9 @@ ui/desktop/src/bin/goose_llm.dll # Hermit .hermit/ +# Claude +.claude + debug_*.txt # Docs diff --git a/crates/goose-cli/src/recipes/extract_from_cli.rs b/crates/goose-cli/src/recipes/extract_from_cli.rs index 199c7091..56113b75 100644 --- a/crates/goose-cli/src/recipes/extract_from_cli.rs +++ b/crates/goose-cli/src/recipes/extract_from_cli.rs @@ -32,6 +32,7 @@ pub fn extract_recipe_info_from_cli( path: recipe_file_path.to_string_lossy().to_string(), name, values: None, + sequential_when_repeated: true, }; all_sub_recipes.push(additional_sub_recipe); } diff --git a/crates/goose-cli/src/session/mod.rs b/crates/goose-cli/src/session/mod.rs index c6ccb09f..5669c142 100644 --- a/crates/goose-cli/src/session/mod.rs +++ b/crates/goose-cli/src/session/mod.rs @@ -4,8 +4,11 @@ mod export; mod input; mod output; mod prompt; +mod task_execution_display; mod thinking; +use crate::session::task_execution_display::TASK_EXECUTION_NOTIFICATION_TYPE; + pub use self::export::message_to_markdown; pub use builder::{build_session, SessionBuilderConfig, SessionSettings}; use console::Color; @@ -17,6 +20,8 @@ use goose::permission::PermissionConfirmation; use goose::providers::base::Provider; pub use goose::session::Identifier; use goose::utils::safe_truncate; +use std::io::Write; +use task_execution_display::format_task_execution_notification; use anyhow::{Context, Result}; use completion::GooseCompleter; @@ -1008,7 +1013,7 @@ impl Session { match method.as_str() { "notifications/message" => { let data = o.get("data").unwrap_or(&Value::Null); - let (formatted_message, subagent_id, _notification_type) = match data { + let (formatted_message, subagent_id, message_notification_type) = match data { Value::String(s) => (s.clone(), None, None), Value::Object(o) => { // Check for subagent notification structure first @@ -1059,6 +1064,8 @@ impl Session { } else if let Some(Value::String(output)) = o.get("output") { // Fallback for other MCP notification types (output.to_owned(), None, None) + } else if let Some(result) = format_task_execution_notification(data) { + result } else { (data.to_string(), None, None) } @@ -1077,7 +1084,19 @@ impl Session { } else { progress_bars.log(&formatted_message); } - } else { + } else if let Some(ref notification_type) = message_notification_type { + if notification_type == TASK_EXECUTION_NOTIFICATION_TYPE { + if interactive { + let _ = progress_bars.hide(); + print!("{}", formatted_message); + std::io::stdout().flush().unwrap(); + } else { + print!("{}", formatted_message); + std::io::stdout().flush().unwrap(); + } + } + } + else { // Non-subagent notification, display immediately with compact spacing if interactive { let _ = progress_bars.hide(); diff --git a/crates/goose-cli/src/session/task_execution_display/mod.rs b/crates/goose-cli/src/session/task_execution_display/mod.rs new file mode 100644 index 00000000..ec6c41ff --- /dev/null +++ b/crates/goose-cli/src/session/task_execution_display/mod.rs @@ -0,0 +1,247 @@ +use goose::agents::sub_recipe_execution_tool::lib::TaskStatus; +use goose::agents::sub_recipe_execution_tool::notification_events::{ + TaskExecutionNotificationEvent, TaskInfo, +}; +use serde_json::Value; +use std::sync::atomic::{AtomicBool, Ordering}; + +#[cfg(test)] +mod tests; + +const CLEAR_SCREEN: &str = "\x1b[2J\x1b[H"; +const MOVE_TO_PROGRESS_LINE: &str = "\x1b[4;1H"; +const CLEAR_TO_EOL: &str = "\x1b[K"; +const CLEAR_BELOW: &str = "\x1b[J"; +pub const TASK_EXECUTION_NOTIFICATION_TYPE: &str = "task_execution"; + +static INITIAL_SHOWN: AtomicBool = AtomicBool::new(false); + +fn format_result_data_for_display(result_data: &Value) -> String { + match result_data { + Value::String(s) => strip_ansi_codes(s), + Value::Object(obj) => { + if let Some(partial_output) = obj.get("partial_output").and_then(|v| v.as_str()) { + format!("Partial output: {}", partial_output) + } else { + serde_json::to_string_pretty(obj).unwrap_or_default() + } + } + Value::Array(arr) => serde_json::to_string_pretty(arr).unwrap_or_default(), + Value::Bool(b) => b.to_string(), + Value::Number(n) => n.to_string(), + Value::Null => "null".to_string(), + } +} + +fn process_output_for_display(output: &str) -> String { + const MAX_OUTPUT_LINES: usize = 2; + const OUTPUT_PREVIEW_LENGTH: usize = 100; + + let lines: Vec<&str> = output.lines().collect(); + let recent_lines = if lines.len() > MAX_OUTPUT_LINES { + &lines[lines.len() - MAX_OUTPUT_LINES..] + } else { + &lines + }; + + let clean_output = recent_lines.join(" ... "); + let stripped = strip_ansi_codes(&clean_output); + truncate_with_ellipsis(&stripped, OUTPUT_PREVIEW_LENGTH) +} + +fn truncate_with_ellipsis(text: &str, max_len: usize) -> String { + if text.len() > max_len { + let mut end = max_len.saturating_sub(3); + while end > 0 && !text.is_char_boundary(end) { + end -= 1; + } + format!("{}...", &text[..end]) + } else { + text.to_string() + } +} + +fn strip_ansi_codes(text: &str) -> String { + let mut result = String::new(); + let mut chars = text.chars(); + + while let Some(ch) = chars.next() { + if ch == '\x1b' { + if let Some(next_ch) = chars.next() { + if next_ch == '[' { + // This is an ANSI escape sequence, consume until alphabetic character + loop { + match chars.next() { + Some(c) if c.is_ascii_alphabetic() => break, + Some(_) => continue, + None => break, + } + } + } else { + // Not an ANSI sequence, keep both characters + result.push(ch); + result.push(next_ch); + } + } else { + // End of string after \x1b + result.push(ch); + } + } else { + result.push(ch); + } + } + + result +} + +pub fn format_task_execution_notification( + data: &Value, +) -> Option<(String, Option, Option)> { + if let Ok(event) = serde_json::from_value::(data.clone()) { + return Some(match event { + TaskExecutionNotificationEvent::LineOutput { output, .. } => ( + format!("{}\n", output), + None, + Some(TASK_EXECUTION_NOTIFICATION_TYPE.to_string()), + ), + TaskExecutionNotificationEvent::TasksUpdate { .. } => { + let formatted_display = format_tasks_update_from_event(&event); + ( + formatted_display, + None, + Some(TASK_EXECUTION_NOTIFICATION_TYPE.to_string()), + ) + } + TaskExecutionNotificationEvent::TasksComplete { .. } => { + let formatted_summary = format_tasks_complete_from_event(&event); + ( + formatted_summary, + None, + Some(TASK_EXECUTION_NOTIFICATION_TYPE.to_string()), + ) + } + }); + } + None +} + +fn format_tasks_update_from_event(event: &TaskExecutionNotificationEvent) -> String { + if let TaskExecutionNotificationEvent::TasksUpdate { stats, tasks } = event { + let mut display = String::new(); + + if !INITIAL_SHOWN.swap(true, Ordering::SeqCst) { + display.push_str(CLEAR_SCREEN); + display.push_str("šŸŽÆ Task Execution Dashboard\n"); + display.push_str("═══════════════════════════\n\n"); + } else { + display.push_str(MOVE_TO_PROGRESS_LINE); + } + + display.push_str(&format!( + "šŸ“Š Progress: {} total | ā³ {} pending | šŸƒ {} running | āœ… {} completed | āŒ {} failed", + stats.total, stats.pending, stats.running, stats.completed, stats.failed + )); + display.push_str(&format!("{}\n\n", CLEAR_TO_EOL)); + + let mut sorted_tasks = tasks.clone(); + sorted_tasks.sort_by(|a, b| a.id.cmp(&b.id)); + + for task in sorted_tasks { + display.push_str(&format_task_display(&task)); + } + + display.push_str(CLEAR_BELOW); + display + } else { + String::new() + } +} + +fn format_tasks_complete_from_event(event: &TaskExecutionNotificationEvent) -> String { + if let TaskExecutionNotificationEvent::TasksComplete { + stats, + failed_tasks, + } = event + { + let mut summary = String::new(); + summary.push_str("Execution Complete!\n"); + summary.push_str("═══════════════════════\n"); + + summary.push_str(&format!("Total Tasks: {}\n", stats.total)); + summary.push_str(&format!("āœ… Completed: {}\n", stats.completed)); + summary.push_str(&format!("āŒ Failed: {}\n", stats.failed)); + summary.push_str(&format!("šŸ“ˆ Success Rate: {:.1}%\n", stats.success_rate)); + + if !failed_tasks.is_empty() { + summary.push_str("\nāŒ Failed Tasks:\n"); + for task in failed_tasks { + summary.push_str(&format!(" • {}\n", task.name)); + if let Some(error) = &task.error { + summary.push_str(&format!(" Error: {}\n", error)); + } + } + } + + summary.push_str("\nšŸ“ Generating summary...\n"); + summary + } else { + String::new() + } +} + +fn format_task_display(task: &TaskInfo) -> String { + let mut task_display = String::new(); + + let status_icon = match task.status { + TaskStatus::Pending => "ā³", + TaskStatus::Running => "šŸƒ", + TaskStatus::Completed => "āœ…", + TaskStatus::Failed => "āŒ", + }; + + task_display.push_str(&format!( + "{} {} ({}){}\n", + status_icon, task.task_name, task.task_type, CLEAR_TO_EOL + )); + + if !task.task_metadata.is_empty() { + task_display.push_str(&format!( + " šŸ“‹ Parameters: {}{}\n", + task.task_metadata, CLEAR_TO_EOL + )); + } + + if let Some(duration_secs) = task.duration_secs { + task_display.push_str(&format!(" ā±ļø {:.1}s{}\n", duration_secs, CLEAR_TO_EOL)); + } + + if matches!(task.status, TaskStatus::Running) && !task.current_output.trim().is_empty() { + let processed_output = process_output_for_display(&task.current_output); + if !processed_output.is_empty() { + task_display.push_str(&format!(" šŸ’¬ {}{}\n", processed_output, CLEAR_TO_EOL)); + } + } + + if matches!(task.status, TaskStatus::Completed) { + if let Some(result_data) = &task.result_data { + let result_preview = format_result_data_for_display(result_data); + if !result_preview.is_empty() { + task_display.push_str(&format!(" šŸ“„ {}{}\n", result_preview, CLEAR_TO_EOL)); + } + } + } + + if matches!(task.status, TaskStatus::Failed) { + if let Some(error) = &task.error { + let error_preview = truncate_with_ellipsis(error, 80); + task_display.push_str(&format!( + " āš ļø {}{}\n", + error_preview.replace('\n', " "), + CLEAR_TO_EOL + )); + } + } + + task_display.push_str(&format!("{}\n", CLEAR_TO_EOL)); + task_display +} diff --git a/crates/goose-cli/src/session/task_execution_display/tests.rs b/crates/goose-cli/src/session/task_execution_display/tests.rs new file mode 100644 index 00000000..fb532850 --- /dev/null +++ b/crates/goose-cli/src/session/task_execution_display/tests.rs @@ -0,0 +1,337 @@ +use super::*; +use goose::agents::sub_recipe_execution_tool::notification_events::{ + FailedTaskInfo, TaskCompletionStats, TaskExecutionStats, +}; +use serde_json::json; + +#[test] +fn test_strip_ansi_codes() { + assert_eq!(strip_ansi_codes("hello world"), "hello world"); + assert_eq!(strip_ansi_codes("\x1b[31mred text\x1b[0m"), "red text"); + assert_eq!( + strip_ansi_codes("\x1b[1;32mbold green\x1b[0m"), + "bold green" + ); + assert_eq!( + strip_ansi_codes("normal\x1b[33myellow\x1b[0mnormal"), + "normalyellownormal" + ); + assert_eq!(strip_ansi_codes("\x1bhello"), "\x1bhello"); + assert_eq!(strip_ansi_codes("hello\x1b"), "hello\x1b"); + assert_eq!(strip_ansi_codes(""), ""); +} + +#[test] +fn test_truncate_with_ellipsis() { + assert_eq!(truncate_with_ellipsis("hello", 10), "hello"); + assert_eq!(truncate_with_ellipsis("hello", 5), "hello"); + assert_eq!(truncate_with_ellipsis("hello world", 8), "hello..."); + assert_eq!(truncate_with_ellipsis("hello", 3), "..."); + assert_eq!(truncate_with_ellipsis("hello", 2), "..."); + assert_eq!(truncate_with_ellipsis("hello", 1), "..."); + assert_eq!(truncate_with_ellipsis("", 5), ""); +} + +#[test] +fn test_process_output_for_display() { + assert_eq!(process_output_for_display("hello world"), "hello world"); + assert_eq!( + process_output_for_display("line1\nline2"), + "line1 ... line2" + ); + + let input = "line1\nline2\nline3\nline4"; + let result = process_output_for_display(input); + assert_eq!(result, "line3 ... line4"); + + let long_line = "a".repeat(150); + let result = process_output_for_display(&long_line); + assert!(result.len() <= 100); + assert!(result.ends_with("...")); + + let ansi_output = "\x1b[31mred line 1\x1b[0m\n\x1b[32mgreen line 2\x1b[0m"; + let result = process_output_for_display(ansi_output); + assert_eq!(result, "red line 1 ... green line 2"); + + assert_eq!(process_output_for_display(""), ""); +} + +#[test] +fn test_format_result_data_for_display() { + let string_val = json!("hello world"); + assert_eq!(format_result_data_for_display(&string_val), "hello world"); + + let ansi_string = json!("\x1b[31mred text\x1b[0m"); + assert_eq!(format_result_data_for_display(&ansi_string), "red text"); + + assert_eq!(format_result_data_for_display(&json!(true)), "true"); + assert_eq!(format_result_data_for_display(&json!(false)), "false"); + assert_eq!(format_result_data_for_display(&json!(42)), "42"); + assert_eq!(format_result_data_for_display(&json!(3.14)), "3.14"); + assert_eq!(format_result_data_for_display(&json!(null)), "null"); + + let partial_obj = json!({ + "partial_output": "some output", + "other_field": "ignored" + }); + assert_eq!( + format_result_data_for_display(&partial_obj), + "Partial output: some output" + ); + + let obj = json!({"key": "value", "num": 42}); + let result = format_result_data_for_display(&obj); + assert!(result.contains("key")); + assert!(result.contains("value")); + + let arr = json!([1, 2, 3]); + let result = format_result_data_for_display(&arr); + assert!(result.contains("1")); + assert!(result.contains("2")); + assert!(result.contains("3")); +} + +#[test] +fn test_format_task_execution_notification_line_output() { + let _event = TaskExecutionNotificationEvent::LineOutput { + task_id: "task-1".to_string(), + output: "Hello World".to_string(), + }; + + let data = json!({ + "subtype": "line_output", + "task_id": "task-1", + "output": "Hello World" + }); + + let result = format_task_execution_notification(&data); + assert!(result.is_some()); + + let (formatted, second, third) = result.unwrap(); + assert_eq!(formatted, "Hello World\n"); + assert_eq!(second, None); + assert_eq!(third, Some("task_execution".to_string())); +} + +#[test] +fn test_format_task_execution_notification_invalid_data() { + let invalid_data = json!({ + "invalid": "structure" + }); + + let result = format_task_execution_notification(&invalid_data); + assert_eq!(result, None); + + let incomplete_data = json!({ + "subtype": "line_output" + }); + + let result = format_task_execution_notification(&incomplete_data); + assert_eq!(result, None); +} + +#[test] +fn test_format_tasks_update_from_event() { + INITIAL_SHOWN.store(false, Ordering::SeqCst); + + let stats = TaskExecutionStats::new(3, 1, 1, 1, 0); + let tasks = vec![ + TaskInfo { + id: "task-1".to_string(), + status: TaskStatus::Running, + duration_secs: Some(1.5), + current_output: "Processing...".to_string(), + task_type: "sub_recipe".to_string(), + task_name: "test-task".to_string(), + task_metadata: "param=value".to_string(), + error: None, + result_data: None, + }, + TaskInfo { + id: "task-2".to_string(), + status: TaskStatus::Completed, + duration_secs: Some(2.3), + current_output: "".to_string(), + task_type: "text_instruction".to_string(), + task_name: "another-task".to_string(), + task_metadata: "".to_string(), + error: None, + result_data: Some(json!({"result": "success"})), + }, + ]; + + let event = TaskExecutionNotificationEvent::TasksUpdate { stats, tasks }; + let result = format_tasks_update_from_event(&event); + + assert!(result.contains("šŸŽÆ Task Execution Dashboard")); + assert!(result.contains("═══════════════════════════")); + assert!(result.contains("šŸ“Š Progress: 3 total")); + assert!(result.contains("ā³ 1 pending")); + assert!(result.contains("šŸƒ 1 running")); + assert!(result.contains("āœ… 1 completed")); + assert!(result.contains("āŒ 0 failed")); + assert!(result.contains("šŸƒ test-task")); + assert!(result.contains("āœ… another-task")); + assert!(result.contains("šŸ“‹ Parameters: param=value")); + assert!(result.contains("ā±ļø 1.5s")); + assert!(result.contains("šŸ’¬ Processing...")); + + let result2 = format_tasks_update_from_event(&event); + assert!(!result2.contains("šŸŽÆ Task Execution Dashboard")); + assert!(result2.contains(MOVE_TO_PROGRESS_LINE)); +} + +#[test] +fn test_format_tasks_complete_from_event() { + let stats = TaskCompletionStats::new(5, 4, 1); + let failed_tasks = vec![FailedTaskInfo { + id: "task-3".to_string(), + name: "failed-task".to_string(), + error: Some("Connection timeout".to_string()), + }]; + + let event = TaskExecutionNotificationEvent::TasksComplete { + stats, + failed_tasks, + }; + let result = format_tasks_complete_from_event(&event); + + assert!(result.contains("Execution Complete!")); + assert!(result.contains("═══════════════════════")); + assert!(result.contains("Total Tasks: 5")); + assert!(result.contains("āœ… Completed: 4")); + assert!(result.contains("āŒ Failed: 1")); + assert!(result.contains("šŸ“ˆ Success Rate: 80.0%")); + assert!(result.contains("āŒ Failed Tasks:")); + assert!(result.contains("• failed-task")); + assert!(result.contains("Error: Connection timeout")); + assert!(result.contains("šŸ“ Generating summary...")); +} + +#[test] +fn test_format_tasks_complete_from_event_no_failures() { + let stats = TaskCompletionStats::new(3, 3, 0); + let failed_tasks = vec![]; + + let event = TaskExecutionNotificationEvent::TasksComplete { + stats, + failed_tasks, + }; + let result = format_tasks_complete_from_event(&event); + + assert!(!result.contains("āŒ Failed Tasks:")); + assert!(result.contains("šŸ“ˆ Success Rate: 100.0%")); + assert!(result.contains("āŒ Failed: 0")); +} + +#[test] +fn test_format_task_display_running() { + let task = TaskInfo { + id: "task-1".to_string(), + status: TaskStatus::Running, + duration_secs: Some(1.5), + current_output: "Processing data...\nAlmost done...".to_string(), + task_type: "sub_recipe".to_string(), + task_name: "data-processor".to_string(), + task_metadata: "input=file.txt,output=result.json".to_string(), + error: None, + result_data: None, + }; + + let result = format_task_display(&task); + + assert!(result.contains("šŸƒ data-processor (sub_recipe)")); + assert!(result.contains("šŸ“‹ Parameters: input=file.txt,output=result.json")); + assert!(result.contains("ā±ļø 1.5s")); + assert!(result.contains("šŸ’¬ Processing data... ... Almost done...")); +} + +#[test] +fn test_format_task_display_completed() { + let task = TaskInfo { + id: "task-2".to_string(), + status: TaskStatus::Completed, + duration_secs: Some(3.2), + current_output: "".to_string(), + task_type: "text_instruction".to_string(), + task_name: "analyzer".to_string(), + task_metadata: "".to_string(), + error: None, + result_data: Some(json!({"status": "success", "count": 42})), + }; + + let result = format_task_display(&task); + + assert!(result.contains("āœ… analyzer (text_instruction)")); + assert!(result.contains("ā±ļø 3.2s")); + assert!(!result.contains("šŸ“‹ Parameters")); + assert!(result.contains("šŸ“„")); +} + +#[test] +fn test_format_task_display_failed() { + let task = TaskInfo { + id: "task-3".to_string(), + status: TaskStatus::Failed, + duration_secs: None, + current_output: "".to_string(), + task_type: "sub_recipe".to_string(), + task_name: "failing-task".to_string(), + task_metadata: "".to_string(), + error: Some( + "Network connection failed after multiple retries. The server is unreachable." + .to_string(), + ), + result_data: None, + }; + + let result = format_task_display(&task); + + assert!(result.contains("āŒ failing-task (sub_recipe)")); + assert!(!result.contains("ā±ļø")); + assert!(result.contains("āš ļø")); + assert!(result.contains("Network connection failed after multiple retries")); +} + +#[test] +fn test_format_task_display_pending() { + let task = TaskInfo { + id: "task-4".to_string(), + status: TaskStatus::Pending, + duration_secs: None, + current_output: "".to_string(), + task_type: "sub_recipe".to_string(), + task_name: "waiting-task".to_string(), + task_metadata: "priority=high".to_string(), + error: None, + result_data: None, + }; + + let result = format_task_display(&task); + + assert!(result.contains("ā³ waiting-task (sub_recipe)")); + assert!(result.contains("šŸ“‹ Parameters: priority=high")); + assert!(!result.contains("ā±ļø")); + assert!(!result.contains("šŸ’¬")); + assert!(!result.contains("šŸ“„")); + assert!(!result.contains("āš ļø")); +} + +#[test] +fn test_format_task_display_empty_current_output() { + let task = TaskInfo { + id: "task-5".to_string(), + status: TaskStatus::Running, + duration_secs: Some(0.5), + current_output: " \n\t \n ".to_string(), + task_type: "sub_recipe".to_string(), + task_name: "quiet-task".to_string(), + task_metadata: "".to_string(), + error: None, + result_data: None, + }; + + let result = format_task_display(&task); + + assert!(!result.contains("šŸ’¬")); +} diff --git a/crates/goose/src/agents/agent.rs b/crates/goose/src/agents/agent.rs index c52807d5..2ef7a183 100644 --- a/crates/goose/src/agents/agent.rs +++ b/crates/goose/src/agents/agent.rs @@ -12,6 +12,7 @@ use crate::agents::final_output_tool::{FINAL_OUTPUT_CONTINUATION_MESSAGE, FINAL_ use crate::agents::sub_recipe_execution_tool::sub_recipe_execute_task_tool::{ self, SUB_RECIPE_EXECUTE_TASK_TOOL_NAME, }; +use crate::agents::sub_recipe_execution_tool::tasks_manager::TasksManager; use crate::agents::sub_recipe_manager::SubRecipeManager; use crate::config::{Config, ExtensionConfigManager, PermissionManager}; use crate::message::{push_message, Message}; @@ -63,6 +64,7 @@ pub struct Agent { pub(super) provider: Mutex>>, pub(super) extension_manager: RwLock, pub(super) sub_recipe_manager: Mutex, + pub(super) tasks_manager: TasksManager, pub(super) final_output_tool: Mutex>, pub(super) frontend_tools: Mutex>, pub(super) frontend_instructions: Mutex>, @@ -137,6 +139,7 @@ impl Agent { provider: Mutex::new(None), extension_manager: RwLock::new(ExtensionManager::new()), sub_recipe_manager: Mutex::new(SubRecipeManager::new()), + tasks_manager: TasksManager::new(), final_output_tool: Mutex::new(None), frontend_tools: Mutex::new(HashMap::new()), frontend_instructions: Mutex::new(None), @@ -291,10 +294,18 @@ impl Agent { let sub_recipe_manager = self.sub_recipe_manager.lock().await; let result: ToolCallResult = if sub_recipe_manager.is_sub_recipe_tool(&tool_call.name) { sub_recipe_manager - .dispatch_sub_recipe_tool_call(&tool_call.name, tool_call.arguments.clone()) + .dispatch_sub_recipe_tool_call( + &tool_call.name, + tool_call.arguments.clone(), + &self.tasks_manager, + ) .await } else if tool_call.name == SUB_RECIPE_EXECUTE_TASK_TOOL_NAME { - sub_recipe_execute_task_tool::run_tasks(tool_call.arguments.clone()).await + sub_recipe_execute_task_tool::run_tasks( + tool_call.arguments.clone(), + &self.tasks_manager, + ) + .await } else if tool_call.name == PLATFORM_READ_RESOURCE_TOOL_NAME { // Check if the tool is read_resource and handle it separately ToolCallResult::from( diff --git a/crates/goose/src/agents/recipe_tools/mod.rs b/crates/goose/src/agents/recipe_tools/mod.rs index 5f2f95fc..90603c88 100644 --- a/crates/goose/src/agents/recipe_tools/mod.rs +++ b/crates/goose/src/agents/recipe_tools/mod.rs @@ -1 +1,2 @@ +pub mod param_utils; pub mod sub_recipe_tools; diff --git a/crates/goose/src/agents/recipe_tools/param_utils/mod.rs b/crates/goose/src/agents/recipe_tools/param_utils/mod.rs new file mode 100644 index 00000000..bd8468c0 --- /dev/null +++ b/crates/goose/src/agents/recipe_tools/param_utils/mod.rs @@ -0,0 +1,38 @@ +use anyhow::Result; +use serde_json::Value; +use std::collections::HashMap; + +use crate::recipe::SubRecipe; + +pub fn prepare_command_params( + sub_recipe: &SubRecipe, + params_from_tool_call: Vec, +) -> Result>> { + let base_params = sub_recipe.values.clone().unwrap_or_default(); + + if params_from_tool_call.is_empty() { + return Ok(vec![base_params]); + } + + let result = params_from_tool_call + .into_iter() + .map(|tool_param| { + let mut param_map = base_params.clone(); + if let Some(param_obj) = tool_param.as_object() { + for (key, value) in param_obj { + let value_str = value + .as_str() + .map(String::from) + .unwrap_or_else(|| value.to_string()); + param_map.entry(key.clone()).or_insert(value_str); + } + } + param_map + }) + .collect(); + + Ok(result) +} + +#[cfg(test)] +mod tests; diff --git a/crates/goose/src/agents/recipe_tools/param_utils/tests.rs b/crates/goose/src/agents/recipe_tools/param_utils/tests.rs new file mode 100644 index 00000000..583338d6 --- /dev/null +++ b/crates/goose/src/agents/recipe_tools/param_utils/tests.rs @@ -0,0 +1,140 @@ +use std::collections::HashMap; + +use crate::recipe::SubRecipe; +use serde_json::json; + +use crate::agents::recipe_tools::param_utils::prepare_command_params; + +fn setup_default_sub_recipe() -> SubRecipe { + let sub_recipe = SubRecipe { + name: "test_sub_recipe".to_string(), + path: "test_sub_recipe.yaml".to_string(), + values: Some(HashMap::from([("key1".to_string(), "value1".to_string())])), + sequential_when_repeated: true, + }; + sub_recipe +} + +mod prepare_command_params_tests { + use super::*; + + #[test] + fn test_return_command_param() { + let parameter_array = vec![json!(HashMap::from([( + "key2".to_string(), + "value2".to_string() + )]))]; + let mut sub_recipe = setup_default_sub_recipe(); + sub_recipe.values = Some(HashMap::from([("key1".to_string(), "value1".to_string())])); + + let result = prepare_command_params(&sub_recipe, parameter_array).unwrap(); + assert_eq!( + vec![HashMap::from([ + ("key1".to_string(), "value1".to_string()), + ("key2".to_string(), "value2".to_string()) + ]),], + result + ); + } + + #[test] + fn test_return_command_param_when_value_override_passed_param_value() { + let parameter_array = vec![json!(HashMap::from([( + "key2".to_string(), + "different_value".to_string() + )]))]; + let mut sub_recipe = setup_default_sub_recipe(); + sub_recipe.values = Some(HashMap::from([ + ("key1".to_string(), "value1".to_string()), + ("key2".to_string(), "value2".to_string()), + ])); + + let result = prepare_command_params(&sub_recipe, parameter_array).unwrap(); + assert_eq!( + vec![HashMap::from([ + ("key1".to_string(), "value1".to_string()), + ("key2".to_string(), "value2".to_string()) + ]),], + result + ); + } + + #[test] + fn test_return_empty_command_param() { + let parameter_array = vec![]; + let mut sub_recipe = setup_default_sub_recipe(); + sub_recipe.values = None; + + let result = prepare_command_params(&sub_recipe, parameter_array).unwrap(); + assert_eq!(result, vec![HashMap::new()]); + } + + mod multiple_tool_parameters { + use super::*; + + #[test] + fn test_return_command_param_when_all_values_from_tool_call_parameters() { + let parameter_array = vec![ + json!(HashMap::from([ + ("key1".to_string(), "key1_value1".to_string()), + ("key2".to_string(), "key2_value1".to_string()) + ])), + json!(HashMap::from([ + ("key1".to_string(), "key1_value2".to_string()), + ("key2".to_string(), "key2_value2".to_string()) + ])), + ]; + let mut sub_recipe = setup_default_sub_recipe(); + sub_recipe.values = None; + + let result = prepare_command_params(&sub_recipe, parameter_array).unwrap(); + assert_eq!( + vec![ + HashMap::from([ + ("key1".to_string(), "key1_value1".to_string()), + ("key2".to_string(), "key2_value1".to_string()), + ]), + HashMap::from([ + ("key1".to_string(), "key1_value2".to_string()), + ("key2".to_string(), "key2_value2".to_string()), + ]), + ], + result + ); + } + + #[test] + fn test_merge_base_values_with_tool_parameters() { + let parameter_array = vec![ + json!(HashMap::from([( + "key2".to_string(), + "override_value1".to_string() + )])), + json!(HashMap::from([( + "key2".to_string(), + "override_value2".to_string() + )])), + ]; + let mut sub_recipe = setup_default_sub_recipe(); + sub_recipe.values = Some(HashMap::from([ + ("key1".to_string(), "base_value".to_string()), + ("key2".to_string(), "original_value".to_string()), + ])); + + let result = prepare_command_params(&sub_recipe, parameter_array).unwrap(); + assert_eq!( + vec![ + HashMap::from([ + ("key1".to_string(), "base_value".to_string()), + ("key2".to_string(), "original_value".to_string()), + ]), + HashMap::from([ + ("key1".to_string(), "base_value".to_string()), + ("key2".to_string(), "original_value".to_string()), + ]), + ], + result + ); + } + } +} diff --git a/crates/goose/src/agents/recipe_tools/sub_recipe_tools.rs b/crates/goose/src/agents/recipe_tools/sub_recipe_tools.rs index 928cf8bd..810c4a60 100644 --- a/crates/goose/src/agents/recipe_tools/sub_recipe_tools.rs +++ b/crates/goose/src/agents/recipe_tools/sub_recipe_tools.rs @@ -1,22 +1,35 @@ -use std::{collections::HashMap, fs}; +use std::collections::HashSet; +use std::fs; use anyhow::Result; use mcp_core::tool::{Tool, ToolAnnotations}; use serde_json::{json, Map, Value}; -use crate::agents::sub_recipe_execution_tool::lib::Task; +use crate::agents::sub_recipe_execution_tool::lib::{ExecutionMode, Task}; +use crate::agents::sub_recipe_execution_tool::tasks_manager::TasksManager; use crate::recipe::{Recipe, RecipeParameter, RecipeParameterRequirement, SubRecipe}; +use super::param_utils::prepare_command_params; + pub const SUB_RECIPE_TASK_TOOL_NAME_PREFIX: &str = "subrecipe__create_task"; pub fn create_sub_recipe_task_tool(sub_recipe: &SubRecipe) -> Tool { let input_schema = get_input_schema(sub_recipe).unwrap(); Tool::new( format!("{}_{}", SUB_RECIPE_TASK_TOOL_NAME_PREFIX, sub_recipe.name), - "Before running this sub recipe, you should first create a task with this tool and then pass the task to the task executor".to_string(), + format!( + "Create one or more tasks to run the '{}' sub recipe. \ + Provide an array of parameter sets in the 'task_parameters' field:\n\ + - For a single task: provide an array with one parameter set\n\ + - For multiple tasks: provide an array with multiple parameter sets, each with different values\n\n\ + Each task will run the same sub recipe but with different parameter values. \ + This is useful when you need to execute the same sub recipe multiple times with varying inputs. \ + After creating the tasks and execution_mode is provided, pass them to the task executor to run these tasks", + sub_recipe.name + ), input_schema, Some(ToolAnnotations { - title: Some(format!("create sub recipe task {}", sub_recipe.name)), + title: Some(format!("create multiple sub recipe tasks for {}", sub_recipe.name)), read_only_hint: false, destructive_hint: true, idempotent_hint: false, @@ -25,6 +38,64 @@ pub fn create_sub_recipe_task_tool(sub_recipe: &SubRecipe) -> Tool { ) } +fn extract_task_parameters(params: &Value) -> Vec { + params + .get("task_parameters") + .and_then(|v| v.as_array()) + .cloned() + .unwrap_or_default() +} + +fn create_tasks_from_params( + sub_recipe: &SubRecipe, + command_params: &[std::collections::HashMap], +) -> Vec { + let tasks: Vec = command_params + .iter() + .map(|task_command_param| { + let payload = json!({ + "sub_recipe": { + "name": sub_recipe.name.clone(), + "command_parameters": task_command_param, + "recipe_path": sub_recipe.path.clone(), + "sequential_when_repeated": sub_recipe.sequential_when_repeated + } + }); + Task { + id: uuid::Uuid::new_v4().to_string(), + task_type: "sub_recipe".to_string(), + payload, + } + }) + .collect(); + + tasks +} + +fn create_task_execution_payload(tasks: &[Task], sub_recipe: &SubRecipe) -> Value { + let task_ids: Vec = tasks.iter().map(|task| task.id.clone()).collect(); + json!({ + "task_ids": task_ids, + "execution_mode": if sub_recipe.sequential_when_repeated { ExecutionMode::Sequential } else { ExecutionMode::Parallel }, + }) +} + +pub async fn create_sub_recipe_task( + sub_recipe: &SubRecipe, + params: Value, + tasks_manager: &TasksManager, +) -> Result { + let task_params_array = extract_task_parameters(¶ms); + let command_params = prepare_command_params(sub_recipe, task_params_array.clone())?; + let tasks = create_tasks_from_params(sub_recipe, &command_params); + let task_execution_payload = create_task_execution_payload(&tasks, sub_recipe); + + let tasks_json = serde_json::to_string(&task_execution_payload) + .map_err(|e| anyhow::anyhow!("Failed to serialize task list: {}", e))?; + tasks_manager.save_tasks(tasks.clone()).await; + Ok(tasks_json) +} + fn get_sub_recipe_parameter_definition( sub_recipe: &SubRecipe, ) -> Result>> { @@ -34,22 +105,55 @@ fn get_sub_recipe_parameter_definition( Ok(recipe.parameters) } -fn get_input_schema(sub_recipe: &SubRecipe) -> Result { - let mut sub_recipe_params_map = HashMap::::new(); +fn get_params_with_values(sub_recipe: &SubRecipe) -> HashSet { + let mut sub_recipe_params_with_values = HashSet::::new(); if let Some(params_with_value) = &sub_recipe.values { - for (param_name, param_value) in params_with_value { - sub_recipe_params_map.insert(param_name.clone(), param_value.clone()); + for param_name in params_with_value.keys() { + sub_recipe_params_with_values.insert(param_name.clone()); } } + sub_recipe_params_with_values +} + +fn create_input_schema(param_properties: Map, param_required: Vec) -> Value { + let mut properties = Map::new(); + if !param_properties.is_empty() { + properties.insert( + "task_parameters".to_string(), + json!({ + "type": "array", + "description": "Array of parameter sets for creating tasks. \ + For a single task, provide an array with one element. \ + For multiple tasks, provide an array with multiple elements, each with different parameter values. \ + If there is no parameter set, provide an empty array.", + "items": { + "type": "object", + "properties": param_properties, + "required": param_required + }, + }) + ); + } + json!({ + "type": "object", + "properties": properties, + }) +} + +fn get_input_schema(sub_recipe: &SubRecipe) -> Result { + let sub_recipe_params_with_values = get_params_with_values(sub_recipe); + let parameter_definition = get_sub_recipe_parameter_definition(sub_recipe)?; + + let mut param_properties = Map::new(); + let mut param_required = Vec::new(); + if let Some(parameters) = parameter_definition { - let mut properties = Map::new(); - let mut required = Vec::new(); for param in parameters { - if sub_recipe_params_map.contains_key(¶m.key) { + if sub_recipe_params_with_values.contains(¶m.key.clone()) { continue; } - properties.insert( + param_properties.insert( param.key.clone(), json!({ "type": param.input_type.to_string(), @@ -57,60 +161,11 @@ fn get_input_schema(sub_recipe: &SubRecipe) -> Result { }), ); if !matches!(param.requirement, RecipeParameterRequirement::Optional) { - required.push(param.key); + param_required.push(param.key); } } - Ok(json!({ - "type": "object", - "properties": properties, - "required": required - })) - } else { - Ok(json!({ - "type": "object", - "properties": {} - })) } -} - -fn prepare_command_params( - sub_recipe: &SubRecipe, - params_from_tool_call: Value, -) -> Result> { - let mut sub_recipe_params = HashMap::::new(); - if let Some(params_with_value) = &sub_recipe.values { - for (param_name, param_value) in params_with_value { - sub_recipe_params.insert(param_name.clone(), param_value.clone()); - } - } - if let Some(params_map) = params_from_tool_call.as_object() { - for (key, value) in params_map { - sub_recipe_params.insert( - key.to_string(), - value.as_str().unwrap_or(&value.to_string()).to_string(), - ); - } - } - Ok(sub_recipe_params) -} - -pub async fn create_sub_recipe_task(sub_recipe: &SubRecipe, params: Value) -> Result { - let command_params = prepare_command_params(sub_recipe, params)?; - let payload = json!({ - "sub_recipe": { - "name": sub_recipe.name.clone(), - "command_parameters": command_params, - "recipe_path": sub_recipe.path.clone(), - } - }); - let task = Task { - id: uuid::Uuid::new_v4().to_string(), - task_type: "sub_recipe".to_string(), - payload, - }; - let task_json = serde_json::to_string(&task) - .map_err(|e| anyhow::anyhow!("Failed to serialize Task: {}", e))?; - Ok(task_json) + Ok(create_input_schema(param_properties, param_required)) } #[cfg(test)] diff --git a/crates/goose/src/agents/recipe_tools/sub_recipe_tools/tests.rs b/crates/goose/src/agents/recipe_tools/sub_recipe_tools/tests.rs index 11ce390a..0b682b0b 100644 --- a/crates/goose/src/agents/recipe_tools/sub_recipe_tools/tests.rs +++ b/crates/goose/src/agents/recipe_tools/sub_recipe_tools/tests.rs @@ -3,66 +3,48 @@ mod tests { use std::collections::HashMap; use crate::recipe::SubRecipe; + use serde_json::json; + use serde_json::Value; + use tempfile::TempDir; - fn setup_sub_recipe() -> SubRecipe { + fn setup_default_sub_recipe() -> SubRecipe { let sub_recipe = SubRecipe { name: "test_sub_recipe".to_string(), path: "test_sub_recipe.yaml".to_string(), values: Some(HashMap::from([("key1".to_string(), "value1".to_string())])), + sequential_when_repeated: true, }; sub_recipe } - mod prepare_command_params_tests { - use std::collections::HashMap; - use crate::{ - agents::recipe_tools::sub_recipe_tools::{ - prepare_command_params, tests::tests::setup_sub_recipe, - }, - recipe::SubRecipe, - }; + mod get_input_schema { + use super::*; + use crate::agents::recipe_tools::sub_recipe_tools::get_input_schema; - #[test] - fn test_prepare_command_params_basic() { - let mut params = HashMap::new(); - params.insert("key2".to_string(), "value2".to_string()); - - let sub_recipe = setup_sub_recipe(); - - let params_value = serde_json::to_value(params).unwrap(); - let result = prepare_command_params(&sub_recipe, params_value).unwrap(); - assert_eq!(result.len(), 2); - assert_eq!(result.get("key1"), Some(&"value1".to_string())); - assert_eq!(result.get("key2"), Some(&"value2".to_string())); + fn prepare_sub_recipe(sub_recipe_file_content: &str) -> (SubRecipe, TempDir) { + let mut sub_recipe = setup_default_sub_recipe(); + let temp_dir = tempfile::tempdir().unwrap(); + let temp_file = temp_dir.path().join(sub_recipe.path.clone()); + std::fs::write(&temp_file, sub_recipe_file_content).unwrap(); + sub_recipe.path = temp_file.to_string_lossy().to_string(); + (sub_recipe, temp_dir) } - #[test] - fn test_prepare_command_params_empty() { - let sub_recipe = SubRecipe { - name: "test_sub_recipe".to_string(), - path: "test_sub_recipe.yaml".to_string(), - values: None, - }; - let params: HashMap = HashMap::new(); - let params_value = serde_json::to_value(params).unwrap(); - let result = prepare_command_params(&sub_recipe, params_value).unwrap(); - assert_eq!(result.len(), 0); + fn verify_task_parameters(result: Value, expected_task_parameters_items: Value) { + let task_parameters = result + .get("properties") + .unwrap() + .as_object() + .unwrap() + .get("task_parameters") + .unwrap() + .as_object() + .unwrap(); + let task_parameters_items = task_parameters.get("items").unwrap(); + assert_eq!(&expected_task_parameters_items, task_parameters_items); } - } - mod get_input_schema_tests { - use crate::{ - agents::recipe_tools::sub_recipe_tools::{ - get_input_schema, tests::tests::setup_sub_recipe, - }, - recipe::SubRecipe, - }; - - #[test] - fn test_get_input_schema_with_parameters() { - let sub_recipe = setup_sub_recipe(); - - let sub_recipe_file_content = r#"{ + const SUB_RECIPE_FILE_CONTENT_WITH_TWO_PARAMS: &str = r#"{ "version": "1.0.0", "title": "Test Recipe", "description": "A test recipe", @@ -83,73 +65,67 @@ mod tests { ] }"#; - let temp_dir = tempfile::tempdir().unwrap(); - let temp_file = temp_dir.path().join("test_sub_recipe.yaml"); - std::fs::write(&temp_file, sub_recipe_file_content).unwrap(); - - let mut sub_recipe = sub_recipe; - sub_recipe.path = temp_file.to_string_lossy().to_string(); + #[test] + fn test_with_one_param_in_tool_input() { + let (mut sub_recipe, _temp_dir) = + prepare_sub_recipe(SUB_RECIPE_FILE_CONTENT_WITH_TWO_PARAMS); + sub_recipe.values = Some(HashMap::from([("key1".to_string(), "value1".to_string())])); let result = get_input_schema(&sub_recipe).unwrap(); - // Verify the schema structure - assert_eq!(result["type"], "object"); - assert!(result["properties"].is_object()); - - let properties = result["properties"].as_object().unwrap(); - assert_eq!(properties.len(), 1); - - let key2_prop = &properties["key2"]; - assert_eq!(key2_prop["type"], "number"); - assert_eq!(key2_prop["description"], "An optional parameter"); - - let required = result["required"].as_array().unwrap(); - assert_eq!(required.len(), 0); + verify_task_parameters( + result, + json!({ + "type": "object", + "properties": { + "key2": { "type": "number", "description": "An optional parameter" } + }, + "required": [] + }), + ); } #[test] - fn test_get_input_schema_no_parameters_values() { - let sub_recipe = SubRecipe { - name: "test_sub_recipe".to_string(), - path: "test_sub_recipe.yaml".to_string(), - values: None, - }; - - let sub_recipe_file_content = r#"{ - "version": "1.0.0", - "title": "Test Recipe", - "description": "A test recipe", - "prompt": "Test prompt", - "parameters": [ - { - "key": "key1", - "input_type": "string", - "requirement": "required", - "description": "A test parameter" - } - ] - }"#; - - let temp_dir = tempfile::tempdir().unwrap(); - let temp_file = temp_dir.path().join("test_sub_recipe.yaml"); - std::fs::write(&temp_file, sub_recipe_file_content).unwrap(); - - let mut sub_recipe = sub_recipe; - sub_recipe.path = temp_file.to_string_lossy().to_string(); + fn test_without_param_in_tool_input() { + let (mut sub_recipe, _temp_dir) = + prepare_sub_recipe(SUB_RECIPE_FILE_CONTENT_WITH_TWO_PARAMS); + sub_recipe.values = Some(HashMap::from([ + ("key1".to_string(), "value1".to_string()), + ("key2".to_string(), "value2".to_string()), + ])); let result = get_input_schema(&sub_recipe).unwrap(); - assert_eq!(result["type"], "object"); - assert!(result["properties"].is_object()); + assert_eq!( + None, + result + .get("properties") + .unwrap() + .as_object() + .unwrap() + .get("task_parameters") + ); + } - let properties = result["properties"].as_object().unwrap(); - assert_eq!(properties.len(), 1); + #[test] + fn test_with_all_params_in_tool_input() { + let (mut sub_recipe, _temp_dir) = + prepare_sub_recipe(SUB_RECIPE_FILE_CONTENT_WITH_TWO_PARAMS); + sub_recipe.values = None; - let key1_prop = &properties["key1"]; - assert_eq!(key1_prop["type"], "string"); - assert_eq!(key1_prop["description"], "A test parameter"); - assert_eq!(result["required"].as_array().unwrap().len(), 1); - assert_eq!(result["required"][0], "key1"); + let result = get_input_schema(&sub_recipe).unwrap(); + + verify_task_parameters( + result, + json!({ + "type": "object", + "properties": { + "key1": { "type": "string", "description": "A test parameter" }, + "key2": { "type": "number", "description": "An optional parameter" } + }, + "required": ["key1"] + }), + ); } } } diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/executor.rs b/crates/goose/src/agents/sub_recipe_execution_tool/executor.rs deleted file mode 100644 index b796d412..00000000 --- a/crates/goose/src/agents/sub_recipe_execution_tool/executor.rs +++ /dev/null @@ -1,103 +0,0 @@ -use std::sync::atomic::{AtomicBool, AtomicUsize}; -use std::sync::Arc; -use tokio::sync::mpsc; -use tokio::time::Instant; - -use crate::agents::sub_recipe_execution_tool::lib::{ - Config, ExecutionResponse, ExecutionStats, Task, TaskResult, -}; -use crate::agents::sub_recipe_execution_tool::tasks::process_task; -use crate::agents::sub_recipe_execution_tool::workers::{run_scaler, spawn_worker, SharedState}; - -pub async fn execute_single_task(task: &Task, config: Config) -> ExecutionResponse { - let start_time = Instant::now(); - let result = process_task(task, config.timeout_seconds).await; - - let execution_time = start_time.elapsed().as_millis(); - let completed = if result.status == "success" { 1 } else { 0 }; - let failed = if result.status == "failed" { 1 } else { 0 }; - - ExecutionResponse { - status: "completed".to_string(), - results: vec![result], - stats: ExecutionStats { - total_tasks: 1, - completed, - failed, - execution_time_ms: execution_time, - }, - } -} - -// Main parallel execution function -pub async fn parallel_execute(tasks: Vec, config: Config) -> ExecutionResponse { - let start_time = Instant::now(); - let task_count = tasks.len(); - - // Create channels - let (task_tx, task_rx) = mpsc::channel::(task_count); - let (result_tx, mut result_rx) = mpsc::channel::(task_count); - - // Initialize shared state - let shared_state = Arc::new(SharedState { - task_receiver: Arc::new(tokio::sync::Mutex::new(task_rx)), - result_sender: result_tx, - active_workers: Arc::new(AtomicUsize::new(0)), - should_stop: Arc::new(AtomicBool::new(false)), - completed_tasks: Arc::new(AtomicUsize::new(0)), - }); - - // Send all tasks to the queue - for task in tasks.clone() { - let _ = task_tx.send(task).await; - } - // Close sender so workers know when queue is empty - drop(task_tx); - - // Start initial workers - let mut worker_handles = Vec::new(); - for i in 0..config.initial_workers { - let handle = spawn_worker(shared_state.clone(), i, config.timeout_seconds); - worker_handles.push(handle); - } - - // Start the scaler - let scaler_state = shared_state.clone(); - let scaler_handle = tokio::spawn(async move { - run_scaler( - scaler_state, - task_count, - config.max_workers, - config.timeout_seconds, - ) - .await; - }); - - // Collect results - let mut results = Vec::new(); - while let Some(result) = result_rx.recv().await { - results.push(result); - if results.len() >= task_count { - break; - } - } - - // Wait for scaler to finish - let _ = scaler_handle.await; - - // Calculate stats - let execution_time = start_time.elapsed().as_millis(); - let completed = results.iter().filter(|r| r.status == "success").count(); - let failed = results.iter().filter(|r| r.status == "failed").count(); - - ExecutionResponse { - status: "completed".to_string(), - results, - stats: ExecutionStats { - total_tasks: task_count, - completed, - failed, - execution_time_ms: execution_time, - }, - } -} diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/executor/mod.rs b/crates/goose/src/agents/sub_recipe_execution_tool/executor/mod.rs new file mode 100644 index 00000000..bb73b73e --- /dev/null +++ b/crates/goose/src/agents/sub_recipe_execution_tool/executor/mod.rs @@ -0,0 +1,197 @@ +use mcp_core::protocol::JsonRpcMessage; +use std::sync::atomic::AtomicUsize; +use std::sync::Arc; +use tokio::sync::mpsc; +use tokio::time::Instant; + +use crate::agents::sub_recipe_execution_tool::lib::{ + ExecutionResponse, ExecutionStats, SharedState, Task, TaskResult, TaskStatus, +}; +use crate::agents::sub_recipe_execution_tool::task_execution_tracker::{ + DisplayMode, TaskExecutionTracker, +}; +use crate::agents::sub_recipe_execution_tool::tasks::process_task; +use crate::agents::sub_recipe_execution_tool::workers::spawn_worker; + +#[cfg(test)] +mod tests; + +const EXECUTION_STATUS_COMPLETED: &str = "completed"; +const DEFAULT_MAX_WORKERS: usize = 10; + +pub async fn execute_single_task( + task: &Task, + notifier: mpsc::Sender, +) -> ExecutionResponse { + let start_time = Instant::now(); + let task_execution_tracker = Arc::new(TaskExecutionTracker::new( + vec![task.clone()], + DisplayMode::SingleTaskOutput, + notifier, + )); + let result = process_task(task, task_execution_tracker).await; + let execution_time = start_time.elapsed().as_millis(); + let stats = calculate_stats(&[result.clone()], execution_time); + + ExecutionResponse { + status: EXECUTION_STATUS_COMPLETED.to_string(), + results: vec![result], + stats, + } +} + +pub async fn execute_tasks_in_parallel( + tasks: Vec, + notifier: mpsc::Sender, +) -> ExecutionResponse { + let task_execution_tracker = Arc::new(TaskExecutionTracker::new( + tasks.clone(), + DisplayMode::MultipleTasksOutput, + notifier, + )); + let start_time = Instant::now(); + let task_count = tasks.len(); + + if task_count == 0 { + return create_empty_response(); + } + + task_execution_tracker.refresh_display().await; + + let (task_tx, task_rx, result_tx, mut result_rx) = create_channels(task_count); + + if let Err(e) = send_tasks_to_channel(tasks, task_tx).await { + tracing::error!("Task execution failed: {}", e); + return create_error_response(e); + } + + let shared_state = create_shared_state(task_rx, result_tx, task_execution_tracker.clone()); + + let worker_count = std::cmp::min(task_count, DEFAULT_MAX_WORKERS); + let mut worker_handles = Vec::new(); + for i in 0..worker_count { + let handle = spawn_worker(shared_state.clone(), i); + worker_handles.push(handle); + } + + let results = collect_results(&mut result_rx, task_execution_tracker.clone(), task_count).await; + + for handle in worker_handles { + if let Err(e) = handle.await { + tracing::error!("Worker error: {}", e); + } + } + + task_execution_tracker.send_tasks_complete().await; + + let execution_time = start_time.elapsed().as_millis(); + let stats = calculate_stats(&results, execution_time); + + ExecutionResponse { + status: EXECUTION_STATUS_COMPLETED.to_string(), + results, + stats, + } +} + +fn calculate_stats(results: &[TaskResult], execution_time_ms: u128) -> ExecutionStats { + let completed = results + .iter() + .filter(|r| matches!(r.status, TaskStatus::Completed)) + .count(); + let failed = results + .iter() + .filter(|r| matches!(r.status, TaskStatus::Failed)) + .count(); + + ExecutionStats { + total_tasks: results.len(), + completed, + failed, + execution_time_ms, + } +} + +fn create_channels( + task_count: usize, +) -> ( + mpsc::Sender, + mpsc::Receiver, + mpsc::Sender, + mpsc::Receiver, +) { + let (task_tx, task_rx) = mpsc::channel::(task_count); + let (result_tx, result_rx) = mpsc::channel::(task_count); + (task_tx, task_rx, result_tx, result_rx) +} + +fn create_shared_state( + task_rx: mpsc::Receiver, + result_tx: mpsc::Sender, + task_execution_tracker: Arc, +) -> Arc { + Arc::new(SharedState { + task_receiver: Arc::new(tokio::sync::Mutex::new(task_rx)), + result_sender: result_tx, + active_workers: Arc::new(AtomicUsize::new(0)), + task_execution_tracker, + }) +} + +async fn send_tasks_to_channel( + tasks: Vec, + task_tx: mpsc::Sender, +) -> Result<(), String> { + for task in tasks { + task_tx + .send(task) + .await + .map_err(|e| format!("Failed to queue task: {}", e))?; + } + Ok(()) +} + +fn create_empty_response() -> ExecutionResponse { + ExecutionResponse { + status: EXECUTION_STATUS_COMPLETED.to_string(), + results: vec![], + stats: ExecutionStats { + total_tasks: 0, + completed: 0, + failed: 0, + execution_time_ms: 0, + }, + } +} + +async fn collect_results( + result_rx: &mut mpsc::Receiver, + task_execution_tracker: Arc, + expected_count: usize, +) -> Vec { + let mut results = Vec::new(); + while let Some(result) = result_rx.recv().await { + task_execution_tracker + .complete_task(&result.task_id, result.clone()) + .await; + results.push(result); + if results.len() >= expected_count { + break; + } + } + results +} + +fn create_error_response(error: String) -> ExecutionResponse { + tracing::error!("Creating error response: {}", error); + ExecutionResponse { + status: "failed".to_string(), + results: vec![], + stats: ExecutionStats { + total_tasks: 0, + completed: 0, + failed: 1, + execution_time_ms: 0, + }, + } +} diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/executor/tests.rs b/crates/goose/src/agents/sub_recipe_execution_tool/executor/tests.rs new file mode 100644 index 00000000..76385b87 --- /dev/null +++ b/crates/goose/src/agents/sub_recipe_execution_tool/executor/tests.rs @@ -0,0 +1,100 @@ +use super::{calculate_stats, create_empty_response, create_error_response}; +use crate::agents::sub_recipe_execution_tool::lib::{TaskResult, TaskStatus}; +use serde_json::json; + +fn create_test_task_result(task_id: &str, status: TaskStatus) -> TaskResult { + let is_failed = matches!(status, TaskStatus::Failed); + TaskResult { + task_id: task_id.to_string(), + status, + data: Some(json!({"output": "test output"})), + error: if is_failed { + Some("Test error".to_string()) + } else { + None + }, + } +} + +#[test] +fn test_calculate_stats() { + let results = vec![ + create_test_task_result("task1", TaskStatus::Completed), + create_test_task_result("task2", TaskStatus::Completed), + create_test_task_result("task3", TaskStatus::Failed), + create_test_task_result("task4", TaskStatus::Completed), + ]; + + let stats = calculate_stats(&results, 1500); + + assert_eq!(stats.total_tasks, 4); + assert_eq!(stats.completed, 3); + assert_eq!(stats.failed, 1); + assert_eq!(stats.execution_time_ms, 1500); +} + +#[test] +fn test_calculate_stats_empty_results() { + let results = vec![]; + let stats = calculate_stats(&results, 0); + + assert_eq!(stats.total_tasks, 0); + assert_eq!(stats.completed, 0); + assert_eq!(stats.failed, 0); + assert_eq!(stats.execution_time_ms, 0); +} + +#[test] +fn test_calculate_stats_all_completed() { + let results = vec![ + create_test_task_result("task1", TaskStatus::Completed), + create_test_task_result("task2", TaskStatus::Completed), + ]; + + let stats = calculate_stats(&results, 800); + + assert_eq!(stats.total_tasks, 2); + assert_eq!(stats.completed, 2); + assert_eq!(stats.failed, 0); + assert_eq!(stats.execution_time_ms, 800); +} + +#[test] +fn test_calculate_stats_all_failed() { + let results = vec![ + create_test_task_result("task1", TaskStatus::Failed), + create_test_task_result("task2", TaskStatus::Failed), + ]; + + let stats = calculate_stats(&results, 1200); + + assert_eq!(stats.total_tasks, 2); + assert_eq!(stats.completed, 0); + assert_eq!(stats.failed, 2); + assert_eq!(stats.execution_time_ms, 1200); +} + +#[test] +fn test_create_empty_response() { + let response = create_empty_response(); + + assert_eq!(response.status, "completed"); + assert_eq!(response.results.len(), 0); + assert_eq!(response.stats.total_tasks, 0); + assert_eq!(response.stats.completed, 0); + assert_eq!(response.stats.failed, 0); + assert_eq!(response.stats.execution_time_ms, 0); +} + +#[test] +fn test_create_error_response() { + let error_msg = "Test error message"; + let response = create_error_response(error_msg.to_string()); + + assert_eq!(response.status, "failed"); + assert_eq!(response.results.len(), 0); + assert_eq!(response.stats.total_tasks, 0); + assert_eq!(response.stats.completed, 0); + assert_eq!(response.stats.failed, 1); + assert_eq!(response.stats.execution_time_ms, 0); +} diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/lib.rs b/crates/goose/src/agents/sub_recipe_execution_tool/lib.rs deleted file mode 100644 index 9df784a4..00000000 --- a/crates/goose/src/agents/sub_recipe_execution_tool/lib.rs +++ /dev/null @@ -1,38 +0,0 @@ -use crate::agents::sub_recipe_execution_tool::executor::execute_single_task; -pub use crate::agents::sub_recipe_execution_tool::executor::parallel_execute; -pub use crate::agents::sub_recipe_execution_tool::types::{ - Config, ExecutionResponse, ExecutionStats, Task, TaskResult, -}; - -use serde_json::Value; - -pub async fn execute_tasks(input: Value, execution_mode: &str) -> Result { - let tasks: Vec = - serde_json::from_value(input.get("tasks").ok_or("Missing tasks field")?.clone()) - .map_err(|e| format!("Failed to parse tasks: {}", e))?; - - let config: Config = if let Some(config_value) = input.get("config") { - serde_json::from_value(config_value.clone()) - .map_err(|e| format!("Failed to parse config: {}", e))? - } else { - Config::default() - }; - let task_count = tasks.len(); - match execution_mode { - "sequential" => { - if task_count == 1 { - let response = execute_single_task(&tasks[0], config).await; - serde_json::to_value(response) - .map_err(|e| format!("Failed to serialize response: {}", e)) - } else { - Err("Sequential execution mode requires exactly one task".to_string()) - } - } - "parallel" => { - let response = parallel_execute(tasks, config).await; - serde_json::to_value(response) - .map_err(|e| format!("Failed to serialize response: {}", e)) - } - _ => Err("Invalid execution mode".to_string()), - } -} diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/lib/mod.rs b/crates/goose/src/agents/sub_recipe_execution_tool/lib/mod.rs new file mode 100644 index 00000000..446b6011 --- /dev/null +++ b/crates/goose/src/agents/sub_recipe_execution_tool/lib/mod.rs @@ -0,0 +1,127 @@ +use crate::agents::sub_recipe_execution_tool::executor::{ + execute_single_task, execute_tasks_in_parallel, +}; +pub use crate::agents::sub_recipe_execution_tool::task_types::{ + ExecutionMode, ExecutionResponse, ExecutionStats, SharedState, Task, TaskResult, TaskStatus, +}; +use crate::agents::sub_recipe_execution_tool::tasks_manager::TasksManager; + +#[cfg(test)] +mod tests; + +use mcp_core::protocol::JsonRpcMessage; +use serde_json::{json, Value}; +use tokio::sync::mpsc; + +pub async fn execute_tasks( + input: Value, + execution_mode: ExecutionMode, + notifier: mpsc::Sender, + tasks_manager: &TasksManager, +) -> Result { + let task_ids: Vec = serde_json::from_value( + input + .get("task_ids") + .ok_or("Missing task_ids field")? + .clone(), + ) + .map_err(|e| format!("Failed to parse task_ids: {}", e))?; + + let mut tasks = Vec::new(); + for task_id in &task_ids { + match tasks_manager.get_task(task_id).await { + Some(task) => tasks.push(task), + None => { + return Err(format!( + "Task with ID '{}' not found in TasksManager", + task_id + )) + } + } + } + + let task_count = tasks.len(); + match execution_mode { + ExecutionMode::Sequential => { + if task_count == 1 { + let response = execute_single_task(&tasks[0], notifier).await; + handle_response(response) + } else { + Err("Sequential execution mode requires exactly one task".to_string()) + } + } + + ExecutionMode::Parallel => { + if tasks.iter().any(|task| task.get_sequential_when_repeated()) { + Ok(json!( + { + "execution_mode": ExecutionMode::Sequential, + "task_ids": task_ids, + "results": ["the tasks should be executed sequentially, no matter how user requests it. Please use the subrecipe__execute_task tool to execute the tasks sequentially."] + } + )) + } else { + let response: ExecutionResponse = + execute_tasks_in_parallel(tasks, notifier.clone()).await; + handle_response(response) + } + } + } +} + +fn extract_failed_tasks(results: &[TaskResult]) -> Vec { + results + .iter() + .filter(|r| matches!(r.status, TaskStatus::Failed)) + .map(format_failed_task_error) + .collect() +} + +fn format_failed_task_error(result: &TaskResult) -> String { + let error_msg = result.error.as_deref().unwrap_or("Unknown error"); + let partial_output = result + .data + .as_ref() + .and_then(|d| d.get("partial_output")) + .and_then(|v| v.as_str()) + .filter(|s| !s.trim().is_empty()) + .unwrap_or("No output captured"); + + format!( + "Task '{}' ({}): {}\nOutput: {}", + result.task_id, + get_task_description(result), + error_msg, + partial_output + ) +} + +fn format_error_summary( + failed_count: usize, + total_count: usize, + failed_tasks: Vec, +) -> String { + format!( + "{}/{} tasks failed:\n{}", + failed_count, + total_count, + failed_tasks.join("\n") + ) +} + +fn handle_response(response: ExecutionResponse) -> Result { + if response.stats.failed > 0 { + let failed_tasks = extract_failed_tasks(&response.results); + let error_summary = format_error_summary( + response.stats.failed, + response.stats.total_tasks, + failed_tasks, + ); + return Err(error_summary); + } + serde_json::to_value(response).map_err(|e| format!("Failed to serialize response: {}", e)) +} + +fn get_task_description(result: &TaskResult) -> String { + format!("ID: {}", result.task_id) +} diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/lib/tests.rs b/crates/goose/src/agents/sub_recipe_execution_tool/lib/tests.rs new file mode 100644 index 00000000..957b1127 --- /dev/null +++ b/crates/goose/src/agents/sub_recipe_execution_tool/lib/tests.rs @@ -0,0 +1,216 @@ +use super::{ + extract_failed_tasks, format_error_summary, format_failed_task_error, get_task_description, + handle_response, +}; +use crate::agents::sub_recipe_execution_tool::lib::{ + ExecutionResponse, ExecutionStats, TaskResult, TaskStatus, +}; +use serde_json::json; + +fn create_test_task_result(task_id: &str, status: TaskStatus, error: Option) -> TaskResult { + TaskResult { + task_id: task_id.to_string(), + status, + data: Some(json!({"partial_output": "test output"})), + error, + } +} + +fn create_test_execution_response( + results: Vec, + failed_count: usize, +) -> ExecutionResponse { + ExecutionResponse { + status: "completed".to_string(), + results: results.clone(), + stats: ExecutionStats { + total_tasks: results.len(), + completed: results.len() - failed_count, + failed: failed_count, + execution_time_ms: 1000, + }, + } +} + +#[test] +fn test_extract_failed_tasks() { + let results = vec![ + create_test_task_result("task1", TaskStatus::Completed, None), + create_test_task_result( + "task2", + TaskStatus::Failed, + Some("Error message".to_string()), + ), + create_test_task_result("task3", TaskStatus::Completed, None), + create_test_task_result( + "task4", + TaskStatus::Failed, + Some("Another error".to_string()), + ), + ]; + + let failed_tasks = extract_failed_tasks(&results); + + assert_eq!(failed_tasks.len(), 2); + assert!(failed_tasks[0].contains("task2")); + assert!(failed_tasks[0].contains("Error message")); + assert!(failed_tasks[1].contains("task4")); + assert!(failed_tasks[1].contains("Another error")); +} + +#[test] +fn test_extract_failed_tasks_empty() { + let results = vec![ + create_test_task_result("task1", TaskStatus::Completed, None), + create_test_task_result("task2", TaskStatus::Completed, None), + ]; + + let failed_tasks = extract_failed_tasks(&results); + + assert_eq!(failed_tasks.len(), 0); +} + +#[test] +fn test_format_failed_task_error_with_error_message() { + let result = create_test_task_result( + "task1", + TaskStatus::Failed, + Some("Test error message".to_string()), + ); + + let formatted = format_failed_task_error(&result); + + assert!(formatted.contains("task1")); + assert!(formatted.contains("Test error message")); + assert!(formatted.contains("test output")); + assert!(formatted.contains("ID: task1")); +} + +#[test] +fn test_format_failed_task_error_without_error_message() { + let result = create_test_task_result("task2", TaskStatus::Failed, None); + + let formatted = format_failed_task_error(&result); + + assert!(formatted.contains("task2")); + assert!(formatted.contains("Unknown error")); + assert!(formatted.contains("test output")); +} + +#[test] +fn test_format_failed_task_error_empty_partial_output() { + let mut result = + create_test_task_result("task3", TaskStatus::Failed, Some("Error".to_string())); + result.data = Some(json!({"partial_output": ""})); + + let formatted = format_failed_task_error(&result); + + assert!(formatted.contains("No output captured")); +} + +#[test] +fn test_format_failed_task_error_no_partial_output() { + let mut result = + create_test_task_result("task4", TaskStatus::Failed, Some("Error".to_string())); + result.data = Some(json!({})); + + let formatted = format_failed_task_error(&result); + + assert!(formatted.contains("No output captured")); +} + +#[test] +fn test_format_failed_task_error_no_data() { + let mut result = + create_test_task_result("task5", TaskStatus::Failed, Some("Error".to_string())); + result.data = None; + + let formatted = format_failed_task_error(&result); + + assert!(formatted.contains("No output captured")); +} + +#[test] +fn test_format_error_summary() { + let failed_tasks = vec![ + "Task 'task1': Error 1\nOutput: output1".to_string(), + "Task 'task2': Error 2\nOutput: output2".to_string(), + ]; + + let summary = format_error_summary(2, 5, failed_tasks); + + assert_eq!(summary, "2/5 tasks failed:\nTask 'task1': Error 1\nOutput: output1\nTask 'task2': Error 2\nOutput: output2"); +} + +#[test] +fn test_format_error_summary_single_failure() { + let failed_tasks = vec!["Task 'task1': Error\nOutput: output".to_string()]; + + let summary = format_error_summary(1, 3, failed_tasks); + + assert_eq!( + summary, + "1/3 tasks failed:\nTask 'task1': Error\nOutput: output" + ); +} + +#[test] +fn test_handle_response_success() { + let results = vec![ + create_test_task_result("task1", TaskStatus::Completed, None), + create_test_task_result("task2", TaskStatus::Completed, None), + ]; + let response = create_test_execution_response(results, 0); + + let result = handle_response(response); + + assert!(result.is_ok()); + let value = result.unwrap(); + assert_eq!(value["status"], "completed"); + assert_eq!(value["stats"]["failed"], 0); +} + +#[test] +fn test_handle_response_with_failures() { + let results = vec![ + create_test_task_result("task1", TaskStatus::Completed, None), + create_test_task_result("task2", TaskStatus::Failed, Some("Test error".to_string())), + ]; + let response = create_test_execution_response(results, 1); + + let result = handle_response(response); + + assert!(result.is_err()); + let error = result.unwrap_err(); + assert!(error.contains("1/2 tasks failed")); + assert!(error.contains("task2")); + assert!(error.contains("Test error")); +} + +#[test] +fn test_handle_response_all_failures() { + let results = vec![ + create_test_task_result("task1", TaskStatus::Failed, Some("Error 1".to_string())), + create_test_task_result("task2", TaskStatus::Failed, Some("Error 2".to_string())), + ]; + let response = create_test_execution_response(results, 2); + + let result = handle_response(response); + + assert!(result.is_err()); + let error = result.unwrap_err(); + assert!(error.contains("2/2 tasks failed")); + assert!(error.contains("task1")); + assert!(error.contains("task2")); + assert!(error.contains("Error 1")); + assert!(error.contains("Error 2")); +} + +#[test] +fn test_get_task_description() { + let result = create_test_task_result("test_task_123", TaskStatus::Completed, None); + + let description = get_task_description(&result); + + assert_eq!(description, "ID: test_task_123"); +} diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/mod.rs b/crates/goose/src/agents/sub_recipe_execution_tool/mod.rs index a49791e2..0b7af3b5 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/mod.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/mod.rs @@ -1,6 +1,10 @@ mod executor; pub mod lib; +pub mod notification_events; pub mod sub_recipe_execute_task_tool; +mod task_execution_tracker; +mod task_types; mod tasks; -mod types; +pub mod tasks_manager; +pub mod utils; mod workers; diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/notification_events.rs b/crates/goose/src/agents/sub_recipe_execution_tool/notification_events.rs new file mode 100644 index 00000000..2a6134ea --- /dev/null +++ b/crates/goose/src/agents/sub_recipe_execution_tool/notification_events.rs @@ -0,0 +1,204 @@ +use crate::agents::sub_recipe_execution_tool::task_types::TaskStatus; +use serde::{Deserialize, Serialize}; +use serde_json::Value; + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "subtype")] +pub enum TaskExecutionNotificationEvent { + #[serde(rename = "line_output")] + LineOutput { task_id: String, output: String }, + #[serde(rename = "tasks_update")] + TasksUpdate { + stats: TaskExecutionStats, + tasks: Vec, + }, + #[serde(rename = "tasks_complete")] + TasksComplete { + stats: TaskCompletionStats, + failed_tasks: Vec, + }, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TaskExecutionStats { + pub total: usize, + pub pending: usize, + pub running: usize, + pub completed: usize, + pub failed: usize, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TaskCompletionStats { + pub total: usize, + pub completed: usize, + pub failed: usize, + pub success_rate: f64, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TaskInfo { + pub id: String, + pub status: TaskStatus, + pub duration_secs: Option, + pub current_output: String, + pub task_type: String, + pub task_name: String, + pub task_metadata: String, + pub error: Option, + pub result_data: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct FailedTaskInfo { + pub id: String, + pub name: String, + pub error: Option, +} + +impl TaskExecutionNotificationEvent { + pub fn line_output(task_id: String, output: String) -> Self { + Self::LineOutput { task_id, output } + } + + pub fn tasks_update(stats: TaskExecutionStats, tasks: Vec) -> Self { + Self::TasksUpdate { stats, tasks } + } + + pub fn tasks_complete(stats: TaskCompletionStats, failed_tasks: Vec) -> Self { + Self::TasksComplete { + stats, + failed_tasks, + } + } + + /// Convert event to JSON format for MCP notification + pub fn to_notification_data(&self) -> serde_json::Value { + let mut event_data = serde_json::to_value(self).expect("Failed to serialize event"); + + // Add the type field at the root level + if let serde_json::Value::Object(ref mut map) = event_data { + map.insert( + "type".to_string(), + serde_json::Value::String("task_execution".to_string()), + ); + } + + event_data + } +} + +impl TaskExecutionStats { + pub fn new( + total: usize, + pending: usize, + running: usize, + completed: usize, + failed: usize, + ) -> Self { + Self { + total, + pending, + running, + completed, + failed, + } + } +} + +impl TaskCompletionStats { + pub fn new(total: usize, completed: usize, failed: usize) -> Self { + let success_rate = if total > 0 { + (completed as f64 / total as f64) * 100.0 + } else { + 0.0 + }; + + Self { + total, + completed, + failed, + success_rate, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_line_output_event_serialization() { + let event = TaskExecutionNotificationEvent::line_output( + "task-1".to_string(), + "Hello World".to_string(), + ); + + let notification_data = event.to_notification_data(); + assert_eq!(notification_data["type"], "task_execution"); + assert_eq!(notification_data["subtype"], "line_output"); + assert_eq!(notification_data["task_id"], "task-1"); + assert_eq!(notification_data["output"], "Hello World"); + } + + #[test] + fn test_tasks_update_event_serialization() { + let stats = TaskExecutionStats::new(5, 2, 1, 1, 1); + let tasks = vec![TaskInfo { + id: "task-1".to_string(), + status: TaskStatus::Running, + duration_secs: Some(1.5), + current_output: "Processing...".to_string(), + task_type: "sub_recipe".to_string(), + task_name: "test-task".to_string(), + task_metadata: "param=value".to_string(), + error: None, + result_data: None, + }]; + + let event = TaskExecutionNotificationEvent::tasks_update(stats, tasks); + let notification_data = event.to_notification_data(); + + assert_eq!(notification_data["type"], "task_execution"); + assert_eq!(notification_data["subtype"], "tasks_update"); + assert_eq!(notification_data["stats"]["total"], 5); + assert_eq!(notification_data["tasks"].as_array().unwrap().len(), 1); + } + + #[test] + fn test_event_roundtrip_serialization() { + let original_event = TaskExecutionNotificationEvent::line_output( + "task-1".to_string(), + "Test output".to_string(), + ); + + // Serialize to JSON + let json_data = original_event.to_notification_data(); + + // Deserialize back to event (excluding the type field) + let mut event_data = json_data.clone(); + if let serde_json::Value::Object(ref mut map) = event_data { + map.remove("type"); + } + + let deserialized_event: TaskExecutionNotificationEvent = + serde_json::from_value(event_data).expect("Failed to deserialize"); + + match (original_event, deserialized_event) { + ( + TaskExecutionNotificationEvent::LineOutput { + task_id: id1, + output: out1, + }, + TaskExecutionNotificationEvent::LineOutput { + task_id: id2, + output: out2, + }, + ) => { + assert_eq!(id1, id2); + assert_eq!(out1, out2); + } + _ => panic!("Event types don't match after roundtrip"), + } + } +} diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/sub_recipe_execute_task_tool.rs b/crates/goose/src/agents/sub_recipe_execution_tool/sub_recipe_execute_task_tool.rs index 46738b81..e5f9062f 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/sub_recipe_execute_task_tool.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/sub_recipe_execute_task_tool.rs @@ -2,17 +2,22 @@ use mcp_core::{tool::ToolAnnotations, Content, Tool, ToolError}; use serde_json::Value; use crate::agents::{ - sub_recipe_execution_tool::lib::execute_tasks, tool_execution::ToolCallResult, + sub_recipe_execution_tool::lib::execute_tasks, + sub_recipe_execution_tool::task_types::ExecutionMode, + sub_recipe_execution_tool::tasks_manager::TasksManager, tool_execution::ToolCallResult, }; +use mcp_core::protocol::JsonRpcMessage; +use tokio::sync::mpsc; +use tokio_stream; pub const SUB_RECIPE_EXECUTE_TASK_TOOL_NAME: &str = "sub_recipe__execute_task"; pub fn create_sub_recipe_execute_task_tool() -> Tool { Tool::new( SUB_RECIPE_EXECUTE_TASK_TOOL_NAME, "Only use this tool when you execute sub recipe task. -EXECUTION STRATEGY: -- DEFAULT: Execute tasks sequentially (one at a time) unless user explicitly requests parallel execution -- PARALLEL: Only when user explicitly uses keywords like 'parallel', 'simultaneously', 'at the same time', 'concurrently' +EXECUTION STRATEGY DECISION: +1. If the tasks are created with execution_mode, use the execution_mode. +2. Execute tasks sequentially unless user explicitly requests parallel execution. PARALLEL: User uses keywords like 'parallel', 'simultaneously', 'at the same time', 'concurrently' IMPLEMENTATION: - Sequential execution: Call this tool multiple times, passing exactly ONE task per call @@ -32,69 +37,15 @@ EXAMPLES: "default": "sequential", "description": "Execution strategy for multiple tasks. Use 'sequential' (default) unless user explicitly requests parallel execution with words like 'parallel', 'simultaneously', 'at the same time', or 'concurrently'." }, - "tasks": { + "task_ids": { "type": "array", "items": { - "type": "object", - "properties": { - "id": { - "type": "string", - "description": "Unique identifier for the task" - }, - "task_type": { - "type": "string", - "enum": ["sub_recipe", "text_instruction"], - "default": "sub_recipe", - "description": "the type of task to execute, can be one of: sub_recipe, text_instruction" - }, - "payload": { - "type": "object", - "properties": { - "sub_recipe": { - "type": "object", - "description": "sub recipe to execute", - "properties": { - "name": { - "type": "string", - "description": "name of the sub recipe to execute" - }, - "recipe_path": { - "type": "string", - "description": "path of the sub recipe file" - }, - "command_parameters": { - "type": "object", - "description": "parameters to pass to run recipe command with sub recipe file" - } - } - }, - "text_instruction": { - "type": "string", - "description": "text instruction to execute" - } - } - } - }, - "required": ["id", "payload"] - }, - "description": "The tasks to run in parallel" - }, - "config": { - "type": "object", - "properties": { - "timeout_seconds": { - "type": "number" - }, - "max_workers": { - "type": "number" - }, - "initial_workers": { - "type": "number" - } + "type": "string", + "description": "Unique identifier for the task" } } }, - "required": ["tasks"] + "required": ["task_ids"] }), Some(ToolAnnotations { title: Some("Run tasks in parallel".to_string()), @@ -106,19 +57,38 @@ EXAMPLES: ) } -pub async fn run_tasks(execute_data: Value) -> ToolCallResult { - let execute_data_clone = execute_data.clone(); - let default_execution_mode_value = Value::String("sequential".to_string()); - let execution_mode = execute_data_clone - .get("execution_mode") - .unwrap_or(&default_execution_mode_value) - .as_str() - .unwrap_or("sequential"); - match execute_tasks(execute_data, execution_mode).await { - Ok(result) => { - let output = serde_json::to_string(&result).unwrap(); - ToolCallResult::from(Ok(vec![Content::text(output)])) +pub async fn run_tasks(execute_data: Value, tasks_manager: &TasksManager) -> ToolCallResult { + let (notification_tx, notification_rx) = mpsc::channel::(100); + + let tasks_manager_clone = tasks_manager.clone(); + let result_future = async move { + let execute_data_clone = execute_data.clone(); + let execution_mode = execute_data_clone + .get("execution_mode") + .and_then(|v| serde_json::from_value::(v.clone()).ok()) + .unwrap_or_default(); + + match execute_tasks( + execute_data, + execution_mode, + notification_tx, + &tasks_manager_clone, + ) + .await + { + Ok(result) => { + let output = serde_json::to_string(&result).unwrap(); + Ok(vec![Content::text(output)]) + } + Err(e) => Err(ToolError::ExecutionError(e.to_string())), } - Err(e) => ToolCallResult::from(Err(ToolError::ExecutionError(e.to_string()))), + }; + + // Convert receiver to stream + let notification_stream = tokio_stream::wrappers::ReceiverStream::new(notification_rx); + + ToolCallResult { + result: Box::new(Box::pin(result_future)), + notification_stream: Some(Box::new(notification_stream)), } } diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/task_execution_tracker.rs b/crates/goose/src/agents/sub_recipe_execution_tool/task_execution_tracker.rs new file mode 100644 index 00000000..a906a59a --- /dev/null +++ b/crates/goose/src/agents/sub_recipe_execution_tool/task_execution_tracker.rs @@ -0,0 +1,292 @@ +use mcp_core::protocol::{JsonRpcMessage, JsonRpcNotification}; +use serde_json::json; +use std::collections::HashMap; +use std::sync::Arc; +use tokio::sync::{mpsc, RwLock}; +use tokio::time::{sleep, Duration, Instant}; + +use crate::agents::sub_recipe_execution_tool::notification_events::{ + FailedTaskInfo, TaskCompletionStats, TaskExecutionNotificationEvent, TaskExecutionStats, + TaskInfo as EventTaskInfo, +}; +use crate::agents::sub_recipe_execution_tool::task_types::{ + Task, TaskInfo, TaskResult, TaskStatus, +}; +use crate::agents::sub_recipe_execution_tool::utils::{count_by_status, get_task_name}; +use serde_json::Value; + +#[derive(Debug, Clone, PartialEq)] +pub enum DisplayMode { + MultipleTasksOutput, + SingleTaskOutput, +} + +const THROTTLE_INTERVAL_MS: u64 = 250; +const COMPLETION_NOTIFICATION_DELAY_MS: u64 = 500; + +fn format_task_metadata(task_info: &TaskInfo) -> String { + if let Some(params) = task_info.task.get_command_parameters() { + if params.is_empty() { + return String::new(); + } + + params + .iter() + .map(|(key, value)| { + let value_str = match value { + Value::String(s) => s.clone(), + _ => value.to_string(), + }; + format!("{}={}", key, value_str) + }) + .collect::>() + .join(",") + } else { + String::new() + } +} + +pub struct TaskExecutionTracker { + tasks: Arc>>, + last_refresh: Arc>, + notifier: mpsc::Sender, + display_mode: DisplayMode, +} + +impl TaskExecutionTracker { + pub fn new( + tasks: Vec, + display_mode: DisplayMode, + notifier: mpsc::Sender, + ) -> Self { + let task_map = tasks + .into_iter() + .map(|task| { + let task_id = task.id.clone(); + ( + task_id, + TaskInfo { + task, + status: TaskStatus::Pending, + start_time: None, + end_time: None, + result: None, + current_output: String::new(), + }, + ) + }) + .collect(); + + Self { + tasks: Arc::new(RwLock::new(task_map)), + last_refresh: Arc::new(RwLock::new(Instant::now())), + notifier, + display_mode, + } + } + + pub async fn start_task(&self, task_id: &str) { + let mut tasks = self.tasks.write().await; + if let Some(task_info) = tasks.get_mut(task_id) { + task_info.status = TaskStatus::Running; + task_info.start_time = Some(Instant::now()); + } + drop(tasks); + self.force_refresh_display().await; + } + + pub async fn complete_task(&self, task_id: &str, result: TaskResult) { + let mut tasks = self.tasks.write().await; + if let Some(task_info) = tasks.get_mut(task_id) { + task_info.status = result.status.clone(); + task_info.end_time = Some(Instant::now()); + task_info.result = Some(result); + } + drop(tasks); + self.force_refresh_display().await; + } + + pub async fn get_current_output(&self, task_id: &str) -> Option { + let tasks = self.tasks.read().await; + tasks + .get(task_id) + .map(|task_info| task_info.current_output.clone()) + } + + pub async fn send_live_output(&self, task_id: &str, line: &str) { + match self.display_mode { + DisplayMode::SingleTaskOutput => { + let tasks = self.tasks.read().await; + let task_info = tasks.get(task_id); + + let formatted_line = if let Some(task_info) = task_info { + let task_name = get_task_name(task_info); + let task_type = task_info.task.task_type.clone(); + let metadata = format_task_metadata(task_info); + + if metadata.is_empty() { + format!("[{} ({})] {}", task_name, task_type, line) + } else { + format!("[{} ({}) {}] {}", task_name, task_type, metadata, line) + } + } else { + line.to_string() + }; + drop(tasks); + + let event = TaskExecutionNotificationEvent::line_output( + task_id.to_string(), + formatted_line, + ); + + if let Err(e) = + self.notifier + .try_send(JsonRpcMessage::Notification(JsonRpcNotification { + jsonrpc: "2.0".to_string(), + method: "notifications/message".to_string(), + params: Some(json!({ + "data": event.to_notification_data() + })), + })) + { + tracing::warn!("Failed to send live output notification: {}", e); + } + } + DisplayMode::MultipleTasksOutput => { + let mut tasks = self.tasks.write().await; + if let Some(task_info) = tasks.get_mut(task_id) { + task_info.current_output.push_str(line); + task_info.current_output.push('\n'); + } + drop(tasks); + + if !self.should_throttle_refresh().await { + self.refresh_display().await; + } + } + } + } + + async fn should_throttle_refresh(&self) -> bool { + let now = Instant::now(); + let mut last_refresh = self.last_refresh.write().await; + + if now.duration_since(*last_refresh) > Duration::from_millis(THROTTLE_INTERVAL_MS) { + *last_refresh = now; + false + } else { + true + } + } + + async fn send_tasks_update(&self) { + let tasks = self.tasks.read().await; + let task_list: Vec<_> = tasks.values().collect(); + let (total, pending, running, completed, failed) = count_by_status(&tasks); + + let stats = TaskExecutionStats::new(total, pending, running, completed, failed); + + let event_tasks: Vec = task_list + .iter() + .map(|task_info| { + let now = Instant::now(); + EventTaskInfo { + id: task_info.task.id.clone(), + status: task_info.status.clone(), + duration_secs: task_info.start_time.map(|start| { + if let Some(end) = task_info.end_time { + end.duration_since(start).as_secs_f64() + } else { + now.duration_since(start).as_secs_f64() + } + }), + current_output: task_info.current_output.clone(), + task_type: task_info.task.task_type.clone(), + task_name: get_task_name(task_info).to_string(), + task_metadata: format_task_metadata(task_info), + error: task_info.error().cloned(), + result_data: task_info.data().cloned(), + } + }) + .collect(); + + let event = TaskExecutionNotificationEvent::tasks_update(stats, event_tasks); + + if let Err(e) = self + .notifier + .try_send(JsonRpcMessage::Notification(JsonRpcNotification { + jsonrpc: "2.0".to_string(), + method: "notifications/message".to_string(), + params: Some(json!({ + "data": event.to_notification_data() + })), + })) + { + tracing::warn!("Failed to send tasks update notification: {}", e); + } + } + + pub async fn refresh_display(&self) { + match self.display_mode { + DisplayMode::MultipleTasksOutput => { + self.send_tasks_update().await; + } + DisplayMode::SingleTaskOutput => { + // No dashboard display needed for single task output mode + // Live output is handled via send_live_output method + } + } + } + + // Force refresh without throttling - used for important status changes + async fn force_refresh_display(&self) { + match self.display_mode { + DisplayMode::MultipleTasksOutput => { + // Reset throttle timer to allow immediate update + let mut last_refresh = self.last_refresh.write().await; + *last_refresh = Instant::now() - Duration::from_millis(THROTTLE_INTERVAL_MS + 1); + drop(last_refresh); + + self.send_tasks_update().await; + } + DisplayMode::SingleTaskOutput => { + // No dashboard display needed for single task output mode + } + } + } + + pub async fn send_tasks_complete(&self) { + let tasks = self.tasks.read().await; + let (total, _, _, completed, failed) = count_by_status(&tasks); + + let stats = TaskCompletionStats::new(total, completed, failed); + + let failed_tasks: Vec = tasks + .values() + .filter(|task_info| matches!(task_info.status, TaskStatus::Failed)) + .map(|task_info| FailedTaskInfo { + id: task_info.task.id.clone(), + name: get_task_name(task_info).to_string(), + error: task_info.error().cloned(), + }) + .collect(); + + let event = TaskExecutionNotificationEvent::tasks_complete(stats, failed_tasks); + + if let Err(e) = self + .notifier + .try_send(JsonRpcMessage::Notification(JsonRpcNotification { + jsonrpc: "2.0".to_string(), + method: "notifications/message".to_string(), + params: Some(json!({ + "data": event.to_notification_data() + })), + })) + { + tracing::warn!("Failed to send tasks complete notification: {}", e); + } + + // Brief delay to ensure completion notification is processed + sleep(Duration::from_millis(COMPLETION_NOTIFICATION_DELAY_MS)).await; + } +} diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/task_types.rs b/crates/goose/src/agents/sub_recipe_execution_tool/task_types.rs new file mode 100644 index 00000000..4515bb84 --- /dev/null +++ b/crates/goose/src/agents/sub_recipe_execution_tool/task_types.rs @@ -0,0 +1,145 @@ +use serde::{Deserialize, Serialize}; +use serde_json::{Map, Value}; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::Arc; +use tokio::sync::mpsc; + +use crate::agents::sub_recipe_execution_tool::task_execution_tracker::TaskExecutionTracker; + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)] +#[serde(rename_all = "lowercase")] +pub enum ExecutionMode { + #[default] + Sequential, + Parallel, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Task { + pub id: String, + pub task_type: String, + pub payload: Value, +} + +impl Task { + pub fn get_sub_recipe(&self) -> Option<&Map> { + (self.task_type == "sub_recipe") + .then(|| self.payload.get("sub_recipe")?.as_object()) + .flatten() + } + + pub fn get_sequential_when_repeated(&self) -> bool { + self.get_sub_recipe() + .and_then(|sr| sr.get("sequential_when_repeated").and_then(|v| v.as_bool())) + .unwrap_or_default() + } + + pub fn get_command_parameters(&self) -> Option<&Map> { + self.get_sub_recipe() + .and_then(|sr| sr.get("command_parameters")) + .and_then(|cp| cp.as_object()) + } + + pub fn get_sub_recipe_name(&self) -> Option<&str> { + self.get_sub_recipe() + .and_then(|sr| sr.get("name")) + .and_then(|name| name.as_str()) + } + + pub fn get_sub_recipe_path(&self) -> Option<&str> { + self.get_sub_recipe() + .and_then(|sr| sr.get("recipe_path")) + .and_then(|path| path.as_str()) + } + + pub fn get_text_instruction(&self) -> Option<&str> { + if self.task_type != "sub_recipe" { + self.payload + .get("text_instruction") + .and_then(|text| text.as_str()) + } else { + None + } + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TaskResult { + pub task_id: String, + pub status: TaskStatus, + #[serde(skip_serializing_if = "Option::is_none")] + pub data: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub error: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub enum TaskStatus { + Pending, + Running, + Completed, + Failed, +} + +impl std::fmt::Display for TaskStatus { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + TaskStatus::Pending => write!(f, "Pending"), + TaskStatus::Running => write!(f, "Running"), + TaskStatus::Completed => write!(f, "Completed"), + TaskStatus::Failed => write!(f, "Failed"), + } + } +} + +#[derive(Debug, Clone)] +pub struct TaskInfo { + pub task: Task, + pub status: TaskStatus, + pub start_time: Option, + pub end_time: Option, + pub result: Option, + pub current_output: String, +} + +impl TaskInfo { + pub fn error(&self) -> Option<&String> { + self.result.as_ref().and_then(|r| r.error.as_ref()) + } + + pub fn data(&self) -> Option<&Value> { + self.result.as_ref().and_then(|r| r.data.as_ref()) + } +} + +pub struct SharedState { + pub task_receiver: Arc>>, + pub result_sender: mpsc::Sender, + pub active_workers: Arc, + pub task_execution_tracker: Arc, +} + +impl SharedState { + pub fn increment_active_workers(&self) { + self.active_workers.fetch_add(1, Ordering::SeqCst); + } + + pub fn decrement_active_workers(&self) { + self.active_workers.fetch_sub(1, Ordering::SeqCst); + } +} + +#[derive(Debug, Serialize)] +pub struct ExecutionStats { + pub total_tasks: usize, + pub completed: usize, + pub failed: usize, + pub execution_time_ms: u128, +} + +#[derive(Debug, Serialize)] +pub struct ExecutionResponse { + pub status: String, + pub results: Vec, + pub stats: ExecutionStats, +} diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/tasks.rs b/crates/goose/src/agents/sub_recipe_execution_tool/tasks.rs index 4e4584aa..a3ab4140 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/tasks.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/tasks.rs @@ -1,76 +1,98 @@ use serde_json::Value; use std::process::Stdio; -use std::time::Duration; +use std::sync::Arc; use tokio::io::{AsyncBufReadExt, BufReader}; use tokio::process::Command; -use tokio::time::timeout; -use crate::agents::sub_recipe_execution_tool::types::{Task, TaskResult}; +use crate::agents::sub_recipe_execution_tool::task_execution_tracker::TaskExecutionTracker; +use crate::agents::sub_recipe_execution_tool::task_types::{Task, TaskResult, TaskStatus}; -// Process a single task based on its type -pub async fn process_task(task: &Task, timeout_seconds: u64) -> TaskResult { - let task_clone = task.clone(); - let timeout_duration = Duration::from_secs(timeout_seconds); - - // Execute with timeout - match timeout(timeout_duration, execute_task(task_clone)).await { - Ok(Ok(data)) => TaskResult { +pub async fn process_task( + task: &Task, + task_execution_tracker: Arc, +) -> TaskResult { + match get_task_result(task.clone(), task_execution_tracker).await { + Ok(data) => TaskResult { task_id: task.id.clone(), - status: "success".to_string(), + status: TaskStatus::Completed, data: Some(data), error: None, }, - Ok(Err(error)) => TaskResult { + Err(error) => TaskResult { task_id: task.id.clone(), - status: "failed".to_string(), + status: TaskStatus::Failed, data: None, error: Some(error), }, - Err(_) => TaskResult { - task_id: task.id.clone(), - status: "failed".to_string(), - data: None, - error: Some("Task timeout".to_string()), - }, } } -async fn execute_task(task: Task) -> Result { +async fn get_task_result( + task: Task, + task_execution_tracker: Arc, +) -> Result { + let (command, output_identifier) = build_command(&task)?; + let (stdout_output, stderr_output, success) = run_command( + command, + &output_identifier, + &task.id, + task_execution_tracker, + ) + .await?; + + if success { + process_output(stdout_output) + } else { + Err(format!("Command failed:\n{}", stderr_output)) + } +} + +fn build_command(task: &Task) -> Result<(Command, String), String> { + let task_error = |field: &str| format!("Task {}: Missing {}", task.id, field); + let mut output_identifier = task.id.clone(); let mut command = if task.task_type == "sub_recipe" { - let sub_recipe = task.payload.get("sub_recipe").unwrap(); - let sub_recipe_name = sub_recipe.get("name").unwrap().as_str().unwrap(); - let path = sub_recipe.get("recipe_path").unwrap().as_str().unwrap(); - let command_parameters = sub_recipe.get("command_parameters").unwrap(); + let sub_recipe_name = task + .get_sub_recipe_name() + .ok_or_else(|| task_error("sub_recipe name"))?; + let path = task + .get_sub_recipe_path() + .ok_or_else(|| task_error("sub_recipe path"))?; + let command_parameters = task + .get_command_parameters() + .ok_or_else(|| task_error("command_parameters"))?; + output_identifier = format!("sub-recipe {}", sub_recipe_name); let mut cmd = Command::new("goose"); - cmd.arg("run").arg("--recipe").arg(path); - if let Some(params_map) = command_parameters.as_object() { - for (key, value) in params_map { - let key_str = key.to_string(); - let value_str = value.as_str().unwrap_or(&value.to_string()).to_string(); - cmd.arg("--params") - .arg(format!("{}={}", key_str, value_str)); - } + cmd.arg("run").arg("--recipe").arg(path).arg("--no-session"); + + for (key, value) in command_parameters { + let key_str = key.to_string(); + let value_str = value.as_str().unwrap_or(&value.to_string()).to_string(); + cmd.arg("--params") + .arg(format!("{}={}", key_str, value_str)); } cmd } else { let text = task - .payload - .get("text_instruction") - .unwrap() - .as_str() - .unwrap(); + .get_text_instruction() + .ok_or_else(|| task_error("text_instruction"))?; let mut cmd = Command::new("goose"); cmd.arg("run").arg("--text").arg(text); cmd }; - // Configure to capture stdout command.stdout(Stdio::piped()); command.stderr(Stdio::piped()); + Ok((command, output_identifier)) +} - // Spawn the child process +async fn run_command( + mut command: Command, + output_identifier: &str, + task_id: &str, + task_execution_tracker: Arc, +) -> Result<(String, String, bool), String> { let mut child = command .spawn() .map_err(|e| format!("Failed to spawn goose: {}", e))?; @@ -78,30 +100,20 @@ async fn execute_task(task: Task) -> Result { let stdout = child.stdout.take().expect("Failed to capture stdout"); let stderr = child.stderr.take().expect("Failed to capture stderr"); - let mut stdout_reader = BufReader::new(stdout).lines(); - let mut stderr_reader = BufReader::new(stderr).lines(); - - // Spawn background tasks to read from stdout and stderr - let output_identifier_clone = output_identifier.clone(); - let stdout_task = tokio::spawn(async move { - let mut buffer = String::new(); - while let Ok(Some(line)) = stdout_reader.next_line().await { - println!("[{}] {}", output_identifier_clone, line); - buffer.push_str(&line); - buffer.push('\n'); - } - buffer - }); - - let stderr_task = tokio::spawn(async move { - let mut buffer = String::new(); - while let Ok(Some(line)) = stderr_reader.next_line().await { - eprintln!("[stderr for {}] {}", output_identifier, line); - buffer.push_str(&line); - buffer.push('\n'); - } - buffer - }); + let stdout_task = spawn_output_reader( + stdout, + output_identifier, + false, + task_id, + task_execution_tracker.clone(), + ); + let stderr_task = spawn_output_reader( + stderr, + output_identifier, + true, + task_id, + task_execution_tracker.clone(), + ); let status = child .wait() @@ -111,9 +123,63 @@ async fn execute_task(task: Task) -> Result { let stdout_output = stdout_task.await.unwrap(); let stderr_output = stderr_task.await.unwrap(); - if status.success() { - Ok(Value::String(stdout_output)) + Ok((stdout_output, stderr_output, status.success())) +} + +fn spawn_output_reader( + reader: impl tokio::io::AsyncRead + Unpin + Send + 'static, + output_identifier: &str, + is_stderr: bool, + task_id: &str, + task_execution_tracker: Arc, +) -> tokio::task::JoinHandle { + let output_identifier = output_identifier.to_string(); + let task_id = task_id.to_string(); + tokio::spawn(async move { + let mut buffer = String::new(); + let mut lines = BufReader::new(reader).lines(); + while let Ok(Some(line)) = lines.next_line().await { + buffer.push_str(&line); + buffer.push('\n'); + + if !is_stderr { + task_execution_tracker + .send_live_output(&task_id, &line) + .await; + } else { + tracing::warn!("Task stderr [{}]: {}", output_identifier, line); + } + } + buffer + }) +} + +fn extract_json_from_line(line: &str) -> Option { + let start = line.find('{')?; + let end = line.rfind('}')?; + + if start >= end { + return None; + } + + let potential_json = &line[start..=end]; + if serde_json::from_str::(potential_json).is_ok() { + Some(potential_json.to_string()) } else { - Err(format!("Command failed:\n{}", stderr_output)) + None + } +} + +fn process_output(stdout_output: String) -> Result { + let last_line = stdout_output + .lines() + .filter(|line| !line.trim().is_empty()) + .next_back() + .unwrap_or(""); + + if let Some(json_string) = extract_json_from_line(last_line) { + Ok(Value::String(json_string)) + } else { + Ok(Value::String(stdout_output)) } } diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/tasks_manager.rs b/crates/goose/src/agents/sub_recipe_execution_tool/tasks_manager.rs new file mode 100644 index 00000000..433478be --- /dev/null +++ b/crates/goose/src/agents/sub_recipe_execution_tool/tasks_manager.rs @@ -0,0 +1,86 @@ +use std::collections::HashMap; +use std::sync::Arc; +use tokio::sync::RwLock; + +use crate::agents::sub_recipe_execution_tool::task_types::Task; + +#[derive(Debug, Clone)] +pub struct TasksManager { + tasks: Arc>>, +} + +impl Default for TasksManager { + fn default() -> Self { + Self::new() + } +} + +impl TasksManager { + pub fn new() -> Self { + Self { + tasks: Arc::new(RwLock::new(HashMap::new())), + } + } + + pub async fn save_tasks(&self, tasks: Vec) { + let mut task_map = self.tasks.write().await; + for task in tasks { + task_map.insert(task.id.clone(), task); + } + } + + pub async fn get_task(&self, task_id: &str) -> Option { + let tasks = self.tasks.read().await; + tasks.get(task_id).cloned() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use serde_json::json; + + fn create_test_task(id: &str, sub_recipe_name: &str) -> Task { + Task { + id: id.to_string(), + task_type: "sub_recipe".to_string(), + payload: json!({ + "sub_recipe": { + "name": sub_recipe_name, + "command_parameters": {}, + "recipe_path": "/test/path" + } + }), + } + } + + #[tokio::test] + async fn test_save_and_get_task() { + let manager = TasksManager::new(); + let tasks = vec![create_test_task("task1", "weather")]; + + manager.save_tasks(tasks).await; + + let retrieved = manager.get_task("task1").await; + assert!(retrieved.is_some()); + assert_eq!(retrieved.unwrap().id, "task1"); + } + + #[tokio::test] + async fn test_save_multiple_tasks() { + let manager = TasksManager::new(); + let tasks = vec![ + create_test_task("task1", "weather"), + create_test_task("task2", "news"), + ]; + + manager.save_tasks(tasks).await; + + let task1 = manager.get_task("task1").await; + let task2 = manager.get_task("task2").await; + assert!(task1.is_some()); + assert!(task2.is_some()); + assert_eq!(task1.unwrap().id, "task1"); + assert_eq!(task2.unwrap().id, "task2"); + } +} diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/types.rs b/crates/goose/src/agents/sub_recipe_execution_tool/types.rs deleted file mode 100644 index ede71dbf..00000000 --- a/crates/goose/src/agents/sub_recipe_execution_tool/types.rs +++ /dev/null @@ -1,69 +0,0 @@ -use serde::{Deserialize, Serialize}; -use serde_json::Value; - -// Task definition that LLMs will send -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct Task { - pub id: String, - pub task_type: String, - pub payload: Value, -} - -// Result for each task -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct TaskResult { - pub task_id: String, - pub status: String, - #[serde(skip_serializing_if = "Option::is_none")] - pub data: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub error: Option, -} - -// Configuration for the parallel executor -#[derive(Debug, Clone, Deserialize)] -pub struct Config { - #[serde(default = "default_max_workers")] - pub max_workers: usize, - #[serde(default = "default_timeout")] - pub timeout_seconds: u64, - #[serde(default = "default_initial_workers")] - pub initial_workers: usize, -} - -impl Default for Config { - fn default() -> Self { - Self { - max_workers: default_max_workers(), - timeout_seconds: default_timeout(), - initial_workers: default_initial_workers(), - } - } -} - -fn default_max_workers() -> usize { - 10 -} -fn default_timeout() -> u64 { - 300 -} -fn default_initial_workers() -> usize { - 2 -} - -// Stats for the execution -#[derive(Debug, Serialize)] -pub struct ExecutionStats { - pub total_tasks: usize, - pub completed: usize, - pub failed: usize, - pub execution_time_ms: u128, -} - -// Main response structure -#[derive(Debug, Serialize)] -pub struct ExecutionResponse { - pub status: String, - pub results: Vec, - pub stats: ExecutionStats, -} diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/utils/mod.rs b/crates/goose/src/agents/sub_recipe_execution_tool/utils/mod.rs new file mode 100644 index 00000000..1ead865e --- /dev/null +++ b/crates/goose/src/agents/sub_recipe_execution_tool/utils/mod.rs @@ -0,0 +1,27 @@ +use std::collections::HashMap; + +use crate::agents::sub_recipe_execution_tool::task_types::{TaskInfo, TaskStatus}; + +pub fn get_task_name(task_info: &TaskInfo) -> &str { + task_info + .task + .get_sub_recipe_name() + .unwrap_or(&task_info.task.id) +} + +pub fn count_by_status(tasks: &HashMap) -> (usize, usize, usize, usize, usize) { + let total = tasks.len(); + let (pending, running, completed, failed) = tasks.values().fold( + (0, 0, 0, 0), + |(pending, running, completed, failed), task| match task.status { + TaskStatus::Pending => (pending + 1, running, completed, failed), + TaskStatus::Running => (pending, running + 1, completed, failed), + TaskStatus::Completed => (pending, running, completed + 1, failed), + TaskStatus::Failed => (pending, running, completed, failed + 1), + }, + ); + (total, pending, running, completed, failed) +} + +#[cfg(test)] +mod tests; diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/utils/tests.rs b/crates/goose/src/agents/sub_recipe_execution_tool/utils/tests.rs new file mode 100644 index 00000000..de5bac92 --- /dev/null +++ b/crates/goose/src/agents/sub_recipe_execution_tool/utils/tests.rs @@ -0,0 +1,154 @@ +use crate::agents::sub_recipe_execution_tool::task_types::{Task, TaskInfo, TaskStatus}; +use crate::agents::sub_recipe_execution_tool::utils::{count_by_status, get_task_name}; +use serde_json::json; +use std::collections::HashMap; + +fn create_task_info_with_defaults(task: Task, status: TaskStatus) -> TaskInfo { + TaskInfo { + task, + status, + start_time: None, + end_time: None, + result: None, + current_output: String::new(), + } +} + +mod test_get_task_name { + use super::*; + + #[test] + fn test_extracts_sub_recipe_name() { + let sub_recipe_task = Task { + id: "task_1".to_string(), + task_type: "sub_recipe".to_string(), + payload: json!({ + "sub_recipe": { + "name": "my_recipe", + "recipe_path": "/path/to/recipe" + } + }), + }; + + let task_info = create_task_info_with_defaults(sub_recipe_task, TaskStatus::Pending); + + assert_eq!(get_task_name(&task_info), "my_recipe"); + } + + #[test] + fn falls_back_to_task_id_for_text_instruction() { + let text_task = Task { + id: "task_2".to_string(), + task_type: "text_instruction".to_string(), + payload: json!({"text_instruction": "do something"}), + }; + + let task_info = create_task_info_with_defaults(text_task, TaskStatus::Pending); + + assert_eq!(get_task_name(&task_info), "task_2"); + } + + #[test] + fn falls_back_to_task_id_when_sub_recipe_name_missing() { + let malformed_task = Task { + id: "task_3".to_string(), + task_type: "sub_recipe".to_string(), + payload: json!({ + "sub_recipe": { + "recipe_path": "/path/to/recipe" + // missing "name" field + } + }), + }; + + let task_info = create_task_info_with_defaults(malformed_task, TaskStatus::Pending); + + assert_eq!(get_task_name(&task_info), "task_3"); + } + + #[test] + fn falls_back_to_task_id_when_sub_recipe_missing() { + let malformed_task = Task { + id: "task_4".to_string(), + task_type: "sub_recipe".to_string(), + payload: json!({}), // missing "sub_recipe" field + }; + + let task_info = create_task_info_with_defaults(malformed_task, TaskStatus::Pending); + + assert_eq!(get_task_name(&task_info), "task_4"); + } +} + +mod count_by_status { + use super::*; + + fn create_test_task(id: &str, status: TaskStatus) -> TaskInfo { + let task = Task { + id: id.to_string(), + task_type: "test".to_string(), + payload: json!({}), + }; + create_task_info_with_defaults(task, status) + } + + #[test] + fn counts_empty_map() { + let tasks = HashMap::new(); + let (total, pending, running, completed, failed) = count_by_status(&tasks); + assert_eq!( + (total, pending, running, completed, failed), + (0, 0, 0, 0, 0) + ); + } + + #[test] + fn counts_single_status() { + let mut tasks = HashMap::new(); + tasks.insert( + "task1".to_string(), + create_test_task("task1", TaskStatus::Pending), + ); + tasks.insert( + "task2".to_string(), + create_test_task("task2", TaskStatus::Pending), + ); + + let (total, pending, running, completed, failed) = count_by_status(&tasks); + assert_eq!( + (total, pending, running, completed, failed), + (2, 2, 0, 0, 0) + ); + } + + #[test] + fn counts_mixed_statuses() { + let mut tasks = HashMap::new(); + tasks.insert( + "task1".to_string(), + create_test_task("task1", TaskStatus::Pending), + ); + tasks.insert( + "task2".to_string(), + create_test_task("task2", TaskStatus::Running), + ); + tasks.insert( + "task3".to_string(), + create_test_task("task3", TaskStatus::Completed), + ); + tasks.insert( + "task4".to_string(), + create_test_task("task4", TaskStatus::Failed), + ); + tasks.insert( + "task5".to_string(), + create_test_task("task5", TaskStatus::Completed), + ); + + let (total, pending, running, completed, failed) = count_by_status(&tasks); + assert_eq!( + (total, pending, running, completed, failed), + (5, 1, 1, 2, 1) + ); + } +} diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/workers.rs b/crates/goose/src/agents/sub_recipe_execution_tool/workers.rs index e48f19c4..fefbf0eb 100644 --- a/crates/goose/src/agents/sub_recipe_execution_tool/workers.rs +++ b/crates/goose/src/agents/sub_recipe_execution_tool/workers.rs @@ -1,133 +1,30 @@ +use crate::agents::sub_recipe_execution_tool::task_types::{SharedState, Task}; use crate::agents::sub_recipe_execution_tool::tasks::process_task; -use crate::agents::sub_recipe_execution_tool::types::{Task, TaskResult}; -use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; use std::sync::Arc; -use tokio::sync::mpsc; -use tokio::time::{sleep, Duration}; -#[cfg(test)] -mod tests { - use super::*; - use crate::agents::sub_recipe_execution_tool::types::Task; - - #[tokio::test] - async fn test_spawn_worker_returns_handle() { - // Create a simple shared state for testing - let (task_tx, task_rx) = mpsc::channel::(1); - let (result_tx, _result_rx) = mpsc::channel::(1); - - let shared_state = Arc::new(SharedState { - task_receiver: Arc::new(tokio::sync::Mutex::new(task_rx)), - result_sender: result_tx, - active_workers: Arc::new(AtomicUsize::new(0)), - should_stop: Arc::new(AtomicBool::new(false)), - completed_tasks: Arc::new(AtomicUsize::new(0)), - }); - - // Test that spawn_worker returns a JoinHandle - let handle = spawn_worker(shared_state.clone(), 0, 5); - - // Verify it's a JoinHandle by checking we can abort it - assert!(!handle.is_finished()); - - // Signal stop and close the channel to let the worker exit - shared_state.should_stop.store(true, Ordering::SeqCst); - drop(task_tx); // Close the channel - - // Wait for the worker to finish - let result = handle.await; - assert!(result.is_ok()); - } +async fn receive_task(state: &SharedState) -> Option { + let mut receiver = state.task_receiver.lock().await; + receiver.recv().await } -pub struct SharedState { - pub task_receiver: Arc>>, - pub result_sender: mpsc::Sender, - pub active_workers: Arc, - pub should_stop: Arc, - pub completed_tasks: Arc, -} - -// Spawn a worker task -pub fn spawn_worker( - state: Arc, - worker_id: usize, - timeout_seconds: u64, -) -> tokio::task::JoinHandle<()> { - state.active_workers.fetch_add(1, Ordering::SeqCst); +pub fn spawn_worker(state: Arc, worker_id: usize) -> tokio::task::JoinHandle<()> { + state.increment_active_workers(); tokio::spawn(async move { - worker_loop(state, worker_id, timeout_seconds).await; + worker_loop(state, worker_id).await; }) } -async fn worker_loop(state: Arc, _worker_id: usize, timeout_seconds: u64) { - loop { - // Try to receive a task - let task = { - let mut receiver = state.task_receiver.lock().await; - receiver.recv().await - }; +async fn worker_loop(state: Arc, _worker_id: usize) { + while let Some(task) = receive_task(&state).await { + state.task_execution_tracker.start_task(&task.id).await; + let result = process_task(&task, state.task_execution_tracker.clone()).await; - match task { - Some(task) => { - // Process the task - let result = process_task(&task, timeout_seconds).await; - - // Send result - let _ = state.result_sender.send(result).await; - - // Update completed count - state.completed_tasks.fetch_add(1, Ordering::SeqCst); - } - None => { - // Channel closed, exit worker - break; - } - } - - // Check if we should stop - if state.should_stop.load(Ordering::SeqCst) { + if let Err(e) = state.result_sender.send(result).await { + tracing::error!("Worker failed to send result: {}", e); break; } } - // Worker is exiting - state.active_workers.fetch_sub(1, Ordering::SeqCst); -} - -// Scaling controller that monitors queue and spawns workers -pub async fn run_scaler( - state: Arc, - task_count: usize, - max_workers: usize, - timeout_seconds: u64, -) { - let mut worker_count = 0; - - loop { - sleep(Duration::from_millis(100)).await; - - let active = state.active_workers.load(Ordering::SeqCst); - let completed = state.completed_tasks.load(Ordering::SeqCst); - let pending = task_count.saturating_sub(completed); - - // Simple scaling logic: spawn worker if many pending tasks and under limit - if pending > active * 2 && active < max_workers && worker_count < max_workers { - let _handle = spawn_worker(state.clone(), worker_count, timeout_seconds); - worker_count += 1; - } - - // If all tasks completed, signal stop - if completed >= task_count { - state.should_stop.store(true, Ordering::SeqCst); - break; - } - - // If no active workers and tasks remaining, spawn one - if active == 0 && pending > 0 { - let _handle = spawn_worker(state.clone(), worker_count, timeout_seconds); - worker_count += 1; - } - } + state.decrement_active_workers(); } diff --git a/crates/goose/src/agents/sub_recipe_manager.rs b/crates/goose/src/agents/sub_recipe_manager.rs index 2441684b..33229b97 100644 --- a/crates/goose/src/agents/sub_recipe_manager.rs +++ b/crates/goose/src/agents/sub_recipe_manager.rs @@ -7,6 +7,7 @@ use crate::{ recipe_tools::sub_recipe_tools::{ create_sub_recipe_task, create_sub_recipe_task_tool, SUB_RECIPE_TASK_TOOL_NAME_PREFIX, }, + sub_recipe_execution_tool::tasks_manager::TasksManager, tool_execution::ToolCallResult, }, recipe::SubRecipe, @@ -34,12 +35,6 @@ impl SubRecipeManager { pub fn add_sub_recipe_tools(&mut self, sub_recipes_to_add: Vec) { for sub_recipe in sub_recipes_to_add { - // let sub_recipe_key = format!( - // "{}_{}", - // SUB_RECIPE_TOOL_NAME_PREFIX, - // sub_recipe.name.clone() - // ); - // let tool = create_sub_recipe_tool(&sub_recipe); let sub_recipe_key = format!( "{}_{}", SUB_RECIPE_TASK_TOOL_NAME_PREFIX, @@ -59,43 +54,22 @@ impl SubRecipeManager { &self, tool_name: &str, params: Value, + tasks_manager: &TasksManager, ) -> ToolCallResult { - let result = self.call_sub_recipe_tool(tool_name, params).await; + let result = self + .call_sub_recipe_tool(tool_name, params, tasks_manager) + .await; match result { Ok(call_result) => ToolCallResult::from(Ok(call_result)), Err(e) => ToolCallResult::from(Err(ToolError::ExecutionError(e.to_string()))), } } - // async fn call_sub_recipe_tool( - // &self, - // tool_name: &str, - // params: Value, - // ) -> Result, ToolError> { - // let sub_recipe = self.sub_recipes.get(tool_name).ok_or_else(|| { - // let sub_recipe_name = tool_name - // .strip_prefix(SUB_RECIPE_TOOL_NAME_PREFIX) - // .and_then(|s| s.strip_prefix("_")) - // .ok_or_else(|| { - // ToolError::InvalidParameters(format!( - // "Invalid sub-recipe tool name format: {}", - // tool_name - // )) - // }) - // .unwrap(); - - // ToolError::InvalidParameters(format!("Sub-recipe '{}' not found", sub_recipe_name)) - // })?; - - // let output = run_sub_recipe(sub_recipe, params).await.map_err(|e| { - // ToolError::ExecutionError(format!("Sub-recipe execution failed: {}", e)) - // })?; - // Ok(vec![Content::text(output)]) - // } async fn call_sub_recipe_tool( &self, tool_name: &str, params: Value, + tasks_manager: &TasksManager, ) -> Result, ToolError> { let sub_recipe = self.sub_recipes.get(tool_name).ok_or_else(|| { let sub_recipe_name = tool_name @@ -111,11 +85,10 @@ impl SubRecipeManager { ToolError::InvalidParameters(format!("Sub-recipe '{}' not found", sub_recipe_name)) })?; - - let output = create_sub_recipe_task(sub_recipe, params) + let output = create_sub_recipe_task(sub_recipe, params, tasks_manager) .await .map_err(|e| { - ToolError::ExecutionError(format!("Sub-recipe execution failed: {}", e)) + ToolError::ExecutionError(format!("Sub-recipe task createion failed: {}", e)) })?; Ok(vec![Content::text(output)]) } diff --git a/crates/goose/src/recipe/mod.rs b/crates/goose/src/recipe/mod.rs index c38c0697..135cdd32 100644 --- a/crates/goose/src/recipe/mod.rs +++ b/crates/goose/src/recipe/mod.rs @@ -137,6 +137,13 @@ pub struct SubRecipe { pub path: String, #[serde(default, deserialize_with = "deserialize_value_map_as_string")] pub values: Option>, + #[serde(default)] + pub sequential_when_repeated: bool, +} +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct Execution { + #[serde(default)] + pub parallel: bool, } fn deserialize_value_map_as_string<'de, D>(