mirror of
https://github.com/aljazceru/goose.git
synced 2025-12-19 07:04:21 +01:00
feat: add bedrock provider (#1069)
Co-authored-by: Burak Varlı <unexge@gmail.com>
This commit is contained in:
@@ -61,6 +61,11 @@ once_cell = "1.20.2"
|
|||||||
dirs = "6.0.0"
|
dirs = "6.0.0"
|
||||||
rand = "0.8.5"
|
rand = "0.8.5"
|
||||||
|
|
||||||
|
# For Bedrock provider
|
||||||
|
aws-config = { version = "1.1.7", features = ["behavior-version-latest"] }
|
||||||
|
aws-smithy-types = "1.2.12"
|
||||||
|
aws-sdk-bedrockruntime = "1.72.0"
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
criterion = "0.5"
|
criterion = "0.5"
|
||||||
tempfile = "3.15.0"
|
tempfile = "3.15.0"
|
||||||
|
|||||||
162
crates/goose/src/providers/bedrock.rs
Normal file
162
crates/goose/src/providers/bedrock.rs
Normal file
@@ -0,0 +1,162 @@
|
|||||||
|
use anyhow::Result;
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use aws_sdk_bedrockruntime::operation::converse::ConverseError;
|
||||||
|
use aws_sdk_bedrockruntime::{types as bedrock, Client};
|
||||||
|
use mcp_core::Tool;
|
||||||
|
|
||||||
|
use super::base::{Provider, ProviderMetadata, ProviderUsage};
|
||||||
|
use super::errors::ProviderError;
|
||||||
|
use crate::message::Message;
|
||||||
|
use crate::model::ModelConfig;
|
||||||
|
use crate::providers::utils::emit_debug_trace;
|
||||||
|
|
||||||
|
// Import the migrated helper functions from providers/formats/bedrock.rs
|
||||||
|
use super::formats::bedrock::{
|
||||||
|
from_bedrock_message, from_bedrock_usage, to_bedrock_message, to_bedrock_tool_config,
|
||||||
|
};
|
||||||
|
|
||||||
|
pub const BEDROCK_DOC_LINK: &str =
|
||||||
|
"https://docs.aws.amazon.com/bedrock/latest/userguide/models-supported.html";
|
||||||
|
|
||||||
|
pub const BEDROCK_DEFAULT_MODEL: &str = "anthropic.claude-3-5-sonnet-20240620-v1:0";
|
||||||
|
pub const BEDROCK_KNOWN_MODELS: &[&str] = &[
|
||||||
|
"anthropic.claude-3-5-sonnet-20240620-v1:0",
|
||||||
|
"anthropic.claude-3-5-sonnet-20241022-v2:0",
|
||||||
|
];
|
||||||
|
|
||||||
|
#[derive(Debug, serde::Serialize)]
|
||||||
|
pub struct BedrockProvider {
|
||||||
|
#[serde(skip)]
|
||||||
|
client: Client,
|
||||||
|
model: ModelConfig,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl BedrockProvider {
|
||||||
|
pub fn from_env(model: ModelConfig) -> Result<Self> {
|
||||||
|
let sdk_config = futures::executor::block_on(aws_config::load_from_env());
|
||||||
|
let client = Client::new(&sdk_config);
|
||||||
|
|
||||||
|
Ok(Self { client, model })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for BedrockProvider {
|
||||||
|
fn default() -> Self {
|
||||||
|
let model = ModelConfig::new(BedrockProvider::metadata().default_model);
|
||||||
|
BedrockProvider::from_env(model).expect("Failed to initialize Bedrock provider")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl Provider for BedrockProvider {
|
||||||
|
fn metadata() -> ProviderMetadata {
|
||||||
|
ProviderMetadata::new(
|
||||||
|
"bedrock",
|
||||||
|
"Amazon Bedrock",
|
||||||
|
"Run models through Amazon Bedrock. You may have to set AWS_ACCESS_KEY_ID, AWS_ACCESS_KEY, and AWS_REGION as env vars before configuring.",
|
||||||
|
BEDROCK_DEFAULT_MODEL,
|
||||||
|
BEDROCK_KNOWN_MODELS.iter().map(|s| s.to_string()).collect(),
|
||||||
|
BEDROCK_DOC_LINK,
|
||||||
|
vec![],
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_model_config(&self) -> ModelConfig {
|
||||||
|
self.model.clone()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tracing::instrument(
|
||||||
|
skip(self, system, messages, tools),
|
||||||
|
fields(model_config, input, output, input_tokens, output_tokens, total_tokens)
|
||||||
|
)]
|
||||||
|
async fn complete(
|
||||||
|
&self,
|
||||||
|
system: &str,
|
||||||
|
messages: &[Message],
|
||||||
|
tools: &[Tool],
|
||||||
|
) -> Result<(Message, ProviderUsage), ProviderError> {
|
||||||
|
let model_name = &self.model.model_name;
|
||||||
|
|
||||||
|
let mut request = self
|
||||||
|
.client
|
||||||
|
.converse()
|
||||||
|
.system(bedrock::SystemContentBlock::Text(system.to_string()))
|
||||||
|
.model_id(model_name.to_string())
|
||||||
|
.set_messages(Some(
|
||||||
|
messages
|
||||||
|
.iter()
|
||||||
|
.map(to_bedrock_message)
|
||||||
|
.collect::<Result<_>>()?,
|
||||||
|
));
|
||||||
|
|
||||||
|
if !tools.is_empty() {
|
||||||
|
request = request.tool_config(to_bedrock_tool_config(tools)?);
|
||||||
|
}
|
||||||
|
|
||||||
|
let response = request.send().await;
|
||||||
|
|
||||||
|
let response = match response {
|
||||||
|
Ok(response) => response,
|
||||||
|
Err(err) => {
|
||||||
|
return Err(match err.into_service_error() {
|
||||||
|
ConverseError::AccessDeniedException(err) => {
|
||||||
|
ProviderError::Authentication(format!("Failed to call Bedrock: {:?}", err))
|
||||||
|
}
|
||||||
|
ConverseError::ThrottlingException(err) => ProviderError::RateLimitExceeded(
|
||||||
|
format!("Failed to call Bedrock: {:?}", err),
|
||||||
|
),
|
||||||
|
ConverseError::ValidationException(err)
|
||||||
|
if err
|
||||||
|
.message()
|
||||||
|
.unwrap_or_default()
|
||||||
|
.contains("Input is too long for requested model.") =>
|
||||||
|
{
|
||||||
|
ProviderError::ContextLengthExceeded(format!(
|
||||||
|
"Failed to call Bedrock: {:?}",
|
||||||
|
err
|
||||||
|
))
|
||||||
|
}
|
||||||
|
ConverseError::ModelErrorException(err) => {
|
||||||
|
ProviderError::ExecutionError(format!("Failed to call Bedrock: {:?}", err))
|
||||||
|
}
|
||||||
|
err => {
|
||||||
|
ProviderError::ServerError(format!("Failed to call Bedrock: {:?}", err,))
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let message = match response.output {
|
||||||
|
Some(bedrock::ConverseOutput::Message(message)) => message,
|
||||||
|
_ => {
|
||||||
|
return Err(ProviderError::RequestFailed(
|
||||||
|
"No output from Bedrock".to_string(),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let usage = response
|
||||||
|
.usage
|
||||||
|
.as_ref()
|
||||||
|
.map(from_bedrock_usage)
|
||||||
|
.unwrap_or_default();
|
||||||
|
|
||||||
|
let message = from_bedrock_message(&message)?;
|
||||||
|
|
||||||
|
// Add debug trace with input context
|
||||||
|
let debug_payload = serde_json::json!({
|
||||||
|
"system": system,
|
||||||
|
"messages": messages,
|
||||||
|
"tools": tools
|
||||||
|
});
|
||||||
|
emit_debug_trace(
|
||||||
|
&self.model,
|
||||||
|
&debug_payload,
|
||||||
|
&serde_json::to_value(&message).unwrap_or_default(),
|
||||||
|
&usage,
|
||||||
|
);
|
||||||
|
|
||||||
|
let provider_usage = ProviderUsage::new(model_name.to_string(), usage);
|
||||||
|
Ok((message, provider_usage))
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -2,6 +2,7 @@ use super::{
|
|||||||
anthropic::AnthropicProvider,
|
anthropic::AnthropicProvider,
|
||||||
azure::AzureProvider,
|
azure::AzureProvider,
|
||||||
base::{Provider, ProviderMetadata},
|
base::{Provider, ProviderMetadata},
|
||||||
|
bedrock::BedrockProvider,
|
||||||
databricks::DatabricksProvider,
|
databricks::DatabricksProvider,
|
||||||
google::GoogleProvider,
|
google::GoogleProvider,
|
||||||
groq::GroqProvider,
|
groq::GroqProvider,
|
||||||
@@ -16,6 +17,7 @@ pub fn providers() -> Vec<ProviderMetadata> {
|
|||||||
vec![
|
vec![
|
||||||
AnthropicProvider::metadata(),
|
AnthropicProvider::metadata(),
|
||||||
AzureProvider::metadata(),
|
AzureProvider::metadata(),
|
||||||
|
BedrockProvider::metadata(),
|
||||||
DatabricksProvider::metadata(),
|
DatabricksProvider::metadata(),
|
||||||
GoogleProvider::metadata(),
|
GoogleProvider::metadata(),
|
||||||
GroqProvider::metadata(),
|
GroqProvider::metadata(),
|
||||||
@@ -30,6 +32,7 @@ pub fn create(name: &str, model: ModelConfig) -> Result<Box<dyn Provider + Send
|
|||||||
"openai" => Ok(Box::new(OpenAiProvider::from_env(model)?)),
|
"openai" => Ok(Box::new(OpenAiProvider::from_env(model)?)),
|
||||||
"anthropic" => Ok(Box::new(AnthropicProvider::from_env(model)?)),
|
"anthropic" => Ok(Box::new(AnthropicProvider::from_env(model)?)),
|
||||||
"azure_openai" => Ok(Box::new(AzureProvider::from_env(model)?)),
|
"azure_openai" => Ok(Box::new(AzureProvider::from_env(model)?)),
|
||||||
|
"bedrock" => Ok(Box::new(BedrockProvider::from_env(model)?)),
|
||||||
"databricks" => Ok(Box::new(DatabricksProvider::from_env(model)?)),
|
"databricks" => Ok(Box::new(DatabricksProvider::from_env(model)?)),
|
||||||
"groq" => Ok(Box::new(GroqProvider::from_env(model)?)),
|
"groq" => Ok(Box::new(GroqProvider::from_env(model)?)),
|
||||||
"ollama" => Ok(Box::new(OllamaProvider::from_env(model)?)),
|
"ollama" => Ok(Box::new(OllamaProvider::from_env(model)?)),
|
||||||
|
|||||||
270
crates/goose/src/providers/formats/bedrock.rs
Normal file
270
crates/goose/src/providers/formats/bedrock.rs
Normal file
@@ -0,0 +1,270 @@
|
|||||||
|
use std::collections::HashMap;
|
||||||
|
use std::path::Path;
|
||||||
|
|
||||||
|
use anyhow::{anyhow, bail, Result};
|
||||||
|
use aws_sdk_bedrockruntime::types as bedrock;
|
||||||
|
use aws_smithy_types::{Document, Number};
|
||||||
|
use chrono::Utc;
|
||||||
|
use mcp_core::{Content, ResourceContents, Role, Tool, ToolCall, ToolError, ToolResult};
|
||||||
|
use serde_json::Value;
|
||||||
|
|
||||||
|
use super::super::base::Usage;
|
||||||
|
use crate::message::{Message, MessageContent};
|
||||||
|
|
||||||
|
pub fn to_bedrock_message(message: &Message) -> Result<bedrock::Message> {
|
||||||
|
bedrock::Message::builder()
|
||||||
|
.role(to_bedrock_role(&message.role))
|
||||||
|
.set_content(Some(
|
||||||
|
message
|
||||||
|
.content
|
||||||
|
.iter()
|
||||||
|
.map(to_bedrock_message_content)
|
||||||
|
.collect::<Result<_>>()?,
|
||||||
|
))
|
||||||
|
.build()
|
||||||
|
.map_err(|err| anyhow!("Failed to construct Bedrock message: {}", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn to_bedrock_message_content(content: &MessageContent) -> Result<bedrock::ContentBlock> {
|
||||||
|
Ok(match content {
|
||||||
|
MessageContent::Text(text) => bedrock::ContentBlock::Text(text.text.to_string()),
|
||||||
|
MessageContent::Image(_) => {
|
||||||
|
bail!("Image content is not supported by Bedrock provider yet")
|
||||||
|
}
|
||||||
|
MessageContent::ToolRequest(tool_req) => {
|
||||||
|
let tool_use_id = tool_req.id.to_string();
|
||||||
|
let tool_use = if let Ok(call) = tool_req.tool_call.as_ref() {
|
||||||
|
bedrock::ToolUseBlock::builder()
|
||||||
|
.tool_use_id(tool_use_id)
|
||||||
|
.name(call.name.to_string())
|
||||||
|
.input(to_bedrock_json(&call.arguments))
|
||||||
|
.build()
|
||||||
|
} else {
|
||||||
|
bedrock::ToolUseBlock::builder()
|
||||||
|
.tool_use_id(tool_use_id)
|
||||||
|
.build()
|
||||||
|
}?;
|
||||||
|
bedrock::ContentBlock::ToolUse(tool_use)
|
||||||
|
}
|
||||||
|
MessageContent::ToolResponse(tool_res) => {
|
||||||
|
let content = match &tool_res.tool_result {
|
||||||
|
Ok(content) => Some(
|
||||||
|
content
|
||||||
|
.iter()
|
||||||
|
.map(|c| to_bedrock_tool_result_content_block(&tool_res.id, c))
|
||||||
|
.collect::<Result<_>>()?,
|
||||||
|
),
|
||||||
|
Err(_) => None,
|
||||||
|
};
|
||||||
|
bedrock::ContentBlock::ToolResult(
|
||||||
|
bedrock::ToolResultBlock::builder()
|
||||||
|
.tool_use_id(tool_res.id.to_string())
|
||||||
|
.status(if content.is_some() {
|
||||||
|
bedrock::ToolResultStatus::Success
|
||||||
|
} else {
|
||||||
|
bedrock::ToolResultStatus::Error
|
||||||
|
})
|
||||||
|
.set_content(content)
|
||||||
|
.build()?,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn to_bedrock_tool_result_content_block(
|
||||||
|
tool_use_id: &str,
|
||||||
|
content: &Content,
|
||||||
|
) -> Result<bedrock::ToolResultContentBlock> {
|
||||||
|
Ok(match content {
|
||||||
|
Content::Text(text) => bedrock::ToolResultContentBlock::Text(text.text.to_string()),
|
||||||
|
Content::Image(_) => bail!("Image content is not supported by Bedrock provider yet"),
|
||||||
|
Content::Resource(resource) => bedrock::ToolResultContentBlock::Document(
|
||||||
|
to_bedrock_document(tool_use_id, &resource.resource)?,
|
||||||
|
),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn to_bedrock_role(role: &Role) -> bedrock::ConversationRole {
|
||||||
|
match role {
|
||||||
|
Role::User => bedrock::ConversationRole::User,
|
||||||
|
Role::Assistant => bedrock::ConversationRole::Assistant,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn to_bedrock_tool_config(tools: &[Tool]) -> Result<bedrock::ToolConfiguration> {
|
||||||
|
Ok(bedrock::ToolConfiguration::builder()
|
||||||
|
.set_tools(Some(
|
||||||
|
tools.iter().map(to_bedrock_tool).collect::<Result<_>>()?,
|
||||||
|
))
|
||||||
|
.build()?)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn to_bedrock_tool(tool: &Tool) -> Result<bedrock::Tool> {
|
||||||
|
Ok(bedrock::Tool::ToolSpec(
|
||||||
|
bedrock::ToolSpecification::builder()
|
||||||
|
.name(tool.name.to_string())
|
||||||
|
.description(tool.description.to_string())
|
||||||
|
.input_schema(bedrock::ToolInputSchema::Json(to_bedrock_json(
|
||||||
|
&tool.input_schema,
|
||||||
|
)))
|
||||||
|
.build()?,
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn to_bedrock_json(value: &Value) -> Document {
|
||||||
|
match value {
|
||||||
|
Value::Null => Document::Null,
|
||||||
|
Value::Bool(bool) => Document::Bool(*bool),
|
||||||
|
Value::Number(num) => {
|
||||||
|
if let Some(n) = num.as_u64() {
|
||||||
|
Document::Number(Number::PosInt(n))
|
||||||
|
} else if let Some(n) = num.as_i64() {
|
||||||
|
Document::Number(Number::NegInt(n))
|
||||||
|
} else if let Some(n) = num.as_f64() {
|
||||||
|
Document::Number(Number::Float(n))
|
||||||
|
} else {
|
||||||
|
unreachable!()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Value::String(str) => Document::String(str.to_string()),
|
||||||
|
Value::Array(arr) => Document::Array(arr.iter().map(to_bedrock_json).collect()),
|
||||||
|
Value::Object(obj) => Document::Object(HashMap::from_iter(
|
||||||
|
obj.into_iter()
|
||||||
|
.map(|(key, val)| (key.to_string(), to_bedrock_json(val))),
|
||||||
|
)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn to_bedrock_document(
|
||||||
|
tool_use_id: &str,
|
||||||
|
content: &ResourceContents,
|
||||||
|
) -> Result<bedrock::DocumentBlock> {
|
||||||
|
let (uri, text) = match content {
|
||||||
|
ResourceContents::TextResourceContents { uri, text, .. } => (uri, text),
|
||||||
|
ResourceContents::BlobResourceContents { .. } => {
|
||||||
|
bail!("Blob resource content is not supported by Bedrock provider yet")
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let filename = Path::new(uri)
|
||||||
|
.file_name()
|
||||||
|
.and_then(|n| n.to_str())
|
||||||
|
.unwrap_or(uri);
|
||||||
|
|
||||||
|
let (name, format) = match filename.split_once('.') {
|
||||||
|
Some((name, "txt")) => (name, bedrock::DocumentFormat::Txt),
|
||||||
|
Some((name, "csv")) => (name, bedrock::DocumentFormat::Csv),
|
||||||
|
Some((name, "md")) => (name, bedrock::DocumentFormat::Md),
|
||||||
|
Some((name, "html")) => (name, bedrock::DocumentFormat::Html),
|
||||||
|
Some((name, _)) => (name, bedrock::DocumentFormat::Txt),
|
||||||
|
_ => (filename, bedrock::DocumentFormat::Txt),
|
||||||
|
};
|
||||||
|
|
||||||
|
// Since we can't use the full path (due to character limit and also Bedrock does not accept `/` etc.),
|
||||||
|
// and Bedrock wants document names to be unique, we're adding `tool_use_id` as a prefix to make
|
||||||
|
// document names unique.
|
||||||
|
let name = format!("{tool_use_id}-{name}");
|
||||||
|
|
||||||
|
bedrock::DocumentBlock::builder()
|
||||||
|
.format(format)
|
||||||
|
.name(name)
|
||||||
|
.source(bedrock::DocumentSource::Bytes(text.as_bytes().into()))
|
||||||
|
.build()
|
||||||
|
.map_err(|err| anyhow!("Failed to construct Bedrock document: {}", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn from_bedrock_message(message: &bedrock::Message) -> Result<Message> {
|
||||||
|
let role = from_bedrock_role(message.role())?;
|
||||||
|
let content = message
|
||||||
|
.content()
|
||||||
|
.iter()
|
||||||
|
.map(from_bedrock_content_block)
|
||||||
|
.collect::<Result<Vec<_>>>()?;
|
||||||
|
let created = Utc::now().timestamp();
|
||||||
|
|
||||||
|
Ok(Message {
|
||||||
|
role,
|
||||||
|
content,
|
||||||
|
created,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn from_bedrock_content_block(block: &bedrock::ContentBlock) -> Result<MessageContent> {
|
||||||
|
Ok(match block {
|
||||||
|
bedrock::ContentBlock::Text(text) => MessageContent::text(text),
|
||||||
|
bedrock::ContentBlock::ToolUse(tool_use) => MessageContent::tool_request(
|
||||||
|
tool_use.tool_use_id.to_string(),
|
||||||
|
Ok(ToolCall::new(
|
||||||
|
tool_use.name.to_string(),
|
||||||
|
from_bedrock_json(&tool_use.input)?,
|
||||||
|
)),
|
||||||
|
),
|
||||||
|
bedrock::ContentBlock::ToolResult(tool_res) => MessageContent::tool_response(
|
||||||
|
tool_res.tool_use_id.to_string(),
|
||||||
|
if tool_res.content.is_empty() {
|
||||||
|
Err(ToolError::ExecutionError(
|
||||||
|
"Empty content for tool use from Bedrock".to_string(),
|
||||||
|
))
|
||||||
|
} else {
|
||||||
|
tool_res
|
||||||
|
.content
|
||||||
|
.iter()
|
||||||
|
.map(from_bedrock_tool_result_content_block)
|
||||||
|
.collect::<ToolResult<Vec<_>>>()
|
||||||
|
},
|
||||||
|
),
|
||||||
|
_ => bail!("Unsupported content block type from Bedrock"),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn from_bedrock_tool_result_content_block(
|
||||||
|
content: &bedrock::ToolResultContentBlock,
|
||||||
|
) -> ToolResult<Content> {
|
||||||
|
Ok(match content {
|
||||||
|
bedrock::ToolResultContentBlock::Text(text) => Content::text(text.to_string()),
|
||||||
|
_ => {
|
||||||
|
return Err(ToolError::ExecutionError(
|
||||||
|
"Unsupported tool result from Bedrock".to_string(),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn from_bedrock_role(role: &bedrock::ConversationRole) -> Result<Role> {
|
||||||
|
Ok(match role {
|
||||||
|
bedrock::ConversationRole::User => Role::User,
|
||||||
|
bedrock::ConversationRole::Assistant => Role::Assistant,
|
||||||
|
_ => bail!("Unknown role from Bedrock"),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn from_bedrock_usage(usage: &bedrock::TokenUsage) -> Usage {
|
||||||
|
Usage {
|
||||||
|
input_tokens: Some(usage.input_tokens),
|
||||||
|
output_tokens: Some(usage.output_tokens),
|
||||||
|
total_tokens: Some(usage.total_tokens),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn from_bedrock_json(document: &Document) -> Result<Value> {
|
||||||
|
Ok(match document {
|
||||||
|
Document::Null => Value::Null,
|
||||||
|
Document::Bool(bool) => Value::Bool(*bool),
|
||||||
|
Document::Number(num) => match num {
|
||||||
|
Number::PosInt(i) => Value::Number((*i).into()),
|
||||||
|
Number::NegInt(i) => Value::Number((*i).into()),
|
||||||
|
Number::Float(f) => Value::Number(
|
||||||
|
serde_json::Number::from_f64(*f).ok_or(anyhow!("Expected a valid float"))?,
|
||||||
|
),
|
||||||
|
},
|
||||||
|
Document::String(str) => Value::String(str.clone()),
|
||||||
|
Document::Array(arr) => {
|
||||||
|
Value::Array(arr.iter().map(from_bedrock_json).collect::<Result<_>>()?)
|
||||||
|
}
|
||||||
|
Document::Object(obj) => Value::Object(
|
||||||
|
obj.iter()
|
||||||
|
.map(|(key, val)| Ok((key.clone(), from_bedrock_json(val)?)))
|
||||||
|
.collect::<Result<_>>()?,
|
||||||
|
),
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -1,3 +1,4 @@
|
|||||||
pub mod anthropic;
|
pub mod anthropic;
|
||||||
|
pub mod bedrock;
|
||||||
pub mod google;
|
pub mod google;
|
||||||
pub mod openai;
|
pub mod openai;
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
pub mod anthropic;
|
pub mod anthropic;
|
||||||
pub mod azure;
|
pub mod azure;
|
||||||
pub mod base;
|
pub mod base;
|
||||||
|
pub mod bedrock;
|
||||||
pub mod databricks;
|
pub mod databricks;
|
||||||
pub mod errors;
|
pub mod errors;
|
||||||
mod factory;
|
mod factory;
|
||||||
|
|||||||
@@ -3,7 +3,9 @@ use dotenv::dotenv;
|
|||||||
use goose::message::{Message, MessageContent};
|
use goose::message::{Message, MessageContent};
|
||||||
use goose::providers::base::Provider;
|
use goose::providers::base::Provider;
|
||||||
use goose::providers::errors::ProviderError;
|
use goose::providers::errors::ProviderError;
|
||||||
use goose::providers::{anthropic, azure, databricks, google, groq, ollama, openai, openrouter};
|
use goose::providers::{
|
||||||
|
anthropic, azure, bedrock, databricks, google, groq, ollama, openai, openrouter,
|
||||||
|
};
|
||||||
use mcp_core::content::Content;
|
use mcp_core::content::Content;
|
||||||
use mcp_core::tool::Tool;
|
use mcp_core::tool::Tool;
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
@@ -374,6 +376,34 @@ async fn test_azure_provider() -> Result<()> {
|
|||||||
.await
|
.await
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_bedrock_provider_long_term_credentials() -> Result<()> {
|
||||||
|
test_provider(
|
||||||
|
"Bedrock",
|
||||||
|
&["AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY"],
|
||||||
|
None,
|
||||||
|
bedrock::BedrockProvider::default,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_bedrock_provider_aws_profile_credentials() -> Result<()> {
|
||||||
|
let env_mods = HashMap::from_iter([
|
||||||
|
// Ensure to unset long-term credentials to use AWS Profile provider
|
||||||
|
("AWS_ACCESS_KEY_ID", None),
|
||||||
|
("AWS_SECRET_ACCESS_KEY", None),
|
||||||
|
]);
|
||||||
|
|
||||||
|
test_provider(
|
||||||
|
"Bedrock AWS Profile Credentials",
|
||||||
|
&["AWS_PROFILE"],
|
||||||
|
Some(env_mods),
|
||||||
|
bedrock::BedrockProvider::default,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_databricks_provider() -> Result<()> {
|
async fn test_databricks_provider() -> Result<()> {
|
||||||
test_provider(
|
test_provider(
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ use goose::model::ModelConfig;
|
|||||||
use goose::providers::base::Provider;
|
use goose::providers::base::Provider;
|
||||||
use goose::providers::{anthropic::AnthropicProvider, databricks::DatabricksProvider};
|
use goose::providers::{anthropic::AnthropicProvider, databricks::DatabricksProvider};
|
||||||
use goose::providers::{
|
use goose::providers::{
|
||||||
azure::AzureProvider, ollama::OllamaProvider, openai::OpenAiProvider,
|
azure::AzureProvider, bedrock::BedrockProvider, ollama::OllamaProvider, openai::OpenAiProvider,
|
||||||
openrouter::OpenRouterProvider,
|
openrouter::OpenRouterProvider,
|
||||||
};
|
};
|
||||||
use goose::providers::{google::GoogleProvider, groq::GroqProvider};
|
use goose::providers::{google::GoogleProvider, groq::GroqProvider};
|
||||||
@@ -18,6 +18,7 @@ enum ProviderType {
|
|||||||
Azure,
|
Azure,
|
||||||
OpenAi,
|
OpenAi,
|
||||||
Anthropic,
|
Anthropic,
|
||||||
|
Bedrock,
|
||||||
Databricks,
|
Databricks,
|
||||||
Google,
|
Google,
|
||||||
Groq,
|
Groq,
|
||||||
@@ -35,6 +36,7 @@ impl ProviderType {
|
|||||||
],
|
],
|
||||||
ProviderType::OpenAi => &["OPENAI_API_KEY"],
|
ProviderType::OpenAi => &["OPENAI_API_KEY"],
|
||||||
ProviderType::Anthropic => &["ANTHROPIC_API_KEY"],
|
ProviderType::Anthropic => &["ANTHROPIC_API_KEY"],
|
||||||
|
ProviderType::Bedrock => &["AWS_PROFILE", "AWS_REGION"],
|
||||||
ProviderType::Databricks => &["DATABRICKS_HOST"],
|
ProviderType::Databricks => &["DATABRICKS_HOST"],
|
||||||
ProviderType::Google => &["GOOGLE_API_KEY"],
|
ProviderType::Google => &["GOOGLE_API_KEY"],
|
||||||
ProviderType::Groq => &["GROQ_API_KEY"],
|
ProviderType::Groq => &["GROQ_API_KEY"],
|
||||||
@@ -66,6 +68,7 @@ impl ProviderType {
|
|||||||
ProviderType::Azure => Box::new(AzureProvider::from_env(model_config)?),
|
ProviderType::Azure => Box::new(AzureProvider::from_env(model_config)?),
|
||||||
ProviderType::OpenAi => Box::new(OpenAiProvider::from_env(model_config)?),
|
ProviderType::OpenAi => Box::new(OpenAiProvider::from_env(model_config)?),
|
||||||
ProviderType::Anthropic => Box::new(AnthropicProvider::from_env(model_config)?),
|
ProviderType::Anthropic => Box::new(AnthropicProvider::from_env(model_config)?),
|
||||||
|
ProviderType::Bedrock => Box::new(BedrockProvider::from_env(model_config)?),
|
||||||
ProviderType::Databricks => Box::new(DatabricksProvider::from_env(model_config)?),
|
ProviderType::Databricks => Box::new(DatabricksProvider::from_env(model_config)?),
|
||||||
ProviderType::Google => Box::new(GoogleProvider::from_env(model_config)?),
|
ProviderType::Google => Box::new(GoogleProvider::from_env(model_config)?),
|
||||||
ProviderType::Groq => Box::new(GroqProvider::from_env(model_config)?),
|
ProviderType::Groq => Box::new(GroqProvider::from_env(model_config)?),
|
||||||
@@ -200,6 +203,16 @@ mod tests {
|
|||||||
.await
|
.await
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_truncate_agent_with_bedrock() -> Result<()> {
|
||||||
|
run_test_with_config(TestConfig {
|
||||||
|
provider_type: ProviderType::Bedrock,
|
||||||
|
model: "anthropic.claude-3-5-sonnet-20241022-v2:0",
|
||||||
|
context_window: 200_000,
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_truncate_agent_with_databricks() -> Result<()> {
|
async fn test_truncate_agent_with_databricks() -> Result<()> {
|
||||||
run_test_with_config(TestConfig {
|
run_test_with_config(TestConfig {
|
||||||
|
|||||||
Reference in New Issue
Block a user