mirror of
https://github.com/aljazceru/goose.git
synced 2026-02-18 13:04:38 +01:00
Use command line to run sub agent and sub recipe (in sequence or parallel) (#3190)
This commit is contained in:
@@ -9,6 +9,9 @@ use futures::{stream, FutureExt, Stream, StreamExt, TryStreamExt};
|
||||
use mcp_core::protocol::JsonRpcMessage;
|
||||
|
||||
use crate::agents::final_output_tool::{FINAL_OUTPUT_CONTINUATION_MESSAGE, FINAL_OUTPUT_TOOL_NAME};
|
||||
use crate::agents::sub_recipe_execution_tool::sub_recipe_execute_task_tool::{
|
||||
self, SUB_RECIPE_EXECUTE_TASK_TOOL_NAME,
|
||||
};
|
||||
use crate::agents::sub_recipe_manager::SubRecipeManager;
|
||||
use crate::config::{Config, ExtensionConfigManager, PermissionManager};
|
||||
use crate::message::Message;
|
||||
@@ -286,11 +289,12 @@ impl Agent {
|
||||
|
||||
let extension_manager = self.extension_manager.read().await;
|
||||
let sub_recipe_manager = self.sub_recipe_manager.lock().await;
|
||||
|
||||
let result: ToolCallResult = if sub_recipe_manager.is_sub_recipe_tool(&tool_call.name) {
|
||||
sub_recipe_manager
|
||||
.dispatch_sub_recipe_tool_call(&tool_call.name, tool_call.arguments.clone())
|
||||
.await
|
||||
} else if tool_call.name == SUB_RECIPE_EXECUTE_TASK_TOOL_NAME {
|
||||
sub_recipe_execute_task_tool::run_tasks(tool_call.arguments.clone()).await
|
||||
} else if tool_call.name == PLATFORM_READ_RESOURCE_TOOL_NAME {
|
||||
// Check if the tool is read_resource and handle it separately
|
||||
ToolCallResult::from(
|
||||
@@ -574,6 +578,8 @@ impl Agent {
|
||||
if let Some(final_output_tool) = self.final_output_tool.lock().await.as_ref() {
|
||||
prefixed_tools.push(final_output_tool.tool());
|
||||
}
|
||||
prefixed_tools
|
||||
.push(sub_recipe_execute_task_tool::create_sub_recipe_execute_task_tool());
|
||||
}
|
||||
|
||||
prefixed_tools
|
||||
|
||||
@@ -11,6 +11,7 @@ mod reply_parts;
|
||||
mod router_tool_selector;
|
||||
mod router_tools;
|
||||
mod schedule_tool;
|
||||
pub mod sub_recipe_execution_tool;
|
||||
pub mod sub_recipe_manager;
|
||||
pub mod subagent;
|
||||
pub mod subagent_handler;
|
||||
|
||||
@@ -3,25 +3,20 @@ use std::{collections::HashMap, fs};
|
||||
use anyhow::Result;
|
||||
use mcp_core::tool::{Tool, ToolAnnotations};
|
||||
use serde_json::{json, Map, Value};
|
||||
use tokio::io::{AsyncBufReadExt, BufReader};
|
||||
use tokio::process::Command;
|
||||
|
||||
use crate::agents::sub_recipe_execution_tool::lib::Task;
|
||||
use crate::recipe::{Recipe, RecipeParameter, RecipeParameterRequirement, SubRecipe};
|
||||
|
||||
pub const SUB_RECIPE_TOOL_NAME_PREFIX: &str = "subrecipe__run_";
|
||||
pub const SUB_RECIPE_TASK_TOOL_NAME_PREFIX: &str = "subrecipe__create_task";
|
||||
|
||||
pub fn create_sub_recipe_tool(sub_recipe: &SubRecipe) -> Tool {
|
||||
pub fn create_sub_recipe_task_tool(sub_recipe: &SubRecipe) -> Tool {
|
||||
let input_schema = get_input_schema(sub_recipe).unwrap();
|
||||
Tool::new(
|
||||
format!("{}_{}", SUB_RECIPE_TOOL_NAME_PREFIX, sub_recipe.name),
|
||||
"Run a sub recipe.
|
||||
Use this tool when you need to run a sub-recipe.
|
||||
The sub recipe will be run with the provided parameters
|
||||
and return the output of the sub recipe."
|
||||
.to_string(),
|
||||
format!("{}_{}", SUB_RECIPE_TASK_TOOL_NAME_PREFIX, sub_recipe.name),
|
||||
"Before running this sub recipe, you should first create a task with this tool and then pass the task to the task executor".to_string(),
|
||||
input_schema,
|
||||
Some(ToolAnnotations {
|
||||
title: Some(format!("run sub recipe {}", sub_recipe.name)),
|
||||
title: Some(format!("create sub recipe task {}", sub_recipe.name)),
|
||||
read_only_hint: false,
|
||||
destructive_hint: true,
|
||||
idempotent_hint: false,
|
||||
@@ -99,68 +94,23 @@ fn prepare_command_params(
|
||||
Ok(sub_recipe_params)
|
||||
}
|
||||
|
||||
pub async fn run_sub_recipe(sub_recipe: &SubRecipe, params: Value) -> Result<String> {
|
||||
pub async fn create_sub_recipe_task(sub_recipe: &SubRecipe, params: Value) -> Result<String> {
|
||||
let command_params = prepare_command_params(sub_recipe, params)?;
|
||||
|
||||
let mut command = Command::new("goose");
|
||||
command.arg("run").arg("--recipe").arg(&sub_recipe.path);
|
||||
|
||||
for (key, value) in command_params {
|
||||
command.arg("--params").arg(format!("{}={}", key, value));
|
||||
}
|
||||
|
||||
command.stdout(std::process::Stdio::piped());
|
||||
command.stderr(std::process::Stdio::piped());
|
||||
|
||||
let mut child = command
|
||||
.spawn()
|
||||
.map_err(|e| anyhow::anyhow!("Failed to spawn: {}", e))?;
|
||||
|
||||
let stdout = child.stdout.take().expect("Failed to capture stdout");
|
||||
let stderr = child.stderr.take().expect("Failed to capture stderr");
|
||||
|
||||
let mut stdout_reader = BufReader::new(stdout).lines();
|
||||
let mut stderr_reader = BufReader::new(stderr).lines();
|
||||
let stdout_sub_recipe_name = sub_recipe.name.clone();
|
||||
let stderr_sub_recipe_name = sub_recipe.name.clone();
|
||||
|
||||
// Spawn background tasks to read from stdout and stderr
|
||||
let stdout_task = tokio::spawn(async move {
|
||||
let mut buffer = String::new();
|
||||
while let Ok(Some(line)) = stdout_reader.next_line().await {
|
||||
println!("[sub-recipe {}] {}", stdout_sub_recipe_name, line);
|
||||
buffer.push_str(&line);
|
||||
buffer.push('\n');
|
||||
let payload = json!({
|
||||
"sub_recipe": {
|
||||
"name": sub_recipe.name.clone(),
|
||||
"command_parameters": command_params,
|
||||
"recipe_path": sub_recipe.path.clone(),
|
||||
}
|
||||
buffer
|
||||
});
|
||||
|
||||
let stderr_task = tokio::spawn(async move {
|
||||
let mut buffer = String::new();
|
||||
while let Ok(Some(line)) = stderr_reader.next_line().await {
|
||||
eprintln!(
|
||||
"[stderr for sub-recipe {}] {}",
|
||||
stderr_sub_recipe_name, line
|
||||
);
|
||||
buffer.push_str(&line);
|
||||
buffer.push('\n');
|
||||
}
|
||||
buffer
|
||||
});
|
||||
|
||||
let status = child
|
||||
.wait()
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("Failed to wait for process: {}", e))?;
|
||||
|
||||
let stdout_output = stdout_task.await.unwrap();
|
||||
let stderr_output = stderr_task.await.unwrap();
|
||||
|
||||
if status.success() {
|
||||
Ok(stdout_output)
|
||||
} else {
|
||||
Err(anyhow::anyhow!("Command failed:\n{}", stderr_output))
|
||||
}
|
||||
let task = Task {
|
||||
id: uuid::Uuid::new_v4().to_string(),
|
||||
task_type: "sub_recipe".to_string(),
|
||||
payload,
|
||||
};
|
||||
let task_json = serde_json::to_string(&task)
|
||||
.map_err(|e| anyhow::anyhow!("Failed to serialize Task: {}", e))?;
|
||||
Ok(task_json)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
||||
103
crates/goose/src/agents/sub_recipe_execution_tool/executor.rs
Normal file
103
crates/goose/src/agents/sub_recipe_execution_tool/executor.rs
Normal file
@@ -0,0 +1,103 @@
|
||||
use std::sync::atomic::{AtomicBool, AtomicUsize};
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::time::Instant;
|
||||
|
||||
use crate::agents::sub_recipe_execution_tool::lib::{
|
||||
Config, ExecutionResponse, ExecutionStats, Task, TaskResult,
|
||||
};
|
||||
use crate::agents::sub_recipe_execution_tool::tasks::process_task;
|
||||
use crate::agents::sub_recipe_execution_tool::workers::{run_scaler, spawn_worker, SharedState};
|
||||
|
||||
pub async fn execute_single_task(task: &Task, config: Config) -> ExecutionResponse {
|
||||
let start_time = Instant::now();
|
||||
let result = process_task(task, config.timeout_seconds).await;
|
||||
|
||||
let execution_time = start_time.elapsed().as_millis();
|
||||
let completed = if result.status == "success" { 1 } else { 0 };
|
||||
let failed = if result.status == "failed" { 1 } else { 0 };
|
||||
|
||||
ExecutionResponse {
|
||||
status: "completed".to_string(),
|
||||
results: vec![result],
|
||||
stats: ExecutionStats {
|
||||
total_tasks: 1,
|
||||
completed,
|
||||
failed,
|
||||
execution_time_ms: execution_time,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Main parallel execution function
|
||||
pub async fn parallel_execute(tasks: Vec<Task>, config: Config) -> ExecutionResponse {
|
||||
let start_time = Instant::now();
|
||||
let task_count = tasks.len();
|
||||
|
||||
// Create channels
|
||||
let (task_tx, task_rx) = mpsc::channel::<Task>(task_count);
|
||||
let (result_tx, mut result_rx) = mpsc::channel::<TaskResult>(task_count);
|
||||
|
||||
// Initialize shared state
|
||||
let shared_state = Arc::new(SharedState {
|
||||
task_receiver: Arc::new(tokio::sync::Mutex::new(task_rx)),
|
||||
result_sender: result_tx,
|
||||
active_workers: Arc::new(AtomicUsize::new(0)),
|
||||
should_stop: Arc::new(AtomicBool::new(false)),
|
||||
completed_tasks: Arc::new(AtomicUsize::new(0)),
|
||||
});
|
||||
|
||||
// Send all tasks to the queue
|
||||
for task in tasks.clone() {
|
||||
let _ = task_tx.send(task).await;
|
||||
}
|
||||
// Close sender so workers know when queue is empty
|
||||
drop(task_tx);
|
||||
|
||||
// Start initial workers
|
||||
let mut worker_handles = Vec::new();
|
||||
for i in 0..config.initial_workers {
|
||||
let handle = spawn_worker(shared_state.clone(), i, config.timeout_seconds);
|
||||
worker_handles.push(handle);
|
||||
}
|
||||
|
||||
// Start the scaler
|
||||
let scaler_state = shared_state.clone();
|
||||
let scaler_handle = tokio::spawn(async move {
|
||||
run_scaler(
|
||||
scaler_state,
|
||||
task_count,
|
||||
config.max_workers,
|
||||
config.timeout_seconds,
|
||||
)
|
||||
.await;
|
||||
});
|
||||
|
||||
// Collect results
|
||||
let mut results = Vec::new();
|
||||
while let Some(result) = result_rx.recv().await {
|
||||
results.push(result);
|
||||
if results.len() >= task_count {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Wait for scaler to finish
|
||||
let _ = scaler_handle.await;
|
||||
|
||||
// Calculate stats
|
||||
let execution_time = start_time.elapsed().as_millis();
|
||||
let completed = results.iter().filter(|r| r.status == "success").count();
|
||||
let failed = results.iter().filter(|r| r.status == "failed").count();
|
||||
|
||||
ExecutionResponse {
|
||||
status: "completed".to_string(),
|
||||
results,
|
||||
stats: ExecutionStats {
|
||||
total_tasks: task_count,
|
||||
completed,
|
||||
failed,
|
||||
execution_time_ms: execution_time,
|
||||
},
|
||||
}
|
||||
}
|
||||
38
crates/goose/src/agents/sub_recipe_execution_tool/lib.rs
Normal file
38
crates/goose/src/agents/sub_recipe_execution_tool/lib.rs
Normal file
@@ -0,0 +1,38 @@
|
||||
use crate::agents::sub_recipe_execution_tool::executor::execute_single_task;
|
||||
pub use crate::agents::sub_recipe_execution_tool::executor::parallel_execute;
|
||||
pub use crate::agents::sub_recipe_execution_tool::types::{
|
||||
Config, ExecutionResponse, ExecutionStats, Task, TaskResult,
|
||||
};
|
||||
|
||||
use serde_json::Value;
|
||||
|
||||
pub async fn execute_tasks(input: Value, execution_mode: &str) -> Result<Value, String> {
|
||||
let tasks: Vec<Task> =
|
||||
serde_json::from_value(input.get("tasks").ok_or("Missing tasks field")?.clone())
|
||||
.map_err(|e| format!("Failed to parse tasks: {}", e))?;
|
||||
|
||||
let config: Config = if let Some(config_value) = input.get("config") {
|
||||
serde_json::from_value(config_value.clone())
|
||||
.map_err(|e| format!("Failed to parse config: {}", e))?
|
||||
} else {
|
||||
Config::default()
|
||||
};
|
||||
let task_count = tasks.len();
|
||||
match execution_mode {
|
||||
"sequential" => {
|
||||
if task_count == 1 {
|
||||
let response = execute_single_task(&tasks[0], config).await;
|
||||
serde_json::to_value(response)
|
||||
.map_err(|e| format!("Failed to serialize response: {}", e))
|
||||
} else {
|
||||
Err("Sequential execution mode requires exactly one task".to_string())
|
||||
}
|
||||
}
|
||||
"parallel" => {
|
||||
let response = parallel_execute(tasks, config).await;
|
||||
serde_json::to_value(response)
|
||||
.map_err(|e| format!("Failed to serialize response: {}", e))
|
||||
}
|
||||
_ => Err("Invalid execution mode".to_string()),
|
||||
}
|
||||
}
|
||||
6
crates/goose/src/agents/sub_recipe_execution_tool/mod.rs
Normal file
6
crates/goose/src/agents/sub_recipe_execution_tool/mod.rs
Normal file
@@ -0,0 +1,6 @@
|
||||
mod executor;
|
||||
pub mod lib;
|
||||
pub mod sub_recipe_execute_task_tool;
|
||||
mod tasks;
|
||||
mod types;
|
||||
mod workers;
|
||||
@@ -0,0 +1,124 @@
|
||||
use mcp_core::{tool::ToolAnnotations, Content, Tool, ToolError};
|
||||
use serde_json::Value;
|
||||
|
||||
use crate::agents::{
|
||||
sub_recipe_execution_tool::lib::execute_tasks, tool_execution::ToolCallResult,
|
||||
};
|
||||
|
||||
pub const SUB_RECIPE_EXECUTE_TASK_TOOL_NAME: &str = "sub_recipe__execute_task";
|
||||
pub fn create_sub_recipe_execute_task_tool() -> Tool {
|
||||
Tool::new(
|
||||
SUB_RECIPE_EXECUTE_TASK_TOOL_NAME,
|
||||
"Only use this tool when you execute sub recipe task.
|
||||
EXECUTION STRATEGY:
|
||||
- DEFAULT: Execute tasks sequentially (one at a time) unless user explicitly requests parallel execution
|
||||
- PARALLEL: Only when user explicitly uses keywords like 'parallel', 'simultaneously', 'at the same time', 'concurrently'
|
||||
|
||||
IMPLEMENTATION:
|
||||
- Sequential execution: Call this tool multiple times, passing exactly ONE task per call
|
||||
- Parallel execution: Call this tool once, passing an ARRAY of all tasks
|
||||
|
||||
EXAMPLES:
|
||||
- User: 'get weather and tell me a joke' → Sequential (2 separate tool calls, 1 task each)
|
||||
- User: 'get weather and joke in parallel' → Parallel (1 tool call with array of 2 tasks)
|
||||
- User: 'run these simultaneously' → Parallel (1 tool call with task array)
|
||||
- User: 'do task A then task B' → Sequential (2 separate tool calls)",
|
||||
serde_json::json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"execution_mode": {
|
||||
"type": "string",
|
||||
"enum": ["sequential", "parallel"],
|
||||
"default": "sequential",
|
||||
"description": "Execution strategy for multiple tasks. Use 'sequential' (default) unless user explicitly requests parallel execution with words like 'parallel', 'simultaneously', 'at the same time', or 'concurrently'."
|
||||
},
|
||||
"tasks": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"id": {
|
||||
"type": "string",
|
||||
"description": "Unique identifier for the task"
|
||||
},
|
||||
"task_type": {
|
||||
"type": "string",
|
||||
"enum": ["sub_recipe", "text_instruction"],
|
||||
"default": "sub_recipe",
|
||||
"description": "the type of task to execute, can be one of: sub_recipe, text_instruction"
|
||||
},
|
||||
"payload": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"sub_recipe": {
|
||||
"type": "object",
|
||||
"description": "sub recipe to execute",
|
||||
"properties": {
|
||||
"name": {
|
||||
"type": "string",
|
||||
"description": "name of the sub recipe to execute"
|
||||
},
|
||||
"recipe_path": {
|
||||
"type": "string",
|
||||
"description": "path of the sub recipe file"
|
||||
},
|
||||
"command_parameters": {
|
||||
"type": "object",
|
||||
"description": "parameters to pass to run recipe command with sub recipe file"
|
||||
}
|
||||
}
|
||||
},
|
||||
"text_instruction": {
|
||||
"type": "string",
|
||||
"description": "text instruction to execute"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"required": ["id", "payload"]
|
||||
},
|
||||
"description": "The tasks to run in parallel"
|
||||
},
|
||||
"config": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"timeout_seconds": {
|
||||
"type": "number"
|
||||
},
|
||||
"max_workers": {
|
||||
"type": "number"
|
||||
},
|
||||
"initial_workers": {
|
||||
"type": "number"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"required": ["tasks"]
|
||||
}),
|
||||
Some(ToolAnnotations {
|
||||
title: Some("Run tasks in parallel".to_string()),
|
||||
read_only_hint: false,
|
||||
destructive_hint: true,
|
||||
idempotent_hint: false,
|
||||
open_world_hint: true,
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
||||
pub async fn run_tasks(execute_data: Value) -> ToolCallResult {
|
||||
let execute_data_clone = execute_data.clone();
|
||||
let default_execution_mode_value = Value::String("sequential".to_string());
|
||||
let execution_mode = execute_data_clone
|
||||
.get("execution_mode")
|
||||
.unwrap_or(&default_execution_mode_value)
|
||||
.as_str()
|
||||
.unwrap_or("sequential");
|
||||
match execute_tasks(execute_data, execution_mode).await {
|
||||
Ok(result) => {
|
||||
let output = serde_json::to_string(&result).unwrap();
|
||||
ToolCallResult::from(Ok(vec![Content::text(output)]))
|
||||
}
|
||||
Err(e) => ToolCallResult::from(Err(ToolError::ExecutionError(e.to_string()))),
|
||||
}
|
||||
}
|
||||
119
crates/goose/src/agents/sub_recipe_execution_tool/tasks.rs
Normal file
119
crates/goose/src/agents/sub_recipe_execution_tool/tasks.rs
Normal file
@@ -0,0 +1,119 @@
|
||||
use serde_json::Value;
|
||||
use std::process::Stdio;
|
||||
use std::time::Duration;
|
||||
use tokio::io::{AsyncBufReadExt, BufReader};
|
||||
use tokio::process::Command;
|
||||
use tokio::time::timeout;
|
||||
|
||||
use crate::agents::sub_recipe_execution_tool::types::{Task, TaskResult};
|
||||
|
||||
// Process a single task based on its type
|
||||
pub async fn process_task(task: &Task, timeout_seconds: u64) -> TaskResult {
|
||||
let task_clone = task.clone();
|
||||
let timeout_duration = Duration::from_secs(timeout_seconds);
|
||||
|
||||
// Execute with timeout
|
||||
match timeout(timeout_duration, execute_task(task_clone)).await {
|
||||
Ok(Ok(data)) => TaskResult {
|
||||
task_id: task.id.clone(),
|
||||
status: "success".to_string(),
|
||||
data: Some(data),
|
||||
error: None,
|
||||
},
|
||||
Ok(Err(error)) => TaskResult {
|
||||
task_id: task.id.clone(),
|
||||
status: "failed".to_string(),
|
||||
data: None,
|
||||
error: Some(error),
|
||||
},
|
||||
Err(_) => TaskResult {
|
||||
task_id: task.id.clone(),
|
||||
status: "failed".to_string(),
|
||||
data: None,
|
||||
error: Some("Task timeout".to_string()),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
async fn execute_task(task: Task) -> Result<Value, String> {
|
||||
let mut output_identifier = task.id.clone();
|
||||
let mut command = if task.task_type == "sub_recipe" {
|
||||
let sub_recipe = task.payload.get("sub_recipe").unwrap();
|
||||
let sub_recipe_name = sub_recipe.get("name").unwrap().as_str().unwrap();
|
||||
let path = sub_recipe.get("recipe_path").unwrap().as_str().unwrap();
|
||||
let command_parameters = sub_recipe.get("command_parameters").unwrap();
|
||||
output_identifier = format!("sub-recipe {}", sub_recipe_name);
|
||||
let mut cmd = Command::new("goose");
|
||||
cmd.arg("run").arg("--recipe").arg(path);
|
||||
if let Some(params_map) = command_parameters.as_object() {
|
||||
for (key, value) in params_map {
|
||||
let key_str = key.to_string();
|
||||
let value_str = value.as_str().unwrap_or(&value.to_string()).to_string();
|
||||
cmd.arg("--params")
|
||||
.arg(format!("{}={}", key_str, value_str));
|
||||
}
|
||||
}
|
||||
cmd
|
||||
} else {
|
||||
let text = task
|
||||
.payload
|
||||
.get("text_instruction")
|
||||
.unwrap()
|
||||
.as_str()
|
||||
.unwrap();
|
||||
let mut cmd = Command::new("goose");
|
||||
cmd.arg("run").arg("--text").arg(text);
|
||||
cmd
|
||||
};
|
||||
|
||||
// Configure to capture stdout
|
||||
command.stdout(Stdio::piped());
|
||||
command.stderr(Stdio::piped());
|
||||
|
||||
// Spawn the child process
|
||||
let mut child = command
|
||||
.spawn()
|
||||
.map_err(|e| format!("Failed to spawn goose: {}", e))?;
|
||||
|
||||
let stdout = child.stdout.take().expect("Failed to capture stdout");
|
||||
let stderr = child.stderr.take().expect("Failed to capture stderr");
|
||||
|
||||
let mut stdout_reader = BufReader::new(stdout).lines();
|
||||
let mut stderr_reader = BufReader::new(stderr).lines();
|
||||
|
||||
// Spawn background tasks to read from stdout and stderr
|
||||
let output_identifier_clone = output_identifier.clone();
|
||||
let stdout_task = tokio::spawn(async move {
|
||||
let mut buffer = String::new();
|
||||
while let Ok(Some(line)) = stdout_reader.next_line().await {
|
||||
println!("[{}] {}", output_identifier_clone, line);
|
||||
buffer.push_str(&line);
|
||||
buffer.push('\n');
|
||||
}
|
||||
buffer
|
||||
});
|
||||
|
||||
let stderr_task = tokio::spawn(async move {
|
||||
let mut buffer = String::new();
|
||||
while let Ok(Some(line)) = stderr_reader.next_line().await {
|
||||
eprintln!("[stderr for {}] {}", output_identifier, line);
|
||||
buffer.push_str(&line);
|
||||
buffer.push('\n');
|
||||
}
|
||||
buffer
|
||||
});
|
||||
|
||||
let status = child
|
||||
.wait()
|
||||
.await
|
||||
.map_err(|e| format!("Failed to wait for process: {}", e))?;
|
||||
|
||||
let stdout_output = stdout_task.await.unwrap();
|
||||
let stderr_output = stderr_task.await.unwrap();
|
||||
|
||||
if status.success() {
|
||||
Ok(Value::String(stdout_output))
|
||||
} else {
|
||||
Err(format!("Command failed:\n{}", stderr_output))
|
||||
}
|
||||
}
|
||||
69
crates/goose/src/agents/sub_recipe_execution_tool/types.rs
Normal file
69
crates/goose/src/agents/sub_recipe_execution_tool/types.rs
Normal file
@@ -0,0 +1,69 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
|
||||
// Task definition that LLMs will send
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Task {
|
||||
pub id: String,
|
||||
pub task_type: String,
|
||||
pub payload: Value,
|
||||
}
|
||||
|
||||
// Result for each task
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct TaskResult {
|
||||
pub task_id: String,
|
||||
pub status: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub data: Option<Value>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub error: Option<String>,
|
||||
}
|
||||
|
||||
// Configuration for the parallel executor
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct Config {
|
||||
#[serde(default = "default_max_workers")]
|
||||
pub max_workers: usize,
|
||||
#[serde(default = "default_timeout")]
|
||||
pub timeout_seconds: u64,
|
||||
#[serde(default = "default_initial_workers")]
|
||||
pub initial_workers: usize,
|
||||
}
|
||||
|
||||
impl Default for Config {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
max_workers: default_max_workers(),
|
||||
timeout_seconds: default_timeout(),
|
||||
initial_workers: default_initial_workers(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn default_max_workers() -> usize {
|
||||
10
|
||||
}
|
||||
fn default_timeout() -> u64 {
|
||||
300
|
||||
}
|
||||
fn default_initial_workers() -> usize {
|
||||
2
|
||||
}
|
||||
|
||||
// Stats for the execution
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct ExecutionStats {
|
||||
pub total_tasks: usize,
|
||||
pub completed: usize,
|
||||
pub failed: usize,
|
||||
pub execution_time_ms: u128,
|
||||
}
|
||||
|
||||
// Main response structure
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct ExecutionResponse {
|
||||
pub status: String,
|
||||
pub results: Vec<TaskResult>,
|
||||
pub stats: ExecutionStats,
|
||||
}
|
||||
133
crates/goose/src/agents/sub_recipe_execution_tool/workers.rs
Normal file
133
crates/goose/src/agents/sub_recipe_execution_tool/workers.rs
Normal file
@@ -0,0 +1,133 @@
|
||||
use crate::agents::sub_recipe_execution_tool::tasks::process_task;
|
||||
use crate::agents::sub_recipe_execution_tool::types::{Task, TaskResult};
|
||||
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::time::{sleep, Duration};
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::agents::sub_recipe_execution_tool::types::Task;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_spawn_worker_returns_handle() {
|
||||
// Create a simple shared state for testing
|
||||
let (task_tx, task_rx) = mpsc::channel::<Task>(1);
|
||||
let (result_tx, _result_rx) = mpsc::channel::<TaskResult>(1);
|
||||
|
||||
let shared_state = Arc::new(SharedState {
|
||||
task_receiver: Arc::new(tokio::sync::Mutex::new(task_rx)),
|
||||
result_sender: result_tx,
|
||||
active_workers: Arc::new(AtomicUsize::new(0)),
|
||||
should_stop: Arc::new(AtomicBool::new(false)),
|
||||
completed_tasks: Arc::new(AtomicUsize::new(0)),
|
||||
});
|
||||
|
||||
// Test that spawn_worker returns a JoinHandle
|
||||
let handle = spawn_worker(shared_state.clone(), 0, 5);
|
||||
|
||||
// Verify it's a JoinHandle by checking we can abort it
|
||||
assert!(!handle.is_finished());
|
||||
|
||||
// Signal stop and close the channel to let the worker exit
|
||||
shared_state.should_stop.store(true, Ordering::SeqCst);
|
||||
drop(task_tx); // Close the channel
|
||||
|
||||
// Wait for the worker to finish
|
||||
let result = handle.await;
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
}
|
||||
|
||||
pub struct SharedState {
|
||||
pub task_receiver: Arc<tokio::sync::Mutex<mpsc::Receiver<Task>>>,
|
||||
pub result_sender: mpsc::Sender<TaskResult>,
|
||||
pub active_workers: Arc<AtomicUsize>,
|
||||
pub should_stop: Arc<AtomicBool>,
|
||||
pub completed_tasks: Arc<AtomicUsize>,
|
||||
}
|
||||
|
||||
// Spawn a worker task
|
||||
pub fn spawn_worker(
|
||||
state: Arc<SharedState>,
|
||||
worker_id: usize,
|
||||
timeout_seconds: u64,
|
||||
) -> tokio::task::JoinHandle<()> {
|
||||
state.active_workers.fetch_add(1, Ordering::SeqCst);
|
||||
|
||||
tokio::spawn(async move {
|
||||
worker_loop(state, worker_id, timeout_seconds).await;
|
||||
})
|
||||
}
|
||||
|
||||
async fn worker_loop(state: Arc<SharedState>, _worker_id: usize, timeout_seconds: u64) {
|
||||
loop {
|
||||
// Try to receive a task
|
||||
let task = {
|
||||
let mut receiver = state.task_receiver.lock().await;
|
||||
receiver.recv().await
|
||||
};
|
||||
|
||||
match task {
|
||||
Some(task) => {
|
||||
// Process the task
|
||||
let result = process_task(&task, timeout_seconds).await;
|
||||
|
||||
// Send result
|
||||
let _ = state.result_sender.send(result).await;
|
||||
|
||||
// Update completed count
|
||||
state.completed_tasks.fetch_add(1, Ordering::SeqCst);
|
||||
}
|
||||
None => {
|
||||
// Channel closed, exit worker
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Check if we should stop
|
||||
if state.should_stop.load(Ordering::SeqCst) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Worker is exiting
|
||||
state.active_workers.fetch_sub(1, Ordering::SeqCst);
|
||||
}
|
||||
|
||||
// Scaling controller that monitors queue and spawns workers
|
||||
pub async fn run_scaler(
|
||||
state: Arc<SharedState>,
|
||||
task_count: usize,
|
||||
max_workers: usize,
|
||||
timeout_seconds: u64,
|
||||
) {
|
||||
let mut worker_count = 0;
|
||||
|
||||
loop {
|
||||
sleep(Duration::from_millis(100)).await;
|
||||
|
||||
let active = state.active_workers.load(Ordering::SeqCst);
|
||||
let completed = state.completed_tasks.load(Ordering::SeqCst);
|
||||
let pending = task_count.saturating_sub(completed);
|
||||
|
||||
// Simple scaling logic: spawn worker if many pending tasks and under limit
|
||||
if pending > active * 2 && active < max_workers && worker_count < max_workers {
|
||||
let _handle = spawn_worker(state.clone(), worker_count, timeout_seconds);
|
||||
worker_count += 1;
|
||||
}
|
||||
|
||||
// If all tasks completed, signal stop
|
||||
if completed >= task_count {
|
||||
state.should_stop.store(true, Ordering::SeqCst);
|
||||
break;
|
||||
}
|
||||
|
||||
// If no active workers and tasks remaining, spawn one
|
||||
if active == 0 && pending > 0 {
|
||||
let _handle = spawn_worker(state.clone(), worker_count, timeout_seconds);
|
||||
worker_count += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -5,7 +5,7 @@ use std::collections::HashMap;
|
||||
use crate::{
|
||||
agents::{
|
||||
recipe_tools::sub_recipe_tools::{
|
||||
create_sub_recipe_tool, run_sub_recipe, SUB_RECIPE_TOOL_NAME_PREFIX,
|
||||
create_sub_recipe_task, create_sub_recipe_task_tool, SUB_RECIPE_TASK_TOOL_NAME_PREFIX,
|
||||
},
|
||||
tool_execution::ToolCallResult,
|
||||
},
|
||||
@@ -34,12 +34,18 @@ impl SubRecipeManager {
|
||||
|
||||
pub fn add_sub_recipe_tools(&mut self, sub_recipes_to_add: Vec<SubRecipe>) {
|
||||
for sub_recipe in sub_recipes_to_add {
|
||||
// let sub_recipe_key = format!(
|
||||
// "{}_{}",
|
||||
// SUB_RECIPE_TOOL_NAME_PREFIX,
|
||||
// sub_recipe.name.clone()
|
||||
// );
|
||||
// let tool = create_sub_recipe_tool(&sub_recipe);
|
||||
let sub_recipe_key = format!(
|
||||
"{}_{}",
|
||||
SUB_RECIPE_TOOL_NAME_PREFIX,
|
||||
SUB_RECIPE_TASK_TOOL_NAME_PREFIX,
|
||||
sub_recipe.name.clone()
|
||||
);
|
||||
let tool = create_sub_recipe_tool(&sub_recipe);
|
||||
let tool = create_sub_recipe_task_tool(&sub_recipe);
|
||||
self.sub_recipe_tools.insert(sub_recipe_key.clone(), tool);
|
||||
self.sub_recipes.insert(sub_recipe_key.clone(), sub_recipe);
|
||||
}
|
||||
@@ -61,6 +67,31 @@ impl SubRecipeManager {
|
||||
}
|
||||
}
|
||||
|
||||
// async fn call_sub_recipe_tool(
|
||||
// &self,
|
||||
// tool_name: &str,
|
||||
// params: Value,
|
||||
// ) -> Result<Vec<Content>, ToolError> {
|
||||
// let sub_recipe = self.sub_recipes.get(tool_name).ok_or_else(|| {
|
||||
// let sub_recipe_name = tool_name
|
||||
// .strip_prefix(SUB_RECIPE_TOOL_NAME_PREFIX)
|
||||
// .and_then(|s| s.strip_prefix("_"))
|
||||
// .ok_or_else(|| {
|
||||
// ToolError::InvalidParameters(format!(
|
||||
// "Invalid sub-recipe tool name format: {}",
|
||||
// tool_name
|
||||
// ))
|
||||
// })
|
||||
// .unwrap();
|
||||
|
||||
// ToolError::InvalidParameters(format!("Sub-recipe '{}' not found", sub_recipe_name))
|
||||
// })?;
|
||||
|
||||
// let output = run_sub_recipe(sub_recipe, params).await.map_err(|e| {
|
||||
// ToolError::ExecutionError(format!("Sub-recipe execution failed: {}", e))
|
||||
// })?;
|
||||
// Ok(vec![Content::text(output)])
|
||||
// }
|
||||
async fn call_sub_recipe_tool(
|
||||
&self,
|
||||
tool_name: &str,
|
||||
@@ -68,7 +99,7 @@ impl SubRecipeManager {
|
||||
) -> Result<Vec<Content>, ToolError> {
|
||||
let sub_recipe = self.sub_recipes.get(tool_name).ok_or_else(|| {
|
||||
let sub_recipe_name = tool_name
|
||||
.strip_prefix(SUB_RECIPE_TOOL_NAME_PREFIX)
|
||||
.strip_prefix(SUB_RECIPE_TASK_TOOL_NAME_PREFIX)
|
||||
.and_then(|s| s.strip_prefix("_"))
|
||||
.ok_or_else(|| {
|
||||
ToolError::InvalidParameters(format!(
|
||||
@@ -81,9 +112,11 @@ impl SubRecipeManager {
|
||||
ToolError::InvalidParameters(format!("Sub-recipe '{}' not found", sub_recipe_name))
|
||||
})?;
|
||||
|
||||
let output = run_sub_recipe(sub_recipe, params).await.map_err(|e| {
|
||||
ToolError::ExecutionError(format!("Sub-recipe execution failed: {}", e))
|
||||
})?;
|
||||
let output = create_sub_recipe_task(sub_recipe, params)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
ToolError::ExecutionError(format!("Sub-recipe execution failed: {}", e))
|
||||
})?;
|
||||
Ok(vec![Content::text(output)])
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user