diff --git a/crates/goose-cli/src/cli.rs b/crates/goose-cli/src/cli.rs index 261470d7..ec54873d 100644 --- a/crates/goose-cli/src/cli.rs +++ b/crates/goose-cli/src/cli.rs @@ -220,6 +220,15 @@ enum Command { )] debug: bool, + /// Maximum number of consecutive identical tool calls allowed + #[arg( + long = "max-tool-repetitions", + value_name = "NUMBER", + help = "Maximum number of consecutive identical tool calls allowed", + long_help = "Set a limit on how many times the same tool can be called consecutively with identical parameters. Helps prevent infinite loops." + )] + max_tool_repetitions: Option, + /// Add stdio extensions with environment variables and commands #[arg( long = "with-extension", @@ -324,6 +333,15 @@ enum Command { )] no_session: bool, + /// Maximum number of consecutive identical tool calls allowed + #[arg( + long = "max-tool-repetitions", + value_name = "NUMBER", + help = "Maximum number of consecutive identical tool calls allowed", + long_help = "Set a limit on how many times the same tool can be called consecutively with identical parameters. Helps prevent infinite loops." + )] + max_tool_repetitions: Option, + /// Identifier for this run session #[command(flatten)] identifier: Option, @@ -446,6 +464,7 @@ pub async fn cli() -> Result<()> { resume, history, debug, + max_tool_repetitions, extensions, remote_extensions, builtins, @@ -475,6 +494,7 @@ pub async fn cli() -> Result<()> { extensions_override: None, additional_system_prompt: None, debug, + max_tool_repetitions, }) .await; setup_logging( @@ -511,6 +531,7 @@ pub async fn cli() -> Result<()> { resume, no_session, debug, + max_tool_repetitions, extensions, remote_extensions, builtins, @@ -576,6 +597,7 @@ pub async fn cli() -> Result<()> { extensions_override: input_config.extensions_override, additional_system_prompt: input_config.additional_system_prompt, debug, + max_tool_repetitions, }) .await; @@ -647,6 +669,7 @@ pub async fn cli() -> Result<()> { extensions_override: None, additional_system_prompt: None, debug: false, + max_tool_repetitions: None, }) .await; setup_logging( diff --git a/crates/goose-cli/src/commands/bench.rs b/crates/goose-cli/src/commands/bench.rs index a6ef200c..ee9514e2 100644 --- a/crates/goose-cli/src/commands/bench.rs +++ b/crates/goose-cli/src/commands/bench.rs @@ -41,6 +41,7 @@ pub async fn agent_generator( extensions_override: None, additional_system_prompt: None, debug: false, + max_tool_repetitions: None, }) .await; diff --git a/crates/goose-cli/src/session/builder.rs b/crates/goose-cli/src/session/builder.rs index a5c77368..747862de 100644 --- a/crates/goose-cli/src/session/builder.rs +++ b/crates/goose-cli/src/session/builder.rs @@ -35,6 +35,8 @@ pub struct SessionBuilderConfig { pub additional_system_prompt: Option, /// Enable debug printing pub debug: bool, + /// Maximum number of consecutive identical tool calls allowed + pub max_tool_repetitions: Option, } pub async fn build_session(session_config: SessionBuilderConfig) -> Session { @@ -55,6 +57,11 @@ pub async fn build_session(session_config: SessionBuilderConfig) -> Session { let new_provider = create(&provider_name, model_config).unwrap(); let _ = agent.update_provider(new_provider).await; + // Configure tool monitoring if max_tool_repetitions is set + if let Some(max_repetitions) = session_config.max_tool_repetitions { + agent.configure_tool_monitor(Some(max_repetitions)).await; + } + // Handle session file resolution and resuming let session_file = if session_config.no_session { // Use a temporary path that won't be written to diff --git a/crates/goose/src/agents/agent.rs b/crates/goose/src/agents/agent.rs index 2ae19998..d7662eb6 100644 --- a/crates/goose/src/agents/agent.rs +++ b/crates/goose/src/agents/agent.rs @@ -12,6 +12,7 @@ use crate::permission::PermissionConfirmation; use crate::providers::base::Provider; use crate::providers::errors::ProviderError; use crate::recipe::{Author, Recipe}; +use crate::tool_monitor::{ToolCall, ToolMonitor}; use regex::Regex; use serde_json::Value; use tokio::sync::{mpsc, Mutex}; @@ -44,6 +45,7 @@ pub struct Agent { pub(super) confirmation_rx: Mutex>, pub(super) tool_result_tx: mpsc::Sender<(String, ToolResult>)>, pub(super) tool_result_rx: ToolResultReceiver, + pub(super) tool_monitor: Mutex>, } impl Agent { @@ -62,6 +64,23 @@ impl Agent { confirmation_rx: Mutex::new(confirm_rx), tool_result_tx: tool_tx, tool_result_rx: Arc::new(Mutex::new(tool_rx)), + tool_monitor: Mutex::new(None), + } + } + + pub async fn configure_tool_monitor(&self, max_repetitions: Option) { + let mut tool_monitor = self.tool_monitor.lock().await; + *tool_monitor = Some(ToolMonitor::new(max_repetitions)); + } + + pub async fn get_tool_stats(&self) -> Option> { + let tool_monitor = self.tool_monitor.lock().await; + tool_monitor.as_ref().map(|monitor| monitor.get_stats()) + } + + pub async fn reset_tool_monitor(&self) { + if let Some(monitor) = self.tool_monitor.lock().await.as_mut() { + monitor.reset(); } } } @@ -116,6 +135,20 @@ impl Agent { tool_call: mcp_core::tool::ToolCall, request_id: String, ) -> (String, Result, ToolError>) { + // Check if this tool call should be allowed based on repetition monitoring + if let Some(monitor) = self.tool_monitor.lock().await.as_mut() { + let tool_call_info = ToolCall::new(tool_call.name.clone(), tool_call.arguments.clone()); + + if !monitor.check_tool_call(tool_call_info) { + return ( + request_id, + Err(ToolError::ExecutionError( + "Tool call rejected: exceeded maximum allowed repetitions".to_string(), + )), + ); + } + } + if tool_call.name == PLATFORM_MANAGE_EXTENSIONS_TOOL_NAME { let extension_name = tool_call .arguments diff --git a/crates/goose/src/lib.rs b/crates/goose/src/lib.rs index e0023e12..d1389adb 100644 --- a/crates/goose/src/lib.rs +++ b/crates/goose/src/lib.rs @@ -9,4 +9,5 @@ pub mod providers; pub mod recipe; pub mod session; pub mod token_counter; +pub mod tool_monitor; pub mod tracing; diff --git a/crates/goose/src/tool_monitor.rs b/crates/goose/src/tool_monitor.rs new file mode 100644 index 00000000..68720f70 --- /dev/null +++ b/crates/goose/src/tool_monitor.rs @@ -0,0 +1,74 @@ +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ToolCall { + name: String, + parameters: serde_json::Value, +} + +impl ToolCall { + pub fn new(name: String, parameters: serde_json::Value) -> Self { + Self { name, parameters } + } + + fn matches(&self, other: &ToolCall) -> bool { + self.name == other.name && self.parameters == other.parameters + } +} + +#[derive(Debug)] +pub struct ToolMonitor { + max_repetitions: Option, + last_call: Option, + repeat_count: u32, + call_counts: HashMap, +} + +impl ToolMonitor { + pub fn new(max_repetitions: Option) -> Self { + Self { + max_repetitions, + last_call: None, + repeat_count: 0, + call_counts: HashMap::new(), + } + } + + pub fn check_tool_call(&mut self, tool_call: ToolCall) -> bool { + let total_calls = self.call_counts.entry(tool_call.name.clone()).or_insert(0); + *total_calls += 1; + + if self.max_repetitions.is_none() { + self.last_call = Some(tool_call); + self.repeat_count = 1; + return true; + } + + if let Some(last) = &self.last_call { + if last.matches(&tool_call) { + self.repeat_count += 1; + if self.repeat_count > self.max_repetitions.unwrap() { + return false; + } + } else { + self.repeat_count = 1; + } + } else { + self.repeat_count = 1; + } + + self.last_call = Some(tool_call); + true + } + + pub fn get_stats(&self) -> HashMap { + self.call_counts.clone() + } + + pub fn reset(&mut self) { + self.last_call = None; + self.repeat_count = 0; + self.call_counts.clear(); + } +}