feat: add bedrock provider (#1069)

Co-authored-by: Burak Varlı <unexge@gmail.com>
This commit is contained in:
Alice Hau
2025-02-05 17:05:59 -05:00
committed by GitHub
parent 56fc54509b
commit b4b213ba34
8 changed files with 487 additions and 2 deletions

View File

@@ -61,6 +61,11 @@ once_cell = "1.20.2"
dirs = "6.0.0"
rand = "0.8.5"
# For Bedrock provider
aws-config = { version = "1.1.7", features = ["behavior-version-latest"] }
aws-smithy-types = "1.2.12"
aws-sdk-bedrockruntime = "1.72.0"
[dev-dependencies]
criterion = "0.5"
tempfile = "3.15.0"

View File

@@ -0,0 +1,162 @@
use anyhow::Result;
use async_trait::async_trait;
use aws_sdk_bedrockruntime::operation::converse::ConverseError;
use aws_sdk_bedrockruntime::{types as bedrock, Client};
use mcp_core::Tool;
use super::base::{Provider, ProviderMetadata, ProviderUsage};
use super::errors::ProviderError;
use crate::message::Message;
use crate::model::ModelConfig;
use crate::providers::utils::emit_debug_trace;
// Import the migrated helper functions from providers/formats/bedrock.rs
use super::formats::bedrock::{
from_bedrock_message, from_bedrock_usage, to_bedrock_message, to_bedrock_tool_config,
};
pub const BEDROCK_DOC_LINK: &str =
"https://docs.aws.amazon.com/bedrock/latest/userguide/models-supported.html";
pub const BEDROCK_DEFAULT_MODEL: &str = "anthropic.claude-3-5-sonnet-20240620-v1:0";
pub const BEDROCK_KNOWN_MODELS: &[&str] = &[
"anthropic.claude-3-5-sonnet-20240620-v1:0",
"anthropic.claude-3-5-sonnet-20241022-v2:0",
];
#[derive(Debug, serde::Serialize)]
pub struct BedrockProvider {
#[serde(skip)]
client: Client,
model: ModelConfig,
}
impl BedrockProvider {
pub fn from_env(model: ModelConfig) -> Result<Self> {
let sdk_config = futures::executor::block_on(aws_config::load_from_env());
let client = Client::new(&sdk_config);
Ok(Self { client, model })
}
}
impl Default for BedrockProvider {
fn default() -> Self {
let model = ModelConfig::new(BedrockProvider::metadata().default_model);
BedrockProvider::from_env(model).expect("Failed to initialize Bedrock provider")
}
}
#[async_trait]
impl Provider for BedrockProvider {
fn metadata() -> ProviderMetadata {
ProviderMetadata::new(
"bedrock",
"Amazon Bedrock",
"Run models through Amazon Bedrock. You may have to set AWS_ACCESS_KEY_ID, AWS_ACCESS_KEY, and AWS_REGION as env vars before configuring.",
BEDROCK_DEFAULT_MODEL,
BEDROCK_KNOWN_MODELS.iter().map(|s| s.to_string()).collect(),
BEDROCK_DOC_LINK,
vec![],
)
}
fn get_model_config(&self) -> ModelConfig {
self.model.clone()
}
#[tracing::instrument(
skip(self, system, messages, tools),
fields(model_config, input, output, input_tokens, output_tokens, total_tokens)
)]
async fn complete(
&self,
system: &str,
messages: &[Message],
tools: &[Tool],
) -> Result<(Message, ProviderUsage), ProviderError> {
let model_name = &self.model.model_name;
let mut request = self
.client
.converse()
.system(bedrock::SystemContentBlock::Text(system.to_string()))
.model_id(model_name.to_string())
.set_messages(Some(
messages
.iter()
.map(to_bedrock_message)
.collect::<Result<_>>()?,
));
if !tools.is_empty() {
request = request.tool_config(to_bedrock_tool_config(tools)?);
}
let response = request.send().await;
let response = match response {
Ok(response) => response,
Err(err) => {
return Err(match err.into_service_error() {
ConverseError::AccessDeniedException(err) => {
ProviderError::Authentication(format!("Failed to call Bedrock: {:?}", err))
}
ConverseError::ThrottlingException(err) => ProviderError::RateLimitExceeded(
format!("Failed to call Bedrock: {:?}", err),
),
ConverseError::ValidationException(err)
if err
.message()
.unwrap_or_default()
.contains("Input is too long for requested model.") =>
{
ProviderError::ContextLengthExceeded(format!(
"Failed to call Bedrock: {:?}",
err
))
}
ConverseError::ModelErrorException(err) => {
ProviderError::ExecutionError(format!("Failed to call Bedrock: {:?}", err))
}
err => {
ProviderError::ServerError(format!("Failed to call Bedrock: {:?}", err,))
}
});
}
};
let message = match response.output {
Some(bedrock::ConverseOutput::Message(message)) => message,
_ => {
return Err(ProviderError::RequestFailed(
"No output from Bedrock".to_string(),
))
}
};
let usage = response
.usage
.as_ref()
.map(from_bedrock_usage)
.unwrap_or_default();
let message = from_bedrock_message(&message)?;
// Add debug trace with input context
let debug_payload = serde_json::json!({
"system": system,
"messages": messages,
"tools": tools
});
emit_debug_trace(
&self.model,
&debug_payload,
&serde_json::to_value(&message).unwrap_or_default(),
&usage,
);
let provider_usage = ProviderUsage::new(model_name.to_string(), usage);
Ok((message, provider_usage))
}
}

View File

@@ -2,6 +2,7 @@ use super::{
anthropic::AnthropicProvider,
azure::AzureProvider,
base::{Provider, ProviderMetadata},
bedrock::BedrockProvider,
databricks::DatabricksProvider,
google::GoogleProvider,
groq::GroqProvider,
@@ -16,6 +17,7 @@ pub fn providers() -> Vec<ProviderMetadata> {
vec![
AnthropicProvider::metadata(),
AzureProvider::metadata(),
BedrockProvider::metadata(),
DatabricksProvider::metadata(),
GoogleProvider::metadata(),
GroqProvider::metadata(),
@@ -30,6 +32,7 @@ pub fn create(name: &str, model: ModelConfig) -> Result<Box<dyn Provider + Send
"openai" => Ok(Box::new(OpenAiProvider::from_env(model)?)),
"anthropic" => Ok(Box::new(AnthropicProvider::from_env(model)?)),
"azure_openai" => Ok(Box::new(AzureProvider::from_env(model)?)),
"bedrock" => Ok(Box::new(BedrockProvider::from_env(model)?)),
"databricks" => Ok(Box::new(DatabricksProvider::from_env(model)?)),
"groq" => Ok(Box::new(GroqProvider::from_env(model)?)),
"ollama" => Ok(Box::new(OllamaProvider::from_env(model)?)),

View File

@@ -0,0 +1,270 @@
use std::collections::HashMap;
use std::path::Path;
use anyhow::{anyhow, bail, Result};
use aws_sdk_bedrockruntime::types as bedrock;
use aws_smithy_types::{Document, Number};
use chrono::Utc;
use mcp_core::{Content, ResourceContents, Role, Tool, ToolCall, ToolError, ToolResult};
use serde_json::Value;
use super::super::base::Usage;
use crate::message::{Message, MessageContent};
pub fn to_bedrock_message(message: &Message) -> Result<bedrock::Message> {
bedrock::Message::builder()
.role(to_bedrock_role(&message.role))
.set_content(Some(
message
.content
.iter()
.map(to_bedrock_message_content)
.collect::<Result<_>>()?,
))
.build()
.map_err(|err| anyhow!("Failed to construct Bedrock message: {}", err))
}
pub fn to_bedrock_message_content(content: &MessageContent) -> Result<bedrock::ContentBlock> {
Ok(match content {
MessageContent::Text(text) => bedrock::ContentBlock::Text(text.text.to_string()),
MessageContent::Image(_) => {
bail!("Image content is not supported by Bedrock provider yet")
}
MessageContent::ToolRequest(tool_req) => {
let tool_use_id = tool_req.id.to_string();
let tool_use = if let Ok(call) = tool_req.tool_call.as_ref() {
bedrock::ToolUseBlock::builder()
.tool_use_id(tool_use_id)
.name(call.name.to_string())
.input(to_bedrock_json(&call.arguments))
.build()
} else {
bedrock::ToolUseBlock::builder()
.tool_use_id(tool_use_id)
.build()
}?;
bedrock::ContentBlock::ToolUse(tool_use)
}
MessageContent::ToolResponse(tool_res) => {
let content = match &tool_res.tool_result {
Ok(content) => Some(
content
.iter()
.map(|c| to_bedrock_tool_result_content_block(&tool_res.id, c))
.collect::<Result<_>>()?,
),
Err(_) => None,
};
bedrock::ContentBlock::ToolResult(
bedrock::ToolResultBlock::builder()
.tool_use_id(tool_res.id.to_string())
.status(if content.is_some() {
bedrock::ToolResultStatus::Success
} else {
bedrock::ToolResultStatus::Error
})
.set_content(content)
.build()?,
)
}
})
}
pub fn to_bedrock_tool_result_content_block(
tool_use_id: &str,
content: &Content,
) -> Result<bedrock::ToolResultContentBlock> {
Ok(match content {
Content::Text(text) => bedrock::ToolResultContentBlock::Text(text.text.to_string()),
Content::Image(_) => bail!("Image content is not supported by Bedrock provider yet"),
Content::Resource(resource) => bedrock::ToolResultContentBlock::Document(
to_bedrock_document(tool_use_id, &resource.resource)?,
),
})
}
pub fn to_bedrock_role(role: &Role) -> bedrock::ConversationRole {
match role {
Role::User => bedrock::ConversationRole::User,
Role::Assistant => bedrock::ConversationRole::Assistant,
}
}
pub fn to_bedrock_tool_config(tools: &[Tool]) -> Result<bedrock::ToolConfiguration> {
Ok(bedrock::ToolConfiguration::builder()
.set_tools(Some(
tools.iter().map(to_bedrock_tool).collect::<Result<_>>()?,
))
.build()?)
}
pub fn to_bedrock_tool(tool: &Tool) -> Result<bedrock::Tool> {
Ok(bedrock::Tool::ToolSpec(
bedrock::ToolSpecification::builder()
.name(tool.name.to_string())
.description(tool.description.to_string())
.input_schema(bedrock::ToolInputSchema::Json(to_bedrock_json(
&tool.input_schema,
)))
.build()?,
))
}
pub fn to_bedrock_json(value: &Value) -> Document {
match value {
Value::Null => Document::Null,
Value::Bool(bool) => Document::Bool(*bool),
Value::Number(num) => {
if let Some(n) = num.as_u64() {
Document::Number(Number::PosInt(n))
} else if let Some(n) = num.as_i64() {
Document::Number(Number::NegInt(n))
} else if let Some(n) = num.as_f64() {
Document::Number(Number::Float(n))
} else {
unreachable!()
}
}
Value::String(str) => Document::String(str.to_string()),
Value::Array(arr) => Document::Array(arr.iter().map(to_bedrock_json).collect()),
Value::Object(obj) => Document::Object(HashMap::from_iter(
obj.into_iter()
.map(|(key, val)| (key.to_string(), to_bedrock_json(val))),
)),
}
}
fn to_bedrock_document(
tool_use_id: &str,
content: &ResourceContents,
) -> Result<bedrock::DocumentBlock> {
let (uri, text) = match content {
ResourceContents::TextResourceContents { uri, text, .. } => (uri, text),
ResourceContents::BlobResourceContents { .. } => {
bail!("Blob resource content is not supported by Bedrock provider yet")
}
};
let filename = Path::new(uri)
.file_name()
.and_then(|n| n.to_str())
.unwrap_or(uri);
let (name, format) = match filename.split_once('.') {
Some((name, "txt")) => (name, bedrock::DocumentFormat::Txt),
Some((name, "csv")) => (name, bedrock::DocumentFormat::Csv),
Some((name, "md")) => (name, bedrock::DocumentFormat::Md),
Some((name, "html")) => (name, bedrock::DocumentFormat::Html),
Some((name, _)) => (name, bedrock::DocumentFormat::Txt),
_ => (filename, bedrock::DocumentFormat::Txt),
};
// Since we can't use the full path (due to character limit and also Bedrock does not accept `/` etc.),
// and Bedrock wants document names to be unique, we're adding `tool_use_id` as a prefix to make
// document names unique.
let name = format!("{tool_use_id}-{name}");
bedrock::DocumentBlock::builder()
.format(format)
.name(name)
.source(bedrock::DocumentSource::Bytes(text.as_bytes().into()))
.build()
.map_err(|err| anyhow!("Failed to construct Bedrock document: {}", err))
}
pub fn from_bedrock_message(message: &bedrock::Message) -> Result<Message> {
let role = from_bedrock_role(message.role())?;
let content = message
.content()
.iter()
.map(from_bedrock_content_block)
.collect::<Result<Vec<_>>>()?;
let created = Utc::now().timestamp();
Ok(Message {
role,
content,
created,
})
}
pub fn from_bedrock_content_block(block: &bedrock::ContentBlock) -> Result<MessageContent> {
Ok(match block {
bedrock::ContentBlock::Text(text) => MessageContent::text(text),
bedrock::ContentBlock::ToolUse(tool_use) => MessageContent::tool_request(
tool_use.tool_use_id.to_string(),
Ok(ToolCall::new(
tool_use.name.to_string(),
from_bedrock_json(&tool_use.input)?,
)),
),
bedrock::ContentBlock::ToolResult(tool_res) => MessageContent::tool_response(
tool_res.tool_use_id.to_string(),
if tool_res.content.is_empty() {
Err(ToolError::ExecutionError(
"Empty content for tool use from Bedrock".to_string(),
))
} else {
tool_res
.content
.iter()
.map(from_bedrock_tool_result_content_block)
.collect::<ToolResult<Vec<_>>>()
},
),
_ => bail!("Unsupported content block type from Bedrock"),
})
}
pub fn from_bedrock_tool_result_content_block(
content: &bedrock::ToolResultContentBlock,
) -> ToolResult<Content> {
Ok(match content {
bedrock::ToolResultContentBlock::Text(text) => Content::text(text.to_string()),
_ => {
return Err(ToolError::ExecutionError(
"Unsupported tool result from Bedrock".to_string(),
))
}
})
}
pub fn from_bedrock_role(role: &bedrock::ConversationRole) -> Result<Role> {
Ok(match role {
bedrock::ConversationRole::User => Role::User,
bedrock::ConversationRole::Assistant => Role::Assistant,
_ => bail!("Unknown role from Bedrock"),
})
}
pub fn from_bedrock_usage(usage: &bedrock::TokenUsage) -> Usage {
Usage {
input_tokens: Some(usage.input_tokens),
output_tokens: Some(usage.output_tokens),
total_tokens: Some(usage.total_tokens),
}
}
pub fn from_bedrock_json(document: &Document) -> Result<Value> {
Ok(match document {
Document::Null => Value::Null,
Document::Bool(bool) => Value::Bool(*bool),
Document::Number(num) => match num {
Number::PosInt(i) => Value::Number((*i).into()),
Number::NegInt(i) => Value::Number((*i).into()),
Number::Float(f) => Value::Number(
serde_json::Number::from_f64(*f).ok_or(anyhow!("Expected a valid float"))?,
),
},
Document::String(str) => Value::String(str.clone()),
Document::Array(arr) => {
Value::Array(arr.iter().map(from_bedrock_json).collect::<Result<_>>()?)
}
Document::Object(obj) => Value::Object(
obj.iter()
.map(|(key, val)| Ok((key.clone(), from_bedrock_json(val)?)))
.collect::<Result<_>>()?,
),
})
}

View File

@@ -1,3 +1,4 @@
pub mod anthropic;
pub mod bedrock;
pub mod google;
pub mod openai;

View File

@@ -1,6 +1,7 @@
pub mod anthropic;
pub mod azure;
pub mod base;
pub mod bedrock;
pub mod databricks;
pub mod errors;
mod factory;

View File

@@ -3,7 +3,9 @@ use dotenv::dotenv;
use goose::message::{Message, MessageContent};
use goose::providers::base::Provider;
use goose::providers::errors::ProviderError;
use goose::providers::{anthropic, azure, databricks, google, groq, ollama, openai, openrouter};
use goose::providers::{
anthropic, azure, bedrock, databricks, google, groq, ollama, openai, openrouter,
};
use mcp_core::content::Content;
use mcp_core::tool::Tool;
use std::collections::HashMap;
@@ -374,6 +376,34 @@ async fn test_azure_provider() -> Result<()> {
.await
}
#[tokio::test]
async fn test_bedrock_provider_long_term_credentials() -> Result<()> {
test_provider(
"Bedrock",
&["AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY"],
None,
bedrock::BedrockProvider::default,
)
.await
}
#[tokio::test]
async fn test_bedrock_provider_aws_profile_credentials() -> Result<()> {
let env_mods = HashMap::from_iter([
// Ensure to unset long-term credentials to use AWS Profile provider
("AWS_ACCESS_KEY_ID", None),
("AWS_SECRET_ACCESS_KEY", None),
]);
test_provider(
"Bedrock AWS Profile Credentials",
&["AWS_PROFILE"],
Some(env_mods),
bedrock::BedrockProvider::default,
)
.await
}
#[tokio::test]
async fn test_databricks_provider() -> Result<()> {
test_provider(

View File

@@ -8,7 +8,7 @@ use goose::model::ModelConfig;
use goose::providers::base::Provider;
use goose::providers::{anthropic::AnthropicProvider, databricks::DatabricksProvider};
use goose::providers::{
azure::AzureProvider, ollama::OllamaProvider, openai::OpenAiProvider,
azure::AzureProvider, bedrock::BedrockProvider, ollama::OllamaProvider, openai::OpenAiProvider,
openrouter::OpenRouterProvider,
};
use goose::providers::{google::GoogleProvider, groq::GroqProvider};
@@ -18,6 +18,7 @@ enum ProviderType {
Azure,
OpenAi,
Anthropic,
Bedrock,
Databricks,
Google,
Groq,
@@ -35,6 +36,7 @@ impl ProviderType {
],
ProviderType::OpenAi => &["OPENAI_API_KEY"],
ProviderType::Anthropic => &["ANTHROPIC_API_KEY"],
ProviderType::Bedrock => &["AWS_PROFILE", "AWS_REGION"],
ProviderType::Databricks => &["DATABRICKS_HOST"],
ProviderType::Google => &["GOOGLE_API_KEY"],
ProviderType::Groq => &["GROQ_API_KEY"],
@@ -66,6 +68,7 @@ impl ProviderType {
ProviderType::Azure => Box::new(AzureProvider::from_env(model_config)?),
ProviderType::OpenAi => Box::new(OpenAiProvider::from_env(model_config)?),
ProviderType::Anthropic => Box::new(AnthropicProvider::from_env(model_config)?),
ProviderType::Bedrock => Box::new(BedrockProvider::from_env(model_config)?),
ProviderType::Databricks => Box::new(DatabricksProvider::from_env(model_config)?),
ProviderType::Google => Box::new(GoogleProvider::from_env(model_config)?),
ProviderType::Groq => Box::new(GroqProvider::from_env(model_config)?),
@@ -200,6 +203,16 @@ mod tests {
.await
}
#[tokio::test]
async fn test_truncate_agent_with_bedrock() -> Result<()> {
run_test_with_config(TestConfig {
provider_type: ProviderType::Bedrock,
model: "anthropic.claude-3-5-sonnet-20241022-v2:0",
context_window: 200_000,
})
.await
}
#[tokio::test]
async fn test_truncate_agent_with_databricks() -> Result<()> {
run_test_with_config(TestConfig {