mirror of
https://github.com/aljazceru/goose.git
synced 2025-12-18 06:34:26 +01:00
feat: Azure credential chain logging (#2413)
Co-authored-by: Michael Paquette <mpaquette@pax8.com> Co-authored-by: Alice Hau <ahau@squareup.com>
This commit is contained in:
@@ -49,9 +49,9 @@
|
||||
},
|
||||
"azure_openai": {
|
||||
"name": "Azure OpenAI",
|
||||
"description": "Connect to Azure OpenAI Service",
|
||||
"description": "Connect to Azure OpenAI Service. If no API key is provided, Azure credential chain will be used.",
|
||||
"models": ["gpt-4o", "gpt-4o-mini"],
|
||||
"required_keys": ["AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOYMENT_NAME"]
|
||||
"required_keys": ["AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOYMENT_NAME"]
|
||||
},
|
||||
"aws_bedrock": {
|
||||
"name": "AWS Bedrock",
|
||||
|
||||
@@ -1,9 +1,12 @@
|
||||
use anyhow::Result;
|
||||
use async_trait::async_trait;
|
||||
use reqwest::Client;
|
||||
use serde::Serialize;
|
||||
use serde_json::Value;
|
||||
use std::time::Duration;
|
||||
use tokio::time::sleep;
|
||||
|
||||
use super::azureauth::AzureAuth;
|
||||
use super::base::{ConfigKey, Provider, ProviderMetadata, ProviderUsage, Usage};
|
||||
use super::errors::ProviderError;
|
||||
use super::formats::openai::{create_request, get_usage, response_to_message};
|
||||
@@ -18,17 +21,36 @@ pub const AZURE_DOC_URL: &str =
|
||||
pub const AZURE_DEFAULT_API_VERSION: &str = "2024-10-21";
|
||||
pub const AZURE_OPENAI_KNOWN_MODELS: &[&str] = &["gpt-4o", "gpt-4o-mini", "gpt-4"];
|
||||
|
||||
#[derive(Debug, serde::Serialize)]
|
||||
// Default retry configuration
|
||||
const DEFAULT_MAX_RETRIES: usize = 5;
|
||||
const DEFAULT_INITIAL_RETRY_INTERVAL_MS: u64 = 1000; // Start with 1 second
|
||||
const DEFAULT_MAX_RETRY_INTERVAL_MS: u64 = 32000; // Max 32 seconds
|
||||
const DEFAULT_BACKOFF_MULTIPLIER: f64 = 2.0;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct AzureProvider {
|
||||
#[serde(skip)]
|
||||
client: Client,
|
||||
auth: AzureAuth,
|
||||
endpoint: String,
|
||||
api_key: String,
|
||||
deployment_name: String,
|
||||
api_version: String,
|
||||
model: ModelConfig,
|
||||
}
|
||||
|
||||
impl Serialize for AzureProvider {
|
||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: serde::Serializer,
|
||||
{
|
||||
use serde::ser::SerializeStruct;
|
||||
let mut state = serializer.serialize_struct("AzureProvider", 3)?;
|
||||
state.serialize_field("endpoint", &self.endpoint)?;
|
||||
state.serialize_field("deployment_name", &self.deployment_name)?;
|
||||
state.serialize_field("api_version", &self.api_version)?;
|
||||
state.end()
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for AzureProvider {
|
||||
fn default() -> Self {
|
||||
let model = ModelConfig::new(AzureProvider::metadata().default_model);
|
||||
@@ -39,13 +61,16 @@ impl Default for AzureProvider {
|
||||
impl AzureProvider {
|
||||
pub fn from_env(model: ModelConfig) -> Result<Self> {
|
||||
let config = crate::config::Config::global();
|
||||
let api_key: String = config.get_secret("AZURE_OPENAI_API_KEY")?;
|
||||
let endpoint: String = config.get_param("AZURE_OPENAI_ENDPOINT")?;
|
||||
let deployment_name: String = config.get_param("AZURE_OPENAI_DEPLOYMENT_NAME")?;
|
||||
let api_version: String = config
|
||||
.get_param("AZURE_OPENAI_API_VERSION")
|
||||
.unwrap_or_else(|_| AZURE_DEFAULT_API_VERSION.to_string());
|
||||
|
||||
// Try to get API key first, if not found use Azure credential chain
|
||||
let api_key = config.get_secret("AZURE_OPENAI_API_KEY").ok();
|
||||
let auth = AzureAuth::new(api_key)?;
|
||||
|
||||
let client = Client::builder()
|
||||
.timeout(Duration::from_secs(600))
|
||||
.build()?;
|
||||
@@ -53,7 +78,7 @@ impl AzureProvider {
|
||||
Ok(Self {
|
||||
client,
|
||||
endpoint,
|
||||
api_key,
|
||||
auth,
|
||||
deployment_name,
|
||||
api_version,
|
||||
model,
|
||||
@@ -81,15 +106,110 @@ impl AzureProvider {
|
||||
base_url.set_path(&new_path);
|
||||
base_url.set_query(Some(&format!("api-version={}", self.api_version)));
|
||||
|
||||
let response: reqwest::Response = self
|
||||
.client
|
||||
.post(base_url)
|
||||
.header("api-key", &self.api_key)
|
||||
.json(&payload)
|
||||
.send()
|
||||
.await?;
|
||||
let mut attempts = 0;
|
||||
let mut last_error = None;
|
||||
let mut current_delay = DEFAULT_INITIAL_RETRY_INTERVAL_MS;
|
||||
|
||||
handle_response_openai_compat(response).await
|
||||
loop {
|
||||
// Check if we've exceeded max retries
|
||||
if attempts > DEFAULT_MAX_RETRIES {
|
||||
let error_msg = format!(
|
||||
"Exceeded maximum retry attempts ({}) for rate limiting",
|
||||
DEFAULT_MAX_RETRIES
|
||||
);
|
||||
tracing::error!("{}", error_msg);
|
||||
return Err(last_error.unwrap_or(ProviderError::RateLimitExceeded(error_msg)));
|
||||
}
|
||||
|
||||
// Get a fresh auth token for each attempt
|
||||
let auth_token = self.auth.get_token().await.map_err(|e| {
|
||||
tracing::error!("Authentication error: {:?}", e);
|
||||
ProviderError::RequestFailed(format!("Failed to get authentication token: {}", e))
|
||||
})?;
|
||||
|
||||
let mut request_builder = self.client.post(base_url.clone());
|
||||
let token_value = auth_token.token_value.clone();
|
||||
|
||||
// Set the correct header based on authentication type
|
||||
match self.auth.credential_type() {
|
||||
super::azureauth::AzureCredentials::ApiKey(_) => {
|
||||
request_builder = request_builder.header("api-key", token_value.clone());
|
||||
}
|
||||
super::azureauth::AzureCredentials::DefaultCredential => {
|
||||
request_builder = request_builder
|
||||
.header("Authorization", format!("Bearer {}", token_value.clone()));
|
||||
}
|
||||
}
|
||||
|
||||
let response_result = request_builder.json(&payload).send().await;
|
||||
|
||||
match response_result {
|
||||
Ok(response) => match handle_response_openai_compat(response).await {
|
||||
Ok(result) => {
|
||||
return Ok(result);
|
||||
}
|
||||
Err(ProviderError::RateLimitExceeded(msg)) => {
|
||||
attempts += 1;
|
||||
last_error = Some(ProviderError::RateLimitExceeded(msg.clone()));
|
||||
|
||||
let retry_after =
|
||||
if let Some(secs) = msg.to_lowercase().find("try again in ") {
|
||||
msg[secs..]
|
||||
.split_whitespace()
|
||||
.nth(3)
|
||||
.and_then(|s| s.parse::<u64>().ok())
|
||||
.unwrap_or(0)
|
||||
} else {
|
||||
0
|
||||
};
|
||||
|
||||
let delay = if retry_after > 0 {
|
||||
Duration::from_secs(retry_after)
|
||||
} else {
|
||||
let delay = current_delay.min(DEFAULT_MAX_RETRY_INTERVAL_MS);
|
||||
current_delay =
|
||||
(current_delay as f64 * DEFAULT_BACKOFF_MULTIPLIER) as u64;
|
||||
Duration::from_millis(delay)
|
||||
};
|
||||
|
||||
sleep(delay).await;
|
||||
continue;
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!(
|
||||
"Error response from Azure OpenAI (attempt {}): {:?}",
|
||||
attempts + 1,
|
||||
e
|
||||
);
|
||||
return Err(e);
|
||||
}
|
||||
},
|
||||
Err(e) => {
|
||||
tracing::error!(
|
||||
"Request failed (attempt {}): {:?}\nIs timeout: {}\nIs connect: {}\nIs request: {}",
|
||||
attempts + 1,
|
||||
e,
|
||||
e.is_timeout(),
|
||||
e.is_connect(),
|
||||
e.is_request(),
|
||||
);
|
||||
|
||||
// For timeout errors, we should retry
|
||||
if e.is_timeout() {
|
||||
attempts += 1;
|
||||
let delay = current_delay.min(DEFAULT_MAX_RETRY_INTERVAL_MS);
|
||||
current_delay = (current_delay as f64 * DEFAULT_BACKOFF_MULTIPLIER) as u64;
|
||||
sleep(Duration::from_millis(delay)).await;
|
||||
continue;
|
||||
}
|
||||
|
||||
return Err(ProviderError::RequestFailed(format!(
|
||||
"Request failed: {}",
|
||||
e
|
||||
)));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -99,12 +219,11 @@ impl Provider for AzureProvider {
|
||||
ProviderMetadata::new(
|
||||
"azure_openai",
|
||||
"Azure OpenAI",
|
||||
"Models through Azure OpenAI Service",
|
||||
"Models through Azure OpenAI Service (uses Azure credential chain by default)",
|
||||
"gpt-4o",
|
||||
AZURE_OPENAI_KNOWN_MODELS.to_vec(),
|
||||
AZURE_DOC_URL,
|
||||
vec![
|
||||
ConfigKey::new("AZURE_OPENAI_API_KEY", true, true, None),
|
||||
ConfigKey::new("AZURE_OPENAI_ENDPOINT", true, false, None),
|
||||
ConfigKey::new("AZURE_OPENAI_DEPLOYMENT_NAME", true, false, None),
|
||||
ConfigKey::new("AZURE_OPENAI_API_VERSION", true, false, Some("2024-10-21")),
|
||||
|
||||
170
crates/goose/src/providers/azureauth.rs
Normal file
170
crates/goose/src/providers/azureauth.rs
Normal file
@@ -0,0 +1,170 @@
|
||||
use chrono;
|
||||
use serde::Deserialize;
|
||||
use std::sync::Arc;
|
||||
use std::time::{Duration, Instant};
|
||||
use tokio::sync::RwLock;
|
||||
|
||||
/// Represents errors that can occur during Azure authentication.
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum AuthError {
|
||||
/// Error when loading credentials from the filesystem or environment
|
||||
#[error("Failed to load credentials: {0}")]
|
||||
Credentials(String),
|
||||
|
||||
/// Error during token exchange
|
||||
#[error("Token exchange failed: {0}")]
|
||||
TokenExchange(String),
|
||||
}
|
||||
|
||||
/// Represents an authentication token with its type and value.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct AuthToken {
|
||||
/// The type of the token (e.g., "Bearer")
|
||||
pub token_type: String,
|
||||
/// The actual token value
|
||||
pub token_value: String,
|
||||
}
|
||||
|
||||
/// Represents the types of Azure credentials supported.
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum AzureCredentials {
|
||||
/// API key based authentication
|
||||
ApiKey(String),
|
||||
/// Azure credential chain based authentication
|
||||
DefaultCredential,
|
||||
}
|
||||
|
||||
/// Holds a cached token and its expiration time.
|
||||
#[derive(Debug, Clone)]
|
||||
struct CachedToken {
|
||||
token: AuthToken,
|
||||
expires_at: Instant,
|
||||
}
|
||||
|
||||
/// Response from Azure token endpoint
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
struct TokenResponse {
|
||||
#[serde(rename = "accessToken")]
|
||||
access_token: String,
|
||||
#[serde(rename = "tokenType")]
|
||||
token_type: String,
|
||||
#[serde(rename = "expires_on")]
|
||||
expires_on: u64,
|
||||
}
|
||||
|
||||
/// Azure authentication handler that manages credentials and token caching.
|
||||
#[derive(Debug)]
|
||||
pub struct AzureAuth {
|
||||
credentials: AzureCredentials,
|
||||
cached_token: Arc<RwLock<Option<CachedToken>>>,
|
||||
}
|
||||
|
||||
impl AzureAuth {
|
||||
/// Creates a new Azure authentication handler.
|
||||
///
|
||||
/// Initializes the authentication handler by:
|
||||
/// 1. Loading credentials from environment
|
||||
/// 2. Setting up an HTTP client for token requests
|
||||
/// 3. Initializing the token cache
|
||||
///
|
||||
/// # Returns
|
||||
/// * `Result<Self, AuthError>` - A new AzureAuth instance or an error if initialization fails
|
||||
pub fn new(api_key: Option<String>) -> Result<Self, AuthError> {
|
||||
let credentials = match api_key {
|
||||
Some(key) => AzureCredentials::ApiKey(key),
|
||||
None => AzureCredentials::DefaultCredential,
|
||||
};
|
||||
|
||||
Ok(Self {
|
||||
credentials,
|
||||
cached_token: Arc::new(RwLock::new(None)),
|
||||
})
|
||||
}
|
||||
|
||||
/// Returns the type of credentials being used.
|
||||
pub fn credential_type(&self) -> &AzureCredentials {
|
||||
&self.credentials
|
||||
}
|
||||
|
||||
/// Retrieves a valid authentication token.
|
||||
///
|
||||
/// This method implements an efficient token management strategy:
|
||||
/// 1. For API key auth, returns the API key directly
|
||||
/// 2. For Azure credential chain:
|
||||
/// a. Checks the cache for a valid token
|
||||
/// b. Returns the cached token if not expired
|
||||
/// c. Obtains a new token if needed or expired
|
||||
/// d. Uses double-checked locking for thread safety
|
||||
///
|
||||
/// # Returns
|
||||
/// * `Result<AuthToken, AuthError>` - A valid authentication token or an error
|
||||
pub async fn get_token(&self) -> Result<AuthToken, AuthError> {
|
||||
match &self.credentials {
|
||||
AzureCredentials::ApiKey(key) => Ok(AuthToken {
|
||||
token_type: "Bearer".to_string(),
|
||||
token_value: key.clone(),
|
||||
}),
|
||||
AzureCredentials::DefaultCredential => self.get_default_credential_token().await,
|
||||
}
|
||||
}
|
||||
|
||||
async fn get_default_credential_token(&self) -> Result<AuthToken, AuthError> {
|
||||
// Try read lock first for better concurrency
|
||||
if let Some(cached) = self.cached_token.read().await.as_ref() {
|
||||
if cached.expires_at > Instant::now() {
|
||||
return Ok(cached.token.clone());
|
||||
}
|
||||
}
|
||||
|
||||
// Take write lock only if needed
|
||||
let mut token_guard = self.cached_token.write().await;
|
||||
|
||||
// Double-check expiration after acquiring write lock
|
||||
if let Some(cached) = token_guard.as_ref() {
|
||||
if cached.expires_at > Instant::now() {
|
||||
return Ok(cached.token.clone());
|
||||
}
|
||||
}
|
||||
|
||||
// Get new token using Azure CLI credential
|
||||
let output = tokio::process::Command::new("az")
|
||||
.args([
|
||||
"account",
|
||||
"get-access-token",
|
||||
"--resource",
|
||||
"https://cognitiveservices.azure.com",
|
||||
])
|
||||
.output()
|
||||
.await
|
||||
.map_err(|e| AuthError::TokenExchange(format!("Failed to execute Azure CLI: {}", e)))?;
|
||||
|
||||
if !output.status.success() {
|
||||
return Err(AuthError::TokenExchange(
|
||||
String::from_utf8_lossy(&output.stderr).to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let token_response: TokenResponse = serde_json::from_slice(&output.stdout)
|
||||
.map_err(|e| AuthError::TokenExchange(format!("Invalid token response: {}", e)))?;
|
||||
|
||||
let auth_token = AuthToken {
|
||||
token_type: token_response.token_type,
|
||||
token_value: token_response.access_token,
|
||||
};
|
||||
|
||||
let expires_at = Instant::now()
|
||||
+ Duration::from_secs(
|
||||
token_response
|
||||
.expires_on
|
||||
.saturating_sub(chrono::Utc::now().timestamp() as u64)
|
||||
.saturating_sub(30),
|
||||
);
|
||||
|
||||
*token_guard = Some(CachedToken {
|
||||
token: auth_token.clone(),
|
||||
expires_at,
|
||||
});
|
||||
|
||||
Ok(auth_token)
|
||||
}
|
||||
}
|
||||
@@ -1,5 +1,6 @@
|
||||
pub mod anthropic;
|
||||
pub mod azure;
|
||||
pub mod azureauth;
|
||||
pub mod base;
|
||||
pub mod bedrock;
|
||||
pub mod databricks;
|
||||
|
||||
@@ -21,7 +21,7 @@ Goose relies heavily on tool calling capabilities and currently works best with
|
||||
|-----------------------------------------------------------------------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
||||
| [Amazon Bedrock](https://aws.amazon.com/bedrock/) | Offers a variety of foundation models, including Claude, Jurassic-2, and others. **AWS environment variables must be set in advance, not configured through `goose configure`** | `AWS_PROFILE`, or `AWS_ACCESS_KEY_ID`, `AWS_SECRET_ACCESS_KEY`, `AWS_REGION`, ... |
|
||||
| [Anthropic](https://www.anthropic.com/) | Offers Claude, an advanced AI model for natural language tasks. | `ANTHROPIC_API_KEY`, `ANTHROPIC_HOST` (optional) |
|
||||
| [Azure OpenAI](https://learn.microsoft.com/en-us/azure/ai-services/openai/) | Access Azure-hosted OpenAI models, including GPT-4 and GPT-3.5. | `AZURE_OPENAI_API_KEY`, `AZURE_OPENAI_ENDPOINT`, `AZURE_OPENAI_DEPLOYMENT_NAME` |
|
||||
| [Azure OpenAI](https://learn.microsoft.com/en-us/azure/ai-services/openai/) | Access Azure-hosted OpenAI models, including GPT-4 and GPT-3.5. Supports both API key and Azure credential chain authentication. | `AZURE_OPENAI_ENDPOINT`, `AZURE_OPENAI_DEPLOYMENT_NAME`, `AZURE_OPENAI_API_KEY` (optional) |
|
||||
| [Databricks](https://www.databricks.com/) | Unified data analytics and AI platform for building and deploying models. | `DATABRICKS_HOST`, `DATABRICKS_TOKEN` |
|
||||
| [Gemini](https://ai.google.dev/gemini-api/docs) | Advanced LLMs by Google with multimodal capabilities (text, images). | `GOOGLE_API_KEY` |
|
||||
| [GCP Vertex AI](https://cloud.google.com/vertex-ai) | Google Cloud's Vertex AI platform, supporting Gemini and Claude models. **Credentials must be configured in advance. Follow the instructions at https://cloud.google.com/vertex-ai/docs/authentication.** | `GCP_PROJECT_ID`, `GCP_LOCATION` and optional `GCP_MAX_RETRIES` (6), `GCP_INITIAL_RETRY_INTERVAL_MS` (5000), `GCP_BACKOFF_MULTIPLIER` (2.0), `GCP_MAX_RETRY_INTERVAL_MS` (320_000). |
|
||||
@@ -456,6 +456,20 @@ ollama run michaelneale/deepseek-r1-goose
|
||||
</TabItem>
|
||||
</Tabs>
|
||||
|
||||
## Azure OpenAI Credential Chain
|
||||
|
||||
Goose supports two authentication methods for Azure OpenAI:
|
||||
|
||||
1. **API Key Authentication** - Uses the `AZURE_OPENAI_API_KEY` for direct authentication
|
||||
2. **Azure Credential Chain** - Uses Azure CLI credentials automatically without requiring an API key
|
||||
|
||||
To use the Azure Credential Chain:
|
||||
- Ensure you're logged in with `az login`
|
||||
- Have appropriate Azure role assignments for the Azure OpenAI service
|
||||
- Configure with `goose configure` and select Azure OpenAI, leaving the API key field empty
|
||||
|
||||
This method simplifies authentication and enhances security for enterprise environments.
|
||||
|
||||
---
|
||||
|
||||
If you have any questions or need help with a specific provider, feel free to reach out to us on [Discord](https://discord.gg/block-opensource) or on the [Goose repo](https://github.com/block/goose).
|
||||
|
||||
@@ -64,7 +64,6 @@ export const required_keys = {
|
||||
Google: ['GOOGLE_API_KEY'],
|
||||
OpenRouter: ['OPENROUTER_API_KEY'],
|
||||
'Azure OpenAI': [
|
||||
'AZURE_OPENAI_API_KEY',
|
||||
'AZURE_OPENAI_ENDPOINT',
|
||||
'AZURE_OPENAI_DEPLOYMENT_NAME',
|
||||
'AZURE_OPENAI_API_VERSION',
|
||||
|
||||
@@ -173,11 +173,12 @@ export const PROVIDER_REGISTRY: ProviderRegistry[] = [
|
||||
details: {
|
||||
id: 'azure_openai',
|
||||
name: 'Azure OpenAI',
|
||||
description: 'Access Azure OpenAI models',
|
||||
description: 'Access Azure OpenAI models using API key or Azure credentials. If no API key is provided, Azure credential chain will be used.',
|
||||
parameters: [
|
||||
{
|
||||
name: 'AZURE_OPENAI_API_KEY',
|
||||
is_secret: true,
|
||||
required: false,
|
||||
},
|
||||
{
|
||||
name: 'AZURE_OPENAI_ENDPOINT',
|
||||
|
||||
Reference in New Issue
Block a user