mirror of
https://github.com/aljazceru/goose.git
synced 2025-12-19 07:04:21 +01:00
fix: Fixes structured output after streaming change to agent (#3448)
This commit is contained in:
@@ -837,24 +837,6 @@ impl Agent {
|
|||||||
|
|
||||||
let num_tool_requests = frontend_requests.len() + remaining_requests.len();
|
let num_tool_requests = frontend_requests.len() + remaining_requests.len();
|
||||||
if num_tool_requests == 0 {
|
if num_tool_requests == 0 {
|
||||||
if let Some(final_output_tool) = self.final_output_tool.lock().await.as_ref() {
|
|
||||||
if final_output_tool.final_output.is_none() {
|
|
||||||
tracing::warn!("Final output tool has not been called yet. Continuing agent loop.");
|
|
||||||
let message = Message::assistant().with_text(FINAL_OUTPUT_CONTINUATION_MESSAGE);
|
|
||||||
messages.push(message.clone());
|
|
||||||
yield AgentEvent::Message(message);
|
|
||||||
continue;
|
|
||||||
} else {
|
|
||||||
let message = Message::assistant().with_text(final_output_tool.final_output.clone().unwrap());
|
|
||||||
messages.push(message.clone());
|
|
||||||
yield AgentEvent::Message(message);
|
|
||||||
// Set added_message to true and continue to end the current iteration
|
|
||||||
added_message = true;
|
|
||||||
push_message(&mut messages, response);
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// If there's no final output tool and no tool requests, continue the loop
|
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1039,10 +1021,14 @@ impl Agent {
|
|||||||
if let Some(final_output_tool) = self.final_output_tool.lock().await.as_ref() {
|
if let Some(final_output_tool) = self.final_output_tool.lock().await.as_ref() {
|
||||||
if final_output_tool.final_output.is_none() {
|
if final_output_tool.final_output.is_none() {
|
||||||
tracing::warn!("Final output tool has not been called yet. Continuing agent loop.");
|
tracing::warn!("Final output tool has not been called yet. Continuing agent loop.");
|
||||||
yield AgentEvent::Message(Message::user().with_text(FINAL_OUTPUT_CONTINUATION_MESSAGE));
|
let message = Message::user().with_text(FINAL_OUTPUT_CONTINUATION_MESSAGE);
|
||||||
|
messages.push(message.clone());
|
||||||
|
yield AgentEvent::Message(message);
|
||||||
continue;
|
continue;
|
||||||
} else {
|
} else {
|
||||||
yield AgentEvent::Message(Message::assistant().with_text(final_output_tool.final_output.clone().unwrap()));
|
let message = Message::assistant().with_text(final_output_tool.final_output.clone().unwrap());
|
||||||
|
messages.push(message.clone());
|
||||||
|
yield AgentEvent::Message(message);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ use serde_json::Value;
|
|||||||
|
|
||||||
pub const FINAL_OUTPUT_TOOL_NAME: &str = "recipe__final_output";
|
pub const FINAL_OUTPUT_TOOL_NAME: &str = "recipe__final_output";
|
||||||
pub const FINAL_OUTPUT_CONTINUATION_MESSAGE: &str =
|
pub const FINAL_OUTPUT_CONTINUATION_MESSAGE: &str =
|
||||||
"I see I MUST call the `final_output` tool NOW with the final output for the user.";
|
"You MUST call the `final_output` tool NOW with the final output for the user.";
|
||||||
|
|
||||||
pub struct FinalOutputTool {
|
pub struct FinalOutputTool {
|
||||||
pub response: Response,
|
pub response: Response,
|
||||||
|
|||||||
@@ -535,7 +535,11 @@ mod schedule_tool_tests {
|
|||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod final_output_tool_tests {
|
mod final_output_tool_tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use goose::agents::final_output_tool::FINAL_OUTPUT_TOOL_NAME;
|
use futures::stream;
|
||||||
|
use goose::agents::final_output_tool::{
|
||||||
|
FINAL_OUTPUT_CONTINUATION_MESSAGE, FINAL_OUTPUT_TOOL_NAME,
|
||||||
|
};
|
||||||
|
use goose::providers::base::MessageStream;
|
||||||
use goose::recipe::Response;
|
use goose::recipe::Response;
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
@@ -637,6 +641,124 @@ mod final_output_tool_tests {
|
|||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_when_final_output_not_called_in_reply() -> Result<()> {
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use goose::model::ModelConfig;
|
||||||
|
use goose::providers::base::{Provider, ProviderUsage};
|
||||||
|
use goose::providers::errors::ProviderError;
|
||||||
|
use mcp_core::tool::Tool;
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
struct MockProvider {
|
||||||
|
model_config: ModelConfig,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl Provider for MockProvider {
|
||||||
|
fn metadata() -> goose::providers::base::ProviderMetadata {
|
||||||
|
goose::providers::base::ProviderMetadata::empty()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_model_config(&self) -> ModelConfig {
|
||||||
|
self.model_config.clone()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn supports_streaming(&self) -> bool {
|
||||||
|
true
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn stream(
|
||||||
|
&self,
|
||||||
|
_system: &str,
|
||||||
|
_messages: &[Message],
|
||||||
|
_tools: &[Tool],
|
||||||
|
) -> Result<MessageStream, ProviderError> {
|
||||||
|
let deltas = vec![
|
||||||
|
Ok((Some(Message::assistant().with_text("Hello")), None)),
|
||||||
|
Ok((Some(Message::assistant().with_text("Hi!")), None)),
|
||||||
|
Ok((
|
||||||
|
Some(Message::assistant().with_text("What is the final output?")),
|
||||||
|
None,
|
||||||
|
)),
|
||||||
|
];
|
||||||
|
|
||||||
|
let stream = stream::iter(deltas.into_iter());
|
||||||
|
Ok(Box::pin(stream))
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn complete(
|
||||||
|
&self,
|
||||||
|
_system: &str,
|
||||||
|
_messages: &[Message],
|
||||||
|
_tools: &[Tool],
|
||||||
|
) -> Result<(Message, ProviderUsage), ProviderError> {
|
||||||
|
Err(ProviderError::NotImplemented("Not implemented".to_string()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let agent = Agent::new();
|
||||||
|
|
||||||
|
let model_config = ModelConfig::new("test-model".to_string());
|
||||||
|
let mock_provider = Arc::new(MockProvider { model_config });
|
||||||
|
agent.update_provider(mock_provider).await?;
|
||||||
|
|
||||||
|
let response = Response {
|
||||||
|
json_schema: Some(serde_json::json!({
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"result": {"type": "string"}
|
||||||
|
},
|
||||||
|
"required": ["result"]
|
||||||
|
})),
|
||||||
|
};
|
||||||
|
agent.add_final_output_tool(response).await;
|
||||||
|
|
||||||
|
// Simulate the reply stream being called.
|
||||||
|
let reply_stream = agent.reply(&vec![], None).await?;
|
||||||
|
tokio::pin!(reply_stream);
|
||||||
|
|
||||||
|
let mut responses = Vec::new();
|
||||||
|
let mut count = 0;
|
||||||
|
while let Some(response_result) = reply_stream.next().await {
|
||||||
|
match response_result {
|
||||||
|
Ok(AgentEvent::Message(response)) => {
|
||||||
|
responses.push(response);
|
||||||
|
count += 1;
|
||||||
|
if count >= 4 {
|
||||||
|
// Limit to 4 messages to avoid infinite loop due to mock provider
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(_) => {}
|
||||||
|
Err(e) => return Err(e),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
assert!(!responses.is_empty(), "Should have received responses");
|
||||||
|
println!("Responses: {:?}", responses);
|
||||||
|
let last_message = responses.last().unwrap();
|
||||||
|
|
||||||
|
// Check that the first 3 messages do not have FINAL_OUTPUT_CONTINUATION_MESSAGE
|
||||||
|
for (i, response) in responses.iter().take(3).enumerate() {
|
||||||
|
let message_text = response.as_concat_text();
|
||||||
|
assert_ne!(
|
||||||
|
message_text,
|
||||||
|
FINAL_OUTPUT_CONTINUATION_MESSAGE,
|
||||||
|
"Message {} should not be the continuation message, got: '{}'",
|
||||||
|
i + 1,
|
||||||
|
message_text
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check that the last message after the llm stream is the message directing the agent to continue
|
||||||
|
assert_eq!(last_message.role, mcp_core::role::Role::User);
|
||||||
|
let message_text = last_message.as_concat_text();
|
||||||
|
assert_eq!(message_text, FINAL_OUTPUT_CONTINUATION_MESSAGE);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
|
|||||||
Reference in New Issue
Block a user