From a9132291779a4265e7740e30077a0c3543f4da3c Mon Sep 17 00:00:00 2001 From: german-garcia-intellitech Date: Fri, 11 Apr 2025 12:17:18 -0400 Subject: [PATCH] Added Exponential Backoff to Bedrock provider (#2139) Co-authored-by: German Garcia --- crates/goose/src/providers/bedrock.rs | 172 ++++++++++++++++---------- 1 file changed, 110 insertions(+), 62 deletions(-) diff --git a/crates/goose/src/providers/bedrock.rs b/crates/goose/src/providers/bedrock.rs index ca96b1f4..1e047d2c 100644 --- a/crates/goose/src/providers/bedrock.rs +++ b/crates/goose/src/providers/bedrock.rs @@ -1,4 +1,5 @@ use std::collections::HashMap; +use std::time::Duration; use anyhow::Result; use async_trait::async_trait; @@ -7,6 +8,7 @@ use aws_sdk_bedrockruntime::operation::converse::ConverseError; use aws_sdk_bedrockruntime::{types as bedrock, Client}; use mcp_core::Tool; use serde_json::Value; +use tokio::time::sleep; use super::base::{ConfigKey, Provider, ProviderMetadata, ProviderUsage}; use super::errors::ProviderError; @@ -121,70 +123,116 @@ impl Provider for BedrockProvider { request = request.tool_config(to_bedrock_tool_config(tools)?); } - let response = request.send().await; + // Retry configuration + const MAX_RETRIES: u32 = 10; + const INITIAL_BACKOFF_MS: u64 = 20_000; // 20 seconds + const MAX_BACKOFF_MS: u64 = 120_000; // 120 seconds (2 minutes) - let response = match response { - Ok(response) => response, - Err(err) => { - return Err(match err.into_service_error() { - ConverseError::AccessDeniedException(err) => { - ProviderError::Authentication(format!("Failed to call Bedrock: {:?}", err)) + let mut attempts = 0; + let mut backoff_ms = INITIAL_BACKOFF_MS; + + loop { + attempts += 1; + + match request.clone().send().await { + Ok(response) => { + // Successful response, process it and return + return match response.output { + Some(bedrock::ConverseOutput::Message(message)) => { + let usage = response + .usage + .as_ref() + .map(from_bedrock_usage) + .unwrap_or_default(); + + let message = from_bedrock_message(&message)?; + + // Add debug trace with input context + let debug_payload = serde_json::json!({ + "system": system, + "messages": messages, + "tools": tools + }); + emit_debug_trace( + &self.model, + &debug_payload, + &serde_json::to_value(&message).unwrap_or_default(), + &usage, + ); + + let provider_usage = ProviderUsage::new(model_name.to_string(), usage); + Ok((message, provider_usage)) + } + _ => Err(ProviderError::RequestFailed( + "No output from Bedrock".to_string(), + )), + }; + } + Err(err) => { + match err.into_service_error() { + ConverseError::ThrottlingException(throttle_err) => { + if attempts > MAX_RETRIES { + // We've exhausted our retries + tracing::error!( + "Failed after {MAX_RETRIES} retries: {:?}", + throttle_err + ); + return Err(ProviderError::RateLimitExceeded(format!( + "Failed to call Bedrock after {MAX_RETRIES} retries: {:?}", + throttle_err + ))); + } + + // Log retry attempt + tracing::warn!( + "Bedrock throttling error (attempt {}/{}), retrying in {} ms: {:?}", + attempts, + MAX_RETRIES, + backoff_ms, + throttle_err + ); + + // Wait before retry with exponential backoff + sleep(Duration::from_millis(backoff_ms)).await; + + // Calculate next backoff with exponential growth, capped at max + backoff_ms = (backoff_ms * 2).min(MAX_BACKOFF_MS); + + // Continue to the next retry attempt + continue; + } + ConverseError::AccessDeniedException(err) => { + return Err(ProviderError::Authentication(format!( + "Failed to call Bedrock: {:?}", + err + ))); + } + ConverseError::ValidationException(err) + if err + .message() + .unwrap_or_default() + .contains("Input is too long for requested model.") => + { + return Err(ProviderError::ContextLengthExceeded(format!( + "Failed to call Bedrock: {:?}", + err + ))); + } + ConverseError::ModelErrorException(err) => { + return Err(ProviderError::ExecutionError(format!( + "Failed to call Bedrock: {:?}", + err + ))); + } + err => { + return Err(ProviderError::ServerError(format!( + "Failed to call Bedrock: {:?}", + err + ))); + } } - ConverseError::ThrottlingException(err) => ProviderError::RateLimitExceeded( - format!("Failed to call Bedrock: {:?}", err), - ), - ConverseError::ValidationException(err) - if err - .message() - .unwrap_or_default() - .contains("Input is too long for requested model.") => - { - ProviderError::ContextLengthExceeded(format!( - "Failed to call Bedrock: {:?}", - err - )) - } - ConverseError::ModelErrorException(err) => { - ProviderError::ExecutionError(format!("Failed to call Bedrock: {:?}", err)) - } - err => { - ProviderError::ServerError(format!("Failed to call Bedrock: {:?}", err,)) - } - }); + } } - }; - - let message = match response.output { - Some(bedrock::ConverseOutput::Message(message)) => message, - _ => { - return Err(ProviderError::RequestFailed( - "No output from Bedrock".to_string(), - )) - } - }; - - let usage = response - .usage - .as_ref() - .map(from_bedrock_usage) - .unwrap_or_default(); - - let message = from_bedrock_message(&message)?; - - // Add debug trace with input context - let debug_payload = serde_json::json!({ - "system": system, - "messages": messages, - "tools": tools - }); - emit_debug_trace( - &self.model, - &debug_payload, - &serde_json::to_value(&message).unwrap_or_default(), - &usage, - ); - - let provider_usage = ProviderUsage::new(model_name.to_string(), usage); - Ok((message, provider_usage)) + } } }