mirror of
https://github.com/aljazceru/goose.git
synced 2025-12-18 14:44:21 +01:00
feat: Support VertexAI for Claude (#1138)
This commit is contained in:
@@ -46,5 +46,11 @@
|
||||
"description": "Connect to Azure OpenAI Service",
|
||||
"models": ["gpt-4o", "gpt-4o-mini"],
|
||||
"required_keys": ["AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOYMENT_NAME"]
|
||||
},
|
||||
"vertex_ai": {
|
||||
"name": "Vertex AI",
|
||||
"description": "Access variety of AI models through Vertex AI",
|
||||
"models": ["claude-3-5-sonnet-v2@20241022", "claude-3-5-sonnet@20240620"],
|
||||
"required_keys": ["VERTEXAI_PROJECT_ID", "VERTEXAI_REGION"]
|
||||
}
|
||||
}
|
||||
|
||||
@@ -17,6 +17,7 @@ mcp-core = { path = "../mcp-core" }
|
||||
anyhow = "1.0"
|
||||
thiserror = "1.0"
|
||||
futures = "0.3"
|
||||
gcp-sdk-auth = "0.1.1"
|
||||
reqwest = { version = "0.12.9", features = [
|
||||
"rustls-tls",
|
||||
"json",
|
||||
|
||||
@@ -9,6 +9,7 @@ use super::{
|
||||
ollama::OllamaProvider,
|
||||
openai::OpenAiProvider,
|
||||
openrouter::OpenRouterProvider,
|
||||
vertexai::VertexAIProvider,
|
||||
};
|
||||
use crate::model::ModelConfig;
|
||||
use anyhow::Result;
|
||||
@@ -24,6 +25,7 @@ pub fn providers() -> Vec<ProviderMetadata> {
|
||||
OllamaProvider::metadata(),
|
||||
OpenAiProvider::metadata(),
|
||||
OpenRouterProvider::metadata(),
|
||||
VertexAIProvider::metadata(),
|
||||
]
|
||||
}
|
||||
|
||||
@@ -38,6 +40,7 @@ pub fn create(name: &str, model: ModelConfig) -> Result<Box<dyn Provider + Send
|
||||
"ollama" => Ok(Box::new(OllamaProvider::from_env(model)?)),
|
||||
"openrouter" => Ok(Box::new(OpenRouterProvider::from_env(model)?)),
|
||||
"google" => Ok(Box::new(GoogleProvider::from_env(model)?)),
|
||||
"vertex_ai" => Ok(Box::new(VertexAIProvider::from_env(model)?)),
|
||||
_ => Err(anyhow::anyhow!("Unknown provider: {}", name)),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,3 +2,4 @@ pub mod anthropic;
|
||||
pub mod bedrock;
|
||||
pub mod google;
|
||||
pub mod openai;
|
||||
pub mod vertexai;
|
||||
|
||||
67
crates/goose/src/providers/formats/vertexai.rs
Normal file
67
crates/goose/src/providers/formats/vertexai.rs
Normal file
@@ -0,0 +1,67 @@
|
||||
use crate::message::Message;
|
||||
use crate::model::ModelConfig;
|
||||
use crate::providers::base::Usage;
|
||||
use anyhow::Result;
|
||||
use mcp_core::tool::Tool;
|
||||
use serde_json::Value;
|
||||
|
||||
use super::anthropic;
|
||||
|
||||
pub fn create_request(
|
||||
model_config: &ModelConfig,
|
||||
system: &str,
|
||||
messages: &[Message],
|
||||
tools: &[Tool],
|
||||
) -> Result<Value> {
|
||||
match model_config.model_name.as_str() {
|
||||
"claude-3-5-sonnet-v2@20241022" | "claude-3-5-sonnet@20240620" => {
|
||||
create_anthropic_request(model_config, system, messages, tools)
|
||||
}
|
||||
_ => Err(anyhow::anyhow!("Vertex AI only supports Anthropic models")),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn create_anthropic_request(
|
||||
model_config: &ModelConfig,
|
||||
system: &str,
|
||||
messages: &[Message],
|
||||
tools: &[Tool],
|
||||
) -> Result<Value> {
|
||||
let mut request = anthropic::create_request(model_config, system, messages, tools)?;
|
||||
|
||||
// the Vertex AI for Claude API has small differences from the Anthropic API
|
||||
// ref: https://docs.anthropic.com/en/api/claude-on-vertex-ai
|
||||
request.as_object_mut().unwrap().remove("model");
|
||||
request.as_object_mut().unwrap().insert(
|
||||
"anthropic_version".to_string(),
|
||||
Value::String("vertex-2023-10-16".to_string()),
|
||||
);
|
||||
|
||||
Ok(request)
|
||||
}
|
||||
|
||||
pub fn response_to_message(response: Value) -> Result<Message> {
|
||||
anthropic::response_to_message(response)
|
||||
}
|
||||
|
||||
pub fn get_usage(data: &Value) -> Result<Usage> {
|
||||
anthropic::get_usage(data)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_create_request() {
|
||||
let model_config = ModelConfig::new("claude-3-5-sonnet-v2@20241022".to_string());
|
||||
let system = "You are a helpful assistant.";
|
||||
let messages = vec![Message::user().with_text("Hello, how are you?")];
|
||||
let tools = vec![];
|
||||
|
||||
let request = create_request(&model_config, &system, &messages, &tools).unwrap();
|
||||
|
||||
assert!(request.get("anthropic_version").is_some());
|
||||
assert!(request.get("model").is_none());
|
||||
}
|
||||
}
|
||||
@@ -13,5 +13,6 @@ pub mod ollama;
|
||||
pub mod openai;
|
||||
pub mod openrouter;
|
||||
pub mod utils;
|
||||
pub mod vertexai;
|
||||
|
||||
pub use factory::{create, providers};
|
||||
|
||||
189
crates/goose/src/providers/vertexai.rs
Normal file
189
crates/goose/src/providers/vertexai.rs
Normal file
@@ -0,0 +1,189 @@
|
||||
use std::time::Duration;
|
||||
|
||||
use anyhow::Result;
|
||||
use async_trait::async_trait;
|
||||
use gcp_sdk_auth::credentials::create_access_token_credential;
|
||||
use reqwest::Client;
|
||||
use serde_json::Value;
|
||||
|
||||
use crate::message::Message;
|
||||
use crate::model::ModelConfig;
|
||||
use crate::providers::base::{ConfigKey, Provider, ProviderMetadata, ProviderUsage};
|
||||
use crate::providers::errors::ProviderError;
|
||||
use crate::providers::formats::vertexai::{create_request, get_usage, response_to_message};
|
||||
use crate::providers::utils::emit_debug_trace;
|
||||
use mcp_core::tool::Tool;
|
||||
|
||||
pub const VERTEXAI_DEFAULT_MODEL: &str = "claude-3-5-sonnet-v2@20241022";
|
||||
pub const VERTEXAI_KNOWN_MODELS: &[&str] = &[
|
||||
"claude-3-5-sonnet-v2@20241022",
|
||||
"claude-3-5-sonnet@20240620",
|
||||
];
|
||||
pub const VERTEXAI_DOC_URL: &str = "https://cloud.google.com/vertex-ai";
|
||||
pub const VERTEXAI_DEFAULT_REGION: &str = "us-east5";
|
||||
|
||||
#[derive(Debug, serde::Serialize)]
|
||||
pub struct VertexAIProvider {
|
||||
#[serde(skip)]
|
||||
client: Client,
|
||||
host: String,
|
||||
project_id: String,
|
||||
region: String,
|
||||
model: ModelConfig,
|
||||
}
|
||||
|
||||
impl VertexAIProvider {
|
||||
pub fn from_env(model: ModelConfig) -> Result<Self> {
|
||||
let config = crate::config::Config::global();
|
||||
|
||||
let project_id = config.get("VERTEXAI_PROJECT_ID")?;
|
||||
let region = config
|
||||
.get("VERTEXAI_REGION")
|
||||
.unwrap_or_else(|_| VERTEXAI_DEFAULT_REGION.to_string());
|
||||
let host = config
|
||||
.get("VERTEXAI_API_HOST")
|
||||
.unwrap_or_else(|_| format!("https://{}-aiplatform.googleapis.com", region));
|
||||
|
||||
let client = Client::builder()
|
||||
.timeout(Duration::from_secs(600))
|
||||
.build()?;
|
||||
|
||||
Ok(VertexAIProvider {
|
||||
client,
|
||||
host,
|
||||
project_id,
|
||||
region,
|
||||
model,
|
||||
})
|
||||
}
|
||||
|
||||
async fn post(&self, payload: Value) -> Result<Value, ProviderError> {
|
||||
let base_url = url::Url::parse(&self.host)
|
||||
.map_err(|e| ProviderError::RequestFailed(format!("Invalid base URL: {e}")))?;
|
||||
let path = format!(
|
||||
"v1/projects/{}/locations/{}/publishers/{}/models/{}:streamRawPredict",
|
||||
self.project_id,
|
||||
self.region,
|
||||
self.get_model_provider(),
|
||||
self.model.model_name
|
||||
);
|
||||
let url = base_url.join(&path).map_err(|e| {
|
||||
ProviderError::RequestFailed(format!("Failed to construct endpoint URL: {e}"))
|
||||
})?;
|
||||
|
||||
let creds = create_access_token_credential().await.map_err(|e| {
|
||||
ProviderError::RequestFailed(format!("Failed to create access token credential: {}", e))
|
||||
})?;
|
||||
let token = creds.get_token().await.map_err(|e| {
|
||||
ProviderError::RequestFailed(format!("Failed to get access token: {}", e))
|
||||
})?;
|
||||
|
||||
let response = self
|
||||
.client
|
||||
.post(url)
|
||||
.json(&payload)
|
||||
.header("Authorization", format!("Bearer {}", token.token))
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| ProviderError::RequestFailed(format!("Request failed: {}", e)))?;
|
||||
|
||||
let status = response.status();
|
||||
let response_json = response.json::<Value>().await.map_err(|e| {
|
||||
ProviderError::RequestFailed(format!("Failed to parse response: {}", e))
|
||||
})?;
|
||||
|
||||
match status {
|
||||
reqwest::StatusCode::OK => Ok(response_json),
|
||||
reqwest::StatusCode::UNAUTHORIZED | reqwest::StatusCode::FORBIDDEN => {
|
||||
tracing::debug!(
|
||||
"{}",
|
||||
format!(
|
||||
"Provider request failed with status: {}. Payload: {:?}",
|
||||
status, payload
|
||||
)
|
||||
);
|
||||
Err(ProviderError::Authentication(format!(
|
||||
"Authentication failed: {:?}",
|
||||
response_json
|
||||
)))
|
||||
}
|
||||
_ => {
|
||||
tracing::debug!(
|
||||
"{}",
|
||||
format!("Request failed with status {}: {:?}", status, response_json)
|
||||
);
|
||||
Err(ProviderError::RequestFailed(format!(
|
||||
"Request failed with status {}: {:?}",
|
||||
status, response_json
|
||||
)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn get_model_provider(&self) -> String {
|
||||
// TODO: switch this by model_name
|
||||
"anthropic".to_string()
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for VertexAIProvider {
|
||||
fn default() -> Self {
|
||||
let model = ModelConfig::new(Self::metadata().default_model);
|
||||
VertexAIProvider::from_env(model).expect("Failed to initialize VertexAI provider")
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Provider for VertexAIProvider {
|
||||
fn metadata() -> ProviderMetadata
|
||||
where
|
||||
Self: Sized,
|
||||
{
|
||||
ProviderMetadata::new(
|
||||
"vertex_ai",
|
||||
"Vertex AI",
|
||||
"Access variety of AI models such as Claude through Vertex AI",
|
||||
VERTEXAI_DEFAULT_MODEL,
|
||||
VERTEXAI_KNOWN_MODELS
|
||||
.iter()
|
||||
.map(|&s| s.to_string())
|
||||
.collect(),
|
||||
VERTEXAI_DOC_URL,
|
||||
vec![
|
||||
ConfigKey::new("VERTEXAI_PROJECT_ID", true, false, None),
|
||||
ConfigKey::new(
|
||||
"VERTEXAI_REGION",
|
||||
true,
|
||||
false,
|
||||
Some(VERTEXAI_DEFAULT_REGION),
|
||||
),
|
||||
],
|
||||
)
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
skip(self, system, messages, tools),
|
||||
fields(model_config, input, output, input_tokens, output_tokens, total_tokens)
|
||||
)]
|
||||
async fn complete(
|
||||
&self,
|
||||
system: &str,
|
||||
messages: &[Message],
|
||||
tools: &[Tool],
|
||||
) -> Result<(Message, ProviderUsage), ProviderError> {
|
||||
let request = create_request(&self.model, system, messages, tools)?;
|
||||
let response = self.post(request.clone()).await?;
|
||||
let usage = get_usage(&response)?;
|
||||
|
||||
emit_debug_trace(self, &request, &response, &usage);
|
||||
|
||||
let message = response_to_message(response.clone())?;
|
||||
let provider_usage = ProviderUsage::new(self.model.model_name.clone(), usage);
|
||||
|
||||
Ok((message, provider_usage))
|
||||
}
|
||||
|
||||
fn get_model_config(&self) -> ModelConfig {
|
||||
self.model.clone()
|
||||
}
|
||||
}
|
||||
@@ -6,6 +6,7 @@ use goose::agents::AgentFactory;
|
||||
use goose::message::Message;
|
||||
use goose::model::ModelConfig;
|
||||
use goose::providers::base::Provider;
|
||||
use goose::providers::vertexai::VertexAIProvider;
|
||||
use goose::providers::{anthropic::AnthropicProvider, databricks::DatabricksProvider};
|
||||
use goose::providers::{
|
||||
azure::AzureProvider, bedrock::BedrockProvider, ollama::OllamaProvider, openai::OpenAiProvider,
|
||||
@@ -24,6 +25,7 @@ enum ProviderType {
|
||||
Groq,
|
||||
Ollama,
|
||||
OpenRouter,
|
||||
VertexAI,
|
||||
}
|
||||
|
||||
impl ProviderType {
|
||||
@@ -42,6 +44,7 @@ impl ProviderType {
|
||||
ProviderType::Groq => &["GROQ_API_KEY"],
|
||||
ProviderType::Ollama => &[],
|
||||
ProviderType::OpenRouter => &["OPENROUTER_API_KEY"],
|
||||
ProviderType::VertexAI => &["VERTEXAI_PROJECT_ID", "VERTEXAI_REGION"],
|
||||
}
|
||||
}
|
||||
|
||||
@@ -74,6 +77,7 @@ impl ProviderType {
|
||||
ProviderType::Groq => Box::new(GroqProvider::from_env(model_config)?),
|
||||
ProviderType::Ollama => Box::new(OllamaProvider::from_env(model_config)?),
|
||||
ProviderType::OpenRouter => Box::new(OpenRouterProvider::from_env(model_config)?),
|
||||
ProviderType::VertexAI => Box::new(VertexAIProvider::from_env(model_config)?),
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -290,4 +294,14 @@ mod tests {
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_truncate_agent_with_vertexai() -> Result<()> {
|
||||
run_test_with_config(TestConfig {
|
||||
provider_type: ProviderType::VertexAI,
|
||||
model: "claude-3-5-sonnet-v2@20241022",
|
||||
context_window: 200_000,
|
||||
})
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,6 +8,8 @@ export function isSecretKey(keyName: string): boolean {
|
||||
'OLLAMA_HOST',
|
||||
'AZURE_OPENAI_ENDPOINT',
|
||||
'AZURE_OPENAI_DEPLOYMENT_NAME',
|
||||
'VERTEXAI_PROJECT_ID',
|
||||
'VERTEXAI_REGION',
|
||||
];
|
||||
return !nonSecretKeys.includes(keyName);
|
||||
}
|
||||
|
||||
@@ -19,6 +19,8 @@ export const goose_models: Model[] = [
|
||||
{ id: 17, name: 'qwen2.5', provider: 'Ollama' },
|
||||
{ id: 18, name: 'anthropic/claude-3.5-sonnet', provider: 'OpenRouter' },
|
||||
{ id: 19, name: 'gpt-4o', provider: 'Azure OpenAI' },
|
||||
{ id: 20, name: 'claude-3-5-sonnet-v2@20241022', provider: 'Vertex AI' },
|
||||
{ id: 21, name: 'claude-3-5-sonnet@20240620', provider: 'Vertex AI' },
|
||||
];
|
||||
|
||||
export const openai_models = ['gpt-4o-mini', 'gpt-4o', 'gpt-4-turbo', 'o1'];
|
||||
@@ -47,6 +49,8 @@ export const openrouter_models = ['anthropic/claude-3.5-sonnet'];
|
||||
|
||||
export const azure_openai_models = ['gpt-4o'];
|
||||
|
||||
export const vertexai_models = ['claude-3-5-sonnet-v2@20241022', 'claude-3-5-sonnet@20240620'];
|
||||
|
||||
export const default_models = {
|
||||
openai: 'gpt-4o',
|
||||
anthropic: 'claude-3-5-sonnet-latest',
|
||||
@@ -56,6 +60,7 @@ export const default_models = {
|
||||
openrouter: 'anthropic/claude-3.5-sonnet',
|
||||
ollama: 'qwen2.5',
|
||||
azure_openai: 'gpt-4o',
|
||||
vertex_ai: 'claude-3-5-sonnet-v2@20241022',
|
||||
};
|
||||
|
||||
export function getDefaultModel(key: string): string | undefined {
|
||||
@@ -73,6 +78,7 @@ export const required_keys = {
|
||||
Google: ['GOOGLE_API_KEY'],
|
||||
OpenRouter: ['OPENROUTER_API_KEY'],
|
||||
'Azure OpenAI': ['AZURE_OPENAI_API_KEY', 'AZURE_OPENAI_ENDPOINT', 'AZURE_OPENAI_DEPLOYMENT_NAME'],
|
||||
'Vertex AI': ['VERTEXAI_PROJECT_ID', 'VERTEXAI_REGION'],
|
||||
};
|
||||
|
||||
export const supported_providers = [
|
||||
@@ -84,6 +90,7 @@ export const supported_providers = [
|
||||
'Ollama',
|
||||
'OpenRouter',
|
||||
'Azure OpenAI',
|
||||
'Vertex AI',
|
||||
];
|
||||
|
||||
export const model_docs_link = [
|
||||
@@ -97,6 +104,7 @@ export const model_docs_link = [
|
||||
},
|
||||
{ name: 'OpenRouter', href: 'https://openrouter.ai/models' },
|
||||
{ name: 'Ollama', href: 'https://ollama.com/library' },
|
||||
{ name: 'Vertex AI', href: 'https://cloud.google.com/vertex-ai' },
|
||||
];
|
||||
|
||||
export const provider_aliases = [
|
||||
@@ -108,4 +116,5 @@ export const provider_aliases = [
|
||||
{ provider: 'OpenRouter', alias: 'openrouter' },
|
||||
{ provider: 'Google', alias: 'google' },
|
||||
{ provider: 'Azure OpenAI', alias: 'azure_openai' },
|
||||
{ provider: 'Vertex AI', alias: 'vertex_ai' },
|
||||
];
|
||||
|
||||
Reference in New Issue
Block a user