diff --git a/crates/goose/src/providers/azure.rs b/crates/goose/src/providers/azure.rs index db826a9d..a35259fb 100644 --- a/crates/goose/src/providers/azure.rs +++ b/crates/goose/src/providers/azure.rs @@ -15,7 +15,7 @@ use mcp_core::tool::Tool; pub const AZURE_DEFAULT_MODEL: &str = "gpt-4o"; pub const AZURE_DOC_URL: &str = "https://learn.microsoft.com/en-us/azure/ai-services/openai/concepts/models"; -pub const AZURE_API_VERSION: &str = "2024-10-21"; +pub const AZURE_DEFAULT_API_VERSION: &str = "2024-10-21"; pub const AZURE_OPENAI_KNOWN_MODELS: &[&str] = &["gpt-4o", "gpt-4o-mini", "gpt-4"]; #[derive(Debug, serde::Serialize)] @@ -25,6 +25,7 @@ pub struct AzureProvider { endpoint: String, api_key: String, deployment_name: String, + api_version: String, model: ModelConfig, } @@ -41,6 +42,9 @@ impl AzureProvider { let api_key: String = config.get_secret("AZURE_OPENAI_API_KEY")?; let endpoint: String = config.get("AZURE_OPENAI_ENDPOINT")?; let deployment_name: String = config.get("AZURE_OPENAI_DEPLOYMENT_NAME")?; + let api_version: String = config + .get("AZURE_OPENAI_API_VERSION") + .unwrap_or_else(|_| AZURE_DEFAULT_API_VERSION.to_string()); let client = Client::builder() .timeout(Duration::from_secs(600)) @@ -51,6 +55,7 @@ impl AzureProvider { endpoint, api_key, deployment_name, + api_version, model, }) } @@ -63,7 +68,7 @@ impl AzureProvider { "openai/deployments/{}/chat/completions", self.deployment_name )); - base_url.set_query(Some(&format!("api-version={}", AZURE_API_VERSION))); + base_url.set_query(Some(&format!("api-version={}", self.api_version))); let response: reqwest::Response = self .client @@ -99,6 +104,12 @@ impl Provider for AzureProvider { false, Some("Name of your Azure OpenAI deployment"), ), + ConfigKey::new( + "AZURE_OPENAI_API_VERSION", + false, + false, + Some("Azure OpenAI API version, default: 2024-10-21"), + ), ], ) }