feat: add tool repetition monitoring to prevent infinite loops (#2527)

This commit is contained in:
Max Novich
2025-05-14 14:46:37 -07:00
committed by GitHub
parent 70020f1b45
commit 4e1b091d91
6 changed files with 139 additions and 0 deletions

View File

@@ -220,6 +220,15 @@ enum Command {
)]
debug: bool,
/// Maximum number of consecutive identical tool calls allowed
#[arg(
long = "max-tool-repetitions",
value_name = "NUMBER",
help = "Maximum number of consecutive identical tool calls allowed",
long_help = "Set a limit on how many times the same tool can be called consecutively with identical parameters. Helps prevent infinite loops."
)]
max_tool_repetitions: Option<u32>,
/// Add stdio extensions with environment variables and commands
#[arg(
long = "with-extension",
@@ -324,6 +333,15 @@ enum Command {
)]
no_session: bool,
/// Maximum number of consecutive identical tool calls allowed
#[arg(
long = "max-tool-repetitions",
value_name = "NUMBER",
help = "Maximum number of consecutive identical tool calls allowed",
long_help = "Set a limit on how many times the same tool can be called consecutively with identical parameters. Helps prevent infinite loops."
)]
max_tool_repetitions: Option<u32>,
/// Identifier for this run session
#[command(flatten)]
identifier: Option<Identifier>,
@@ -446,6 +464,7 @@ pub async fn cli() -> Result<()> {
resume,
history,
debug,
max_tool_repetitions,
extensions,
remote_extensions,
builtins,
@@ -475,6 +494,7 @@ pub async fn cli() -> Result<()> {
extensions_override: None,
additional_system_prompt: None,
debug,
max_tool_repetitions,
})
.await;
setup_logging(
@@ -511,6 +531,7 @@ pub async fn cli() -> Result<()> {
resume,
no_session,
debug,
max_tool_repetitions,
extensions,
remote_extensions,
builtins,
@@ -576,6 +597,7 @@ pub async fn cli() -> Result<()> {
extensions_override: input_config.extensions_override,
additional_system_prompt: input_config.additional_system_prompt,
debug,
max_tool_repetitions,
})
.await;
@@ -647,6 +669,7 @@ pub async fn cli() -> Result<()> {
extensions_override: None,
additional_system_prompt: None,
debug: false,
max_tool_repetitions: None,
})
.await;
setup_logging(

View File

@@ -41,6 +41,7 @@ pub async fn agent_generator(
extensions_override: None,
additional_system_prompt: None,
debug: false,
max_tool_repetitions: None,
})
.await;

View File

@@ -35,6 +35,8 @@ pub struct SessionBuilderConfig {
pub additional_system_prompt: Option<String>,
/// Enable debug printing
pub debug: bool,
/// Maximum number of consecutive identical tool calls allowed
pub max_tool_repetitions: Option<u32>,
}
pub async fn build_session(session_config: SessionBuilderConfig) -> Session {
@@ -55,6 +57,11 @@ pub async fn build_session(session_config: SessionBuilderConfig) -> Session {
let new_provider = create(&provider_name, model_config).unwrap();
let _ = agent.update_provider(new_provider).await;
// Configure tool monitoring if max_tool_repetitions is set
if let Some(max_repetitions) = session_config.max_tool_repetitions {
agent.configure_tool_monitor(Some(max_repetitions)).await;
}
// Handle session file resolution and resuming
let session_file = if session_config.no_session {
// Use a temporary path that won't be written to

View File

@@ -12,6 +12,7 @@ use crate::permission::PermissionConfirmation;
use crate::providers::base::Provider;
use crate::providers::errors::ProviderError;
use crate::recipe::{Author, Recipe};
use crate::tool_monitor::{ToolCall, ToolMonitor};
use regex::Regex;
use serde_json::Value;
use tokio::sync::{mpsc, Mutex};
@@ -44,6 +45,7 @@ pub struct Agent {
pub(super) confirmation_rx: Mutex<mpsc::Receiver<(String, PermissionConfirmation)>>,
pub(super) tool_result_tx: mpsc::Sender<(String, ToolResult<Vec<Content>>)>,
pub(super) tool_result_rx: ToolResultReceiver,
pub(super) tool_monitor: Mutex<Option<ToolMonitor>>,
}
impl Agent {
@@ -62,6 +64,23 @@ impl Agent {
confirmation_rx: Mutex::new(confirm_rx),
tool_result_tx: tool_tx,
tool_result_rx: Arc::new(Mutex::new(tool_rx)),
tool_monitor: Mutex::new(None),
}
}
pub async fn configure_tool_monitor(&self, max_repetitions: Option<u32>) {
let mut tool_monitor = self.tool_monitor.lock().await;
*tool_monitor = Some(ToolMonitor::new(max_repetitions));
}
pub async fn get_tool_stats(&self) -> Option<HashMap<String, u32>> {
let tool_monitor = self.tool_monitor.lock().await;
tool_monitor.as_ref().map(|monitor| monitor.get_stats())
}
pub async fn reset_tool_monitor(&self) {
if let Some(monitor) = self.tool_monitor.lock().await.as_mut() {
monitor.reset();
}
}
}
@@ -116,6 +135,20 @@ impl Agent {
tool_call: mcp_core::tool::ToolCall,
request_id: String,
) -> (String, Result<Vec<Content>, ToolError>) {
// Check if this tool call should be allowed based on repetition monitoring
if let Some(monitor) = self.tool_monitor.lock().await.as_mut() {
let tool_call_info = ToolCall::new(tool_call.name.clone(), tool_call.arguments.clone());
if !monitor.check_tool_call(tool_call_info) {
return (
request_id,
Err(ToolError::ExecutionError(
"Tool call rejected: exceeded maximum allowed repetitions".to_string(),
)),
);
}
}
if tool_call.name == PLATFORM_MANAGE_EXTENSIONS_TOOL_NAME {
let extension_name = tool_call
.arguments

View File

@@ -9,4 +9,5 @@ pub mod providers;
pub mod recipe;
pub mod session;
pub mod token_counter;
pub mod tool_monitor;
pub mod tracing;

View File

@@ -0,0 +1,74 @@
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolCall {
name: String,
parameters: serde_json::Value,
}
impl ToolCall {
pub fn new(name: String, parameters: serde_json::Value) -> Self {
Self { name, parameters }
}
fn matches(&self, other: &ToolCall) -> bool {
self.name == other.name && self.parameters == other.parameters
}
}
#[derive(Debug)]
pub struct ToolMonitor {
max_repetitions: Option<u32>,
last_call: Option<ToolCall>,
repeat_count: u32,
call_counts: HashMap<String, u32>,
}
impl ToolMonitor {
pub fn new(max_repetitions: Option<u32>) -> Self {
Self {
max_repetitions,
last_call: None,
repeat_count: 0,
call_counts: HashMap::new(),
}
}
pub fn check_tool_call(&mut self, tool_call: ToolCall) -> bool {
let total_calls = self.call_counts.entry(tool_call.name.clone()).or_insert(0);
*total_calls += 1;
if self.max_repetitions.is_none() {
self.last_call = Some(tool_call);
self.repeat_count = 1;
return true;
}
if let Some(last) = &self.last_call {
if last.matches(&tool_call) {
self.repeat_count += 1;
if self.repeat_count > self.max_repetitions.unwrap() {
return false;
}
} else {
self.repeat_count = 1;
}
} else {
self.repeat_count = 1;
}
self.last_call = Some(tool_call);
true
}
pub fn get_stats(&self) -> HashMap<String, u32> {
self.call_counts.clone()
}
pub fn reset(&mut self) {
self.last_call = None;
self.repeat_count = 0;
self.call_counts.clear();
}
}