mirror of
https://github.com/aljazceru/goose.git
synced 2025-12-18 14:44:21 +01:00
feat: add bedrock provider (#1069)
Co-authored-by: Burak Varlı <unexge@gmail.com>
This commit is contained in:
@@ -61,6 +61,11 @@ once_cell = "1.20.2"
|
||||
dirs = "6.0.0"
|
||||
rand = "0.8.5"
|
||||
|
||||
# For Bedrock provider
|
||||
aws-config = { version = "1.1.7", features = ["behavior-version-latest"] }
|
||||
aws-smithy-types = "1.2.12"
|
||||
aws-sdk-bedrockruntime = "1.72.0"
|
||||
|
||||
[dev-dependencies]
|
||||
criterion = "0.5"
|
||||
tempfile = "3.15.0"
|
||||
|
||||
162
crates/goose/src/providers/bedrock.rs
Normal file
162
crates/goose/src/providers/bedrock.rs
Normal file
@@ -0,0 +1,162 @@
|
||||
use anyhow::Result;
|
||||
use async_trait::async_trait;
|
||||
use aws_sdk_bedrockruntime::operation::converse::ConverseError;
|
||||
use aws_sdk_bedrockruntime::{types as bedrock, Client};
|
||||
use mcp_core::Tool;
|
||||
|
||||
use super::base::{Provider, ProviderMetadata, ProviderUsage};
|
||||
use super::errors::ProviderError;
|
||||
use crate::message::Message;
|
||||
use crate::model::ModelConfig;
|
||||
use crate::providers::utils::emit_debug_trace;
|
||||
|
||||
// Import the migrated helper functions from providers/formats/bedrock.rs
|
||||
use super::formats::bedrock::{
|
||||
from_bedrock_message, from_bedrock_usage, to_bedrock_message, to_bedrock_tool_config,
|
||||
};
|
||||
|
||||
pub const BEDROCK_DOC_LINK: &str =
|
||||
"https://docs.aws.amazon.com/bedrock/latest/userguide/models-supported.html";
|
||||
|
||||
pub const BEDROCK_DEFAULT_MODEL: &str = "anthropic.claude-3-5-sonnet-20240620-v1:0";
|
||||
pub const BEDROCK_KNOWN_MODELS: &[&str] = &[
|
||||
"anthropic.claude-3-5-sonnet-20240620-v1:0",
|
||||
"anthropic.claude-3-5-sonnet-20241022-v2:0",
|
||||
];
|
||||
|
||||
#[derive(Debug, serde::Serialize)]
|
||||
pub struct BedrockProvider {
|
||||
#[serde(skip)]
|
||||
client: Client,
|
||||
model: ModelConfig,
|
||||
}
|
||||
|
||||
impl BedrockProvider {
|
||||
pub fn from_env(model: ModelConfig) -> Result<Self> {
|
||||
let sdk_config = futures::executor::block_on(aws_config::load_from_env());
|
||||
let client = Client::new(&sdk_config);
|
||||
|
||||
Ok(Self { client, model })
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for BedrockProvider {
|
||||
fn default() -> Self {
|
||||
let model = ModelConfig::new(BedrockProvider::metadata().default_model);
|
||||
BedrockProvider::from_env(model).expect("Failed to initialize Bedrock provider")
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Provider for BedrockProvider {
|
||||
fn metadata() -> ProviderMetadata {
|
||||
ProviderMetadata::new(
|
||||
"bedrock",
|
||||
"Amazon Bedrock",
|
||||
"Run models through Amazon Bedrock. You may have to set AWS_ACCESS_KEY_ID, AWS_ACCESS_KEY, and AWS_REGION as env vars before configuring.",
|
||||
BEDROCK_DEFAULT_MODEL,
|
||||
BEDROCK_KNOWN_MODELS.iter().map(|s| s.to_string()).collect(),
|
||||
BEDROCK_DOC_LINK,
|
||||
vec![],
|
||||
)
|
||||
}
|
||||
|
||||
fn get_model_config(&self) -> ModelConfig {
|
||||
self.model.clone()
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
skip(self, system, messages, tools),
|
||||
fields(model_config, input, output, input_tokens, output_tokens, total_tokens)
|
||||
)]
|
||||
async fn complete(
|
||||
&self,
|
||||
system: &str,
|
||||
messages: &[Message],
|
||||
tools: &[Tool],
|
||||
) -> Result<(Message, ProviderUsage), ProviderError> {
|
||||
let model_name = &self.model.model_name;
|
||||
|
||||
let mut request = self
|
||||
.client
|
||||
.converse()
|
||||
.system(bedrock::SystemContentBlock::Text(system.to_string()))
|
||||
.model_id(model_name.to_string())
|
||||
.set_messages(Some(
|
||||
messages
|
||||
.iter()
|
||||
.map(to_bedrock_message)
|
||||
.collect::<Result<_>>()?,
|
||||
));
|
||||
|
||||
if !tools.is_empty() {
|
||||
request = request.tool_config(to_bedrock_tool_config(tools)?);
|
||||
}
|
||||
|
||||
let response = request.send().await;
|
||||
|
||||
let response = match response {
|
||||
Ok(response) => response,
|
||||
Err(err) => {
|
||||
return Err(match err.into_service_error() {
|
||||
ConverseError::AccessDeniedException(err) => {
|
||||
ProviderError::Authentication(format!("Failed to call Bedrock: {:?}", err))
|
||||
}
|
||||
ConverseError::ThrottlingException(err) => ProviderError::RateLimitExceeded(
|
||||
format!("Failed to call Bedrock: {:?}", err),
|
||||
),
|
||||
ConverseError::ValidationException(err)
|
||||
if err
|
||||
.message()
|
||||
.unwrap_or_default()
|
||||
.contains("Input is too long for requested model.") =>
|
||||
{
|
||||
ProviderError::ContextLengthExceeded(format!(
|
||||
"Failed to call Bedrock: {:?}",
|
||||
err
|
||||
))
|
||||
}
|
||||
ConverseError::ModelErrorException(err) => {
|
||||
ProviderError::ExecutionError(format!("Failed to call Bedrock: {:?}", err))
|
||||
}
|
||||
err => {
|
||||
ProviderError::ServerError(format!("Failed to call Bedrock: {:?}", err,))
|
||||
}
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
let message = match response.output {
|
||||
Some(bedrock::ConverseOutput::Message(message)) => message,
|
||||
_ => {
|
||||
return Err(ProviderError::RequestFailed(
|
||||
"No output from Bedrock".to_string(),
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
let usage = response
|
||||
.usage
|
||||
.as_ref()
|
||||
.map(from_bedrock_usage)
|
||||
.unwrap_or_default();
|
||||
|
||||
let message = from_bedrock_message(&message)?;
|
||||
|
||||
// Add debug trace with input context
|
||||
let debug_payload = serde_json::json!({
|
||||
"system": system,
|
||||
"messages": messages,
|
||||
"tools": tools
|
||||
});
|
||||
emit_debug_trace(
|
||||
&self.model,
|
||||
&debug_payload,
|
||||
&serde_json::to_value(&message).unwrap_or_default(),
|
||||
&usage,
|
||||
);
|
||||
|
||||
let provider_usage = ProviderUsage::new(model_name.to_string(), usage);
|
||||
Ok((message, provider_usage))
|
||||
}
|
||||
}
|
||||
@@ -2,6 +2,7 @@ use super::{
|
||||
anthropic::AnthropicProvider,
|
||||
azure::AzureProvider,
|
||||
base::{Provider, ProviderMetadata},
|
||||
bedrock::BedrockProvider,
|
||||
databricks::DatabricksProvider,
|
||||
google::GoogleProvider,
|
||||
groq::GroqProvider,
|
||||
@@ -16,6 +17,7 @@ pub fn providers() -> Vec<ProviderMetadata> {
|
||||
vec![
|
||||
AnthropicProvider::metadata(),
|
||||
AzureProvider::metadata(),
|
||||
BedrockProvider::metadata(),
|
||||
DatabricksProvider::metadata(),
|
||||
GoogleProvider::metadata(),
|
||||
GroqProvider::metadata(),
|
||||
@@ -30,6 +32,7 @@ pub fn create(name: &str, model: ModelConfig) -> Result<Box<dyn Provider + Send
|
||||
"openai" => Ok(Box::new(OpenAiProvider::from_env(model)?)),
|
||||
"anthropic" => Ok(Box::new(AnthropicProvider::from_env(model)?)),
|
||||
"azure_openai" => Ok(Box::new(AzureProvider::from_env(model)?)),
|
||||
"bedrock" => Ok(Box::new(BedrockProvider::from_env(model)?)),
|
||||
"databricks" => Ok(Box::new(DatabricksProvider::from_env(model)?)),
|
||||
"groq" => Ok(Box::new(GroqProvider::from_env(model)?)),
|
||||
"ollama" => Ok(Box::new(OllamaProvider::from_env(model)?)),
|
||||
|
||||
270
crates/goose/src/providers/formats/bedrock.rs
Normal file
270
crates/goose/src/providers/formats/bedrock.rs
Normal file
@@ -0,0 +1,270 @@
|
||||
use std::collections::HashMap;
|
||||
use std::path::Path;
|
||||
|
||||
use anyhow::{anyhow, bail, Result};
|
||||
use aws_sdk_bedrockruntime::types as bedrock;
|
||||
use aws_smithy_types::{Document, Number};
|
||||
use chrono::Utc;
|
||||
use mcp_core::{Content, ResourceContents, Role, Tool, ToolCall, ToolError, ToolResult};
|
||||
use serde_json::Value;
|
||||
|
||||
use super::super::base::Usage;
|
||||
use crate::message::{Message, MessageContent};
|
||||
|
||||
pub fn to_bedrock_message(message: &Message) -> Result<bedrock::Message> {
|
||||
bedrock::Message::builder()
|
||||
.role(to_bedrock_role(&message.role))
|
||||
.set_content(Some(
|
||||
message
|
||||
.content
|
||||
.iter()
|
||||
.map(to_bedrock_message_content)
|
||||
.collect::<Result<_>>()?,
|
||||
))
|
||||
.build()
|
||||
.map_err(|err| anyhow!("Failed to construct Bedrock message: {}", err))
|
||||
}
|
||||
|
||||
pub fn to_bedrock_message_content(content: &MessageContent) -> Result<bedrock::ContentBlock> {
|
||||
Ok(match content {
|
||||
MessageContent::Text(text) => bedrock::ContentBlock::Text(text.text.to_string()),
|
||||
MessageContent::Image(_) => {
|
||||
bail!("Image content is not supported by Bedrock provider yet")
|
||||
}
|
||||
MessageContent::ToolRequest(tool_req) => {
|
||||
let tool_use_id = tool_req.id.to_string();
|
||||
let tool_use = if let Ok(call) = tool_req.tool_call.as_ref() {
|
||||
bedrock::ToolUseBlock::builder()
|
||||
.tool_use_id(tool_use_id)
|
||||
.name(call.name.to_string())
|
||||
.input(to_bedrock_json(&call.arguments))
|
||||
.build()
|
||||
} else {
|
||||
bedrock::ToolUseBlock::builder()
|
||||
.tool_use_id(tool_use_id)
|
||||
.build()
|
||||
}?;
|
||||
bedrock::ContentBlock::ToolUse(tool_use)
|
||||
}
|
||||
MessageContent::ToolResponse(tool_res) => {
|
||||
let content = match &tool_res.tool_result {
|
||||
Ok(content) => Some(
|
||||
content
|
||||
.iter()
|
||||
.map(|c| to_bedrock_tool_result_content_block(&tool_res.id, c))
|
||||
.collect::<Result<_>>()?,
|
||||
),
|
||||
Err(_) => None,
|
||||
};
|
||||
bedrock::ContentBlock::ToolResult(
|
||||
bedrock::ToolResultBlock::builder()
|
||||
.tool_use_id(tool_res.id.to_string())
|
||||
.status(if content.is_some() {
|
||||
bedrock::ToolResultStatus::Success
|
||||
} else {
|
||||
bedrock::ToolResultStatus::Error
|
||||
})
|
||||
.set_content(content)
|
||||
.build()?,
|
||||
)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
pub fn to_bedrock_tool_result_content_block(
|
||||
tool_use_id: &str,
|
||||
content: &Content,
|
||||
) -> Result<bedrock::ToolResultContentBlock> {
|
||||
Ok(match content {
|
||||
Content::Text(text) => bedrock::ToolResultContentBlock::Text(text.text.to_string()),
|
||||
Content::Image(_) => bail!("Image content is not supported by Bedrock provider yet"),
|
||||
Content::Resource(resource) => bedrock::ToolResultContentBlock::Document(
|
||||
to_bedrock_document(tool_use_id, &resource.resource)?,
|
||||
),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn to_bedrock_role(role: &Role) -> bedrock::ConversationRole {
|
||||
match role {
|
||||
Role::User => bedrock::ConversationRole::User,
|
||||
Role::Assistant => bedrock::ConversationRole::Assistant,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn to_bedrock_tool_config(tools: &[Tool]) -> Result<bedrock::ToolConfiguration> {
|
||||
Ok(bedrock::ToolConfiguration::builder()
|
||||
.set_tools(Some(
|
||||
tools.iter().map(to_bedrock_tool).collect::<Result<_>>()?,
|
||||
))
|
||||
.build()?)
|
||||
}
|
||||
|
||||
pub fn to_bedrock_tool(tool: &Tool) -> Result<bedrock::Tool> {
|
||||
Ok(bedrock::Tool::ToolSpec(
|
||||
bedrock::ToolSpecification::builder()
|
||||
.name(tool.name.to_string())
|
||||
.description(tool.description.to_string())
|
||||
.input_schema(bedrock::ToolInputSchema::Json(to_bedrock_json(
|
||||
&tool.input_schema,
|
||||
)))
|
||||
.build()?,
|
||||
))
|
||||
}
|
||||
|
||||
pub fn to_bedrock_json(value: &Value) -> Document {
|
||||
match value {
|
||||
Value::Null => Document::Null,
|
||||
Value::Bool(bool) => Document::Bool(*bool),
|
||||
Value::Number(num) => {
|
||||
if let Some(n) = num.as_u64() {
|
||||
Document::Number(Number::PosInt(n))
|
||||
} else if let Some(n) = num.as_i64() {
|
||||
Document::Number(Number::NegInt(n))
|
||||
} else if let Some(n) = num.as_f64() {
|
||||
Document::Number(Number::Float(n))
|
||||
} else {
|
||||
unreachable!()
|
||||
}
|
||||
}
|
||||
Value::String(str) => Document::String(str.to_string()),
|
||||
Value::Array(arr) => Document::Array(arr.iter().map(to_bedrock_json).collect()),
|
||||
Value::Object(obj) => Document::Object(HashMap::from_iter(
|
||||
obj.into_iter()
|
||||
.map(|(key, val)| (key.to_string(), to_bedrock_json(val))),
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
fn to_bedrock_document(
|
||||
tool_use_id: &str,
|
||||
content: &ResourceContents,
|
||||
) -> Result<bedrock::DocumentBlock> {
|
||||
let (uri, text) = match content {
|
||||
ResourceContents::TextResourceContents { uri, text, .. } => (uri, text),
|
||||
ResourceContents::BlobResourceContents { .. } => {
|
||||
bail!("Blob resource content is not supported by Bedrock provider yet")
|
||||
}
|
||||
};
|
||||
|
||||
let filename = Path::new(uri)
|
||||
.file_name()
|
||||
.and_then(|n| n.to_str())
|
||||
.unwrap_or(uri);
|
||||
|
||||
let (name, format) = match filename.split_once('.') {
|
||||
Some((name, "txt")) => (name, bedrock::DocumentFormat::Txt),
|
||||
Some((name, "csv")) => (name, bedrock::DocumentFormat::Csv),
|
||||
Some((name, "md")) => (name, bedrock::DocumentFormat::Md),
|
||||
Some((name, "html")) => (name, bedrock::DocumentFormat::Html),
|
||||
Some((name, _)) => (name, bedrock::DocumentFormat::Txt),
|
||||
_ => (filename, bedrock::DocumentFormat::Txt),
|
||||
};
|
||||
|
||||
// Since we can't use the full path (due to character limit and also Bedrock does not accept `/` etc.),
|
||||
// and Bedrock wants document names to be unique, we're adding `tool_use_id` as a prefix to make
|
||||
// document names unique.
|
||||
let name = format!("{tool_use_id}-{name}");
|
||||
|
||||
bedrock::DocumentBlock::builder()
|
||||
.format(format)
|
||||
.name(name)
|
||||
.source(bedrock::DocumentSource::Bytes(text.as_bytes().into()))
|
||||
.build()
|
||||
.map_err(|err| anyhow!("Failed to construct Bedrock document: {}", err))
|
||||
}
|
||||
|
||||
pub fn from_bedrock_message(message: &bedrock::Message) -> Result<Message> {
|
||||
let role = from_bedrock_role(message.role())?;
|
||||
let content = message
|
||||
.content()
|
||||
.iter()
|
||||
.map(from_bedrock_content_block)
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
let created = Utc::now().timestamp();
|
||||
|
||||
Ok(Message {
|
||||
role,
|
||||
content,
|
||||
created,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn from_bedrock_content_block(block: &bedrock::ContentBlock) -> Result<MessageContent> {
|
||||
Ok(match block {
|
||||
bedrock::ContentBlock::Text(text) => MessageContent::text(text),
|
||||
bedrock::ContentBlock::ToolUse(tool_use) => MessageContent::tool_request(
|
||||
tool_use.tool_use_id.to_string(),
|
||||
Ok(ToolCall::new(
|
||||
tool_use.name.to_string(),
|
||||
from_bedrock_json(&tool_use.input)?,
|
||||
)),
|
||||
),
|
||||
bedrock::ContentBlock::ToolResult(tool_res) => MessageContent::tool_response(
|
||||
tool_res.tool_use_id.to_string(),
|
||||
if tool_res.content.is_empty() {
|
||||
Err(ToolError::ExecutionError(
|
||||
"Empty content for tool use from Bedrock".to_string(),
|
||||
))
|
||||
} else {
|
||||
tool_res
|
||||
.content
|
||||
.iter()
|
||||
.map(from_bedrock_tool_result_content_block)
|
||||
.collect::<ToolResult<Vec<_>>>()
|
||||
},
|
||||
),
|
||||
_ => bail!("Unsupported content block type from Bedrock"),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn from_bedrock_tool_result_content_block(
|
||||
content: &bedrock::ToolResultContentBlock,
|
||||
) -> ToolResult<Content> {
|
||||
Ok(match content {
|
||||
bedrock::ToolResultContentBlock::Text(text) => Content::text(text.to_string()),
|
||||
_ => {
|
||||
return Err(ToolError::ExecutionError(
|
||||
"Unsupported tool result from Bedrock".to_string(),
|
||||
))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
pub fn from_bedrock_role(role: &bedrock::ConversationRole) -> Result<Role> {
|
||||
Ok(match role {
|
||||
bedrock::ConversationRole::User => Role::User,
|
||||
bedrock::ConversationRole::Assistant => Role::Assistant,
|
||||
_ => bail!("Unknown role from Bedrock"),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn from_bedrock_usage(usage: &bedrock::TokenUsage) -> Usage {
|
||||
Usage {
|
||||
input_tokens: Some(usage.input_tokens),
|
||||
output_tokens: Some(usage.output_tokens),
|
||||
total_tokens: Some(usage.total_tokens),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn from_bedrock_json(document: &Document) -> Result<Value> {
|
||||
Ok(match document {
|
||||
Document::Null => Value::Null,
|
||||
Document::Bool(bool) => Value::Bool(*bool),
|
||||
Document::Number(num) => match num {
|
||||
Number::PosInt(i) => Value::Number((*i).into()),
|
||||
Number::NegInt(i) => Value::Number((*i).into()),
|
||||
Number::Float(f) => Value::Number(
|
||||
serde_json::Number::from_f64(*f).ok_or(anyhow!("Expected a valid float"))?,
|
||||
),
|
||||
},
|
||||
Document::String(str) => Value::String(str.clone()),
|
||||
Document::Array(arr) => {
|
||||
Value::Array(arr.iter().map(from_bedrock_json).collect::<Result<_>>()?)
|
||||
}
|
||||
Document::Object(obj) => Value::Object(
|
||||
obj.iter()
|
||||
.map(|(key, val)| Ok((key.clone(), from_bedrock_json(val)?)))
|
||||
.collect::<Result<_>>()?,
|
||||
),
|
||||
})
|
||||
}
|
||||
@@ -1,3 +1,4 @@
|
||||
pub mod anthropic;
|
||||
pub mod bedrock;
|
||||
pub mod google;
|
||||
pub mod openai;
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
pub mod anthropic;
|
||||
pub mod azure;
|
||||
pub mod base;
|
||||
pub mod bedrock;
|
||||
pub mod databricks;
|
||||
pub mod errors;
|
||||
mod factory;
|
||||
|
||||
@@ -3,7 +3,9 @@ use dotenv::dotenv;
|
||||
use goose::message::{Message, MessageContent};
|
||||
use goose::providers::base::Provider;
|
||||
use goose::providers::errors::ProviderError;
|
||||
use goose::providers::{anthropic, azure, databricks, google, groq, ollama, openai, openrouter};
|
||||
use goose::providers::{
|
||||
anthropic, azure, bedrock, databricks, google, groq, ollama, openai, openrouter,
|
||||
};
|
||||
use mcp_core::content::Content;
|
||||
use mcp_core::tool::Tool;
|
||||
use std::collections::HashMap;
|
||||
@@ -374,6 +376,34 @@ async fn test_azure_provider() -> Result<()> {
|
||||
.await
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_bedrock_provider_long_term_credentials() -> Result<()> {
|
||||
test_provider(
|
||||
"Bedrock",
|
||||
&["AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY"],
|
||||
None,
|
||||
bedrock::BedrockProvider::default,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_bedrock_provider_aws_profile_credentials() -> Result<()> {
|
||||
let env_mods = HashMap::from_iter([
|
||||
// Ensure to unset long-term credentials to use AWS Profile provider
|
||||
("AWS_ACCESS_KEY_ID", None),
|
||||
("AWS_SECRET_ACCESS_KEY", None),
|
||||
]);
|
||||
|
||||
test_provider(
|
||||
"Bedrock AWS Profile Credentials",
|
||||
&["AWS_PROFILE"],
|
||||
Some(env_mods),
|
||||
bedrock::BedrockProvider::default,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_databricks_provider() -> Result<()> {
|
||||
test_provider(
|
||||
|
||||
@@ -8,7 +8,7 @@ use goose::model::ModelConfig;
|
||||
use goose::providers::base::Provider;
|
||||
use goose::providers::{anthropic::AnthropicProvider, databricks::DatabricksProvider};
|
||||
use goose::providers::{
|
||||
azure::AzureProvider, ollama::OllamaProvider, openai::OpenAiProvider,
|
||||
azure::AzureProvider, bedrock::BedrockProvider, ollama::OllamaProvider, openai::OpenAiProvider,
|
||||
openrouter::OpenRouterProvider,
|
||||
};
|
||||
use goose::providers::{google::GoogleProvider, groq::GroqProvider};
|
||||
@@ -18,6 +18,7 @@ enum ProviderType {
|
||||
Azure,
|
||||
OpenAi,
|
||||
Anthropic,
|
||||
Bedrock,
|
||||
Databricks,
|
||||
Google,
|
||||
Groq,
|
||||
@@ -35,6 +36,7 @@ impl ProviderType {
|
||||
],
|
||||
ProviderType::OpenAi => &["OPENAI_API_KEY"],
|
||||
ProviderType::Anthropic => &["ANTHROPIC_API_KEY"],
|
||||
ProviderType::Bedrock => &["AWS_PROFILE", "AWS_REGION"],
|
||||
ProviderType::Databricks => &["DATABRICKS_HOST"],
|
||||
ProviderType::Google => &["GOOGLE_API_KEY"],
|
||||
ProviderType::Groq => &["GROQ_API_KEY"],
|
||||
@@ -66,6 +68,7 @@ impl ProviderType {
|
||||
ProviderType::Azure => Box::new(AzureProvider::from_env(model_config)?),
|
||||
ProviderType::OpenAi => Box::new(OpenAiProvider::from_env(model_config)?),
|
||||
ProviderType::Anthropic => Box::new(AnthropicProvider::from_env(model_config)?),
|
||||
ProviderType::Bedrock => Box::new(BedrockProvider::from_env(model_config)?),
|
||||
ProviderType::Databricks => Box::new(DatabricksProvider::from_env(model_config)?),
|
||||
ProviderType::Google => Box::new(GoogleProvider::from_env(model_config)?),
|
||||
ProviderType::Groq => Box::new(GroqProvider::from_env(model_config)?),
|
||||
@@ -200,6 +203,16 @@ mod tests {
|
||||
.await
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_truncate_agent_with_bedrock() -> Result<()> {
|
||||
run_test_with_config(TestConfig {
|
||||
provider_type: ProviderType::Bedrock,
|
||||
model: "anthropic.claude-3-5-sonnet-20241022-v2:0",
|
||||
context_window: 200_000,
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_truncate_agent_with_databricks() -> Result<()> {
|
||||
run_test_with_config(TestConfig {
|
||||
|
||||
Reference in New Issue
Block a user