From c99d5beb07b9bb94a786cfbbd62416bb6bb661d8 Mon Sep 17 00:00:00 2001 From: Alice Hau Date: Thu, 1 May 2025 14:09:37 -0400 Subject: [PATCH] feat: Azure credential chain logging (#2413) Co-authored-by: Michael Paquette Co-authored-by: Alice Hau --- .../src/routes/providers_and_keys.json | 4 +- crates/goose/src/providers/azure.rs | 149 +++++++++++++-- crates/goose/src/providers/azureauth.rs | 170 ++++++++++++++++++ crates/goose/src/providers/mod.rs | 1 + .../docs/getting-started/providers.md | 16 +- .../settings/models/hardcoded_stuff.tsx | 1 - .../providers/ProviderRegistry.tsx | 3 +- 7 files changed, 324 insertions(+), 20 deletions(-) create mode 100644 crates/goose/src/providers/azureauth.rs diff --git a/crates/goose-server/src/routes/providers_and_keys.json b/crates/goose-server/src/routes/providers_and_keys.json index f689e16c..830cf665 100644 --- a/crates/goose-server/src/routes/providers_and_keys.json +++ b/crates/goose-server/src/routes/providers_and_keys.json @@ -49,9 +49,9 @@ }, "azure_openai": { "name": "Azure OpenAI", - "description": "Connect to Azure OpenAI Service", + "description": "Connect to Azure OpenAI Service. If no API key is provided, Azure credential chain will be used.", "models": ["gpt-4o", "gpt-4o-mini"], - "required_keys": ["AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOYMENT_NAME"] + "required_keys": ["AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOYMENT_NAME"] }, "aws_bedrock": { "name": "AWS Bedrock", diff --git a/crates/goose/src/providers/azure.rs b/crates/goose/src/providers/azure.rs index 0ef9203e..51a31c06 100644 --- a/crates/goose/src/providers/azure.rs +++ b/crates/goose/src/providers/azure.rs @@ -1,9 +1,12 @@ use anyhow::Result; use async_trait::async_trait; use reqwest::Client; +use serde::Serialize; use serde_json::Value; use std::time::Duration; +use tokio::time::sleep; +use super::azureauth::AzureAuth; use super::base::{ConfigKey, Provider, ProviderMetadata, ProviderUsage, Usage}; use super::errors::ProviderError; use super::formats::openai::{create_request, get_usage, response_to_message}; @@ -18,17 +21,36 @@ pub const AZURE_DOC_URL: &str = pub const AZURE_DEFAULT_API_VERSION: &str = "2024-10-21"; pub const AZURE_OPENAI_KNOWN_MODELS: &[&str] = &["gpt-4o", "gpt-4o-mini", "gpt-4"]; -#[derive(Debug, serde::Serialize)] +// Default retry configuration +const DEFAULT_MAX_RETRIES: usize = 5; +const DEFAULT_INITIAL_RETRY_INTERVAL_MS: u64 = 1000; // Start with 1 second +const DEFAULT_MAX_RETRY_INTERVAL_MS: u64 = 32000; // Max 32 seconds +const DEFAULT_BACKOFF_MULTIPLIER: f64 = 2.0; + +#[derive(Debug)] pub struct AzureProvider { - #[serde(skip)] client: Client, + auth: AzureAuth, endpoint: String, - api_key: String, deployment_name: String, api_version: String, model: ModelConfig, } +impl Serialize for AzureProvider { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut state = serializer.serialize_struct("AzureProvider", 3)?; + state.serialize_field("endpoint", &self.endpoint)?; + state.serialize_field("deployment_name", &self.deployment_name)?; + state.serialize_field("api_version", &self.api_version)?; + state.end() + } +} + impl Default for AzureProvider { fn default() -> Self { let model = ModelConfig::new(AzureProvider::metadata().default_model); @@ -39,13 +61,16 @@ impl Default for AzureProvider { impl AzureProvider { pub fn from_env(model: ModelConfig) -> Result { let config = crate::config::Config::global(); - let api_key: String = config.get_secret("AZURE_OPENAI_API_KEY")?; let endpoint: String = config.get_param("AZURE_OPENAI_ENDPOINT")?; let deployment_name: String = config.get_param("AZURE_OPENAI_DEPLOYMENT_NAME")?; let api_version: String = config .get_param("AZURE_OPENAI_API_VERSION") .unwrap_or_else(|_| AZURE_DEFAULT_API_VERSION.to_string()); + // Try to get API key first, if not found use Azure credential chain + let api_key = config.get_secret("AZURE_OPENAI_API_KEY").ok(); + let auth = AzureAuth::new(api_key)?; + let client = Client::builder() .timeout(Duration::from_secs(600)) .build()?; @@ -53,7 +78,7 @@ impl AzureProvider { Ok(Self { client, endpoint, - api_key, + auth, deployment_name, api_version, model, @@ -81,15 +106,110 @@ impl AzureProvider { base_url.set_path(&new_path); base_url.set_query(Some(&format!("api-version={}", self.api_version))); - let response: reqwest::Response = self - .client - .post(base_url) - .header("api-key", &self.api_key) - .json(&payload) - .send() - .await?; + let mut attempts = 0; + let mut last_error = None; + let mut current_delay = DEFAULT_INITIAL_RETRY_INTERVAL_MS; - handle_response_openai_compat(response).await + loop { + // Check if we've exceeded max retries + if attempts > DEFAULT_MAX_RETRIES { + let error_msg = format!( + "Exceeded maximum retry attempts ({}) for rate limiting", + DEFAULT_MAX_RETRIES + ); + tracing::error!("{}", error_msg); + return Err(last_error.unwrap_or(ProviderError::RateLimitExceeded(error_msg))); + } + + // Get a fresh auth token for each attempt + let auth_token = self.auth.get_token().await.map_err(|e| { + tracing::error!("Authentication error: {:?}", e); + ProviderError::RequestFailed(format!("Failed to get authentication token: {}", e)) + })?; + + let mut request_builder = self.client.post(base_url.clone()); + let token_value = auth_token.token_value.clone(); + + // Set the correct header based on authentication type + match self.auth.credential_type() { + super::azureauth::AzureCredentials::ApiKey(_) => { + request_builder = request_builder.header("api-key", token_value.clone()); + } + super::azureauth::AzureCredentials::DefaultCredential => { + request_builder = request_builder + .header("Authorization", format!("Bearer {}", token_value.clone())); + } + } + + let response_result = request_builder.json(&payload).send().await; + + match response_result { + Ok(response) => match handle_response_openai_compat(response).await { + Ok(result) => { + return Ok(result); + } + Err(ProviderError::RateLimitExceeded(msg)) => { + attempts += 1; + last_error = Some(ProviderError::RateLimitExceeded(msg.clone())); + + let retry_after = + if let Some(secs) = msg.to_lowercase().find("try again in ") { + msg[secs..] + .split_whitespace() + .nth(3) + .and_then(|s| s.parse::().ok()) + .unwrap_or(0) + } else { + 0 + }; + + let delay = if retry_after > 0 { + Duration::from_secs(retry_after) + } else { + let delay = current_delay.min(DEFAULT_MAX_RETRY_INTERVAL_MS); + current_delay = + (current_delay as f64 * DEFAULT_BACKOFF_MULTIPLIER) as u64; + Duration::from_millis(delay) + }; + + sleep(delay).await; + continue; + } + Err(e) => { + tracing::error!( + "Error response from Azure OpenAI (attempt {}): {:?}", + attempts + 1, + e + ); + return Err(e); + } + }, + Err(e) => { + tracing::error!( + "Request failed (attempt {}): {:?}\nIs timeout: {}\nIs connect: {}\nIs request: {}", + attempts + 1, + e, + e.is_timeout(), + e.is_connect(), + e.is_request(), + ); + + // For timeout errors, we should retry + if e.is_timeout() { + attempts += 1; + let delay = current_delay.min(DEFAULT_MAX_RETRY_INTERVAL_MS); + current_delay = (current_delay as f64 * DEFAULT_BACKOFF_MULTIPLIER) as u64; + sleep(Duration::from_millis(delay)).await; + continue; + } + + return Err(ProviderError::RequestFailed(format!( + "Request failed: {}", + e + ))); + } + } + } } } @@ -99,12 +219,11 @@ impl Provider for AzureProvider { ProviderMetadata::new( "azure_openai", "Azure OpenAI", - "Models through Azure OpenAI Service", + "Models through Azure OpenAI Service (uses Azure credential chain by default)", "gpt-4o", AZURE_OPENAI_KNOWN_MODELS.to_vec(), AZURE_DOC_URL, vec![ - ConfigKey::new("AZURE_OPENAI_API_KEY", true, true, None), ConfigKey::new("AZURE_OPENAI_ENDPOINT", true, false, None), ConfigKey::new("AZURE_OPENAI_DEPLOYMENT_NAME", true, false, None), ConfigKey::new("AZURE_OPENAI_API_VERSION", true, false, Some("2024-10-21")), diff --git a/crates/goose/src/providers/azureauth.rs b/crates/goose/src/providers/azureauth.rs new file mode 100644 index 00000000..be7e39f4 --- /dev/null +++ b/crates/goose/src/providers/azureauth.rs @@ -0,0 +1,170 @@ +use chrono; +use serde::Deserialize; +use std::sync::Arc; +use std::time::{Duration, Instant}; +use tokio::sync::RwLock; + +/// Represents errors that can occur during Azure authentication. +#[derive(Debug, thiserror::Error)] +pub enum AuthError { + /// Error when loading credentials from the filesystem or environment + #[error("Failed to load credentials: {0}")] + Credentials(String), + + /// Error during token exchange + #[error("Token exchange failed: {0}")] + TokenExchange(String), +} + +/// Represents an authentication token with its type and value. +#[derive(Debug, Clone)] +pub struct AuthToken { + /// The type of the token (e.g., "Bearer") + pub token_type: String, + /// The actual token value + pub token_value: String, +} + +/// Represents the types of Azure credentials supported. +#[derive(Debug, Clone)] +pub enum AzureCredentials { + /// API key based authentication + ApiKey(String), + /// Azure credential chain based authentication + DefaultCredential, +} + +/// Holds a cached token and its expiration time. +#[derive(Debug, Clone)] +struct CachedToken { + token: AuthToken, + expires_at: Instant, +} + +/// Response from Azure token endpoint +#[derive(Debug, Clone, Deserialize)] +struct TokenResponse { + #[serde(rename = "accessToken")] + access_token: String, + #[serde(rename = "tokenType")] + token_type: String, + #[serde(rename = "expires_on")] + expires_on: u64, +} + +/// Azure authentication handler that manages credentials and token caching. +#[derive(Debug)] +pub struct AzureAuth { + credentials: AzureCredentials, + cached_token: Arc>>, +} + +impl AzureAuth { + /// Creates a new Azure authentication handler. + /// + /// Initializes the authentication handler by: + /// 1. Loading credentials from environment + /// 2. Setting up an HTTP client for token requests + /// 3. Initializing the token cache + /// + /// # Returns + /// * `Result` - A new AzureAuth instance or an error if initialization fails + pub fn new(api_key: Option) -> Result { + let credentials = match api_key { + Some(key) => AzureCredentials::ApiKey(key), + None => AzureCredentials::DefaultCredential, + }; + + Ok(Self { + credentials, + cached_token: Arc::new(RwLock::new(None)), + }) + } + + /// Returns the type of credentials being used. + pub fn credential_type(&self) -> &AzureCredentials { + &self.credentials + } + + /// Retrieves a valid authentication token. + /// + /// This method implements an efficient token management strategy: + /// 1. For API key auth, returns the API key directly + /// 2. For Azure credential chain: + /// a. Checks the cache for a valid token + /// b. Returns the cached token if not expired + /// c. Obtains a new token if needed or expired + /// d. Uses double-checked locking for thread safety + /// + /// # Returns + /// * `Result` - A valid authentication token or an error + pub async fn get_token(&self) -> Result { + match &self.credentials { + AzureCredentials::ApiKey(key) => Ok(AuthToken { + token_type: "Bearer".to_string(), + token_value: key.clone(), + }), + AzureCredentials::DefaultCredential => self.get_default_credential_token().await, + } + } + + async fn get_default_credential_token(&self) -> Result { + // Try read lock first for better concurrency + if let Some(cached) = self.cached_token.read().await.as_ref() { + if cached.expires_at > Instant::now() { + return Ok(cached.token.clone()); + } + } + + // Take write lock only if needed + let mut token_guard = self.cached_token.write().await; + + // Double-check expiration after acquiring write lock + if let Some(cached) = token_guard.as_ref() { + if cached.expires_at > Instant::now() { + return Ok(cached.token.clone()); + } + } + + // Get new token using Azure CLI credential + let output = tokio::process::Command::new("az") + .args([ + "account", + "get-access-token", + "--resource", + "https://cognitiveservices.azure.com", + ]) + .output() + .await + .map_err(|e| AuthError::TokenExchange(format!("Failed to execute Azure CLI: {}", e)))?; + + if !output.status.success() { + return Err(AuthError::TokenExchange( + String::from_utf8_lossy(&output.stderr).to_string(), + )); + } + + let token_response: TokenResponse = serde_json::from_slice(&output.stdout) + .map_err(|e| AuthError::TokenExchange(format!("Invalid token response: {}", e)))?; + + let auth_token = AuthToken { + token_type: token_response.token_type, + token_value: token_response.access_token, + }; + + let expires_at = Instant::now() + + Duration::from_secs( + token_response + .expires_on + .saturating_sub(chrono::Utc::now().timestamp() as u64) + .saturating_sub(30), + ); + + *token_guard = Some(CachedToken { + token: auth_token.clone(), + expires_at, + }); + + Ok(auth_token) + } +} diff --git a/crates/goose/src/providers/mod.rs b/crates/goose/src/providers/mod.rs index 0577bc66..82dfdfe3 100644 --- a/crates/goose/src/providers/mod.rs +++ b/crates/goose/src/providers/mod.rs @@ -1,5 +1,6 @@ pub mod anthropic; pub mod azure; +pub mod azureauth; pub mod base; pub mod bedrock; pub mod databricks; diff --git a/documentation/docs/getting-started/providers.md b/documentation/docs/getting-started/providers.md index 92f2e3c4..189357c4 100644 --- a/documentation/docs/getting-started/providers.md +++ b/documentation/docs/getting-started/providers.md @@ -21,7 +21,7 @@ Goose relies heavily on tool calling capabilities and currently works best with |-----------------------------------------------------------------------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| | [Amazon Bedrock](https://aws.amazon.com/bedrock/) | Offers a variety of foundation models, including Claude, Jurassic-2, and others. **AWS environment variables must be set in advance, not configured through `goose configure`** | `AWS_PROFILE`, or `AWS_ACCESS_KEY_ID`, `AWS_SECRET_ACCESS_KEY`, `AWS_REGION`, ... | | [Anthropic](https://www.anthropic.com/) | Offers Claude, an advanced AI model for natural language tasks. | `ANTHROPIC_API_KEY`, `ANTHROPIC_HOST` (optional) | -| [Azure OpenAI](https://learn.microsoft.com/en-us/azure/ai-services/openai/) | Access Azure-hosted OpenAI models, including GPT-4 and GPT-3.5. | `AZURE_OPENAI_API_KEY`, `AZURE_OPENAI_ENDPOINT`, `AZURE_OPENAI_DEPLOYMENT_NAME` | +| [Azure OpenAI](https://learn.microsoft.com/en-us/azure/ai-services/openai/) | Access Azure-hosted OpenAI models, including GPT-4 and GPT-3.5. Supports both API key and Azure credential chain authentication. | `AZURE_OPENAI_ENDPOINT`, `AZURE_OPENAI_DEPLOYMENT_NAME`, `AZURE_OPENAI_API_KEY` (optional) | | [Databricks](https://www.databricks.com/) | Unified data analytics and AI platform for building and deploying models. | `DATABRICKS_HOST`, `DATABRICKS_TOKEN` | | [Gemini](https://ai.google.dev/gemini-api/docs) | Advanced LLMs by Google with multimodal capabilities (text, images). | `GOOGLE_API_KEY` | | [GCP Vertex AI](https://cloud.google.com/vertex-ai) | Google Cloud's Vertex AI platform, supporting Gemini and Claude models. **Credentials must be configured in advance. Follow the instructions at https://cloud.google.com/vertex-ai/docs/authentication.** | `GCP_PROJECT_ID`, `GCP_LOCATION` and optional `GCP_MAX_RETRIES` (6), `GCP_INITIAL_RETRY_INTERVAL_MS` (5000), `GCP_BACKOFF_MULTIPLIER` (2.0), `GCP_MAX_RETRY_INTERVAL_MS` (320_000). | @@ -456,6 +456,20 @@ ollama run michaelneale/deepseek-r1-goose +## Azure OpenAI Credential Chain + +Goose supports two authentication methods for Azure OpenAI: + +1. **API Key Authentication** - Uses the `AZURE_OPENAI_API_KEY` for direct authentication +2. **Azure Credential Chain** - Uses Azure CLI credentials automatically without requiring an API key + +To use the Azure Credential Chain: +- Ensure you're logged in with `az login` +- Have appropriate Azure role assignments for the Azure OpenAI service +- Configure with `goose configure` and select Azure OpenAI, leaving the API key field empty + +This method simplifies authentication and enhances security for enterprise environments. + --- If you have any questions or need help with a specific provider, feel free to reach out to us on [Discord](https://discord.gg/block-opensource) or on the [Goose repo](https://github.com/block/goose). diff --git a/ui/desktop/src/components/settings/models/hardcoded_stuff.tsx b/ui/desktop/src/components/settings/models/hardcoded_stuff.tsx index c300e780..d8168e6e 100644 --- a/ui/desktop/src/components/settings/models/hardcoded_stuff.tsx +++ b/ui/desktop/src/components/settings/models/hardcoded_stuff.tsx @@ -64,7 +64,6 @@ export const required_keys = { Google: ['GOOGLE_API_KEY'], OpenRouter: ['OPENROUTER_API_KEY'], 'Azure OpenAI': [ - 'AZURE_OPENAI_API_KEY', 'AZURE_OPENAI_ENDPOINT', 'AZURE_OPENAI_DEPLOYMENT_NAME', 'AZURE_OPENAI_API_VERSION', diff --git a/ui/desktop/src/components/settings_v2/providers/ProviderRegistry.tsx b/ui/desktop/src/components/settings_v2/providers/ProviderRegistry.tsx index e2232d67..6ef080d5 100644 --- a/ui/desktop/src/components/settings_v2/providers/ProviderRegistry.tsx +++ b/ui/desktop/src/components/settings_v2/providers/ProviderRegistry.tsx @@ -173,11 +173,12 @@ export const PROVIDER_REGISTRY: ProviderRegistry[] = [ details: { id: 'azure_openai', name: 'Azure OpenAI', - description: 'Access Azure OpenAI models', + description: 'Access Azure OpenAI models using API key or Azure credentials. If no API key is provided, Azure credential chain will be used.', parameters: [ { name: 'AZURE_OPENAI_API_KEY', is_secret: true, + required: false, }, { name: 'AZURE_OPENAI_ENDPOINT',