mirror of
https://github.com/aljazceru/goose.git
synced 2025-12-19 07:04:21 +01:00
fix: Fixes structured output after streaming change to agent (#3448)
This commit is contained in:
@@ -837,24 +837,6 @@ impl Agent {
|
||||
|
||||
let num_tool_requests = frontend_requests.len() + remaining_requests.len();
|
||||
if num_tool_requests == 0 {
|
||||
if let Some(final_output_tool) = self.final_output_tool.lock().await.as_ref() {
|
||||
if final_output_tool.final_output.is_none() {
|
||||
tracing::warn!("Final output tool has not been called yet. Continuing agent loop.");
|
||||
let message = Message::assistant().with_text(FINAL_OUTPUT_CONTINUATION_MESSAGE);
|
||||
messages.push(message.clone());
|
||||
yield AgentEvent::Message(message);
|
||||
continue;
|
||||
} else {
|
||||
let message = Message::assistant().with_text(final_output_tool.final_output.clone().unwrap());
|
||||
messages.push(message.clone());
|
||||
yield AgentEvent::Message(message);
|
||||
// Set added_message to true and continue to end the current iteration
|
||||
added_message = true;
|
||||
push_message(&mut messages, response);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
// If there's no final output tool and no tool requests, continue the loop
|
||||
continue;
|
||||
}
|
||||
|
||||
@@ -1039,10 +1021,14 @@ impl Agent {
|
||||
if let Some(final_output_tool) = self.final_output_tool.lock().await.as_ref() {
|
||||
if final_output_tool.final_output.is_none() {
|
||||
tracing::warn!("Final output tool has not been called yet. Continuing agent loop.");
|
||||
yield AgentEvent::Message(Message::user().with_text(FINAL_OUTPUT_CONTINUATION_MESSAGE));
|
||||
let message = Message::user().with_text(FINAL_OUTPUT_CONTINUATION_MESSAGE);
|
||||
messages.push(message.clone());
|
||||
yield AgentEvent::Message(message);
|
||||
continue;
|
||||
} else {
|
||||
yield AgentEvent::Message(Message::assistant().with_text(final_output_tool.final_output.clone().unwrap()));
|
||||
let message = Message::assistant().with_text(final_output_tool.final_output.clone().unwrap());
|
||||
messages.push(message.clone());
|
||||
yield AgentEvent::Message(message);
|
||||
}
|
||||
}
|
||||
break;
|
||||
|
||||
@@ -9,7 +9,7 @@ use serde_json::Value;
|
||||
|
||||
pub const FINAL_OUTPUT_TOOL_NAME: &str = "recipe__final_output";
|
||||
pub const FINAL_OUTPUT_CONTINUATION_MESSAGE: &str =
|
||||
"I see I MUST call the `final_output` tool NOW with the final output for the user.";
|
||||
"You MUST call the `final_output` tool NOW with the final output for the user.";
|
||||
|
||||
pub struct FinalOutputTool {
|
||||
pub response: Response,
|
||||
|
||||
@@ -535,7 +535,11 @@ mod schedule_tool_tests {
|
||||
#[cfg(test)]
|
||||
mod final_output_tool_tests {
|
||||
use super::*;
|
||||
use goose::agents::final_output_tool::FINAL_OUTPUT_TOOL_NAME;
|
||||
use futures::stream;
|
||||
use goose::agents::final_output_tool::{
|
||||
FINAL_OUTPUT_CONTINUATION_MESSAGE, FINAL_OUTPUT_TOOL_NAME,
|
||||
};
|
||||
use goose::providers::base::MessageStream;
|
||||
use goose::recipe::Response;
|
||||
|
||||
#[tokio::test]
|
||||
@@ -637,6 +641,124 @@ mod final_output_tool_tests {
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_when_final_output_not_called_in_reply() -> Result<()> {
|
||||
use async_trait::async_trait;
|
||||
use goose::model::ModelConfig;
|
||||
use goose::providers::base::{Provider, ProviderUsage};
|
||||
use goose::providers::errors::ProviderError;
|
||||
use mcp_core::tool::Tool;
|
||||
|
||||
#[derive(Clone)]
|
||||
struct MockProvider {
|
||||
model_config: ModelConfig,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Provider for MockProvider {
|
||||
fn metadata() -> goose::providers::base::ProviderMetadata {
|
||||
goose::providers::base::ProviderMetadata::empty()
|
||||
}
|
||||
|
||||
fn get_model_config(&self) -> ModelConfig {
|
||||
self.model_config.clone()
|
||||
}
|
||||
|
||||
fn supports_streaming(&self) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
async fn stream(
|
||||
&self,
|
||||
_system: &str,
|
||||
_messages: &[Message],
|
||||
_tools: &[Tool],
|
||||
) -> Result<MessageStream, ProviderError> {
|
||||
let deltas = vec![
|
||||
Ok((Some(Message::assistant().with_text("Hello")), None)),
|
||||
Ok((Some(Message::assistant().with_text("Hi!")), None)),
|
||||
Ok((
|
||||
Some(Message::assistant().with_text("What is the final output?")),
|
||||
None,
|
||||
)),
|
||||
];
|
||||
|
||||
let stream = stream::iter(deltas.into_iter());
|
||||
Ok(Box::pin(stream))
|
||||
}
|
||||
|
||||
async fn complete(
|
||||
&self,
|
||||
_system: &str,
|
||||
_messages: &[Message],
|
||||
_tools: &[Tool],
|
||||
) -> Result<(Message, ProviderUsage), ProviderError> {
|
||||
Err(ProviderError::NotImplemented("Not implemented".to_string()))
|
||||
}
|
||||
}
|
||||
|
||||
let agent = Agent::new();
|
||||
|
||||
let model_config = ModelConfig::new("test-model".to_string());
|
||||
let mock_provider = Arc::new(MockProvider { model_config });
|
||||
agent.update_provider(mock_provider).await?;
|
||||
|
||||
let response = Response {
|
||||
json_schema: Some(serde_json::json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"result": {"type": "string"}
|
||||
},
|
||||
"required": ["result"]
|
||||
})),
|
||||
};
|
||||
agent.add_final_output_tool(response).await;
|
||||
|
||||
// Simulate the reply stream being called.
|
||||
let reply_stream = agent.reply(&vec![], None).await?;
|
||||
tokio::pin!(reply_stream);
|
||||
|
||||
let mut responses = Vec::new();
|
||||
let mut count = 0;
|
||||
while let Some(response_result) = reply_stream.next().await {
|
||||
match response_result {
|
||||
Ok(AgentEvent::Message(response)) => {
|
||||
responses.push(response);
|
||||
count += 1;
|
||||
if count >= 4 {
|
||||
// Limit to 4 messages to avoid infinite loop due to mock provider
|
||||
break;
|
||||
}
|
||||
}
|
||||
Ok(_) => {}
|
||||
Err(e) => return Err(e),
|
||||
}
|
||||
}
|
||||
|
||||
assert!(!responses.is_empty(), "Should have received responses");
|
||||
println!("Responses: {:?}", responses);
|
||||
let last_message = responses.last().unwrap();
|
||||
|
||||
// Check that the first 3 messages do not have FINAL_OUTPUT_CONTINUATION_MESSAGE
|
||||
for (i, response) in responses.iter().take(3).enumerate() {
|
||||
let message_text = response.as_concat_text();
|
||||
assert_ne!(
|
||||
message_text,
|
||||
FINAL_OUTPUT_CONTINUATION_MESSAGE,
|
||||
"Message {} should not be the continuation message, got: '{}'",
|
||||
i + 1,
|
||||
message_text
|
||||
);
|
||||
}
|
||||
|
||||
// Check that the last message after the llm stream is the message directing the agent to continue
|
||||
assert_eq!(last_message.role, mcp_core::role::Role::User);
|
||||
let message_text = last_message.as_concat_text();
|
||||
assert_eq!(message_text, FINAL_OUTPUT_CONTINUATION_MESSAGE);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
||||
Reference in New Issue
Block a user