diff --git a/Cargo.lock b/Cargo.lock index cab8ca6c..9f368d7b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -38,6 +38,7 @@ dependencies = [ "const-random", "getrandom 0.2.15", "once_cell", + "serde", "version_check", "zerocopy", ] @@ -1377,6 +1378,12 @@ dependencies = [ "generic-array", ] +[[package]] +name = "borrow-or-share" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3eeab4423108c5d7c744f4d234de88d18d636100093ae04caf4825134b9c3a32" + [[package]] name = "brotli" version = "7.0.0" @@ -2769,6 +2776,15 @@ version = "1.15.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719" +[[package]] +name = "email_address" +version = "0.2.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e079f19b08ca6239f47f8ba8509c11cf3ea30095831f7fed61441475edd8c449" +dependencies = [ + "serde", +] + [[package]] name = "encode_unicode" version = "1.0.0" @@ -2981,6 +2997,17 @@ dependencies = [ "miniz_oxide", ] +[[package]] +name = "fluent-uri" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1918b65d96df47d3591bed19c5cca17e3fa5d0707318e4b5ef2eae01764df7e5" +dependencies = [ + "borrow-or-share", + "ref-cast", + "serde", +] + [[package]] name = "fnv" version = "1.0.7" @@ -3044,6 +3071,16 @@ dependencies = [ "percent-encoding", ] +[[package]] +name = "fraction" +version = "0.15.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0f158e3ff0a1b334408dc9fb811cd99b446986f4d8b741bb08f9df1604085ae7" +dependencies = [ + "lazy_static", + "num", +] + [[package]] name = "fragile" version = "2.0.0" @@ -3423,6 +3460,7 @@ dependencies = [ "futures-util", "include_dir", "indoc 2.0.6", + "jsonschema", "jsonwebtoken", "keyring", "lancedb", @@ -4512,6 +4550,33 @@ dependencies = [ "serde", ] +[[package]] +name = "jsonschema" +version = "0.30.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f1b46a0365a611fbf1d2143104dcf910aada96fafd295bab16c60b802bf6fa1d" +dependencies = [ + "ahash", + "base64 0.22.1", + "bytecount", + "email_address", + "fancy-regex", + "fraction", + "idna", + "itoa", + "num-cmp", + "num-traits", + "once_cell", + "percent-encoding", + "referencing", + "regex", + "regex-syntax 0.8.5", + "reqwest 0.12.12", + "serde", + "serde_json", + "uuid-simd", +] + [[package]] name = "jsonwebtoken" version = "9.3.1" @@ -5695,6 +5760,12 @@ dependencies = [ "num-traits", ] +[[package]] +name = "num-cmp" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "63335b2e2c34fae2fb0aa2cecfd9f0832a1e24b3b32ecec612c3426d46dc8aaa" + [[package]] name = "num-complex" version = "0.4.6" @@ -6835,6 +6906,40 @@ dependencies = [ "thiserror 2.0.12", ] +[[package]] +name = "ref-cast" +version = "1.0.24" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4a0ae411dbe946a674d89546582cea4ba2bb8defac896622d6496f14c23ba5cf" +dependencies = [ + "ref-cast-impl", +] + +[[package]] +name = "ref-cast-impl" +version = "1.0.24" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1165225c21bff1f3bbce98f5a1f889949bc902d3575308cc7b0de30b4f6d27c7" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.99", +] + +[[package]] +name = "referencing" +version = "0.30.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c8eff4fa778b5c2a57e85c5f2fe3a709c52f0e60d23146e2151cbef5893f420e" +dependencies = [ + "ahash", + "fluent-uri", + "once_cell", + "parking_lot", + "percent-encoding", + "serde_json", +] + [[package]] name = "regex" version = "1.11.1" @@ -9037,6 +9142,17 @@ dependencies = [ "serde", ] +[[package]] +name = "uuid-simd" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "23b082222b4f6619906941c17eb2297fff4c2fb96cb60164170522942a200bd8" +dependencies = [ + "outref", + "uuid", + "vsimd", +] + [[package]] name = "v_frame" version = "0.3.8" diff --git a/crates/goose-cli/src/cli.rs b/crates/goose-cli/src/cli.rs index f7fe62c2..dd98f3ec 100644 --- a/crates/goose-cli/src/cli.rs +++ b/crates/goose-cli/src/cli.rs @@ -687,6 +687,7 @@ pub async fn cli() -> Result<()> { interactive: true, quiet: false, sub_recipes: None, + final_output_response: None, }) .await; setup_logging( @@ -736,7 +737,7 @@ pub async fn cli() -> Result<()> { quiet, additional_sub_recipes, }) => { - let (input_config, session_settings, sub_recipes) = match ( + let (input_config, session_settings, sub_recipes, final_output_response) = match ( instructions, input_text, recipe, @@ -755,6 +756,7 @@ pub async fn cli() -> Result<()> { }, None, None, + None, ) } (Some(file), _, _) => { @@ -773,6 +775,7 @@ pub async fn cli() -> Result<()> { }, None, None, + None, ) } (_, Some(text), _) => ( @@ -783,6 +786,7 @@ pub async fn cli() -> Result<()> { }, None, None, + None, ), (_, _, Some(recipe_name)) => { if explain { @@ -822,6 +826,7 @@ pub async fn cli() -> Result<()> { interactive, // Use the interactive flag from the Run command quiet, sub_recipes, + final_output_response, }) .await; @@ -941,6 +946,7 @@ pub async fn cli() -> Result<()> { interactive: true, // Default case is always interactive quiet: false, sub_recipes: None, + final_output_response: None, }) .await; setup_logging( diff --git a/crates/goose-cli/src/commands/bench.rs b/crates/goose-cli/src/commands/bench.rs index 83d48562..ee64588b 100644 --- a/crates/goose-cli/src/commands/bench.rs +++ b/crates/goose-cli/src/commands/bench.rs @@ -47,6 +47,7 @@ pub async fn agent_generator( scheduled_job_id: None, quiet: false, sub_recipes: None, + final_output_response: None, }) .await; diff --git a/crates/goose-cli/src/recipes/extract_from_cli.rs b/crates/goose-cli/src/recipes/extract_from_cli.rs index eb853f99..e84fd564 100644 --- a/crates/goose-cli/src/recipes/extract_from_cli.rs +++ b/crates/goose-cli/src/recipes/extract_from_cli.rs @@ -1,15 +1,21 @@ use std::path::PathBuf; use anyhow::Result; -use goose::recipe::SubRecipe; +use goose::recipe::{Response, SubRecipe}; use crate::{cli::InputConfig, recipes::recipe::load_recipe_as_template, session::SessionSettings}; +#[allow(clippy::type_complexity)] pub fn extract_recipe_info_from_cli( recipe_name: String, params: Vec<(String, String)>, additional_sub_recipes: Vec, -) -> Result<(InputConfig, Option, Option>)> { +) -> Result<( + InputConfig, + Option, + Option>, + Option, +)> { let recipe = load_recipe_as_template(&recipe_name, params).unwrap_or_else(|err| { eprintln!("{}: {}", console::style("Error").red().bold(), err); std::process::exit(1); @@ -43,6 +49,7 @@ pub fn extract_recipe_info_from_cli( temperature: s.temperature, }), Some(all_sub_recipes), + recipe.response, )) } @@ -69,7 +76,7 @@ mod tests { let params = vec![("name".to_string(), "my_value".to_string())]; let recipe_name = recipe_path.to_str().unwrap().to_string(); - let (input_config, settings, sub_recipes) = + let (input_config, settings, sub_recipes, response) = extract_recipe_info_from_cli(recipe_name, params, Vec::new()).unwrap(); assert_eq!(input_config.contents, Some("test_prompt".to_string())); @@ -91,6 +98,17 @@ mod tests { assert_eq!(sub_recipes[0].path, "existing_sub_recipe.yaml".to_string()); assert_eq!(sub_recipes[0].name, "existing_sub_recipe".to_string()); assert!(sub_recipes[0].values.is_none()); + assert!(response.is_some()); + let response = response.unwrap(); + assert_eq!( + response.json_schema, + Some(serde_json::json!({ + "type": "object", + "properties": { + "result": {"type": "string"} + } + })) + ); } #[test] @@ -103,7 +121,7 @@ mod tests { "another/sub_recipe2.yaml".to_string(), ]; - let (input_config, settings, sub_recipes) = + let (input_config, settings, sub_recipes, response) = extract_recipe_info_from_cli(recipe_name, params, additional_sub_recipes).unwrap(); assert_eq!(input_config.contents, Some("test_prompt".to_string())); @@ -131,6 +149,17 @@ mod tests { assert_eq!(sub_recipes[2].path, "another/sub_recipe2.yaml".to_string()); assert_eq!(sub_recipes[2].name, "sub_recipe2".to_string()); assert!(sub_recipes[2].values.is_none()); + assert!(response.is_some()); + let response = response.unwrap(); + assert_eq!( + response.json_schema, + Some(serde_json::json!({ + "type": "object", + "properties": { + "result": {"type": "string"} + } + })) + ); } fn create_recipe() -> (TempDir, PathBuf) { @@ -151,6 +180,12 @@ settings: sub_recipes: - path: existing_sub_recipe.yaml name: existing_sub_recipe +response: + json_schema: + type: object + properties: + result: + type: string "#; let temp_dir = tempfile::tempdir().unwrap(); let recipe_path: std::path::PathBuf = temp_dir.path().join("test_recipe.yaml"); diff --git a/crates/goose-cli/src/session/builder.rs b/crates/goose-cli/src/session/builder.rs index 06f81650..9aff9999 100644 --- a/crates/goose-cli/src/session/builder.rs +++ b/crates/goose-cli/src/session/builder.rs @@ -3,7 +3,7 @@ use goose::agents::extension::ExtensionError; use goose::agents::Agent; use goose::config::{Config, ExtensionConfig, ExtensionConfigManager}; use goose::providers::create; -use goose::recipe::SubRecipe; +use goose::recipe::{Response, SubRecipe}; use goose::session; use goose::session::Identifier; use mcp_client::transport::Error as McpClientError; @@ -49,6 +49,8 @@ pub struct SessionBuilderConfig { pub quiet: bool, /// Sub-recipes to add to the session pub sub_recipes: Option>, + /// Final output expected response + pub final_output_response: Option, } /// Offers to help debug an extension failure by creating a minimal debugging session @@ -180,6 +182,11 @@ pub async fn build_session(session_config: SessionBuilderConfig) -> Session { if let Some(sub_recipes) = session_config.sub_recipes { agent.add_sub_recipes(sub_recipes).await; } + + if let Some(final_output_response) = session_config.final_output_response { + agent.add_final_output_tool(final_output_response).await; + } + let new_provider = match create(&provider_name, model_config) { Ok(provider) => provider, Err(e) => { @@ -520,6 +527,7 @@ mod tests { interactive: true, quiet: false, sub_recipes: None, + final_output_response: None, }; assert_eq!(config.extensions.len(), 1); @@ -549,6 +557,7 @@ mod tests { assert!(config.scheduled_job_id.is_none()); assert!(!config.interactive); assert!(!config.quiet); + assert!(config.final_output_response.is_none()); } #[tokio::test] diff --git a/crates/goose/Cargo.toml b/crates/goose/Cargo.toml index 4a91d428..24623116 100644 --- a/crates/goose/Cargo.toml +++ b/crates/goose/Cargo.toml @@ -38,6 +38,7 @@ tokio = { version = "1.43", features = ["full"] } serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" serde_urlencoded = "0.7" +jsonschema = "0.30.0" uuid = { version = "1.0", features = ["v4"] } regex = "1.11.1" async-trait = "0.1" diff --git a/crates/goose/src/agents/agent.rs b/crates/goose/src/agents/agent.rs index c2a5dbdb..e87083e2 100644 --- a/crates/goose/src/agents/agent.rs +++ b/crates/goose/src/agents/agent.rs @@ -10,6 +10,7 @@ use futures_util::stream; use futures_util::stream::StreamExt; use mcp_core::protocol::JsonRpcMessage; +use crate::agents::final_output_tool::{FINAL_OUTPUT_CONTINUATION_MESSAGE, FINAL_OUTPUT_TOOL_NAME}; use crate::agents::sub_recipe_manager::SubRecipeManager; use crate::config::{Config, ExtensionConfigManager, PermissionManager}; use crate::message::Message; @@ -17,7 +18,7 @@ use crate::permission::permission_judge::check_tool_permissions; use crate::permission::PermissionConfirmation; use crate::providers::base::Provider; use crate::providers::errors::ProviderError; -use crate::recipe::{Author, Recipe, Settings, SubRecipe}; +use crate::recipe::{Author, Recipe, Response, Settings, SubRecipe}; use crate::scheduler_trait::SchedulerTrait; use crate::tool_monitor::{ToolCall, ToolMonitor}; use regex::Regex; @@ -47,6 +48,7 @@ use mcp_core::{ use crate::agents::subagent_tools::SUBAGENT_RUN_TASK_TOOL_NAME; +use super::final_output_tool::FinalOutputTool; use super::platform_tools; use super::router_tools; use super::subagent_manager::SubAgentManager; @@ -58,6 +60,7 @@ pub struct Agent { pub(super) provider: Mutex>>, pub(super) extension_manager: RwLock, pub(super) sub_recipe_manager: Mutex, + pub(super) final_output_tool: Mutex>, pub(super) frontend_tools: Mutex>, pub(super) frontend_instructions: Mutex>, pub(super) prompt_manager: Mutex, @@ -131,6 +134,7 @@ impl Agent { provider: Mutex::new(None), extension_manager: RwLock::new(ExtensionManager::new()), sub_recipe_manager: Mutex::new(SubRecipeManager::new()), + final_output_tool: Mutex::new(None), frontend_tools: Mutex::new(HashMap::new()), frontend_instructions: Mutex::new(None), prompt_manager: Mutex::new(PromptManager::new()), @@ -205,6 +209,14 @@ impl Agent { Ok(tools) } + pub async fn add_final_output_tool(&self, response: Response) { + let mut final_output_tool = self.final_output_tool.lock().await; + let created_final_output_tool = FinalOutputTool::new(response); + let final_output_system_prompt = created_final_output_tool.system_prompt(); + *final_output_tool = Some(created_final_output_tool); + self.extend_system_prompt(final_output_system_prompt).await; + } + pub async fn add_sub_recipes(&self, sub_recipes: Vec) { let mut sub_recipe_manager = self.sub_recipe_manager.lock().await; sub_recipe_manager.add_sub_recipe_tools(sub_recipes); @@ -258,6 +270,20 @@ impl Agent { return (request_id, Ok(ToolCallResult::from(result))); } + if tool_call.name == FINAL_OUTPUT_TOOL_NAME { + if let Some(final_output_tool) = self.final_output_tool.lock().await.as_mut() { + let result = final_output_tool.execute_tool_call(tool_call.clone()).await; + return (request_id, Ok(result)); + } else { + return ( + request_id, + Err(ToolError::ExecutionError( + "Final output tool not defined".to_string(), + )), + ); + } + } + let extension_manager = self.extension_manager.read().await; let sub_recipe_manager = self.sub_recipe_manager.lock().await; @@ -544,6 +570,10 @@ impl Agent { if extension_name.is_none() { let sub_recipe_manager = self.sub_recipe_manager.lock().await; prefixed_tools.extend(sub_recipe_manager.sub_recipe_tools.values().cloned()); + + if let Some(final_output_tool) = self.final_output_tool.lock().await.as_ref() { + prefixed_tools.push(final_output_tool.tool()); + } } prefixed_tools @@ -766,6 +796,15 @@ impl Agent { let num_tool_requests = frontend_requests.len() + remaining_requests.len(); if num_tool_requests == 0 { + if let Some(final_output_tool) = self.final_output_tool.lock().await.as_ref() { + if final_output_tool.final_output.is_none() { + tracing::warn!("Final output tool has not been called yet. Continuing agent loop."); + yield AgentEvent::Message(Message::user().with_text(FINAL_OUTPUT_CONTINUATION_MESSAGE)); + continue; + } else { + yield AgentEvent::Message(Message::assistant().with_text(final_output_tool.final_output.clone().unwrap())); + } + } break; } @@ -1260,3 +1299,43 @@ impl Agent { Ok(recipe) } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::recipe::Response; + + #[tokio::test] + async fn test_add_final_output_tool() -> Result<()> { + let agent = Agent::new(); + + let response = Response { + json_schema: Some(serde_json::json!({ + "type": "object", + "properties": { + "result": {"type": "string"} + } + })), + }; + + agent.add_final_output_tool(response).await; + + let tools = agent.list_tools(None).await; + let final_output_tool = tools.iter().find(|tool| tool.name == "final_output"); + + assert!( + final_output_tool.is_some(), + "Final output tool should be present after adding" + ); + + let prompt_manager = agent.prompt_manager.lock().await; + let system_prompt = + prompt_manager.build_system_prompt(vec![], None, serde_json::Value::Null, None, None); + + let final_output_tool_ref = agent.final_output_tool.lock().await; + let final_output_tool_system_prompt = + final_output_tool_ref.as_ref().unwrap().system_prompt(); + assert!(system_prompt.contains(&final_output_tool_system_prompt)); + Ok(()) + } +} diff --git a/crates/goose/src/agents/final_output_tool.rs b/crates/goose/src/agents/final_output_tool.rs new file mode 100644 index 00000000..9e272998 --- /dev/null +++ b/crates/goose/src/agents/final_output_tool.rs @@ -0,0 +1,261 @@ +use crate::agents::tool_execution::ToolCallResult; +use crate::recipe::Response; +use indoc::formatdoc; +use mcp_core::{ + tool::{Tool, ToolAnnotations}, + Content, ToolCall, ToolError, +}; +use serde_json::Value; + +pub const FINAL_OUTPUT_TOOL_NAME: &str = "final_output"; +pub const FINAL_OUTPUT_CONTINUATION_MESSAGE: &str = + "You MUST call the `final_output` tool with your final output for the user."; + +pub struct FinalOutputTool { + pub response: Response, + /// The final output collected for the user. It will be a single line string for easy script extraction from output. + pub final_output: Option, +} + +impl FinalOutputTool { + pub fn new(response: Response) -> Self { + if response.json_schema.is_none() { + panic!("Cannot create FinalOutputTool: json_schema is required"); + } + let schema = response.json_schema.as_ref().unwrap(); + + if let Some(obj) = schema.as_object() { + if obj.is_empty() { + panic!("Cannot create FinalOutputTool: empty json_schema is not allowed"); + } + } + + jsonschema::meta::validate(schema).unwrap(); + Self { + response, + final_output: None, + } + } + + pub fn tool(&self) -> Tool { + let instructions = formatdoc! {r#" + This tool collects the final output for a user and provides validation for structured JSON final output against a predefined schema. + + This tool MUST be used for the final output to the user. + + Purpose: + - Collects the final output for a user + - Ensures that final outputs conform to the expected JSON structure + - Provides clear validation feedback when outputs don't match the schema + + Usage: + - Call the `final_output` tool with your JSON final output + + The expected JSON schema format is: + + {} + + When validation fails, you'll receive: + - Specific validation errors + - The expected format + "#, serde_json::to_string_pretty(self.response.json_schema.as_ref().unwrap()).unwrap()}; + + Tool::new( + FINAL_OUTPUT_TOOL_NAME.to_string(), + instructions, + self.response.json_schema.as_ref().unwrap().clone(), + Some(ToolAnnotations { + title: Some("Final Output".to_string()), + read_only_hint: false, + destructive_hint: false, + idempotent_hint: true, + open_world_hint: false, + }), + ) + } + + pub fn system_prompt(&self) -> String { + formatdoc! {r#" + # Final Ouptut Instructions + + You MUST use the `final_output` tool to collect the final output for a user. + The final output MUST be a valid JSON object that matches the following expected schema: + + {} + + ---- + "#, serde_json::to_string_pretty(self.response.json_schema.as_ref().unwrap()).unwrap()} + } + + async fn validate_json_output(&self, output: &Value) -> Result { + let compiled_schema = + match jsonschema::validator_for(self.response.json_schema.as_ref().unwrap()) { + Ok(schema) => schema, + Err(e) => { + return Err(format!("Internal error: Failed to compile schema: {}", e)); + } + }; + + let validation_errors: Vec = compiled_schema + .iter_errors(output) + .map(|error| format!("- {}: {}", error.instance_path, error)) + .collect(); + + if validation_errors.is_empty() { + Ok(output.clone()) + } else { + Err(format!( + "Validation failed:\n{}\n\nExpected format:\n{}\n\nPlease correct your output to match the expected JSON schema and try again.", + validation_errors.join("\n"), + serde_json::to_string_pretty(self.response.json_schema.as_ref().unwrap()).unwrap_or_else(|_| "Invalid schema".to_string()) + )) + } + } + + pub async fn execute_tool_call(&mut self, tool_call: ToolCall) -> ToolCallResult { + match tool_call.name.as_str() { + FINAL_OUTPUT_TOOL_NAME => { + let result = self.validate_json_output(&tool_call.arguments).await; + match result { + Ok(parsed_value) => { + self.final_output = Some(Self::parsed_final_output_string(parsed_value)); + ToolCallResult::from(Ok(vec![Content::text( + "Final output successfully collected.".to_string(), + )])) + } + Err(error) => ToolCallResult::from(Err(ToolError::InvalidParameters(error))), + } + } + _ => ToolCallResult::from(Err(ToolError::NotFound(format!( + "Unknown tool: {}", + tool_call.name + )))), + } + } + + // Formats the parsed JSON as a single line string so its easy to extract from the output + fn parsed_final_output_string(parsed_json: Value) -> String { + serde_json::to_string(&parsed_json).unwrap() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::recipe::Response; + use serde_json::json; + + fn create_complex_test_schema() -> Value { + json!({ + "type": "object", + "properties": { + "user": { + "type": "object", + "properties": { + "name": {"type": "string"}, + "age": {"type": "number"} + }, + "required": ["name", "age"] + }, + "tags": { + "type": "array", + "items": {"type": "string"} + } + }, + "required": ["user", "tags"] + }) + } + + #[test] + #[should_panic(expected = "Cannot create FinalOutputTool: json_schema is required")] + fn test_new_with_missing_schema() { + let response = Response { json_schema: None }; + FinalOutputTool::new(response); + } + + #[test] + #[should_panic(expected = "Cannot create FinalOutputTool: empty json_schema is not allowed")] + fn test_new_with_empty_schema() { + let response = Response { + json_schema: Some(json!({})), + }; + FinalOutputTool::new(response); + } + + #[test] + #[should_panic] + fn test_new_with_invalid_schema() { + let response = Response { + json_schema: Some(json!({ + "type": "invalid_type", + "properties": { + "message": { + "type": "unknown_type" + } + } + })), + }; + FinalOutputTool::new(response); + } + + #[tokio::test] + async fn test_execute_tool_call_schema_validation_failure() { + let response = Response { + json_schema: Some(json!({ + "type": "object", + "properties": { + "message": { + "type": "string" + }, + "count": { + "type": "number" + } + }, + "required": ["message", "count"] + })), + }; + + let mut tool = FinalOutputTool::new(response); + let tool_call = ToolCall { + name: FINAL_OUTPUT_TOOL_NAME.to_string(), + arguments: json!({ + "message": "Hello" // Missing required "count" field + }), + }; + + let result = tool.execute_tool_call(tool_call).await; + let tool_result = result.result.await; + assert!(tool_result.is_err()); + if let Err(error) = tool_result { + assert!(error.to_string().contains("Validation failed")); + } + } + + #[tokio::test] + async fn test_execute_tool_call_complex_valid_json() { + let response = Response { + json_schema: Some(create_complex_test_schema()), + }; + + let mut tool = FinalOutputTool::new(response); + let tool_call = ToolCall { + name: FINAL_OUTPUT_TOOL_NAME.to_string(), + arguments: json!({ + "user": { + "name": "John", + "age": 30 + }, + "tags": ["developer", "rust"] + }), + }; + + let result = tool.execute_tool_call(tool_call).await; + let tool_result = result.result.await; + assert!(tool_result.is_ok()); + assert!(tool.final_output.is_some()); + + let final_output = tool.final_output.unwrap(); + assert!(serde_json::from_str::(&final_output).is_ok()); + assert!(!final_output.contains('\n')); + } +} diff --git a/crates/goose/src/agents/mod.rs b/crates/goose/src/agents/mod.rs index 6b4a6e9f..098521c0 100644 --- a/crates/goose/src/agents/mod.rs +++ b/crates/goose/src/agents/mod.rs @@ -2,6 +2,7 @@ mod agent; mod context; pub mod extension; pub mod extension_manager; +pub mod final_output_tool; mod large_response_handler; pub mod platform_tools; pub mod prompt_manager; diff --git a/crates/goose/src/recipe/mod.rs b/crates/goose/src/recipe/mod.rs index a8604ff7..55ff144f 100644 --- a/crates/goose/src/recipe/mod.rs +++ b/crates/goose/src/recipe/mod.rs @@ -29,6 +29,7 @@ fn default_version() -> String { /// * `activities` - Activity labels that appear when loading the Recipe /// * `author` - Information about the Recipe's creator and metadata /// * `parameters` - Additional parameters for the Recipe +/// * `response` - Response configuration including JSON schema validation /// /// # Example /// @@ -56,6 +57,8 @@ fn default_version() -> String { /// author: None, /// settings: None, /// parameters: None, +/// response: None, +/// sub_recipes: None, /// }; /// #[derive(Serialize, Deserialize, Debug, Clone)] @@ -94,6 +97,9 @@ pub struct Recipe { #[serde(skip_serializing_if = "Option::is_none")] pub parameters: Option>, // any additional parameters for the recipe + #[serde(skip_serializing_if = "Option::is_none")] + pub response: Option, // response configuration including JSON schema + #[serde(skip_serializing_if = "Option::is_none")] pub sub_recipes: Option>, // sub-recipes for the recipe } @@ -119,6 +125,12 @@ pub struct Settings { pub temperature: Option, } +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct Response { + #[serde(skip_serializing_if = "Option::is_none")] + pub json_schema: Option, +} + #[derive(Serialize, Deserialize, Debug, Clone)] pub struct SubRecipe { pub name: String, @@ -216,6 +228,7 @@ pub struct RecipeBuilder { activities: Option>, author: Option, parameters: Option>, + response: Option, sub_recipes: Option>, } @@ -247,6 +260,7 @@ impl Recipe { activities: None, author: None, parameters: None, + response: None, sub_recipes: None, } } @@ -327,6 +341,12 @@ impl RecipeBuilder { self.parameters = Some(parameters); self } + + pub fn response(mut self, response: Response) -> Self { + self.response = Some(response); + self + } + pub fn sub_recipes(mut self, sub_recipes: Vec) -> Self { self.sub_recipes = Some(sub_recipes); self @@ -355,6 +375,7 @@ impl RecipeBuilder { activities: self.activities, author: self.author, parameters: self.parameters, + response: self.response, sub_recipes: self.sub_recipes, }) } @@ -390,6 +411,20 @@ mod tests { "description": "A test parameter" } ], + "response": { + "json_schema": { + "type": "object", + "properties": { + "name": { + "type": "string" + }, + "age": { + "type": "number" + } + }, + "required": ["name"] + } + }, "sub_recipes": [ { "name": "test_sub_recipe", @@ -425,6 +460,16 @@ mod tests { RecipeParameterRequirement::Required )); + assert!(recipe.response.is_some()); + let response = recipe.response.unwrap(); + assert!(response.json_schema.is_some()); + let json_schema = response.json_schema.unwrap(); + assert_eq!(json_schema["type"], "object"); + assert!(json_schema["properties"].is_object()); + assert_eq!(json_schema["properties"]["name"]["type"], "string"); + assert_eq!(json_schema["properties"]["age"]["type"], "number"); + assert_eq!(json_schema["required"], serde_json::json!(["name"])); + assert!(recipe.sub_recipes.is_some()); let sub_recipes = recipe.sub_recipes.unwrap(); assert_eq!(sub_recipes.len(), 1); @@ -458,6 +503,16 @@ parameters: input_type: string requirement: required description: A test parameter +response: + json_schema: + type: object + properties: + name: + type: string + age: + type: number + required: + - name sub_recipes: - name: test_sub_recipe path: test_sub_recipe.yaml @@ -488,6 +543,16 @@ sub_recipes: RecipeParameterRequirement::Required )); + assert!(recipe.response.is_some()); + let response = recipe.response.unwrap(); + assert!(response.json_schema.is_some()); + let json_schema = response.json_schema.unwrap(); + assert_eq!(json_schema["type"], "object"); + assert!(json_schema["properties"].is_object()); + assert_eq!(json_schema["properties"]["name"]["type"], "string"); + assert_eq!(json_schema["properties"]["age"]["type"], "number"); + assert_eq!(json_schema["required"], serde_json::json!(["name"])); + assert!(recipe.sub_recipes.is_some()); let sub_recipes = recipe.sub_recipes.unwrap(); assert_eq!(sub_recipes.len(), 1); diff --git a/crates/goose/src/scheduler.rs b/crates/goose/src/scheduler.rs index 8c256dcd..e73e618e 100644 --- a/crates/goose/src/scheduler.rs +++ b/crates/goose/src/scheduler.rs @@ -1422,6 +1422,7 @@ mod tests { author: None, parameters: None, settings: None, + response: None, sub_recipes: None, }; let mut recipe_file = File::create(&recipe_filename)?; diff --git a/crates/goose/tests/agent.rs b/crates/goose/tests/agent.rs index 18260129..f03473b0 100644 --- a/crates/goose/tests/agent.rs +++ b/crates/goose/tests/agent.rs @@ -531,3 +531,110 @@ mod schedule_tool_tests { } } } + +#[cfg(test)] +mod final_output_tool_tests { + use super::*; + use goose::agents::final_output_tool::FINAL_OUTPUT_TOOL_NAME; + use goose::recipe::Response; + + #[tokio::test] + async fn test_final_output_assistant_message_in_reply() -> Result<()> { + use async_trait::async_trait; + use goose::model::ModelConfig; + use goose::providers::base::{Provider, ProviderUsage, Usage}; + use goose::providers::errors::ProviderError; + use mcp_core::tool::Tool; + + #[derive(Clone)] + struct MockProvider { + model_config: ModelConfig, + } + + #[async_trait] + impl Provider for MockProvider { + fn metadata() -> goose::providers::base::ProviderMetadata { + goose::providers::base::ProviderMetadata::empty() + } + + fn get_model_config(&self) -> ModelConfig { + self.model_config.clone() + } + + async fn complete( + &self, + _system: &str, + _messages: &[Message], + _tools: &[Tool], + ) -> anyhow::Result<(Message, ProviderUsage), ProviderError> { + Ok(( + Message::assistant().with_text("Task completed."), + ProviderUsage::new("mock".to_string(), Usage::default()), + )) + } + } + + let agent = Agent::new(); + + let model_config = ModelConfig::new("test-model".to_string()); + let mock_provider = Arc::new(MockProvider { model_config }); + agent.update_provider(mock_provider).await?; + + let response = Response { + json_schema: Some(serde_json::json!({ + "type": "object", + "properties": { + "result": {"type": "string"} + }, + "required": ["result"] + })), + }; + agent.add_final_output_tool(response).await; + + // Simulate a final output tool call occurring. + let tool_call = mcp_core::tool::ToolCall::new( + FINAL_OUTPUT_TOOL_NAME, + serde_json::json!({ + "result": "Test output" + }), + ); + let (_, result) = agent + .dispatch_tool_call(tool_call, "request_id".to_string()) + .await; + + assert!(result.is_ok(), "Tool call should succeed"); + let final_result = result.unwrap().result.await; + assert!(final_result.is_ok(), "Tool execution should succeed"); + + let content = final_result.unwrap(); + let text = content.first().unwrap().as_text().unwrap(); + assert!( + text.contains("Final output successfully collected."), + "Tool result missing expected content: {}", + text + ); + + // Simulate the reply stream continuing after the final output tool call. + let reply_stream = agent.reply(&vec![], None).await?; + tokio::pin!(reply_stream); + + let mut responses = Vec::new(); + while let Some(response_result) = reply_stream.next().await { + match response_result { + Ok(AgentEvent::Message(response)) => responses.push(response), + Ok(_) => {} + Err(e) => return Err(e), + } + } + + assert!(!responses.is_empty(), "Should have received responses"); + let last_message = responses.last().unwrap(); + + // Check that the last message is an assistant message with our final output + assert_eq!(last_message.role, mcp_core::role::Role::Assistant); + let message_text = last_message.as_concat_text(); + assert_eq!(message_text, r#"{"result":"Test output"}"#); + + Ok(()) + } +} diff --git a/documentation/docs/guides/recipes/recipe-reference.md b/documentation/docs/guides/recipes/recipe-reference.md index 4c0f79c3..f8da8175 100644 --- a/documentation/docs/guides/recipes/recipe-reference.md +++ b/documentation/docs/guides/recipes/recipe-reference.md @@ -36,6 +36,7 @@ After creating recipe files, you can use [`goose` CLI commands](/docs/guides/goo | `prompt` | String | A template prompt that can include parameter substitutions; required in headless (non-interactive) mode | | `parameters` | Array | List of parameter definitions | | `extensions` | Array | List of extension configurations | +| `response` | Object | Configuration for structured output validation | ## Parameters @@ -106,6 +107,54 @@ extensions: description: "For searching logs using Presidio" ``` +## Structured Output with `response` + +The `response` field enables recipes to enforce a final structured JSON output from Goose. When you specify a `json_schema`, Goose will: + +1. **Validate the output**: Validates the output JSON against your JSON schema with basic JSON schema validations +2. **Final structured output**: Ensure the final output of the agent is a response matching your JSON structure + +This **Enables automation** by returning consistent, parseable results for scripts and workflows. + +### Basic Structure + +```yaml +response: + json_schema: + type: object + properties: + # Define your fields here, with their type and description + required: + # List required field names +``` + +### Simple Example + +```yaml +version: "1.0.0" +title: "Task Summary" +description: "Summarize completed tasks" +prompt: "Summarize the tasks you completed" +response: + json_schema: + type: object + properties: + summary: + type: string + description: "Brief summary of work done" + tasks_completed: + type: number + description: "Number of tasks finished" + next_steps: + type: array + items: + type: string + description: "Recommended next actions" + required: + - summary + - tasks_completed +``` + ## Template Support Recipes support Jinja-style template syntax in both `instructions` and `prompt` fields: @@ -164,6 +213,22 @@ extensions: timeout: 300 bundled: true description: "Query codesearch directly from goose" + +response: + json_schema: + type: object + properties: + result: + type: string + description: "The main result of the task" + details: + type: array + items: + type: string + description: "Additional details of steps taken" + required: + - result + - status ``` ## Template Inheritance