feat: Azure credential chain logging (#2413)

Co-authored-by: Michael Paquette <mpaquette@pax8.com>
Co-authored-by: Alice Hau <ahau@squareup.com>
This commit is contained in:
Alice Hau
2025-05-01 14:09:37 -04:00
committed by GitHub
parent 37ed707dd6
commit c99d5beb07
7 changed files with 324 additions and 20 deletions

View File

@@ -49,9 +49,9 @@
},
"azure_openai": {
"name": "Azure OpenAI",
"description": "Connect to Azure OpenAI Service",
"description": "Connect to Azure OpenAI Service. If no API key is provided, Azure credential chain will be used.",
"models": ["gpt-4o", "gpt-4o-mini"],
"required_keys": ["AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOYMENT_NAME"]
"required_keys": ["AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOYMENT_NAME"]
},
"aws_bedrock": {
"name": "AWS Bedrock",

View File

@@ -1,9 +1,12 @@
use anyhow::Result;
use async_trait::async_trait;
use reqwest::Client;
use serde::Serialize;
use serde_json::Value;
use std::time::Duration;
use tokio::time::sleep;
use super::azureauth::AzureAuth;
use super::base::{ConfigKey, Provider, ProviderMetadata, ProviderUsage, Usage};
use super::errors::ProviderError;
use super::formats::openai::{create_request, get_usage, response_to_message};
@@ -18,17 +21,36 @@ pub const AZURE_DOC_URL: &str =
pub const AZURE_DEFAULT_API_VERSION: &str = "2024-10-21";
pub const AZURE_OPENAI_KNOWN_MODELS: &[&str] = &["gpt-4o", "gpt-4o-mini", "gpt-4"];
#[derive(Debug, serde::Serialize)]
// Default retry configuration
const DEFAULT_MAX_RETRIES: usize = 5;
const DEFAULT_INITIAL_RETRY_INTERVAL_MS: u64 = 1000; // Start with 1 second
const DEFAULT_MAX_RETRY_INTERVAL_MS: u64 = 32000; // Max 32 seconds
const DEFAULT_BACKOFF_MULTIPLIER: f64 = 2.0;
#[derive(Debug)]
pub struct AzureProvider {
#[serde(skip)]
client: Client,
auth: AzureAuth,
endpoint: String,
api_key: String,
deployment_name: String,
api_version: String,
model: ModelConfig,
}
impl Serialize for AzureProvider {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
use serde::ser::SerializeStruct;
let mut state = serializer.serialize_struct("AzureProvider", 3)?;
state.serialize_field("endpoint", &self.endpoint)?;
state.serialize_field("deployment_name", &self.deployment_name)?;
state.serialize_field("api_version", &self.api_version)?;
state.end()
}
}
impl Default for AzureProvider {
fn default() -> Self {
let model = ModelConfig::new(AzureProvider::metadata().default_model);
@@ -39,13 +61,16 @@ impl Default for AzureProvider {
impl AzureProvider {
pub fn from_env(model: ModelConfig) -> Result<Self> {
let config = crate::config::Config::global();
let api_key: String = config.get_secret("AZURE_OPENAI_API_KEY")?;
let endpoint: String = config.get_param("AZURE_OPENAI_ENDPOINT")?;
let deployment_name: String = config.get_param("AZURE_OPENAI_DEPLOYMENT_NAME")?;
let api_version: String = config
.get_param("AZURE_OPENAI_API_VERSION")
.unwrap_or_else(|_| AZURE_DEFAULT_API_VERSION.to_string());
// Try to get API key first, if not found use Azure credential chain
let api_key = config.get_secret("AZURE_OPENAI_API_KEY").ok();
let auth = AzureAuth::new(api_key)?;
let client = Client::builder()
.timeout(Duration::from_secs(600))
.build()?;
@@ -53,7 +78,7 @@ impl AzureProvider {
Ok(Self {
client,
endpoint,
api_key,
auth,
deployment_name,
api_version,
model,
@@ -81,15 +106,110 @@ impl AzureProvider {
base_url.set_path(&new_path);
base_url.set_query(Some(&format!("api-version={}", self.api_version)));
let response: reqwest::Response = self
.client
.post(base_url)
.header("api-key", &self.api_key)
.json(&payload)
.send()
.await?;
let mut attempts = 0;
let mut last_error = None;
let mut current_delay = DEFAULT_INITIAL_RETRY_INTERVAL_MS;
handle_response_openai_compat(response).await
loop {
// Check if we've exceeded max retries
if attempts > DEFAULT_MAX_RETRIES {
let error_msg = format!(
"Exceeded maximum retry attempts ({}) for rate limiting",
DEFAULT_MAX_RETRIES
);
tracing::error!("{}", error_msg);
return Err(last_error.unwrap_or(ProviderError::RateLimitExceeded(error_msg)));
}
// Get a fresh auth token for each attempt
let auth_token = self.auth.get_token().await.map_err(|e| {
tracing::error!("Authentication error: {:?}", e);
ProviderError::RequestFailed(format!("Failed to get authentication token: {}", e))
})?;
let mut request_builder = self.client.post(base_url.clone());
let token_value = auth_token.token_value.clone();
// Set the correct header based on authentication type
match self.auth.credential_type() {
super::azureauth::AzureCredentials::ApiKey(_) => {
request_builder = request_builder.header("api-key", token_value.clone());
}
super::azureauth::AzureCredentials::DefaultCredential => {
request_builder = request_builder
.header("Authorization", format!("Bearer {}", token_value.clone()));
}
}
let response_result = request_builder.json(&payload).send().await;
match response_result {
Ok(response) => match handle_response_openai_compat(response).await {
Ok(result) => {
return Ok(result);
}
Err(ProviderError::RateLimitExceeded(msg)) => {
attempts += 1;
last_error = Some(ProviderError::RateLimitExceeded(msg.clone()));
let retry_after =
if let Some(secs) = msg.to_lowercase().find("try again in ") {
msg[secs..]
.split_whitespace()
.nth(3)
.and_then(|s| s.parse::<u64>().ok())
.unwrap_or(0)
} else {
0
};
let delay = if retry_after > 0 {
Duration::from_secs(retry_after)
} else {
let delay = current_delay.min(DEFAULT_MAX_RETRY_INTERVAL_MS);
current_delay =
(current_delay as f64 * DEFAULT_BACKOFF_MULTIPLIER) as u64;
Duration::from_millis(delay)
};
sleep(delay).await;
continue;
}
Err(e) => {
tracing::error!(
"Error response from Azure OpenAI (attempt {}): {:?}",
attempts + 1,
e
);
return Err(e);
}
},
Err(e) => {
tracing::error!(
"Request failed (attempt {}): {:?}\nIs timeout: {}\nIs connect: {}\nIs request: {}",
attempts + 1,
e,
e.is_timeout(),
e.is_connect(),
e.is_request(),
);
// For timeout errors, we should retry
if e.is_timeout() {
attempts += 1;
let delay = current_delay.min(DEFAULT_MAX_RETRY_INTERVAL_MS);
current_delay = (current_delay as f64 * DEFAULT_BACKOFF_MULTIPLIER) as u64;
sleep(Duration::from_millis(delay)).await;
continue;
}
return Err(ProviderError::RequestFailed(format!(
"Request failed: {}",
e
)));
}
}
}
}
}
@@ -99,12 +219,11 @@ impl Provider for AzureProvider {
ProviderMetadata::new(
"azure_openai",
"Azure OpenAI",
"Models through Azure OpenAI Service",
"Models through Azure OpenAI Service (uses Azure credential chain by default)",
"gpt-4o",
AZURE_OPENAI_KNOWN_MODELS.to_vec(),
AZURE_DOC_URL,
vec![
ConfigKey::new("AZURE_OPENAI_API_KEY", true, true, None),
ConfigKey::new("AZURE_OPENAI_ENDPOINT", true, false, None),
ConfigKey::new("AZURE_OPENAI_DEPLOYMENT_NAME", true, false, None),
ConfigKey::new("AZURE_OPENAI_API_VERSION", true, false, Some("2024-10-21")),

View File

@@ -0,0 +1,170 @@
use chrono;
use serde::Deserialize;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::RwLock;
/// Represents errors that can occur during Azure authentication.
#[derive(Debug, thiserror::Error)]
pub enum AuthError {
/// Error when loading credentials from the filesystem or environment
#[error("Failed to load credentials: {0}")]
Credentials(String),
/// Error during token exchange
#[error("Token exchange failed: {0}")]
TokenExchange(String),
}
/// Represents an authentication token with its type and value.
#[derive(Debug, Clone)]
pub struct AuthToken {
/// The type of the token (e.g., "Bearer")
pub token_type: String,
/// The actual token value
pub token_value: String,
}
/// Represents the types of Azure credentials supported.
#[derive(Debug, Clone)]
pub enum AzureCredentials {
/// API key based authentication
ApiKey(String),
/// Azure credential chain based authentication
DefaultCredential,
}
/// Holds a cached token and its expiration time.
#[derive(Debug, Clone)]
struct CachedToken {
token: AuthToken,
expires_at: Instant,
}
/// Response from Azure token endpoint
#[derive(Debug, Clone, Deserialize)]
struct TokenResponse {
#[serde(rename = "accessToken")]
access_token: String,
#[serde(rename = "tokenType")]
token_type: String,
#[serde(rename = "expires_on")]
expires_on: u64,
}
/// Azure authentication handler that manages credentials and token caching.
#[derive(Debug)]
pub struct AzureAuth {
credentials: AzureCredentials,
cached_token: Arc<RwLock<Option<CachedToken>>>,
}
impl AzureAuth {
/// Creates a new Azure authentication handler.
///
/// Initializes the authentication handler by:
/// 1. Loading credentials from environment
/// 2. Setting up an HTTP client for token requests
/// 3. Initializing the token cache
///
/// # Returns
/// * `Result<Self, AuthError>` - A new AzureAuth instance or an error if initialization fails
pub fn new(api_key: Option<String>) -> Result<Self, AuthError> {
let credentials = match api_key {
Some(key) => AzureCredentials::ApiKey(key),
None => AzureCredentials::DefaultCredential,
};
Ok(Self {
credentials,
cached_token: Arc::new(RwLock::new(None)),
})
}
/// Returns the type of credentials being used.
pub fn credential_type(&self) -> &AzureCredentials {
&self.credentials
}
/// Retrieves a valid authentication token.
///
/// This method implements an efficient token management strategy:
/// 1. For API key auth, returns the API key directly
/// 2. For Azure credential chain:
/// a. Checks the cache for a valid token
/// b. Returns the cached token if not expired
/// c. Obtains a new token if needed or expired
/// d. Uses double-checked locking for thread safety
///
/// # Returns
/// * `Result<AuthToken, AuthError>` - A valid authentication token or an error
pub async fn get_token(&self) -> Result<AuthToken, AuthError> {
match &self.credentials {
AzureCredentials::ApiKey(key) => Ok(AuthToken {
token_type: "Bearer".to_string(),
token_value: key.clone(),
}),
AzureCredentials::DefaultCredential => self.get_default_credential_token().await,
}
}
async fn get_default_credential_token(&self) -> Result<AuthToken, AuthError> {
// Try read lock first for better concurrency
if let Some(cached) = self.cached_token.read().await.as_ref() {
if cached.expires_at > Instant::now() {
return Ok(cached.token.clone());
}
}
// Take write lock only if needed
let mut token_guard = self.cached_token.write().await;
// Double-check expiration after acquiring write lock
if let Some(cached) = token_guard.as_ref() {
if cached.expires_at > Instant::now() {
return Ok(cached.token.clone());
}
}
// Get new token using Azure CLI credential
let output = tokio::process::Command::new("az")
.args([
"account",
"get-access-token",
"--resource",
"https://cognitiveservices.azure.com",
])
.output()
.await
.map_err(|e| AuthError::TokenExchange(format!("Failed to execute Azure CLI: {}", e)))?;
if !output.status.success() {
return Err(AuthError::TokenExchange(
String::from_utf8_lossy(&output.stderr).to_string(),
));
}
let token_response: TokenResponse = serde_json::from_slice(&output.stdout)
.map_err(|e| AuthError::TokenExchange(format!("Invalid token response: {}", e)))?;
let auth_token = AuthToken {
token_type: token_response.token_type,
token_value: token_response.access_token,
};
let expires_at = Instant::now()
+ Duration::from_secs(
token_response
.expires_on
.saturating_sub(chrono::Utc::now().timestamp() as u64)
.saturating_sub(30),
);
*token_guard = Some(CachedToken {
token: auth_token.clone(),
expires_at,
});
Ok(auth_token)
}
}

View File

@@ -1,5 +1,6 @@
pub mod anthropic;
pub mod azure;
pub mod azureauth;
pub mod base;
pub mod bedrock;
pub mod databricks;

View File

@@ -21,7 +21,7 @@ Goose relies heavily on tool calling capabilities and currently works best with
|-----------------------------------------------------------------------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| [Amazon Bedrock](https://aws.amazon.com/bedrock/) | Offers a variety of foundation models, including Claude, Jurassic-2, and others. **AWS environment variables must be set in advance, not configured through `goose configure`** | `AWS_PROFILE`, or `AWS_ACCESS_KEY_ID`, `AWS_SECRET_ACCESS_KEY`, `AWS_REGION`, ... |
| [Anthropic](https://www.anthropic.com/) | Offers Claude, an advanced AI model for natural language tasks. | `ANTHROPIC_API_KEY`, `ANTHROPIC_HOST` (optional) |
| [Azure OpenAI](https://learn.microsoft.com/en-us/azure/ai-services/openai/) | Access Azure-hosted OpenAI models, including GPT-4 and GPT-3.5. | `AZURE_OPENAI_API_KEY`, `AZURE_OPENAI_ENDPOINT`, `AZURE_OPENAI_DEPLOYMENT_NAME` |
| [Azure OpenAI](https://learn.microsoft.com/en-us/azure/ai-services/openai/) | Access Azure-hosted OpenAI models, including GPT-4 and GPT-3.5. Supports both API key and Azure credential chain authentication. | `AZURE_OPENAI_ENDPOINT`, `AZURE_OPENAI_DEPLOYMENT_NAME`, `AZURE_OPENAI_API_KEY` (optional) |
| [Databricks](https://www.databricks.com/) | Unified data analytics and AI platform for building and deploying models. | `DATABRICKS_HOST`, `DATABRICKS_TOKEN` |
| [Gemini](https://ai.google.dev/gemini-api/docs) | Advanced LLMs by Google with multimodal capabilities (text, images). | `GOOGLE_API_KEY` |
| [GCP Vertex AI](https://cloud.google.com/vertex-ai) | Google Cloud's Vertex AI platform, supporting Gemini and Claude models. **Credentials must be configured in advance. Follow the instructions at https://cloud.google.com/vertex-ai/docs/authentication.** | `GCP_PROJECT_ID`, `GCP_LOCATION` and optional `GCP_MAX_RETRIES` (6), `GCP_INITIAL_RETRY_INTERVAL_MS` (5000), `GCP_BACKOFF_MULTIPLIER` (2.0), `GCP_MAX_RETRY_INTERVAL_MS` (320_000). |
@@ -456,6 +456,20 @@ ollama run michaelneale/deepseek-r1-goose
</TabItem>
</Tabs>
## Azure OpenAI Credential Chain
Goose supports two authentication methods for Azure OpenAI:
1. **API Key Authentication** - Uses the `AZURE_OPENAI_API_KEY` for direct authentication
2. **Azure Credential Chain** - Uses Azure CLI credentials automatically without requiring an API key
To use the Azure Credential Chain:
- Ensure you're logged in with `az login`
- Have appropriate Azure role assignments for the Azure OpenAI service
- Configure with `goose configure` and select Azure OpenAI, leaving the API key field empty
This method simplifies authentication and enhances security for enterprise environments.
---
If you have any questions or need help with a specific provider, feel free to reach out to us on [Discord](https://discord.gg/block-opensource) or on the [Goose repo](https://github.com/block/goose).

View File

@@ -64,7 +64,6 @@ export const required_keys = {
Google: ['GOOGLE_API_KEY'],
OpenRouter: ['OPENROUTER_API_KEY'],
'Azure OpenAI': [
'AZURE_OPENAI_API_KEY',
'AZURE_OPENAI_ENDPOINT',
'AZURE_OPENAI_DEPLOYMENT_NAME',
'AZURE_OPENAI_API_VERSION',

View File

@@ -173,11 +173,12 @@ export const PROVIDER_REGISTRY: ProviderRegistry[] = [
details: {
id: 'azure_openai',
name: 'Azure OpenAI',
description: 'Access Azure OpenAI models',
description: 'Access Azure OpenAI models using API key or Azure credentials. If no API key is provided, Azure credential chain will be used.',
parameters: [
{
name: 'AZURE_OPENAI_API_KEY',
is_secret: true,
required: false,
},
{
name: 'AZURE_OPENAI_ENDPOINT',