mirror of
https://github.com/aljazceru/goose.git
synced 2025-12-19 07:04:21 +01:00
feat: add tool repetition monitoring to prevent infinite loops (#2527)
This commit is contained in:
@@ -220,6 +220,15 @@ enum Command {
|
||||
)]
|
||||
debug: bool,
|
||||
|
||||
/// Maximum number of consecutive identical tool calls allowed
|
||||
#[arg(
|
||||
long = "max-tool-repetitions",
|
||||
value_name = "NUMBER",
|
||||
help = "Maximum number of consecutive identical tool calls allowed",
|
||||
long_help = "Set a limit on how many times the same tool can be called consecutively with identical parameters. Helps prevent infinite loops."
|
||||
)]
|
||||
max_tool_repetitions: Option<u32>,
|
||||
|
||||
/// Add stdio extensions with environment variables and commands
|
||||
#[arg(
|
||||
long = "with-extension",
|
||||
@@ -324,6 +333,15 @@ enum Command {
|
||||
)]
|
||||
no_session: bool,
|
||||
|
||||
/// Maximum number of consecutive identical tool calls allowed
|
||||
#[arg(
|
||||
long = "max-tool-repetitions",
|
||||
value_name = "NUMBER",
|
||||
help = "Maximum number of consecutive identical tool calls allowed",
|
||||
long_help = "Set a limit on how many times the same tool can be called consecutively with identical parameters. Helps prevent infinite loops."
|
||||
)]
|
||||
max_tool_repetitions: Option<u32>,
|
||||
|
||||
/// Identifier for this run session
|
||||
#[command(flatten)]
|
||||
identifier: Option<Identifier>,
|
||||
@@ -446,6 +464,7 @@ pub async fn cli() -> Result<()> {
|
||||
resume,
|
||||
history,
|
||||
debug,
|
||||
max_tool_repetitions,
|
||||
extensions,
|
||||
remote_extensions,
|
||||
builtins,
|
||||
@@ -475,6 +494,7 @@ pub async fn cli() -> Result<()> {
|
||||
extensions_override: None,
|
||||
additional_system_prompt: None,
|
||||
debug,
|
||||
max_tool_repetitions,
|
||||
})
|
||||
.await;
|
||||
setup_logging(
|
||||
@@ -511,6 +531,7 @@ pub async fn cli() -> Result<()> {
|
||||
resume,
|
||||
no_session,
|
||||
debug,
|
||||
max_tool_repetitions,
|
||||
extensions,
|
||||
remote_extensions,
|
||||
builtins,
|
||||
@@ -576,6 +597,7 @@ pub async fn cli() -> Result<()> {
|
||||
extensions_override: input_config.extensions_override,
|
||||
additional_system_prompt: input_config.additional_system_prompt,
|
||||
debug,
|
||||
max_tool_repetitions,
|
||||
})
|
||||
.await;
|
||||
|
||||
@@ -647,6 +669,7 @@ pub async fn cli() -> Result<()> {
|
||||
extensions_override: None,
|
||||
additional_system_prompt: None,
|
||||
debug: false,
|
||||
max_tool_repetitions: None,
|
||||
})
|
||||
.await;
|
||||
setup_logging(
|
||||
|
||||
@@ -41,6 +41,7 @@ pub async fn agent_generator(
|
||||
extensions_override: None,
|
||||
additional_system_prompt: None,
|
||||
debug: false,
|
||||
max_tool_repetitions: None,
|
||||
})
|
||||
.await;
|
||||
|
||||
|
||||
@@ -35,6 +35,8 @@ pub struct SessionBuilderConfig {
|
||||
pub additional_system_prompt: Option<String>,
|
||||
/// Enable debug printing
|
||||
pub debug: bool,
|
||||
/// Maximum number of consecutive identical tool calls allowed
|
||||
pub max_tool_repetitions: Option<u32>,
|
||||
}
|
||||
|
||||
pub async fn build_session(session_config: SessionBuilderConfig) -> Session {
|
||||
@@ -55,6 +57,11 @@ pub async fn build_session(session_config: SessionBuilderConfig) -> Session {
|
||||
let new_provider = create(&provider_name, model_config).unwrap();
|
||||
let _ = agent.update_provider(new_provider).await;
|
||||
|
||||
// Configure tool monitoring if max_tool_repetitions is set
|
||||
if let Some(max_repetitions) = session_config.max_tool_repetitions {
|
||||
agent.configure_tool_monitor(Some(max_repetitions)).await;
|
||||
}
|
||||
|
||||
// Handle session file resolution and resuming
|
||||
let session_file = if session_config.no_session {
|
||||
// Use a temporary path that won't be written to
|
||||
|
||||
@@ -12,6 +12,7 @@ use crate::permission::PermissionConfirmation;
|
||||
use crate::providers::base::Provider;
|
||||
use crate::providers::errors::ProviderError;
|
||||
use crate::recipe::{Author, Recipe};
|
||||
use crate::tool_monitor::{ToolCall, ToolMonitor};
|
||||
use regex::Regex;
|
||||
use serde_json::Value;
|
||||
use tokio::sync::{mpsc, Mutex};
|
||||
@@ -44,6 +45,7 @@ pub struct Agent {
|
||||
pub(super) confirmation_rx: Mutex<mpsc::Receiver<(String, PermissionConfirmation)>>,
|
||||
pub(super) tool_result_tx: mpsc::Sender<(String, ToolResult<Vec<Content>>)>,
|
||||
pub(super) tool_result_rx: ToolResultReceiver,
|
||||
pub(super) tool_monitor: Mutex<Option<ToolMonitor>>,
|
||||
}
|
||||
|
||||
impl Agent {
|
||||
@@ -62,6 +64,23 @@ impl Agent {
|
||||
confirmation_rx: Mutex::new(confirm_rx),
|
||||
tool_result_tx: tool_tx,
|
||||
tool_result_rx: Arc::new(Mutex::new(tool_rx)),
|
||||
tool_monitor: Mutex::new(None),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn configure_tool_monitor(&self, max_repetitions: Option<u32>) {
|
||||
let mut tool_monitor = self.tool_monitor.lock().await;
|
||||
*tool_monitor = Some(ToolMonitor::new(max_repetitions));
|
||||
}
|
||||
|
||||
pub async fn get_tool_stats(&self) -> Option<HashMap<String, u32>> {
|
||||
let tool_monitor = self.tool_monitor.lock().await;
|
||||
tool_monitor.as_ref().map(|monitor| monitor.get_stats())
|
||||
}
|
||||
|
||||
pub async fn reset_tool_monitor(&self) {
|
||||
if let Some(monitor) = self.tool_monitor.lock().await.as_mut() {
|
||||
monitor.reset();
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -116,6 +135,20 @@ impl Agent {
|
||||
tool_call: mcp_core::tool::ToolCall,
|
||||
request_id: String,
|
||||
) -> (String, Result<Vec<Content>, ToolError>) {
|
||||
// Check if this tool call should be allowed based on repetition monitoring
|
||||
if let Some(monitor) = self.tool_monitor.lock().await.as_mut() {
|
||||
let tool_call_info = ToolCall::new(tool_call.name.clone(), tool_call.arguments.clone());
|
||||
|
||||
if !monitor.check_tool_call(tool_call_info) {
|
||||
return (
|
||||
request_id,
|
||||
Err(ToolError::ExecutionError(
|
||||
"Tool call rejected: exceeded maximum allowed repetitions".to_string(),
|
||||
)),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
if tool_call.name == PLATFORM_MANAGE_EXTENSIONS_TOOL_NAME {
|
||||
let extension_name = tool_call
|
||||
.arguments
|
||||
|
||||
@@ -9,4 +9,5 @@ pub mod providers;
|
||||
pub mod recipe;
|
||||
pub mod session;
|
||||
pub mod token_counter;
|
||||
pub mod tool_monitor;
|
||||
pub mod tracing;
|
||||
|
||||
74
crates/goose/src/tool_monitor.rs
Normal file
74
crates/goose/src/tool_monitor.rs
Normal file
@@ -0,0 +1,74 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ToolCall {
|
||||
name: String,
|
||||
parameters: serde_json::Value,
|
||||
}
|
||||
|
||||
impl ToolCall {
|
||||
pub fn new(name: String, parameters: serde_json::Value) -> Self {
|
||||
Self { name, parameters }
|
||||
}
|
||||
|
||||
fn matches(&self, other: &ToolCall) -> bool {
|
||||
self.name == other.name && self.parameters == other.parameters
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct ToolMonitor {
|
||||
max_repetitions: Option<u32>,
|
||||
last_call: Option<ToolCall>,
|
||||
repeat_count: u32,
|
||||
call_counts: HashMap<String, u32>,
|
||||
}
|
||||
|
||||
impl ToolMonitor {
|
||||
pub fn new(max_repetitions: Option<u32>) -> Self {
|
||||
Self {
|
||||
max_repetitions,
|
||||
last_call: None,
|
||||
repeat_count: 0,
|
||||
call_counts: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn check_tool_call(&mut self, tool_call: ToolCall) -> bool {
|
||||
let total_calls = self.call_counts.entry(tool_call.name.clone()).or_insert(0);
|
||||
*total_calls += 1;
|
||||
|
||||
if self.max_repetitions.is_none() {
|
||||
self.last_call = Some(tool_call);
|
||||
self.repeat_count = 1;
|
||||
return true;
|
||||
}
|
||||
|
||||
if let Some(last) = &self.last_call {
|
||||
if last.matches(&tool_call) {
|
||||
self.repeat_count += 1;
|
||||
if self.repeat_count > self.max_repetitions.unwrap() {
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
self.repeat_count = 1;
|
||||
}
|
||||
} else {
|
||||
self.repeat_count = 1;
|
||||
}
|
||||
|
||||
self.last_call = Some(tool_call);
|
||||
true
|
||||
}
|
||||
|
||||
pub fn get_stats(&self) -> HashMap<String, u32> {
|
||||
self.call_counts.clone()
|
||||
}
|
||||
|
||||
pub fn reset(&mut self) {
|
||||
self.last_call = None;
|
||||
self.repeat_count = 0;
|
||||
self.call_counts.clear();
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user