fix: improve AWS credential handling in Bedrock provider, make keyring optional (#1886)

This commit is contained in:
Kalvin C
2025-03-27 16:28:20 -07:00
committed by GitHub
parent 6bd4f36331
commit 2fa8c3e8ab

View File

@@ -1,8 +1,12 @@
use std::collections::HashMap;
use anyhow::Result;
use async_trait::async_trait;
use aws_sdk_bedrockruntime::config::ProvideCredentials;
use aws_sdk_bedrockruntime::operation::converse::ConverseError;
use aws_sdk_bedrockruntime::{types as bedrock, Client};
use mcp_core::Tool;
use serde_json::Value;
use super::base::{ConfigKey, Provider, ProviderMetadata, ProviderUsage};
use super::errors::ProviderError;
@@ -34,12 +38,30 @@ pub struct BedrockProvider {
impl BedrockProvider {
pub fn from_env(model: ModelConfig) -> Result<Self> {
let config = crate::config::Config::global();
for (key, value) in config.load_secrets()?.iter() {
if key.starts_with("AWS_") && value.is_string() {
std::env::set_var(key, value.as_str().unwrap());
// Attempt to load config and secrets to get AWS_ prefixed keys
// to re-export them into the environment for aws_config::load_from_env()
let set_aws_env_vars = |res: Result<HashMap<String, Value>, _>| {
if let Ok(map) = res {
map.into_iter()
.filter(|(key, _)| key.starts_with("AWS_"))
.filter_map(|(key, value)| value.as_str().map(|s| (key, s.to_string())))
.for_each(|(key, s)| std::env::set_var(key, s));
}
}
};
set_aws_env_vars(config.load_values());
set_aws_env_vars(config.load_secrets());
let sdk_config = futures::executor::block_on(aws_config::load_from_env());
// validate credentials or return error back up
futures::executor::block_on(
sdk_config
.credentials_provider()
.unwrap()
.provide_credentials(),
)?;
let client = Client::new(&sdk_config);
Ok(Self { client, model })
@@ -63,7 +85,7 @@ impl Provider for BedrockProvider {
BEDROCK_DEFAULT_MODEL,
BEDROCK_KNOWN_MODELS.iter().map(|s| s.to_string()).collect(),
BEDROCK_DOC_LINK,
vec![ConfigKey::new("AWS_PROFILE", true, false, Some("us-west-2"))],
vec![ConfigKey::new("AWS_PROFILE", true, false, Some("default"))],
)
}