mirror of
https://github.com/aljazceru/goose.git
synced 2025-12-18 22:54:24 +01:00
feat: add GCP Vertex AI platform as provider (#1364)
Signed-off-by: Uddhav Kambli <uddhav@kambli.net>
This commit is contained in:
129
Cargo.lock
generated
129
Cargo.lock
generated
@@ -230,9 +230,9 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "async-trait"
|
name = "async-trait"
|
||||||
version = "0.1.85"
|
version = "0.1.86"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "3f934833b4b7233644e5848f235df3f57ed8c80f1528a26c3dfa13d2147fa056"
|
checksum = "644dd749086bf3771a2fbc5f256fdb982d53f011c7d5d560304eafeecebce79d"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
@@ -1699,6 +1699,12 @@ version = "0.15.0"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "77c90badedccf4105eca100756a0b1289e191f6fcbdadd3cee1d2f614f97da8f"
|
checksum = "77c90badedccf4105eca100756a0b1289e191f6fcbdadd3cee1d2f614f97da8f"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "downcast"
|
||||||
|
version = "0.11.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "1435fa1053d8b2fbbe9be7e97eca7f33d37b28409959813daefc1446a14247f1"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "dyn-clone"
|
name = "dyn-clone"
|
||||||
version = "1.0.17"
|
version = "1.0.17"
|
||||||
@@ -1883,6 +1889,12 @@ dependencies = [
|
|||||||
"percent-encoding",
|
"percent-encoding",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "fragile"
|
||||||
|
version = "2.0.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "6c2141d6d6c8512188a7891b4b01590a45f6dac67afb4f255c4124dbb86d4eaa"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "futures"
|
name = "futures"
|
||||||
version = "0.3.31"
|
version = "0.3.31"
|
||||||
@@ -2149,10 +2161,12 @@ dependencies = [
|
|||||||
"futures",
|
"futures",
|
||||||
"include_dir",
|
"include_dir",
|
||||||
"indoc",
|
"indoc",
|
||||||
|
"jsonwebtoken",
|
||||||
"keyring",
|
"keyring",
|
||||||
"lazy_static",
|
"lazy_static",
|
||||||
"mcp-client",
|
"mcp-client",
|
||||||
"mcp-core",
|
"mcp-core",
|
||||||
|
"mockall",
|
||||||
"nanoid",
|
"nanoid",
|
||||||
"once_cell",
|
"once_cell",
|
||||||
"paste",
|
"paste",
|
||||||
@@ -2571,7 +2585,7 @@ dependencies = [
|
|||||||
"http 1.2.0",
|
"http 1.2.0",
|
||||||
"hyper 1.6.0",
|
"hyper 1.6.0",
|
||||||
"hyper-util",
|
"hyper-util",
|
||||||
"rustls 0.23.21",
|
"rustls 0.23.23",
|
||||||
"rustls-native-certs 0.8.1",
|
"rustls-native-certs 0.8.1",
|
||||||
"rustls-pki-types",
|
"rustls-pki-types",
|
||||||
"tokio",
|
"tokio",
|
||||||
@@ -3028,6 +3042,21 @@ dependencies = [
|
|||||||
"serde",
|
"serde",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "jsonwebtoken"
|
||||||
|
version = "9.3.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "5a87cc7a48537badeae96744432de36f4be2b4a34a05a5ef32e9dd8a1c169dde"
|
||||||
|
dependencies = [
|
||||||
|
"base64 0.22.1",
|
||||||
|
"js-sys",
|
||||||
|
"pem",
|
||||||
|
"ring",
|
||||||
|
"serde",
|
||||||
|
"serde_json",
|
||||||
|
"simple_asn1",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "keyring"
|
name = "keyring"
|
||||||
version = "3.6.1"
|
version = "3.6.1"
|
||||||
@@ -3367,6 +3396,32 @@ dependencies = [
|
|||||||
"windows-sys 0.52.0",
|
"windows-sys 0.52.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "mockall"
|
||||||
|
version = "0.13.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "39a6bfcc6c8c7eed5ee98b9c3e33adc726054389233e201c95dab2d41a3839d2"
|
||||||
|
dependencies = [
|
||||||
|
"cfg-if",
|
||||||
|
"downcast",
|
||||||
|
"fragile",
|
||||||
|
"mockall_derive",
|
||||||
|
"predicates",
|
||||||
|
"predicates-tree",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "mockall_derive"
|
||||||
|
version = "0.13.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "25ca3004c2efe9011bd4e461bd8256445052b9615405b4f7ea43fc8ca5c20898"
|
||||||
|
dependencies = [
|
||||||
|
"cfg-if",
|
||||||
|
"proc-macro2",
|
||||||
|
"quote",
|
||||||
|
"syn 2.0.96",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "monostate"
|
name = "monostate"
|
||||||
version = "0.1.13"
|
version = "0.1.13"
|
||||||
@@ -3739,6 +3794,16 @@ version = "0.2.3"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "df94ce210e5bc13cb6651479fa48d14f601d9858cfe0467f43ae157023b938d3"
|
checksum = "df94ce210e5bc13cb6651479fa48d14f601d9858cfe0467f43ae157023b938d3"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "pem"
|
||||||
|
version = "3.0.5"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "38af38e8470ac9dee3ce1bae1af9c1671fffc44ddfd8bd1d0a3445bf349a8ef3"
|
||||||
|
dependencies = [
|
||||||
|
"base64 0.22.1",
|
||||||
|
"serde",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "percent-encoding"
|
name = "percent-encoding"
|
||||||
version = "2.3.1"
|
version = "2.3.1"
|
||||||
@@ -3941,6 +4006,32 @@ dependencies = [
|
|||||||
"zerocopy",
|
"zerocopy",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "predicates"
|
||||||
|
version = "3.1.3"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "a5d19ee57562043d37e82899fade9a22ebab7be9cef5026b07fda9cdd4293573"
|
||||||
|
dependencies = [
|
||||||
|
"anstyle",
|
||||||
|
"predicates-core",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "predicates-core"
|
||||||
|
version = "1.0.9"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "727e462b119fe9c93fd0eb1429a5f7647394014cf3c04ab2c0350eeb09095ffa"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "predicates-tree"
|
||||||
|
version = "1.0.12"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "72dd2d6d381dfb73a193c7fca536518d7caee39fc8503f74e7dc0be0531b425c"
|
||||||
|
dependencies = [
|
||||||
|
"predicates-core",
|
||||||
|
"termtree",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "prettyplease"
|
name = "prettyplease"
|
||||||
version = "0.2.29"
|
version = "0.2.29"
|
||||||
@@ -4063,7 +4154,7 @@ dependencies = [
|
|||||||
"quinn-proto",
|
"quinn-proto",
|
||||||
"quinn-udp",
|
"quinn-udp",
|
||||||
"rustc-hash 2.1.0",
|
"rustc-hash 2.1.0",
|
||||||
"rustls 0.23.21",
|
"rustls 0.23.23",
|
||||||
"socket2",
|
"socket2",
|
||||||
"thiserror 2.0.11",
|
"thiserror 2.0.11",
|
||||||
"tokio",
|
"tokio",
|
||||||
@@ -4081,7 +4172,7 @@ dependencies = [
|
|||||||
"rand",
|
"rand",
|
||||||
"ring",
|
"ring",
|
||||||
"rustc-hash 2.1.0",
|
"rustc-hash 2.1.0",
|
||||||
"rustls 0.23.21",
|
"rustls 0.23.23",
|
||||||
"rustls-pki-types",
|
"rustls-pki-types",
|
||||||
"slab",
|
"slab",
|
||||||
"thiserror 2.0.11",
|
"thiserror 2.0.11",
|
||||||
@@ -4393,7 +4484,7 @@ dependencies = [
|
|||||||
"percent-encoding",
|
"percent-encoding",
|
||||||
"pin-project-lite",
|
"pin-project-lite",
|
||||||
"quinn",
|
"quinn",
|
||||||
"rustls 0.23.21",
|
"rustls 0.23.23",
|
||||||
"rustls-pemfile 2.2.0",
|
"rustls-pemfile 2.2.0",
|
||||||
"rustls-pki-types",
|
"rustls-pki-types",
|
||||||
"serde",
|
"serde",
|
||||||
@@ -4514,9 +4605,9 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "rustls"
|
name = "rustls"
|
||||||
version = "0.23.21"
|
version = "0.23.23"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "8f287924602bf649d949c63dc8ac8b235fa5387d394020705b80c4eb597ce5b8"
|
checksum = "47796c98c480fce5406ef69d1c76378375492c3b0a0de587be0c1d9feb12f395"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"once_cell",
|
"once_cell",
|
||||||
"ring",
|
"ring",
|
||||||
@@ -4977,6 +5068,18 @@ dependencies = [
|
|||||||
"quote",
|
"quote",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "simple_asn1"
|
||||||
|
version = "0.6.3"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "297f631f50729c8c99b84667867963997ec0b50f32b2a7dbcab828ef0541e8bb"
|
||||||
|
dependencies = [
|
||||||
|
"num-bigint",
|
||||||
|
"num-traits",
|
||||||
|
"thiserror 2.0.11",
|
||||||
|
"time",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "siphasher"
|
name = "siphasher"
|
||||||
version = "1.0.1"
|
version = "1.0.1"
|
||||||
@@ -5268,6 +5371,12 @@ dependencies = [
|
|||||||
"windows-sys 0.59.0",
|
"windows-sys 0.59.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "termtree"
|
||||||
|
version = "0.5.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "8f50febec83f5ee1df3015341d8bd429f2d1cc62bcba7ea2076759d315084683"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "test-case"
|
name = "test-case"
|
||||||
version = "3.3.1"
|
version = "3.3.1"
|
||||||
@@ -5537,7 +5646,7 @@ version = "0.26.1"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "5f6d0975eaace0cf0fcadee4e4aaa5da15b5c079146f2cffb67c113be122bf37"
|
checksum = "5f6d0975eaace0cf0fcadee4e4aaa5da15b5c079146f2cffb67c113be122bf37"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"rustls 0.23.21",
|
"rustls 0.23.23",
|
||||||
"tokio",
|
"tokio",
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -6767,7 +6876,7 @@ dependencies = [
|
|||||||
"hyper-util",
|
"hyper-util",
|
||||||
"log",
|
"log",
|
||||||
"percent-encoding",
|
"percent-encoding",
|
||||||
"rustls 0.23.21",
|
"rustls 0.23.23",
|
||||||
"rustls-pemfile 2.2.0",
|
"rustls-pemfile 2.2.0",
|
||||||
"seahash",
|
"seahash",
|
||||||
"serde",
|
"serde",
|
||||||
|
|||||||
@@ -17,6 +17,12 @@
|
|||||||
"models": ["goose"],
|
"models": ["goose"],
|
||||||
"required_keys": ["DATABRICKS_HOST"]
|
"required_keys": ["DATABRICKS_HOST"]
|
||||||
},
|
},
|
||||||
|
"gcp_vertex_ai": {
|
||||||
|
"name": "GCP Vertex AI",
|
||||||
|
"description": "Use Vertex AI platform models",
|
||||||
|
"models": ["claude-3-5-haiku@20241022", "claude-3-5-sonnet@20240620", "claude-3-5-sonnet-v2@20241022", "claude-3-7-sonnet@20250219", "gemini-1.5-pro-002", "gemini-2.0-flash-001", "gemini-2.0-pro-exp-02-05"],
|
||||||
|
"required_keys": ["GCP_PROJECT_ID", "GCP_LOCATION"]
|
||||||
|
},
|
||||||
"google": {
|
"google": {
|
||||||
"name": "Google",
|
"name": "Google",
|
||||||
"description": "Lorem ipsum",
|
"description": "Lorem ipsum",
|
||||||
|
|||||||
@@ -66,6 +66,9 @@ aws-config = { version = "1.1.7", features = ["behavior-version-latest"] }
|
|||||||
aws-smithy-types = "1.2.12"
|
aws-smithy-types = "1.2.12"
|
||||||
aws-sdk-bedrockruntime = "1.72.0"
|
aws-sdk-bedrockruntime = "1.72.0"
|
||||||
|
|
||||||
|
# For GCP Vertex AI provider auth
|
||||||
|
jsonwebtoken = "9.3.1"
|
||||||
|
|
||||||
[target.'cfg(target_os = "windows")'.dependencies]
|
[target.'cfg(target_os = "windows")'.dependencies]
|
||||||
winapi = { version = "0.3", features = ["wincred"] }
|
winapi = { version = "0.3", features = ["wincred"] }
|
||||||
|
|
||||||
@@ -73,6 +76,9 @@ winapi = { version = "0.3", features = ["wincred"] }
|
|||||||
criterion = "0.5"
|
criterion = "0.5"
|
||||||
tempfile = "3.15.0"
|
tempfile = "3.15.0"
|
||||||
serial_test = "3.2.0"
|
serial_test = "3.2.0"
|
||||||
|
mockall = "0.13.1"
|
||||||
|
wiremock = "0.6.0"
|
||||||
|
tokio = { version = "1.0", features = ["full"] }
|
||||||
|
|
||||||
[[example]]
|
[[example]]
|
||||||
name = "agent"
|
name = "agent"
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ use super::{
|
|||||||
base::{Provider, ProviderMetadata},
|
base::{Provider, ProviderMetadata},
|
||||||
bedrock::BedrockProvider,
|
bedrock::BedrockProvider,
|
||||||
databricks::DatabricksProvider,
|
databricks::DatabricksProvider,
|
||||||
|
gcpvertexai::GcpVertexAIProvider,
|
||||||
google::GoogleProvider,
|
google::GoogleProvider,
|
||||||
groq::GroqProvider,
|
groq::GroqProvider,
|
||||||
ollama::OllamaProvider,
|
ollama::OllamaProvider,
|
||||||
@@ -19,6 +20,7 @@ pub fn providers() -> Vec<ProviderMetadata> {
|
|||||||
AzureProvider::metadata(),
|
AzureProvider::metadata(),
|
||||||
BedrockProvider::metadata(),
|
BedrockProvider::metadata(),
|
||||||
DatabricksProvider::metadata(),
|
DatabricksProvider::metadata(),
|
||||||
|
GcpVertexAIProvider::metadata(),
|
||||||
GoogleProvider::metadata(),
|
GoogleProvider::metadata(),
|
||||||
GroqProvider::metadata(),
|
GroqProvider::metadata(),
|
||||||
OllamaProvider::metadata(),
|
OllamaProvider::metadata(),
|
||||||
@@ -37,6 +39,7 @@ pub fn create(name: &str, model: ModelConfig) -> Result<Box<dyn Provider + Send
|
|||||||
"groq" => Ok(Box::new(GroqProvider::from_env(model)?)),
|
"groq" => Ok(Box::new(GroqProvider::from_env(model)?)),
|
||||||
"ollama" => Ok(Box::new(OllamaProvider::from_env(model)?)),
|
"ollama" => Ok(Box::new(OllamaProvider::from_env(model)?)),
|
||||||
"openrouter" => Ok(Box::new(OpenRouterProvider::from_env(model)?)),
|
"openrouter" => Ok(Box::new(OpenRouterProvider::from_env(model)?)),
|
||||||
|
"gcp_vertex_ai" => Ok(Box::new(GcpVertexAIProvider::from_env(model)?)),
|
||||||
"google" => Ok(Box::new(GoogleProvider::from_env(model)?)),
|
"google" => Ok(Box::new(GoogleProvider::from_env(model)?)),
|
||||||
_ => Err(anyhow::anyhow!("Unknown provider: {}", name)),
|
_ => Err(anyhow::anyhow!("Unknown provider: {}", name)),
|
||||||
}
|
}
|
||||||
|
|||||||
369
crates/goose/src/providers/formats/gcpvertexai.rs
Normal file
369
crates/goose/src/providers/formats/gcpvertexai.rs
Normal file
@@ -0,0 +1,369 @@
|
|||||||
|
use super::{anthropic, google};
|
||||||
|
use crate::message::Message;
|
||||||
|
use crate::model::ModelConfig;
|
||||||
|
use crate::providers::base::Usage;
|
||||||
|
use anyhow::{Context, Result};
|
||||||
|
use mcp_core::tool::Tool;
|
||||||
|
use serde_json::Value;
|
||||||
|
|
||||||
|
use std::fmt;
|
||||||
|
|
||||||
|
/// Sensible default values of Google Cloud Platform (GCP) locations for model deployment.
|
||||||
|
///
|
||||||
|
/// Each variant corresponds to a specific GCP region where models can be hosted.
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq, Copy)]
|
||||||
|
pub enum GcpLocation {
|
||||||
|
/// Represents the us-central1 region in Iowa
|
||||||
|
Iowa,
|
||||||
|
/// Represents the us-east5 region in Ohio
|
||||||
|
Ohio,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl fmt::Display for GcpLocation {
|
||||||
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||||
|
match self {
|
||||||
|
Self::Iowa => write!(f, "us-central1"),
|
||||||
|
Self::Ohio => write!(f, "us-east5"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TryFrom<&str> for GcpLocation {
|
||||||
|
type Error = ModelError;
|
||||||
|
|
||||||
|
fn try_from(s: &str) -> Result<Self, Self::Error> {
|
||||||
|
match s {
|
||||||
|
"us-central1" => Ok(Self::Iowa),
|
||||||
|
"us-east5" => Ok(Self::Ohio),
|
||||||
|
_ => Err(ModelError::UnsupportedLocation(s.to_string())),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Represents errors that can occur during model operations.
|
||||||
|
///
|
||||||
|
/// This enum encompasses various error conditions that might arise when working
|
||||||
|
/// with GCP Vertex AI models, including unsupported models, invalid requests,
|
||||||
|
/// and unsupported locations.
|
||||||
|
#[derive(Debug, thiserror::Error)]
|
||||||
|
pub enum ModelError {
|
||||||
|
/// Error when an unsupported Vertex AI model is specified
|
||||||
|
#[error("Unsupported Vertex AI model: {0}")]
|
||||||
|
UnsupportedModel(String),
|
||||||
|
/// Error when the request structure is invalid
|
||||||
|
#[error("Invalid request structure: {0}")]
|
||||||
|
InvalidRequest(String),
|
||||||
|
/// Error when an unsupported GCP location is specified
|
||||||
|
#[error("Unsupported GCP location: {0}")]
|
||||||
|
UnsupportedLocation(String),
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Represents available GCP Vertex AI models for Goose.
|
||||||
|
///
|
||||||
|
/// This enum encompasses different model families and their versions
|
||||||
|
/// that are supported in the GCP Vertex AI platform.
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
|
pub enum GcpVertexAIModel {
|
||||||
|
/// Claude model family with specific versions
|
||||||
|
Claude(ClaudeVersion),
|
||||||
|
/// Gemini model family with specific versions
|
||||||
|
Gemini(GeminiVersion),
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Represents available versions of the Claude model for Goose.
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq, Copy)]
|
||||||
|
pub enum ClaudeVersion {
|
||||||
|
/// Claude 3.5 Sonnet initial version
|
||||||
|
Sonnet35,
|
||||||
|
/// Claude 3.5 Sonnet version 2
|
||||||
|
Sonnet35V2,
|
||||||
|
/// Claude 3.7 Sonnet
|
||||||
|
Sonnet37,
|
||||||
|
/// Claude 3.5 Haiku
|
||||||
|
Haiku35,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Represents available versions of the Gemini model for Goose.
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq, Copy)]
|
||||||
|
pub enum GeminiVersion {
|
||||||
|
/// Gemini 1.5 Pro version
|
||||||
|
Pro15,
|
||||||
|
/// Gemini 2.0 Flash version
|
||||||
|
Flash20,
|
||||||
|
/// Gemini 2.0 Pro Experimental version
|
||||||
|
Pro20Exp,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl fmt::Display for GcpVertexAIModel {
|
||||||
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||||
|
let model_id = match self {
|
||||||
|
Self::Claude(version) => match version {
|
||||||
|
ClaudeVersion::Sonnet35 => "claude-3-5-sonnet@20240620",
|
||||||
|
ClaudeVersion::Sonnet35V2 => "claude-3-5-sonnet-v2@20241022",
|
||||||
|
ClaudeVersion::Sonnet37 => "claude-3-7-sonnet@20250219",
|
||||||
|
ClaudeVersion::Haiku35 => "claude-3-5-haiku@20241022",
|
||||||
|
},
|
||||||
|
Self::Gemini(version) => match version {
|
||||||
|
GeminiVersion::Pro15 => "gemini-1.5-pro-002",
|
||||||
|
GeminiVersion::Flash20 => "gemini-2.0-flash-001",
|
||||||
|
GeminiVersion::Pro20Exp => "gemini-2.0-pro-exp-02-05",
|
||||||
|
},
|
||||||
|
};
|
||||||
|
write!(f, "{model_id}")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl GcpVertexAIModel {
|
||||||
|
/// Returns the default GCP location for the model.
|
||||||
|
///
|
||||||
|
/// Each model family has a well-known location:
|
||||||
|
/// - Claude models default to Ohio (us-east5)
|
||||||
|
/// - Gemini models default to Iowa (us-central1)
|
||||||
|
pub fn known_location(&self) -> GcpLocation {
|
||||||
|
match self {
|
||||||
|
Self::Claude(_) => GcpLocation::Ohio,
|
||||||
|
Self::Gemini(_) => GcpLocation::Iowa,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TryFrom<&str> for GcpVertexAIModel {
|
||||||
|
type Error = ModelError;
|
||||||
|
|
||||||
|
fn try_from(s: &str) -> Result<Self, Self::Error> {
|
||||||
|
match s {
|
||||||
|
"claude-3-5-sonnet@20240620" => Ok(Self::Claude(ClaudeVersion::Sonnet35)),
|
||||||
|
"claude-3-5-sonnet-v2@20241022" => Ok(Self::Claude(ClaudeVersion::Sonnet35V2)),
|
||||||
|
"claude-3-7-sonnet@20250219" => Ok(Self::Claude(ClaudeVersion::Sonnet37)),
|
||||||
|
"claude-3-5-haiku@20241022" => Ok(Self::Claude(ClaudeVersion::Haiku35)),
|
||||||
|
"gemini-1.5-pro-002" => Ok(Self::Gemini(GeminiVersion::Pro15)),
|
||||||
|
"gemini-2.0-flash-001" => Ok(Self::Gemini(GeminiVersion::Flash20)),
|
||||||
|
"gemini-2.0-pro-exp-02-05" => Ok(Self::Gemini(GeminiVersion::Pro20Exp)),
|
||||||
|
_ => Err(ModelError::UnsupportedModel(s.to_string())),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Holds context information for a model request since the Vertex AI platform
|
||||||
|
/// supports multiple model families.
|
||||||
|
///
|
||||||
|
/// This structure maintains information about the model being used
|
||||||
|
/// and provides utility methods for handling model-specific operations.
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct RequestContext {
|
||||||
|
/// The GCP Vertex AI model being used
|
||||||
|
pub model: GcpVertexAIModel,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl RequestContext {
|
||||||
|
/// Creates a new RequestContext from a model ID string.
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
/// * `model_id` - The string identifier of the model
|
||||||
|
///
|
||||||
|
/// # Returns
|
||||||
|
/// * `Result<Self>` - A new RequestContext if the model ID is valid
|
||||||
|
pub fn new(model_id: &str) -> Result<Self> {
|
||||||
|
Ok(Self {
|
||||||
|
model: GcpVertexAIModel::try_from(model_id)
|
||||||
|
.with_context(|| format!("Failed to parse model ID: {model_id}"))?,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns the provider associated with the model.
|
||||||
|
pub fn provider(&self) -> ModelProvider {
|
||||||
|
match self.model {
|
||||||
|
GcpVertexAIModel::Claude(_) => ModelProvider::Anthropic,
|
||||||
|
GcpVertexAIModel::Gemini(_) => ModelProvider::Google,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Represents available model providers.
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq, Copy)]
|
||||||
|
pub enum ModelProvider {
|
||||||
|
/// Anthropic provider (Claude models)
|
||||||
|
Anthropic,
|
||||||
|
/// Google provider (Gemini models)
|
||||||
|
Google,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ModelProvider {
|
||||||
|
/// Returns the string representation of the provider.
|
||||||
|
pub fn as_str(&self) -> &'static str {
|
||||||
|
match self {
|
||||||
|
Self::Anthropic => "anthropic",
|
||||||
|
Self::Google => "google",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Creates an Anthropic-specific Vertex AI request payload.
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
/// * `model_config` - Configuration for the model
|
||||||
|
/// * `system` - System prompt
|
||||||
|
/// * `messages` - Array of messages
|
||||||
|
/// * `tools` - Array of available tools
|
||||||
|
///
|
||||||
|
/// # Returns
|
||||||
|
/// * `Result<Value>` - JSON request payload for Anthropic API
|
||||||
|
fn create_anthropic_request(
|
||||||
|
model_config: &ModelConfig,
|
||||||
|
system: &str,
|
||||||
|
messages: &[Message],
|
||||||
|
tools: &[Tool],
|
||||||
|
) -> Result<Value> {
|
||||||
|
let mut request = anthropic::create_request(model_config, system, messages, tools)?;
|
||||||
|
|
||||||
|
let obj = request
|
||||||
|
.as_object_mut()
|
||||||
|
.ok_or_else(|| ModelError::InvalidRequest("Request is not a JSON object".to_string()))?;
|
||||||
|
|
||||||
|
// Note: We don't need to specify the model in the request body
|
||||||
|
// The model is determined by the endpoint URL in GCP Vertex AI
|
||||||
|
obj.remove("model");
|
||||||
|
obj.insert(
|
||||||
|
"anthropic_version".to_string(),
|
||||||
|
Value::String("vertex-2023-10-16".to_string()),
|
||||||
|
);
|
||||||
|
|
||||||
|
Ok(request)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Creates a Gemini-specific Vertex AI request payload.
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
/// * `model_config` - Configuration for the model
|
||||||
|
/// * `system` - System prompt
|
||||||
|
/// * `messages` - Array of messages
|
||||||
|
/// * `tools` - Array of available tools
|
||||||
|
///
|
||||||
|
/// # Returns
|
||||||
|
/// * `Result<Value>` - JSON request payload for Google API
|
||||||
|
fn create_google_request(
|
||||||
|
model_config: &ModelConfig,
|
||||||
|
system: &str,
|
||||||
|
messages: &[Message],
|
||||||
|
tools: &[Tool],
|
||||||
|
) -> Result<Value> {
|
||||||
|
google::create_request(model_config, system, messages, tools)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Creates a provider-specific request payload and context.
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
/// * `model_config` - Configuration for the model
|
||||||
|
/// * `system` - System prompt
|
||||||
|
/// * `messages` - Array of messages
|
||||||
|
/// * `tools` - Array of available tools
|
||||||
|
///
|
||||||
|
/// # Returns
|
||||||
|
/// * `Result<(Value, RequestContext)>` - Tuple of request payload and context
|
||||||
|
pub fn create_request(
|
||||||
|
model_config: &ModelConfig,
|
||||||
|
system: &str,
|
||||||
|
messages: &[Message],
|
||||||
|
tools: &[Tool],
|
||||||
|
) -> Result<(Value, RequestContext)> {
|
||||||
|
let context = RequestContext::new(&model_config.model_name)?;
|
||||||
|
|
||||||
|
let request = match &context.model {
|
||||||
|
GcpVertexAIModel::Claude(_) => {
|
||||||
|
create_anthropic_request(model_config, system, messages, tools)?
|
||||||
|
}
|
||||||
|
GcpVertexAIModel::Gemini(_) => {
|
||||||
|
create_google_request(model_config, system, messages, tools)?
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok((request, context))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Converts a provider response to a Message.
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
/// * `response` - The raw response from the provider
|
||||||
|
/// * `request_context` - Context information about the request
|
||||||
|
///
|
||||||
|
/// # Returns
|
||||||
|
/// * `Result<Message>` - Converted message
|
||||||
|
pub fn response_to_message(response: Value, request_context: RequestContext) -> Result<Message> {
|
||||||
|
match request_context.provider() {
|
||||||
|
ModelProvider::Anthropic => anthropic::response_to_message(response),
|
||||||
|
ModelProvider::Google => google::response_to_message(response),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Extracts token usage information from the response data.
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
/// * `data` - The response data containing usage information
|
||||||
|
/// * `request_context` - Context information about the request
|
||||||
|
///
|
||||||
|
/// # Returns
|
||||||
|
/// * `Result<Usage>` - Usage statistics
|
||||||
|
pub fn get_usage(data: &Value, request_context: &RequestContext) -> Result<Usage> {
|
||||||
|
match request_context.provider() {
|
||||||
|
ModelProvider::Anthropic => anthropic::get_usage(data),
|
||||||
|
ModelProvider::Google => google::get_usage(data),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use anyhow::Result;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_model_parsing() -> Result<()> {
|
||||||
|
let valid_models = [
|
||||||
|
"claude-3-5-sonnet@20240620",
|
||||||
|
"claude-3-5-sonnet-v2@20241022",
|
||||||
|
"claude-3-7-sonnet@20250219",
|
||||||
|
"claude-3-5-haiku@20241022",
|
||||||
|
"gemini-1.5-pro-002",
|
||||||
|
"gemini-2.0-flash-001",
|
||||||
|
"gemini-2.0-pro-exp-02-05",
|
||||||
|
];
|
||||||
|
|
||||||
|
for model_id in valid_models {
|
||||||
|
let model = GcpVertexAIModel::try_from(model_id)?;
|
||||||
|
assert_eq!(model.to_string(), model_id);
|
||||||
|
}
|
||||||
|
|
||||||
|
assert!(GcpVertexAIModel::try_from("unsupported-model").is_err());
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_default_locations() -> Result<()> {
|
||||||
|
let test_cases = [
|
||||||
|
("claude-3-5-sonnet@20240620", GcpLocation::Ohio),
|
||||||
|
("claude-3-5-sonnet-v2@20241022", GcpLocation::Ohio),
|
||||||
|
("claude-3-7-sonnet@20250219", GcpLocation::Ohio),
|
||||||
|
("claude-3-5-haiku@20241022", GcpLocation::Ohio),
|
||||||
|
("gemini-1.5-pro-002", GcpLocation::Iowa),
|
||||||
|
("gemini-2.0-flash-001", GcpLocation::Iowa),
|
||||||
|
("gemini-2.0-pro-exp-02-05", GcpLocation::Iowa),
|
||||||
|
];
|
||||||
|
|
||||||
|
for (model_id, expected_location) in test_cases {
|
||||||
|
let model = GcpVertexAIModel::try_from(model_id)?;
|
||||||
|
assert_eq!(
|
||||||
|
model.known_location(),
|
||||||
|
expected_location,
|
||||||
|
"Model {model_id} should have default location {expected_location:?}",
|
||||||
|
);
|
||||||
|
|
||||||
|
let context = RequestContext::new(model_id)?;
|
||||||
|
assert_eq!(
|
||||||
|
context.model.known_location(),
|
||||||
|
expected_location,
|
||||||
|
"RequestContext for {model_id} should have default location {expected_location:?}",
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,4 +1,5 @@
|
|||||||
pub mod anthropic;
|
pub mod anthropic;
|
||||||
pub mod bedrock;
|
pub mod bedrock;
|
||||||
|
pub mod gcpvertexai;
|
||||||
pub mod google;
|
pub mod google;
|
||||||
pub mod openai;
|
pub mod openai;
|
||||||
|
|||||||
1115
crates/goose/src/providers/gcpauth.rs
Normal file
1115
crates/goose/src/providers/gcpauth.rs
Normal file
File diff suppressed because it is too large
Load Diff
595
crates/goose/src/providers/gcpvertexai.rs
Normal file
595
crates/goose/src/providers/gcpvertexai.rs
Normal file
@@ -0,0 +1,595 @@
|
|||||||
|
use std::time::Duration;
|
||||||
|
|
||||||
|
use anyhow::Result;
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use reqwest::{Client, StatusCode};
|
||||||
|
use serde_json::Value;
|
||||||
|
use tokio::time::sleep;
|
||||||
|
use url::Url;
|
||||||
|
|
||||||
|
use crate::message::Message;
|
||||||
|
use crate::model::ModelConfig;
|
||||||
|
use crate::providers::base::{ConfigKey, Provider, ProviderMetadata, ProviderUsage};
|
||||||
|
|
||||||
|
use crate::providers::errors::ProviderError;
|
||||||
|
use crate::providers::formats::gcpvertexai::{
|
||||||
|
create_request, get_usage, response_to_message, ClaudeVersion, GcpVertexAIModel, GeminiVersion,
|
||||||
|
ModelProvider, RequestContext,
|
||||||
|
};
|
||||||
|
|
||||||
|
use crate::providers::formats::gcpvertexai::GcpLocation::Iowa;
|
||||||
|
use crate::providers::gcpauth::GcpAuth;
|
||||||
|
use crate::providers::utils::emit_debug_trace;
|
||||||
|
use mcp_core::tool::Tool;
|
||||||
|
|
||||||
|
/// Base URL for GCP Vertex AI documentation
|
||||||
|
const GCP_VERTEX_AI_DOC_URL: &str = "https://cloud.google.com/vertex-ai";
|
||||||
|
/// Default timeout for API requests in seconds
|
||||||
|
const DEFAULT_TIMEOUT_SECS: u64 = 600;
|
||||||
|
/// Default initial interval for retry (in milliseconds)
|
||||||
|
const DEFAULT_INITIAL_RETRY_INTERVAL_MS: u64 = 5000;
|
||||||
|
/// Default maximum number of retries
|
||||||
|
const DEFAULT_MAX_RETRIES: usize = 6;
|
||||||
|
/// Default retry backoff multiplier
|
||||||
|
const DEFAULT_BACKOFF_MULTIPLIER: f64 = 2.0;
|
||||||
|
/// Default maximum interval for retry (in milliseconds)
|
||||||
|
const DEFAULT_MAX_RETRY_INTERVAL_MS: u64 = 320_000;
|
||||||
|
|
||||||
|
/// Represents errors specific to GCP Vertex AI operations.
|
||||||
|
#[derive(Debug, thiserror::Error)]
|
||||||
|
enum GcpVertexAIError {
|
||||||
|
/// Error when URL construction fails
|
||||||
|
#[error("Invalid URL configuration: {0}")]
|
||||||
|
InvalidUrl(String),
|
||||||
|
|
||||||
|
/// Error during GCP authentication
|
||||||
|
#[error("Authentication error: {0}")]
|
||||||
|
AuthError(String),
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Retry configuration for handling rate limit errors
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
struct RetryConfig {
|
||||||
|
/// Maximum number of retry attempts
|
||||||
|
max_retries: usize,
|
||||||
|
/// Initial interval between retries in milliseconds
|
||||||
|
initial_interval_ms: u64,
|
||||||
|
/// Multiplier for backoff (exponential)
|
||||||
|
backoff_multiplier: f64,
|
||||||
|
/// Maximum interval between retries in milliseconds
|
||||||
|
max_interval_ms: u64,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for RetryConfig {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
max_retries: DEFAULT_MAX_RETRIES,
|
||||||
|
initial_interval_ms: DEFAULT_INITIAL_RETRY_INTERVAL_MS,
|
||||||
|
backoff_multiplier: DEFAULT_BACKOFF_MULTIPLIER,
|
||||||
|
max_interval_ms: DEFAULT_MAX_RETRY_INTERVAL_MS,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl RetryConfig {
|
||||||
|
/// Calculate the delay for a specific retry attempt (with jitter)
|
||||||
|
fn delay_for_attempt(&self, attempt: usize) -> Duration {
|
||||||
|
if attempt == 0 {
|
||||||
|
return Duration::from_millis(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Calculate exponential backoff
|
||||||
|
let exponent = (attempt - 1) as u32;
|
||||||
|
let base_delay_ms = (self.initial_interval_ms as f64
|
||||||
|
* self.backoff_multiplier.powi(exponent as i32)) as u64;
|
||||||
|
|
||||||
|
// Apply max limit
|
||||||
|
let capped_delay_ms = std::cmp::min(base_delay_ms, self.max_interval_ms);
|
||||||
|
|
||||||
|
// Add jitter (+/-20% randomness) to avoid thundering herd problem
|
||||||
|
let jitter_factor = 0.8 + (rand::random::<f64>() * 0.4); // Between 0.8 and 1.2
|
||||||
|
let jittered_delay_ms = (capped_delay_ms as f64 * jitter_factor) as u64;
|
||||||
|
|
||||||
|
Duration::from_millis(jittered_delay_ms)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Provider implementation for Google Cloud Platform's Vertex AI service.
|
||||||
|
///
|
||||||
|
/// This provider enables interaction with various AI models hosted on GCP Vertex AI,
|
||||||
|
/// including Claude and Gemini model families. It handles authentication, request routing,
|
||||||
|
/// and response processing for the Vertex AI API endpoints.
|
||||||
|
#[derive(Debug, serde::Serialize)]
|
||||||
|
pub struct GcpVertexAIProvider {
|
||||||
|
/// HTTP client for making API requests
|
||||||
|
#[serde(skip)]
|
||||||
|
client: Client,
|
||||||
|
/// GCP authentication handler
|
||||||
|
#[serde(skip)]
|
||||||
|
auth: GcpAuth,
|
||||||
|
/// Base URL for the Vertex AI API
|
||||||
|
host: String,
|
||||||
|
/// GCP project identifier
|
||||||
|
project_id: String,
|
||||||
|
/// GCP region for model deployment
|
||||||
|
location: String,
|
||||||
|
/// Configuration for the specific model being used
|
||||||
|
model: ModelConfig,
|
||||||
|
/// Retry configuration for handling rate limit errors
|
||||||
|
#[serde(skip)]
|
||||||
|
retry_config: RetryConfig,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl GcpVertexAIProvider {
|
||||||
|
/// Creates a new provider instance from environment configuration.
|
||||||
|
///
|
||||||
|
/// This is a convenience method that initializes the provider using
|
||||||
|
/// environment variables and default settings.
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
/// * `model` - Configuration for the model to be used
|
||||||
|
pub fn from_env(model: ModelConfig) -> Result<Self> {
|
||||||
|
Self::new(model)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Creates a new provider instance with the specified model configuration.
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
/// * `model` - Configuration for the model to be used
|
||||||
|
pub fn new(model: ModelConfig) -> Result<Self> {
|
||||||
|
futures::executor::block_on(Self::new_async(model))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Async implementation of new provider instance creation.
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
/// * `model` - Configuration for the model to be used
|
||||||
|
async fn new_async(model: ModelConfig) -> Result<Self> {
|
||||||
|
let config = crate::config::Config::global();
|
||||||
|
let project_id = config.get("GCP_PROJECT_ID")?;
|
||||||
|
let location = Self::determine_location(config)?;
|
||||||
|
let host = format!("https://{}-aiplatform.googleapis.com", location);
|
||||||
|
|
||||||
|
let client = Client::builder()
|
||||||
|
.timeout(Duration::from_secs(DEFAULT_TIMEOUT_SECS))
|
||||||
|
.build()?;
|
||||||
|
|
||||||
|
let auth = GcpAuth::new().await?;
|
||||||
|
|
||||||
|
// Load optional retry configuration from environment
|
||||||
|
let retry_config = Self::load_retry_config(config);
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
client,
|
||||||
|
auth,
|
||||||
|
host,
|
||||||
|
project_id,
|
||||||
|
location,
|
||||||
|
model,
|
||||||
|
retry_config,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Loads retry configuration from environment variables or uses defaults.
|
||||||
|
fn load_retry_config(config: &crate::config::Config) -> RetryConfig {
|
||||||
|
let max_retries = config
|
||||||
|
.get("GCP_MAX_RETRIES")
|
||||||
|
.ok()
|
||||||
|
.and_then(|v: String| v.parse::<usize>().ok())
|
||||||
|
.unwrap_or(DEFAULT_MAX_RETRIES);
|
||||||
|
|
||||||
|
let initial_interval_ms = config
|
||||||
|
.get("GCP_INITIAL_RETRY_INTERVAL_MS")
|
||||||
|
.ok()
|
||||||
|
.and_then(|v: String| v.parse::<u64>().ok())
|
||||||
|
.unwrap_or(DEFAULT_INITIAL_RETRY_INTERVAL_MS);
|
||||||
|
|
||||||
|
let backoff_multiplier = config
|
||||||
|
.get("GCP_BACKOFF_MULTIPLIER")
|
||||||
|
.ok()
|
||||||
|
.and_then(|v: String| v.parse::<f64>().ok())
|
||||||
|
.unwrap_or(DEFAULT_BACKOFF_MULTIPLIER);
|
||||||
|
|
||||||
|
let max_interval_ms = config
|
||||||
|
.get("GCP_MAX_RETRY_INTERVAL_MS")
|
||||||
|
.ok()
|
||||||
|
.and_then(|v: String| v.parse::<u64>().ok())
|
||||||
|
.unwrap_or(DEFAULT_MAX_RETRY_INTERVAL_MS);
|
||||||
|
|
||||||
|
RetryConfig {
|
||||||
|
max_retries,
|
||||||
|
initial_interval_ms,
|
||||||
|
backoff_multiplier,
|
||||||
|
max_interval_ms,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Determines the appropriate GCP location for model deployment.
|
||||||
|
///
|
||||||
|
/// Location is determined in the following order:
|
||||||
|
/// 1. Custom location from GCP_LOCATION environment variable
|
||||||
|
/// 2. Global default location (Iowa)
|
||||||
|
fn determine_location(config: &crate::config::Config) -> Result<String> {
|
||||||
|
Ok(config
|
||||||
|
.get("GCP_LOCATION")
|
||||||
|
.ok()
|
||||||
|
.filter(|location: &String| !location.trim().is_empty())
|
||||||
|
.unwrap_or_else(|| Iowa.to_string()))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Retrieves an authentication token for API requests.
|
||||||
|
async fn get_auth_header(&self) -> Result<String, GcpVertexAIError> {
|
||||||
|
self.auth
|
||||||
|
.get_token()
|
||||||
|
.await
|
||||||
|
.map(|token| format!("Bearer {}", token.token_value))
|
||||||
|
.map_err(|e| GcpVertexAIError::AuthError(e.to_string()))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Constructs the appropriate API endpoint URL for a given provider.
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
/// * `provider` - The model provider (Anthropic or Google)
|
||||||
|
/// * `location` - The GCP location for model deployment
|
||||||
|
fn build_request_url(
|
||||||
|
&self,
|
||||||
|
provider: ModelProvider,
|
||||||
|
location: &str,
|
||||||
|
) -> Result<Url, GcpVertexAIError> {
|
||||||
|
// Create host URL for the specified location
|
||||||
|
let host_url = if self.location == location {
|
||||||
|
self.host.clone()
|
||||||
|
} else {
|
||||||
|
// Only allocate a new string if location differs
|
||||||
|
self.host.replace(&self.location, location)
|
||||||
|
};
|
||||||
|
|
||||||
|
let base_url =
|
||||||
|
Url::parse(&host_url).map_err(|e| GcpVertexAIError::InvalidUrl(e.to_string()))?;
|
||||||
|
|
||||||
|
// Determine endpoint based on provider type
|
||||||
|
let endpoint = match provider {
|
||||||
|
ModelProvider::Anthropic => "streamRawPredict",
|
||||||
|
ModelProvider::Google => "generateContent",
|
||||||
|
};
|
||||||
|
|
||||||
|
// Construct path for URL
|
||||||
|
let path = format!(
|
||||||
|
"v1/projects/{}/locations/{}/publishers/{}/models/{}:{}",
|
||||||
|
self.project_id,
|
||||||
|
location,
|
||||||
|
provider.as_str(),
|
||||||
|
self.model.model_name,
|
||||||
|
endpoint
|
||||||
|
);
|
||||||
|
|
||||||
|
base_url
|
||||||
|
.join(&path)
|
||||||
|
.map_err(|e| GcpVertexAIError::InvalidUrl(e.to_string()))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Makes an authenticated POST request to the Vertex AI API at a specific location.
|
||||||
|
/// Includes retry logic for 429 Too Many Requests errors.
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
/// * `payload` - The request payload to send
|
||||||
|
/// * `context` - Request context containing model information
|
||||||
|
/// * `location` - The GCP location for the request
|
||||||
|
async fn post_with_location(
|
||||||
|
&self,
|
||||||
|
payload: &Value,
|
||||||
|
context: &RequestContext,
|
||||||
|
location: &str,
|
||||||
|
) -> Result<Value, ProviderError> {
|
||||||
|
let url = self
|
||||||
|
.build_request_url(context.provider(), location)
|
||||||
|
.map_err(|e| ProviderError::RequestFailed(e.to_string()))?;
|
||||||
|
|
||||||
|
// Initialize retry counter
|
||||||
|
let mut attempts = 0;
|
||||||
|
let mut last_error = None;
|
||||||
|
|
||||||
|
loop {
|
||||||
|
// Check if we've exceeded max retries
|
||||||
|
if attempts > 0 && attempts > self.retry_config.max_retries {
|
||||||
|
let error_msg = format!(
|
||||||
|
"Exceeded maximum retry attempts ({}) for rate limiting (429)",
|
||||||
|
self.retry_config.max_retries
|
||||||
|
);
|
||||||
|
tracing::error!("{}", error_msg);
|
||||||
|
return Err(last_error.unwrap_or(ProviderError::RateLimitExceeded(error_msg)));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get a fresh auth token for each attempt
|
||||||
|
let auth_header = self
|
||||||
|
.get_auth_header()
|
||||||
|
.await
|
||||||
|
.map_err(|e| ProviderError::Authentication(e.to_string()))?;
|
||||||
|
|
||||||
|
// Make the request
|
||||||
|
let response = self
|
||||||
|
.client
|
||||||
|
.post(url.clone())
|
||||||
|
.json(payload)
|
||||||
|
.header("Authorization", auth_header)
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.map_err(|e| ProviderError::RequestFailed(e.to_string()))?;
|
||||||
|
|
||||||
|
let status = response.status();
|
||||||
|
|
||||||
|
// If not a 429, process normally
|
||||||
|
if status != StatusCode::TOO_MANY_REQUESTS {
|
||||||
|
let response_json = response.json::<Value>().await.map_err(|e| {
|
||||||
|
ProviderError::RequestFailed(format!("Failed to parse response: {e}"))
|
||||||
|
})?;
|
||||||
|
|
||||||
|
return match status {
|
||||||
|
StatusCode::OK => Ok(response_json),
|
||||||
|
StatusCode::UNAUTHORIZED | StatusCode::FORBIDDEN => {
|
||||||
|
tracing::debug!(
|
||||||
|
"Authentication failed. Status: {status}, Payload: {payload:?}"
|
||||||
|
);
|
||||||
|
Err(ProviderError::Authentication(format!(
|
||||||
|
"Authentication failed: {response_json:?}"
|
||||||
|
)))
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
tracing::debug!(
|
||||||
|
"Request failed. Status: {status}, Response: {response_json:?}"
|
||||||
|
);
|
||||||
|
Err(ProviderError::RequestFailed(format!(
|
||||||
|
"Request failed with status {status}: {response_json:?}"
|
||||||
|
)))
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle 429 Too Many Requests
|
||||||
|
attempts += 1;
|
||||||
|
|
||||||
|
// Try to parse response for more detailed error info
|
||||||
|
let cite_gcp_vertex_429 =
|
||||||
|
"See https://cloud.google.com/vertex-ai/generative-ai/docs/error-code-429";
|
||||||
|
let response_text = response.text().await.unwrap_or_default();
|
||||||
|
let quota_error = if response_text.contains("Exceeded the Provisioned Throughput") {
|
||||||
|
format!("Exceeded the Provisioned Throughput: {cite_gcp_vertex_429}.")
|
||||||
|
} else {
|
||||||
|
format!("Pay-as-you-go resource exhausted: {cite_gcp_vertex_429}.")
|
||||||
|
};
|
||||||
|
|
||||||
|
tracing::warn!(
|
||||||
|
"Rate limit exceeded (attempt {}/{}): {}. Retrying after backoff...",
|
||||||
|
attempts,
|
||||||
|
self.retry_config.max_retries,
|
||||||
|
quota_error
|
||||||
|
);
|
||||||
|
|
||||||
|
// Store the error in case we need to return it after max retries
|
||||||
|
last_error = Some(ProviderError::RateLimitExceeded(quota_error));
|
||||||
|
|
||||||
|
// Calculate and apply the backoff delay
|
||||||
|
let delay = self.retry_config.delay_for_attempt(attempts);
|
||||||
|
tracing::info!("Backing off for {:?} before retry", delay);
|
||||||
|
sleep(delay).await;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Makes an authenticated POST request to the Vertex AI API with fallback for invalid locations.
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
/// * `payload` - The request payload to send
|
||||||
|
/// * `context` - Request context containing model information
|
||||||
|
async fn post(&self, payload: Value, context: &RequestContext) -> Result<Value, ProviderError> {
|
||||||
|
// Try with user-specified location first
|
||||||
|
let result = self
|
||||||
|
.post_with_location(&payload, context, &self.location)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
// If location is already the known location for the model or request succeeded, return result
|
||||||
|
if self.location == context.model.known_location().to_string() || result.is_ok() {
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if we should retry with the model's known location
|
||||||
|
match &result {
|
||||||
|
Err(ProviderError::RequestFailed(msg)) => {
|
||||||
|
let model_name = context.model.to_string();
|
||||||
|
let configured_location = &self.location;
|
||||||
|
let known_location = context.model.known_location().to_string();
|
||||||
|
|
||||||
|
tracing::error!(
|
||||||
|
"Trying known location {known_location} for {model_name} instead of {configured_location}: {msg}"
|
||||||
|
);
|
||||||
|
|
||||||
|
self.post_with_location(&payload, context, &known_location)
|
||||||
|
.await
|
||||||
|
}
|
||||||
|
// For any other error, return the original result
|
||||||
|
_ => result,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for GcpVertexAIProvider {
|
||||||
|
fn default() -> Self {
|
||||||
|
let model = ModelConfig::new(Self::metadata().default_model);
|
||||||
|
Self::new(model).expect("Failed to initialize VertexAI provider")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl Provider for GcpVertexAIProvider {
|
||||||
|
/// Returns metadata about the GCP Vertex AI provider.
|
||||||
|
fn metadata() -> ProviderMetadata
|
||||||
|
where
|
||||||
|
Self: Sized,
|
||||||
|
{
|
||||||
|
let known_models = vec![
|
||||||
|
GcpVertexAIModel::Claude(ClaudeVersion::Sonnet35),
|
||||||
|
GcpVertexAIModel::Claude(ClaudeVersion::Sonnet35V2),
|
||||||
|
GcpVertexAIModel::Claude(ClaudeVersion::Sonnet37),
|
||||||
|
GcpVertexAIModel::Claude(ClaudeVersion::Haiku35),
|
||||||
|
GcpVertexAIModel::Gemini(GeminiVersion::Pro15),
|
||||||
|
GcpVertexAIModel::Gemini(GeminiVersion::Flash20),
|
||||||
|
GcpVertexAIModel::Gemini(GeminiVersion::Pro20Exp),
|
||||||
|
]
|
||||||
|
.into_iter()
|
||||||
|
.map(|model| model.to_string())
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
ProviderMetadata::new(
|
||||||
|
"gcp_vertex_ai",
|
||||||
|
"GCP Vertex AI",
|
||||||
|
"Access variety of AI models such as Claude, Gemini through Vertex AI",
|
||||||
|
GcpVertexAIModel::Gemini(GeminiVersion::Flash20)
|
||||||
|
.to_string()
|
||||||
|
.as_str(),
|
||||||
|
known_models,
|
||||||
|
GCP_VERTEX_AI_DOC_URL,
|
||||||
|
vec![
|
||||||
|
ConfigKey::new("GCP_PROJECT_ID", true, false, None),
|
||||||
|
ConfigKey::new("GCP_LOCATION", true, false, Some(Iowa.to_string().as_str())),
|
||||||
|
ConfigKey::new(
|
||||||
|
"GCP_MAX_RETRIES",
|
||||||
|
false,
|
||||||
|
false,
|
||||||
|
Some(&DEFAULT_MAX_RETRIES.to_string()),
|
||||||
|
),
|
||||||
|
ConfigKey::new(
|
||||||
|
"GCP_INITIAL_RETRY_INTERVAL_MS",
|
||||||
|
false,
|
||||||
|
false,
|
||||||
|
Some(&DEFAULT_INITIAL_RETRY_INTERVAL_MS.to_string()),
|
||||||
|
),
|
||||||
|
ConfigKey::new(
|
||||||
|
"GCP_BACKOFF_MULTIPLIER",
|
||||||
|
false,
|
||||||
|
false,
|
||||||
|
Some(&DEFAULT_BACKOFF_MULTIPLIER.to_string()),
|
||||||
|
),
|
||||||
|
ConfigKey::new(
|
||||||
|
"GCP_MAX_RETRY_INTERVAL_MS",
|
||||||
|
false,
|
||||||
|
false,
|
||||||
|
Some(&DEFAULT_MAX_RETRY_INTERVAL_MS.to_string()),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Completes a model interaction by sending a request and processing the response.
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
/// * `system` - System prompt or context
|
||||||
|
/// * `messages` - Array of previous messages in the conversation
|
||||||
|
/// * `tools` - Array of available tools for the model
|
||||||
|
#[tracing::instrument(
|
||||||
|
skip(self, system, messages, tools),
|
||||||
|
fields(model_config, input, output, input_tokens, output_tokens, total_tokens)
|
||||||
|
)]
|
||||||
|
async fn complete(
|
||||||
|
&self,
|
||||||
|
system: &str,
|
||||||
|
messages: &[Message],
|
||||||
|
tools: &[Tool],
|
||||||
|
) -> Result<(Message, ProviderUsage), ProviderError> {
|
||||||
|
// Create request and context
|
||||||
|
let (request, context) = create_request(&self.model, system, messages, tools)?;
|
||||||
|
|
||||||
|
// Send request and process response
|
||||||
|
let response = self.post(request.clone(), &context).await?;
|
||||||
|
let usage = get_usage(&response, &context)?;
|
||||||
|
|
||||||
|
emit_debug_trace(self, &request, &response, &usage);
|
||||||
|
|
||||||
|
// Convert response to message
|
||||||
|
let message = response_to_message(response, context)?;
|
||||||
|
let provider_usage = ProviderUsage::new(self.model.model_name.clone(), usage);
|
||||||
|
|
||||||
|
Ok((message, provider_usage))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns the current model configuration.
|
||||||
|
fn get_model_config(&self) -> ModelConfig {
|
||||||
|
self.model.clone()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_retry_config_delay_calculation() {
|
||||||
|
let config = RetryConfig {
|
||||||
|
max_retries: 5,
|
||||||
|
initial_interval_ms: 1000,
|
||||||
|
backoff_multiplier: 2.0,
|
||||||
|
max_interval_ms: 32000,
|
||||||
|
};
|
||||||
|
|
||||||
|
// First attempt has no delay
|
||||||
|
let delay0 = config.delay_for_attempt(0);
|
||||||
|
assert_eq!(delay0.as_millis(), 0);
|
||||||
|
|
||||||
|
// First retry should be around initial_interval with jitter
|
||||||
|
let delay1 = config.delay_for_attempt(1);
|
||||||
|
assert!(delay1.as_millis() >= 800 && delay1.as_millis() <= 1200);
|
||||||
|
|
||||||
|
// Second retry should be around initial_interval * multiplier^1 with jitter
|
||||||
|
let delay2 = config.delay_for_attempt(2);
|
||||||
|
assert!(delay2.as_millis() >= 1600 && delay2.as_millis() <= 2400);
|
||||||
|
|
||||||
|
// Check that max interval is respected
|
||||||
|
let delay10 = config.delay_for_attempt(10);
|
||||||
|
assert!(delay10.as_millis() <= 38400); // max_interval_ms * 1.2 (max jitter)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_model_provider_conversion() {
|
||||||
|
assert_eq!(ModelProvider::Anthropic.as_str(), "anthropic");
|
||||||
|
assert_eq!(ModelProvider::Google.as_str(), "google");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_url_construction() {
|
||||||
|
use url::Url;
|
||||||
|
|
||||||
|
let model_config = ModelConfig::new("claude-3-5-sonnet-v2@20241022".to_string());
|
||||||
|
let context = RequestContext::new(&model_config.model_name).unwrap();
|
||||||
|
let api_model_id = context.model.to_string();
|
||||||
|
|
||||||
|
let host = "https://us-east5-aiplatform.googleapis.com";
|
||||||
|
let project_id = "test-project";
|
||||||
|
let location = "us-east5";
|
||||||
|
|
||||||
|
let path = format!(
|
||||||
|
"v1/projects/{}/locations/{}/publishers/{}/models/{}:{}",
|
||||||
|
project_id,
|
||||||
|
location,
|
||||||
|
ModelProvider::Anthropic.as_str(),
|
||||||
|
api_model_id,
|
||||||
|
"streamRawPredict"
|
||||||
|
);
|
||||||
|
|
||||||
|
let url = Url::parse(host).unwrap().join(&path).unwrap();
|
||||||
|
|
||||||
|
assert!(url.as_str().contains("publishers/anthropic"));
|
||||||
|
assert!(url.as_str().contains("projects/test-project"));
|
||||||
|
assert!(url.as_str().contains("locations/us-east5"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_provider_metadata() {
|
||||||
|
let metadata = GcpVertexAIProvider::metadata();
|
||||||
|
assert!(metadata
|
||||||
|
.known_models
|
||||||
|
.contains(&"claude-3-5-sonnet-v2@20241022".to_string()));
|
||||||
|
assert!(metadata
|
||||||
|
.known_models
|
||||||
|
.contains(&"gemini-1.5-pro-002".to_string()));
|
||||||
|
// Should contain the original 2 config keys plus 4 new retry-related ones
|
||||||
|
assert_eq!(metadata.config_keys.len(), 6);
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -6,6 +6,8 @@ pub mod databricks;
|
|||||||
pub mod errors;
|
pub mod errors;
|
||||||
mod factory;
|
mod factory;
|
||||||
pub mod formats;
|
pub mod formats;
|
||||||
|
mod gcpauth;
|
||||||
|
pub mod gcpvertexai;
|
||||||
pub mod google;
|
pub mod google;
|
||||||
pub mod groq;
|
pub mod groq;
|
||||||
pub mod oauth;
|
pub mod oauth;
|
||||||
|
|||||||
@@ -6,12 +6,12 @@ use goose::agents::AgentFactory;
|
|||||||
use goose::message::Message;
|
use goose::message::Message;
|
||||||
use goose::model::ModelConfig;
|
use goose::model::ModelConfig;
|
||||||
use goose::providers::base::Provider;
|
use goose::providers::base::Provider;
|
||||||
use goose::providers::{anthropic::AnthropicProvider, databricks::DatabricksProvider};
|
|
||||||
use goose::providers::{
|
use goose::providers::{
|
||||||
azure::AzureProvider, bedrock::BedrockProvider, ollama::OllamaProvider, openai::OpenAiProvider,
|
anthropic::AnthropicProvider, azure::AzureProvider, bedrock::BedrockProvider,
|
||||||
|
databricks::DatabricksProvider, gcpvertexai::GcpVertexAIProvider, google::GoogleProvider,
|
||||||
|
groq::GroqProvider, ollama::OllamaProvider, openai::OpenAiProvider,
|
||||||
openrouter::OpenRouterProvider,
|
openrouter::OpenRouterProvider,
|
||||||
};
|
};
|
||||||
use goose::providers::{google::GoogleProvider, groq::GroqProvider};
|
|
||||||
|
|
||||||
#[derive(Debug, PartialEq)]
|
#[derive(Debug, PartialEq)]
|
||||||
enum ProviderType {
|
enum ProviderType {
|
||||||
@@ -20,6 +20,7 @@ enum ProviderType {
|
|||||||
Anthropic,
|
Anthropic,
|
||||||
Bedrock,
|
Bedrock,
|
||||||
Databricks,
|
Databricks,
|
||||||
|
GcpVertexAI,
|
||||||
Google,
|
Google,
|
||||||
Groq,
|
Groq,
|
||||||
Ollama,
|
Ollama,
|
||||||
@@ -42,6 +43,7 @@ impl ProviderType {
|
|||||||
ProviderType::Groq => &["GROQ_API_KEY"],
|
ProviderType::Groq => &["GROQ_API_KEY"],
|
||||||
ProviderType::Ollama => &[],
|
ProviderType::Ollama => &[],
|
||||||
ProviderType::OpenRouter => &["OPENROUTER_API_KEY"],
|
ProviderType::OpenRouter => &["OPENROUTER_API_KEY"],
|
||||||
|
ProviderType::GcpVertexAI => &["GCP_PROJECT_ID", "GCP_LOCATION"],
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -70,6 +72,7 @@ impl ProviderType {
|
|||||||
ProviderType::Anthropic => Box::new(AnthropicProvider::from_env(model_config)?),
|
ProviderType::Anthropic => Box::new(AnthropicProvider::from_env(model_config)?),
|
||||||
ProviderType::Bedrock => Box::new(BedrockProvider::from_env(model_config)?),
|
ProviderType::Bedrock => Box::new(BedrockProvider::from_env(model_config)?),
|
||||||
ProviderType::Databricks => Box::new(DatabricksProvider::from_env(model_config)?),
|
ProviderType::Databricks => Box::new(DatabricksProvider::from_env(model_config)?),
|
||||||
|
ProviderType::GcpVertexAI => Box::new(GcpVertexAIProvider::from_env(model_config)?),
|
||||||
ProviderType::Google => Box::new(GoogleProvider::from_env(model_config)?),
|
ProviderType::Google => Box::new(GoogleProvider::from_env(model_config)?),
|
||||||
ProviderType::Groq => Box::new(GroqProvider::from_env(model_config)?),
|
ProviderType::Groq => Box::new(GroqProvider::from_env(model_config)?),
|
||||||
ProviderType::Ollama => Box::new(OllamaProvider::from_env(model_config)?),
|
ProviderType::Ollama => Box::new(OllamaProvider::from_env(model_config)?),
|
||||||
@@ -290,4 +293,14 @@ mod tests {
|
|||||||
})
|
})
|
||||||
.await
|
.await
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_truncate_agent_with_gcpvertexai() -> Result<()> {
|
||||||
|
run_test_with_config(TestConfig {
|
||||||
|
provider_type: ProviderType::GcpVertexAI,
|
||||||
|
model: "claude-3-5-sonnet-v2@20241022",
|
||||||
|
context_window: 200_000,
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -18,12 +18,13 @@ Goose relies heavily on tool calling capabilities and currently works best with
|
|||||||
## Available Providers
|
## Available Providers
|
||||||
|
|
||||||
| Provider | Description | Parameters |
|
| Provider | Description | Parameters |
|
||||||
|-----------------------------------------------|-----------------------------------------------------|---------------------------------------|
|
|-----------------------------------------------------------------------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
||||||
| [Amazon Bedrock](https://aws.amazon.com/bedrock/) | Offers a variety of foundation models, including Claude, Jurassic-2, and others. **Environment variables must be set in advance, not configured through `goose configure`** | `AWS_ACCESS_KEY_ID`, `AWS_SECRET_ACCESS_KEY`, `AWS_REGION` |
|
| [Amazon Bedrock](https://aws.amazon.com/bedrock/) | Offers a variety of foundation models, including Claude, Jurassic-2, and others. **Environment variables must be set in advance, not configured through `goose configure`** | `AWS_ACCESS_KEY_ID`, `AWS_SECRET_ACCESS_KEY`, `AWS_REGION` |
|
||||||
| [Anthropic](https://www.anthropic.com/) | Offers Claude, an advanced AI model for natural language tasks. | `ANTHROPIC_API_KEY` |
|
| [Anthropic](https://www.anthropic.com/) | Offers Claude, an advanced AI model for natural language tasks. | `ANTHROPIC_API_KEY` |
|
||||||
| [Azure OpenAI](https://learn.microsoft.com/en-us/azure/ai-services/openai/) | Access Azure-hosted OpenAI models, including GPT-4 and GPT-3.5. | `AZURE_OPENAI_API_KEY`, `AZURE_OPENAI_ENDPOINT`, `AZURE_OPENAI_DEPLOYMENT_NAME` |
|
| [Azure OpenAI](https://learn.microsoft.com/en-us/azure/ai-services/openai/) | Access Azure-hosted OpenAI models, including GPT-4 and GPT-3.5. | `AZURE_OPENAI_API_KEY`, `AZURE_OPENAI_ENDPOINT`, `AZURE_OPENAI_DEPLOYMENT_NAME` |
|
||||||
| [Databricks](https://www.databricks.com/) | Unified data analytics and AI platform for building and deploying models. | `DATABRICKS_HOST`, `DATABRICKS_TOKEN` |
|
| [Databricks](https://www.databricks.com/) | Unified data analytics and AI platform for building and deploying models. | `DATABRICKS_HOST`, `DATABRICKS_TOKEN` |
|
||||||
| [Gemini](https://ai.google.dev/gemini-api/docs) | Advanced LLMs by Google with multimodal capabilities (text, images). | `GOOGLE_API_KEY` |
|
| [Gemini](https://ai.google.dev/gemini-api/docs) | Advanced LLMs by Google with multimodal capabilities (text, images). | `GOOGLE_API_KEY` |
|
||||||
|
| [GCP Vertex AI](https://cloud.google.com/vertex-ai) | Google Cloud's Vertex AI platform, supporting Gemini and Claude models. **Credentials must be configured in advance. Follow the instructions at https://cloud.google.com/vertex-ai/docs/authentication.** | `GCP_PROJECT_ID`, `GCP_LOCATION` and optional `GCP_MAX_RETRIES` (6), `GCP_INITIAL_RETRY_INTERVAL_MS` (5000), `GCP_BACKOFF_MULTIPLIER` (2.0), `GCP_MAX_RETRY_INTERVAL_MS` (320_000). |
|
||||||
| [Groq](https://groq.com/) | High-performance inference hardware and tools for LLMs. | `GROQ_API_KEY` |
|
| [Groq](https://groq.com/) | High-performance inference hardware and tools for LLMs. | `GROQ_API_KEY` |
|
||||||
| [Ollama](https://ollama.com/) | Local model runner supporting Qwen, Llama, DeepSeek, and other open-source models. **Because this provider runs locally, you must first [download and run a model](/docs/getting-started/providers#local-llms-ollama).** | `OLLAMA_HOST` |
|
| [Ollama](https://ollama.com/) | Local model runner supporting Qwen, Llama, DeepSeek, and other open-source models. **Because this provider runs locally, you must first [download and run a model](/docs/getting-started/providers#local-llms-ollama).** | `OLLAMA_HOST` |
|
||||||
| [OpenAI](https://platform.openai.com/api-keys) | Provides gpt-4o, o1, and other advanced language models. Also supports OpenAI-compatible endpoints (e.g., self-hosted LLaMA, vLLM, KServe). **o1-mini and o1-preview are not supported because Goose uses tool calling.** | `OPENAI_API_KEY`, `OPENAI_HOST` (optional), `OPENAI_ORGANIZATION` (optional), `OPENAI_PROJECT` (optional) |
|
| [OpenAI](https://platform.openai.com/api-keys) | Provides gpt-4o, o1, and other advanced language models. Also supports OpenAI-compatible endpoints (e.g., self-hosted LLaMA, vLLM, KServe). **o1-mini and o1-preview are not supported because Goose uses tool calling.** | `OPENAI_API_KEY`, `OPENAI_HOST` (optional), `OPENAI_ORGANIZATION` (optional), `OPENAI_PROJECT` (optional) |
|
||||||
|
|||||||
@@ -11,6 +11,8 @@ export function isSecretKey(keyName: string): boolean {
|
|||||||
'OPENAI_BASE_PATH',
|
'OPENAI_BASE_PATH',
|
||||||
'AZURE_OPENAI_ENDPOINT',
|
'AZURE_OPENAI_ENDPOINT',
|
||||||
'AZURE_OPENAI_DEPLOYMENT_NAME',
|
'AZURE_OPENAI_DEPLOYMENT_NAME',
|
||||||
|
'GCP_PROJECT_ID',
|
||||||
|
'GCP_LOCATION',
|
||||||
];
|
];
|
||||||
return !nonSecretKeys.includes(keyName);
|
return !nonSecretKeys.includes(keyName);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -19,6 +19,13 @@ export const goose_models: Model[] = [
|
|||||||
{ id: 17, name: 'qwen2.5', provider: 'Ollama' },
|
{ id: 17, name: 'qwen2.5', provider: 'Ollama' },
|
||||||
{ id: 18, name: 'anthropic/claude-3.5-sonnet', provider: 'OpenRouter' },
|
{ id: 18, name: 'anthropic/claude-3.5-sonnet', provider: 'OpenRouter' },
|
||||||
{ id: 19, name: 'gpt-4o', provider: 'Azure OpenAI' },
|
{ id: 19, name: 'gpt-4o', provider: 'Azure OpenAI' },
|
||||||
|
{ id: 20, name: 'claude-3-7-sonnet@20250219', provider: 'GCP Vertex AI' },
|
||||||
|
{ id: 21, name: 'claude-3-5-sonnet-v2@20241022', provider: 'GCP Vertex AI' },
|
||||||
|
{ id: 22, name: 'claude-3-5-sonnet@20240620', provider: 'GCP Vertex AI' },
|
||||||
|
{ id: 23, name: 'claude-3-5-haiku@20241022', provider: 'GCP Vertex AI' },
|
||||||
|
{ id: 24, name: 'gemini-2.0-pro-exp-02-05', provider: 'GCP Vertex AI' },
|
||||||
|
{ id: 25, name: 'gemini-2.0-flash-001', provider: 'GCP Vertex AI' },
|
||||||
|
{ id: 26, name: 'gemini-1.5-pro-002', provider: 'GCP Vertex AI' },
|
||||||
];
|
];
|
||||||
|
|
||||||
export const openai_models = ['gpt-4o-mini', 'gpt-4o', 'gpt-4-turbo', 'o1'];
|
export const openai_models = ['gpt-4o-mini', 'gpt-4o', 'gpt-4-turbo', 'o1'];
|
||||||
@@ -47,6 +54,16 @@ export const openrouter_models = ['anthropic/claude-3.5-sonnet'];
|
|||||||
|
|
||||||
export const azure_openai_models = ['gpt-4o'];
|
export const azure_openai_models = ['gpt-4o'];
|
||||||
|
|
||||||
|
export const gcp_vertex_ai_models = [
|
||||||
|
'claude-3-7-sonnet@20250219',
|
||||||
|
'claude-3-5-sonnet-v2@20241022',
|
||||||
|
'claude-3-5-sonnet@20240620',
|
||||||
|
'claude-3-5-haiku@20241022',
|
||||||
|
'gemini-1.5-pro-002',
|
||||||
|
'gemini-2.0-flash-001',
|
||||||
|
'gemini-2.0-pro-exp-02-05',
|
||||||
|
];
|
||||||
|
|
||||||
export const default_models = {
|
export const default_models = {
|
||||||
openai: 'gpt-4o',
|
openai: 'gpt-4o',
|
||||||
anthropic: 'claude-3-5-sonnet-latest',
|
anthropic: 'claude-3-5-sonnet-latest',
|
||||||
@@ -56,6 +73,7 @@ export const default_models = {
|
|||||||
openrouter: 'anthropic/claude-3.5-sonnet',
|
openrouter: 'anthropic/claude-3.5-sonnet',
|
||||||
ollama: 'qwen2.5',
|
ollama: 'qwen2.5',
|
||||||
azure_openai: 'gpt-4o',
|
azure_openai: 'gpt-4o',
|
||||||
|
gcp_vertex_ai: 'gemini-2.0-flash-001',
|
||||||
};
|
};
|
||||||
|
|
||||||
export function getDefaultModel(key: string): string | undefined {
|
export function getDefaultModel(key: string): string | undefined {
|
||||||
@@ -73,12 +91,14 @@ export const required_keys = {
|
|||||||
Google: ['GOOGLE_API_KEY'],
|
Google: ['GOOGLE_API_KEY'],
|
||||||
OpenRouter: ['OPENROUTER_API_KEY'],
|
OpenRouter: ['OPENROUTER_API_KEY'],
|
||||||
'Azure OpenAI': ['AZURE_OPENAI_API_KEY', 'AZURE_OPENAI_ENDPOINT', 'AZURE_OPENAI_DEPLOYMENT_NAME'],
|
'Azure OpenAI': ['AZURE_OPENAI_API_KEY', 'AZURE_OPENAI_ENDPOINT', 'AZURE_OPENAI_DEPLOYMENT_NAME'],
|
||||||
|
'GCP Vertex AI': ['GCP_PROJECT_ID', 'GCP_LOCATION'],
|
||||||
};
|
};
|
||||||
|
|
||||||
export const default_key_value = {
|
export const default_key_value = {
|
||||||
OPENAI_HOST: 'https://api.openai.com',
|
OPENAI_HOST: 'https://api.openai.com',
|
||||||
OPENAI_BASE_PATH: 'v1/chat/completions',
|
OPENAI_BASE_PATH: 'v1/chat/completions',
|
||||||
OLLAMA_HOST: 'localhost',
|
OLLAMA_HOST: 'localhost',
|
||||||
|
GCP_LOCATION: 'us-central1',
|
||||||
};
|
};
|
||||||
|
|
||||||
export const supported_providers = [
|
export const supported_providers = [
|
||||||
@@ -90,6 +110,7 @@ export const supported_providers = [
|
|||||||
'Ollama',
|
'Ollama',
|
||||||
'OpenRouter',
|
'OpenRouter',
|
||||||
'Azure OpenAI',
|
'Azure OpenAI',
|
||||||
|
'GCP Vertex AI',
|
||||||
];
|
];
|
||||||
|
|
||||||
export const model_docs_link = [
|
export const model_docs_link = [
|
||||||
@@ -103,6 +124,7 @@ export const model_docs_link = [
|
|||||||
},
|
},
|
||||||
{ name: 'OpenRouter', href: 'https://openrouter.ai/models' },
|
{ name: 'OpenRouter', href: 'https://openrouter.ai/models' },
|
||||||
{ name: 'Ollama', href: 'https://ollama.com/library' },
|
{ name: 'Ollama', href: 'https://ollama.com/library' },
|
||||||
|
{ name: 'GCP Vertex AI', href: 'https://cloud.google.com/vertex-ai' },
|
||||||
];
|
];
|
||||||
|
|
||||||
export const provider_aliases = [
|
export const provider_aliases = [
|
||||||
@@ -114,4 +136,5 @@ export const provider_aliases = [
|
|||||||
{ provider: 'OpenRouter', alias: 'openrouter' },
|
{ provider: 'OpenRouter', alias: 'openrouter' },
|
||||||
{ provider: 'Google', alias: 'google' },
|
{ provider: 'Google', alias: 'google' },
|
||||||
{ provider: 'Azure OpenAI', alias: 'azure_openai' },
|
{ provider: 'Azure OpenAI', alias: 'azure_openai' },
|
||||||
|
{ provider: 'GCP Vertex AI', alias: 'gcp_vertex_ai' },
|
||||||
];
|
];
|
||||||
|
|||||||
Reference in New Issue
Block a user