From 5e8a8bae3d19d7defc139cc41c027c730a2c6b61 Mon Sep 17 00:00:00 2001 From: TechnoHouse <13776377+deephbz@users.noreply.github.com> Date: Wed, 12 Feb 2025 15:01:32 +0800 Subject: [PATCH] Support modifying AZURE_OPENAI_API_VERSION (#1042) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Jean-FrançoisMillet --- crates/goose/src/providers/azure.rs | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/crates/goose/src/providers/azure.rs b/crates/goose/src/providers/azure.rs index db826a9d..a35259fb 100644 --- a/crates/goose/src/providers/azure.rs +++ b/crates/goose/src/providers/azure.rs @@ -15,7 +15,7 @@ use mcp_core::tool::Tool; pub const AZURE_DEFAULT_MODEL: &str = "gpt-4o"; pub const AZURE_DOC_URL: &str = "https://learn.microsoft.com/en-us/azure/ai-services/openai/concepts/models"; -pub const AZURE_API_VERSION: &str = "2024-10-21"; +pub const AZURE_DEFAULT_API_VERSION: &str = "2024-10-21"; pub const AZURE_OPENAI_KNOWN_MODELS: &[&str] = &["gpt-4o", "gpt-4o-mini", "gpt-4"]; #[derive(Debug, serde::Serialize)] @@ -25,6 +25,7 @@ pub struct AzureProvider { endpoint: String, api_key: String, deployment_name: String, + api_version: String, model: ModelConfig, } @@ -41,6 +42,9 @@ impl AzureProvider { let api_key: String = config.get_secret("AZURE_OPENAI_API_KEY")?; let endpoint: String = config.get("AZURE_OPENAI_ENDPOINT")?; let deployment_name: String = config.get("AZURE_OPENAI_DEPLOYMENT_NAME")?; + let api_version: String = config + .get("AZURE_OPENAI_API_VERSION") + .unwrap_or_else(|_| AZURE_DEFAULT_API_VERSION.to_string()); let client = Client::builder() .timeout(Duration::from_secs(600)) @@ -51,6 +55,7 @@ impl AzureProvider { endpoint, api_key, deployment_name, + api_version, model, }) } @@ -63,7 +68,7 @@ impl AzureProvider { "openai/deployments/{}/chat/completions", self.deployment_name )); - base_url.set_query(Some(&format!("api-version={}", AZURE_API_VERSION))); + base_url.set_query(Some(&format!("api-version={}", self.api_version))); let response: reqwest::Response = self .client @@ -99,6 +104,12 @@ impl Provider for AzureProvider { false, Some("Name of your Azure OpenAI deployment"), ), + ConfigKey::new( + "AZURE_OPENAI_API_VERSION", + false, + false, + Some("Azure OpenAI API version, default: 2024-10-21"), + ), ], ) }