mirror of
https://github.com/aljazceru/goose.git
synced 2025-12-18 14:44:21 +01:00
feat: add GCP Vertex AI platform as provider (#1364)
Signed-off-by: Uddhav Kambli <uddhav@kambli.net>
This commit is contained in:
129
Cargo.lock
generated
129
Cargo.lock
generated
@@ -230,9 +230,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "async-trait"
|
||||
version = "0.1.85"
|
||||
version = "0.1.86"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3f934833b4b7233644e5848f235df3f57ed8c80f1528a26c3dfa13d2147fa056"
|
||||
checksum = "644dd749086bf3771a2fbc5f256fdb982d53f011c7d5d560304eafeecebce79d"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
@@ -1699,6 +1699,12 @@ version = "0.15.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "77c90badedccf4105eca100756a0b1289e191f6fcbdadd3cee1d2f614f97da8f"
|
||||
|
||||
[[package]]
|
||||
name = "downcast"
|
||||
version = "0.11.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1435fa1053d8b2fbbe9be7e97eca7f33d37b28409959813daefc1446a14247f1"
|
||||
|
||||
[[package]]
|
||||
name = "dyn-clone"
|
||||
version = "1.0.17"
|
||||
@@ -1883,6 +1889,12 @@ dependencies = [
|
||||
"percent-encoding",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "fragile"
|
||||
version = "2.0.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6c2141d6d6c8512188a7891b4b01590a45f6dac67afb4f255c4124dbb86d4eaa"
|
||||
|
||||
[[package]]
|
||||
name = "futures"
|
||||
version = "0.3.31"
|
||||
@@ -2149,10 +2161,12 @@ dependencies = [
|
||||
"futures",
|
||||
"include_dir",
|
||||
"indoc",
|
||||
"jsonwebtoken",
|
||||
"keyring",
|
||||
"lazy_static",
|
||||
"mcp-client",
|
||||
"mcp-core",
|
||||
"mockall",
|
||||
"nanoid",
|
||||
"once_cell",
|
||||
"paste",
|
||||
@@ -2571,7 +2585,7 @@ dependencies = [
|
||||
"http 1.2.0",
|
||||
"hyper 1.6.0",
|
||||
"hyper-util",
|
||||
"rustls 0.23.21",
|
||||
"rustls 0.23.23",
|
||||
"rustls-native-certs 0.8.1",
|
||||
"rustls-pki-types",
|
||||
"tokio",
|
||||
@@ -3028,6 +3042,21 @@ dependencies = [
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "jsonwebtoken"
|
||||
version = "9.3.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5a87cc7a48537badeae96744432de36f4be2b4a34a05a5ef32e9dd8a1c169dde"
|
||||
dependencies = [
|
||||
"base64 0.22.1",
|
||||
"js-sys",
|
||||
"pem",
|
||||
"ring",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"simple_asn1",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "keyring"
|
||||
version = "3.6.1"
|
||||
@@ -3367,6 +3396,32 @@ dependencies = [
|
||||
"windows-sys 0.52.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "mockall"
|
||||
version = "0.13.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "39a6bfcc6c8c7eed5ee98b9c3e33adc726054389233e201c95dab2d41a3839d2"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"downcast",
|
||||
"fragile",
|
||||
"mockall_derive",
|
||||
"predicates",
|
||||
"predicates-tree",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "mockall_derive"
|
||||
version = "0.13.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "25ca3004c2efe9011bd4e461bd8256445052b9615405b4f7ea43fc8ca5c20898"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.96",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "monostate"
|
||||
version = "0.1.13"
|
||||
@@ -3739,6 +3794,16 @@ version = "0.2.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "df94ce210e5bc13cb6651479fa48d14f601d9858cfe0467f43ae157023b938d3"
|
||||
|
||||
[[package]]
|
||||
name = "pem"
|
||||
version = "3.0.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "38af38e8470ac9dee3ce1bae1af9c1671fffc44ddfd8bd1d0a3445bf349a8ef3"
|
||||
dependencies = [
|
||||
"base64 0.22.1",
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "percent-encoding"
|
||||
version = "2.3.1"
|
||||
@@ -3941,6 +4006,32 @@ dependencies = [
|
||||
"zerocopy",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "predicates"
|
||||
version = "3.1.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a5d19ee57562043d37e82899fade9a22ebab7be9cef5026b07fda9cdd4293573"
|
||||
dependencies = [
|
||||
"anstyle",
|
||||
"predicates-core",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "predicates-core"
|
||||
version = "1.0.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "727e462b119fe9c93fd0eb1429a5f7647394014cf3c04ab2c0350eeb09095ffa"
|
||||
|
||||
[[package]]
|
||||
name = "predicates-tree"
|
||||
version = "1.0.12"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "72dd2d6d381dfb73a193c7fca536518d7caee39fc8503f74e7dc0be0531b425c"
|
||||
dependencies = [
|
||||
"predicates-core",
|
||||
"termtree",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "prettyplease"
|
||||
version = "0.2.29"
|
||||
@@ -4063,7 +4154,7 @@ dependencies = [
|
||||
"quinn-proto",
|
||||
"quinn-udp",
|
||||
"rustc-hash 2.1.0",
|
||||
"rustls 0.23.21",
|
||||
"rustls 0.23.23",
|
||||
"socket2",
|
||||
"thiserror 2.0.11",
|
||||
"tokio",
|
||||
@@ -4081,7 +4172,7 @@ dependencies = [
|
||||
"rand",
|
||||
"ring",
|
||||
"rustc-hash 2.1.0",
|
||||
"rustls 0.23.21",
|
||||
"rustls 0.23.23",
|
||||
"rustls-pki-types",
|
||||
"slab",
|
||||
"thiserror 2.0.11",
|
||||
@@ -4393,7 +4484,7 @@ dependencies = [
|
||||
"percent-encoding",
|
||||
"pin-project-lite",
|
||||
"quinn",
|
||||
"rustls 0.23.21",
|
||||
"rustls 0.23.23",
|
||||
"rustls-pemfile 2.2.0",
|
||||
"rustls-pki-types",
|
||||
"serde",
|
||||
@@ -4514,9 +4605,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "rustls"
|
||||
version = "0.23.21"
|
||||
version = "0.23.23"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8f287924602bf649d949c63dc8ac8b235fa5387d394020705b80c4eb597ce5b8"
|
||||
checksum = "47796c98c480fce5406ef69d1c76378375492c3b0a0de587be0c1d9feb12f395"
|
||||
dependencies = [
|
||||
"once_cell",
|
||||
"ring",
|
||||
@@ -4977,6 +5068,18 @@ dependencies = [
|
||||
"quote",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "simple_asn1"
|
||||
version = "0.6.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "297f631f50729c8c99b84667867963997ec0b50f32b2a7dbcab828ef0541e8bb"
|
||||
dependencies = [
|
||||
"num-bigint",
|
||||
"num-traits",
|
||||
"thiserror 2.0.11",
|
||||
"time",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "siphasher"
|
||||
version = "1.0.1"
|
||||
@@ -5268,6 +5371,12 @@ dependencies = [
|
||||
"windows-sys 0.59.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "termtree"
|
||||
version = "0.5.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8f50febec83f5ee1df3015341d8bd429f2d1cc62bcba7ea2076759d315084683"
|
||||
|
||||
[[package]]
|
||||
name = "test-case"
|
||||
version = "3.3.1"
|
||||
@@ -5537,7 +5646,7 @@ version = "0.26.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5f6d0975eaace0cf0fcadee4e4aaa5da15b5c079146f2cffb67c113be122bf37"
|
||||
dependencies = [
|
||||
"rustls 0.23.21",
|
||||
"rustls 0.23.23",
|
||||
"tokio",
|
||||
]
|
||||
|
||||
@@ -6767,7 +6876,7 @@ dependencies = [
|
||||
"hyper-util",
|
||||
"log",
|
||||
"percent-encoding",
|
||||
"rustls 0.23.21",
|
||||
"rustls 0.23.23",
|
||||
"rustls-pemfile 2.2.0",
|
||||
"seahash",
|
||||
"serde",
|
||||
|
||||
@@ -17,6 +17,12 @@
|
||||
"models": ["goose"],
|
||||
"required_keys": ["DATABRICKS_HOST"]
|
||||
},
|
||||
"gcp_vertex_ai": {
|
||||
"name": "GCP Vertex AI",
|
||||
"description": "Use Vertex AI platform models",
|
||||
"models": ["claude-3-5-haiku@20241022", "claude-3-5-sonnet@20240620", "claude-3-5-sonnet-v2@20241022", "claude-3-7-sonnet@20250219", "gemini-1.5-pro-002", "gemini-2.0-flash-001", "gemini-2.0-pro-exp-02-05"],
|
||||
"required_keys": ["GCP_PROJECT_ID", "GCP_LOCATION"]
|
||||
},
|
||||
"google": {
|
||||
"name": "Google",
|
||||
"description": "Lorem ipsum",
|
||||
|
||||
@@ -66,6 +66,9 @@ aws-config = { version = "1.1.7", features = ["behavior-version-latest"] }
|
||||
aws-smithy-types = "1.2.12"
|
||||
aws-sdk-bedrockruntime = "1.72.0"
|
||||
|
||||
# For GCP Vertex AI provider auth
|
||||
jsonwebtoken = "9.3.1"
|
||||
|
||||
[target.'cfg(target_os = "windows")'.dependencies]
|
||||
winapi = { version = "0.3", features = ["wincred"] }
|
||||
|
||||
@@ -73,6 +76,9 @@ winapi = { version = "0.3", features = ["wincred"] }
|
||||
criterion = "0.5"
|
||||
tempfile = "3.15.0"
|
||||
serial_test = "3.2.0"
|
||||
mockall = "0.13.1"
|
||||
wiremock = "0.6.0"
|
||||
tokio = { version = "1.0", features = ["full"] }
|
||||
|
||||
[[example]]
|
||||
name = "agent"
|
||||
|
||||
@@ -4,6 +4,7 @@ use super::{
|
||||
base::{Provider, ProviderMetadata},
|
||||
bedrock::BedrockProvider,
|
||||
databricks::DatabricksProvider,
|
||||
gcpvertexai::GcpVertexAIProvider,
|
||||
google::GoogleProvider,
|
||||
groq::GroqProvider,
|
||||
ollama::OllamaProvider,
|
||||
@@ -19,6 +20,7 @@ pub fn providers() -> Vec<ProviderMetadata> {
|
||||
AzureProvider::metadata(),
|
||||
BedrockProvider::metadata(),
|
||||
DatabricksProvider::metadata(),
|
||||
GcpVertexAIProvider::metadata(),
|
||||
GoogleProvider::metadata(),
|
||||
GroqProvider::metadata(),
|
||||
OllamaProvider::metadata(),
|
||||
@@ -37,6 +39,7 @@ pub fn create(name: &str, model: ModelConfig) -> Result<Box<dyn Provider + Send
|
||||
"groq" => Ok(Box::new(GroqProvider::from_env(model)?)),
|
||||
"ollama" => Ok(Box::new(OllamaProvider::from_env(model)?)),
|
||||
"openrouter" => Ok(Box::new(OpenRouterProvider::from_env(model)?)),
|
||||
"gcp_vertex_ai" => Ok(Box::new(GcpVertexAIProvider::from_env(model)?)),
|
||||
"google" => Ok(Box::new(GoogleProvider::from_env(model)?)),
|
||||
_ => Err(anyhow::anyhow!("Unknown provider: {}", name)),
|
||||
}
|
||||
|
||||
369
crates/goose/src/providers/formats/gcpvertexai.rs
Normal file
369
crates/goose/src/providers/formats/gcpvertexai.rs
Normal file
@@ -0,0 +1,369 @@
|
||||
use super::{anthropic, google};
|
||||
use crate::message::Message;
|
||||
use crate::model::ModelConfig;
|
||||
use crate::providers::base::Usage;
|
||||
use anyhow::{Context, Result};
|
||||
use mcp_core::tool::Tool;
|
||||
use serde_json::Value;
|
||||
|
||||
use std::fmt;
|
||||
|
||||
/// Sensible default values of Google Cloud Platform (GCP) locations for model deployment.
|
||||
///
|
||||
/// Each variant corresponds to a specific GCP region where models can be hosted.
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Copy)]
|
||||
pub enum GcpLocation {
|
||||
/// Represents the us-central1 region in Iowa
|
||||
Iowa,
|
||||
/// Represents the us-east5 region in Ohio
|
||||
Ohio,
|
||||
}
|
||||
|
||||
impl fmt::Display for GcpLocation {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match self {
|
||||
Self::Iowa => write!(f, "us-central1"),
|
||||
Self::Ohio => write!(f, "us-east5"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<&str> for GcpLocation {
|
||||
type Error = ModelError;
|
||||
|
||||
fn try_from(s: &str) -> Result<Self, Self::Error> {
|
||||
match s {
|
||||
"us-central1" => Ok(Self::Iowa),
|
||||
"us-east5" => Ok(Self::Ohio),
|
||||
_ => Err(ModelError::UnsupportedLocation(s.to_string())),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Represents errors that can occur during model operations.
|
||||
///
|
||||
/// This enum encompasses various error conditions that might arise when working
|
||||
/// with GCP Vertex AI models, including unsupported models, invalid requests,
|
||||
/// and unsupported locations.
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum ModelError {
|
||||
/// Error when an unsupported Vertex AI model is specified
|
||||
#[error("Unsupported Vertex AI model: {0}")]
|
||||
UnsupportedModel(String),
|
||||
/// Error when the request structure is invalid
|
||||
#[error("Invalid request structure: {0}")]
|
||||
InvalidRequest(String),
|
||||
/// Error when an unsupported GCP location is specified
|
||||
#[error("Unsupported GCP location: {0}")]
|
||||
UnsupportedLocation(String),
|
||||
}
|
||||
|
||||
/// Represents available GCP Vertex AI models for Goose.
|
||||
///
|
||||
/// This enum encompasses different model families and their versions
|
||||
/// that are supported in the GCP Vertex AI platform.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub enum GcpVertexAIModel {
|
||||
/// Claude model family with specific versions
|
||||
Claude(ClaudeVersion),
|
||||
/// Gemini model family with specific versions
|
||||
Gemini(GeminiVersion),
|
||||
}
|
||||
|
||||
/// Represents available versions of the Claude model for Goose.
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Copy)]
|
||||
pub enum ClaudeVersion {
|
||||
/// Claude 3.5 Sonnet initial version
|
||||
Sonnet35,
|
||||
/// Claude 3.5 Sonnet version 2
|
||||
Sonnet35V2,
|
||||
/// Claude 3.7 Sonnet
|
||||
Sonnet37,
|
||||
/// Claude 3.5 Haiku
|
||||
Haiku35,
|
||||
}
|
||||
|
||||
/// Represents available versions of the Gemini model for Goose.
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Copy)]
|
||||
pub enum GeminiVersion {
|
||||
/// Gemini 1.5 Pro version
|
||||
Pro15,
|
||||
/// Gemini 2.0 Flash version
|
||||
Flash20,
|
||||
/// Gemini 2.0 Pro Experimental version
|
||||
Pro20Exp,
|
||||
}
|
||||
|
||||
impl fmt::Display for GcpVertexAIModel {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
let model_id = match self {
|
||||
Self::Claude(version) => match version {
|
||||
ClaudeVersion::Sonnet35 => "claude-3-5-sonnet@20240620",
|
||||
ClaudeVersion::Sonnet35V2 => "claude-3-5-sonnet-v2@20241022",
|
||||
ClaudeVersion::Sonnet37 => "claude-3-7-sonnet@20250219",
|
||||
ClaudeVersion::Haiku35 => "claude-3-5-haiku@20241022",
|
||||
},
|
||||
Self::Gemini(version) => match version {
|
||||
GeminiVersion::Pro15 => "gemini-1.5-pro-002",
|
||||
GeminiVersion::Flash20 => "gemini-2.0-flash-001",
|
||||
GeminiVersion::Pro20Exp => "gemini-2.0-pro-exp-02-05",
|
||||
},
|
||||
};
|
||||
write!(f, "{model_id}")
|
||||
}
|
||||
}
|
||||
|
||||
impl GcpVertexAIModel {
|
||||
/// Returns the default GCP location for the model.
|
||||
///
|
||||
/// Each model family has a well-known location:
|
||||
/// - Claude models default to Ohio (us-east5)
|
||||
/// - Gemini models default to Iowa (us-central1)
|
||||
pub fn known_location(&self) -> GcpLocation {
|
||||
match self {
|
||||
Self::Claude(_) => GcpLocation::Ohio,
|
||||
Self::Gemini(_) => GcpLocation::Iowa,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<&str> for GcpVertexAIModel {
|
||||
type Error = ModelError;
|
||||
|
||||
fn try_from(s: &str) -> Result<Self, Self::Error> {
|
||||
match s {
|
||||
"claude-3-5-sonnet@20240620" => Ok(Self::Claude(ClaudeVersion::Sonnet35)),
|
||||
"claude-3-5-sonnet-v2@20241022" => Ok(Self::Claude(ClaudeVersion::Sonnet35V2)),
|
||||
"claude-3-7-sonnet@20250219" => Ok(Self::Claude(ClaudeVersion::Sonnet37)),
|
||||
"claude-3-5-haiku@20241022" => Ok(Self::Claude(ClaudeVersion::Haiku35)),
|
||||
"gemini-1.5-pro-002" => Ok(Self::Gemini(GeminiVersion::Pro15)),
|
||||
"gemini-2.0-flash-001" => Ok(Self::Gemini(GeminiVersion::Flash20)),
|
||||
"gemini-2.0-pro-exp-02-05" => Ok(Self::Gemini(GeminiVersion::Pro20Exp)),
|
||||
_ => Err(ModelError::UnsupportedModel(s.to_string())),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Holds context information for a model request since the Vertex AI platform
|
||||
/// supports multiple model families.
|
||||
///
|
||||
/// This structure maintains information about the model being used
|
||||
/// and provides utility methods for handling model-specific operations.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct RequestContext {
|
||||
/// The GCP Vertex AI model being used
|
||||
pub model: GcpVertexAIModel,
|
||||
}
|
||||
|
||||
impl RequestContext {
|
||||
/// Creates a new RequestContext from a model ID string.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `model_id` - The string identifier of the model
|
||||
///
|
||||
/// # Returns
|
||||
/// * `Result<Self>` - A new RequestContext if the model ID is valid
|
||||
pub fn new(model_id: &str) -> Result<Self> {
|
||||
Ok(Self {
|
||||
model: GcpVertexAIModel::try_from(model_id)
|
||||
.with_context(|| format!("Failed to parse model ID: {model_id}"))?,
|
||||
})
|
||||
}
|
||||
|
||||
/// Returns the provider associated with the model.
|
||||
pub fn provider(&self) -> ModelProvider {
|
||||
match self.model {
|
||||
GcpVertexAIModel::Claude(_) => ModelProvider::Anthropic,
|
||||
GcpVertexAIModel::Gemini(_) => ModelProvider::Google,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Represents available model providers.
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Copy)]
|
||||
pub enum ModelProvider {
|
||||
/// Anthropic provider (Claude models)
|
||||
Anthropic,
|
||||
/// Google provider (Gemini models)
|
||||
Google,
|
||||
}
|
||||
|
||||
impl ModelProvider {
|
||||
/// Returns the string representation of the provider.
|
||||
pub fn as_str(&self) -> &'static str {
|
||||
match self {
|
||||
Self::Anthropic => "anthropic",
|
||||
Self::Google => "google",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates an Anthropic-specific Vertex AI request payload.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `model_config` - Configuration for the model
|
||||
/// * `system` - System prompt
|
||||
/// * `messages` - Array of messages
|
||||
/// * `tools` - Array of available tools
|
||||
///
|
||||
/// # Returns
|
||||
/// * `Result<Value>` - JSON request payload for Anthropic API
|
||||
fn create_anthropic_request(
|
||||
model_config: &ModelConfig,
|
||||
system: &str,
|
||||
messages: &[Message],
|
||||
tools: &[Tool],
|
||||
) -> Result<Value> {
|
||||
let mut request = anthropic::create_request(model_config, system, messages, tools)?;
|
||||
|
||||
let obj = request
|
||||
.as_object_mut()
|
||||
.ok_or_else(|| ModelError::InvalidRequest("Request is not a JSON object".to_string()))?;
|
||||
|
||||
// Note: We don't need to specify the model in the request body
|
||||
// The model is determined by the endpoint URL in GCP Vertex AI
|
||||
obj.remove("model");
|
||||
obj.insert(
|
||||
"anthropic_version".to_string(),
|
||||
Value::String("vertex-2023-10-16".to_string()),
|
||||
);
|
||||
|
||||
Ok(request)
|
||||
}
|
||||
|
||||
/// Creates a Gemini-specific Vertex AI request payload.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `model_config` - Configuration for the model
|
||||
/// * `system` - System prompt
|
||||
/// * `messages` - Array of messages
|
||||
/// * `tools` - Array of available tools
|
||||
///
|
||||
/// # Returns
|
||||
/// * `Result<Value>` - JSON request payload for Google API
|
||||
fn create_google_request(
|
||||
model_config: &ModelConfig,
|
||||
system: &str,
|
||||
messages: &[Message],
|
||||
tools: &[Tool],
|
||||
) -> Result<Value> {
|
||||
google::create_request(model_config, system, messages, tools)
|
||||
}
|
||||
|
||||
/// Creates a provider-specific request payload and context.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `model_config` - Configuration for the model
|
||||
/// * `system` - System prompt
|
||||
/// * `messages` - Array of messages
|
||||
/// * `tools` - Array of available tools
|
||||
///
|
||||
/// # Returns
|
||||
/// * `Result<(Value, RequestContext)>` - Tuple of request payload and context
|
||||
pub fn create_request(
|
||||
model_config: &ModelConfig,
|
||||
system: &str,
|
||||
messages: &[Message],
|
||||
tools: &[Tool],
|
||||
) -> Result<(Value, RequestContext)> {
|
||||
let context = RequestContext::new(&model_config.model_name)?;
|
||||
|
||||
let request = match &context.model {
|
||||
GcpVertexAIModel::Claude(_) => {
|
||||
create_anthropic_request(model_config, system, messages, tools)?
|
||||
}
|
||||
GcpVertexAIModel::Gemini(_) => {
|
||||
create_google_request(model_config, system, messages, tools)?
|
||||
}
|
||||
};
|
||||
|
||||
Ok((request, context))
|
||||
}
|
||||
|
||||
/// Converts a provider response to a Message.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `response` - The raw response from the provider
|
||||
/// * `request_context` - Context information about the request
|
||||
///
|
||||
/// # Returns
|
||||
/// * `Result<Message>` - Converted message
|
||||
pub fn response_to_message(response: Value, request_context: RequestContext) -> Result<Message> {
|
||||
match request_context.provider() {
|
||||
ModelProvider::Anthropic => anthropic::response_to_message(response),
|
||||
ModelProvider::Google => google::response_to_message(response),
|
||||
}
|
||||
}
|
||||
|
||||
/// Extracts token usage information from the response data.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `data` - The response data containing usage information
|
||||
/// * `request_context` - Context information about the request
|
||||
///
|
||||
/// # Returns
|
||||
/// * `Result<Usage>` - Usage statistics
|
||||
pub fn get_usage(data: &Value, request_context: &RequestContext) -> Result<Usage> {
|
||||
match request_context.provider() {
|
||||
ModelProvider::Anthropic => anthropic::get_usage(data),
|
||||
ModelProvider::Google => google::get_usage(data),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use anyhow::Result;
|
||||
|
||||
#[test]
|
||||
fn test_model_parsing() -> Result<()> {
|
||||
let valid_models = [
|
||||
"claude-3-5-sonnet@20240620",
|
||||
"claude-3-5-sonnet-v2@20241022",
|
||||
"claude-3-7-sonnet@20250219",
|
||||
"claude-3-5-haiku@20241022",
|
||||
"gemini-1.5-pro-002",
|
||||
"gemini-2.0-flash-001",
|
||||
"gemini-2.0-pro-exp-02-05",
|
||||
];
|
||||
|
||||
for model_id in valid_models {
|
||||
let model = GcpVertexAIModel::try_from(model_id)?;
|
||||
assert_eq!(model.to_string(), model_id);
|
||||
}
|
||||
|
||||
assert!(GcpVertexAIModel::try_from("unsupported-model").is_err());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_default_locations() -> Result<()> {
|
||||
let test_cases = [
|
||||
("claude-3-5-sonnet@20240620", GcpLocation::Ohio),
|
||||
("claude-3-5-sonnet-v2@20241022", GcpLocation::Ohio),
|
||||
("claude-3-7-sonnet@20250219", GcpLocation::Ohio),
|
||||
("claude-3-5-haiku@20241022", GcpLocation::Ohio),
|
||||
("gemini-1.5-pro-002", GcpLocation::Iowa),
|
||||
("gemini-2.0-flash-001", GcpLocation::Iowa),
|
||||
("gemini-2.0-pro-exp-02-05", GcpLocation::Iowa),
|
||||
];
|
||||
|
||||
for (model_id, expected_location) in test_cases {
|
||||
let model = GcpVertexAIModel::try_from(model_id)?;
|
||||
assert_eq!(
|
||||
model.known_location(),
|
||||
expected_location,
|
||||
"Model {model_id} should have default location {expected_location:?}",
|
||||
);
|
||||
|
||||
let context = RequestContext::new(model_id)?;
|
||||
assert_eq!(
|
||||
context.model.known_location(),
|
||||
expected_location,
|
||||
"RequestContext for {model_id} should have default location {expected_location:?}",
|
||||
);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -1,4 +1,5 @@
|
||||
pub mod anthropic;
|
||||
pub mod bedrock;
|
||||
pub mod gcpvertexai;
|
||||
pub mod google;
|
||||
pub mod openai;
|
||||
|
||||
1115
crates/goose/src/providers/gcpauth.rs
Normal file
1115
crates/goose/src/providers/gcpauth.rs
Normal file
File diff suppressed because it is too large
Load Diff
595
crates/goose/src/providers/gcpvertexai.rs
Normal file
595
crates/goose/src/providers/gcpvertexai.rs
Normal file
@@ -0,0 +1,595 @@
|
||||
use std::time::Duration;
|
||||
|
||||
use anyhow::Result;
|
||||
use async_trait::async_trait;
|
||||
use reqwest::{Client, StatusCode};
|
||||
use serde_json::Value;
|
||||
use tokio::time::sleep;
|
||||
use url::Url;
|
||||
|
||||
use crate::message::Message;
|
||||
use crate::model::ModelConfig;
|
||||
use crate::providers::base::{ConfigKey, Provider, ProviderMetadata, ProviderUsage};
|
||||
|
||||
use crate::providers::errors::ProviderError;
|
||||
use crate::providers::formats::gcpvertexai::{
|
||||
create_request, get_usage, response_to_message, ClaudeVersion, GcpVertexAIModel, GeminiVersion,
|
||||
ModelProvider, RequestContext,
|
||||
};
|
||||
|
||||
use crate::providers::formats::gcpvertexai::GcpLocation::Iowa;
|
||||
use crate::providers::gcpauth::GcpAuth;
|
||||
use crate::providers::utils::emit_debug_trace;
|
||||
use mcp_core::tool::Tool;
|
||||
|
||||
/// Base URL for GCP Vertex AI documentation
|
||||
const GCP_VERTEX_AI_DOC_URL: &str = "https://cloud.google.com/vertex-ai";
|
||||
/// Default timeout for API requests in seconds
|
||||
const DEFAULT_TIMEOUT_SECS: u64 = 600;
|
||||
/// Default initial interval for retry (in milliseconds)
|
||||
const DEFAULT_INITIAL_RETRY_INTERVAL_MS: u64 = 5000;
|
||||
/// Default maximum number of retries
|
||||
const DEFAULT_MAX_RETRIES: usize = 6;
|
||||
/// Default retry backoff multiplier
|
||||
const DEFAULT_BACKOFF_MULTIPLIER: f64 = 2.0;
|
||||
/// Default maximum interval for retry (in milliseconds)
|
||||
const DEFAULT_MAX_RETRY_INTERVAL_MS: u64 = 320_000;
|
||||
|
||||
/// Represents errors specific to GCP Vertex AI operations.
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
enum GcpVertexAIError {
|
||||
/// Error when URL construction fails
|
||||
#[error("Invalid URL configuration: {0}")]
|
||||
InvalidUrl(String),
|
||||
|
||||
/// Error during GCP authentication
|
||||
#[error("Authentication error: {0}")]
|
||||
AuthError(String),
|
||||
}
|
||||
|
||||
/// Retry configuration for handling rate limit errors
|
||||
#[derive(Debug, Clone)]
|
||||
struct RetryConfig {
|
||||
/// Maximum number of retry attempts
|
||||
max_retries: usize,
|
||||
/// Initial interval between retries in milliseconds
|
||||
initial_interval_ms: u64,
|
||||
/// Multiplier for backoff (exponential)
|
||||
backoff_multiplier: f64,
|
||||
/// Maximum interval between retries in milliseconds
|
||||
max_interval_ms: u64,
|
||||
}
|
||||
|
||||
impl Default for RetryConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
max_retries: DEFAULT_MAX_RETRIES,
|
||||
initial_interval_ms: DEFAULT_INITIAL_RETRY_INTERVAL_MS,
|
||||
backoff_multiplier: DEFAULT_BACKOFF_MULTIPLIER,
|
||||
max_interval_ms: DEFAULT_MAX_RETRY_INTERVAL_MS,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl RetryConfig {
|
||||
/// Calculate the delay for a specific retry attempt (with jitter)
|
||||
fn delay_for_attempt(&self, attempt: usize) -> Duration {
|
||||
if attempt == 0 {
|
||||
return Duration::from_millis(0);
|
||||
}
|
||||
|
||||
// Calculate exponential backoff
|
||||
let exponent = (attempt - 1) as u32;
|
||||
let base_delay_ms = (self.initial_interval_ms as f64
|
||||
* self.backoff_multiplier.powi(exponent as i32)) as u64;
|
||||
|
||||
// Apply max limit
|
||||
let capped_delay_ms = std::cmp::min(base_delay_ms, self.max_interval_ms);
|
||||
|
||||
// Add jitter (+/-20% randomness) to avoid thundering herd problem
|
||||
let jitter_factor = 0.8 + (rand::random::<f64>() * 0.4); // Between 0.8 and 1.2
|
||||
let jittered_delay_ms = (capped_delay_ms as f64 * jitter_factor) as u64;
|
||||
|
||||
Duration::from_millis(jittered_delay_ms)
|
||||
}
|
||||
}
|
||||
|
||||
/// Provider implementation for Google Cloud Platform's Vertex AI service.
|
||||
///
|
||||
/// This provider enables interaction with various AI models hosted on GCP Vertex AI,
|
||||
/// including Claude and Gemini model families. It handles authentication, request routing,
|
||||
/// and response processing for the Vertex AI API endpoints.
|
||||
#[derive(Debug, serde::Serialize)]
|
||||
pub struct GcpVertexAIProvider {
|
||||
/// HTTP client for making API requests
|
||||
#[serde(skip)]
|
||||
client: Client,
|
||||
/// GCP authentication handler
|
||||
#[serde(skip)]
|
||||
auth: GcpAuth,
|
||||
/// Base URL for the Vertex AI API
|
||||
host: String,
|
||||
/// GCP project identifier
|
||||
project_id: String,
|
||||
/// GCP region for model deployment
|
||||
location: String,
|
||||
/// Configuration for the specific model being used
|
||||
model: ModelConfig,
|
||||
/// Retry configuration for handling rate limit errors
|
||||
#[serde(skip)]
|
||||
retry_config: RetryConfig,
|
||||
}
|
||||
|
||||
impl GcpVertexAIProvider {
|
||||
/// Creates a new provider instance from environment configuration.
|
||||
///
|
||||
/// This is a convenience method that initializes the provider using
|
||||
/// environment variables and default settings.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `model` - Configuration for the model to be used
|
||||
pub fn from_env(model: ModelConfig) -> Result<Self> {
|
||||
Self::new(model)
|
||||
}
|
||||
|
||||
/// Creates a new provider instance with the specified model configuration.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `model` - Configuration for the model to be used
|
||||
pub fn new(model: ModelConfig) -> Result<Self> {
|
||||
futures::executor::block_on(Self::new_async(model))
|
||||
}
|
||||
|
||||
/// Async implementation of new provider instance creation.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `model` - Configuration for the model to be used
|
||||
async fn new_async(model: ModelConfig) -> Result<Self> {
|
||||
let config = crate::config::Config::global();
|
||||
let project_id = config.get("GCP_PROJECT_ID")?;
|
||||
let location = Self::determine_location(config)?;
|
||||
let host = format!("https://{}-aiplatform.googleapis.com", location);
|
||||
|
||||
let client = Client::builder()
|
||||
.timeout(Duration::from_secs(DEFAULT_TIMEOUT_SECS))
|
||||
.build()?;
|
||||
|
||||
let auth = GcpAuth::new().await?;
|
||||
|
||||
// Load optional retry configuration from environment
|
||||
let retry_config = Self::load_retry_config(config);
|
||||
|
||||
Ok(Self {
|
||||
client,
|
||||
auth,
|
||||
host,
|
||||
project_id,
|
||||
location,
|
||||
model,
|
||||
retry_config,
|
||||
})
|
||||
}
|
||||
|
||||
/// Loads retry configuration from environment variables or uses defaults.
|
||||
fn load_retry_config(config: &crate::config::Config) -> RetryConfig {
|
||||
let max_retries = config
|
||||
.get("GCP_MAX_RETRIES")
|
||||
.ok()
|
||||
.and_then(|v: String| v.parse::<usize>().ok())
|
||||
.unwrap_or(DEFAULT_MAX_RETRIES);
|
||||
|
||||
let initial_interval_ms = config
|
||||
.get("GCP_INITIAL_RETRY_INTERVAL_MS")
|
||||
.ok()
|
||||
.and_then(|v: String| v.parse::<u64>().ok())
|
||||
.unwrap_or(DEFAULT_INITIAL_RETRY_INTERVAL_MS);
|
||||
|
||||
let backoff_multiplier = config
|
||||
.get("GCP_BACKOFF_MULTIPLIER")
|
||||
.ok()
|
||||
.and_then(|v: String| v.parse::<f64>().ok())
|
||||
.unwrap_or(DEFAULT_BACKOFF_MULTIPLIER);
|
||||
|
||||
let max_interval_ms = config
|
||||
.get("GCP_MAX_RETRY_INTERVAL_MS")
|
||||
.ok()
|
||||
.and_then(|v: String| v.parse::<u64>().ok())
|
||||
.unwrap_or(DEFAULT_MAX_RETRY_INTERVAL_MS);
|
||||
|
||||
RetryConfig {
|
||||
max_retries,
|
||||
initial_interval_ms,
|
||||
backoff_multiplier,
|
||||
max_interval_ms,
|
||||
}
|
||||
}
|
||||
|
||||
/// Determines the appropriate GCP location for model deployment.
|
||||
///
|
||||
/// Location is determined in the following order:
|
||||
/// 1. Custom location from GCP_LOCATION environment variable
|
||||
/// 2. Global default location (Iowa)
|
||||
fn determine_location(config: &crate::config::Config) -> Result<String> {
|
||||
Ok(config
|
||||
.get("GCP_LOCATION")
|
||||
.ok()
|
||||
.filter(|location: &String| !location.trim().is_empty())
|
||||
.unwrap_or_else(|| Iowa.to_string()))
|
||||
}
|
||||
|
||||
/// Retrieves an authentication token for API requests.
|
||||
async fn get_auth_header(&self) -> Result<String, GcpVertexAIError> {
|
||||
self.auth
|
||||
.get_token()
|
||||
.await
|
||||
.map(|token| format!("Bearer {}", token.token_value))
|
||||
.map_err(|e| GcpVertexAIError::AuthError(e.to_string()))
|
||||
}
|
||||
|
||||
/// Constructs the appropriate API endpoint URL for a given provider.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `provider` - The model provider (Anthropic or Google)
|
||||
/// * `location` - The GCP location for model deployment
|
||||
fn build_request_url(
|
||||
&self,
|
||||
provider: ModelProvider,
|
||||
location: &str,
|
||||
) -> Result<Url, GcpVertexAIError> {
|
||||
// Create host URL for the specified location
|
||||
let host_url = if self.location == location {
|
||||
self.host.clone()
|
||||
} else {
|
||||
// Only allocate a new string if location differs
|
||||
self.host.replace(&self.location, location)
|
||||
};
|
||||
|
||||
let base_url =
|
||||
Url::parse(&host_url).map_err(|e| GcpVertexAIError::InvalidUrl(e.to_string()))?;
|
||||
|
||||
// Determine endpoint based on provider type
|
||||
let endpoint = match provider {
|
||||
ModelProvider::Anthropic => "streamRawPredict",
|
||||
ModelProvider::Google => "generateContent",
|
||||
};
|
||||
|
||||
// Construct path for URL
|
||||
let path = format!(
|
||||
"v1/projects/{}/locations/{}/publishers/{}/models/{}:{}",
|
||||
self.project_id,
|
||||
location,
|
||||
provider.as_str(),
|
||||
self.model.model_name,
|
||||
endpoint
|
||||
);
|
||||
|
||||
base_url
|
||||
.join(&path)
|
||||
.map_err(|e| GcpVertexAIError::InvalidUrl(e.to_string()))
|
||||
}
|
||||
|
||||
/// Makes an authenticated POST request to the Vertex AI API at a specific location.
|
||||
/// Includes retry logic for 429 Too Many Requests errors.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `payload` - The request payload to send
|
||||
/// * `context` - Request context containing model information
|
||||
/// * `location` - The GCP location for the request
|
||||
async fn post_with_location(
|
||||
&self,
|
||||
payload: &Value,
|
||||
context: &RequestContext,
|
||||
location: &str,
|
||||
) -> Result<Value, ProviderError> {
|
||||
let url = self
|
||||
.build_request_url(context.provider(), location)
|
||||
.map_err(|e| ProviderError::RequestFailed(e.to_string()))?;
|
||||
|
||||
// Initialize retry counter
|
||||
let mut attempts = 0;
|
||||
let mut last_error = None;
|
||||
|
||||
loop {
|
||||
// Check if we've exceeded max retries
|
||||
if attempts > 0 && attempts > self.retry_config.max_retries {
|
||||
let error_msg = format!(
|
||||
"Exceeded maximum retry attempts ({}) for rate limiting (429)",
|
||||
self.retry_config.max_retries
|
||||
);
|
||||
tracing::error!("{}", error_msg);
|
||||
return Err(last_error.unwrap_or(ProviderError::RateLimitExceeded(error_msg)));
|
||||
}
|
||||
|
||||
// Get a fresh auth token for each attempt
|
||||
let auth_header = self
|
||||
.get_auth_header()
|
||||
.await
|
||||
.map_err(|e| ProviderError::Authentication(e.to_string()))?;
|
||||
|
||||
// Make the request
|
||||
let response = self
|
||||
.client
|
||||
.post(url.clone())
|
||||
.json(payload)
|
||||
.header("Authorization", auth_header)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| ProviderError::RequestFailed(e.to_string()))?;
|
||||
|
||||
let status = response.status();
|
||||
|
||||
// If not a 429, process normally
|
||||
if status != StatusCode::TOO_MANY_REQUESTS {
|
||||
let response_json = response.json::<Value>().await.map_err(|e| {
|
||||
ProviderError::RequestFailed(format!("Failed to parse response: {e}"))
|
||||
})?;
|
||||
|
||||
return match status {
|
||||
StatusCode::OK => Ok(response_json),
|
||||
StatusCode::UNAUTHORIZED | StatusCode::FORBIDDEN => {
|
||||
tracing::debug!(
|
||||
"Authentication failed. Status: {status}, Payload: {payload:?}"
|
||||
);
|
||||
Err(ProviderError::Authentication(format!(
|
||||
"Authentication failed: {response_json:?}"
|
||||
)))
|
||||
}
|
||||
_ => {
|
||||
tracing::debug!(
|
||||
"Request failed. Status: {status}, Response: {response_json:?}"
|
||||
);
|
||||
Err(ProviderError::RequestFailed(format!(
|
||||
"Request failed with status {status}: {response_json:?}"
|
||||
)))
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
// Handle 429 Too Many Requests
|
||||
attempts += 1;
|
||||
|
||||
// Try to parse response for more detailed error info
|
||||
let cite_gcp_vertex_429 =
|
||||
"See https://cloud.google.com/vertex-ai/generative-ai/docs/error-code-429";
|
||||
let response_text = response.text().await.unwrap_or_default();
|
||||
let quota_error = if response_text.contains("Exceeded the Provisioned Throughput") {
|
||||
format!("Exceeded the Provisioned Throughput: {cite_gcp_vertex_429}.")
|
||||
} else {
|
||||
format!("Pay-as-you-go resource exhausted: {cite_gcp_vertex_429}.")
|
||||
};
|
||||
|
||||
tracing::warn!(
|
||||
"Rate limit exceeded (attempt {}/{}): {}. Retrying after backoff...",
|
||||
attempts,
|
||||
self.retry_config.max_retries,
|
||||
quota_error
|
||||
);
|
||||
|
||||
// Store the error in case we need to return it after max retries
|
||||
last_error = Some(ProviderError::RateLimitExceeded(quota_error));
|
||||
|
||||
// Calculate and apply the backoff delay
|
||||
let delay = self.retry_config.delay_for_attempt(attempts);
|
||||
tracing::info!("Backing off for {:?} before retry", delay);
|
||||
sleep(delay).await;
|
||||
}
|
||||
}
|
||||
|
||||
/// Makes an authenticated POST request to the Vertex AI API with fallback for invalid locations.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `payload` - The request payload to send
|
||||
/// * `context` - Request context containing model information
|
||||
async fn post(&self, payload: Value, context: &RequestContext) -> Result<Value, ProviderError> {
|
||||
// Try with user-specified location first
|
||||
let result = self
|
||||
.post_with_location(&payload, context, &self.location)
|
||||
.await;
|
||||
|
||||
// If location is already the known location for the model or request succeeded, return result
|
||||
if self.location == context.model.known_location().to_string() || result.is_ok() {
|
||||
return result;
|
||||
}
|
||||
|
||||
// Check if we should retry with the model's known location
|
||||
match &result {
|
||||
Err(ProviderError::RequestFailed(msg)) => {
|
||||
let model_name = context.model.to_string();
|
||||
let configured_location = &self.location;
|
||||
let known_location = context.model.known_location().to_string();
|
||||
|
||||
tracing::error!(
|
||||
"Trying known location {known_location} for {model_name} instead of {configured_location}: {msg}"
|
||||
);
|
||||
|
||||
self.post_with_location(&payload, context, &known_location)
|
||||
.await
|
||||
}
|
||||
// For any other error, return the original result
|
||||
_ => result,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for GcpVertexAIProvider {
|
||||
fn default() -> Self {
|
||||
let model = ModelConfig::new(Self::metadata().default_model);
|
||||
Self::new(model).expect("Failed to initialize VertexAI provider")
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Provider for GcpVertexAIProvider {
|
||||
/// Returns metadata about the GCP Vertex AI provider.
|
||||
fn metadata() -> ProviderMetadata
|
||||
where
|
||||
Self: Sized,
|
||||
{
|
||||
let known_models = vec![
|
||||
GcpVertexAIModel::Claude(ClaudeVersion::Sonnet35),
|
||||
GcpVertexAIModel::Claude(ClaudeVersion::Sonnet35V2),
|
||||
GcpVertexAIModel::Claude(ClaudeVersion::Sonnet37),
|
||||
GcpVertexAIModel::Claude(ClaudeVersion::Haiku35),
|
||||
GcpVertexAIModel::Gemini(GeminiVersion::Pro15),
|
||||
GcpVertexAIModel::Gemini(GeminiVersion::Flash20),
|
||||
GcpVertexAIModel::Gemini(GeminiVersion::Pro20Exp),
|
||||
]
|
||||
.into_iter()
|
||||
.map(|model| model.to_string())
|
||||
.collect();
|
||||
|
||||
ProviderMetadata::new(
|
||||
"gcp_vertex_ai",
|
||||
"GCP Vertex AI",
|
||||
"Access variety of AI models such as Claude, Gemini through Vertex AI",
|
||||
GcpVertexAIModel::Gemini(GeminiVersion::Flash20)
|
||||
.to_string()
|
||||
.as_str(),
|
||||
known_models,
|
||||
GCP_VERTEX_AI_DOC_URL,
|
||||
vec![
|
||||
ConfigKey::new("GCP_PROJECT_ID", true, false, None),
|
||||
ConfigKey::new("GCP_LOCATION", true, false, Some(Iowa.to_string().as_str())),
|
||||
ConfigKey::new(
|
||||
"GCP_MAX_RETRIES",
|
||||
false,
|
||||
false,
|
||||
Some(&DEFAULT_MAX_RETRIES.to_string()),
|
||||
),
|
||||
ConfigKey::new(
|
||||
"GCP_INITIAL_RETRY_INTERVAL_MS",
|
||||
false,
|
||||
false,
|
||||
Some(&DEFAULT_INITIAL_RETRY_INTERVAL_MS.to_string()),
|
||||
),
|
||||
ConfigKey::new(
|
||||
"GCP_BACKOFF_MULTIPLIER",
|
||||
false,
|
||||
false,
|
||||
Some(&DEFAULT_BACKOFF_MULTIPLIER.to_string()),
|
||||
),
|
||||
ConfigKey::new(
|
||||
"GCP_MAX_RETRY_INTERVAL_MS",
|
||||
false,
|
||||
false,
|
||||
Some(&DEFAULT_MAX_RETRY_INTERVAL_MS.to_string()),
|
||||
),
|
||||
],
|
||||
)
|
||||
}
|
||||
|
||||
/// Completes a model interaction by sending a request and processing the response.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `system` - System prompt or context
|
||||
/// * `messages` - Array of previous messages in the conversation
|
||||
/// * `tools` - Array of available tools for the model
|
||||
#[tracing::instrument(
|
||||
skip(self, system, messages, tools),
|
||||
fields(model_config, input, output, input_tokens, output_tokens, total_tokens)
|
||||
)]
|
||||
async fn complete(
|
||||
&self,
|
||||
system: &str,
|
||||
messages: &[Message],
|
||||
tools: &[Tool],
|
||||
) -> Result<(Message, ProviderUsage), ProviderError> {
|
||||
// Create request and context
|
||||
let (request, context) = create_request(&self.model, system, messages, tools)?;
|
||||
|
||||
// Send request and process response
|
||||
let response = self.post(request.clone(), &context).await?;
|
||||
let usage = get_usage(&response, &context)?;
|
||||
|
||||
emit_debug_trace(self, &request, &response, &usage);
|
||||
|
||||
// Convert response to message
|
||||
let message = response_to_message(response, context)?;
|
||||
let provider_usage = ProviderUsage::new(self.model.model_name.clone(), usage);
|
||||
|
||||
Ok((message, provider_usage))
|
||||
}
|
||||
|
||||
/// Returns the current model configuration.
|
||||
fn get_model_config(&self) -> ModelConfig {
|
||||
self.model.clone()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_retry_config_delay_calculation() {
|
||||
let config = RetryConfig {
|
||||
max_retries: 5,
|
||||
initial_interval_ms: 1000,
|
||||
backoff_multiplier: 2.0,
|
||||
max_interval_ms: 32000,
|
||||
};
|
||||
|
||||
// First attempt has no delay
|
||||
let delay0 = config.delay_for_attempt(0);
|
||||
assert_eq!(delay0.as_millis(), 0);
|
||||
|
||||
// First retry should be around initial_interval with jitter
|
||||
let delay1 = config.delay_for_attempt(1);
|
||||
assert!(delay1.as_millis() >= 800 && delay1.as_millis() <= 1200);
|
||||
|
||||
// Second retry should be around initial_interval * multiplier^1 with jitter
|
||||
let delay2 = config.delay_for_attempt(2);
|
||||
assert!(delay2.as_millis() >= 1600 && delay2.as_millis() <= 2400);
|
||||
|
||||
// Check that max interval is respected
|
||||
let delay10 = config.delay_for_attempt(10);
|
||||
assert!(delay10.as_millis() <= 38400); // max_interval_ms * 1.2 (max jitter)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_model_provider_conversion() {
|
||||
assert_eq!(ModelProvider::Anthropic.as_str(), "anthropic");
|
||||
assert_eq!(ModelProvider::Google.as_str(), "google");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_url_construction() {
|
||||
use url::Url;
|
||||
|
||||
let model_config = ModelConfig::new("claude-3-5-sonnet-v2@20241022".to_string());
|
||||
let context = RequestContext::new(&model_config.model_name).unwrap();
|
||||
let api_model_id = context.model.to_string();
|
||||
|
||||
let host = "https://us-east5-aiplatform.googleapis.com";
|
||||
let project_id = "test-project";
|
||||
let location = "us-east5";
|
||||
|
||||
let path = format!(
|
||||
"v1/projects/{}/locations/{}/publishers/{}/models/{}:{}",
|
||||
project_id,
|
||||
location,
|
||||
ModelProvider::Anthropic.as_str(),
|
||||
api_model_id,
|
||||
"streamRawPredict"
|
||||
);
|
||||
|
||||
let url = Url::parse(host).unwrap().join(&path).unwrap();
|
||||
|
||||
assert!(url.as_str().contains("publishers/anthropic"));
|
||||
assert!(url.as_str().contains("projects/test-project"));
|
||||
assert!(url.as_str().contains("locations/us-east5"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_provider_metadata() {
|
||||
let metadata = GcpVertexAIProvider::metadata();
|
||||
assert!(metadata
|
||||
.known_models
|
||||
.contains(&"claude-3-5-sonnet-v2@20241022".to_string()));
|
||||
assert!(metadata
|
||||
.known_models
|
||||
.contains(&"gemini-1.5-pro-002".to_string()));
|
||||
// Should contain the original 2 config keys plus 4 new retry-related ones
|
||||
assert_eq!(metadata.config_keys.len(), 6);
|
||||
}
|
||||
}
|
||||
@@ -6,6 +6,8 @@ pub mod databricks;
|
||||
pub mod errors;
|
||||
mod factory;
|
||||
pub mod formats;
|
||||
mod gcpauth;
|
||||
pub mod gcpvertexai;
|
||||
pub mod google;
|
||||
pub mod groq;
|
||||
pub mod oauth;
|
||||
|
||||
@@ -6,12 +6,12 @@ use goose::agents::AgentFactory;
|
||||
use goose::message::Message;
|
||||
use goose::model::ModelConfig;
|
||||
use goose::providers::base::Provider;
|
||||
use goose::providers::{anthropic::AnthropicProvider, databricks::DatabricksProvider};
|
||||
use goose::providers::{
|
||||
azure::AzureProvider, bedrock::BedrockProvider, ollama::OllamaProvider, openai::OpenAiProvider,
|
||||
anthropic::AnthropicProvider, azure::AzureProvider, bedrock::BedrockProvider,
|
||||
databricks::DatabricksProvider, gcpvertexai::GcpVertexAIProvider, google::GoogleProvider,
|
||||
groq::GroqProvider, ollama::OllamaProvider, openai::OpenAiProvider,
|
||||
openrouter::OpenRouterProvider,
|
||||
};
|
||||
use goose::providers::{google::GoogleProvider, groq::GroqProvider};
|
||||
|
||||
#[derive(Debug, PartialEq)]
|
||||
enum ProviderType {
|
||||
@@ -20,6 +20,7 @@ enum ProviderType {
|
||||
Anthropic,
|
||||
Bedrock,
|
||||
Databricks,
|
||||
GcpVertexAI,
|
||||
Google,
|
||||
Groq,
|
||||
Ollama,
|
||||
@@ -42,6 +43,7 @@ impl ProviderType {
|
||||
ProviderType::Groq => &["GROQ_API_KEY"],
|
||||
ProviderType::Ollama => &[],
|
||||
ProviderType::OpenRouter => &["OPENROUTER_API_KEY"],
|
||||
ProviderType::GcpVertexAI => &["GCP_PROJECT_ID", "GCP_LOCATION"],
|
||||
}
|
||||
}
|
||||
|
||||
@@ -70,6 +72,7 @@ impl ProviderType {
|
||||
ProviderType::Anthropic => Box::new(AnthropicProvider::from_env(model_config)?),
|
||||
ProviderType::Bedrock => Box::new(BedrockProvider::from_env(model_config)?),
|
||||
ProviderType::Databricks => Box::new(DatabricksProvider::from_env(model_config)?),
|
||||
ProviderType::GcpVertexAI => Box::new(GcpVertexAIProvider::from_env(model_config)?),
|
||||
ProviderType::Google => Box::new(GoogleProvider::from_env(model_config)?),
|
||||
ProviderType::Groq => Box::new(GroqProvider::from_env(model_config)?),
|
||||
ProviderType::Ollama => Box::new(OllamaProvider::from_env(model_config)?),
|
||||
@@ -290,4 +293,14 @@ mod tests {
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_truncate_agent_with_gcpvertexai() -> Result<()> {
|
||||
run_test_with_config(TestConfig {
|
||||
provider_type: ProviderType::GcpVertexAI,
|
||||
model: "claude-3-5-sonnet-v2@20241022",
|
||||
context_window: 200_000,
|
||||
})
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
@@ -17,17 +17,18 @@ Goose relies heavily on tool calling capabilities and currently works best with
|
||||
|
||||
## Available Providers
|
||||
|
||||
| Provider | Description | Parameters |
|
||||
|-----------------------------------------------|-----------------------------------------------------|---------------------------------------|
|
||||
|[Amazon Bedrock](https://aws.amazon.com/bedrock/)| Offers a variety of foundation models, including Claude, Jurassic-2, and others. **Environment variables must be set in advance, not configured through `goose configure`** | `AWS_ACCESS_KEY_ID`, `AWS_SECRET_ACCESS_KEY`, `AWS_REGION`|
|
||||
| [Anthropic](https://www.anthropic.com/) | Offers Claude, an advanced AI model for natural language tasks. | `ANTHROPIC_API_KEY` |
|
||||
|[Azure OpenAI](https://learn.microsoft.com/en-us/azure/ai-services/openai/) | Access Azure-hosted OpenAI models, including GPT-4 and GPT-3.5.| `AZURE_OPENAI_API_KEY`, `AZURE_OPENAI_ENDPOINT`, `AZURE_OPENAI_DEPLOYMENT_NAME` |
|
||||
| [Databricks](https://www.databricks.com/) | Unified data analytics and AI platform for building and deploying models. | `DATABRICKS_HOST`, `DATABRICKS_TOKEN` |
|
||||
| [Gemini](https://ai.google.dev/gemini-api/docs) | Advanced LLMs by Google with multimodal capabilities (text, images). | `GOOGLE_API_KEY` |
|
||||
| [Groq](https://groq.com/) | High-performance inference hardware and tools for LLMs. | `GROQ_API_KEY` |
|
||||
| [Ollama](https://ollama.com/) | Local model runner supporting Qwen, Llama, DeepSeek, and other open-source models. **Because this provider runs locally, you must first [download and run a model](/docs/getting-started/providers#local-llms-ollama).** | `OLLAMA_HOST` |
|
||||
| [OpenAI](https://platform.openai.com/api-keys) | Provides gpt-4o, o1, and other advanced language models. Also supports OpenAI-compatible endpoints (e.g., self-hosted LLaMA, vLLM, KServe). **o1-mini and o1-preview are not supported because Goose uses tool calling.** | `OPENAI_API_KEY`, `OPENAI_HOST` (optional), `OPENAI_ORGANIZATION` (optional), `OPENAI_PROJECT` (optional) |
|
||||
| [OpenRouter](https://openrouter.ai/) | API gateway for unified access to various models with features like rate-limiting management. | `OPENROUTER_API_KEY` |
|
||||
| Provider | Description | Parameters |
|
||||
|-----------------------------------------------------------------------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
||||
| [Amazon Bedrock](https://aws.amazon.com/bedrock/) | Offers a variety of foundation models, including Claude, Jurassic-2, and others. **Environment variables must be set in advance, not configured through `goose configure`** | `AWS_ACCESS_KEY_ID`, `AWS_SECRET_ACCESS_KEY`, `AWS_REGION` |
|
||||
| [Anthropic](https://www.anthropic.com/) | Offers Claude, an advanced AI model for natural language tasks. | `ANTHROPIC_API_KEY` |
|
||||
| [Azure OpenAI](https://learn.microsoft.com/en-us/azure/ai-services/openai/) | Access Azure-hosted OpenAI models, including GPT-4 and GPT-3.5. | `AZURE_OPENAI_API_KEY`, `AZURE_OPENAI_ENDPOINT`, `AZURE_OPENAI_DEPLOYMENT_NAME` |
|
||||
| [Databricks](https://www.databricks.com/) | Unified data analytics and AI platform for building and deploying models. | `DATABRICKS_HOST`, `DATABRICKS_TOKEN` |
|
||||
| [Gemini](https://ai.google.dev/gemini-api/docs) | Advanced LLMs by Google with multimodal capabilities (text, images). | `GOOGLE_API_KEY` |
|
||||
| [GCP Vertex AI](https://cloud.google.com/vertex-ai) | Google Cloud's Vertex AI platform, supporting Gemini and Claude models. **Credentials must be configured in advance. Follow the instructions at https://cloud.google.com/vertex-ai/docs/authentication.** | `GCP_PROJECT_ID`, `GCP_LOCATION` and optional `GCP_MAX_RETRIES` (6), `GCP_INITIAL_RETRY_INTERVAL_MS` (5000), `GCP_BACKOFF_MULTIPLIER` (2.0), `GCP_MAX_RETRY_INTERVAL_MS` (320_000). |
|
||||
| [Groq](https://groq.com/) | High-performance inference hardware and tools for LLMs. | `GROQ_API_KEY` |
|
||||
| [Ollama](https://ollama.com/) | Local model runner supporting Qwen, Llama, DeepSeek, and other open-source models. **Because this provider runs locally, you must first [download and run a model](/docs/getting-started/providers#local-llms-ollama).** | `OLLAMA_HOST` |
|
||||
| [OpenAI](https://platform.openai.com/api-keys) | Provides gpt-4o, o1, and other advanced language models. Also supports OpenAI-compatible endpoints (e.g., self-hosted LLaMA, vLLM, KServe). **o1-mini and o1-preview are not supported because Goose uses tool calling.** | `OPENAI_API_KEY`, `OPENAI_HOST` (optional), `OPENAI_ORGANIZATION` (optional), `OPENAI_PROJECT` (optional) |
|
||||
| [OpenRouter](https://openrouter.ai/) | API gateway for unified access to various models with features like rate-limiting management. | `OPENROUTER_API_KEY` |
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -11,6 +11,8 @@ export function isSecretKey(keyName: string): boolean {
|
||||
'OPENAI_BASE_PATH',
|
||||
'AZURE_OPENAI_ENDPOINT',
|
||||
'AZURE_OPENAI_DEPLOYMENT_NAME',
|
||||
'GCP_PROJECT_ID',
|
||||
'GCP_LOCATION',
|
||||
];
|
||||
return !nonSecretKeys.includes(keyName);
|
||||
}
|
||||
|
||||
@@ -19,6 +19,13 @@ export const goose_models: Model[] = [
|
||||
{ id: 17, name: 'qwen2.5', provider: 'Ollama' },
|
||||
{ id: 18, name: 'anthropic/claude-3.5-sonnet', provider: 'OpenRouter' },
|
||||
{ id: 19, name: 'gpt-4o', provider: 'Azure OpenAI' },
|
||||
{ id: 20, name: 'claude-3-7-sonnet@20250219', provider: 'GCP Vertex AI' },
|
||||
{ id: 21, name: 'claude-3-5-sonnet-v2@20241022', provider: 'GCP Vertex AI' },
|
||||
{ id: 22, name: 'claude-3-5-sonnet@20240620', provider: 'GCP Vertex AI' },
|
||||
{ id: 23, name: 'claude-3-5-haiku@20241022', provider: 'GCP Vertex AI' },
|
||||
{ id: 24, name: 'gemini-2.0-pro-exp-02-05', provider: 'GCP Vertex AI' },
|
||||
{ id: 25, name: 'gemini-2.0-flash-001', provider: 'GCP Vertex AI' },
|
||||
{ id: 26, name: 'gemini-1.5-pro-002', provider: 'GCP Vertex AI' },
|
||||
];
|
||||
|
||||
export const openai_models = ['gpt-4o-mini', 'gpt-4o', 'gpt-4-turbo', 'o1'];
|
||||
@@ -47,6 +54,16 @@ export const openrouter_models = ['anthropic/claude-3.5-sonnet'];
|
||||
|
||||
export const azure_openai_models = ['gpt-4o'];
|
||||
|
||||
export const gcp_vertex_ai_models = [
|
||||
'claude-3-7-sonnet@20250219',
|
||||
'claude-3-5-sonnet-v2@20241022',
|
||||
'claude-3-5-sonnet@20240620',
|
||||
'claude-3-5-haiku@20241022',
|
||||
'gemini-1.5-pro-002',
|
||||
'gemini-2.0-flash-001',
|
||||
'gemini-2.0-pro-exp-02-05',
|
||||
];
|
||||
|
||||
export const default_models = {
|
||||
openai: 'gpt-4o',
|
||||
anthropic: 'claude-3-5-sonnet-latest',
|
||||
@@ -56,6 +73,7 @@ export const default_models = {
|
||||
openrouter: 'anthropic/claude-3.5-sonnet',
|
||||
ollama: 'qwen2.5',
|
||||
azure_openai: 'gpt-4o',
|
||||
gcp_vertex_ai: 'gemini-2.0-flash-001',
|
||||
};
|
||||
|
||||
export function getDefaultModel(key: string): string | undefined {
|
||||
@@ -73,12 +91,14 @@ export const required_keys = {
|
||||
Google: ['GOOGLE_API_KEY'],
|
||||
OpenRouter: ['OPENROUTER_API_KEY'],
|
||||
'Azure OpenAI': ['AZURE_OPENAI_API_KEY', 'AZURE_OPENAI_ENDPOINT', 'AZURE_OPENAI_DEPLOYMENT_NAME'],
|
||||
'GCP Vertex AI': ['GCP_PROJECT_ID', 'GCP_LOCATION'],
|
||||
};
|
||||
|
||||
export const default_key_value = {
|
||||
OPENAI_HOST: 'https://api.openai.com',
|
||||
OPENAI_BASE_PATH: 'v1/chat/completions',
|
||||
OLLAMA_HOST: 'localhost',
|
||||
GCP_LOCATION: 'us-central1',
|
||||
};
|
||||
|
||||
export const supported_providers = [
|
||||
@@ -90,6 +110,7 @@ export const supported_providers = [
|
||||
'Ollama',
|
||||
'OpenRouter',
|
||||
'Azure OpenAI',
|
||||
'GCP Vertex AI',
|
||||
];
|
||||
|
||||
export const model_docs_link = [
|
||||
@@ -103,6 +124,7 @@ export const model_docs_link = [
|
||||
},
|
||||
{ name: 'OpenRouter', href: 'https://openrouter.ai/models' },
|
||||
{ name: 'Ollama', href: 'https://ollama.com/library' },
|
||||
{ name: 'GCP Vertex AI', href: 'https://cloud.google.com/vertex-ai' },
|
||||
];
|
||||
|
||||
export const provider_aliases = [
|
||||
@@ -114,4 +136,5 @@ export const provider_aliases = [
|
||||
{ provider: 'OpenRouter', alias: 'openrouter' },
|
||||
{ provider: 'Google', alias: 'google' },
|
||||
{ provider: 'Azure OpenAI', alias: 'azure_openai' },
|
||||
{ provider: 'GCP Vertex AI', alias: 'gcp_vertex_ai' },
|
||||
];
|
||||
|
||||
Reference in New Issue
Block a user