mirror of
https://github.com/aljazceru/goose.git
synced 2025-12-18 14:44:21 +01:00
Added Exponential Backoff to Bedrock provider (#2139)
Co-authored-by: German Garcia <german.garcia@kalosys.net>
This commit is contained in:
committed by
GitHub
parent
26a8bcafbe
commit
a913229177
@@ -1,4 +1,5 @@
|
||||
use std::collections::HashMap;
|
||||
use std::time::Duration;
|
||||
|
||||
use anyhow::Result;
|
||||
use async_trait::async_trait;
|
||||
@@ -7,6 +8,7 @@ use aws_sdk_bedrockruntime::operation::converse::ConverseError;
|
||||
use aws_sdk_bedrockruntime::{types as bedrock, Client};
|
||||
use mcp_core::Tool;
|
||||
use serde_json::Value;
|
||||
use tokio::time::sleep;
|
||||
|
||||
use super::base::{ConfigKey, Provider, ProviderMetadata, ProviderUsage};
|
||||
use super::errors::ProviderError;
|
||||
@@ -121,48 +123,22 @@ impl Provider for BedrockProvider {
|
||||
request = request.tool_config(to_bedrock_tool_config(tools)?);
|
||||
}
|
||||
|
||||
let response = request.send().await;
|
||||
// Retry configuration
|
||||
const MAX_RETRIES: u32 = 10;
|
||||
const INITIAL_BACKOFF_MS: u64 = 20_000; // 20 seconds
|
||||
const MAX_BACKOFF_MS: u64 = 120_000; // 120 seconds (2 minutes)
|
||||
|
||||
let response = match response {
|
||||
Ok(response) => response,
|
||||
Err(err) => {
|
||||
return Err(match err.into_service_error() {
|
||||
ConverseError::AccessDeniedException(err) => {
|
||||
ProviderError::Authentication(format!("Failed to call Bedrock: {:?}", err))
|
||||
}
|
||||
ConverseError::ThrottlingException(err) => ProviderError::RateLimitExceeded(
|
||||
format!("Failed to call Bedrock: {:?}", err),
|
||||
),
|
||||
ConverseError::ValidationException(err)
|
||||
if err
|
||||
.message()
|
||||
.unwrap_or_default()
|
||||
.contains("Input is too long for requested model.") =>
|
||||
{
|
||||
ProviderError::ContextLengthExceeded(format!(
|
||||
"Failed to call Bedrock: {:?}",
|
||||
err
|
||||
))
|
||||
}
|
||||
ConverseError::ModelErrorException(err) => {
|
||||
ProviderError::ExecutionError(format!("Failed to call Bedrock: {:?}", err))
|
||||
}
|
||||
err => {
|
||||
ProviderError::ServerError(format!("Failed to call Bedrock: {:?}", err,))
|
||||
}
|
||||
});
|
||||
}
|
||||
};
|
||||
let mut attempts = 0;
|
||||
let mut backoff_ms = INITIAL_BACKOFF_MS;
|
||||
|
||||
let message = match response.output {
|
||||
Some(bedrock::ConverseOutput::Message(message)) => message,
|
||||
_ => {
|
||||
return Err(ProviderError::RequestFailed(
|
||||
"No output from Bedrock".to_string(),
|
||||
))
|
||||
}
|
||||
};
|
||||
loop {
|
||||
attempts += 1;
|
||||
|
||||
match request.clone().send().await {
|
||||
Ok(response) => {
|
||||
// Successful response, process it and return
|
||||
return match response.output {
|
||||
Some(bedrock::ConverseOutput::Message(message)) => {
|
||||
let usage = response
|
||||
.usage
|
||||
.as_ref()
|
||||
@@ -187,4 +163,76 @@ impl Provider for BedrockProvider {
|
||||
let provider_usage = ProviderUsage::new(model_name.to_string(), usage);
|
||||
Ok((message, provider_usage))
|
||||
}
|
||||
_ => Err(ProviderError::RequestFailed(
|
||||
"No output from Bedrock".to_string(),
|
||||
)),
|
||||
};
|
||||
}
|
||||
Err(err) => {
|
||||
match err.into_service_error() {
|
||||
ConverseError::ThrottlingException(throttle_err) => {
|
||||
if attempts > MAX_RETRIES {
|
||||
// We've exhausted our retries
|
||||
tracing::error!(
|
||||
"Failed after {MAX_RETRIES} retries: {:?}",
|
||||
throttle_err
|
||||
);
|
||||
return Err(ProviderError::RateLimitExceeded(format!(
|
||||
"Failed to call Bedrock after {MAX_RETRIES} retries: {:?}",
|
||||
throttle_err
|
||||
)));
|
||||
}
|
||||
|
||||
// Log retry attempt
|
||||
tracing::warn!(
|
||||
"Bedrock throttling error (attempt {}/{}), retrying in {} ms: {:?}",
|
||||
attempts,
|
||||
MAX_RETRIES,
|
||||
backoff_ms,
|
||||
throttle_err
|
||||
);
|
||||
|
||||
// Wait before retry with exponential backoff
|
||||
sleep(Duration::from_millis(backoff_ms)).await;
|
||||
|
||||
// Calculate next backoff with exponential growth, capped at max
|
||||
backoff_ms = (backoff_ms * 2).min(MAX_BACKOFF_MS);
|
||||
|
||||
// Continue to the next retry attempt
|
||||
continue;
|
||||
}
|
||||
ConverseError::AccessDeniedException(err) => {
|
||||
return Err(ProviderError::Authentication(format!(
|
||||
"Failed to call Bedrock: {:?}",
|
||||
err
|
||||
)));
|
||||
}
|
||||
ConverseError::ValidationException(err)
|
||||
if err
|
||||
.message()
|
||||
.unwrap_or_default()
|
||||
.contains("Input is too long for requested model.") =>
|
||||
{
|
||||
return Err(ProviderError::ContextLengthExceeded(format!(
|
||||
"Failed to call Bedrock: {:?}",
|
||||
err
|
||||
)));
|
||||
}
|
||||
ConverseError::ModelErrorException(err) => {
|
||||
return Err(ProviderError::ExecutionError(format!(
|
||||
"Failed to call Bedrock: {:?}",
|
||||
err
|
||||
)));
|
||||
}
|
||||
err => {
|
||||
return Err(ProviderError::ServerError(format!(
|
||||
"Failed to call Bedrock: {:?}",
|
||||
err
|
||||
)));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user