Added Exponential Backoff to Bedrock provider (#2139)

Co-authored-by: German Garcia <german.garcia@kalosys.net>
This commit is contained in:
german-garcia-intellitech
2025-04-11 12:17:18 -04:00
committed by GitHub
parent 26a8bcafbe
commit a913229177

View File

@@ -1,4 +1,5 @@
use std::collections::HashMap;
use std::time::Duration;
use anyhow::Result;
use async_trait::async_trait;
@@ -7,6 +8,7 @@ use aws_sdk_bedrockruntime::operation::converse::ConverseError;
use aws_sdk_bedrockruntime::{types as bedrock, Client};
use mcp_core::Tool;
use serde_json::Value;
use tokio::time::sleep;
use super::base::{ConfigKey, Provider, ProviderMetadata, ProviderUsage};
use super::errors::ProviderError;
@@ -121,70 +123,116 @@ impl Provider for BedrockProvider {
request = request.tool_config(to_bedrock_tool_config(tools)?);
}
let response = request.send().await;
// Retry configuration
const MAX_RETRIES: u32 = 10;
const INITIAL_BACKOFF_MS: u64 = 20_000; // 20 seconds
const MAX_BACKOFF_MS: u64 = 120_000; // 120 seconds (2 minutes)
let response = match response {
Ok(response) => response,
Err(err) => {
return Err(match err.into_service_error() {
ConverseError::AccessDeniedException(err) => {
ProviderError::Authentication(format!("Failed to call Bedrock: {:?}", err))
let mut attempts = 0;
let mut backoff_ms = INITIAL_BACKOFF_MS;
loop {
attempts += 1;
match request.clone().send().await {
Ok(response) => {
// Successful response, process it and return
return match response.output {
Some(bedrock::ConverseOutput::Message(message)) => {
let usage = response
.usage
.as_ref()
.map(from_bedrock_usage)
.unwrap_or_default();
let message = from_bedrock_message(&message)?;
// Add debug trace with input context
let debug_payload = serde_json::json!({
"system": system,
"messages": messages,
"tools": tools
});
emit_debug_trace(
&self.model,
&debug_payload,
&serde_json::to_value(&message).unwrap_or_default(),
&usage,
);
let provider_usage = ProviderUsage::new(model_name.to_string(), usage);
Ok((message, provider_usage))
}
_ => Err(ProviderError::RequestFailed(
"No output from Bedrock".to_string(),
)),
};
}
Err(err) => {
match err.into_service_error() {
ConverseError::ThrottlingException(throttle_err) => {
if attempts > MAX_RETRIES {
// We've exhausted our retries
tracing::error!(
"Failed after {MAX_RETRIES} retries: {:?}",
throttle_err
);
return Err(ProviderError::RateLimitExceeded(format!(
"Failed to call Bedrock after {MAX_RETRIES} retries: {:?}",
throttle_err
)));
}
// Log retry attempt
tracing::warn!(
"Bedrock throttling error (attempt {}/{}), retrying in {} ms: {:?}",
attempts,
MAX_RETRIES,
backoff_ms,
throttle_err
);
// Wait before retry with exponential backoff
sleep(Duration::from_millis(backoff_ms)).await;
// Calculate next backoff with exponential growth, capped at max
backoff_ms = (backoff_ms * 2).min(MAX_BACKOFF_MS);
// Continue to the next retry attempt
continue;
}
ConverseError::AccessDeniedException(err) => {
return Err(ProviderError::Authentication(format!(
"Failed to call Bedrock: {:?}",
err
)));
}
ConverseError::ValidationException(err)
if err
.message()
.unwrap_or_default()
.contains("Input is too long for requested model.") =>
{
return Err(ProviderError::ContextLengthExceeded(format!(
"Failed to call Bedrock: {:?}",
err
)));
}
ConverseError::ModelErrorException(err) => {
return Err(ProviderError::ExecutionError(format!(
"Failed to call Bedrock: {:?}",
err
)));
}
err => {
return Err(ProviderError::ServerError(format!(
"Failed to call Bedrock: {:?}",
err
)));
}
}
ConverseError::ThrottlingException(err) => ProviderError::RateLimitExceeded(
format!("Failed to call Bedrock: {:?}", err),
),
ConverseError::ValidationException(err)
if err
.message()
.unwrap_or_default()
.contains("Input is too long for requested model.") =>
{
ProviderError::ContextLengthExceeded(format!(
"Failed to call Bedrock: {:?}",
err
))
}
ConverseError::ModelErrorException(err) => {
ProviderError::ExecutionError(format!("Failed to call Bedrock: {:?}", err))
}
err => {
ProviderError::ServerError(format!("Failed to call Bedrock: {:?}", err,))
}
});
}
}
};
let message = match response.output {
Some(bedrock::ConverseOutput::Message(message)) => message,
_ => {
return Err(ProviderError::RequestFailed(
"No output from Bedrock".to_string(),
))
}
};
let usage = response
.usage
.as_ref()
.map(from_bedrock_usage)
.unwrap_or_default();
let message = from_bedrock_message(&message)?;
// Add debug trace with input context
let debug_payload = serde_json::json!({
"system": system,
"messages": messages,
"tools": tools
});
emit_debug_trace(
&self.model,
&debug_payload,
&serde_json::to_value(&message).unwrap_or_default(),
&usage,
);
let provider_usage = ProviderUsage::new(model_name.to_string(), usage);
Ok((message, provider_usage))
}
}
}