diff --git a/crates/goose/src/agents/agent.rs b/crates/goose/src/agents/agent.rs index 3adee023..c52807d5 100644 --- a/crates/goose/src/agents/agent.rs +++ b/crates/goose/src/agents/agent.rs @@ -837,24 +837,6 @@ impl Agent { let num_tool_requests = frontend_requests.len() + remaining_requests.len(); if num_tool_requests == 0 { - if let Some(final_output_tool) = self.final_output_tool.lock().await.as_ref() { - if final_output_tool.final_output.is_none() { - tracing::warn!("Final output tool has not been called yet. Continuing agent loop."); - let message = Message::assistant().with_text(FINAL_OUTPUT_CONTINUATION_MESSAGE); - messages.push(message.clone()); - yield AgentEvent::Message(message); - continue; - } else { - let message = Message::assistant().with_text(final_output_tool.final_output.clone().unwrap()); - messages.push(message.clone()); - yield AgentEvent::Message(message); - // Set added_message to true and continue to end the current iteration - added_message = true; - push_message(&mut messages, response); - continue; - } - } - // If there's no final output tool and no tool requests, continue the loop continue; } @@ -1039,10 +1021,14 @@ impl Agent { if let Some(final_output_tool) = self.final_output_tool.lock().await.as_ref() { if final_output_tool.final_output.is_none() { tracing::warn!("Final output tool has not been called yet. Continuing agent loop."); - yield AgentEvent::Message(Message::user().with_text(FINAL_OUTPUT_CONTINUATION_MESSAGE)); + let message = Message::user().with_text(FINAL_OUTPUT_CONTINUATION_MESSAGE); + messages.push(message.clone()); + yield AgentEvent::Message(message); continue; } else { - yield AgentEvent::Message(Message::assistant().with_text(final_output_tool.final_output.clone().unwrap())); + let message = Message::assistant().with_text(final_output_tool.final_output.clone().unwrap()); + messages.push(message.clone()); + yield AgentEvent::Message(message); } } break; diff --git a/crates/goose/src/agents/final_output_tool.rs b/crates/goose/src/agents/final_output_tool.rs index 3ada87c8..7059feb0 100644 --- a/crates/goose/src/agents/final_output_tool.rs +++ b/crates/goose/src/agents/final_output_tool.rs @@ -9,7 +9,7 @@ use serde_json::Value; pub const FINAL_OUTPUT_TOOL_NAME: &str = "recipe__final_output"; pub const FINAL_OUTPUT_CONTINUATION_MESSAGE: &str = - "I see I MUST call the `final_output` tool NOW with the final output for the user."; + "You MUST call the `final_output` tool NOW with the final output for the user."; pub struct FinalOutputTool { pub response: Response, diff --git a/crates/goose/tests/agent.rs b/crates/goose/tests/agent.rs index 3f7cff0c..837a433c 100644 --- a/crates/goose/tests/agent.rs +++ b/crates/goose/tests/agent.rs @@ -535,7 +535,11 @@ mod schedule_tool_tests { #[cfg(test)] mod final_output_tool_tests { use super::*; - use goose::agents::final_output_tool::FINAL_OUTPUT_TOOL_NAME; + use futures::stream; + use goose::agents::final_output_tool::{ + FINAL_OUTPUT_CONTINUATION_MESSAGE, FINAL_OUTPUT_TOOL_NAME, + }; + use goose::providers::base::MessageStream; use goose::recipe::Response; #[tokio::test] @@ -637,6 +641,124 @@ mod final_output_tool_tests { Ok(()) } + + #[tokio::test] + async fn test_when_final_output_not_called_in_reply() -> Result<()> { + use async_trait::async_trait; + use goose::model::ModelConfig; + use goose::providers::base::{Provider, ProviderUsage}; + use goose::providers::errors::ProviderError; + use mcp_core::tool::Tool; + + #[derive(Clone)] + struct MockProvider { + model_config: ModelConfig, + } + + #[async_trait] + impl Provider for MockProvider { + fn metadata() -> goose::providers::base::ProviderMetadata { + goose::providers::base::ProviderMetadata::empty() + } + + fn get_model_config(&self) -> ModelConfig { + self.model_config.clone() + } + + fn supports_streaming(&self) -> bool { + true + } + + async fn stream( + &self, + _system: &str, + _messages: &[Message], + _tools: &[Tool], + ) -> Result { + let deltas = vec![ + Ok((Some(Message::assistant().with_text("Hello")), None)), + Ok((Some(Message::assistant().with_text("Hi!")), None)), + Ok(( + Some(Message::assistant().with_text("What is the final output?")), + None, + )), + ]; + + let stream = stream::iter(deltas.into_iter()); + Ok(Box::pin(stream)) + } + + async fn complete( + &self, + _system: &str, + _messages: &[Message], + _tools: &[Tool], + ) -> Result<(Message, ProviderUsage), ProviderError> { + Err(ProviderError::NotImplemented("Not implemented".to_string())) + } + } + + let agent = Agent::new(); + + let model_config = ModelConfig::new("test-model".to_string()); + let mock_provider = Arc::new(MockProvider { model_config }); + agent.update_provider(mock_provider).await?; + + let response = Response { + json_schema: Some(serde_json::json!({ + "type": "object", + "properties": { + "result": {"type": "string"} + }, + "required": ["result"] + })), + }; + agent.add_final_output_tool(response).await; + + // Simulate the reply stream being called. + let reply_stream = agent.reply(&vec![], None).await?; + tokio::pin!(reply_stream); + + let mut responses = Vec::new(); + let mut count = 0; + while let Some(response_result) = reply_stream.next().await { + match response_result { + Ok(AgentEvent::Message(response)) => { + responses.push(response); + count += 1; + if count >= 4 { + // Limit to 4 messages to avoid infinite loop due to mock provider + break; + } + } + Ok(_) => {} + Err(e) => return Err(e), + } + } + + assert!(!responses.is_empty(), "Should have received responses"); + println!("Responses: {:?}", responses); + let last_message = responses.last().unwrap(); + + // Check that the first 3 messages do not have FINAL_OUTPUT_CONTINUATION_MESSAGE + for (i, response) in responses.iter().take(3).enumerate() { + let message_text = response.as_concat_text(); + assert_ne!( + message_text, + FINAL_OUTPUT_CONTINUATION_MESSAGE, + "Message {} should not be the continuation message, got: '{}'", + i + 1, + message_text + ); + } + + // Check that the last message after the llm stream is the message directing the agent to continue + assert_eq!(last_message.role, mcp_core::role::Role::User); + let message_text = last_message.as_concat_text(); + assert_eq!(message_text, FINAL_OUTPUT_CONTINUATION_MESSAGE); + + Ok(()) + } } #[cfg(test)]