feat: Support VertexAI for Claude (#1138)

This commit is contained in:
Yuku Kotani
2025-02-13 02:17:15 +09:00
committed by GitHub
parent 724e5ac9a9
commit 58c9eeb6d6
10 changed files with 293 additions and 0 deletions

View File

@@ -46,5 +46,11 @@
"description": "Connect to Azure OpenAI Service",
"models": ["gpt-4o", "gpt-4o-mini"],
"required_keys": ["AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOYMENT_NAME"]
},
"vertex_ai": {
"name": "Vertex AI",
"description": "Access variety of AI models through Vertex AI",
"models": ["claude-3-5-sonnet-v2@20241022", "claude-3-5-sonnet@20240620"],
"required_keys": ["VERTEXAI_PROJECT_ID", "VERTEXAI_REGION"]
}
}

View File

@@ -17,6 +17,7 @@ mcp-core = { path = "../mcp-core" }
anyhow = "1.0"
thiserror = "1.0"
futures = "0.3"
gcp-sdk-auth = "0.1.1"
reqwest = { version = "0.12.9", features = [
"rustls-tls",
"json",

View File

@@ -9,6 +9,7 @@ use super::{
ollama::OllamaProvider,
openai::OpenAiProvider,
openrouter::OpenRouterProvider,
vertexai::VertexAIProvider,
};
use crate::model::ModelConfig;
use anyhow::Result;
@@ -24,6 +25,7 @@ pub fn providers() -> Vec<ProviderMetadata> {
OllamaProvider::metadata(),
OpenAiProvider::metadata(),
OpenRouterProvider::metadata(),
VertexAIProvider::metadata(),
]
}
@@ -38,6 +40,7 @@ pub fn create(name: &str, model: ModelConfig) -> Result<Box<dyn Provider + Send
"ollama" => Ok(Box::new(OllamaProvider::from_env(model)?)),
"openrouter" => Ok(Box::new(OpenRouterProvider::from_env(model)?)),
"google" => Ok(Box::new(GoogleProvider::from_env(model)?)),
"vertex_ai" => Ok(Box::new(VertexAIProvider::from_env(model)?)),
_ => Err(anyhow::anyhow!("Unknown provider: {}", name)),
}
}

View File

@@ -2,3 +2,4 @@ pub mod anthropic;
pub mod bedrock;
pub mod google;
pub mod openai;
pub mod vertexai;

View File

@@ -0,0 +1,67 @@
use crate::message::Message;
use crate::model::ModelConfig;
use crate::providers::base::Usage;
use anyhow::Result;
use mcp_core::tool::Tool;
use serde_json::Value;
use super::anthropic;
pub fn create_request(
model_config: &ModelConfig,
system: &str,
messages: &[Message],
tools: &[Tool],
) -> Result<Value> {
match model_config.model_name.as_str() {
"claude-3-5-sonnet-v2@20241022" | "claude-3-5-sonnet@20240620" => {
create_anthropic_request(model_config, system, messages, tools)
}
_ => Err(anyhow::anyhow!("Vertex AI only supports Anthropic models")),
}
}
pub fn create_anthropic_request(
model_config: &ModelConfig,
system: &str,
messages: &[Message],
tools: &[Tool],
) -> Result<Value> {
let mut request = anthropic::create_request(model_config, system, messages, tools)?;
// the Vertex AI for Claude API has small differences from the Anthropic API
// ref: https://docs.anthropic.com/en/api/claude-on-vertex-ai
request.as_object_mut().unwrap().remove("model");
request.as_object_mut().unwrap().insert(
"anthropic_version".to_string(),
Value::String("vertex-2023-10-16".to_string()),
);
Ok(request)
}
pub fn response_to_message(response: Value) -> Result<Message> {
anthropic::response_to_message(response)
}
pub fn get_usage(data: &Value) -> Result<Usage> {
anthropic::get_usage(data)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_create_request() {
let model_config = ModelConfig::new("claude-3-5-sonnet-v2@20241022".to_string());
let system = "You are a helpful assistant.";
let messages = vec![Message::user().with_text("Hello, how are you?")];
let tools = vec![];
let request = create_request(&model_config, &system, &messages, &tools).unwrap();
assert!(request.get("anthropic_version").is_some());
assert!(request.get("model").is_none());
}
}

View File

@@ -13,5 +13,6 @@ pub mod ollama;
pub mod openai;
pub mod openrouter;
pub mod utils;
pub mod vertexai;
pub use factory::{create, providers};

View File

@@ -0,0 +1,189 @@
use std::time::Duration;
use anyhow::Result;
use async_trait::async_trait;
use gcp_sdk_auth::credentials::create_access_token_credential;
use reqwest::Client;
use serde_json::Value;
use crate::message::Message;
use crate::model::ModelConfig;
use crate::providers::base::{ConfigKey, Provider, ProviderMetadata, ProviderUsage};
use crate::providers::errors::ProviderError;
use crate::providers::formats::vertexai::{create_request, get_usage, response_to_message};
use crate::providers::utils::emit_debug_trace;
use mcp_core::tool::Tool;
pub const VERTEXAI_DEFAULT_MODEL: &str = "claude-3-5-sonnet-v2@20241022";
pub const VERTEXAI_KNOWN_MODELS: &[&str] = &[
"claude-3-5-sonnet-v2@20241022",
"claude-3-5-sonnet@20240620",
];
pub const VERTEXAI_DOC_URL: &str = "https://cloud.google.com/vertex-ai";
pub const VERTEXAI_DEFAULT_REGION: &str = "us-east5";
#[derive(Debug, serde::Serialize)]
pub struct VertexAIProvider {
#[serde(skip)]
client: Client,
host: String,
project_id: String,
region: String,
model: ModelConfig,
}
impl VertexAIProvider {
pub fn from_env(model: ModelConfig) -> Result<Self> {
let config = crate::config::Config::global();
let project_id = config.get("VERTEXAI_PROJECT_ID")?;
let region = config
.get("VERTEXAI_REGION")
.unwrap_or_else(|_| VERTEXAI_DEFAULT_REGION.to_string());
let host = config
.get("VERTEXAI_API_HOST")
.unwrap_or_else(|_| format!("https://{}-aiplatform.googleapis.com", region));
let client = Client::builder()
.timeout(Duration::from_secs(600))
.build()?;
Ok(VertexAIProvider {
client,
host,
project_id,
region,
model,
})
}
async fn post(&self, payload: Value) -> Result<Value, ProviderError> {
let base_url = url::Url::parse(&self.host)
.map_err(|e| ProviderError::RequestFailed(format!("Invalid base URL: {e}")))?;
let path = format!(
"v1/projects/{}/locations/{}/publishers/{}/models/{}:streamRawPredict",
self.project_id,
self.region,
self.get_model_provider(),
self.model.model_name
);
let url = base_url.join(&path).map_err(|e| {
ProviderError::RequestFailed(format!("Failed to construct endpoint URL: {e}"))
})?;
let creds = create_access_token_credential().await.map_err(|e| {
ProviderError::RequestFailed(format!("Failed to create access token credential: {}", e))
})?;
let token = creds.get_token().await.map_err(|e| {
ProviderError::RequestFailed(format!("Failed to get access token: {}", e))
})?;
let response = self
.client
.post(url)
.json(&payload)
.header("Authorization", format!("Bearer {}", token.token))
.send()
.await
.map_err(|e| ProviderError::RequestFailed(format!("Request failed: {}", e)))?;
let status = response.status();
let response_json = response.json::<Value>().await.map_err(|e| {
ProviderError::RequestFailed(format!("Failed to parse response: {}", e))
})?;
match status {
reqwest::StatusCode::OK => Ok(response_json),
reqwest::StatusCode::UNAUTHORIZED | reqwest::StatusCode::FORBIDDEN => {
tracing::debug!(
"{}",
format!(
"Provider request failed with status: {}. Payload: {:?}",
status, payload
)
);
Err(ProviderError::Authentication(format!(
"Authentication failed: {:?}",
response_json
)))
}
_ => {
tracing::debug!(
"{}",
format!("Request failed with status {}: {:?}", status, response_json)
);
Err(ProviderError::RequestFailed(format!(
"Request failed with status {}: {:?}",
status, response_json
)))
}
}
}
fn get_model_provider(&self) -> String {
// TODO: switch this by model_name
"anthropic".to_string()
}
}
impl Default for VertexAIProvider {
fn default() -> Self {
let model = ModelConfig::new(Self::metadata().default_model);
VertexAIProvider::from_env(model).expect("Failed to initialize VertexAI provider")
}
}
#[async_trait]
impl Provider for VertexAIProvider {
fn metadata() -> ProviderMetadata
where
Self: Sized,
{
ProviderMetadata::new(
"vertex_ai",
"Vertex AI",
"Access variety of AI models such as Claude through Vertex AI",
VERTEXAI_DEFAULT_MODEL,
VERTEXAI_KNOWN_MODELS
.iter()
.map(|&s| s.to_string())
.collect(),
VERTEXAI_DOC_URL,
vec![
ConfigKey::new("VERTEXAI_PROJECT_ID", true, false, None),
ConfigKey::new(
"VERTEXAI_REGION",
true,
false,
Some(VERTEXAI_DEFAULT_REGION),
),
],
)
}
#[tracing::instrument(
skip(self, system, messages, tools),
fields(model_config, input, output, input_tokens, output_tokens, total_tokens)
)]
async fn complete(
&self,
system: &str,
messages: &[Message],
tools: &[Tool],
) -> Result<(Message, ProviderUsage), ProviderError> {
let request = create_request(&self.model, system, messages, tools)?;
let response = self.post(request.clone()).await?;
let usage = get_usage(&response)?;
emit_debug_trace(self, &request, &response, &usage);
let message = response_to_message(response.clone())?;
let provider_usage = ProviderUsage::new(self.model.model_name.clone(), usage);
Ok((message, provider_usage))
}
fn get_model_config(&self) -> ModelConfig {
self.model.clone()
}
}

View File

@@ -6,6 +6,7 @@ use goose::agents::AgentFactory;
use goose::message::Message;
use goose::model::ModelConfig;
use goose::providers::base::Provider;
use goose::providers::vertexai::VertexAIProvider;
use goose::providers::{anthropic::AnthropicProvider, databricks::DatabricksProvider};
use goose::providers::{
azure::AzureProvider, bedrock::BedrockProvider, ollama::OllamaProvider, openai::OpenAiProvider,
@@ -24,6 +25,7 @@ enum ProviderType {
Groq,
Ollama,
OpenRouter,
VertexAI,
}
impl ProviderType {
@@ -42,6 +44,7 @@ impl ProviderType {
ProviderType::Groq => &["GROQ_API_KEY"],
ProviderType::Ollama => &[],
ProviderType::OpenRouter => &["OPENROUTER_API_KEY"],
ProviderType::VertexAI => &["VERTEXAI_PROJECT_ID", "VERTEXAI_REGION"],
}
}
@@ -74,6 +77,7 @@ impl ProviderType {
ProviderType::Groq => Box::new(GroqProvider::from_env(model_config)?),
ProviderType::Ollama => Box::new(OllamaProvider::from_env(model_config)?),
ProviderType::OpenRouter => Box::new(OpenRouterProvider::from_env(model_config)?),
ProviderType::VertexAI => Box::new(VertexAIProvider::from_env(model_config)?),
})
}
}
@@ -290,4 +294,14 @@ mod tests {
})
.await
}
#[tokio::test]
async fn test_truncate_agent_with_vertexai() -> Result<()> {
run_test_with_config(TestConfig {
provider_type: ProviderType::VertexAI,
model: "claude-3-5-sonnet-v2@20241022",
context_window: 200_000,
})
.await
}
}

View File

@@ -8,6 +8,8 @@ export function isSecretKey(keyName: string): boolean {
'OLLAMA_HOST',
'AZURE_OPENAI_ENDPOINT',
'AZURE_OPENAI_DEPLOYMENT_NAME',
'VERTEXAI_PROJECT_ID',
'VERTEXAI_REGION',
];
return !nonSecretKeys.includes(keyName);
}

View File

@@ -19,6 +19,8 @@ export const goose_models: Model[] = [
{ id: 17, name: 'qwen2.5', provider: 'Ollama' },
{ id: 18, name: 'anthropic/claude-3.5-sonnet', provider: 'OpenRouter' },
{ id: 19, name: 'gpt-4o', provider: 'Azure OpenAI' },
{ id: 20, name: 'claude-3-5-sonnet-v2@20241022', provider: 'Vertex AI' },
{ id: 21, name: 'claude-3-5-sonnet@20240620', provider: 'Vertex AI' },
];
export const openai_models = ['gpt-4o-mini', 'gpt-4o', 'gpt-4-turbo', 'o1'];
@@ -47,6 +49,8 @@ export const openrouter_models = ['anthropic/claude-3.5-sonnet'];
export const azure_openai_models = ['gpt-4o'];
export const vertexai_models = ['claude-3-5-sonnet-v2@20241022', 'claude-3-5-sonnet@20240620'];
export const default_models = {
openai: 'gpt-4o',
anthropic: 'claude-3-5-sonnet-latest',
@@ -56,6 +60,7 @@ export const default_models = {
openrouter: 'anthropic/claude-3.5-sonnet',
ollama: 'qwen2.5',
azure_openai: 'gpt-4o',
vertex_ai: 'claude-3-5-sonnet-v2@20241022',
};
export function getDefaultModel(key: string): string | undefined {
@@ -73,6 +78,7 @@ export const required_keys = {
Google: ['GOOGLE_API_KEY'],
OpenRouter: ['OPENROUTER_API_KEY'],
'Azure OpenAI': ['AZURE_OPENAI_API_KEY', 'AZURE_OPENAI_ENDPOINT', 'AZURE_OPENAI_DEPLOYMENT_NAME'],
'Vertex AI': ['VERTEXAI_PROJECT_ID', 'VERTEXAI_REGION'],
};
export const supported_providers = [
@@ -84,6 +90,7 @@ export const supported_providers = [
'Ollama',
'OpenRouter',
'Azure OpenAI',
'Vertex AI',
];
export const model_docs_link = [
@@ -97,6 +104,7 @@ export const model_docs_link = [
},
{ name: 'OpenRouter', href: 'https://openrouter.ai/models' },
{ name: 'Ollama', href: 'https://ollama.com/library' },
{ name: 'Vertex AI', href: 'https://cloud.google.com/vertex-ai' },
];
export const provider_aliases = [
@@ -108,4 +116,5 @@ export const provider_aliases = [
{ provider: 'OpenRouter', alias: 'openrouter' },
{ provider: 'Google', alias: 'google' },
{ provider: 'Azure OpenAI', alias: 'azure_openai' },
{ provider: 'Vertex AI', alias: 'vertex_ai' },
];