feat: add GCP Vertex AI platform as provider (#1364)

Signed-off-by: Uddhav Kambli <uddhav@kambli.net>
This commit is contained in:
Uddhav Kambli
2025-03-03 11:46:11 -05:00
committed by GitHub
parent d1a365789e
commit 68b8c5d19d
13 changed files with 2269 additions and 24 deletions

129
Cargo.lock generated
View File

@@ -230,9 +230,9 @@ dependencies = [
[[package]] [[package]]
name = "async-trait" name = "async-trait"
version = "0.1.85" version = "0.1.86"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3f934833b4b7233644e5848f235df3f57ed8c80f1528a26c3dfa13d2147fa056" checksum = "644dd749086bf3771a2fbc5f256fdb982d53f011c7d5d560304eafeecebce79d"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
@@ -1699,6 +1699,12 @@ version = "0.15.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "77c90badedccf4105eca100756a0b1289e191f6fcbdadd3cee1d2f614f97da8f" checksum = "77c90badedccf4105eca100756a0b1289e191f6fcbdadd3cee1d2f614f97da8f"
[[package]]
name = "downcast"
version = "0.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1435fa1053d8b2fbbe9be7e97eca7f33d37b28409959813daefc1446a14247f1"
[[package]] [[package]]
name = "dyn-clone" name = "dyn-clone"
version = "1.0.17" version = "1.0.17"
@@ -1883,6 +1889,12 @@ dependencies = [
"percent-encoding", "percent-encoding",
] ]
[[package]]
name = "fragile"
version = "2.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6c2141d6d6c8512188a7891b4b01590a45f6dac67afb4f255c4124dbb86d4eaa"
[[package]] [[package]]
name = "futures" name = "futures"
version = "0.3.31" version = "0.3.31"
@@ -2149,10 +2161,12 @@ dependencies = [
"futures", "futures",
"include_dir", "include_dir",
"indoc", "indoc",
"jsonwebtoken",
"keyring", "keyring",
"lazy_static", "lazy_static",
"mcp-client", "mcp-client",
"mcp-core", "mcp-core",
"mockall",
"nanoid", "nanoid",
"once_cell", "once_cell",
"paste", "paste",
@@ -2571,7 +2585,7 @@ dependencies = [
"http 1.2.0", "http 1.2.0",
"hyper 1.6.0", "hyper 1.6.0",
"hyper-util", "hyper-util",
"rustls 0.23.21", "rustls 0.23.23",
"rustls-native-certs 0.8.1", "rustls-native-certs 0.8.1",
"rustls-pki-types", "rustls-pki-types",
"tokio", "tokio",
@@ -3028,6 +3042,21 @@ dependencies = [
"serde", "serde",
] ]
[[package]]
name = "jsonwebtoken"
version = "9.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5a87cc7a48537badeae96744432de36f4be2b4a34a05a5ef32e9dd8a1c169dde"
dependencies = [
"base64 0.22.1",
"js-sys",
"pem",
"ring",
"serde",
"serde_json",
"simple_asn1",
]
[[package]] [[package]]
name = "keyring" name = "keyring"
version = "3.6.1" version = "3.6.1"
@@ -3367,6 +3396,32 @@ dependencies = [
"windows-sys 0.52.0", "windows-sys 0.52.0",
] ]
[[package]]
name = "mockall"
version = "0.13.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "39a6bfcc6c8c7eed5ee98b9c3e33adc726054389233e201c95dab2d41a3839d2"
dependencies = [
"cfg-if",
"downcast",
"fragile",
"mockall_derive",
"predicates",
"predicates-tree",
]
[[package]]
name = "mockall_derive"
version = "0.13.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "25ca3004c2efe9011bd4e461bd8256445052b9615405b4f7ea43fc8ca5c20898"
dependencies = [
"cfg-if",
"proc-macro2",
"quote",
"syn 2.0.96",
]
[[package]] [[package]]
name = "monostate" name = "monostate"
version = "0.1.13" version = "0.1.13"
@@ -3739,6 +3794,16 @@ version = "0.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "df94ce210e5bc13cb6651479fa48d14f601d9858cfe0467f43ae157023b938d3" checksum = "df94ce210e5bc13cb6651479fa48d14f601d9858cfe0467f43ae157023b938d3"
[[package]]
name = "pem"
version = "3.0.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "38af38e8470ac9dee3ce1bae1af9c1671fffc44ddfd8bd1d0a3445bf349a8ef3"
dependencies = [
"base64 0.22.1",
"serde",
]
[[package]] [[package]]
name = "percent-encoding" name = "percent-encoding"
version = "2.3.1" version = "2.3.1"
@@ -3941,6 +4006,32 @@ dependencies = [
"zerocopy", "zerocopy",
] ]
[[package]]
name = "predicates"
version = "3.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a5d19ee57562043d37e82899fade9a22ebab7be9cef5026b07fda9cdd4293573"
dependencies = [
"anstyle",
"predicates-core",
]
[[package]]
name = "predicates-core"
version = "1.0.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "727e462b119fe9c93fd0eb1429a5f7647394014cf3c04ab2c0350eeb09095ffa"
[[package]]
name = "predicates-tree"
version = "1.0.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "72dd2d6d381dfb73a193c7fca536518d7caee39fc8503f74e7dc0be0531b425c"
dependencies = [
"predicates-core",
"termtree",
]
[[package]] [[package]]
name = "prettyplease" name = "prettyplease"
version = "0.2.29" version = "0.2.29"
@@ -4063,7 +4154,7 @@ dependencies = [
"quinn-proto", "quinn-proto",
"quinn-udp", "quinn-udp",
"rustc-hash 2.1.0", "rustc-hash 2.1.0",
"rustls 0.23.21", "rustls 0.23.23",
"socket2", "socket2",
"thiserror 2.0.11", "thiserror 2.0.11",
"tokio", "tokio",
@@ -4081,7 +4172,7 @@ dependencies = [
"rand", "rand",
"ring", "ring",
"rustc-hash 2.1.0", "rustc-hash 2.1.0",
"rustls 0.23.21", "rustls 0.23.23",
"rustls-pki-types", "rustls-pki-types",
"slab", "slab",
"thiserror 2.0.11", "thiserror 2.0.11",
@@ -4393,7 +4484,7 @@ dependencies = [
"percent-encoding", "percent-encoding",
"pin-project-lite", "pin-project-lite",
"quinn", "quinn",
"rustls 0.23.21", "rustls 0.23.23",
"rustls-pemfile 2.2.0", "rustls-pemfile 2.2.0",
"rustls-pki-types", "rustls-pki-types",
"serde", "serde",
@@ -4514,9 +4605,9 @@ dependencies = [
[[package]] [[package]]
name = "rustls" name = "rustls"
version = "0.23.21" version = "0.23.23"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8f287924602bf649d949c63dc8ac8b235fa5387d394020705b80c4eb597ce5b8" checksum = "47796c98c480fce5406ef69d1c76378375492c3b0a0de587be0c1d9feb12f395"
dependencies = [ dependencies = [
"once_cell", "once_cell",
"ring", "ring",
@@ -4977,6 +5068,18 @@ dependencies = [
"quote", "quote",
] ]
[[package]]
name = "simple_asn1"
version = "0.6.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "297f631f50729c8c99b84667867963997ec0b50f32b2a7dbcab828ef0541e8bb"
dependencies = [
"num-bigint",
"num-traits",
"thiserror 2.0.11",
"time",
]
[[package]] [[package]]
name = "siphasher" name = "siphasher"
version = "1.0.1" version = "1.0.1"
@@ -5268,6 +5371,12 @@ dependencies = [
"windows-sys 0.59.0", "windows-sys 0.59.0",
] ]
[[package]]
name = "termtree"
version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8f50febec83f5ee1df3015341d8bd429f2d1cc62bcba7ea2076759d315084683"
[[package]] [[package]]
name = "test-case" name = "test-case"
version = "3.3.1" version = "3.3.1"
@@ -5537,7 +5646,7 @@ version = "0.26.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5f6d0975eaace0cf0fcadee4e4aaa5da15b5c079146f2cffb67c113be122bf37" checksum = "5f6d0975eaace0cf0fcadee4e4aaa5da15b5c079146f2cffb67c113be122bf37"
dependencies = [ dependencies = [
"rustls 0.23.21", "rustls 0.23.23",
"tokio", "tokio",
] ]
@@ -6767,7 +6876,7 @@ dependencies = [
"hyper-util", "hyper-util",
"log", "log",
"percent-encoding", "percent-encoding",
"rustls 0.23.21", "rustls 0.23.23",
"rustls-pemfile 2.2.0", "rustls-pemfile 2.2.0",
"seahash", "seahash",
"serde", "serde",

View File

@@ -17,6 +17,12 @@
"models": ["goose"], "models": ["goose"],
"required_keys": ["DATABRICKS_HOST"] "required_keys": ["DATABRICKS_HOST"]
}, },
"gcp_vertex_ai": {
"name": "GCP Vertex AI",
"description": "Use Vertex AI platform models",
"models": ["claude-3-5-haiku@20241022", "claude-3-5-sonnet@20240620", "claude-3-5-sonnet-v2@20241022", "claude-3-7-sonnet@20250219", "gemini-1.5-pro-002", "gemini-2.0-flash-001", "gemini-2.0-pro-exp-02-05"],
"required_keys": ["GCP_PROJECT_ID", "GCP_LOCATION"]
},
"google": { "google": {
"name": "Google", "name": "Google",
"description": "Lorem ipsum", "description": "Lorem ipsum",

View File

@@ -66,6 +66,9 @@ aws-config = { version = "1.1.7", features = ["behavior-version-latest"] }
aws-smithy-types = "1.2.12" aws-smithy-types = "1.2.12"
aws-sdk-bedrockruntime = "1.72.0" aws-sdk-bedrockruntime = "1.72.0"
# For GCP Vertex AI provider auth
jsonwebtoken = "9.3.1"
[target.'cfg(target_os = "windows")'.dependencies] [target.'cfg(target_os = "windows")'.dependencies]
winapi = { version = "0.3", features = ["wincred"] } winapi = { version = "0.3", features = ["wincred"] }
@@ -73,6 +76,9 @@ winapi = { version = "0.3", features = ["wincred"] }
criterion = "0.5" criterion = "0.5"
tempfile = "3.15.0" tempfile = "3.15.0"
serial_test = "3.2.0" serial_test = "3.2.0"
mockall = "0.13.1"
wiremock = "0.6.0"
tokio = { version = "1.0", features = ["full"] }
[[example]] [[example]]
name = "agent" name = "agent"

View File

@@ -4,6 +4,7 @@ use super::{
base::{Provider, ProviderMetadata}, base::{Provider, ProviderMetadata},
bedrock::BedrockProvider, bedrock::BedrockProvider,
databricks::DatabricksProvider, databricks::DatabricksProvider,
gcpvertexai::GcpVertexAIProvider,
google::GoogleProvider, google::GoogleProvider,
groq::GroqProvider, groq::GroqProvider,
ollama::OllamaProvider, ollama::OllamaProvider,
@@ -19,6 +20,7 @@ pub fn providers() -> Vec<ProviderMetadata> {
AzureProvider::metadata(), AzureProvider::metadata(),
BedrockProvider::metadata(), BedrockProvider::metadata(),
DatabricksProvider::metadata(), DatabricksProvider::metadata(),
GcpVertexAIProvider::metadata(),
GoogleProvider::metadata(), GoogleProvider::metadata(),
GroqProvider::metadata(), GroqProvider::metadata(),
OllamaProvider::metadata(), OllamaProvider::metadata(),
@@ -37,6 +39,7 @@ pub fn create(name: &str, model: ModelConfig) -> Result<Box<dyn Provider + Send
"groq" => Ok(Box::new(GroqProvider::from_env(model)?)), "groq" => Ok(Box::new(GroqProvider::from_env(model)?)),
"ollama" => Ok(Box::new(OllamaProvider::from_env(model)?)), "ollama" => Ok(Box::new(OllamaProvider::from_env(model)?)),
"openrouter" => Ok(Box::new(OpenRouterProvider::from_env(model)?)), "openrouter" => Ok(Box::new(OpenRouterProvider::from_env(model)?)),
"gcp_vertex_ai" => Ok(Box::new(GcpVertexAIProvider::from_env(model)?)),
"google" => Ok(Box::new(GoogleProvider::from_env(model)?)), "google" => Ok(Box::new(GoogleProvider::from_env(model)?)),
_ => Err(anyhow::anyhow!("Unknown provider: {}", name)), _ => Err(anyhow::anyhow!("Unknown provider: {}", name)),
} }

View File

@@ -0,0 +1,369 @@
use super::{anthropic, google};
use crate::message::Message;
use crate::model::ModelConfig;
use crate::providers::base::Usage;
use anyhow::{Context, Result};
use mcp_core::tool::Tool;
use serde_json::Value;
use std::fmt;
/// Sensible default values of Google Cloud Platform (GCP) locations for model deployment.
///
/// Each variant corresponds to a specific GCP region where models can be hosted.
#[derive(Debug, Clone, PartialEq, Eq, Copy)]
pub enum GcpLocation {
/// Represents the us-central1 region in Iowa
Iowa,
/// Represents the us-east5 region in Ohio
Ohio,
}
impl fmt::Display for GcpLocation {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Iowa => write!(f, "us-central1"),
Self::Ohio => write!(f, "us-east5"),
}
}
}
impl TryFrom<&str> for GcpLocation {
type Error = ModelError;
fn try_from(s: &str) -> Result<Self, Self::Error> {
match s {
"us-central1" => Ok(Self::Iowa),
"us-east5" => Ok(Self::Ohio),
_ => Err(ModelError::UnsupportedLocation(s.to_string())),
}
}
}
/// Represents errors that can occur during model operations.
///
/// This enum encompasses various error conditions that might arise when working
/// with GCP Vertex AI models, including unsupported models, invalid requests,
/// and unsupported locations.
#[derive(Debug, thiserror::Error)]
pub enum ModelError {
/// Error when an unsupported Vertex AI model is specified
#[error("Unsupported Vertex AI model: {0}")]
UnsupportedModel(String),
/// Error when the request structure is invalid
#[error("Invalid request structure: {0}")]
InvalidRequest(String),
/// Error when an unsupported GCP location is specified
#[error("Unsupported GCP location: {0}")]
UnsupportedLocation(String),
}
/// Represents available GCP Vertex AI models for Goose.
///
/// This enum encompasses different model families and their versions
/// that are supported in the GCP Vertex AI platform.
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum GcpVertexAIModel {
/// Claude model family with specific versions
Claude(ClaudeVersion),
/// Gemini model family with specific versions
Gemini(GeminiVersion),
}
/// Represents available versions of the Claude model for Goose.
#[derive(Debug, Clone, PartialEq, Eq, Copy)]
pub enum ClaudeVersion {
/// Claude 3.5 Sonnet initial version
Sonnet35,
/// Claude 3.5 Sonnet version 2
Sonnet35V2,
/// Claude 3.7 Sonnet
Sonnet37,
/// Claude 3.5 Haiku
Haiku35,
}
/// Represents available versions of the Gemini model for Goose.
#[derive(Debug, Clone, PartialEq, Eq, Copy)]
pub enum GeminiVersion {
/// Gemini 1.5 Pro version
Pro15,
/// Gemini 2.0 Flash version
Flash20,
/// Gemini 2.0 Pro Experimental version
Pro20Exp,
}
impl fmt::Display for GcpVertexAIModel {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let model_id = match self {
Self::Claude(version) => match version {
ClaudeVersion::Sonnet35 => "claude-3-5-sonnet@20240620",
ClaudeVersion::Sonnet35V2 => "claude-3-5-sonnet-v2@20241022",
ClaudeVersion::Sonnet37 => "claude-3-7-sonnet@20250219",
ClaudeVersion::Haiku35 => "claude-3-5-haiku@20241022",
},
Self::Gemini(version) => match version {
GeminiVersion::Pro15 => "gemini-1.5-pro-002",
GeminiVersion::Flash20 => "gemini-2.0-flash-001",
GeminiVersion::Pro20Exp => "gemini-2.0-pro-exp-02-05",
},
};
write!(f, "{model_id}")
}
}
impl GcpVertexAIModel {
/// Returns the default GCP location for the model.
///
/// Each model family has a well-known location:
/// - Claude models default to Ohio (us-east5)
/// - Gemini models default to Iowa (us-central1)
pub fn known_location(&self) -> GcpLocation {
match self {
Self::Claude(_) => GcpLocation::Ohio,
Self::Gemini(_) => GcpLocation::Iowa,
}
}
}
impl TryFrom<&str> for GcpVertexAIModel {
type Error = ModelError;
fn try_from(s: &str) -> Result<Self, Self::Error> {
match s {
"claude-3-5-sonnet@20240620" => Ok(Self::Claude(ClaudeVersion::Sonnet35)),
"claude-3-5-sonnet-v2@20241022" => Ok(Self::Claude(ClaudeVersion::Sonnet35V2)),
"claude-3-7-sonnet@20250219" => Ok(Self::Claude(ClaudeVersion::Sonnet37)),
"claude-3-5-haiku@20241022" => Ok(Self::Claude(ClaudeVersion::Haiku35)),
"gemini-1.5-pro-002" => Ok(Self::Gemini(GeminiVersion::Pro15)),
"gemini-2.0-flash-001" => Ok(Self::Gemini(GeminiVersion::Flash20)),
"gemini-2.0-pro-exp-02-05" => Ok(Self::Gemini(GeminiVersion::Pro20Exp)),
_ => Err(ModelError::UnsupportedModel(s.to_string())),
}
}
}
/// Holds context information for a model request since the Vertex AI platform
/// supports multiple model families.
///
/// This structure maintains information about the model being used
/// and provides utility methods for handling model-specific operations.
#[derive(Debug, Clone)]
pub struct RequestContext {
/// The GCP Vertex AI model being used
pub model: GcpVertexAIModel,
}
impl RequestContext {
/// Creates a new RequestContext from a model ID string.
///
/// # Arguments
/// * `model_id` - The string identifier of the model
///
/// # Returns
/// * `Result<Self>` - A new RequestContext if the model ID is valid
pub fn new(model_id: &str) -> Result<Self> {
Ok(Self {
model: GcpVertexAIModel::try_from(model_id)
.with_context(|| format!("Failed to parse model ID: {model_id}"))?,
})
}
/// Returns the provider associated with the model.
pub fn provider(&self) -> ModelProvider {
match self.model {
GcpVertexAIModel::Claude(_) => ModelProvider::Anthropic,
GcpVertexAIModel::Gemini(_) => ModelProvider::Google,
}
}
}
/// Represents available model providers.
#[derive(Debug, Clone, PartialEq, Eq, Copy)]
pub enum ModelProvider {
/// Anthropic provider (Claude models)
Anthropic,
/// Google provider (Gemini models)
Google,
}
impl ModelProvider {
/// Returns the string representation of the provider.
pub fn as_str(&self) -> &'static str {
match self {
Self::Anthropic => "anthropic",
Self::Google => "google",
}
}
}
/// Creates an Anthropic-specific Vertex AI request payload.
///
/// # Arguments
/// * `model_config` - Configuration for the model
/// * `system` - System prompt
/// * `messages` - Array of messages
/// * `tools` - Array of available tools
///
/// # Returns
/// * `Result<Value>` - JSON request payload for Anthropic API
fn create_anthropic_request(
model_config: &ModelConfig,
system: &str,
messages: &[Message],
tools: &[Tool],
) -> Result<Value> {
let mut request = anthropic::create_request(model_config, system, messages, tools)?;
let obj = request
.as_object_mut()
.ok_or_else(|| ModelError::InvalidRequest("Request is not a JSON object".to_string()))?;
// Note: We don't need to specify the model in the request body
// The model is determined by the endpoint URL in GCP Vertex AI
obj.remove("model");
obj.insert(
"anthropic_version".to_string(),
Value::String("vertex-2023-10-16".to_string()),
);
Ok(request)
}
/// Creates a Gemini-specific Vertex AI request payload.
///
/// # Arguments
/// * `model_config` - Configuration for the model
/// * `system` - System prompt
/// * `messages` - Array of messages
/// * `tools` - Array of available tools
///
/// # Returns
/// * `Result<Value>` - JSON request payload for Google API
fn create_google_request(
model_config: &ModelConfig,
system: &str,
messages: &[Message],
tools: &[Tool],
) -> Result<Value> {
google::create_request(model_config, system, messages, tools)
}
/// Creates a provider-specific request payload and context.
///
/// # Arguments
/// * `model_config` - Configuration for the model
/// * `system` - System prompt
/// * `messages` - Array of messages
/// * `tools` - Array of available tools
///
/// # Returns
/// * `Result<(Value, RequestContext)>` - Tuple of request payload and context
pub fn create_request(
model_config: &ModelConfig,
system: &str,
messages: &[Message],
tools: &[Tool],
) -> Result<(Value, RequestContext)> {
let context = RequestContext::new(&model_config.model_name)?;
let request = match &context.model {
GcpVertexAIModel::Claude(_) => {
create_anthropic_request(model_config, system, messages, tools)?
}
GcpVertexAIModel::Gemini(_) => {
create_google_request(model_config, system, messages, tools)?
}
};
Ok((request, context))
}
/// Converts a provider response to a Message.
///
/// # Arguments
/// * `response` - The raw response from the provider
/// * `request_context` - Context information about the request
///
/// # Returns
/// * `Result<Message>` - Converted message
pub fn response_to_message(response: Value, request_context: RequestContext) -> Result<Message> {
match request_context.provider() {
ModelProvider::Anthropic => anthropic::response_to_message(response),
ModelProvider::Google => google::response_to_message(response),
}
}
/// Extracts token usage information from the response data.
///
/// # Arguments
/// * `data` - The response data containing usage information
/// * `request_context` - Context information about the request
///
/// # Returns
/// * `Result<Usage>` - Usage statistics
pub fn get_usage(data: &Value, request_context: &RequestContext) -> Result<Usage> {
match request_context.provider() {
ModelProvider::Anthropic => anthropic::get_usage(data),
ModelProvider::Google => google::get_usage(data),
}
}
#[cfg(test)]
mod tests {
use super::*;
use anyhow::Result;
#[test]
fn test_model_parsing() -> Result<()> {
let valid_models = [
"claude-3-5-sonnet@20240620",
"claude-3-5-sonnet-v2@20241022",
"claude-3-7-sonnet@20250219",
"claude-3-5-haiku@20241022",
"gemini-1.5-pro-002",
"gemini-2.0-flash-001",
"gemini-2.0-pro-exp-02-05",
];
for model_id in valid_models {
let model = GcpVertexAIModel::try_from(model_id)?;
assert_eq!(model.to_string(), model_id);
}
assert!(GcpVertexAIModel::try_from("unsupported-model").is_err());
Ok(())
}
#[test]
fn test_default_locations() -> Result<()> {
let test_cases = [
("claude-3-5-sonnet@20240620", GcpLocation::Ohio),
("claude-3-5-sonnet-v2@20241022", GcpLocation::Ohio),
("claude-3-7-sonnet@20250219", GcpLocation::Ohio),
("claude-3-5-haiku@20241022", GcpLocation::Ohio),
("gemini-1.5-pro-002", GcpLocation::Iowa),
("gemini-2.0-flash-001", GcpLocation::Iowa),
("gemini-2.0-pro-exp-02-05", GcpLocation::Iowa),
];
for (model_id, expected_location) in test_cases {
let model = GcpVertexAIModel::try_from(model_id)?;
assert_eq!(
model.known_location(),
expected_location,
"Model {model_id} should have default location {expected_location:?}",
);
let context = RequestContext::new(model_id)?;
assert_eq!(
context.model.known_location(),
expected_location,
"RequestContext for {model_id} should have default location {expected_location:?}",
);
}
Ok(())
}
}

View File

@@ -1,4 +1,5 @@
pub mod anthropic; pub mod anthropic;
pub mod bedrock; pub mod bedrock;
pub mod gcpvertexai;
pub mod google; pub mod google;
pub mod openai; pub mod openai;

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,595 @@
use std::time::Duration;
use anyhow::Result;
use async_trait::async_trait;
use reqwest::{Client, StatusCode};
use serde_json::Value;
use tokio::time::sleep;
use url::Url;
use crate::message::Message;
use crate::model::ModelConfig;
use crate::providers::base::{ConfigKey, Provider, ProviderMetadata, ProviderUsage};
use crate::providers::errors::ProviderError;
use crate::providers::formats::gcpvertexai::{
create_request, get_usage, response_to_message, ClaudeVersion, GcpVertexAIModel, GeminiVersion,
ModelProvider, RequestContext,
};
use crate::providers::formats::gcpvertexai::GcpLocation::Iowa;
use crate::providers::gcpauth::GcpAuth;
use crate::providers::utils::emit_debug_trace;
use mcp_core::tool::Tool;
/// Base URL for GCP Vertex AI documentation
const GCP_VERTEX_AI_DOC_URL: &str = "https://cloud.google.com/vertex-ai";
/// Default timeout for API requests in seconds
const DEFAULT_TIMEOUT_SECS: u64 = 600;
/// Default initial interval for retry (in milliseconds)
const DEFAULT_INITIAL_RETRY_INTERVAL_MS: u64 = 5000;
/// Default maximum number of retries
const DEFAULT_MAX_RETRIES: usize = 6;
/// Default retry backoff multiplier
const DEFAULT_BACKOFF_MULTIPLIER: f64 = 2.0;
/// Default maximum interval for retry (in milliseconds)
const DEFAULT_MAX_RETRY_INTERVAL_MS: u64 = 320_000;
/// Represents errors specific to GCP Vertex AI operations.
#[derive(Debug, thiserror::Error)]
enum GcpVertexAIError {
/// Error when URL construction fails
#[error("Invalid URL configuration: {0}")]
InvalidUrl(String),
/// Error during GCP authentication
#[error("Authentication error: {0}")]
AuthError(String),
}
/// Retry configuration for handling rate limit errors
#[derive(Debug, Clone)]
struct RetryConfig {
/// Maximum number of retry attempts
max_retries: usize,
/// Initial interval between retries in milliseconds
initial_interval_ms: u64,
/// Multiplier for backoff (exponential)
backoff_multiplier: f64,
/// Maximum interval between retries in milliseconds
max_interval_ms: u64,
}
impl Default for RetryConfig {
fn default() -> Self {
Self {
max_retries: DEFAULT_MAX_RETRIES,
initial_interval_ms: DEFAULT_INITIAL_RETRY_INTERVAL_MS,
backoff_multiplier: DEFAULT_BACKOFF_MULTIPLIER,
max_interval_ms: DEFAULT_MAX_RETRY_INTERVAL_MS,
}
}
}
impl RetryConfig {
/// Calculate the delay for a specific retry attempt (with jitter)
fn delay_for_attempt(&self, attempt: usize) -> Duration {
if attempt == 0 {
return Duration::from_millis(0);
}
// Calculate exponential backoff
let exponent = (attempt - 1) as u32;
let base_delay_ms = (self.initial_interval_ms as f64
* self.backoff_multiplier.powi(exponent as i32)) as u64;
// Apply max limit
let capped_delay_ms = std::cmp::min(base_delay_ms, self.max_interval_ms);
// Add jitter (+/-20% randomness) to avoid thundering herd problem
let jitter_factor = 0.8 + (rand::random::<f64>() * 0.4); // Between 0.8 and 1.2
let jittered_delay_ms = (capped_delay_ms as f64 * jitter_factor) as u64;
Duration::from_millis(jittered_delay_ms)
}
}
/// Provider implementation for Google Cloud Platform's Vertex AI service.
///
/// This provider enables interaction with various AI models hosted on GCP Vertex AI,
/// including Claude and Gemini model families. It handles authentication, request routing,
/// and response processing for the Vertex AI API endpoints.
#[derive(Debug, serde::Serialize)]
pub struct GcpVertexAIProvider {
/// HTTP client for making API requests
#[serde(skip)]
client: Client,
/// GCP authentication handler
#[serde(skip)]
auth: GcpAuth,
/// Base URL for the Vertex AI API
host: String,
/// GCP project identifier
project_id: String,
/// GCP region for model deployment
location: String,
/// Configuration for the specific model being used
model: ModelConfig,
/// Retry configuration for handling rate limit errors
#[serde(skip)]
retry_config: RetryConfig,
}
impl GcpVertexAIProvider {
/// Creates a new provider instance from environment configuration.
///
/// This is a convenience method that initializes the provider using
/// environment variables and default settings.
///
/// # Arguments
/// * `model` - Configuration for the model to be used
pub fn from_env(model: ModelConfig) -> Result<Self> {
Self::new(model)
}
/// Creates a new provider instance with the specified model configuration.
///
/// # Arguments
/// * `model` - Configuration for the model to be used
pub fn new(model: ModelConfig) -> Result<Self> {
futures::executor::block_on(Self::new_async(model))
}
/// Async implementation of new provider instance creation.
///
/// # Arguments
/// * `model` - Configuration for the model to be used
async fn new_async(model: ModelConfig) -> Result<Self> {
let config = crate::config::Config::global();
let project_id = config.get("GCP_PROJECT_ID")?;
let location = Self::determine_location(config)?;
let host = format!("https://{}-aiplatform.googleapis.com", location);
let client = Client::builder()
.timeout(Duration::from_secs(DEFAULT_TIMEOUT_SECS))
.build()?;
let auth = GcpAuth::new().await?;
// Load optional retry configuration from environment
let retry_config = Self::load_retry_config(config);
Ok(Self {
client,
auth,
host,
project_id,
location,
model,
retry_config,
})
}
/// Loads retry configuration from environment variables or uses defaults.
fn load_retry_config(config: &crate::config::Config) -> RetryConfig {
let max_retries = config
.get("GCP_MAX_RETRIES")
.ok()
.and_then(|v: String| v.parse::<usize>().ok())
.unwrap_or(DEFAULT_MAX_RETRIES);
let initial_interval_ms = config
.get("GCP_INITIAL_RETRY_INTERVAL_MS")
.ok()
.and_then(|v: String| v.parse::<u64>().ok())
.unwrap_or(DEFAULT_INITIAL_RETRY_INTERVAL_MS);
let backoff_multiplier = config
.get("GCP_BACKOFF_MULTIPLIER")
.ok()
.and_then(|v: String| v.parse::<f64>().ok())
.unwrap_or(DEFAULT_BACKOFF_MULTIPLIER);
let max_interval_ms = config
.get("GCP_MAX_RETRY_INTERVAL_MS")
.ok()
.and_then(|v: String| v.parse::<u64>().ok())
.unwrap_or(DEFAULT_MAX_RETRY_INTERVAL_MS);
RetryConfig {
max_retries,
initial_interval_ms,
backoff_multiplier,
max_interval_ms,
}
}
/// Determines the appropriate GCP location for model deployment.
///
/// Location is determined in the following order:
/// 1. Custom location from GCP_LOCATION environment variable
/// 2. Global default location (Iowa)
fn determine_location(config: &crate::config::Config) -> Result<String> {
Ok(config
.get("GCP_LOCATION")
.ok()
.filter(|location: &String| !location.trim().is_empty())
.unwrap_or_else(|| Iowa.to_string()))
}
/// Retrieves an authentication token for API requests.
async fn get_auth_header(&self) -> Result<String, GcpVertexAIError> {
self.auth
.get_token()
.await
.map(|token| format!("Bearer {}", token.token_value))
.map_err(|e| GcpVertexAIError::AuthError(e.to_string()))
}
/// Constructs the appropriate API endpoint URL for a given provider.
///
/// # Arguments
/// * `provider` - The model provider (Anthropic or Google)
/// * `location` - The GCP location for model deployment
fn build_request_url(
&self,
provider: ModelProvider,
location: &str,
) -> Result<Url, GcpVertexAIError> {
// Create host URL for the specified location
let host_url = if self.location == location {
self.host.clone()
} else {
// Only allocate a new string if location differs
self.host.replace(&self.location, location)
};
let base_url =
Url::parse(&host_url).map_err(|e| GcpVertexAIError::InvalidUrl(e.to_string()))?;
// Determine endpoint based on provider type
let endpoint = match provider {
ModelProvider::Anthropic => "streamRawPredict",
ModelProvider::Google => "generateContent",
};
// Construct path for URL
let path = format!(
"v1/projects/{}/locations/{}/publishers/{}/models/{}:{}",
self.project_id,
location,
provider.as_str(),
self.model.model_name,
endpoint
);
base_url
.join(&path)
.map_err(|e| GcpVertexAIError::InvalidUrl(e.to_string()))
}
/// Makes an authenticated POST request to the Vertex AI API at a specific location.
/// Includes retry logic for 429 Too Many Requests errors.
///
/// # Arguments
/// * `payload` - The request payload to send
/// * `context` - Request context containing model information
/// * `location` - The GCP location for the request
async fn post_with_location(
&self,
payload: &Value,
context: &RequestContext,
location: &str,
) -> Result<Value, ProviderError> {
let url = self
.build_request_url(context.provider(), location)
.map_err(|e| ProviderError::RequestFailed(e.to_string()))?;
// Initialize retry counter
let mut attempts = 0;
let mut last_error = None;
loop {
// Check if we've exceeded max retries
if attempts > 0 && attempts > self.retry_config.max_retries {
let error_msg = format!(
"Exceeded maximum retry attempts ({}) for rate limiting (429)",
self.retry_config.max_retries
);
tracing::error!("{}", error_msg);
return Err(last_error.unwrap_or(ProviderError::RateLimitExceeded(error_msg)));
}
// Get a fresh auth token for each attempt
let auth_header = self
.get_auth_header()
.await
.map_err(|e| ProviderError::Authentication(e.to_string()))?;
// Make the request
let response = self
.client
.post(url.clone())
.json(payload)
.header("Authorization", auth_header)
.send()
.await
.map_err(|e| ProviderError::RequestFailed(e.to_string()))?;
let status = response.status();
// If not a 429, process normally
if status != StatusCode::TOO_MANY_REQUESTS {
let response_json = response.json::<Value>().await.map_err(|e| {
ProviderError::RequestFailed(format!("Failed to parse response: {e}"))
})?;
return match status {
StatusCode::OK => Ok(response_json),
StatusCode::UNAUTHORIZED | StatusCode::FORBIDDEN => {
tracing::debug!(
"Authentication failed. Status: {status}, Payload: {payload:?}"
);
Err(ProviderError::Authentication(format!(
"Authentication failed: {response_json:?}"
)))
}
_ => {
tracing::debug!(
"Request failed. Status: {status}, Response: {response_json:?}"
);
Err(ProviderError::RequestFailed(format!(
"Request failed with status {status}: {response_json:?}"
)))
}
};
}
// Handle 429 Too Many Requests
attempts += 1;
// Try to parse response for more detailed error info
let cite_gcp_vertex_429 =
"See https://cloud.google.com/vertex-ai/generative-ai/docs/error-code-429";
let response_text = response.text().await.unwrap_or_default();
let quota_error = if response_text.contains("Exceeded the Provisioned Throughput") {
format!("Exceeded the Provisioned Throughput: {cite_gcp_vertex_429}.")
} else {
format!("Pay-as-you-go resource exhausted: {cite_gcp_vertex_429}.")
};
tracing::warn!(
"Rate limit exceeded (attempt {}/{}): {}. Retrying after backoff...",
attempts,
self.retry_config.max_retries,
quota_error
);
// Store the error in case we need to return it after max retries
last_error = Some(ProviderError::RateLimitExceeded(quota_error));
// Calculate and apply the backoff delay
let delay = self.retry_config.delay_for_attempt(attempts);
tracing::info!("Backing off for {:?} before retry", delay);
sleep(delay).await;
}
}
/// Makes an authenticated POST request to the Vertex AI API with fallback for invalid locations.
///
/// # Arguments
/// * `payload` - The request payload to send
/// * `context` - Request context containing model information
async fn post(&self, payload: Value, context: &RequestContext) -> Result<Value, ProviderError> {
// Try with user-specified location first
let result = self
.post_with_location(&payload, context, &self.location)
.await;
// If location is already the known location for the model or request succeeded, return result
if self.location == context.model.known_location().to_string() || result.is_ok() {
return result;
}
// Check if we should retry with the model's known location
match &result {
Err(ProviderError::RequestFailed(msg)) => {
let model_name = context.model.to_string();
let configured_location = &self.location;
let known_location = context.model.known_location().to_string();
tracing::error!(
"Trying known location {known_location} for {model_name} instead of {configured_location}: {msg}"
);
self.post_with_location(&payload, context, &known_location)
.await
}
// For any other error, return the original result
_ => result,
}
}
}
impl Default for GcpVertexAIProvider {
fn default() -> Self {
let model = ModelConfig::new(Self::metadata().default_model);
Self::new(model).expect("Failed to initialize VertexAI provider")
}
}
#[async_trait]
impl Provider for GcpVertexAIProvider {
/// Returns metadata about the GCP Vertex AI provider.
fn metadata() -> ProviderMetadata
where
Self: Sized,
{
let known_models = vec![
GcpVertexAIModel::Claude(ClaudeVersion::Sonnet35),
GcpVertexAIModel::Claude(ClaudeVersion::Sonnet35V2),
GcpVertexAIModel::Claude(ClaudeVersion::Sonnet37),
GcpVertexAIModel::Claude(ClaudeVersion::Haiku35),
GcpVertexAIModel::Gemini(GeminiVersion::Pro15),
GcpVertexAIModel::Gemini(GeminiVersion::Flash20),
GcpVertexAIModel::Gemini(GeminiVersion::Pro20Exp),
]
.into_iter()
.map(|model| model.to_string())
.collect();
ProviderMetadata::new(
"gcp_vertex_ai",
"GCP Vertex AI",
"Access variety of AI models such as Claude, Gemini through Vertex AI",
GcpVertexAIModel::Gemini(GeminiVersion::Flash20)
.to_string()
.as_str(),
known_models,
GCP_VERTEX_AI_DOC_URL,
vec![
ConfigKey::new("GCP_PROJECT_ID", true, false, None),
ConfigKey::new("GCP_LOCATION", true, false, Some(Iowa.to_string().as_str())),
ConfigKey::new(
"GCP_MAX_RETRIES",
false,
false,
Some(&DEFAULT_MAX_RETRIES.to_string()),
),
ConfigKey::new(
"GCP_INITIAL_RETRY_INTERVAL_MS",
false,
false,
Some(&DEFAULT_INITIAL_RETRY_INTERVAL_MS.to_string()),
),
ConfigKey::new(
"GCP_BACKOFF_MULTIPLIER",
false,
false,
Some(&DEFAULT_BACKOFF_MULTIPLIER.to_string()),
),
ConfigKey::new(
"GCP_MAX_RETRY_INTERVAL_MS",
false,
false,
Some(&DEFAULT_MAX_RETRY_INTERVAL_MS.to_string()),
),
],
)
}
/// Completes a model interaction by sending a request and processing the response.
///
/// # Arguments
/// * `system` - System prompt or context
/// * `messages` - Array of previous messages in the conversation
/// * `tools` - Array of available tools for the model
#[tracing::instrument(
skip(self, system, messages, tools),
fields(model_config, input, output, input_tokens, output_tokens, total_tokens)
)]
async fn complete(
&self,
system: &str,
messages: &[Message],
tools: &[Tool],
) -> Result<(Message, ProviderUsage), ProviderError> {
// Create request and context
let (request, context) = create_request(&self.model, system, messages, tools)?;
// Send request and process response
let response = self.post(request.clone(), &context).await?;
let usage = get_usage(&response, &context)?;
emit_debug_trace(self, &request, &response, &usage);
// Convert response to message
let message = response_to_message(response, context)?;
let provider_usage = ProviderUsage::new(self.model.model_name.clone(), usage);
Ok((message, provider_usage))
}
/// Returns the current model configuration.
fn get_model_config(&self) -> ModelConfig {
self.model.clone()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_retry_config_delay_calculation() {
let config = RetryConfig {
max_retries: 5,
initial_interval_ms: 1000,
backoff_multiplier: 2.0,
max_interval_ms: 32000,
};
// First attempt has no delay
let delay0 = config.delay_for_attempt(0);
assert_eq!(delay0.as_millis(), 0);
// First retry should be around initial_interval with jitter
let delay1 = config.delay_for_attempt(1);
assert!(delay1.as_millis() >= 800 && delay1.as_millis() <= 1200);
// Second retry should be around initial_interval * multiplier^1 with jitter
let delay2 = config.delay_for_attempt(2);
assert!(delay2.as_millis() >= 1600 && delay2.as_millis() <= 2400);
// Check that max interval is respected
let delay10 = config.delay_for_attempt(10);
assert!(delay10.as_millis() <= 38400); // max_interval_ms * 1.2 (max jitter)
}
#[test]
fn test_model_provider_conversion() {
assert_eq!(ModelProvider::Anthropic.as_str(), "anthropic");
assert_eq!(ModelProvider::Google.as_str(), "google");
}
#[test]
fn test_url_construction() {
use url::Url;
let model_config = ModelConfig::new("claude-3-5-sonnet-v2@20241022".to_string());
let context = RequestContext::new(&model_config.model_name).unwrap();
let api_model_id = context.model.to_string();
let host = "https://us-east5-aiplatform.googleapis.com";
let project_id = "test-project";
let location = "us-east5";
let path = format!(
"v1/projects/{}/locations/{}/publishers/{}/models/{}:{}",
project_id,
location,
ModelProvider::Anthropic.as_str(),
api_model_id,
"streamRawPredict"
);
let url = Url::parse(host).unwrap().join(&path).unwrap();
assert!(url.as_str().contains("publishers/anthropic"));
assert!(url.as_str().contains("projects/test-project"));
assert!(url.as_str().contains("locations/us-east5"));
}
#[test]
fn test_provider_metadata() {
let metadata = GcpVertexAIProvider::metadata();
assert!(metadata
.known_models
.contains(&"claude-3-5-sonnet-v2@20241022".to_string()));
assert!(metadata
.known_models
.contains(&"gemini-1.5-pro-002".to_string()));
// Should contain the original 2 config keys plus 4 new retry-related ones
assert_eq!(metadata.config_keys.len(), 6);
}
}

View File

@@ -6,6 +6,8 @@ pub mod databricks;
pub mod errors; pub mod errors;
mod factory; mod factory;
pub mod formats; pub mod formats;
mod gcpauth;
pub mod gcpvertexai;
pub mod google; pub mod google;
pub mod groq; pub mod groq;
pub mod oauth; pub mod oauth;

View File

@@ -6,12 +6,12 @@ use goose::agents::AgentFactory;
use goose::message::Message; use goose::message::Message;
use goose::model::ModelConfig; use goose::model::ModelConfig;
use goose::providers::base::Provider; use goose::providers::base::Provider;
use goose::providers::{anthropic::AnthropicProvider, databricks::DatabricksProvider};
use goose::providers::{ use goose::providers::{
azure::AzureProvider, bedrock::BedrockProvider, ollama::OllamaProvider, openai::OpenAiProvider, anthropic::AnthropicProvider, azure::AzureProvider, bedrock::BedrockProvider,
databricks::DatabricksProvider, gcpvertexai::GcpVertexAIProvider, google::GoogleProvider,
groq::GroqProvider, ollama::OllamaProvider, openai::OpenAiProvider,
openrouter::OpenRouterProvider, openrouter::OpenRouterProvider,
}; };
use goose::providers::{google::GoogleProvider, groq::GroqProvider};
#[derive(Debug, PartialEq)] #[derive(Debug, PartialEq)]
enum ProviderType { enum ProviderType {
@@ -20,6 +20,7 @@ enum ProviderType {
Anthropic, Anthropic,
Bedrock, Bedrock,
Databricks, Databricks,
GcpVertexAI,
Google, Google,
Groq, Groq,
Ollama, Ollama,
@@ -42,6 +43,7 @@ impl ProviderType {
ProviderType::Groq => &["GROQ_API_KEY"], ProviderType::Groq => &["GROQ_API_KEY"],
ProviderType::Ollama => &[], ProviderType::Ollama => &[],
ProviderType::OpenRouter => &["OPENROUTER_API_KEY"], ProviderType::OpenRouter => &["OPENROUTER_API_KEY"],
ProviderType::GcpVertexAI => &["GCP_PROJECT_ID", "GCP_LOCATION"],
} }
} }
@@ -70,6 +72,7 @@ impl ProviderType {
ProviderType::Anthropic => Box::new(AnthropicProvider::from_env(model_config)?), ProviderType::Anthropic => Box::new(AnthropicProvider::from_env(model_config)?),
ProviderType::Bedrock => Box::new(BedrockProvider::from_env(model_config)?), ProviderType::Bedrock => Box::new(BedrockProvider::from_env(model_config)?),
ProviderType::Databricks => Box::new(DatabricksProvider::from_env(model_config)?), ProviderType::Databricks => Box::new(DatabricksProvider::from_env(model_config)?),
ProviderType::GcpVertexAI => Box::new(GcpVertexAIProvider::from_env(model_config)?),
ProviderType::Google => Box::new(GoogleProvider::from_env(model_config)?), ProviderType::Google => Box::new(GoogleProvider::from_env(model_config)?),
ProviderType::Groq => Box::new(GroqProvider::from_env(model_config)?), ProviderType::Groq => Box::new(GroqProvider::from_env(model_config)?),
ProviderType::Ollama => Box::new(OllamaProvider::from_env(model_config)?), ProviderType::Ollama => Box::new(OllamaProvider::from_env(model_config)?),
@@ -290,4 +293,14 @@ mod tests {
}) })
.await .await
} }
#[tokio::test]
async fn test_truncate_agent_with_gcpvertexai() -> Result<()> {
run_test_with_config(TestConfig {
provider_type: ProviderType::GcpVertexAI,
model: "claude-3-5-sonnet-v2@20241022",
context_window: 200_000,
})
.await
}
} }

View File

@@ -17,17 +17,18 @@ Goose relies heavily on tool calling capabilities and currently works best with
## Available Providers ## Available Providers
| Provider | Description | Parameters | | Provider | Description | Parameters |
|-----------------------------------------------|-----------------------------------------------------|---------------------------------------| |-----------------------------------------------------------------------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|[Amazon Bedrock](https://aws.amazon.com/bedrock/)| Offers a variety of foundation models, including Claude, Jurassic-2, and others. **Environment variables must be set in advance, not configured through `goose configure`** | `AWS_ACCESS_KEY_ID`, `AWS_SECRET_ACCESS_KEY`, `AWS_REGION`| | [Amazon Bedrock](https://aws.amazon.com/bedrock/) | Offers a variety of foundation models, including Claude, Jurassic-2, and others. **Environment variables must be set in advance, not configured through `goose configure`** | `AWS_ACCESS_KEY_ID`, `AWS_SECRET_ACCESS_KEY`, `AWS_REGION` |
| [Anthropic](https://www.anthropic.com/) | Offers Claude, an advanced AI model for natural language tasks. | `ANTHROPIC_API_KEY` | | [Anthropic](https://www.anthropic.com/) | Offers Claude, an advanced AI model for natural language tasks. | `ANTHROPIC_API_KEY` |
|[Azure OpenAI](https://learn.microsoft.com/en-us/azure/ai-services/openai/) | Access Azure-hosted OpenAI models, including GPT-4 and GPT-3.5.| `AZURE_OPENAI_API_KEY`, `AZURE_OPENAI_ENDPOINT`, `AZURE_OPENAI_DEPLOYMENT_NAME` | | [Azure OpenAI](https://learn.microsoft.com/en-us/azure/ai-services/openai/) | Access Azure-hosted OpenAI models, including GPT-4 and GPT-3.5. | `AZURE_OPENAI_API_KEY`, `AZURE_OPENAI_ENDPOINT`, `AZURE_OPENAI_DEPLOYMENT_NAME` |
| [Databricks](https://www.databricks.com/) | Unified data analytics and AI platform for building and deploying models. | `DATABRICKS_HOST`, `DATABRICKS_TOKEN` | | [Databricks](https://www.databricks.com/) | Unified data analytics and AI platform for building and deploying models. | `DATABRICKS_HOST`, `DATABRICKS_TOKEN` |
| [Gemini](https://ai.google.dev/gemini-api/docs) | Advanced LLMs by Google with multimodal capabilities (text, images). | `GOOGLE_API_KEY` | | [Gemini](https://ai.google.dev/gemini-api/docs) | Advanced LLMs by Google with multimodal capabilities (text, images). | `GOOGLE_API_KEY` |
| [Groq](https://groq.com/) | High-performance inference hardware and tools for LLMs. | `GROQ_API_KEY` | | [GCP Vertex AI](https://cloud.google.com/vertex-ai) | Google Cloud's Vertex AI platform, supporting Gemini and Claude models. **Credentials must be configured in advance. Follow the instructions at https://cloud.google.com/vertex-ai/docs/authentication.** | `GCP_PROJECT_ID`, `GCP_LOCATION` and optional `GCP_MAX_RETRIES` (6), `GCP_INITIAL_RETRY_INTERVAL_MS` (5000), `GCP_BACKOFF_MULTIPLIER` (2.0), `GCP_MAX_RETRY_INTERVAL_MS` (320_000). |
| [Ollama](https://ollama.com/) | Local model runner supporting Qwen, Llama, DeepSeek, and other open-source models. **Because this provider runs locally, you must first [download and run a model](/docs/getting-started/providers#local-llms-ollama).** | `OLLAMA_HOST` | | [Groq](https://groq.com/) | High-performance inference hardware and tools for LLMs. | `GROQ_API_KEY` |
| [OpenAI](https://platform.openai.com/api-keys) | Provides gpt-4o, o1, and other advanced language models. Also supports OpenAI-compatible endpoints (e.g., self-hosted LLaMA, vLLM, KServe). **o1-mini and o1-preview are not supported because Goose uses tool calling.** | `OPENAI_API_KEY`, `OPENAI_HOST` (optional), `OPENAI_ORGANIZATION` (optional), `OPENAI_PROJECT` (optional) | | [Ollama](https://ollama.com/) | Local model runner supporting Qwen, Llama, DeepSeek, and other open-source models. **Because this provider runs locally, you must first [download and run a model](/docs/getting-started/providers#local-llms-ollama).** | `OLLAMA_HOST` |
| [OpenRouter](https://openrouter.ai/) | API gateway for unified access to various models with features like rate-limiting management. | `OPENROUTER_API_KEY` | | [OpenAI](https://platform.openai.com/api-keys) | Provides gpt-4o, o1, and other advanced language models. Also supports OpenAI-compatible endpoints (e.g., self-hosted LLaMA, vLLM, KServe). **o1-mini and o1-preview are not supported because Goose uses tool calling.** | `OPENAI_API_KEY`, `OPENAI_HOST` (optional), `OPENAI_ORGANIZATION` (optional), `OPENAI_PROJECT` (optional) |
| [OpenRouter](https://openrouter.ai/) | API gateway for unified access to various models with features like rate-limiting management. | `OPENROUTER_API_KEY` |

View File

@@ -11,6 +11,8 @@ export function isSecretKey(keyName: string): boolean {
'OPENAI_BASE_PATH', 'OPENAI_BASE_PATH',
'AZURE_OPENAI_ENDPOINT', 'AZURE_OPENAI_ENDPOINT',
'AZURE_OPENAI_DEPLOYMENT_NAME', 'AZURE_OPENAI_DEPLOYMENT_NAME',
'GCP_PROJECT_ID',
'GCP_LOCATION',
]; ];
return !nonSecretKeys.includes(keyName); return !nonSecretKeys.includes(keyName);
} }

View File

@@ -19,6 +19,13 @@ export const goose_models: Model[] = [
{ id: 17, name: 'qwen2.5', provider: 'Ollama' }, { id: 17, name: 'qwen2.5', provider: 'Ollama' },
{ id: 18, name: 'anthropic/claude-3.5-sonnet', provider: 'OpenRouter' }, { id: 18, name: 'anthropic/claude-3.5-sonnet', provider: 'OpenRouter' },
{ id: 19, name: 'gpt-4o', provider: 'Azure OpenAI' }, { id: 19, name: 'gpt-4o', provider: 'Azure OpenAI' },
{ id: 20, name: 'claude-3-7-sonnet@20250219', provider: 'GCP Vertex AI' },
{ id: 21, name: 'claude-3-5-sonnet-v2@20241022', provider: 'GCP Vertex AI' },
{ id: 22, name: 'claude-3-5-sonnet@20240620', provider: 'GCP Vertex AI' },
{ id: 23, name: 'claude-3-5-haiku@20241022', provider: 'GCP Vertex AI' },
{ id: 24, name: 'gemini-2.0-pro-exp-02-05', provider: 'GCP Vertex AI' },
{ id: 25, name: 'gemini-2.0-flash-001', provider: 'GCP Vertex AI' },
{ id: 26, name: 'gemini-1.5-pro-002', provider: 'GCP Vertex AI' },
]; ];
export const openai_models = ['gpt-4o-mini', 'gpt-4o', 'gpt-4-turbo', 'o1']; export const openai_models = ['gpt-4o-mini', 'gpt-4o', 'gpt-4-turbo', 'o1'];
@@ -47,6 +54,16 @@ export const openrouter_models = ['anthropic/claude-3.5-sonnet'];
export const azure_openai_models = ['gpt-4o']; export const azure_openai_models = ['gpt-4o'];
export const gcp_vertex_ai_models = [
'claude-3-7-sonnet@20250219',
'claude-3-5-sonnet-v2@20241022',
'claude-3-5-sonnet@20240620',
'claude-3-5-haiku@20241022',
'gemini-1.5-pro-002',
'gemini-2.0-flash-001',
'gemini-2.0-pro-exp-02-05',
];
export const default_models = { export const default_models = {
openai: 'gpt-4o', openai: 'gpt-4o',
anthropic: 'claude-3-5-sonnet-latest', anthropic: 'claude-3-5-sonnet-latest',
@@ -56,6 +73,7 @@ export const default_models = {
openrouter: 'anthropic/claude-3.5-sonnet', openrouter: 'anthropic/claude-3.5-sonnet',
ollama: 'qwen2.5', ollama: 'qwen2.5',
azure_openai: 'gpt-4o', azure_openai: 'gpt-4o',
gcp_vertex_ai: 'gemini-2.0-flash-001',
}; };
export function getDefaultModel(key: string): string | undefined { export function getDefaultModel(key: string): string | undefined {
@@ -73,12 +91,14 @@ export const required_keys = {
Google: ['GOOGLE_API_KEY'], Google: ['GOOGLE_API_KEY'],
OpenRouter: ['OPENROUTER_API_KEY'], OpenRouter: ['OPENROUTER_API_KEY'],
'Azure OpenAI': ['AZURE_OPENAI_API_KEY', 'AZURE_OPENAI_ENDPOINT', 'AZURE_OPENAI_DEPLOYMENT_NAME'], 'Azure OpenAI': ['AZURE_OPENAI_API_KEY', 'AZURE_OPENAI_ENDPOINT', 'AZURE_OPENAI_DEPLOYMENT_NAME'],
'GCP Vertex AI': ['GCP_PROJECT_ID', 'GCP_LOCATION'],
}; };
export const default_key_value = { export const default_key_value = {
OPENAI_HOST: 'https://api.openai.com', OPENAI_HOST: 'https://api.openai.com',
OPENAI_BASE_PATH: 'v1/chat/completions', OPENAI_BASE_PATH: 'v1/chat/completions',
OLLAMA_HOST: 'localhost', OLLAMA_HOST: 'localhost',
GCP_LOCATION: 'us-central1',
}; };
export const supported_providers = [ export const supported_providers = [
@@ -90,6 +110,7 @@ export const supported_providers = [
'Ollama', 'Ollama',
'OpenRouter', 'OpenRouter',
'Azure OpenAI', 'Azure OpenAI',
'GCP Vertex AI',
]; ];
export const model_docs_link = [ export const model_docs_link = [
@@ -103,6 +124,7 @@ export const model_docs_link = [
}, },
{ name: 'OpenRouter', href: 'https://openrouter.ai/models' }, { name: 'OpenRouter', href: 'https://openrouter.ai/models' },
{ name: 'Ollama', href: 'https://ollama.com/library' }, { name: 'Ollama', href: 'https://ollama.com/library' },
{ name: 'GCP Vertex AI', href: 'https://cloud.google.com/vertex-ai' },
]; ];
export const provider_aliases = [ export const provider_aliases = [
@@ -114,4 +136,5 @@ export const provider_aliases = [
{ provider: 'OpenRouter', alias: 'openrouter' }, { provider: 'OpenRouter', alias: 'openrouter' },
{ provider: 'Google', alias: 'google' }, { provider: 'Google', alias: 'google' },
{ provider: 'Azure OpenAI', alias: 'azure_openai' }, { provider: 'Azure OpenAI', alias: 'azure_openai' },
{ provider: 'GCP Vertex AI', alias: 'gcp_vertex_ai' },
]; ];