From 2e9a2a53be9d29eb864d974f38edf8b630fae3fa Mon Sep 17 00:00:00 2001 From: Lifei Zhou Date: Fri, 4 Jul 2025 13:13:31 +1000 Subject: [PATCH] Use command line to run sub agent and sub recipe (in sequence or parallel) (#3190) --- crates/goose/src/agents/agent.rs | 8 +- crates/goose/src/agents/mod.rs | 1 + .../agents/recipe_tools/sub_recipe_tools.rs | 90 +++--------- .../sub_recipe_execution_tool/executor.rs | 103 ++++++++++++++ .../agents/sub_recipe_execution_tool/lib.rs | 38 +++++ .../agents/sub_recipe_execution_tool/mod.rs | 6 + .../sub_recipe_execute_task_tool.rs | 124 ++++++++++++++++ .../agents/sub_recipe_execution_tool/tasks.rs | 119 ++++++++++++++++ .../agents/sub_recipe_execution_tool/types.rs | 69 +++++++++ .../sub_recipe_execution_tool/workers.rs | 133 ++++++++++++++++++ crates/goose/src/agents/sub_recipe_manager.rs | 47 ++++++- 11 files changed, 660 insertions(+), 78 deletions(-) create mode 100644 crates/goose/src/agents/sub_recipe_execution_tool/executor.rs create mode 100644 crates/goose/src/agents/sub_recipe_execution_tool/lib.rs create mode 100644 crates/goose/src/agents/sub_recipe_execution_tool/mod.rs create mode 100644 crates/goose/src/agents/sub_recipe_execution_tool/sub_recipe_execute_task_tool.rs create mode 100644 crates/goose/src/agents/sub_recipe_execution_tool/tasks.rs create mode 100644 crates/goose/src/agents/sub_recipe_execution_tool/types.rs create mode 100644 crates/goose/src/agents/sub_recipe_execution_tool/workers.rs diff --git a/crates/goose/src/agents/agent.rs b/crates/goose/src/agents/agent.rs index dea971e5..3ba40763 100644 --- a/crates/goose/src/agents/agent.rs +++ b/crates/goose/src/agents/agent.rs @@ -9,6 +9,9 @@ use futures::{stream, FutureExt, Stream, StreamExt, TryStreamExt}; use mcp_core::protocol::JsonRpcMessage; use crate::agents::final_output_tool::{FINAL_OUTPUT_CONTINUATION_MESSAGE, FINAL_OUTPUT_TOOL_NAME}; +use crate::agents::sub_recipe_execution_tool::sub_recipe_execute_task_tool::{ + self, SUB_RECIPE_EXECUTE_TASK_TOOL_NAME, +}; use crate::agents::sub_recipe_manager::SubRecipeManager; use crate::config::{Config, ExtensionConfigManager, PermissionManager}; use crate::message::Message; @@ -286,11 +289,12 @@ impl Agent { let extension_manager = self.extension_manager.read().await; let sub_recipe_manager = self.sub_recipe_manager.lock().await; - let result: ToolCallResult = if sub_recipe_manager.is_sub_recipe_tool(&tool_call.name) { sub_recipe_manager .dispatch_sub_recipe_tool_call(&tool_call.name, tool_call.arguments.clone()) .await + } else if tool_call.name == SUB_RECIPE_EXECUTE_TASK_TOOL_NAME { + sub_recipe_execute_task_tool::run_tasks(tool_call.arguments.clone()).await } else if tool_call.name == PLATFORM_READ_RESOURCE_TOOL_NAME { // Check if the tool is read_resource and handle it separately ToolCallResult::from( @@ -574,6 +578,8 @@ impl Agent { if let Some(final_output_tool) = self.final_output_tool.lock().await.as_ref() { prefixed_tools.push(final_output_tool.tool()); } + prefixed_tools + .push(sub_recipe_execute_task_tool::create_sub_recipe_execute_task_tool()); } prefixed_tools diff --git a/crates/goose/src/agents/mod.rs b/crates/goose/src/agents/mod.rs index 098521c0..353e57ac 100644 --- a/crates/goose/src/agents/mod.rs +++ b/crates/goose/src/agents/mod.rs @@ -11,6 +11,7 @@ mod reply_parts; mod router_tool_selector; mod router_tools; mod schedule_tool; +pub mod sub_recipe_execution_tool; pub mod sub_recipe_manager; pub mod subagent; pub mod subagent_handler; diff --git a/crates/goose/src/agents/recipe_tools/sub_recipe_tools.rs b/crates/goose/src/agents/recipe_tools/sub_recipe_tools.rs index 2fd4f504..928cf8bd 100644 --- a/crates/goose/src/agents/recipe_tools/sub_recipe_tools.rs +++ b/crates/goose/src/agents/recipe_tools/sub_recipe_tools.rs @@ -3,25 +3,20 @@ use std::{collections::HashMap, fs}; use anyhow::Result; use mcp_core::tool::{Tool, ToolAnnotations}; use serde_json::{json, Map, Value}; -use tokio::io::{AsyncBufReadExt, BufReader}; -use tokio::process::Command; +use crate::agents::sub_recipe_execution_tool::lib::Task; use crate::recipe::{Recipe, RecipeParameter, RecipeParameterRequirement, SubRecipe}; -pub const SUB_RECIPE_TOOL_NAME_PREFIX: &str = "subrecipe__run_"; +pub const SUB_RECIPE_TASK_TOOL_NAME_PREFIX: &str = "subrecipe__create_task"; -pub fn create_sub_recipe_tool(sub_recipe: &SubRecipe) -> Tool { +pub fn create_sub_recipe_task_tool(sub_recipe: &SubRecipe) -> Tool { let input_schema = get_input_schema(sub_recipe).unwrap(); Tool::new( - format!("{}_{}", SUB_RECIPE_TOOL_NAME_PREFIX, sub_recipe.name), - "Run a sub recipe. - Use this tool when you need to run a sub-recipe. - The sub recipe will be run with the provided parameters - and return the output of the sub recipe." - .to_string(), + format!("{}_{}", SUB_RECIPE_TASK_TOOL_NAME_PREFIX, sub_recipe.name), + "Before running this sub recipe, you should first create a task with this tool and then pass the task to the task executor".to_string(), input_schema, Some(ToolAnnotations { - title: Some(format!("run sub recipe {}", sub_recipe.name)), + title: Some(format!("create sub recipe task {}", sub_recipe.name)), read_only_hint: false, destructive_hint: true, idempotent_hint: false, @@ -99,68 +94,23 @@ fn prepare_command_params( Ok(sub_recipe_params) } -pub async fn run_sub_recipe(sub_recipe: &SubRecipe, params: Value) -> Result { +pub async fn create_sub_recipe_task(sub_recipe: &SubRecipe, params: Value) -> Result { let command_params = prepare_command_params(sub_recipe, params)?; - - let mut command = Command::new("goose"); - command.arg("run").arg("--recipe").arg(&sub_recipe.path); - - for (key, value) in command_params { - command.arg("--params").arg(format!("{}={}", key, value)); - } - - command.stdout(std::process::Stdio::piped()); - command.stderr(std::process::Stdio::piped()); - - let mut child = command - .spawn() - .map_err(|e| anyhow::anyhow!("Failed to spawn: {}", e))?; - - let stdout = child.stdout.take().expect("Failed to capture stdout"); - let stderr = child.stderr.take().expect("Failed to capture stderr"); - - let mut stdout_reader = BufReader::new(stdout).lines(); - let mut stderr_reader = BufReader::new(stderr).lines(); - let stdout_sub_recipe_name = sub_recipe.name.clone(); - let stderr_sub_recipe_name = sub_recipe.name.clone(); - - // Spawn background tasks to read from stdout and stderr - let stdout_task = tokio::spawn(async move { - let mut buffer = String::new(); - while let Ok(Some(line)) = stdout_reader.next_line().await { - println!("[sub-recipe {}] {}", stdout_sub_recipe_name, line); - buffer.push_str(&line); - buffer.push('\n'); + let payload = json!({ + "sub_recipe": { + "name": sub_recipe.name.clone(), + "command_parameters": command_params, + "recipe_path": sub_recipe.path.clone(), } - buffer }); - - let stderr_task = tokio::spawn(async move { - let mut buffer = String::new(); - while let Ok(Some(line)) = stderr_reader.next_line().await { - eprintln!( - "[stderr for sub-recipe {}] {}", - stderr_sub_recipe_name, line - ); - buffer.push_str(&line); - buffer.push('\n'); - } - buffer - }); - - let status = child - .wait() - .await - .map_err(|e| anyhow::anyhow!("Failed to wait for process: {}", e))?; - - let stdout_output = stdout_task.await.unwrap(); - let stderr_output = stderr_task.await.unwrap(); - - if status.success() { - Ok(stdout_output) - } else { - Err(anyhow::anyhow!("Command failed:\n{}", stderr_output)) - } + let task = Task { + id: uuid::Uuid::new_v4().to_string(), + task_type: "sub_recipe".to_string(), + payload, + }; + let task_json = serde_json::to_string(&task) + .map_err(|e| anyhow::anyhow!("Failed to serialize Task: {}", e))?; + Ok(task_json) } #[cfg(test)] diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/executor.rs b/crates/goose/src/agents/sub_recipe_execution_tool/executor.rs new file mode 100644 index 00000000..b796d412 --- /dev/null +++ b/crates/goose/src/agents/sub_recipe_execution_tool/executor.rs @@ -0,0 +1,103 @@ +use std::sync::atomic::{AtomicBool, AtomicUsize}; +use std::sync::Arc; +use tokio::sync::mpsc; +use tokio::time::Instant; + +use crate::agents::sub_recipe_execution_tool::lib::{ + Config, ExecutionResponse, ExecutionStats, Task, TaskResult, +}; +use crate::agents::sub_recipe_execution_tool::tasks::process_task; +use crate::agents::sub_recipe_execution_tool::workers::{run_scaler, spawn_worker, SharedState}; + +pub async fn execute_single_task(task: &Task, config: Config) -> ExecutionResponse { + let start_time = Instant::now(); + let result = process_task(task, config.timeout_seconds).await; + + let execution_time = start_time.elapsed().as_millis(); + let completed = if result.status == "success" { 1 } else { 0 }; + let failed = if result.status == "failed" { 1 } else { 0 }; + + ExecutionResponse { + status: "completed".to_string(), + results: vec![result], + stats: ExecutionStats { + total_tasks: 1, + completed, + failed, + execution_time_ms: execution_time, + }, + } +} + +// Main parallel execution function +pub async fn parallel_execute(tasks: Vec, config: Config) -> ExecutionResponse { + let start_time = Instant::now(); + let task_count = tasks.len(); + + // Create channels + let (task_tx, task_rx) = mpsc::channel::(task_count); + let (result_tx, mut result_rx) = mpsc::channel::(task_count); + + // Initialize shared state + let shared_state = Arc::new(SharedState { + task_receiver: Arc::new(tokio::sync::Mutex::new(task_rx)), + result_sender: result_tx, + active_workers: Arc::new(AtomicUsize::new(0)), + should_stop: Arc::new(AtomicBool::new(false)), + completed_tasks: Arc::new(AtomicUsize::new(0)), + }); + + // Send all tasks to the queue + for task in tasks.clone() { + let _ = task_tx.send(task).await; + } + // Close sender so workers know when queue is empty + drop(task_tx); + + // Start initial workers + let mut worker_handles = Vec::new(); + for i in 0..config.initial_workers { + let handle = spawn_worker(shared_state.clone(), i, config.timeout_seconds); + worker_handles.push(handle); + } + + // Start the scaler + let scaler_state = shared_state.clone(); + let scaler_handle = tokio::spawn(async move { + run_scaler( + scaler_state, + task_count, + config.max_workers, + config.timeout_seconds, + ) + .await; + }); + + // Collect results + let mut results = Vec::new(); + while let Some(result) = result_rx.recv().await { + results.push(result); + if results.len() >= task_count { + break; + } + } + + // Wait for scaler to finish + let _ = scaler_handle.await; + + // Calculate stats + let execution_time = start_time.elapsed().as_millis(); + let completed = results.iter().filter(|r| r.status == "success").count(); + let failed = results.iter().filter(|r| r.status == "failed").count(); + + ExecutionResponse { + status: "completed".to_string(), + results, + stats: ExecutionStats { + total_tasks: task_count, + completed, + failed, + execution_time_ms: execution_time, + }, + } +} diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/lib.rs b/crates/goose/src/agents/sub_recipe_execution_tool/lib.rs new file mode 100644 index 00000000..9df784a4 --- /dev/null +++ b/crates/goose/src/agents/sub_recipe_execution_tool/lib.rs @@ -0,0 +1,38 @@ +use crate::agents::sub_recipe_execution_tool::executor::execute_single_task; +pub use crate::agents::sub_recipe_execution_tool::executor::parallel_execute; +pub use crate::agents::sub_recipe_execution_tool::types::{ + Config, ExecutionResponse, ExecutionStats, Task, TaskResult, +}; + +use serde_json::Value; + +pub async fn execute_tasks(input: Value, execution_mode: &str) -> Result { + let tasks: Vec = + serde_json::from_value(input.get("tasks").ok_or("Missing tasks field")?.clone()) + .map_err(|e| format!("Failed to parse tasks: {}", e))?; + + let config: Config = if let Some(config_value) = input.get("config") { + serde_json::from_value(config_value.clone()) + .map_err(|e| format!("Failed to parse config: {}", e))? + } else { + Config::default() + }; + let task_count = tasks.len(); + match execution_mode { + "sequential" => { + if task_count == 1 { + let response = execute_single_task(&tasks[0], config).await; + serde_json::to_value(response) + .map_err(|e| format!("Failed to serialize response: {}", e)) + } else { + Err("Sequential execution mode requires exactly one task".to_string()) + } + } + "parallel" => { + let response = parallel_execute(tasks, config).await; + serde_json::to_value(response) + .map_err(|e| format!("Failed to serialize response: {}", e)) + } + _ => Err("Invalid execution mode".to_string()), + } +} diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/mod.rs b/crates/goose/src/agents/sub_recipe_execution_tool/mod.rs new file mode 100644 index 00000000..a49791e2 --- /dev/null +++ b/crates/goose/src/agents/sub_recipe_execution_tool/mod.rs @@ -0,0 +1,6 @@ +mod executor; +pub mod lib; +pub mod sub_recipe_execute_task_tool; +mod tasks; +mod types; +mod workers; diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/sub_recipe_execute_task_tool.rs b/crates/goose/src/agents/sub_recipe_execution_tool/sub_recipe_execute_task_tool.rs new file mode 100644 index 00000000..46738b81 --- /dev/null +++ b/crates/goose/src/agents/sub_recipe_execution_tool/sub_recipe_execute_task_tool.rs @@ -0,0 +1,124 @@ +use mcp_core::{tool::ToolAnnotations, Content, Tool, ToolError}; +use serde_json::Value; + +use crate::agents::{ + sub_recipe_execution_tool::lib::execute_tasks, tool_execution::ToolCallResult, +}; + +pub const SUB_RECIPE_EXECUTE_TASK_TOOL_NAME: &str = "sub_recipe__execute_task"; +pub fn create_sub_recipe_execute_task_tool() -> Tool { + Tool::new( + SUB_RECIPE_EXECUTE_TASK_TOOL_NAME, + "Only use this tool when you execute sub recipe task. +EXECUTION STRATEGY: +- DEFAULT: Execute tasks sequentially (one at a time) unless user explicitly requests parallel execution +- PARALLEL: Only when user explicitly uses keywords like 'parallel', 'simultaneously', 'at the same time', 'concurrently' + +IMPLEMENTATION: +- Sequential execution: Call this tool multiple times, passing exactly ONE task per call +- Parallel execution: Call this tool once, passing an ARRAY of all tasks + +EXAMPLES: +- User: 'get weather and tell me a joke' → Sequential (2 separate tool calls, 1 task each) +- User: 'get weather and joke in parallel' → Parallel (1 tool call with array of 2 tasks) +- User: 'run these simultaneously' → Parallel (1 tool call with task array) +- User: 'do task A then task B' → Sequential (2 separate tool calls)", + serde_json::json!({ + "type": "object", + "properties": { + "execution_mode": { + "type": "string", + "enum": ["sequential", "parallel"], + "default": "sequential", + "description": "Execution strategy for multiple tasks. Use 'sequential' (default) unless user explicitly requests parallel execution with words like 'parallel', 'simultaneously', 'at the same time', or 'concurrently'." + }, + "tasks": { + "type": "array", + "items": { + "type": "object", + "properties": { + "id": { + "type": "string", + "description": "Unique identifier for the task" + }, + "task_type": { + "type": "string", + "enum": ["sub_recipe", "text_instruction"], + "default": "sub_recipe", + "description": "the type of task to execute, can be one of: sub_recipe, text_instruction" + }, + "payload": { + "type": "object", + "properties": { + "sub_recipe": { + "type": "object", + "description": "sub recipe to execute", + "properties": { + "name": { + "type": "string", + "description": "name of the sub recipe to execute" + }, + "recipe_path": { + "type": "string", + "description": "path of the sub recipe file" + }, + "command_parameters": { + "type": "object", + "description": "parameters to pass to run recipe command with sub recipe file" + } + } + }, + "text_instruction": { + "type": "string", + "description": "text instruction to execute" + } + } + } + }, + "required": ["id", "payload"] + }, + "description": "The tasks to run in parallel" + }, + "config": { + "type": "object", + "properties": { + "timeout_seconds": { + "type": "number" + }, + "max_workers": { + "type": "number" + }, + "initial_workers": { + "type": "number" + } + } + } + }, + "required": ["tasks"] + }), + Some(ToolAnnotations { + title: Some("Run tasks in parallel".to_string()), + read_only_hint: false, + destructive_hint: true, + idempotent_hint: false, + open_world_hint: true, + }), + ) +} + +pub async fn run_tasks(execute_data: Value) -> ToolCallResult { + let execute_data_clone = execute_data.clone(); + let default_execution_mode_value = Value::String("sequential".to_string()); + let execution_mode = execute_data_clone + .get("execution_mode") + .unwrap_or(&default_execution_mode_value) + .as_str() + .unwrap_or("sequential"); + match execute_tasks(execute_data, execution_mode).await { + Ok(result) => { + let output = serde_json::to_string(&result).unwrap(); + ToolCallResult::from(Ok(vec![Content::text(output)])) + } + Err(e) => ToolCallResult::from(Err(ToolError::ExecutionError(e.to_string()))), + } +} diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/tasks.rs b/crates/goose/src/agents/sub_recipe_execution_tool/tasks.rs new file mode 100644 index 00000000..4e4584aa --- /dev/null +++ b/crates/goose/src/agents/sub_recipe_execution_tool/tasks.rs @@ -0,0 +1,119 @@ +use serde_json::Value; +use std::process::Stdio; +use std::time::Duration; +use tokio::io::{AsyncBufReadExt, BufReader}; +use tokio::process::Command; +use tokio::time::timeout; + +use crate::agents::sub_recipe_execution_tool::types::{Task, TaskResult}; + +// Process a single task based on its type +pub async fn process_task(task: &Task, timeout_seconds: u64) -> TaskResult { + let task_clone = task.clone(); + let timeout_duration = Duration::from_secs(timeout_seconds); + + // Execute with timeout + match timeout(timeout_duration, execute_task(task_clone)).await { + Ok(Ok(data)) => TaskResult { + task_id: task.id.clone(), + status: "success".to_string(), + data: Some(data), + error: None, + }, + Ok(Err(error)) => TaskResult { + task_id: task.id.clone(), + status: "failed".to_string(), + data: None, + error: Some(error), + }, + Err(_) => TaskResult { + task_id: task.id.clone(), + status: "failed".to_string(), + data: None, + error: Some("Task timeout".to_string()), + }, + } +} + +async fn execute_task(task: Task) -> Result { + let mut output_identifier = task.id.clone(); + let mut command = if task.task_type == "sub_recipe" { + let sub_recipe = task.payload.get("sub_recipe").unwrap(); + let sub_recipe_name = sub_recipe.get("name").unwrap().as_str().unwrap(); + let path = sub_recipe.get("recipe_path").unwrap().as_str().unwrap(); + let command_parameters = sub_recipe.get("command_parameters").unwrap(); + output_identifier = format!("sub-recipe {}", sub_recipe_name); + let mut cmd = Command::new("goose"); + cmd.arg("run").arg("--recipe").arg(path); + if let Some(params_map) = command_parameters.as_object() { + for (key, value) in params_map { + let key_str = key.to_string(); + let value_str = value.as_str().unwrap_or(&value.to_string()).to_string(); + cmd.arg("--params") + .arg(format!("{}={}", key_str, value_str)); + } + } + cmd + } else { + let text = task + .payload + .get("text_instruction") + .unwrap() + .as_str() + .unwrap(); + let mut cmd = Command::new("goose"); + cmd.arg("run").arg("--text").arg(text); + cmd + }; + + // Configure to capture stdout + command.stdout(Stdio::piped()); + command.stderr(Stdio::piped()); + + // Spawn the child process + let mut child = command + .spawn() + .map_err(|e| format!("Failed to spawn goose: {}", e))?; + + let stdout = child.stdout.take().expect("Failed to capture stdout"); + let stderr = child.stderr.take().expect("Failed to capture stderr"); + + let mut stdout_reader = BufReader::new(stdout).lines(); + let mut stderr_reader = BufReader::new(stderr).lines(); + + // Spawn background tasks to read from stdout and stderr + let output_identifier_clone = output_identifier.clone(); + let stdout_task = tokio::spawn(async move { + let mut buffer = String::new(); + while let Ok(Some(line)) = stdout_reader.next_line().await { + println!("[{}] {}", output_identifier_clone, line); + buffer.push_str(&line); + buffer.push('\n'); + } + buffer + }); + + let stderr_task = tokio::spawn(async move { + let mut buffer = String::new(); + while let Ok(Some(line)) = stderr_reader.next_line().await { + eprintln!("[stderr for {}] {}", output_identifier, line); + buffer.push_str(&line); + buffer.push('\n'); + } + buffer + }); + + let status = child + .wait() + .await + .map_err(|e| format!("Failed to wait for process: {}", e))?; + + let stdout_output = stdout_task.await.unwrap(); + let stderr_output = stderr_task.await.unwrap(); + + if status.success() { + Ok(Value::String(stdout_output)) + } else { + Err(format!("Command failed:\n{}", stderr_output)) + } +} diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/types.rs b/crates/goose/src/agents/sub_recipe_execution_tool/types.rs new file mode 100644 index 00000000..ede71dbf --- /dev/null +++ b/crates/goose/src/agents/sub_recipe_execution_tool/types.rs @@ -0,0 +1,69 @@ +use serde::{Deserialize, Serialize}; +use serde_json::Value; + +// Task definition that LLMs will send +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Task { + pub id: String, + pub task_type: String, + pub payload: Value, +} + +// Result for each task +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TaskResult { + pub task_id: String, + pub status: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub data: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub error: Option, +} + +// Configuration for the parallel executor +#[derive(Debug, Clone, Deserialize)] +pub struct Config { + #[serde(default = "default_max_workers")] + pub max_workers: usize, + #[serde(default = "default_timeout")] + pub timeout_seconds: u64, + #[serde(default = "default_initial_workers")] + pub initial_workers: usize, +} + +impl Default for Config { + fn default() -> Self { + Self { + max_workers: default_max_workers(), + timeout_seconds: default_timeout(), + initial_workers: default_initial_workers(), + } + } +} + +fn default_max_workers() -> usize { + 10 +} +fn default_timeout() -> u64 { + 300 +} +fn default_initial_workers() -> usize { + 2 +} + +// Stats for the execution +#[derive(Debug, Serialize)] +pub struct ExecutionStats { + pub total_tasks: usize, + pub completed: usize, + pub failed: usize, + pub execution_time_ms: u128, +} + +// Main response structure +#[derive(Debug, Serialize)] +pub struct ExecutionResponse { + pub status: String, + pub results: Vec, + pub stats: ExecutionStats, +} diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/workers.rs b/crates/goose/src/agents/sub_recipe_execution_tool/workers.rs new file mode 100644 index 00000000..e48f19c4 --- /dev/null +++ b/crates/goose/src/agents/sub_recipe_execution_tool/workers.rs @@ -0,0 +1,133 @@ +use crate::agents::sub_recipe_execution_tool::tasks::process_task; +use crate::agents::sub_recipe_execution_tool::types::{Task, TaskResult}; +use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; +use std::sync::Arc; +use tokio::sync::mpsc; +use tokio::time::{sleep, Duration}; + +#[cfg(test)] +mod tests { + use super::*; + use crate::agents::sub_recipe_execution_tool::types::Task; + + #[tokio::test] + async fn test_spawn_worker_returns_handle() { + // Create a simple shared state for testing + let (task_tx, task_rx) = mpsc::channel::(1); + let (result_tx, _result_rx) = mpsc::channel::(1); + + let shared_state = Arc::new(SharedState { + task_receiver: Arc::new(tokio::sync::Mutex::new(task_rx)), + result_sender: result_tx, + active_workers: Arc::new(AtomicUsize::new(0)), + should_stop: Arc::new(AtomicBool::new(false)), + completed_tasks: Arc::new(AtomicUsize::new(0)), + }); + + // Test that spawn_worker returns a JoinHandle + let handle = spawn_worker(shared_state.clone(), 0, 5); + + // Verify it's a JoinHandle by checking we can abort it + assert!(!handle.is_finished()); + + // Signal stop and close the channel to let the worker exit + shared_state.should_stop.store(true, Ordering::SeqCst); + drop(task_tx); // Close the channel + + // Wait for the worker to finish + let result = handle.await; + assert!(result.is_ok()); + } +} + +pub struct SharedState { + pub task_receiver: Arc>>, + pub result_sender: mpsc::Sender, + pub active_workers: Arc, + pub should_stop: Arc, + pub completed_tasks: Arc, +} + +// Spawn a worker task +pub fn spawn_worker( + state: Arc, + worker_id: usize, + timeout_seconds: u64, +) -> tokio::task::JoinHandle<()> { + state.active_workers.fetch_add(1, Ordering::SeqCst); + + tokio::spawn(async move { + worker_loop(state, worker_id, timeout_seconds).await; + }) +} + +async fn worker_loop(state: Arc, _worker_id: usize, timeout_seconds: u64) { + loop { + // Try to receive a task + let task = { + let mut receiver = state.task_receiver.lock().await; + receiver.recv().await + }; + + match task { + Some(task) => { + // Process the task + let result = process_task(&task, timeout_seconds).await; + + // Send result + let _ = state.result_sender.send(result).await; + + // Update completed count + state.completed_tasks.fetch_add(1, Ordering::SeqCst); + } + None => { + // Channel closed, exit worker + break; + } + } + + // Check if we should stop + if state.should_stop.load(Ordering::SeqCst) { + break; + } + } + + // Worker is exiting + state.active_workers.fetch_sub(1, Ordering::SeqCst); +} + +// Scaling controller that monitors queue and spawns workers +pub async fn run_scaler( + state: Arc, + task_count: usize, + max_workers: usize, + timeout_seconds: u64, +) { + let mut worker_count = 0; + + loop { + sleep(Duration::from_millis(100)).await; + + let active = state.active_workers.load(Ordering::SeqCst); + let completed = state.completed_tasks.load(Ordering::SeqCst); + let pending = task_count.saturating_sub(completed); + + // Simple scaling logic: spawn worker if many pending tasks and under limit + if pending > active * 2 && active < max_workers && worker_count < max_workers { + let _handle = spawn_worker(state.clone(), worker_count, timeout_seconds); + worker_count += 1; + } + + // If all tasks completed, signal stop + if completed >= task_count { + state.should_stop.store(true, Ordering::SeqCst); + break; + } + + // If no active workers and tasks remaining, spawn one + if active == 0 && pending > 0 { + let _handle = spawn_worker(state.clone(), worker_count, timeout_seconds); + worker_count += 1; + } + } +} diff --git a/crates/goose/src/agents/sub_recipe_manager.rs b/crates/goose/src/agents/sub_recipe_manager.rs index 3637c947..2441684b 100644 --- a/crates/goose/src/agents/sub_recipe_manager.rs +++ b/crates/goose/src/agents/sub_recipe_manager.rs @@ -5,7 +5,7 @@ use std::collections::HashMap; use crate::{ agents::{ recipe_tools::sub_recipe_tools::{ - create_sub_recipe_tool, run_sub_recipe, SUB_RECIPE_TOOL_NAME_PREFIX, + create_sub_recipe_task, create_sub_recipe_task_tool, SUB_RECIPE_TASK_TOOL_NAME_PREFIX, }, tool_execution::ToolCallResult, }, @@ -34,12 +34,18 @@ impl SubRecipeManager { pub fn add_sub_recipe_tools(&mut self, sub_recipes_to_add: Vec) { for sub_recipe in sub_recipes_to_add { + // let sub_recipe_key = format!( + // "{}_{}", + // SUB_RECIPE_TOOL_NAME_PREFIX, + // sub_recipe.name.clone() + // ); + // let tool = create_sub_recipe_tool(&sub_recipe); let sub_recipe_key = format!( "{}_{}", - SUB_RECIPE_TOOL_NAME_PREFIX, + SUB_RECIPE_TASK_TOOL_NAME_PREFIX, sub_recipe.name.clone() ); - let tool = create_sub_recipe_tool(&sub_recipe); + let tool = create_sub_recipe_task_tool(&sub_recipe); self.sub_recipe_tools.insert(sub_recipe_key.clone(), tool); self.sub_recipes.insert(sub_recipe_key.clone(), sub_recipe); } @@ -61,6 +67,31 @@ impl SubRecipeManager { } } + // async fn call_sub_recipe_tool( + // &self, + // tool_name: &str, + // params: Value, + // ) -> Result, ToolError> { + // let sub_recipe = self.sub_recipes.get(tool_name).ok_or_else(|| { + // let sub_recipe_name = tool_name + // .strip_prefix(SUB_RECIPE_TOOL_NAME_PREFIX) + // .and_then(|s| s.strip_prefix("_")) + // .ok_or_else(|| { + // ToolError::InvalidParameters(format!( + // "Invalid sub-recipe tool name format: {}", + // tool_name + // )) + // }) + // .unwrap(); + + // ToolError::InvalidParameters(format!("Sub-recipe '{}' not found", sub_recipe_name)) + // })?; + + // let output = run_sub_recipe(sub_recipe, params).await.map_err(|e| { + // ToolError::ExecutionError(format!("Sub-recipe execution failed: {}", e)) + // })?; + // Ok(vec![Content::text(output)]) + // } async fn call_sub_recipe_tool( &self, tool_name: &str, @@ -68,7 +99,7 @@ impl SubRecipeManager { ) -> Result, ToolError> { let sub_recipe = self.sub_recipes.get(tool_name).ok_or_else(|| { let sub_recipe_name = tool_name - .strip_prefix(SUB_RECIPE_TOOL_NAME_PREFIX) + .strip_prefix(SUB_RECIPE_TASK_TOOL_NAME_PREFIX) .and_then(|s| s.strip_prefix("_")) .ok_or_else(|| { ToolError::InvalidParameters(format!( @@ -81,9 +112,11 @@ impl SubRecipeManager { ToolError::InvalidParameters(format!("Sub-recipe '{}' not found", sub_recipe_name)) })?; - let output = run_sub_recipe(sub_recipe, params).await.map_err(|e| { - ToolError::ExecutionError(format!("Sub-recipe execution failed: {}", e)) - })?; + let output = create_sub_recipe_task(sub_recipe, params) + .await + .map_err(|e| { + ToolError::ExecutionError(format!("Sub-recipe execution failed: {}", e)) + })?; Ok(vec![Content::text(output)]) } }