feat: run sub recipe multiple times in parallel (Experimental feature) (#3274)

Co-authored-by: Wendy Tang <wendytang@squareup.com>
This commit is contained in:
Lifei Zhou
2025-07-17 08:39:35 +10:00
committed by GitHub
parent 3b90282b49
commit e5a55dbddc
30 changed files with 2757 additions and 674 deletions

3
.gitignore vendored
View File

@@ -31,6 +31,9 @@ ui/desktop/src/bin/goose_llm.dll
# Hermit
.hermit/
# Claude
.claude
debug_*.txt
# Docs

View File

@@ -32,6 +32,7 @@ pub fn extract_recipe_info_from_cli(
path: recipe_file_path.to_string_lossy().to_string(),
name,
values: None,
sequential_when_repeated: true,
};
all_sub_recipes.push(additional_sub_recipe);
}

View File

@@ -4,8 +4,11 @@ mod export;
mod input;
mod output;
mod prompt;
mod task_execution_display;
mod thinking;
use crate::session::task_execution_display::TASK_EXECUTION_NOTIFICATION_TYPE;
pub use self::export::message_to_markdown;
pub use builder::{build_session, SessionBuilderConfig, SessionSettings};
use console::Color;
@@ -17,6 +20,8 @@ use goose::permission::PermissionConfirmation;
use goose::providers::base::Provider;
pub use goose::session::Identifier;
use goose::utils::safe_truncate;
use std::io::Write;
use task_execution_display::format_task_execution_notification;
use anyhow::{Context, Result};
use completion::GooseCompleter;
@@ -1008,7 +1013,7 @@ impl Session {
match method.as_str() {
"notifications/message" => {
let data = o.get("data").unwrap_or(&Value::Null);
let (formatted_message, subagent_id, _notification_type) = match data {
let (formatted_message, subagent_id, message_notification_type) = match data {
Value::String(s) => (s.clone(), None, None),
Value::Object(o) => {
// Check for subagent notification structure first
@@ -1059,6 +1064,8 @@ impl Session {
} else if let Some(Value::String(output)) = o.get("output") {
// Fallback for other MCP notification types
(output.to_owned(), None, None)
} else if let Some(result) = format_task_execution_notification(data) {
result
} else {
(data.to_string(), None, None)
}
@@ -1077,7 +1084,19 @@ impl Session {
} else {
progress_bars.log(&formatted_message);
}
} else {
} else if let Some(ref notification_type) = message_notification_type {
if notification_type == TASK_EXECUTION_NOTIFICATION_TYPE {
if interactive {
let _ = progress_bars.hide();
print!("{}", formatted_message);
std::io::stdout().flush().unwrap();
} else {
print!("{}", formatted_message);
std::io::stdout().flush().unwrap();
}
}
}
else {
// Non-subagent notification, display immediately with compact spacing
if interactive {
let _ = progress_bars.hide();

View File

@@ -0,0 +1,247 @@
use goose::agents::sub_recipe_execution_tool::lib::TaskStatus;
use goose::agents::sub_recipe_execution_tool::notification_events::{
TaskExecutionNotificationEvent, TaskInfo,
};
use serde_json::Value;
use std::sync::atomic::{AtomicBool, Ordering};
#[cfg(test)]
mod tests;
const CLEAR_SCREEN: &str = "\x1b[2J\x1b[H";
const MOVE_TO_PROGRESS_LINE: &str = "\x1b[4;1H";
const CLEAR_TO_EOL: &str = "\x1b[K";
const CLEAR_BELOW: &str = "\x1b[J";
pub const TASK_EXECUTION_NOTIFICATION_TYPE: &str = "task_execution";
static INITIAL_SHOWN: AtomicBool = AtomicBool::new(false);
fn format_result_data_for_display(result_data: &Value) -> String {
match result_data {
Value::String(s) => strip_ansi_codes(s),
Value::Object(obj) => {
if let Some(partial_output) = obj.get("partial_output").and_then(|v| v.as_str()) {
format!("Partial output: {}", partial_output)
} else {
serde_json::to_string_pretty(obj).unwrap_or_default()
}
}
Value::Array(arr) => serde_json::to_string_pretty(arr).unwrap_or_default(),
Value::Bool(b) => b.to_string(),
Value::Number(n) => n.to_string(),
Value::Null => "null".to_string(),
}
}
fn process_output_for_display(output: &str) -> String {
const MAX_OUTPUT_LINES: usize = 2;
const OUTPUT_PREVIEW_LENGTH: usize = 100;
let lines: Vec<&str> = output.lines().collect();
let recent_lines = if lines.len() > MAX_OUTPUT_LINES {
&lines[lines.len() - MAX_OUTPUT_LINES..]
} else {
&lines
};
let clean_output = recent_lines.join(" ... ");
let stripped = strip_ansi_codes(&clean_output);
truncate_with_ellipsis(&stripped, OUTPUT_PREVIEW_LENGTH)
}
fn truncate_with_ellipsis(text: &str, max_len: usize) -> String {
if text.len() > max_len {
let mut end = max_len.saturating_sub(3);
while end > 0 && !text.is_char_boundary(end) {
end -= 1;
}
format!("{}...", &text[..end])
} else {
text.to_string()
}
}
fn strip_ansi_codes(text: &str) -> String {
let mut result = String::new();
let mut chars = text.chars();
while let Some(ch) = chars.next() {
if ch == '\x1b' {
if let Some(next_ch) = chars.next() {
if next_ch == '[' {
// This is an ANSI escape sequence, consume until alphabetic character
loop {
match chars.next() {
Some(c) if c.is_ascii_alphabetic() => break,
Some(_) => continue,
None => break,
}
}
} else {
// Not an ANSI sequence, keep both characters
result.push(ch);
result.push(next_ch);
}
} else {
// End of string after \x1b
result.push(ch);
}
} else {
result.push(ch);
}
}
result
}
pub fn format_task_execution_notification(
data: &Value,
) -> Option<(String, Option<String>, Option<String>)> {
if let Ok(event) = serde_json::from_value::<TaskExecutionNotificationEvent>(data.clone()) {
return Some(match event {
TaskExecutionNotificationEvent::LineOutput { output, .. } => (
format!("{}\n", output),
None,
Some(TASK_EXECUTION_NOTIFICATION_TYPE.to_string()),
),
TaskExecutionNotificationEvent::TasksUpdate { .. } => {
let formatted_display = format_tasks_update_from_event(&event);
(
formatted_display,
None,
Some(TASK_EXECUTION_NOTIFICATION_TYPE.to_string()),
)
}
TaskExecutionNotificationEvent::TasksComplete { .. } => {
let formatted_summary = format_tasks_complete_from_event(&event);
(
formatted_summary,
None,
Some(TASK_EXECUTION_NOTIFICATION_TYPE.to_string()),
)
}
});
}
None
}
fn format_tasks_update_from_event(event: &TaskExecutionNotificationEvent) -> String {
if let TaskExecutionNotificationEvent::TasksUpdate { stats, tasks } = event {
let mut display = String::new();
if !INITIAL_SHOWN.swap(true, Ordering::SeqCst) {
display.push_str(CLEAR_SCREEN);
display.push_str("🎯 Task Execution Dashboard\n");
display.push_str("═══════════════════════════\n\n");
} else {
display.push_str(MOVE_TO_PROGRESS_LINE);
}
display.push_str(&format!(
"📊 Progress: {} total | ⏳ {} pending | 🏃 {} running | ✅ {} completed | ❌ {} failed",
stats.total, stats.pending, stats.running, stats.completed, stats.failed
));
display.push_str(&format!("{}\n\n", CLEAR_TO_EOL));
let mut sorted_tasks = tasks.clone();
sorted_tasks.sort_by(|a, b| a.id.cmp(&b.id));
for task in sorted_tasks {
display.push_str(&format_task_display(&task));
}
display.push_str(CLEAR_BELOW);
display
} else {
String::new()
}
}
fn format_tasks_complete_from_event(event: &TaskExecutionNotificationEvent) -> String {
if let TaskExecutionNotificationEvent::TasksComplete {
stats,
failed_tasks,
} = event
{
let mut summary = String::new();
summary.push_str("Execution Complete!\n");
summary.push_str("═══════════════════════\n");
summary.push_str(&format!("Total Tasks: {}\n", stats.total));
summary.push_str(&format!("✅ Completed: {}\n", stats.completed));
summary.push_str(&format!("❌ Failed: {}\n", stats.failed));
summary.push_str(&format!("📈 Success Rate: {:.1}%\n", stats.success_rate));
if !failed_tasks.is_empty() {
summary.push_str("\n❌ Failed Tasks:\n");
for task in failed_tasks {
summary.push_str(&format!("{}\n", task.name));
if let Some(error) = &task.error {
summary.push_str(&format!(" Error: {}\n", error));
}
}
}
summary.push_str("\n📝 Generating summary...\n");
summary
} else {
String::new()
}
}
fn format_task_display(task: &TaskInfo) -> String {
let mut task_display = String::new();
let status_icon = match task.status {
TaskStatus::Pending => "",
TaskStatus::Running => "🏃",
TaskStatus::Completed => "",
TaskStatus::Failed => "",
};
task_display.push_str(&format!(
"{} {} ({}){}\n",
status_icon, task.task_name, task.task_type, CLEAR_TO_EOL
));
if !task.task_metadata.is_empty() {
task_display.push_str(&format!(
" 📋 Parameters: {}{}\n",
task.task_metadata, CLEAR_TO_EOL
));
}
if let Some(duration_secs) = task.duration_secs {
task_display.push_str(&format!(" ⏱️ {:.1}s{}\n", duration_secs, CLEAR_TO_EOL));
}
if matches!(task.status, TaskStatus::Running) && !task.current_output.trim().is_empty() {
let processed_output = process_output_for_display(&task.current_output);
if !processed_output.is_empty() {
task_display.push_str(&format!(" 💬 {}{}\n", processed_output, CLEAR_TO_EOL));
}
}
if matches!(task.status, TaskStatus::Completed) {
if let Some(result_data) = &task.result_data {
let result_preview = format_result_data_for_display(result_data);
if !result_preview.is_empty() {
task_display.push_str(&format!(" 📄 {}{}\n", result_preview, CLEAR_TO_EOL));
}
}
}
if matches!(task.status, TaskStatus::Failed) {
if let Some(error) = &task.error {
let error_preview = truncate_with_ellipsis(error, 80);
task_display.push_str(&format!(
" ⚠️ {}{}\n",
error_preview.replace('\n', " "),
CLEAR_TO_EOL
));
}
}
task_display.push_str(&format!("{}\n", CLEAR_TO_EOL));
task_display
}

View File

@@ -0,0 +1,337 @@
use super::*;
use goose::agents::sub_recipe_execution_tool::notification_events::{
FailedTaskInfo, TaskCompletionStats, TaskExecutionStats,
};
use serde_json::json;
#[test]
fn test_strip_ansi_codes() {
assert_eq!(strip_ansi_codes("hello world"), "hello world");
assert_eq!(strip_ansi_codes("\x1b[31mred text\x1b[0m"), "red text");
assert_eq!(
strip_ansi_codes("\x1b[1;32mbold green\x1b[0m"),
"bold green"
);
assert_eq!(
strip_ansi_codes("normal\x1b[33myellow\x1b[0mnormal"),
"normalyellownormal"
);
assert_eq!(strip_ansi_codes("\x1bhello"), "\x1bhello");
assert_eq!(strip_ansi_codes("hello\x1b"), "hello\x1b");
assert_eq!(strip_ansi_codes(""), "");
}
#[test]
fn test_truncate_with_ellipsis() {
assert_eq!(truncate_with_ellipsis("hello", 10), "hello");
assert_eq!(truncate_with_ellipsis("hello", 5), "hello");
assert_eq!(truncate_with_ellipsis("hello world", 8), "hello...");
assert_eq!(truncate_with_ellipsis("hello", 3), "...");
assert_eq!(truncate_with_ellipsis("hello", 2), "...");
assert_eq!(truncate_with_ellipsis("hello", 1), "...");
assert_eq!(truncate_with_ellipsis("", 5), "");
}
#[test]
fn test_process_output_for_display() {
assert_eq!(process_output_for_display("hello world"), "hello world");
assert_eq!(
process_output_for_display("line1\nline2"),
"line1 ... line2"
);
let input = "line1\nline2\nline3\nline4";
let result = process_output_for_display(input);
assert_eq!(result, "line3 ... line4");
let long_line = "a".repeat(150);
let result = process_output_for_display(&long_line);
assert!(result.len() <= 100);
assert!(result.ends_with("..."));
let ansi_output = "\x1b[31mred line 1\x1b[0m\n\x1b[32mgreen line 2\x1b[0m";
let result = process_output_for_display(ansi_output);
assert_eq!(result, "red line 1 ... green line 2");
assert_eq!(process_output_for_display(""), "");
}
#[test]
fn test_format_result_data_for_display() {
let string_val = json!("hello world");
assert_eq!(format_result_data_for_display(&string_val), "hello world");
let ansi_string = json!("\x1b[31mred text\x1b[0m");
assert_eq!(format_result_data_for_display(&ansi_string), "red text");
assert_eq!(format_result_data_for_display(&json!(true)), "true");
assert_eq!(format_result_data_for_display(&json!(false)), "false");
assert_eq!(format_result_data_for_display(&json!(42)), "42");
assert_eq!(format_result_data_for_display(&json!(3.14)), "3.14");
assert_eq!(format_result_data_for_display(&json!(null)), "null");
let partial_obj = json!({
"partial_output": "some output",
"other_field": "ignored"
});
assert_eq!(
format_result_data_for_display(&partial_obj),
"Partial output: some output"
);
let obj = json!({"key": "value", "num": 42});
let result = format_result_data_for_display(&obj);
assert!(result.contains("key"));
assert!(result.contains("value"));
let arr = json!([1, 2, 3]);
let result = format_result_data_for_display(&arr);
assert!(result.contains("1"));
assert!(result.contains("2"));
assert!(result.contains("3"));
}
#[test]
fn test_format_task_execution_notification_line_output() {
let _event = TaskExecutionNotificationEvent::LineOutput {
task_id: "task-1".to_string(),
output: "Hello World".to_string(),
};
let data = json!({
"subtype": "line_output",
"task_id": "task-1",
"output": "Hello World"
});
let result = format_task_execution_notification(&data);
assert!(result.is_some());
let (formatted, second, third) = result.unwrap();
assert_eq!(formatted, "Hello World\n");
assert_eq!(second, None);
assert_eq!(third, Some("task_execution".to_string()));
}
#[test]
fn test_format_task_execution_notification_invalid_data() {
let invalid_data = json!({
"invalid": "structure"
});
let result = format_task_execution_notification(&invalid_data);
assert_eq!(result, None);
let incomplete_data = json!({
"subtype": "line_output"
});
let result = format_task_execution_notification(&incomplete_data);
assert_eq!(result, None);
}
#[test]
fn test_format_tasks_update_from_event() {
INITIAL_SHOWN.store(false, Ordering::SeqCst);
let stats = TaskExecutionStats::new(3, 1, 1, 1, 0);
let tasks = vec![
TaskInfo {
id: "task-1".to_string(),
status: TaskStatus::Running,
duration_secs: Some(1.5),
current_output: "Processing...".to_string(),
task_type: "sub_recipe".to_string(),
task_name: "test-task".to_string(),
task_metadata: "param=value".to_string(),
error: None,
result_data: None,
},
TaskInfo {
id: "task-2".to_string(),
status: TaskStatus::Completed,
duration_secs: Some(2.3),
current_output: "".to_string(),
task_type: "text_instruction".to_string(),
task_name: "another-task".to_string(),
task_metadata: "".to_string(),
error: None,
result_data: Some(json!({"result": "success"})),
},
];
let event = TaskExecutionNotificationEvent::TasksUpdate { stats, tasks };
let result = format_tasks_update_from_event(&event);
assert!(result.contains("🎯 Task Execution Dashboard"));
assert!(result.contains("═══════════════════════════"));
assert!(result.contains("📊 Progress: 3 total"));
assert!(result.contains("⏳ 1 pending"));
assert!(result.contains("🏃 1 running"));
assert!(result.contains("✅ 1 completed"));
assert!(result.contains("❌ 0 failed"));
assert!(result.contains("🏃 test-task"));
assert!(result.contains("✅ another-task"));
assert!(result.contains("📋 Parameters: param=value"));
assert!(result.contains("⏱️ 1.5s"));
assert!(result.contains("💬 Processing..."));
let result2 = format_tasks_update_from_event(&event);
assert!(!result2.contains("🎯 Task Execution Dashboard"));
assert!(result2.contains(MOVE_TO_PROGRESS_LINE));
}
#[test]
fn test_format_tasks_complete_from_event() {
let stats = TaskCompletionStats::new(5, 4, 1);
let failed_tasks = vec![FailedTaskInfo {
id: "task-3".to_string(),
name: "failed-task".to_string(),
error: Some("Connection timeout".to_string()),
}];
let event = TaskExecutionNotificationEvent::TasksComplete {
stats,
failed_tasks,
};
let result = format_tasks_complete_from_event(&event);
assert!(result.contains("Execution Complete!"));
assert!(result.contains("═══════════════════════"));
assert!(result.contains("Total Tasks: 5"));
assert!(result.contains("✅ Completed: 4"));
assert!(result.contains("❌ Failed: 1"));
assert!(result.contains("📈 Success Rate: 80.0%"));
assert!(result.contains("❌ Failed Tasks:"));
assert!(result.contains("• failed-task"));
assert!(result.contains("Error: Connection timeout"));
assert!(result.contains("📝 Generating summary..."));
}
#[test]
fn test_format_tasks_complete_from_event_no_failures() {
let stats = TaskCompletionStats::new(3, 3, 0);
let failed_tasks = vec![];
let event = TaskExecutionNotificationEvent::TasksComplete {
stats,
failed_tasks,
};
let result = format_tasks_complete_from_event(&event);
assert!(!result.contains("❌ Failed Tasks:"));
assert!(result.contains("📈 Success Rate: 100.0%"));
assert!(result.contains("❌ Failed: 0"));
}
#[test]
fn test_format_task_display_running() {
let task = TaskInfo {
id: "task-1".to_string(),
status: TaskStatus::Running,
duration_secs: Some(1.5),
current_output: "Processing data...\nAlmost done...".to_string(),
task_type: "sub_recipe".to_string(),
task_name: "data-processor".to_string(),
task_metadata: "input=file.txt,output=result.json".to_string(),
error: None,
result_data: None,
};
let result = format_task_display(&task);
assert!(result.contains("🏃 data-processor (sub_recipe)"));
assert!(result.contains("📋 Parameters: input=file.txt,output=result.json"));
assert!(result.contains("⏱️ 1.5s"));
assert!(result.contains("💬 Processing data... ... Almost done..."));
}
#[test]
fn test_format_task_display_completed() {
let task = TaskInfo {
id: "task-2".to_string(),
status: TaskStatus::Completed,
duration_secs: Some(3.2),
current_output: "".to_string(),
task_type: "text_instruction".to_string(),
task_name: "analyzer".to_string(),
task_metadata: "".to_string(),
error: None,
result_data: Some(json!({"status": "success", "count": 42})),
};
let result = format_task_display(&task);
assert!(result.contains("✅ analyzer (text_instruction)"));
assert!(result.contains("⏱️ 3.2s"));
assert!(!result.contains("📋 Parameters"));
assert!(result.contains("📄"));
}
#[test]
fn test_format_task_display_failed() {
let task = TaskInfo {
id: "task-3".to_string(),
status: TaskStatus::Failed,
duration_secs: None,
current_output: "".to_string(),
task_type: "sub_recipe".to_string(),
task_name: "failing-task".to_string(),
task_metadata: "".to_string(),
error: Some(
"Network connection failed after multiple retries. The server is unreachable."
.to_string(),
),
result_data: None,
};
let result = format_task_display(&task);
assert!(result.contains("❌ failing-task (sub_recipe)"));
assert!(!result.contains("⏱️"));
assert!(result.contains("⚠️"));
assert!(result.contains("Network connection failed after multiple retries"));
}
#[test]
fn test_format_task_display_pending() {
let task = TaskInfo {
id: "task-4".to_string(),
status: TaskStatus::Pending,
duration_secs: None,
current_output: "".to_string(),
task_type: "sub_recipe".to_string(),
task_name: "waiting-task".to_string(),
task_metadata: "priority=high".to_string(),
error: None,
result_data: None,
};
let result = format_task_display(&task);
assert!(result.contains("⏳ waiting-task (sub_recipe)"));
assert!(result.contains("📋 Parameters: priority=high"));
assert!(!result.contains("⏱️"));
assert!(!result.contains("💬"));
assert!(!result.contains("📄"));
assert!(!result.contains("⚠️"));
}
#[test]
fn test_format_task_display_empty_current_output() {
let task = TaskInfo {
id: "task-5".to_string(),
status: TaskStatus::Running,
duration_secs: Some(0.5),
current_output: " \n\t \n ".to_string(),
task_type: "sub_recipe".to_string(),
task_name: "quiet-task".to_string(),
task_metadata: "".to_string(),
error: None,
result_data: None,
};
let result = format_task_display(&task);
assert!(!result.contains("💬"));
}

View File

@@ -12,6 +12,7 @@ use crate::agents::final_output_tool::{FINAL_OUTPUT_CONTINUATION_MESSAGE, FINAL_
use crate::agents::sub_recipe_execution_tool::sub_recipe_execute_task_tool::{
self, SUB_RECIPE_EXECUTE_TASK_TOOL_NAME,
};
use crate::agents::sub_recipe_execution_tool::tasks_manager::TasksManager;
use crate::agents::sub_recipe_manager::SubRecipeManager;
use crate::config::{Config, ExtensionConfigManager, PermissionManager};
use crate::message::{push_message, Message};
@@ -63,6 +64,7 @@ pub struct Agent {
pub(super) provider: Mutex<Option<Arc<dyn Provider>>>,
pub(super) extension_manager: RwLock<ExtensionManager>,
pub(super) sub_recipe_manager: Mutex<SubRecipeManager>,
pub(super) tasks_manager: TasksManager,
pub(super) final_output_tool: Mutex<Option<FinalOutputTool>>,
pub(super) frontend_tools: Mutex<HashMap<String, FrontendTool>>,
pub(super) frontend_instructions: Mutex<Option<String>>,
@@ -137,6 +139,7 @@ impl Agent {
provider: Mutex::new(None),
extension_manager: RwLock::new(ExtensionManager::new()),
sub_recipe_manager: Mutex::new(SubRecipeManager::new()),
tasks_manager: TasksManager::new(),
final_output_tool: Mutex::new(None),
frontend_tools: Mutex::new(HashMap::new()),
frontend_instructions: Mutex::new(None),
@@ -291,10 +294,18 @@ impl Agent {
let sub_recipe_manager = self.sub_recipe_manager.lock().await;
let result: ToolCallResult = if sub_recipe_manager.is_sub_recipe_tool(&tool_call.name) {
sub_recipe_manager
.dispatch_sub_recipe_tool_call(&tool_call.name, tool_call.arguments.clone())
.dispatch_sub_recipe_tool_call(
&tool_call.name,
tool_call.arguments.clone(),
&self.tasks_manager,
)
.await
} else if tool_call.name == SUB_RECIPE_EXECUTE_TASK_TOOL_NAME {
sub_recipe_execute_task_tool::run_tasks(tool_call.arguments.clone()).await
sub_recipe_execute_task_tool::run_tasks(
tool_call.arguments.clone(),
&self.tasks_manager,
)
.await
} else if tool_call.name == PLATFORM_READ_RESOURCE_TOOL_NAME {
// Check if the tool is read_resource and handle it separately
ToolCallResult::from(

View File

@@ -1 +1,2 @@
pub mod param_utils;
pub mod sub_recipe_tools;

View File

@@ -0,0 +1,38 @@
use anyhow::Result;
use serde_json::Value;
use std::collections::HashMap;
use crate::recipe::SubRecipe;
pub fn prepare_command_params(
sub_recipe: &SubRecipe,
params_from_tool_call: Vec<Value>,
) -> Result<Vec<HashMap<String, String>>> {
let base_params = sub_recipe.values.clone().unwrap_or_default();
if params_from_tool_call.is_empty() {
return Ok(vec![base_params]);
}
let result = params_from_tool_call
.into_iter()
.map(|tool_param| {
let mut param_map = base_params.clone();
if let Some(param_obj) = tool_param.as_object() {
for (key, value) in param_obj {
let value_str = value
.as_str()
.map(String::from)
.unwrap_or_else(|| value.to_string());
param_map.entry(key.clone()).or_insert(value_str);
}
}
param_map
})
.collect();
Ok(result)
}
#[cfg(test)]
mod tests;

View File

@@ -0,0 +1,140 @@
use std::collections::HashMap;
use crate::recipe::SubRecipe;
use serde_json::json;
use crate::agents::recipe_tools::param_utils::prepare_command_params;
fn setup_default_sub_recipe() -> SubRecipe {
let sub_recipe = SubRecipe {
name: "test_sub_recipe".to_string(),
path: "test_sub_recipe.yaml".to_string(),
values: Some(HashMap::from([("key1".to_string(), "value1".to_string())])),
sequential_when_repeated: true,
};
sub_recipe
}
mod prepare_command_params_tests {
use super::*;
#[test]
fn test_return_command_param() {
let parameter_array = vec![json!(HashMap::from([(
"key2".to_string(),
"value2".to_string()
)]))];
let mut sub_recipe = setup_default_sub_recipe();
sub_recipe.values = Some(HashMap::from([("key1".to_string(), "value1".to_string())]));
let result = prepare_command_params(&sub_recipe, parameter_array).unwrap();
assert_eq!(
vec![HashMap::from([
("key1".to_string(), "value1".to_string()),
("key2".to_string(), "value2".to_string())
]),],
result
);
}
#[test]
fn test_return_command_param_when_value_override_passed_param_value() {
let parameter_array = vec![json!(HashMap::from([(
"key2".to_string(),
"different_value".to_string()
)]))];
let mut sub_recipe = setup_default_sub_recipe();
sub_recipe.values = Some(HashMap::from([
("key1".to_string(), "value1".to_string()),
("key2".to_string(), "value2".to_string()),
]));
let result = prepare_command_params(&sub_recipe, parameter_array).unwrap();
assert_eq!(
vec![HashMap::from([
("key1".to_string(), "value1".to_string()),
("key2".to_string(), "value2".to_string())
]),],
result
);
}
#[test]
fn test_return_empty_command_param() {
let parameter_array = vec![];
let mut sub_recipe = setup_default_sub_recipe();
sub_recipe.values = None;
let result = prepare_command_params(&sub_recipe, parameter_array).unwrap();
assert_eq!(result, vec![HashMap::new()]);
}
mod multiple_tool_parameters {
use super::*;
#[test]
fn test_return_command_param_when_all_values_from_tool_call_parameters() {
let parameter_array = vec![
json!(HashMap::from([
("key1".to_string(), "key1_value1".to_string()),
("key2".to_string(), "key2_value1".to_string())
])),
json!(HashMap::from([
("key1".to_string(), "key1_value2".to_string()),
("key2".to_string(), "key2_value2".to_string())
])),
];
let mut sub_recipe = setup_default_sub_recipe();
sub_recipe.values = None;
let result = prepare_command_params(&sub_recipe, parameter_array).unwrap();
assert_eq!(
vec![
HashMap::from([
("key1".to_string(), "key1_value1".to_string()),
("key2".to_string(), "key2_value1".to_string()),
]),
HashMap::from([
("key1".to_string(), "key1_value2".to_string()),
("key2".to_string(), "key2_value2".to_string()),
]),
],
result
);
}
#[test]
fn test_merge_base_values_with_tool_parameters() {
let parameter_array = vec![
json!(HashMap::from([(
"key2".to_string(),
"override_value1".to_string()
)])),
json!(HashMap::from([(
"key2".to_string(),
"override_value2".to_string()
)])),
];
let mut sub_recipe = setup_default_sub_recipe();
sub_recipe.values = Some(HashMap::from([
("key1".to_string(), "base_value".to_string()),
("key2".to_string(), "original_value".to_string()),
]));
let result = prepare_command_params(&sub_recipe, parameter_array).unwrap();
assert_eq!(
vec![
HashMap::from([
("key1".to_string(), "base_value".to_string()),
("key2".to_string(), "original_value".to_string()),
]),
HashMap::from([
("key1".to_string(), "base_value".to_string()),
("key2".to_string(), "original_value".to_string()),
]),
],
result
);
}
}
}

View File

@@ -1,22 +1,35 @@
use std::{collections::HashMap, fs};
use std::collections::HashSet;
use std::fs;
use anyhow::Result;
use mcp_core::tool::{Tool, ToolAnnotations};
use serde_json::{json, Map, Value};
use crate::agents::sub_recipe_execution_tool::lib::Task;
use crate::agents::sub_recipe_execution_tool::lib::{ExecutionMode, Task};
use crate::agents::sub_recipe_execution_tool::tasks_manager::TasksManager;
use crate::recipe::{Recipe, RecipeParameter, RecipeParameterRequirement, SubRecipe};
use super::param_utils::prepare_command_params;
pub const SUB_RECIPE_TASK_TOOL_NAME_PREFIX: &str = "subrecipe__create_task";
pub fn create_sub_recipe_task_tool(sub_recipe: &SubRecipe) -> Tool {
let input_schema = get_input_schema(sub_recipe).unwrap();
Tool::new(
format!("{}_{}", SUB_RECIPE_TASK_TOOL_NAME_PREFIX, sub_recipe.name),
"Before running this sub recipe, you should first create a task with this tool and then pass the task to the task executor".to_string(),
format!(
"Create one or more tasks to run the '{}' sub recipe. \
Provide an array of parameter sets in the 'task_parameters' field:\n\
- For a single task: provide an array with one parameter set\n\
- For multiple tasks: provide an array with multiple parameter sets, each with different values\n\n\
Each task will run the same sub recipe but with different parameter values. \
This is useful when you need to execute the same sub recipe multiple times with varying inputs. \
After creating the tasks and execution_mode is provided, pass them to the task executor to run these tasks",
sub_recipe.name
),
input_schema,
Some(ToolAnnotations {
title: Some(format!("create sub recipe task {}", sub_recipe.name)),
title: Some(format!("create multiple sub recipe tasks for {}", sub_recipe.name)),
read_only_hint: false,
destructive_hint: true,
idempotent_hint: false,
@@ -25,6 +38,64 @@ pub fn create_sub_recipe_task_tool(sub_recipe: &SubRecipe) -> Tool {
)
}
fn extract_task_parameters(params: &Value) -> Vec<Value> {
params
.get("task_parameters")
.and_then(|v| v.as_array())
.cloned()
.unwrap_or_default()
}
fn create_tasks_from_params(
sub_recipe: &SubRecipe,
command_params: &[std::collections::HashMap<String, String>],
) -> Vec<Task> {
let tasks: Vec<Task> = command_params
.iter()
.map(|task_command_param| {
let payload = json!({
"sub_recipe": {
"name": sub_recipe.name.clone(),
"command_parameters": task_command_param,
"recipe_path": sub_recipe.path.clone(),
"sequential_when_repeated": sub_recipe.sequential_when_repeated
}
});
Task {
id: uuid::Uuid::new_v4().to_string(),
task_type: "sub_recipe".to_string(),
payload,
}
})
.collect();
tasks
}
fn create_task_execution_payload(tasks: &[Task], sub_recipe: &SubRecipe) -> Value {
let task_ids: Vec<String> = tasks.iter().map(|task| task.id.clone()).collect();
json!({
"task_ids": task_ids,
"execution_mode": if sub_recipe.sequential_when_repeated { ExecutionMode::Sequential } else { ExecutionMode::Parallel },
})
}
pub async fn create_sub_recipe_task(
sub_recipe: &SubRecipe,
params: Value,
tasks_manager: &TasksManager,
) -> Result<String> {
let task_params_array = extract_task_parameters(&params);
let command_params = prepare_command_params(sub_recipe, task_params_array.clone())?;
let tasks = create_tasks_from_params(sub_recipe, &command_params);
let task_execution_payload = create_task_execution_payload(&tasks, sub_recipe);
let tasks_json = serde_json::to_string(&task_execution_payload)
.map_err(|e| anyhow::anyhow!("Failed to serialize task list: {}", e))?;
tasks_manager.save_tasks(tasks.clone()).await;
Ok(tasks_json)
}
fn get_sub_recipe_parameter_definition(
sub_recipe: &SubRecipe,
) -> Result<Option<Vec<RecipeParameter>>> {
@@ -34,22 +105,55 @@ fn get_sub_recipe_parameter_definition(
Ok(recipe.parameters)
}
fn get_input_schema(sub_recipe: &SubRecipe) -> Result<Value> {
let mut sub_recipe_params_map = HashMap::<String, String>::new();
fn get_params_with_values(sub_recipe: &SubRecipe) -> HashSet<String> {
let mut sub_recipe_params_with_values = HashSet::<String>::new();
if let Some(params_with_value) = &sub_recipe.values {
for (param_name, param_value) in params_with_value {
sub_recipe_params_map.insert(param_name.clone(), param_value.clone());
for param_name in params_with_value.keys() {
sub_recipe_params_with_values.insert(param_name.clone());
}
}
sub_recipe_params_with_values
}
fn create_input_schema(param_properties: Map<String, Value>, param_required: Vec<String>) -> Value {
let mut properties = Map::new();
if !param_properties.is_empty() {
properties.insert(
"task_parameters".to_string(),
json!({
"type": "array",
"description": "Array of parameter sets for creating tasks. \
For a single task, provide an array with one element. \
For multiple tasks, provide an array with multiple elements, each with different parameter values. \
If there is no parameter set, provide an empty array.",
"items": {
"type": "object",
"properties": param_properties,
"required": param_required
},
})
);
}
json!({
"type": "object",
"properties": properties,
})
}
fn get_input_schema(sub_recipe: &SubRecipe) -> Result<Value> {
let sub_recipe_params_with_values = get_params_with_values(sub_recipe);
let parameter_definition = get_sub_recipe_parameter_definition(sub_recipe)?;
let mut param_properties = Map::new();
let mut param_required = Vec::new();
if let Some(parameters) = parameter_definition {
let mut properties = Map::new();
let mut required = Vec::new();
for param in parameters {
if sub_recipe_params_map.contains_key(&param.key) {
if sub_recipe_params_with_values.contains(&param.key.clone()) {
continue;
}
properties.insert(
param_properties.insert(
param.key.clone(),
json!({
"type": param.input_type.to_string(),
@@ -57,60 +161,11 @@ fn get_input_schema(sub_recipe: &SubRecipe) -> Result<Value> {
}),
);
if !matches!(param.requirement, RecipeParameterRequirement::Optional) {
required.push(param.key);
param_required.push(param.key);
}
}
Ok(json!({
"type": "object",
"properties": properties,
"required": required
}))
} else {
Ok(json!({
"type": "object",
"properties": {}
}))
}
}
fn prepare_command_params(
sub_recipe: &SubRecipe,
params_from_tool_call: Value,
) -> Result<HashMap<String, String>> {
let mut sub_recipe_params = HashMap::<String, String>::new();
if let Some(params_with_value) = &sub_recipe.values {
for (param_name, param_value) in params_with_value {
sub_recipe_params.insert(param_name.clone(), param_value.clone());
}
}
if let Some(params_map) = params_from_tool_call.as_object() {
for (key, value) in params_map {
sub_recipe_params.insert(
key.to_string(),
value.as_str().unwrap_or(&value.to_string()).to_string(),
);
}
}
Ok(sub_recipe_params)
}
pub async fn create_sub_recipe_task(sub_recipe: &SubRecipe, params: Value) -> Result<String> {
let command_params = prepare_command_params(sub_recipe, params)?;
let payload = json!({
"sub_recipe": {
"name": sub_recipe.name.clone(),
"command_parameters": command_params,
"recipe_path": sub_recipe.path.clone(),
}
});
let task = Task {
id: uuid::Uuid::new_v4().to_string(),
task_type: "sub_recipe".to_string(),
payload,
};
let task_json = serde_json::to_string(&task)
.map_err(|e| anyhow::anyhow!("Failed to serialize Task: {}", e))?;
Ok(task_json)
Ok(create_input_schema(param_properties, param_required))
}
#[cfg(test)]

View File

@@ -3,66 +3,48 @@ mod tests {
use std::collections::HashMap;
use crate::recipe::SubRecipe;
use serde_json::json;
use serde_json::Value;
use tempfile::TempDir;
fn setup_sub_recipe() -> SubRecipe {
fn setup_default_sub_recipe() -> SubRecipe {
let sub_recipe = SubRecipe {
name: "test_sub_recipe".to_string(),
path: "test_sub_recipe.yaml".to_string(),
values: Some(HashMap::from([("key1".to_string(), "value1".to_string())])),
sequential_when_repeated: true,
};
sub_recipe
}
mod prepare_command_params_tests {
use std::collections::HashMap;
use crate::{
agents::recipe_tools::sub_recipe_tools::{
prepare_command_params, tests::tests::setup_sub_recipe,
},
recipe::SubRecipe,
};
mod get_input_schema {
use super::*;
use crate::agents::recipe_tools::sub_recipe_tools::get_input_schema;
#[test]
fn test_prepare_command_params_basic() {
let mut params = HashMap::new();
params.insert("key2".to_string(), "value2".to_string());
let sub_recipe = setup_sub_recipe();
let params_value = serde_json::to_value(params).unwrap();
let result = prepare_command_params(&sub_recipe, params_value).unwrap();
assert_eq!(result.len(), 2);
assert_eq!(result.get("key1"), Some(&"value1".to_string()));
assert_eq!(result.get("key2"), Some(&"value2".to_string()));
fn prepare_sub_recipe(sub_recipe_file_content: &str) -> (SubRecipe, TempDir) {
let mut sub_recipe = setup_default_sub_recipe();
let temp_dir = tempfile::tempdir().unwrap();
let temp_file = temp_dir.path().join(sub_recipe.path.clone());
std::fs::write(&temp_file, sub_recipe_file_content).unwrap();
sub_recipe.path = temp_file.to_string_lossy().to_string();
(sub_recipe, temp_dir)
}
#[test]
fn test_prepare_command_params_empty() {
let sub_recipe = SubRecipe {
name: "test_sub_recipe".to_string(),
path: "test_sub_recipe.yaml".to_string(),
values: None,
};
let params: HashMap<String, String> = HashMap::new();
let params_value = serde_json::to_value(params).unwrap();
let result = prepare_command_params(&sub_recipe, params_value).unwrap();
assert_eq!(result.len(), 0);
fn verify_task_parameters(result: Value, expected_task_parameters_items: Value) {
let task_parameters = result
.get("properties")
.unwrap()
.as_object()
.unwrap()
.get("task_parameters")
.unwrap()
.as_object()
.unwrap();
let task_parameters_items = task_parameters.get("items").unwrap();
assert_eq!(&expected_task_parameters_items, task_parameters_items);
}
}
mod get_input_schema_tests {
use crate::{
agents::recipe_tools::sub_recipe_tools::{
get_input_schema, tests::tests::setup_sub_recipe,
},
recipe::SubRecipe,
};
#[test]
fn test_get_input_schema_with_parameters() {
let sub_recipe = setup_sub_recipe();
let sub_recipe_file_content = r#"{
const SUB_RECIPE_FILE_CONTENT_WITH_TWO_PARAMS: &str = r#"{
"version": "1.0.0",
"title": "Test Recipe",
"description": "A test recipe",
@@ -83,73 +65,67 @@ mod tests {
]
}"#;
let temp_dir = tempfile::tempdir().unwrap();
let temp_file = temp_dir.path().join("test_sub_recipe.yaml");
std::fs::write(&temp_file, sub_recipe_file_content).unwrap();
let mut sub_recipe = sub_recipe;
sub_recipe.path = temp_file.to_string_lossy().to_string();
#[test]
fn test_with_one_param_in_tool_input() {
let (mut sub_recipe, _temp_dir) =
prepare_sub_recipe(SUB_RECIPE_FILE_CONTENT_WITH_TWO_PARAMS);
sub_recipe.values = Some(HashMap::from([("key1".to_string(), "value1".to_string())]));
let result = get_input_schema(&sub_recipe).unwrap();
// Verify the schema structure
assert_eq!(result["type"], "object");
assert!(result["properties"].is_object());
let properties = result["properties"].as_object().unwrap();
assert_eq!(properties.len(), 1);
let key2_prop = &properties["key2"];
assert_eq!(key2_prop["type"], "number");
assert_eq!(key2_prop["description"], "An optional parameter");
let required = result["required"].as_array().unwrap();
assert_eq!(required.len(), 0);
verify_task_parameters(
result,
json!({
"type": "object",
"properties": {
"key2": { "type": "number", "description": "An optional parameter" }
},
"required": []
}),
);
}
#[test]
fn test_get_input_schema_no_parameters_values() {
let sub_recipe = SubRecipe {
name: "test_sub_recipe".to_string(),
path: "test_sub_recipe.yaml".to_string(),
values: None,
};
let sub_recipe_file_content = r#"{
"version": "1.0.0",
"title": "Test Recipe",
"description": "A test recipe",
"prompt": "Test prompt",
"parameters": [
{
"key": "key1",
"input_type": "string",
"requirement": "required",
"description": "A test parameter"
}
]
}"#;
let temp_dir = tempfile::tempdir().unwrap();
let temp_file = temp_dir.path().join("test_sub_recipe.yaml");
std::fs::write(&temp_file, sub_recipe_file_content).unwrap();
let mut sub_recipe = sub_recipe;
sub_recipe.path = temp_file.to_string_lossy().to_string();
fn test_without_param_in_tool_input() {
let (mut sub_recipe, _temp_dir) =
prepare_sub_recipe(SUB_RECIPE_FILE_CONTENT_WITH_TWO_PARAMS);
sub_recipe.values = Some(HashMap::from([
("key1".to_string(), "value1".to_string()),
("key2".to_string(), "value2".to_string()),
]));
let result = get_input_schema(&sub_recipe).unwrap();
assert_eq!(result["type"], "object");
assert!(result["properties"].is_object());
assert_eq!(
None,
result
.get("properties")
.unwrap()
.as_object()
.unwrap()
.get("task_parameters")
);
}
let properties = result["properties"].as_object().unwrap();
assert_eq!(properties.len(), 1);
#[test]
fn test_with_all_params_in_tool_input() {
let (mut sub_recipe, _temp_dir) =
prepare_sub_recipe(SUB_RECIPE_FILE_CONTENT_WITH_TWO_PARAMS);
sub_recipe.values = None;
let key1_prop = &properties["key1"];
assert_eq!(key1_prop["type"], "string");
assert_eq!(key1_prop["description"], "A test parameter");
assert_eq!(result["required"].as_array().unwrap().len(), 1);
assert_eq!(result["required"][0], "key1");
let result = get_input_schema(&sub_recipe).unwrap();
verify_task_parameters(
result,
json!({
"type": "object",
"properties": {
"key1": { "type": "string", "description": "A test parameter" },
"key2": { "type": "number", "description": "An optional parameter" }
},
"required": ["key1"]
}),
);
}
}
}

View File

@@ -1,103 +0,0 @@
use std::sync::atomic::{AtomicBool, AtomicUsize};
use std::sync::Arc;
use tokio::sync::mpsc;
use tokio::time::Instant;
use crate::agents::sub_recipe_execution_tool::lib::{
Config, ExecutionResponse, ExecutionStats, Task, TaskResult,
};
use crate::agents::sub_recipe_execution_tool::tasks::process_task;
use crate::agents::sub_recipe_execution_tool::workers::{run_scaler, spawn_worker, SharedState};
pub async fn execute_single_task(task: &Task, config: Config) -> ExecutionResponse {
let start_time = Instant::now();
let result = process_task(task, config.timeout_seconds).await;
let execution_time = start_time.elapsed().as_millis();
let completed = if result.status == "success" { 1 } else { 0 };
let failed = if result.status == "failed" { 1 } else { 0 };
ExecutionResponse {
status: "completed".to_string(),
results: vec![result],
stats: ExecutionStats {
total_tasks: 1,
completed,
failed,
execution_time_ms: execution_time,
},
}
}
// Main parallel execution function
pub async fn parallel_execute(tasks: Vec<Task>, config: Config) -> ExecutionResponse {
let start_time = Instant::now();
let task_count = tasks.len();
// Create channels
let (task_tx, task_rx) = mpsc::channel::<Task>(task_count);
let (result_tx, mut result_rx) = mpsc::channel::<TaskResult>(task_count);
// Initialize shared state
let shared_state = Arc::new(SharedState {
task_receiver: Arc::new(tokio::sync::Mutex::new(task_rx)),
result_sender: result_tx,
active_workers: Arc::new(AtomicUsize::new(0)),
should_stop: Arc::new(AtomicBool::new(false)),
completed_tasks: Arc::new(AtomicUsize::new(0)),
});
// Send all tasks to the queue
for task in tasks.clone() {
let _ = task_tx.send(task).await;
}
// Close sender so workers know when queue is empty
drop(task_tx);
// Start initial workers
let mut worker_handles = Vec::new();
for i in 0..config.initial_workers {
let handle = spawn_worker(shared_state.clone(), i, config.timeout_seconds);
worker_handles.push(handle);
}
// Start the scaler
let scaler_state = shared_state.clone();
let scaler_handle = tokio::spawn(async move {
run_scaler(
scaler_state,
task_count,
config.max_workers,
config.timeout_seconds,
)
.await;
});
// Collect results
let mut results = Vec::new();
while let Some(result) = result_rx.recv().await {
results.push(result);
if results.len() >= task_count {
break;
}
}
// Wait for scaler to finish
let _ = scaler_handle.await;
// Calculate stats
let execution_time = start_time.elapsed().as_millis();
let completed = results.iter().filter(|r| r.status == "success").count();
let failed = results.iter().filter(|r| r.status == "failed").count();
ExecutionResponse {
status: "completed".to_string(),
results,
stats: ExecutionStats {
total_tasks: task_count,
completed,
failed,
execution_time_ms: execution_time,
},
}
}

View File

@@ -0,0 +1,197 @@
use mcp_core::protocol::JsonRpcMessage;
use std::sync::atomic::AtomicUsize;
use std::sync::Arc;
use tokio::sync::mpsc;
use tokio::time::Instant;
use crate::agents::sub_recipe_execution_tool::lib::{
ExecutionResponse, ExecutionStats, SharedState, Task, TaskResult, TaskStatus,
};
use crate::agents::sub_recipe_execution_tool::task_execution_tracker::{
DisplayMode, TaskExecutionTracker,
};
use crate::agents::sub_recipe_execution_tool::tasks::process_task;
use crate::agents::sub_recipe_execution_tool::workers::spawn_worker;
#[cfg(test)]
mod tests;
const EXECUTION_STATUS_COMPLETED: &str = "completed";
const DEFAULT_MAX_WORKERS: usize = 10;
pub async fn execute_single_task(
task: &Task,
notifier: mpsc::Sender<JsonRpcMessage>,
) -> ExecutionResponse {
let start_time = Instant::now();
let task_execution_tracker = Arc::new(TaskExecutionTracker::new(
vec![task.clone()],
DisplayMode::SingleTaskOutput,
notifier,
));
let result = process_task(task, task_execution_tracker).await;
let execution_time = start_time.elapsed().as_millis();
let stats = calculate_stats(&[result.clone()], execution_time);
ExecutionResponse {
status: EXECUTION_STATUS_COMPLETED.to_string(),
results: vec![result],
stats,
}
}
pub async fn execute_tasks_in_parallel(
tasks: Vec<Task>,
notifier: mpsc::Sender<JsonRpcMessage>,
) -> ExecutionResponse {
let task_execution_tracker = Arc::new(TaskExecutionTracker::new(
tasks.clone(),
DisplayMode::MultipleTasksOutput,
notifier,
));
let start_time = Instant::now();
let task_count = tasks.len();
if task_count == 0 {
return create_empty_response();
}
task_execution_tracker.refresh_display().await;
let (task_tx, task_rx, result_tx, mut result_rx) = create_channels(task_count);
if let Err(e) = send_tasks_to_channel(tasks, task_tx).await {
tracing::error!("Task execution failed: {}", e);
return create_error_response(e);
}
let shared_state = create_shared_state(task_rx, result_tx, task_execution_tracker.clone());
let worker_count = std::cmp::min(task_count, DEFAULT_MAX_WORKERS);
let mut worker_handles = Vec::new();
for i in 0..worker_count {
let handle = spawn_worker(shared_state.clone(), i);
worker_handles.push(handle);
}
let results = collect_results(&mut result_rx, task_execution_tracker.clone(), task_count).await;
for handle in worker_handles {
if let Err(e) = handle.await {
tracing::error!("Worker error: {}", e);
}
}
task_execution_tracker.send_tasks_complete().await;
let execution_time = start_time.elapsed().as_millis();
let stats = calculate_stats(&results, execution_time);
ExecutionResponse {
status: EXECUTION_STATUS_COMPLETED.to_string(),
results,
stats,
}
}
fn calculate_stats(results: &[TaskResult], execution_time_ms: u128) -> ExecutionStats {
let completed = results
.iter()
.filter(|r| matches!(r.status, TaskStatus::Completed))
.count();
let failed = results
.iter()
.filter(|r| matches!(r.status, TaskStatus::Failed))
.count();
ExecutionStats {
total_tasks: results.len(),
completed,
failed,
execution_time_ms,
}
}
fn create_channels(
task_count: usize,
) -> (
mpsc::Sender<Task>,
mpsc::Receiver<Task>,
mpsc::Sender<TaskResult>,
mpsc::Receiver<TaskResult>,
) {
let (task_tx, task_rx) = mpsc::channel::<Task>(task_count);
let (result_tx, result_rx) = mpsc::channel::<TaskResult>(task_count);
(task_tx, task_rx, result_tx, result_rx)
}
fn create_shared_state(
task_rx: mpsc::Receiver<Task>,
result_tx: mpsc::Sender<TaskResult>,
task_execution_tracker: Arc<TaskExecutionTracker>,
) -> Arc<SharedState> {
Arc::new(SharedState {
task_receiver: Arc::new(tokio::sync::Mutex::new(task_rx)),
result_sender: result_tx,
active_workers: Arc::new(AtomicUsize::new(0)),
task_execution_tracker,
})
}
async fn send_tasks_to_channel(
tasks: Vec<Task>,
task_tx: mpsc::Sender<Task>,
) -> Result<(), String> {
for task in tasks {
task_tx
.send(task)
.await
.map_err(|e| format!("Failed to queue task: {}", e))?;
}
Ok(())
}
fn create_empty_response() -> ExecutionResponse {
ExecutionResponse {
status: EXECUTION_STATUS_COMPLETED.to_string(),
results: vec![],
stats: ExecutionStats {
total_tasks: 0,
completed: 0,
failed: 0,
execution_time_ms: 0,
},
}
}
async fn collect_results(
result_rx: &mut mpsc::Receiver<TaskResult>,
task_execution_tracker: Arc<TaskExecutionTracker>,
expected_count: usize,
) -> Vec<TaskResult> {
let mut results = Vec::new();
while let Some(result) = result_rx.recv().await {
task_execution_tracker
.complete_task(&result.task_id, result.clone())
.await;
results.push(result);
if results.len() >= expected_count {
break;
}
}
results
}
fn create_error_response(error: String) -> ExecutionResponse {
tracing::error!("Creating error response: {}", error);
ExecutionResponse {
status: "failed".to_string(),
results: vec![],
stats: ExecutionStats {
total_tasks: 0,
completed: 0,
failed: 1,
execution_time_ms: 0,
},
}
}

View File

@@ -0,0 +1,100 @@
use super::{calculate_stats, create_empty_response, create_error_response};
use crate::agents::sub_recipe_execution_tool::lib::{TaskResult, TaskStatus};
use serde_json::json;
fn create_test_task_result(task_id: &str, status: TaskStatus) -> TaskResult {
let is_failed = matches!(status, TaskStatus::Failed);
TaskResult {
task_id: task_id.to_string(),
status,
data: Some(json!({"output": "test output"})),
error: if is_failed {
Some("Test error".to_string())
} else {
None
},
}
}
#[test]
fn test_calculate_stats() {
let results = vec![
create_test_task_result("task1", TaskStatus::Completed),
create_test_task_result("task2", TaskStatus::Completed),
create_test_task_result("task3", TaskStatus::Failed),
create_test_task_result("task4", TaskStatus::Completed),
];
let stats = calculate_stats(&results, 1500);
assert_eq!(stats.total_tasks, 4);
assert_eq!(stats.completed, 3);
assert_eq!(stats.failed, 1);
assert_eq!(stats.execution_time_ms, 1500);
}
#[test]
fn test_calculate_stats_empty_results() {
let results = vec![];
let stats = calculate_stats(&results, 0);
assert_eq!(stats.total_tasks, 0);
assert_eq!(stats.completed, 0);
assert_eq!(stats.failed, 0);
assert_eq!(stats.execution_time_ms, 0);
}
#[test]
fn test_calculate_stats_all_completed() {
let results = vec![
create_test_task_result("task1", TaskStatus::Completed),
create_test_task_result("task2", TaskStatus::Completed),
];
let stats = calculate_stats(&results, 800);
assert_eq!(stats.total_tasks, 2);
assert_eq!(stats.completed, 2);
assert_eq!(stats.failed, 0);
assert_eq!(stats.execution_time_ms, 800);
}
#[test]
fn test_calculate_stats_all_failed() {
let results = vec![
create_test_task_result("task1", TaskStatus::Failed),
create_test_task_result("task2", TaskStatus::Failed),
];
let stats = calculate_stats(&results, 1200);
assert_eq!(stats.total_tasks, 2);
assert_eq!(stats.completed, 0);
assert_eq!(stats.failed, 2);
assert_eq!(stats.execution_time_ms, 1200);
}
#[test]
fn test_create_empty_response() {
let response = create_empty_response();
assert_eq!(response.status, "completed");
assert_eq!(response.results.len(), 0);
assert_eq!(response.stats.total_tasks, 0);
assert_eq!(response.stats.completed, 0);
assert_eq!(response.stats.failed, 0);
assert_eq!(response.stats.execution_time_ms, 0);
}
#[test]
fn test_create_error_response() {
let error_msg = "Test error message";
let response = create_error_response(error_msg.to_string());
assert_eq!(response.status, "failed");
assert_eq!(response.results.len(), 0);
assert_eq!(response.stats.total_tasks, 0);
assert_eq!(response.stats.completed, 0);
assert_eq!(response.stats.failed, 1);
assert_eq!(response.stats.execution_time_ms, 0);
}

View File

@@ -1,38 +0,0 @@
use crate::agents::sub_recipe_execution_tool::executor::execute_single_task;
pub use crate::agents::sub_recipe_execution_tool::executor::parallel_execute;
pub use crate::agents::sub_recipe_execution_tool::types::{
Config, ExecutionResponse, ExecutionStats, Task, TaskResult,
};
use serde_json::Value;
pub async fn execute_tasks(input: Value, execution_mode: &str) -> Result<Value, String> {
let tasks: Vec<Task> =
serde_json::from_value(input.get("tasks").ok_or("Missing tasks field")?.clone())
.map_err(|e| format!("Failed to parse tasks: {}", e))?;
let config: Config = if let Some(config_value) = input.get("config") {
serde_json::from_value(config_value.clone())
.map_err(|e| format!("Failed to parse config: {}", e))?
} else {
Config::default()
};
let task_count = tasks.len();
match execution_mode {
"sequential" => {
if task_count == 1 {
let response = execute_single_task(&tasks[0], config).await;
serde_json::to_value(response)
.map_err(|e| format!("Failed to serialize response: {}", e))
} else {
Err("Sequential execution mode requires exactly one task".to_string())
}
}
"parallel" => {
let response = parallel_execute(tasks, config).await;
serde_json::to_value(response)
.map_err(|e| format!("Failed to serialize response: {}", e))
}
_ => Err("Invalid execution mode".to_string()),
}
}

View File

@@ -0,0 +1,127 @@
use crate::agents::sub_recipe_execution_tool::executor::{
execute_single_task, execute_tasks_in_parallel,
};
pub use crate::agents::sub_recipe_execution_tool::task_types::{
ExecutionMode, ExecutionResponse, ExecutionStats, SharedState, Task, TaskResult, TaskStatus,
};
use crate::agents::sub_recipe_execution_tool::tasks_manager::TasksManager;
#[cfg(test)]
mod tests;
use mcp_core::protocol::JsonRpcMessage;
use serde_json::{json, Value};
use tokio::sync::mpsc;
pub async fn execute_tasks(
input: Value,
execution_mode: ExecutionMode,
notifier: mpsc::Sender<JsonRpcMessage>,
tasks_manager: &TasksManager,
) -> Result<Value, String> {
let task_ids: Vec<String> = serde_json::from_value(
input
.get("task_ids")
.ok_or("Missing task_ids field")?
.clone(),
)
.map_err(|e| format!("Failed to parse task_ids: {}", e))?;
let mut tasks = Vec::new();
for task_id in &task_ids {
match tasks_manager.get_task(task_id).await {
Some(task) => tasks.push(task),
None => {
return Err(format!(
"Task with ID '{}' not found in TasksManager",
task_id
))
}
}
}
let task_count = tasks.len();
match execution_mode {
ExecutionMode::Sequential => {
if task_count == 1 {
let response = execute_single_task(&tasks[0], notifier).await;
handle_response(response)
} else {
Err("Sequential execution mode requires exactly one task".to_string())
}
}
ExecutionMode::Parallel => {
if tasks.iter().any(|task| task.get_sequential_when_repeated()) {
Ok(json!(
{
"execution_mode": ExecutionMode::Sequential,
"task_ids": task_ids,
"results": ["the tasks should be executed sequentially, no matter how user requests it. Please use the subrecipe__execute_task tool to execute the tasks sequentially."]
}
))
} else {
let response: ExecutionResponse =
execute_tasks_in_parallel(tasks, notifier.clone()).await;
handle_response(response)
}
}
}
}
fn extract_failed_tasks(results: &[TaskResult]) -> Vec<String> {
results
.iter()
.filter(|r| matches!(r.status, TaskStatus::Failed))
.map(format_failed_task_error)
.collect()
}
fn format_failed_task_error(result: &TaskResult) -> String {
let error_msg = result.error.as_deref().unwrap_or("Unknown error");
let partial_output = result
.data
.as_ref()
.and_then(|d| d.get("partial_output"))
.and_then(|v| v.as_str())
.filter(|s| !s.trim().is_empty())
.unwrap_or("No output captured");
format!(
"Task '{}' ({}): {}\nOutput: {}",
result.task_id,
get_task_description(result),
error_msg,
partial_output
)
}
fn format_error_summary(
failed_count: usize,
total_count: usize,
failed_tasks: Vec<String>,
) -> String {
format!(
"{}/{} tasks failed:\n{}",
failed_count,
total_count,
failed_tasks.join("\n")
)
}
fn handle_response(response: ExecutionResponse) -> Result<Value, String> {
if response.stats.failed > 0 {
let failed_tasks = extract_failed_tasks(&response.results);
let error_summary = format_error_summary(
response.stats.failed,
response.stats.total_tasks,
failed_tasks,
);
return Err(error_summary);
}
serde_json::to_value(response).map_err(|e| format!("Failed to serialize response: {}", e))
}
fn get_task_description(result: &TaskResult) -> String {
format!("ID: {}", result.task_id)
}

View File

@@ -0,0 +1,216 @@
use super::{
extract_failed_tasks, format_error_summary, format_failed_task_error, get_task_description,
handle_response,
};
use crate::agents::sub_recipe_execution_tool::lib::{
ExecutionResponse, ExecutionStats, TaskResult, TaskStatus,
};
use serde_json::json;
fn create_test_task_result(task_id: &str, status: TaskStatus, error: Option<String>) -> TaskResult {
TaskResult {
task_id: task_id.to_string(),
status,
data: Some(json!({"partial_output": "test output"})),
error,
}
}
fn create_test_execution_response(
results: Vec<TaskResult>,
failed_count: usize,
) -> ExecutionResponse {
ExecutionResponse {
status: "completed".to_string(),
results: results.clone(),
stats: ExecutionStats {
total_tasks: results.len(),
completed: results.len() - failed_count,
failed: failed_count,
execution_time_ms: 1000,
},
}
}
#[test]
fn test_extract_failed_tasks() {
let results = vec![
create_test_task_result("task1", TaskStatus::Completed, None),
create_test_task_result(
"task2",
TaskStatus::Failed,
Some("Error message".to_string()),
),
create_test_task_result("task3", TaskStatus::Completed, None),
create_test_task_result(
"task4",
TaskStatus::Failed,
Some("Another error".to_string()),
),
];
let failed_tasks = extract_failed_tasks(&results);
assert_eq!(failed_tasks.len(), 2);
assert!(failed_tasks[0].contains("task2"));
assert!(failed_tasks[0].contains("Error message"));
assert!(failed_tasks[1].contains("task4"));
assert!(failed_tasks[1].contains("Another error"));
}
#[test]
fn test_extract_failed_tasks_empty() {
let results = vec![
create_test_task_result("task1", TaskStatus::Completed, None),
create_test_task_result("task2", TaskStatus::Completed, None),
];
let failed_tasks = extract_failed_tasks(&results);
assert_eq!(failed_tasks.len(), 0);
}
#[test]
fn test_format_failed_task_error_with_error_message() {
let result = create_test_task_result(
"task1",
TaskStatus::Failed,
Some("Test error message".to_string()),
);
let formatted = format_failed_task_error(&result);
assert!(formatted.contains("task1"));
assert!(formatted.contains("Test error message"));
assert!(formatted.contains("test output"));
assert!(formatted.contains("ID: task1"));
}
#[test]
fn test_format_failed_task_error_without_error_message() {
let result = create_test_task_result("task2", TaskStatus::Failed, None);
let formatted = format_failed_task_error(&result);
assert!(formatted.contains("task2"));
assert!(formatted.contains("Unknown error"));
assert!(formatted.contains("test output"));
}
#[test]
fn test_format_failed_task_error_empty_partial_output() {
let mut result =
create_test_task_result("task3", TaskStatus::Failed, Some("Error".to_string()));
result.data = Some(json!({"partial_output": ""}));
let formatted = format_failed_task_error(&result);
assert!(formatted.contains("No output captured"));
}
#[test]
fn test_format_failed_task_error_no_partial_output() {
let mut result =
create_test_task_result("task4", TaskStatus::Failed, Some("Error".to_string()));
result.data = Some(json!({}));
let formatted = format_failed_task_error(&result);
assert!(formatted.contains("No output captured"));
}
#[test]
fn test_format_failed_task_error_no_data() {
let mut result =
create_test_task_result("task5", TaskStatus::Failed, Some("Error".to_string()));
result.data = None;
let formatted = format_failed_task_error(&result);
assert!(formatted.contains("No output captured"));
}
#[test]
fn test_format_error_summary() {
let failed_tasks = vec![
"Task 'task1': Error 1\nOutput: output1".to_string(),
"Task 'task2': Error 2\nOutput: output2".to_string(),
];
let summary = format_error_summary(2, 5, failed_tasks);
assert_eq!(summary, "2/5 tasks failed:\nTask 'task1': Error 1\nOutput: output1\nTask 'task2': Error 2\nOutput: output2");
}
#[test]
fn test_format_error_summary_single_failure() {
let failed_tasks = vec!["Task 'task1': Error\nOutput: output".to_string()];
let summary = format_error_summary(1, 3, failed_tasks);
assert_eq!(
summary,
"1/3 tasks failed:\nTask 'task1': Error\nOutput: output"
);
}
#[test]
fn test_handle_response_success() {
let results = vec![
create_test_task_result("task1", TaskStatus::Completed, None),
create_test_task_result("task2", TaskStatus::Completed, None),
];
let response = create_test_execution_response(results, 0);
let result = handle_response(response);
assert!(result.is_ok());
let value = result.unwrap();
assert_eq!(value["status"], "completed");
assert_eq!(value["stats"]["failed"], 0);
}
#[test]
fn test_handle_response_with_failures() {
let results = vec![
create_test_task_result("task1", TaskStatus::Completed, None),
create_test_task_result("task2", TaskStatus::Failed, Some("Test error".to_string())),
];
let response = create_test_execution_response(results, 1);
let result = handle_response(response);
assert!(result.is_err());
let error = result.unwrap_err();
assert!(error.contains("1/2 tasks failed"));
assert!(error.contains("task2"));
assert!(error.contains("Test error"));
}
#[test]
fn test_handle_response_all_failures() {
let results = vec![
create_test_task_result("task1", TaskStatus::Failed, Some("Error 1".to_string())),
create_test_task_result("task2", TaskStatus::Failed, Some("Error 2".to_string())),
];
let response = create_test_execution_response(results, 2);
let result = handle_response(response);
assert!(result.is_err());
let error = result.unwrap_err();
assert!(error.contains("2/2 tasks failed"));
assert!(error.contains("task1"));
assert!(error.contains("task2"));
assert!(error.contains("Error 1"));
assert!(error.contains("Error 2"));
}
#[test]
fn test_get_task_description() {
let result = create_test_task_result("test_task_123", TaskStatus::Completed, None);
let description = get_task_description(&result);
assert_eq!(description, "ID: test_task_123");
}

View File

@@ -1,6 +1,10 @@
mod executor;
pub mod lib;
pub mod notification_events;
pub mod sub_recipe_execute_task_tool;
mod task_execution_tracker;
mod task_types;
mod tasks;
mod types;
pub mod tasks_manager;
pub mod utils;
mod workers;

View File

@@ -0,0 +1,204 @@
use crate::agents::sub_recipe_execution_tool::task_types::TaskStatus;
use serde::{Deserialize, Serialize};
use serde_json::Value;
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "subtype")]
pub enum TaskExecutionNotificationEvent {
#[serde(rename = "line_output")]
LineOutput { task_id: String, output: String },
#[serde(rename = "tasks_update")]
TasksUpdate {
stats: TaskExecutionStats,
tasks: Vec<TaskInfo>,
},
#[serde(rename = "tasks_complete")]
TasksComplete {
stats: TaskCompletionStats,
failed_tasks: Vec<FailedTaskInfo>,
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TaskExecutionStats {
pub total: usize,
pub pending: usize,
pub running: usize,
pub completed: usize,
pub failed: usize,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TaskCompletionStats {
pub total: usize,
pub completed: usize,
pub failed: usize,
pub success_rate: f64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TaskInfo {
pub id: String,
pub status: TaskStatus,
pub duration_secs: Option<f64>,
pub current_output: String,
pub task_type: String,
pub task_name: String,
pub task_metadata: String,
pub error: Option<String>,
pub result_data: Option<Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FailedTaskInfo {
pub id: String,
pub name: String,
pub error: Option<String>,
}
impl TaskExecutionNotificationEvent {
pub fn line_output(task_id: String, output: String) -> Self {
Self::LineOutput { task_id, output }
}
pub fn tasks_update(stats: TaskExecutionStats, tasks: Vec<TaskInfo>) -> Self {
Self::TasksUpdate { stats, tasks }
}
pub fn tasks_complete(stats: TaskCompletionStats, failed_tasks: Vec<FailedTaskInfo>) -> Self {
Self::TasksComplete {
stats,
failed_tasks,
}
}
/// Convert event to JSON format for MCP notification
pub fn to_notification_data(&self) -> serde_json::Value {
let mut event_data = serde_json::to_value(self).expect("Failed to serialize event");
// Add the type field at the root level
if let serde_json::Value::Object(ref mut map) = event_data {
map.insert(
"type".to_string(),
serde_json::Value::String("task_execution".to_string()),
);
}
event_data
}
}
impl TaskExecutionStats {
pub fn new(
total: usize,
pending: usize,
running: usize,
completed: usize,
failed: usize,
) -> Self {
Self {
total,
pending,
running,
completed,
failed,
}
}
}
impl TaskCompletionStats {
pub fn new(total: usize, completed: usize, failed: usize) -> Self {
let success_rate = if total > 0 {
(completed as f64 / total as f64) * 100.0
} else {
0.0
};
Self {
total,
completed,
failed,
success_rate,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_line_output_event_serialization() {
let event = TaskExecutionNotificationEvent::line_output(
"task-1".to_string(),
"Hello World".to_string(),
);
let notification_data = event.to_notification_data();
assert_eq!(notification_data["type"], "task_execution");
assert_eq!(notification_data["subtype"], "line_output");
assert_eq!(notification_data["task_id"], "task-1");
assert_eq!(notification_data["output"], "Hello World");
}
#[test]
fn test_tasks_update_event_serialization() {
let stats = TaskExecutionStats::new(5, 2, 1, 1, 1);
let tasks = vec![TaskInfo {
id: "task-1".to_string(),
status: TaskStatus::Running,
duration_secs: Some(1.5),
current_output: "Processing...".to_string(),
task_type: "sub_recipe".to_string(),
task_name: "test-task".to_string(),
task_metadata: "param=value".to_string(),
error: None,
result_data: None,
}];
let event = TaskExecutionNotificationEvent::tasks_update(stats, tasks);
let notification_data = event.to_notification_data();
assert_eq!(notification_data["type"], "task_execution");
assert_eq!(notification_data["subtype"], "tasks_update");
assert_eq!(notification_data["stats"]["total"], 5);
assert_eq!(notification_data["tasks"].as_array().unwrap().len(), 1);
}
#[test]
fn test_event_roundtrip_serialization() {
let original_event = TaskExecutionNotificationEvent::line_output(
"task-1".to_string(),
"Test output".to_string(),
);
// Serialize to JSON
let json_data = original_event.to_notification_data();
// Deserialize back to event (excluding the type field)
let mut event_data = json_data.clone();
if let serde_json::Value::Object(ref mut map) = event_data {
map.remove("type");
}
let deserialized_event: TaskExecutionNotificationEvent =
serde_json::from_value(event_data).expect("Failed to deserialize");
match (original_event, deserialized_event) {
(
TaskExecutionNotificationEvent::LineOutput {
task_id: id1,
output: out1,
},
TaskExecutionNotificationEvent::LineOutput {
task_id: id2,
output: out2,
},
) => {
assert_eq!(id1, id2);
assert_eq!(out1, out2);
}
_ => panic!("Event types don't match after roundtrip"),
}
}
}

View File

@@ -2,17 +2,22 @@ use mcp_core::{tool::ToolAnnotations, Content, Tool, ToolError};
use serde_json::Value;
use crate::agents::{
sub_recipe_execution_tool::lib::execute_tasks, tool_execution::ToolCallResult,
sub_recipe_execution_tool::lib::execute_tasks,
sub_recipe_execution_tool::task_types::ExecutionMode,
sub_recipe_execution_tool::tasks_manager::TasksManager, tool_execution::ToolCallResult,
};
use mcp_core::protocol::JsonRpcMessage;
use tokio::sync::mpsc;
use tokio_stream;
pub const SUB_RECIPE_EXECUTE_TASK_TOOL_NAME: &str = "sub_recipe__execute_task";
pub fn create_sub_recipe_execute_task_tool() -> Tool {
Tool::new(
SUB_RECIPE_EXECUTE_TASK_TOOL_NAME,
"Only use this tool when you execute sub recipe task.
EXECUTION STRATEGY:
- DEFAULT: Execute tasks sequentially (one at a time) unless user explicitly requests parallel execution
- PARALLEL: Only when user explicitly uses keywords like 'parallel', 'simultaneously', 'at the same time', 'concurrently'
EXECUTION STRATEGY DECISION:
1. If the tasks are created with execution_mode, use the execution_mode.
2. Execute tasks sequentially unless user explicitly requests parallel execution. PARALLEL: User uses keywords like 'parallel', 'simultaneously', 'at the same time', 'concurrently'
IMPLEMENTATION:
- Sequential execution: Call this tool multiple times, passing exactly ONE task per call
@@ -32,69 +37,15 @@ EXAMPLES:
"default": "sequential",
"description": "Execution strategy for multiple tasks. Use 'sequential' (default) unless user explicitly requests parallel execution with words like 'parallel', 'simultaneously', 'at the same time', or 'concurrently'."
},
"tasks": {
"task_ids": {
"type": "array",
"items": {
"type": "object",
"properties": {
"id": {
"type": "string",
"description": "Unique identifier for the task"
},
"task_type": {
"type": "string",
"enum": ["sub_recipe", "text_instruction"],
"default": "sub_recipe",
"description": "the type of task to execute, can be one of: sub_recipe, text_instruction"
},
"payload": {
"type": "object",
"properties": {
"sub_recipe": {
"type": "object",
"description": "sub recipe to execute",
"properties": {
"name": {
"type": "string",
"description": "name of the sub recipe to execute"
},
"recipe_path": {
"type": "string",
"description": "path of the sub recipe file"
},
"command_parameters": {
"type": "object",
"description": "parameters to pass to run recipe command with sub recipe file"
}
}
},
"text_instruction": {
"type": "string",
"description": "text instruction to execute"
}
}
}
},
"required": ["id", "payload"]
},
"description": "The tasks to run in parallel"
},
"config": {
"type": "object",
"properties": {
"timeout_seconds": {
"type": "number"
},
"max_workers": {
"type": "number"
},
"initial_workers": {
"type": "number"
}
"type": "string",
"description": "Unique identifier for the task"
}
}
},
"required": ["tasks"]
"required": ["task_ids"]
}),
Some(ToolAnnotations {
title: Some("Run tasks in parallel".to_string()),
@@ -106,19 +57,38 @@ EXAMPLES:
)
}
pub async fn run_tasks(execute_data: Value) -> ToolCallResult {
let execute_data_clone = execute_data.clone();
let default_execution_mode_value = Value::String("sequential".to_string());
let execution_mode = execute_data_clone
.get("execution_mode")
.unwrap_or(&default_execution_mode_value)
.as_str()
.unwrap_or("sequential");
match execute_tasks(execute_data, execution_mode).await {
Ok(result) => {
let output = serde_json::to_string(&result).unwrap();
ToolCallResult::from(Ok(vec![Content::text(output)]))
pub async fn run_tasks(execute_data: Value, tasks_manager: &TasksManager) -> ToolCallResult {
let (notification_tx, notification_rx) = mpsc::channel::<JsonRpcMessage>(100);
let tasks_manager_clone = tasks_manager.clone();
let result_future = async move {
let execute_data_clone = execute_data.clone();
let execution_mode = execute_data_clone
.get("execution_mode")
.and_then(|v| serde_json::from_value::<ExecutionMode>(v.clone()).ok())
.unwrap_or_default();
match execute_tasks(
execute_data,
execution_mode,
notification_tx,
&tasks_manager_clone,
)
.await
{
Ok(result) => {
let output = serde_json::to_string(&result).unwrap();
Ok(vec![Content::text(output)])
}
Err(e) => Err(ToolError::ExecutionError(e.to_string())),
}
Err(e) => ToolCallResult::from(Err(ToolError::ExecutionError(e.to_string()))),
};
// Convert receiver to stream
let notification_stream = tokio_stream::wrappers::ReceiverStream::new(notification_rx);
ToolCallResult {
result: Box::new(Box::pin(result_future)),
notification_stream: Some(Box::new(notification_stream)),
}
}

View File

@@ -0,0 +1,292 @@
use mcp_core::protocol::{JsonRpcMessage, JsonRpcNotification};
use serde_json::json;
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::{mpsc, RwLock};
use tokio::time::{sleep, Duration, Instant};
use crate::agents::sub_recipe_execution_tool::notification_events::{
FailedTaskInfo, TaskCompletionStats, TaskExecutionNotificationEvent, TaskExecutionStats,
TaskInfo as EventTaskInfo,
};
use crate::agents::sub_recipe_execution_tool::task_types::{
Task, TaskInfo, TaskResult, TaskStatus,
};
use crate::agents::sub_recipe_execution_tool::utils::{count_by_status, get_task_name};
use serde_json::Value;
#[derive(Debug, Clone, PartialEq)]
pub enum DisplayMode {
MultipleTasksOutput,
SingleTaskOutput,
}
const THROTTLE_INTERVAL_MS: u64 = 250;
const COMPLETION_NOTIFICATION_DELAY_MS: u64 = 500;
fn format_task_metadata(task_info: &TaskInfo) -> String {
if let Some(params) = task_info.task.get_command_parameters() {
if params.is_empty() {
return String::new();
}
params
.iter()
.map(|(key, value)| {
let value_str = match value {
Value::String(s) => s.clone(),
_ => value.to_string(),
};
format!("{}={}", key, value_str)
})
.collect::<Vec<_>>()
.join(",")
} else {
String::new()
}
}
pub struct TaskExecutionTracker {
tasks: Arc<RwLock<HashMap<String, TaskInfo>>>,
last_refresh: Arc<RwLock<Instant>>,
notifier: mpsc::Sender<JsonRpcMessage>,
display_mode: DisplayMode,
}
impl TaskExecutionTracker {
pub fn new(
tasks: Vec<Task>,
display_mode: DisplayMode,
notifier: mpsc::Sender<JsonRpcMessage>,
) -> Self {
let task_map = tasks
.into_iter()
.map(|task| {
let task_id = task.id.clone();
(
task_id,
TaskInfo {
task,
status: TaskStatus::Pending,
start_time: None,
end_time: None,
result: None,
current_output: String::new(),
},
)
})
.collect();
Self {
tasks: Arc::new(RwLock::new(task_map)),
last_refresh: Arc::new(RwLock::new(Instant::now())),
notifier,
display_mode,
}
}
pub async fn start_task(&self, task_id: &str) {
let mut tasks = self.tasks.write().await;
if let Some(task_info) = tasks.get_mut(task_id) {
task_info.status = TaskStatus::Running;
task_info.start_time = Some(Instant::now());
}
drop(tasks);
self.force_refresh_display().await;
}
pub async fn complete_task(&self, task_id: &str, result: TaskResult) {
let mut tasks = self.tasks.write().await;
if let Some(task_info) = tasks.get_mut(task_id) {
task_info.status = result.status.clone();
task_info.end_time = Some(Instant::now());
task_info.result = Some(result);
}
drop(tasks);
self.force_refresh_display().await;
}
pub async fn get_current_output(&self, task_id: &str) -> Option<String> {
let tasks = self.tasks.read().await;
tasks
.get(task_id)
.map(|task_info| task_info.current_output.clone())
}
pub async fn send_live_output(&self, task_id: &str, line: &str) {
match self.display_mode {
DisplayMode::SingleTaskOutput => {
let tasks = self.tasks.read().await;
let task_info = tasks.get(task_id);
let formatted_line = if let Some(task_info) = task_info {
let task_name = get_task_name(task_info);
let task_type = task_info.task.task_type.clone();
let metadata = format_task_metadata(task_info);
if metadata.is_empty() {
format!("[{} ({})] {}", task_name, task_type, line)
} else {
format!("[{} ({}) {}] {}", task_name, task_type, metadata, line)
}
} else {
line.to_string()
};
drop(tasks);
let event = TaskExecutionNotificationEvent::line_output(
task_id.to_string(),
formatted_line,
);
if let Err(e) =
self.notifier
.try_send(JsonRpcMessage::Notification(JsonRpcNotification {
jsonrpc: "2.0".to_string(),
method: "notifications/message".to_string(),
params: Some(json!({
"data": event.to_notification_data()
})),
}))
{
tracing::warn!("Failed to send live output notification: {}", e);
}
}
DisplayMode::MultipleTasksOutput => {
let mut tasks = self.tasks.write().await;
if let Some(task_info) = tasks.get_mut(task_id) {
task_info.current_output.push_str(line);
task_info.current_output.push('\n');
}
drop(tasks);
if !self.should_throttle_refresh().await {
self.refresh_display().await;
}
}
}
}
async fn should_throttle_refresh(&self) -> bool {
let now = Instant::now();
let mut last_refresh = self.last_refresh.write().await;
if now.duration_since(*last_refresh) > Duration::from_millis(THROTTLE_INTERVAL_MS) {
*last_refresh = now;
false
} else {
true
}
}
async fn send_tasks_update(&self) {
let tasks = self.tasks.read().await;
let task_list: Vec<_> = tasks.values().collect();
let (total, pending, running, completed, failed) = count_by_status(&tasks);
let stats = TaskExecutionStats::new(total, pending, running, completed, failed);
let event_tasks: Vec<EventTaskInfo> = task_list
.iter()
.map(|task_info| {
let now = Instant::now();
EventTaskInfo {
id: task_info.task.id.clone(),
status: task_info.status.clone(),
duration_secs: task_info.start_time.map(|start| {
if let Some(end) = task_info.end_time {
end.duration_since(start).as_secs_f64()
} else {
now.duration_since(start).as_secs_f64()
}
}),
current_output: task_info.current_output.clone(),
task_type: task_info.task.task_type.clone(),
task_name: get_task_name(task_info).to_string(),
task_metadata: format_task_metadata(task_info),
error: task_info.error().cloned(),
result_data: task_info.data().cloned(),
}
})
.collect();
let event = TaskExecutionNotificationEvent::tasks_update(stats, event_tasks);
if let Err(e) = self
.notifier
.try_send(JsonRpcMessage::Notification(JsonRpcNotification {
jsonrpc: "2.0".to_string(),
method: "notifications/message".to_string(),
params: Some(json!({
"data": event.to_notification_data()
})),
}))
{
tracing::warn!("Failed to send tasks update notification: {}", e);
}
}
pub async fn refresh_display(&self) {
match self.display_mode {
DisplayMode::MultipleTasksOutput => {
self.send_tasks_update().await;
}
DisplayMode::SingleTaskOutput => {
// No dashboard display needed for single task output mode
// Live output is handled via send_live_output method
}
}
}
// Force refresh without throttling - used for important status changes
async fn force_refresh_display(&self) {
match self.display_mode {
DisplayMode::MultipleTasksOutput => {
// Reset throttle timer to allow immediate update
let mut last_refresh = self.last_refresh.write().await;
*last_refresh = Instant::now() - Duration::from_millis(THROTTLE_INTERVAL_MS + 1);
drop(last_refresh);
self.send_tasks_update().await;
}
DisplayMode::SingleTaskOutput => {
// No dashboard display needed for single task output mode
}
}
}
pub async fn send_tasks_complete(&self) {
let tasks = self.tasks.read().await;
let (total, _, _, completed, failed) = count_by_status(&tasks);
let stats = TaskCompletionStats::new(total, completed, failed);
let failed_tasks: Vec<FailedTaskInfo> = tasks
.values()
.filter(|task_info| matches!(task_info.status, TaskStatus::Failed))
.map(|task_info| FailedTaskInfo {
id: task_info.task.id.clone(),
name: get_task_name(task_info).to_string(),
error: task_info.error().cloned(),
})
.collect();
let event = TaskExecutionNotificationEvent::tasks_complete(stats, failed_tasks);
if let Err(e) = self
.notifier
.try_send(JsonRpcMessage::Notification(JsonRpcNotification {
jsonrpc: "2.0".to_string(),
method: "notifications/message".to_string(),
params: Some(json!({
"data": event.to_notification_data()
})),
}))
{
tracing::warn!("Failed to send tasks complete notification: {}", e);
}
// Brief delay to ensure completion notification is processed
sleep(Duration::from_millis(COMPLETION_NOTIFICATION_DELAY_MS)).await;
}
}

View File

@@ -0,0 +1,145 @@
use serde::{Deserialize, Serialize};
use serde_json::{Map, Value};
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use tokio::sync::mpsc;
use crate::agents::sub_recipe_execution_tool::task_execution_tracker::TaskExecutionTracker;
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
#[serde(rename_all = "lowercase")]
pub enum ExecutionMode {
#[default]
Sequential,
Parallel,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Task {
pub id: String,
pub task_type: String,
pub payload: Value,
}
impl Task {
pub fn get_sub_recipe(&self) -> Option<&Map<String, Value>> {
(self.task_type == "sub_recipe")
.then(|| self.payload.get("sub_recipe")?.as_object())
.flatten()
}
pub fn get_sequential_when_repeated(&self) -> bool {
self.get_sub_recipe()
.and_then(|sr| sr.get("sequential_when_repeated").and_then(|v| v.as_bool()))
.unwrap_or_default()
}
pub fn get_command_parameters(&self) -> Option<&Map<String, Value>> {
self.get_sub_recipe()
.and_then(|sr| sr.get("command_parameters"))
.and_then(|cp| cp.as_object())
}
pub fn get_sub_recipe_name(&self) -> Option<&str> {
self.get_sub_recipe()
.and_then(|sr| sr.get("name"))
.and_then(|name| name.as_str())
}
pub fn get_sub_recipe_path(&self) -> Option<&str> {
self.get_sub_recipe()
.and_then(|sr| sr.get("recipe_path"))
.and_then(|path| path.as_str())
}
pub fn get_text_instruction(&self) -> Option<&str> {
if self.task_type != "sub_recipe" {
self.payload
.get("text_instruction")
.and_then(|text| text.as_str())
} else {
None
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TaskResult {
pub task_id: String,
pub status: TaskStatus,
#[serde(skip_serializing_if = "Option::is_none")]
pub data: Option<Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub error: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum TaskStatus {
Pending,
Running,
Completed,
Failed,
}
impl std::fmt::Display for TaskStatus {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
TaskStatus::Pending => write!(f, "Pending"),
TaskStatus::Running => write!(f, "Running"),
TaskStatus::Completed => write!(f, "Completed"),
TaskStatus::Failed => write!(f, "Failed"),
}
}
}
#[derive(Debug, Clone)]
pub struct TaskInfo {
pub task: Task,
pub status: TaskStatus,
pub start_time: Option<tokio::time::Instant>,
pub end_time: Option<tokio::time::Instant>,
pub result: Option<TaskResult>,
pub current_output: String,
}
impl TaskInfo {
pub fn error(&self) -> Option<&String> {
self.result.as_ref().and_then(|r| r.error.as_ref())
}
pub fn data(&self) -> Option<&Value> {
self.result.as_ref().and_then(|r| r.data.as_ref())
}
}
pub struct SharedState {
pub task_receiver: Arc<tokio::sync::Mutex<mpsc::Receiver<Task>>>,
pub result_sender: mpsc::Sender<TaskResult>,
pub active_workers: Arc<AtomicUsize>,
pub task_execution_tracker: Arc<TaskExecutionTracker>,
}
impl SharedState {
pub fn increment_active_workers(&self) {
self.active_workers.fetch_add(1, Ordering::SeqCst);
}
pub fn decrement_active_workers(&self) {
self.active_workers.fetch_sub(1, Ordering::SeqCst);
}
}
#[derive(Debug, Serialize)]
pub struct ExecutionStats {
pub total_tasks: usize,
pub completed: usize,
pub failed: usize,
pub execution_time_ms: u128,
}
#[derive(Debug, Serialize)]
pub struct ExecutionResponse {
pub status: String,
pub results: Vec<TaskResult>,
pub stats: ExecutionStats,
}

View File

@@ -1,76 +1,98 @@
use serde_json::Value;
use std::process::Stdio;
use std::time::Duration;
use std::sync::Arc;
use tokio::io::{AsyncBufReadExt, BufReader};
use tokio::process::Command;
use tokio::time::timeout;
use crate::agents::sub_recipe_execution_tool::types::{Task, TaskResult};
use crate::agents::sub_recipe_execution_tool::task_execution_tracker::TaskExecutionTracker;
use crate::agents::sub_recipe_execution_tool::task_types::{Task, TaskResult, TaskStatus};
// Process a single task based on its type
pub async fn process_task(task: &Task, timeout_seconds: u64) -> TaskResult {
let task_clone = task.clone();
let timeout_duration = Duration::from_secs(timeout_seconds);
// Execute with timeout
match timeout(timeout_duration, execute_task(task_clone)).await {
Ok(Ok(data)) => TaskResult {
pub async fn process_task(
task: &Task,
task_execution_tracker: Arc<TaskExecutionTracker>,
) -> TaskResult {
match get_task_result(task.clone(), task_execution_tracker).await {
Ok(data) => TaskResult {
task_id: task.id.clone(),
status: "success".to_string(),
status: TaskStatus::Completed,
data: Some(data),
error: None,
},
Ok(Err(error)) => TaskResult {
Err(error) => TaskResult {
task_id: task.id.clone(),
status: "failed".to_string(),
status: TaskStatus::Failed,
data: None,
error: Some(error),
},
Err(_) => TaskResult {
task_id: task.id.clone(),
status: "failed".to_string(),
data: None,
error: Some("Task timeout".to_string()),
},
}
}
async fn execute_task(task: Task) -> Result<Value, String> {
async fn get_task_result(
task: Task,
task_execution_tracker: Arc<TaskExecutionTracker>,
) -> Result<Value, String> {
let (command, output_identifier) = build_command(&task)?;
let (stdout_output, stderr_output, success) = run_command(
command,
&output_identifier,
&task.id,
task_execution_tracker,
)
.await?;
if success {
process_output(stdout_output)
} else {
Err(format!("Command failed:\n{}", stderr_output))
}
}
fn build_command(task: &Task) -> Result<(Command, String), String> {
let task_error = |field: &str| format!("Task {}: Missing {}", task.id, field);
let mut output_identifier = task.id.clone();
let mut command = if task.task_type == "sub_recipe" {
let sub_recipe = task.payload.get("sub_recipe").unwrap();
let sub_recipe_name = sub_recipe.get("name").unwrap().as_str().unwrap();
let path = sub_recipe.get("recipe_path").unwrap().as_str().unwrap();
let command_parameters = sub_recipe.get("command_parameters").unwrap();
let sub_recipe_name = task
.get_sub_recipe_name()
.ok_or_else(|| task_error("sub_recipe name"))?;
let path = task
.get_sub_recipe_path()
.ok_or_else(|| task_error("sub_recipe path"))?;
let command_parameters = task
.get_command_parameters()
.ok_or_else(|| task_error("command_parameters"))?;
output_identifier = format!("sub-recipe {}", sub_recipe_name);
let mut cmd = Command::new("goose");
cmd.arg("run").arg("--recipe").arg(path);
if let Some(params_map) = command_parameters.as_object() {
for (key, value) in params_map {
let key_str = key.to_string();
let value_str = value.as_str().unwrap_or(&value.to_string()).to_string();
cmd.arg("--params")
.arg(format!("{}={}", key_str, value_str));
}
cmd.arg("run").arg("--recipe").arg(path).arg("--no-session");
for (key, value) in command_parameters {
let key_str = key.to_string();
let value_str = value.as_str().unwrap_or(&value.to_string()).to_string();
cmd.arg("--params")
.arg(format!("{}={}", key_str, value_str));
}
cmd
} else {
let text = task
.payload
.get("text_instruction")
.unwrap()
.as_str()
.unwrap();
.get_text_instruction()
.ok_or_else(|| task_error("text_instruction"))?;
let mut cmd = Command::new("goose");
cmd.arg("run").arg("--text").arg(text);
cmd
};
// Configure to capture stdout
command.stdout(Stdio::piped());
command.stderr(Stdio::piped());
Ok((command, output_identifier))
}
// Spawn the child process
async fn run_command(
mut command: Command,
output_identifier: &str,
task_id: &str,
task_execution_tracker: Arc<TaskExecutionTracker>,
) -> Result<(String, String, bool), String> {
let mut child = command
.spawn()
.map_err(|e| format!("Failed to spawn goose: {}", e))?;
@@ -78,30 +100,20 @@ async fn execute_task(task: Task) -> Result<Value, String> {
let stdout = child.stdout.take().expect("Failed to capture stdout");
let stderr = child.stderr.take().expect("Failed to capture stderr");
let mut stdout_reader = BufReader::new(stdout).lines();
let mut stderr_reader = BufReader::new(stderr).lines();
// Spawn background tasks to read from stdout and stderr
let output_identifier_clone = output_identifier.clone();
let stdout_task = tokio::spawn(async move {
let mut buffer = String::new();
while let Ok(Some(line)) = stdout_reader.next_line().await {
println!("[{}] {}", output_identifier_clone, line);
buffer.push_str(&line);
buffer.push('\n');
}
buffer
});
let stderr_task = tokio::spawn(async move {
let mut buffer = String::new();
while let Ok(Some(line)) = stderr_reader.next_line().await {
eprintln!("[stderr for {}] {}", output_identifier, line);
buffer.push_str(&line);
buffer.push('\n');
}
buffer
});
let stdout_task = spawn_output_reader(
stdout,
output_identifier,
false,
task_id,
task_execution_tracker.clone(),
);
let stderr_task = spawn_output_reader(
stderr,
output_identifier,
true,
task_id,
task_execution_tracker.clone(),
);
let status = child
.wait()
@@ -111,9 +123,63 @@ async fn execute_task(task: Task) -> Result<Value, String> {
let stdout_output = stdout_task.await.unwrap();
let stderr_output = stderr_task.await.unwrap();
if status.success() {
Ok(Value::String(stdout_output))
Ok((stdout_output, stderr_output, status.success()))
}
fn spawn_output_reader(
reader: impl tokio::io::AsyncRead + Unpin + Send + 'static,
output_identifier: &str,
is_stderr: bool,
task_id: &str,
task_execution_tracker: Arc<TaskExecutionTracker>,
) -> tokio::task::JoinHandle<String> {
let output_identifier = output_identifier.to_string();
let task_id = task_id.to_string();
tokio::spawn(async move {
let mut buffer = String::new();
let mut lines = BufReader::new(reader).lines();
while let Ok(Some(line)) = lines.next_line().await {
buffer.push_str(&line);
buffer.push('\n');
if !is_stderr {
task_execution_tracker
.send_live_output(&task_id, &line)
.await;
} else {
tracing::warn!("Task stderr [{}]: {}", output_identifier, line);
}
}
buffer
})
}
fn extract_json_from_line(line: &str) -> Option<String> {
let start = line.find('{')?;
let end = line.rfind('}')?;
if start >= end {
return None;
}
let potential_json = &line[start..=end];
if serde_json::from_str::<Value>(potential_json).is_ok() {
Some(potential_json.to_string())
} else {
Err(format!("Command failed:\n{}", stderr_output))
None
}
}
fn process_output(stdout_output: String) -> Result<Value, String> {
let last_line = stdout_output
.lines()
.filter(|line| !line.trim().is_empty())
.next_back()
.unwrap_or("");
if let Some(json_string) = extract_json_from_line(last_line) {
Ok(Value::String(json_string))
} else {
Ok(Value::String(stdout_output))
}
}

View File

@@ -0,0 +1,86 @@
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::RwLock;
use crate::agents::sub_recipe_execution_tool::task_types::Task;
#[derive(Debug, Clone)]
pub struct TasksManager {
tasks: Arc<RwLock<HashMap<String, Task>>>,
}
impl Default for TasksManager {
fn default() -> Self {
Self::new()
}
}
impl TasksManager {
pub fn new() -> Self {
Self {
tasks: Arc::new(RwLock::new(HashMap::new())),
}
}
pub async fn save_tasks(&self, tasks: Vec<Task>) {
let mut task_map = self.tasks.write().await;
for task in tasks {
task_map.insert(task.id.clone(), task);
}
}
pub async fn get_task(&self, task_id: &str) -> Option<Task> {
let tasks = self.tasks.read().await;
tasks.get(task_id).cloned()
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
fn create_test_task(id: &str, sub_recipe_name: &str) -> Task {
Task {
id: id.to_string(),
task_type: "sub_recipe".to_string(),
payload: json!({
"sub_recipe": {
"name": sub_recipe_name,
"command_parameters": {},
"recipe_path": "/test/path"
}
}),
}
}
#[tokio::test]
async fn test_save_and_get_task() {
let manager = TasksManager::new();
let tasks = vec![create_test_task("task1", "weather")];
manager.save_tasks(tasks).await;
let retrieved = manager.get_task("task1").await;
assert!(retrieved.is_some());
assert_eq!(retrieved.unwrap().id, "task1");
}
#[tokio::test]
async fn test_save_multiple_tasks() {
let manager = TasksManager::new();
let tasks = vec![
create_test_task("task1", "weather"),
create_test_task("task2", "news"),
];
manager.save_tasks(tasks).await;
let task1 = manager.get_task("task1").await;
let task2 = manager.get_task("task2").await;
assert!(task1.is_some());
assert!(task2.is_some());
assert_eq!(task1.unwrap().id, "task1");
assert_eq!(task2.unwrap().id, "task2");
}
}

View File

@@ -1,69 +0,0 @@
use serde::{Deserialize, Serialize};
use serde_json::Value;
// Task definition that LLMs will send
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Task {
pub id: String,
pub task_type: String,
pub payload: Value,
}
// Result for each task
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TaskResult {
pub task_id: String,
pub status: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub data: Option<Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub error: Option<String>,
}
// Configuration for the parallel executor
#[derive(Debug, Clone, Deserialize)]
pub struct Config {
#[serde(default = "default_max_workers")]
pub max_workers: usize,
#[serde(default = "default_timeout")]
pub timeout_seconds: u64,
#[serde(default = "default_initial_workers")]
pub initial_workers: usize,
}
impl Default for Config {
fn default() -> Self {
Self {
max_workers: default_max_workers(),
timeout_seconds: default_timeout(),
initial_workers: default_initial_workers(),
}
}
}
fn default_max_workers() -> usize {
10
}
fn default_timeout() -> u64 {
300
}
fn default_initial_workers() -> usize {
2
}
// Stats for the execution
#[derive(Debug, Serialize)]
pub struct ExecutionStats {
pub total_tasks: usize,
pub completed: usize,
pub failed: usize,
pub execution_time_ms: u128,
}
// Main response structure
#[derive(Debug, Serialize)]
pub struct ExecutionResponse {
pub status: String,
pub results: Vec<TaskResult>,
pub stats: ExecutionStats,
}

View File

@@ -0,0 +1,27 @@
use std::collections::HashMap;
use crate::agents::sub_recipe_execution_tool::task_types::{TaskInfo, TaskStatus};
pub fn get_task_name(task_info: &TaskInfo) -> &str {
task_info
.task
.get_sub_recipe_name()
.unwrap_or(&task_info.task.id)
}
pub fn count_by_status(tasks: &HashMap<String, TaskInfo>) -> (usize, usize, usize, usize, usize) {
let total = tasks.len();
let (pending, running, completed, failed) = tasks.values().fold(
(0, 0, 0, 0),
|(pending, running, completed, failed), task| match task.status {
TaskStatus::Pending => (pending + 1, running, completed, failed),
TaskStatus::Running => (pending, running + 1, completed, failed),
TaskStatus::Completed => (pending, running, completed + 1, failed),
TaskStatus::Failed => (pending, running, completed, failed + 1),
},
);
(total, pending, running, completed, failed)
}
#[cfg(test)]
mod tests;

View File

@@ -0,0 +1,154 @@
use crate::agents::sub_recipe_execution_tool::task_types::{Task, TaskInfo, TaskStatus};
use crate::agents::sub_recipe_execution_tool::utils::{count_by_status, get_task_name};
use serde_json::json;
use std::collections::HashMap;
fn create_task_info_with_defaults(task: Task, status: TaskStatus) -> TaskInfo {
TaskInfo {
task,
status,
start_time: None,
end_time: None,
result: None,
current_output: String::new(),
}
}
mod test_get_task_name {
use super::*;
#[test]
fn test_extracts_sub_recipe_name() {
let sub_recipe_task = Task {
id: "task_1".to_string(),
task_type: "sub_recipe".to_string(),
payload: json!({
"sub_recipe": {
"name": "my_recipe",
"recipe_path": "/path/to/recipe"
}
}),
};
let task_info = create_task_info_with_defaults(sub_recipe_task, TaskStatus::Pending);
assert_eq!(get_task_name(&task_info), "my_recipe");
}
#[test]
fn falls_back_to_task_id_for_text_instruction() {
let text_task = Task {
id: "task_2".to_string(),
task_type: "text_instruction".to_string(),
payload: json!({"text_instruction": "do something"}),
};
let task_info = create_task_info_with_defaults(text_task, TaskStatus::Pending);
assert_eq!(get_task_name(&task_info), "task_2");
}
#[test]
fn falls_back_to_task_id_when_sub_recipe_name_missing() {
let malformed_task = Task {
id: "task_3".to_string(),
task_type: "sub_recipe".to_string(),
payload: json!({
"sub_recipe": {
"recipe_path": "/path/to/recipe"
// missing "name" field
}
}),
};
let task_info = create_task_info_with_defaults(malformed_task, TaskStatus::Pending);
assert_eq!(get_task_name(&task_info), "task_3");
}
#[test]
fn falls_back_to_task_id_when_sub_recipe_missing() {
let malformed_task = Task {
id: "task_4".to_string(),
task_type: "sub_recipe".to_string(),
payload: json!({}), // missing "sub_recipe" field
};
let task_info = create_task_info_with_defaults(malformed_task, TaskStatus::Pending);
assert_eq!(get_task_name(&task_info), "task_4");
}
}
mod count_by_status {
use super::*;
fn create_test_task(id: &str, status: TaskStatus) -> TaskInfo {
let task = Task {
id: id.to_string(),
task_type: "test".to_string(),
payload: json!({}),
};
create_task_info_with_defaults(task, status)
}
#[test]
fn counts_empty_map() {
let tasks = HashMap::new();
let (total, pending, running, completed, failed) = count_by_status(&tasks);
assert_eq!(
(total, pending, running, completed, failed),
(0, 0, 0, 0, 0)
);
}
#[test]
fn counts_single_status() {
let mut tasks = HashMap::new();
tasks.insert(
"task1".to_string(),
create_test_task("task1", TaskStatus::Pending),
);
tasks.insert(
"task2".to_string(),
create_test_task("task2", TaskStatus::Pending),
);
let (total, pending, running, completed, failed) = count_by_status(&tasks);
assert_eq!(
(total, pending, running, completed, failed),
(2, 2, 0, 0, 0)
);
}
#[test]
fn counts_mixed_statuses() {
let mut tasks = HashMap::new();
tasks.insert(
"task1".to_string(),
create_test_task("task1", TaskStatus::Pending),
);
tasks.insert(
"task2".to_string(),
create_test_task("task2", TaskStatus::Running),
);
tasks.insert(
"task3".to_string(),
create_test_task("task3", TaskStatus::Completed),
);
tasks.insert(
"task4".to_string(),
create_test_task("task4", TaskStatus::Failed),
);
tasks.insert(
"task5".to_string(),
create_test_task("task5", TaskStatus::Completed),
);
let (total, pending, running, completed, failed) = count_by_status(&tasks);
assert_eq!(
(total, pending, running, completed, failed),
(5, 1, 1, 2, 1)
);
}
}

View File

@@ -1,133 +1,30 @@
use crate::agents::sub_recipe_execution_tool::task_types::{SharedState, Task};
use crate::agents::sub_recipe_execution_tool::tasks::process_task;
use crate::agents::sub_recipe_execution_tool::types::{Task, TaskResult};
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::sync::Arc;
use tokio::sync::mpsc;
use tokio::time::{sleep, Duration};
#[cfg(test)]
mod tests {
use super::*;
use crate::agents::sub_recipe_execution_tool::types::Task;
#[tokio::test]
async fn test_spawn_worker_returns_handle() {
// Create a simple shared state for testing
let (task_tx, task_rx) = mpsc::channel::<Task>(1);
let (result_tx, _result_rx) = mpsc::channel::<TaskResult>(1);
let shared_state = Arc::new(SharedState {
task_receiver: Arc::new(tokio::sync::Mutex::new(task_rx)),
result_sender: result_tx,
active_workers: Arc::new(AtomicUsize::new(0)),
should_stop: Arc::new(AtomicBool::new(false)),
completed_tasks: Arc::new(AtomicUsize::new(0)),
});
// Test that spawn_worker returns a JoinHandle
let handle = spawn_worker(shared_state.clone(), 0, 5);
// Verify it's a JoinHandle by checking we can abort it
assert!(!handle.is_finished());
// Signal stop and close the channel to let the worker exit
shared_state.should_stop.store(true, Ordering::SeqCst);
drop(task_tx); // Close the channel
// Wait for the worker to finish
let result = handle.await;
assert!(result.is_ok());
}
async fn receive_task(state: &SharedState) -> Option<Task> {
let mut receiver = state.task_receiver.lock().await;
receiver.recv().await
}
pub struct SharedState {
pub task_receiver: Arc<tokio::sync::Mutex<mpsc::Receiver<Task>>>,
pub result_sender: mpsc::Sender<TaskResult>,
pub active_workers: Arc<AtomicUsize>,
pub should_stop: Arc<AtomicBool>,
pub completed_tasks: Arc<AtomicUsize>,
}
// Spawn a worker task
pub fn spawn_worker(
state: Arc<SharedState>,
worker_id: usize,
timeout_seconds: u64,
) -> tokio::task::JoinHandle<()> {
state.active_workers.fetch_add(1, Ordering::SeqCst);
pub fn spawn_worker(state: Arc<SharedState>, worker_id: usize) -> tokio::task::JoinHandle<()> {
state.increment_active_workers();
tokio::spawn(async move {
worker_loop(state, worker_id, timeout_seconds).await;
worker_loop(state, worker_id).await;
})
}
async fn worker_loop(state: Arc<SharedState>, _worker_id: usize, timeout_seconds: u64) {
loop {
// Try to receive a task
let task = {
let mut receiver = state.task_receiver.lock().await;
receiver.recv().await
};
async fn worker_loop(state: Arc<SharedState>, _worker_id: usize) {
while let Some(task) = receive_task(&state).await {
state.task_execution_tracker.start_task(&task.id).await;
let result = process_task(&task, state.task_execution_tracker.clone()).await;
match task {
Some(task) => {
// Process the task
let result = process_task(&task, timeout_seconds).await;
// Send result
let _ = state.result_sender.send(result).await;
// Update completed count
state.completed_tasks.fetch_add(1, Ordering::SeqCst);
}
None => {
// Channel closed, exit worker
break;
}
}
// Check if we should stop
if state.should_stop.load(Ordering::SeqCst) {
if let Err(e) = state.result_sender.send(result).await {
tracing::error!("Worker failed to send result: {}", e);
break;
}
}
// Worker is exiting
state.active_workers.fetch_sub(1, Ordering::SeqCst);
}
// Scaling controller that monitors queue and spawns workers
pub async fn run_scaler(
state: Arc<SharedState>,
task_count: usize,
max_workers: usize,
timeout_seconds: u64,
) {
let mut worker_count = 0;
loop {
sleep(Duration::from_millis(100)).await;
let active = state.active_workers.load(Ordering::SeqCst);
let completed = state.completed_tasks.load(Ordering::SeqCst);
let pending = task_count.saturating_sub(completed);
// Simple scaling logic: spawn worker if many pending tasks and under limit
if pending > active * 2 && active < max_workers && worker_count < max_workers {
let _handle = spawn_worker(state.clone(), worker_count, timeout_seconds);
worker_count += 1;
}
// If all tasks completed, signal stop
if completed >= task_count {
state.should_stop.store(true, Ordering::SeqCst);
break;
}
// If no active workers and tasks remaining, spawn one
if active == 0 && pending > 0 {
let _handle = spawn_worker(state.clone(), worker_count, timeout_seconds);
worker_count += 1;
}
}
state.decrement_active_workers();
}

View File

@@ -7,6 +7,7 @@ use crate::{
recipe_tools::sub_recipe_tools::{
create_sub_recipe_task, create_sub_recipe_task_tool, SUB_RECIPE_TASK_TOOL_NAME_PREFIX,
},
sub_recipe_execution_tool::tasks_manager::TasksManager,
tool_execution::ToolCallResult,
},
recipe::SubRecipe,
@@ -34,12 +35,6 @@ impl SubRecipeManager {
pub fn add_sub_recipe_tools(&mut self, sub_recipes_to_add: Vec<SubRecipe>) {
for sub_recipe in sub_recipes_to_add {
// let sub_recipe_key = format!(
// "{}_{}",
// SUB_RECIPE_TOOL_NAME_PREFIX,
// sub_recipe.name.clone()
// );
// let tool = create_sub_recipe_tool(&sub_recipe);
let sub_recipe_key = format!(
"{}_{}",
SUB_RECIPE_TASK_TOOL_NAME_PREFIX,
@@ -59,43 +54,22 @@ impl SubRecipeManager {
&self,
tool_name: &str,
params: Value,
tasks_manager: &TasksManager,
) -> ToolCallResult {
let result = self.call_sub_recipe_tool(tool_name, params).await;
let result = self
.call_sub_recipe_tool(tool_name, params, tasks_manager)
.await;
match result {
Ok(call_result) => ToolCallResult::from(Ok(call_result)),
Err(e) => ToolCallResult::from(Err(ToolError::ExecutionError(e.to_string()))),
}
}
// async fn call_sub_recipe_tool(
// &self,
// tool_name: &str,
// params: Value,
// ) -> Result<Vec<Content>, ToolError> {
// let sub_recipe = self.sub_recipes.get(tool_name).ok_or_else(|| {
// let sub_recipe_name = tool_name
// .strip_prefix(SUB_RECIPE_TOOL_NAME_PREFIX)
// .and_then(|s| s.strip_prefix("_"))
// .ok_or_else(|| {
// ToolError::InvalidParameters(format!(
// "Invalid sub-recipe tool name format: {}",
// tool_name
// ))
// })
// .unwrap();
// ToolError::InvalidParameters(format!("Sub-recipe '{}' not found", sub_recipe_name))
// })?;
// let output = run_sub_recipe(sub_recipe, params).await.map_err(|e| {
// ToolError::ExecutionError(format!("Sub-recipe execution failed: {}", e))
// })?;
// Ok(vec![Content::text(output)])
// }
async fn call_sub_recipe_tool(
&self,
tool_name: &str,
params: Value,
tasks_manager: &TasksManager,
) -> Result<Vec<Content>, ToolError> {
let sub_recipe = self.sub_recipes.get(tool_name).ok_or_else(|| {
let sub_recipe_name = tool_name
@@ -111,11 +85,10 @@ impl SubRecipeManager {
ToolError::InvalidParameters(format!("Sub-recipe '{}' not found", sub_recipe_name))
})?;
let output = create_sub_recipe_task(sub_recipe, params)
let output = create_sub_recipe_task(sub_recipe, params, tasks_manager)
.await
.map_err(|e| {
ToolError::ExecutionError(format!("Sub-recipe execution failed: {}", e))
ToolError::ExecutionError(format!("Sub-recipe task createion failed: {}", e))
})?;
Ok(vec![Content::text(output)])
}

View File

@@ -137,6 +137,13 @@ pub struct SubRecipe {
pub path: String,
#[serde(default, deserialize_with = "deserialize_value_map_as_string")]
pub values: Option<HashMap<String, String>>,
#[serde(default)]
pub sequential_when_repeated: bool,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct Execution {
#[serde(default)]
pub parallel: bool,
}
fn deserialize_value_map_as_string<'de, D>(