fix: Fixes structured output after streaming change to agent (#3448)

This commit is contained in:
Jarrod Sibbison
2025-07-16 20:46:22 +10:00
committed by GitHub
parent 77ea27f5f5
commit eda810040d
3 changed files with 130 additions and 22 deletions

View File

@@ -837,24 +837,6 @@ impl Agent {
let num_tool_requests = frontend_requests.len() + remaining_requests.len(); let num_tool_requests = frontend_requests.len() + remaining_requests.len();
if num_tool_requests == 0 { if num_tool_requests == 0 {
if let Some(final_output_tool) = self.final_output_tool.lock().await.as_ref() {
if final_output_tool.final_output.is_none() {
tracing::warn!("Final output tool has not been called yet. Continuing agent loop.");
let message = Message::assistant().with_text(FINAL_OUTPUT_CONTINUATION_MESSAGE);
messages.push(message.clone());
yield AgentEvent::Message(message);
continue;
} else {
let message = Message::assistant().with_text(final_output_tool.final_output.clone().unwrap());
messages.push(message.clone());
yield AgentEvent::Message(message);
// Set added_message to true and continue to end the current iteration
added_message = true;
push_message(&mut messages, response);
continue;
}
}
// If there's no final output tool and no tool requests, continue the loop
continue; continue;
} }
@@ -1039,10 +1021,14 @@ impl Agent {
if let Some(final_output_tool) = self.final_output_tool.lock().await.as_ref() { if let Some(final_output_tool) = self.final_output_tool.lock().await.as_ref() {
if final_output_tool.final_output.is_none() { if final_output_tool.final_output.is_none() {
tracing::warn!("Final output tool has not been called yet. Continuing agent loop."); tracing::warn!("Final output tool has not been called yet. Continuing agent loop.");
yield AgentEvent::Message(Message::user().with_text(FINAL_OUTPUT_CONTINUATION_MESSAGE)); let message = Message::user().with_text(FINAL_OUTPUT_CONTINUATION_MESSAGE);
messages.push(message.clone());
yield AgentEvent::Message(message);
continue; continue;
} else { } else {
yield AgentEvent::Message(Message::assistant().with_text(final_output_tool.final_output.clone().unwrap())); let message = Message::assistant().with_text(final_output_tool.final_output.clone().unwrap());
messages.push(message.clone());
yield AgentEvent::Message(message);
} }
} }
break; break;

View File

@@ -9,7 +9,7 @@ use serde_json::Value;
pub const FINAL_OUTPUT_TOOL_NAME: &str = "recipe__final_output"; pub const FINAL_OUTPUT_TOOL_NAME: &str = "recipe__final_output";
pub const FINAL_OUTPUT_CONTINUATION_MESSAGE: &str = pub const FINAL_OUTPUT_CONTINUATION_MESSAGE: &str =
"I see I MUST call the `final_output` tool NOW with the final output for the user."; "You MUST call the `final_output` tool NOW with the final output for the user.";
pub struct FinalOutputTool { pub struct FinalOutputTool {
pub response: Response, pub response: Response,

View File

@@ -535,7 +535,11 @@ mod schedule_tool_tests {
#[cfg(test)] #[cfg(test)]
mod final_output_tool_tests { mod final_output_tool_tests {
use super::*; use super::*;
use goose::agents::final_output_tool::FINAL_OUTPUT_TOOL_NAME; use futures::stream;
use goose::agents::final_output_tool::{
FINAL_OUTPUT_CONTINUATION_MESSAGE, FINAL_OUTPUT_TOOL_NAME,
};
use goose::providers::base::MessageStream;
use goose::recipe::Response; use goose::recipe::Response;
#[tokio::test] #[tokio::test]
@@ -637,6 +641,124 @@ mod final_output_tool_tests {
Ok(()) Ok(())
} }
#[tokio::test]
async fn test_when_final_output_not_called_in_reply() -> Result<()> {
use async_trait::async_trait;
use goose::model::ModelConfig;
use goose::providers::base::{Provider, ProviderUsage};
use goose::providers::errors::ProviderError;
use mcp_core::tool::Tool;
#[derive(Clone)]
struct MockProvider {
model_config: ModelConfig,
}
#[async_trait]
impl Provider for MockProvider {
fn metadata() -> goose::providers::base::ProviderMetadata {
goose::providers::base::ProviderMetadata::empty()
}
fn get_model_config(&self) -> ModelConfig {
self.model_config.clone()
}
fn supports_streaming(&self) -> bool {
true
}
async fn stream(
&self,
_system: &str,
_messages: &[Message],
_tools: &[Tool],
) -> Result<MessageStream, ProviderError> {
let deltas = vec![
Ok((Some(Message::assistant().with_text("Hello")), None)),
Ok((Some(Message::assistant().with_text("Hi!")), None)),
Ok((
Some(Message::assistant().with_text("What is the final output?")),
None,
)),
];
let stream = stream::iter(deltas.into_iter());
Ok(Box::pin(stream))
}
async fn complete(
&self,
_system: &str,
_messages: &[Message],
_tools: &[Tool],
) -> Result<(Message, ProviderUsage), ProviderError> {
Err(ProviderError::NotImplemented("Not implemented".to_string()))
}
}
let agent = Agent::new();
let model_config = ModelConfig::new("test-model".to_string());
let mock_provider = Arc::new(MockProvider { model_config });
agent.update_provider(mock_provider).await?;
let response = Response {
json_schema: Some(serde_json::json!({
"type": "object",
"properties": {
"result": {"type": "string"}
},
"required": ["result"]
})),
};
agent.add_final_output_tool(response).await;
// Simulate the reply stream being called.
let reply_stream = agent.reply(&vec![], None).await?;
tokio::pin!(reply_stream);
let mut responses = Vec::new();
let mut count = 0;
while let Some(response_result) = reply_stream.next().await {
match response_result {
Ok(AgentEvent::Message(response)) => {
responses.push(response);
count += 1;
if count >= 4 {
// Limit to 4 messages to avoid infinite loop due to mock provider
break;
}
}
Ok(_) => {}
Err(e) => return Err(e),
}
}
assert!(!responses.is_empty(), "Should have received responses");
println!("Responses: {:?}", responses);
let last_message = responses.last().unwrap();
// Check that the first 3 messages do not have FINAL_OUTPUT_CONTINUATION_MESSAGE
for (i, response) in responses.iter().take(3).enumerate() {
let message_text = response.as_concat_text();
assert_ne!(
message_text,
FINAL_OUTPUT_CONTINUATION_MESSAGE,
"Message {} should not be the continuation message, got: '{}'",
i + 1,
message_text
);
}
// Check that the last message after the llm stream is the message directing the agent to continue
assert_eq!(last_message.role, mcp_core::role::Role::User);
let message_text = last_message.as_concat_text();
assert_eq!(message_text, FINAL_OUTPUT_CONTINUATION_MESSAGE);
Ok(())
}
} }
#[cfg(test)] #[cfg(test)]