mirror of
https://github.com/aljazceru/goose.git
synced 2025-12-18 14:44:21 +01:00
Added Exponential Backoff to Bedrock provider (#2139)
Co-authored-by: German Garcia <german.garcia@kalosys.net>
This commit is contained in:
committed by
GitHub
parent
26a8bcafbe
commit
a913229177
@@ -1,4 +1,5 @@
|
|||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
use std::time::Duration;
|
||||||
|
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
@@ -7,6 +8,7 @@ use aws_sdk_bedrockruntime::operation::converse::ConverseError;
|
|||||||
use aws_sdk_bedrockruntime::{types as bedrock, Client};
|
use aws_sdk_bedrockruntime::{types as bedrock, Client};
|
||||||
use mcp_core::Tool;
|
use mcp_core::Tool;
|
||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
|
use tokio::time::sleep;
|
||||||
|
|
||||||
use super::base::{ConfigKey, Provider, ProviderMetadata, ProviderUsage};
|
use super::base::{ConfigKey, Provider, ProviderMetadata, ProviderUsage};
|
||||||
use super::errors::ProviderError;
|
use super::errors::ProviderError;
|
||||||
@@ -121,70 +123,116 @@ impl Provider for BedrockProvider {
|
|||||||
request = request.tool_config(to_bedrock_tool_config(tools)?);
|
request = request.tool_config(to_bedrock_tool_config(tools)?);
|
||||||
}
|
}
|
||||||
|
|
||||||
let response = request.send().await;
|
// Retry configuration
|
||||||
|
const MAX_RETRIES: u32 = 10;
|
||||||
|
const INITIAL_BACKOFF_MS: u64 = 20_000; // 20 seconds
|
||||||
|
const MAX_BACKOFF_MS: u64 = 120_000; // 120 seconds (2 minutes)
|
||||||
|
|
||||||
let response = match response {
|
let mut attempts = 0;
|
||||||
Ok(response) => response,
|
let mut backoff_ms = INITIAL_BACKOFF_MS;
|
||||||
Err(err) => {
|
|
||||||
return Err(match err.into_service_error() {
|
loop {
|
||||||
ConverseError::AccessDeniedException(err) => {
|
attempts += 1;
|
||||||
ProviderError::Authentication(format!("Failed to call Bedrock: {:?}", err))
|
|
||||||
|
match request.clone().send().await {
|
||||||
|
Ok(response) => {
|
||||||
|
// Successful response, process it and return
|
||||||
|
return match response.output {
|
||||||
|
Some(bedrock::ConverseOutput::Message(message)) => {
|
||||||
|
let usage = response
|
||||||
|
.usage
|
||||||
|
.as_ref()
|
||||||
|
.map(from_bedrock_usage)
|
||||||
|
.unwrap_or_default();
|
||||||
|
|
||||||
|
let message = from_bedrock_message(&message)?;
|
||||||
|
|
||||||
|
// Add debug trace with input context
|
||||||
|
let debug_payload = serde_json::json!({
|
||||||
|
"system": system,
|
||||||
|
"messages": messages,
|
||||||
|
"tools": tools
|
||||||
|
});
|
||||||
|
emit_debug_trace(
|
||||||
|
&self.model,
|
||||||
|
&debug_payload,
|
||||||
|
&serde_json::to_value(&message).unwrap_or_default(),
|
||||||
|
&usage,
|
||||||
|
);
|
||||||
|
|
||||||
|
let provider_usage = ProviderUsage::new(model_name.to_string(), usage);
|
||||||
|
Ok((message, provider_usage))
|
||||||
|
}
|
||||||
|
_ => Err(ProviderError::RequestFailed(
|
||||||
|
"No output from Bedrock".to_string(),
|
||||||
|
)),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
Err(err) => {
|
||||||
|
match err.into_service_error() {
|
||||||
|
ConverseError::ThrottlingException(throttle_err) => {
|
||||||
|
if attempts > MAX_RETRIES {
|
||||||
|
// We've exhausted our retries
|
||||||
|
tracing::error!(
|
||||||
|
"Failed after {MAX_RETRIES} retries: {:?}",
|
||||||
|
throttle_err
|
||||||
|
);
|
||||||
|
return Err(ProviderError::RateLimitExceeded(format!(
|
||||||
|
"Failed to call Bedrock after {MAX_RETRIES} retries: {:?}",
|
||||||
|
throttle_err
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Log retry attempt
|
||||||
|
tracing::warn!(
|
||||||
|
"Bedrock throttling error (attempt {}/{}), retrying in {} ms: {:?}",
|
||||||
|
attempts,
|
||||||
|
MAX_RETRIES,
|
||||||
|
backoff_ms,
|
||||||
|
throttle_err
|
||||||
|
);
|
||||||
|
|
||||||
|
// Wait before retry with exponential backoff
|
||||||
|
sleep(Duration::from_millis(backoff_ms)).await;
|
||||||
|
|
||||||
|
// Calculate next backoff with exponential growth, capped at max
|
||||||
|
backoff_ms = (backoff_ms * 2).min(MAX_BACKOFF_MS);
|
||||||
|
|
||||||
|
// Continue to the next retry attempt
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
ConverseError::AccessDeniedException(err) => {
|
||||||
|
return Err(ProviderError::Authentication(format!(
|
||||||
|
"Failed to call Bedrock: {:?}",
|
||||||
|
err
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
ConverseError::ValidationException(err)
|
||||||
|
if err
|
||||||
|
.message()
|
||||||
|
.unwrap_or_default()
|
||||||
|
.contains("Input is too long for requested model.") =>
|
||||||
|
{
|
||||||
|
return Err(ProviderError::ContextLengthExceeded(format!(
|
||||||
|
"Failed to call Bedrock: {:?}",
|
||||||
|
err
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
ConverseError::ModelErrorException(err) => {
|
||||||
|
return Err(ProviderError::ExecutionError(format!(
|
||||||
|
"Failed to call Bedrock: {:?}",
|
||||||
|
err
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
err => {
|
||||||
|
return Err(ProviderError::ServerError(format!(
|
||||||
|
"Failed to call Bedrock: {:?}",
|
||||||
|
err
|
||||||
|
)));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
ConverseError::ThrottlingException(err) => ProviderError::RateLimitExceeded(
|
}
|
||||||
format!("Failed to call Bedrock: {:?}", err),
|
|
||||||
),
|
|
||||||
ConverseError::ValidationException(err)
|
|
||||||
if err
|
|
||||||
.message()
|
|
||||||
.unwrap_or_default()
|
|
||||||
.contains("Input is too long for requested model.") =>
|
|
||||||
{
|
|
||||||
ProviderError::ContextLengthExceeded(format!(
|
|
||||||
"Failed to call Bedrock: {:?}",
|
|
||||||
err
|
|
||||||
))
|
|
||||||
}
|
|
||||||
ConverseError::ModelErrorException(err) => {
|
|
||||||
ProviderError::ExecutionError(format!("Failed to call Bedrock: {:?}", err))
|
|
||||||
}
|
|
||||||
err => {
|
|
||||||
ProviderError::ServerError(format!("Failed to call Bedrock: {:?}", err,))
|
|
||||||
}
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
};
|
}
|
||||||
|
|
||||||
let message = match response.output {
|
|
||||||
Some(bedrock::ConverseOutput::Message(message)) => message,
|
|
||||||
_ => {
|
|
||||||
return Err(ProviderError::RequestFailed(
|
|
||||||
"No output from Bedrock".to_string(),
|
|
||||||
))
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
let usage = response
|
|
||||||
.usage
|
|
||||||
.as_ref()
|
|
||||||
.map(from_bedrock_usage)
|
|
||||||
.unwrap_or_default();
|
|
||||||
|
|
||||||
let message = from_bedrock_message(&message)?;
|
|
||||||
|
|
||||||
// Add debug trace with input context
|
|
||||||
let debug_payload = serde_json::json!({
|
|
||||||
"system": system,
|
|
||||||
"messages": messages,
|
|
||||||
"tools": tools
|
|
||||||
});
|
|
||||||
emit_debug_trace(
|
|
||||||
&self.model,
|
|
||||||
&debug_payload,
|
|
||||||
&serde_json::to_value(&message).unwrap_or_default(),
|
|
||||||
&usage,
|
|
||||||
);
|
|
||||||
|
|
||||||
let provider_usage = ProviderUsage::new(model_name.to_string(), usage);
|
|
||||||
Ok((message, provider_usage))
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user