mirror of
https://github.com/aljazceru/goose.git
synced 2025-12-19 07:04:21 +01:00
refactor: implement nested streaming for frontend & regular tool approval requests (#2184)
Co-authored-by: Kalvin Chau <kalvin@block.xyz>
This commit is contained in:
10
Cargo.lock
generated
10
Cargo.lock
generated
@@ -2323,7 +2323,7 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "goose"
|
name = "goose"
|
||||||
version = "1.0.18"
|
version = "1.0.17"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"anyhow",
|
"anyhow",
|
||||||
"async-stream",
|
"async-stream",
|
||||||
@@ -2378,7 +2378,7 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "goose-bench"
|
name = "goose-bench"
|
||||||
version = "1.0.18"
|
version = "1.0.17"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"anyhow",
|
"anyhow",
|
||||||
"async-trait",
|
"async-trait",
|
||||||
@@ -2401,7 +2401,7 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "goose-cli"
|
name = "goose-cli"
|
||||||
version = "1.0.18"
|
version = "1.0.17"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"anyhow",
|
"anyhow",
|
||||||
"async-trait",
|
"async-trait",
|
||||||
@@ -2439,7 +2439,7 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "goose-mcp"
|
name = "goose-mcp"
|
||||||
version = "1.0.18"
|
version = "1.0.17"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"anyhow",
|
"anyhow",
|
||||||
"async-trait",
|
"async-trait",
|
||||||
@@ -2485,7 +2485,7 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "goose-server"
|
name = "goose-server"
|
||||||
version = "1.0.18"
|
version = "1.0.17"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"anyhow",
|
"anyhow",
|
||||||
"async-trait",
|
"async-trait",
|
||||||
|
|||||||
@@ -3,13 +3,12 @@ use std::sync::Arc;
|
|||||||
|
|
||||||
use anyhow::{anyhow, Result};
|
use anyhow::{anyhow, Result};
|
||||||
use futures::stream::BoxStream;
|
use futures::stream::BoxStream;
|
||||||
|
use futures::TryStreamExt;
|
||||||
|
|
||||||
use crate::config::permission::PermissionLevel;
|
|
||||||
use crate::config::{Config, ExtensionConfigManager, PermissionManager};
|
use crate::config::{Config, ExtensionConfigManager, PermissionManager};
|
||||||
use crate::message::{Message, MessageContent, ToolRequest};
|
use crate::message::Message;
|
||||||
use crate::permission::permission_confirmation::PrincipalType;
|
use crate::permission::permission_judge::check_tool_permissions;
|
||||||
use crate::permission::permission_judge::{check_tool_permissions, get_confirmation_message};
|
use crate::permission::PermissionConfirmation;
|
||||||
use crate::permission::{Permission, PermissionConfirmation};
|
|
||||||
use crate::providers::base::Provider;
|
use crate::providers::base::Provider;
|
||||||
use crate::providers::errors::ProviderError;
|
use crate::providers::errors::ProviderError;
|
||||||
use crate::recipe::{Author, Recipe};
|
use crate::recipe::{Author, Recipe};
|
||||||
@@ -33,6 +32,10 @@ use mcp_core::{
|
|||||||
prompt::Prompt, protocol::GetPromptResult, tool::Tool, Content, ToolError, ToolResult,
|
prompt::Prompt, protocol::GetPromptResult, tool::Tool, Content, ToolError, ToolResult,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
use super::tool_execution::{
|
||||||
|
ExtensionInstallResult, ToolFuture, CHAT_MODE_TOOL_SKIPPED_RESPONSE, DECLINED_RESPONSE,
|
||||||
|
};
|
||||||
|
|
||||||
const MAX_TRUNCATION_ATTEMPTS: usize = 3;
|
const MAX_TRUNCATION_ATTEMPTS: usize = 3;
|
||||||
const ESTIMATE_FACTOR_DECAY: f32 = 0.9;
|
const ESTIMATE_FACTOR_DECAY: f32 = 0.9;
|
||||||
|
|
||||||
@@ -105,7 +108,7 @@ impl Agent {
|
|||||||
|
|
||||||
/// Dispatch a single tool call to the appropriate client
|
/// Dispatch a single tool call to the appropriate client
|
||||||
#[instrument(skip(self, tool_call, request_id), fields(input, output))]
|
#[instrument(skip(self, tool_call, request_id), fields(input, output))]
|
||||||
async fn dispatch_tool_call(
|
pub(super) async fn dispatch_tool_call(
|
||||||
&self,
|
&self,
|
||||||
tool_call: mcp_core::tool::ToolCall,
|
tool_call: mcp_core::tool::ToolCall,
|
||||||
request_id: String,
|
request_id: String,
|
||||||
@@ -190,7 +193,7 @@ impl Agent {
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn enable_extension(
|
pub(super) async fn enable_extension(
|
||||||
&self,
|
&self,
|
||||||
extension_name: String,
|
extension_name: String,
|
||||||
request_id: String,
|
request_id: String,
|
||||||
@@ -351,57 +354,37 @@ impl Agent {
|
|||||||
// Reset truncation attempt
|
// Reset truncation attempt
|
||||||
truncation_attempt = 0;
|
truncation_attempt = 0;
|
||||||
|
|
||||||
// Yield the assistant's response, but filter out frontend tool requests that we'll process separately
|
// categorize the type of requests we need to handle
|
||||||
let filtered_response = Message {
|
let (frontend_requests,
|
||||||
role: response.role.clone(),
|
remaining_requests,
|
||||||
created: response.created,
|
filtered_response) =
|
||||||
content: response.content.iter().filter(|c| {
|
self.categorize_tool_requests(&response);
|
||||||
if let MessageContent::ToolRequest(req) = c {
|
|
||||||
// Only filter out frontend tool requests
|
|
||||||
if let Ok(tool_call) = &req.tool_call {
|
// Yield the assistant's response with frontend tool requests filtered out
|
||||||
return !self.is_frontend_tool(&tool_call.name);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
true
|
|
||||||
}).cloned().collect(),
|
|
||||||
};
|
|
||||||
yield filtered_response.clone();
|
yield filtered_response.clone();
|
||||||
|
|
||||||
tokio::task::yield_now().await;
|
tokio::task::yield_now().await;
|
||||||
|
|
||||||
// First collect any tool requests
|
let num_tool_requests = frontend_requests.len() + remaining_requests.len();
|
||||||
let tool_requests: Vec<&ToolRequest> = response.content
|
if num_tool_requests == 0 {
|
||||||
.iter()
|
|
||||||
.filter_map(|content| content.as_tool_request())
|
|
||||||
.collect();
|
|
||||||
|
|
||||||
if tool_requests.is_empty() {
|
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Process tool requests depending on goose_mode
|
// Process tool requests depending on frontend tools and then goose_mode
|
||||||
let mut message_tool_response = Message::user();
|
let message_tool_response = Arc::new(Mutex::new(Message::user()));
|
||||||
|
|
||||||
// First handle any frontend tool requests
|
// First handle any frontend tool requests
|
||||||
let mut remaining_requests = Vec::new();
|
let mut frontend_tool_stream = self.handle_frontend_tool_requests(
|
||||||
for request in &tool_requests {
|
&frontend_requests,
|
||||||
if let Ok(tool_call) = request.tool_call.clone() {
|
message_tool_response.clone()
|
||||||
if self.is_frontend_tool(&tool_call.name) {
|
|
||||||
// Send frontend tool request and wait for response
|
|
||||||
yield Message::assistant().with_frontend_tool_request(
|
|
||||||
request.id.clone(),
|
|
||||||
Ok(tool_call.clone())
|
|
||||||
);
|
);
|
||||||
|
|
||||||
if let Some((id, result)) = self.tool_result_rx.lock().await.recv().await {
|
// we have a stream of frontend tools to handle, inside the stream
|
||||||
message_tool_response = message_tool_response.with_tool_response(id, result);
|
// execution is yeield back to this reply loop, and is of the same Message
|
||||||
}
|
// type, so we can yield that back up to be handled
|
||||||
} else {
|
while let Some(msg) = frontend_tool_stream.try_next().await? {
|
||||||
remaining_requests.push(request);
|
yield msg;
|
||||||
}
|
|
||||||
} else {
|
|
||||||
remaining_requests.push(request);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Clone goose_mode once before the match to avoid move issues
|
// Clone goose_mode once before the match to avoid move issues
|
||||||
@@ -409,100 +392,97 @@ impl Agent {
|
|||||||
if mode.as_str() == "chat" {
|
if mode.as_str() == "chat" {
|
||||||
// Skip all tool calls in chat mode
|
// Skip all tool calls in chat mode
|
||||||
for request in remaining_requests {
|
for request in remaining_requests {
|
||||||
message_tool_response = message_tool_response.with_tool_response(
|
let mut response = message_tool_response.lock().await;
|
||||||
|
*response = response.clone().with_tool_response(
|
||||||
request.id.clone(),
|
request.id.clone(),
|
||||||
Ok(vec![Content::text(
|
Ok(vec![Content::text(CHAT_MODE_TOOL_SKIPPED_RESPONSE)]),
|
||||||
"Let the user know the tool call was skipped in Goose chat mode. \
|
|
||||||
DO NOT apologize for skipping the tool call. DO NOT say sorry. \
|
|
||||||
Provide an explanation of what the tool call would do, structured as a \
|
|
||||||
plan for the user. Again, DO NOT apologize. \
|
|
||||||
**Example Plan:**\n \
|
|
||||||
1. **Identify Task Scope** - Determine the purpose and expected outcome.\n \
|
|
||||||
2. **Outline Steps** - Break down the steps.\n \
|
|
||||||
If needed, adjust the explanation based on user preferences or questions."
|
|
||||||
)]),
|
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
// At this point, we have handled the frontend tool requests and know goose_mode != "chat"
|
||||||
|
// What remains is handling the remaining tool requests (enable extension,
|
||||||
|
// regular tool calls) in goose_mode == ["auto", "approve" or "smart_approve"]
|
||||||
let mut permission_manager = PermissionManager::default();
|
let mut permission_manager = PermissionManager::default();
|
||||||
let permission_check_result = check_tool_permissions(remaining_requests.into_iter().copied().collect(),
|
let permission_check_result = check_tool_permissions(&remaining_requests,
|
||||||
&mode,
|
&mode,
|
||||||
tools_with_readonly_annotation.clone(),
|
tools_with_readonly_annotation.clone(),
|
||||||
tools_without_annotation.clone(),
|
tools_without_annotation.clone(),
|
||||||
&mut permission_manager,
|
&mut permission_manager,
|
||||||
self.provider()).await;
|
self.provider()).await;
|
||||||
|
|
||||||
// Handle pre-approved and read-only tools in parallel
|
|
||||||
let mut tool_futures = Vec::new();
|
|
||||||
let mut install_results = Vec::new();
|
|
||||||
|
|
||||||
let denied_content_text = Content::text(
|
// Handle pre-approved and read-only tools in parallel
|
||||||
"The user has declined to run this tool. \
|
let mut tool_futures: Vec<ToolFuture> = Vec::new();
|
||||||
DO NOT attempt to call this tool again. \
|
let mut install_results: Vec<ExtensionInstallResult> = Vec::new();
|
||||||
If there are no alternative methods to proceed, clearly explain the situation and STOP.");
|
let install_results_arc = Arc::new(Mutex::new(install_results));
|
||||||
|
|
||||||
// Skip the confirmation for approved tools
|
// Skip the confirmation for approved tools
|
||||||
for request in &permission_check_result.approved {
|
for request in &permission_check_result.approved {
|
||||||
if let Ok(tool_call) = request.tool_call.clone() {
|
if let Ok(tool_call) = request.tool_call.clone() {
|
||||||
let tool_future = self.dispatch_tool_call(tool_call, request.id.clone());
|
let tool_future = self.dispatch_tool_call(tool_call, request.id.clone());
|
||||||
tool_futures.push(tool_future);
|
tool_futures.push(Box::pin(tool_future));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for request in &permission_check_result.denied {
|
for request in &permission_check_result.denied {
|
||||||
message_tool_response = message_tool_response.with_tool_response(
|
let mut response = message_tool_response.lock().await;
|
||||||
|
*response = response.clone().with_tool_response(
|
||||||
request.id.clone(),
|
request.id.clone(),
|
||||||
Ok(vec![denied_content_text.clone()]),
|
Ok(vec![Content::text(DECLINED_RESPONSE)]),
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Process tools requiring approval
|
// we need interior mutability in handle_approval_tool_requests
|
||||||
for request in &permission_check_result.needs_approval {
|
let tool_futures_arc = Arc::new(Mutex::new(tool_futures));
|
||||||
if let Ok(tool_call) = request.tool_call.clone() {
|
// Process tools requiring approval (enable extension, regular tool calls)
|
||||||
let (principal_type, confirmation) = get_confirmation_message(&request.id.clone(), tool_call.clone());
|
let mut tool_approval_stream = self.handle_approval_tool_requests(
|
||||||
yield confirmation;
|
&permission_check_result.needs_approval,
|
||||||
|
install_results_arc.clone(),
|
||||||
// Wait for confirmation response through the channel
|
tool_futures_arc.clone(),
|
||||||
let mut rx = self.confirmation_rx.lock().await;
|
&mut permission_manager,
|
||||||
while let Some((req_id, confirmation)) = rx.recv().await {
|
message_tool_response.clone()
|
||||||
if req_id == request.id {
|
|
||||||
if confirmation.permission == Permission::AllowOnce || confirmation.permission == Permission::AlwaysAllow {
|
|
||||||
if principal_type == PrincipalType::Extension {
|
|
||||||
let extension_name = tool_call.arguments.get("extension_name")
|
|
||||||
.and_then(|v| v.as_str())
|
|
||||||
.unwrap_or("")
|
|
||||||
.to_string();
|
|
||||||
let install_result = self.enable_extension(extension_name, request.id.clone()).await;
|
|
||||||
install_results.push(install_result);
|
|
||||||
} else {
|
|
||||||
// Add this tool call to the futures collection
|
|
||||||
let tool_future = self.dispatch_tool_call(tool_call.clone(), request.id.clone());
|
|
||||||
tool_futures.push(tool_future);
|
|
||||||
if confirmation.permission == Permission::AlwaysAllow {
|
|
||||||
permission_manager.update_user_permission(&tool_call.name, PermissionLevel::AlwaysAllow);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// User declined - add declined response
|
|
||||||
message_tool_response = message_tool_response.with_tool_response(
|
|
||||||
request.id.clone(),
|
|
||||||
Ok(vec![denied_content_text.clone()]),
|
|
||||||
);
|
);
|
||||||
|
|
||||||
|
// we have a stream of tool_approval_requests to handle
|
||||||
|
// execution is yeield back to this reply loop, and is of the same Message
|
||||||
|
// type, so we can yield the Message back up to be handled and grab and
|
||||||
|
// confirmations or denials
|
||||||
|
while let Some(msg) = tool_approval_stream.try_next().await? {
|
||||||
|
yield msg;
|
||||||
}
|
}
|
||||||
break; // Exit the loop once the matching `req_id` is found
|
|
||||||
}
|
tool_futures = {
|
||||||
}
|
// Lock the mutex asynchronously.
|
||||||
}
|
let mut futures_lock = tool_futures_arc.lock().await;
|
||||||
}
|
// Drain the vector and collect into a new Vec.
|
||||||
|
futures_lock.drain(..).collect::<Vec<_>>()
|
||||||
|
};
|
||||||
|
|
||||||
|
install_results = {
|
||||||
|
// Lock the mutex asynchronously.
|
||||||
|
let mut results_lock = install_results_arc.lock().await;
|
||||||
|
// Drain the vector and collect into a new Vec.
|
||||||
|
results_lock.drain(..).collect::<Vec<_>>()
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
// Wait for all tool calls to complete
|
// Wait for all tool calls to complete
|
||||||
let results = futures::future::join_all(tool_futures).await;
|
let results = futures::future::join_all(tool_futures).await;
|
||||||
|
for (request_id, output) in results {
|
||||||
|
let mut response = message_tool_response.lock().await;
|
||||||
|
*response = response.clone().with_tool_response(
|
||||||
|
request_id,
|
||||||
|
output,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
// Check if any install results had errors before processing them
|
// Check if any install results had errors before processing them
|
||||||
let all_install_successful = !install_results.iter().any(|(_, result)| result.is_err());
|
let all_install_successful = !install_results.iter().any(|(_, result)| result.is_err());
|
||||||
|
for (request_id, output) in install_results {
|
||||||
for (request_id, output) in results.into_iter().chain(install_results.into_iter()) {
|
let mut response = message_tool_response.lock().await;
|
||||||
message_tool_response = message_tool_response.with_tool_response(request_id, output);
|
*response = response.clone().with_tool_response(
|
||||||
|
request_id,
|
||||||
|
output
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Update system prompt and tools if installations were successful
|
// Update system prompt and tools if installations were successful
|
||||||
@@ -511,10 +491,11 @@ impl Agent {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
yield message_tool_response.clone();
|
let final_message_tool_resp = message_tool_response.lock().await.clone();
|
||||||
|
yield final_message_tool_resp.clone();
|
||||||
|
|
||||||
messages.push(response);
|
messages.push(response);
|
||||||
messages.push(message_tool_response);
|
messages.push(final_message_tool_resp);
|
||||||
},
|
},
|
||||||
Err(ProviderError::ContextLengthExceeded(_)) => {
|
Err(ProviderError::ContextLengthExceeded(_)) => {
|
||||||
if truncation_attempt >= MAX_TRUNCATION_ATTEMPTS {
|
if truncation_attempt >= MAX_TRUNCATION_ATTEMPTS {
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ pub mod extension_manager;
|
|||||||
pub mod platform_tools;
|
pub mod platform_tools;
|
||||||
pub mod prompt_manager;
|
pub mod prompt_manager;
|
||||||
mod reply_parts;
|
mod reply_parts;
|
||||||
|
mod tool_execution;
|
||||||
mod types;
|
mod types;
|
||||||
|
|
||||||
pub use agent::Agent;
|
pub use agent::Agent;
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ use std::sync::Arc;
|
|||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
|
|
||||||
use crate::agents::platform_tools;
|
use crate::agents::platform_tools;
|
||||||
use crate::message::Message;
|
use crate::message::{Message, MessageContent, ToolRequest};
|
||||||
use crate::providers::base::{Provider, ProviderUsage};
|
use crate::providers::base::{Provider, ProviderUsage};
|
||||||
use crate::providers::errors::ProviderError;
|
use crate::providers::errors::ProviderError;
|
||||||
use crate::providers::toolshim::{
|
use crate::providers::toolshim::{
|
||||||
@@ -111,6 +111,70 @@ impl Agent {
|
|||||||
Ok((response, usage))
|
Ok((response, usage))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Categorize tool requests from the response into different types
|
||||||
|
/// Returns:
|
||||||
|
/// - frontend_requests: Tool requests that should be handled by the frontend
|
||||||
|
/// - other_requests: All other tool requests (including requests to enable extensions)
|
||||||
|
/// - filtered_message: The original message with frontend tool requests removed
|
||||||
|
pub(crate) fn categorize_tool_requests(
|
||||||
|
&self,
|
||||||
|
response: &Message,
|
||||||
|
) -> (Vec<ToolRequest>, Vec<ToolRequest>, Message) {
|
||||||
|
// First collect all tool requests
|
||||||
|
let tool_requests: Vec<ToolRequest> = response
|
||||||
|
.content
|
||||||
|
.iter()
|
||||||
|
.filter_map(|content| {
|
||||||
|
if let MessageContent::ToolRequest(req) = content {
|
||||||
|
Some(req.clone())
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
// Create a filtered message with frontend tool requests removed
|
||||||
|
let filtered_content = response
|
||||||
|
.content
|
||||||
|
.iter()
|
||||||
|
.filter(|c| {
|
||||||
|
if let MessageContent::ToolRequest(req) = c {
|
||||||
|
// Only filter out frontend tool requests
|
||||||
|
if let Ok(tool_call) = &req.tool_call {
|
||||||
|
return !self.is_frontend_tool(&tool_call.name);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
true
|
||||||
|
})
|
||||||
|
.cloned()
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
let filtered_message = Message {
|
||||||
|
role: response.role.clone(),
|
||||||
|
created: response.created,
|
||||||
|
content: filtered_content,
|
||||||
|
};
|
||||||
|
|
||||||
|
// Categorize tool requests
|
||||||
|
let mut frontend_requests = Vec::new();
|
||||||
|
let mut other_requests = Vec::new();
|
||||||
|
|
||||||
|
for request in tool_requests {
|
||||||
|
if let Ok(tool_call) = &request.tool_call {
|
||||||
|
if self.is_frontend_tool(&tool_call.name) {
|
||||||
|
frontend_requests.push(request);
|
||||||
|
} else {
|
||||||
|
other_requests.push(request);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// If there's an error in the tool call, add it to other_requests
|
||||||
|
other_requests.push(request);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
(frontend_requests, other_requests, filtered_message)
|
||||||
|
}
|
||||||
|
|
||||||
/// Update session metrics after a response
|
/// Update session metrics after a response
|
||||||
pub(crate) async fn update_session_metrics(
|
pub(crate) async fn update_session_metrics(
|
||||||
session_config: crate::agents::types::SessionConfig,
|
session_config: crate::agents::types::SessionConfig,
|
||||||
|
|||||||
120
crates/goose/src/agents/tool_execution.rs
Normal file
120
crates/goose/src/agents/tool_execution.rs
Normal file
@@ -0,0 +1,120 @@
|
|||||||
|
use std::future::Future;
|
||||||
|
use std::pin::Pin;
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use async_stream::try_stream;
|
||||||
|
use futures::stream::BoxStream;
|
||||||
|
use futures::StreamExt;
|
||||||
|
use tokio::sync::Mutex;
|
||||||
|
|
||||||
|
use crate::config::permission::PermissionLevel;
|
||||||
|
use crate::config::PermissionManager;
|
||||||
|
use crate::message::{Message, ToolRequest};
|
||||||
|
use crate::permission::permission_confirmation::PrincipalType;
|
||||||
|
use crate::permission::permission_judge::get_confirmation_message;
|
||||||
|
use crate::permission::Permission;
|
||||||
|
use mcp_core::{Content, ToolError};
|
||||||
|
|
||||||
|
// Type alias for ToolFutures - used in the agent loop to join all futures together
|
||||||
|
pub(crate) type ToolFuture<'a> =
|
||||||
|
Pin<Box<dyn Future<Output = (String, Result<Vec<Content>, ToolError>)> + Send + 'a>>;
|
||||||
|
pub(crate) type ToolFuturesVec<'a> = Arc<Mutex<Vec<ToolFuture<'a>>>>;
|
||||||
|
// Type alias for extension installation results
|
||||||
|
pub(crate) type ExtensionInstallResult = (String, Result<Vec<Content>, ToolError>);
|
||||||
|
pub(crate) type ExtensionInstallResults = Arc<Mutex<Vec<ExtensionInstallResult>>>;
|
||||||
|
|
||||||
|
use crate::agents::Agent;
|
||||||
|
|
||||||
|
pub const DECLINED_RESPONSE: &str = "The user has declined to run this tool. \
|
||||||
|
DO NOT attempt to call this tool again. \
|
||||||
|
If there are no alternative methods to proceed, clearly explain the situation and STOP.";
|
||||||
|
|
||||||
|
pub const CHAT_MODE_TOOL_SKIPPED_RESPONSE: &str = "Let the user know the tool call was skipped in Goose chat mode. \
|
||||||
|
DO NOT apologize for skipping the tool call. DO NOT say sorry. \
|
||||||
|
Provide an explanation of what the tool call would do, structured as a \
|
||||||
|
plan for the user. Again, DO NOT apologize. \
|
||||||
|
**Example Plan:**\n \
|
||||||
|
1. **Identify Task Scope** - Determine the purpose and expected outcome.\n \
|
||||||
|
2. **Outline Steps** - Break down the steps.\n \
|
||||||
|
If needed, adjust the explanation based on user preferences or questions.";
|
||||||
|
|
||||||
|
impl Agent {
|
||||||
|
pub(crate) fn handle_approval_tool_requests<'a>(
|
||||||
|
&'a self,
|
||||||
|
tool_requests: &'a [ToolRequest],
|
||||||
|
install_results: ExtensionInstallResults,
|
||||||
|
tool_futures: ToolFuturesVec<'a>,
|
||||||
|
permission_manager: &'a mut PermissionManager,
|
||||||
|
message_tool_response: Arc<Mutex<Message>>,
|
||||||
|
) -> BoxStream<'a, anyhow::Result<Message>> {
|
||||||
|
try_stream! {
|
||||||
|
for request in tool_requests {
|
||||||
|
if let Ok(tool_call) = request.tool_call.clone() {
|
||||||
|
let (principal_type, confirmation) = get_confirmation_message(&request.id.clone(), tool_call.clone());
|
||||||
|
yield confirmation;
|
||||||
|
|
||||||
|
let mut rx = self.confirmation_rx.lock().await;
|
||||||
|
while let Some((req_id, confirmation)) = rx.recv().await {
|
||||||
|
if req_id == request.id {
|
||||||
|
if confirmation.permission == Permission::AllowOnce || confirmation.permission == Permission::AlwaysAllow {
|
||||||
|
if principal_type == PrincipalType::Extension {
|
||||||
|
let extension_name = tool_call.arguments.get("extension_name")
|
||||||
|
.and_then(|v| v.as_str())
|
||||||
|
.unwrap_or("")
|
||||||
|
.to_string();
|
||||||
|
|
||||||
|
let mut results = install_results.lock().await;
|
||||||
|
let install_result = self.enable_extension(extension_name, request.id.clone()).await;
|
||||||
|
results.push(install_result);
|
||||||
|
} else {
|
||||||
|
// Add this tool call to the futures collection
|
||||||
|
let tool_future = self.dispatch_tool_call(tool_call.clone(), request.id.clone());
|
||||||
|
let mut futures = tool_futures.lock().await;
|
||||||
|
futures.push(Box::pin(tool_future));
|
||||||
|
|
||||||
|
if confirmation.permission == Permission::AlwaysAllow {
|
||||||
|
permission_manager.update_user_permission(&tool_call.name, PermissionLevel::AlwaysAllow);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// User declined - add declined response
|
||||||
|
let mut response = message_tool_response.lock().await;
|
||||||
|
*response = response.clone().with_tool_response(
|
||||||
|
request.id.clone(),
|
||||||
|
Ok(vec![Content::text(DECLINED_RESPONSE)]),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
break; // Exit the loop once the matching `req_id` is found
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}.boxed()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn handle_frontend_tool_requests<'a>(
|
||||||
|
&'a self,
|
||||||
|
tool_requests: &'a [ToolRequest],
|
||||||
|
message_tool_response: Arc<Mutex<Message>>,
|
||||||
|
) -> BoxStream<'a, anyhow::Result<Message>> {
|
||||||
|
try_stream! {
|
||||||
|
for request in tool_requests {
|
||||||
|
if let Ok(tool_call) = request.tool_call.clone() {
|
||||||
|
if self.is_frontend_tool(&tool_call.name) {
|
||||||
|
// Send frontend tool request and wait for response
|
||||||
|
yield Message::assistant().with_frontend_tool_request(
|
||||||
|
request.id.clone(),
|
||||||
|
Ok(tool_call.clone())
|
||||||
|
);
|
||||||
|
|
||||||
|
if let Some((id, result)) = self.tool_result_rx.lock().await.recv().await {
|
||||||
|
let mut response = message_tool_response.lock().await;
|
||||||
|
*response = response.clone().with_tool_response(id, result);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
.boxed()
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -191,7 +191,7 @@ pub struct PermissionCheckResult {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub async fn check_tool_permissions(
|
pub async fn check_tool_permissions(
|
||||||
candidate_requests: Vec<&ToolRequest>,
|
candidate_requests: &[ToolRequest],
|
||||||
mode: &str,
|
mode: &str,
|
||||||
tools_with_readonly_annotation: HashSet<String>,
|
tools_with_readonly_annotation: HashSet<String>,
|
||||||
tools_without_annotation: HashSet<String>,
|
tools_without_annotation: HashSet<String>,
|
||||||
@@ -466,12 +466,12 @@ mod tests {
|
|||||||
}),
|
}),
|
||||||
};
|
};
|
||||||
|
|
||||||
let candidate_requests: Vec<&ToolRequest> =
|
let candidate_requests: Vec<ToolRequest> =
|
||||||
vec![&tool_request_1, &tool_request_2, &enable_extension];
|
vec![tool_request_1, tool_request_2, enable_extension];
|
||||||
|
|
||||||
// Call the function under test
|
// Call the function under test
|
||||||
let result = check_tool_permissions(
|
let result = check_tool_permissions(
|
||||||
candidate_requests,
|
&candidate_requests,
|
||||||
"smart_approve",
|
"smart_approve",
|
||||||
tools_with_readonly_annotation,
|
tools_with_readonly_annotation,
|
||||||
tools_without_annotation,
|
tools_without_annotation,
|
||||||
@@ -534,11 +534,11 @@ mod tests {
|
|||||||
}),
|
}),
|
||||||
};
|
};
|
||||||
|
|
||||||
let candidate_requests: Vec<&ToolRequest> = vec![&tool_request_1, &tool_request_2];
|
let candidate_requests: Vec<ToolRequest> = vec![tool_request_1, tool_request_2];
|
||||||
|
|
||||||
// Call the function under test
|
// Call the function under test
|
||||||
let result = check_tool_permissions(
|
let result = check_tool_permissions(
|
||||||
candidate_requests,
|
&candidate_requests,
|
||||||
"auto",
|
"auto",
|
||||||
tools_with_readonly_annotation,
|
tools_with_readonly_annotation,
|
||||||
tools_without_annotation,
|
tools_without_annotation,
|
||||||
|
|||||||
Reference in New Issue
Block a user