diff --git a/crates/goose-server/src/routes/providers_and_keys.json b/crates/goose-server/src/routes/providers_and_keys.json index a936893e..ddcaf525 100644 --- a/crates/goose-server/src/routes/providers_and_keys.json +++ b/crates/goose-server/src/routes/providers_and_keys.json @@ -44,7 +44,7 @@ "azure_openai": { "name": "Azure OpenAI", "description": "Connect to Azure OpenAI Service", - "models": ["gpt-4o", "gpt-4o-mini", "o1", "o1-mini"], + "models": ["gpt-4o", "gpt-4o-mini"], "required_keys": ["AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOYMENT_NAME"] } } diff --git a/crates/goose/src/providers/azure.rs b/crates/goose/src/providers/azure.rs index c17f0804..eef2b46c 100644 --- a/crates/goose/src/providers/azure.rs +++ b/crates/goose/src/providers/azure.rs @@ -16,14 +16,7 @@ pub const AZURE_DEFAULT_MODEL: &str = "gpt-4o"; pub const AZURE_DOC_URL: &str = "https://learn.microsoft.com/en-us/azure/ai-services/openai/concepts/models"; pub const AZURE_API_VERSION: &str = "2024-10-21"; -pub const AZURE_OPENAI_KNOWN_MODELS: &[&str] = &[ - "gpt-4o", - "gpt-4o-mini", - "o1", - "o1-mini", - "o1-preview", - "gpt-4", -]; +pub const AZURE_OPENAI_KNOWN_MODELS: &[&str] = &["gpt-4o", "gpt-4o-mini", "gpt-4"]; #[derive(Debug, serde::Serialize)] pub struct AzureProvider { @@ -63,16 +56,18 @@ impl AzureProvider { } async fn post(&self, payload: Value) -> Result { - let url = format!( - "{}/openai/deployments/{}/chat/completions?api-version={}", - self.endpoint.trim_end_matches('/'), - self.deployment_name, - AZURE_API_VERSION - ); + let mut base_url = url::Url::parse(&self.endpoint) + .map_err(|e| ProviderError::RequestFailed(format!("Invalid base URL: {e}")))?; + + base_url.set_path(&format!( + "openai/deployments/{}/chat/completions", + self.deployment_name + )); + base_url.set_query(Some(&format!("api-version={}", AZURE_API_VERSION))); let response: reqwest::Response = self .client - .post(&url) + .post(base_url) .header("api-key", &self.api_key) .json(&payload) .send() diff --git a/crates/goose/tests/providers.rs b/crates/goose/tests/providers.rs index 4ef98125..6a5f4b9d 100644 --- a/crates/goose/tests/providers.rs +++ b/crates/goose/tests/providers.rs @@ -3,7 +3,7 @@ use dotenv::dotenv; use goose::message::{Message, MessageContent}; use goose::providers::base::Provider; use goose::providers::errors::ProviderError; -use goose::providers::{anthropic, databricks, google, groq, ollama, openai, openrouter}; +use goose::providers::{anthropic, azure, databricks, google, groq, ollama, openai, openrouter}; use mcp_core::content::Content; use mcp_core::tool::Tool; use std::collections::HashMap; @@ -359,6 +359,21 @@ async fn test_openai_provider() -> Result<()> { .await } +#[tokio::test] +async fn test_azure_provider() -> Result<()> { + test_provider( + "Azure", + &[ + "AZURE_OPENAI_API_KEY", + "AZURE_OPENAI_ENDPOINT", + "AZURE_OPENAI_DEPLOYMENT_NAME", + ], + None, + azure::AzureProvider::default, + ) + .await +} + #[tokio::test] async fn test_databricks_provider() -> Result<()> { test_provider( diff --git a/crates/goose/tests/truncate_agent.rs b/crates/goose/tests/truncate_agent.rs index 991bba79..d3702d5f 100644 --- a/crates/goose/tests/truncate_agent.rs +++ b/crates/goose/tests/truncate_agent.rs @@ -7,13 +7,15 @@ use goose::message::Message; use goose::model::ModelConfig; use goose::providers::base::Provider; use goose::providers::{anthropic::AnthropicProvider, databricks::DatabricksProvider}; -use goose::providers::{google::GoogleProvider, groq::GroqProvider}; use goose::providers::{ - ollama::OllamaProvider, openai::OpenAiProvider, openrouter::OpenRouterProvider, + azure::AzureProvider, ollama::OllamaProvider, openai::OpenAiProvider, + openrouter::OpenRouterProvider, }; +use goose::providers::{google::GoogleProvider, groq::GroqProvider}; #[derive(Debug)] enum ProviderType { + Azure, OpenAi, Anthropic, Databricks, @@ -26,6 +28,11 @@ enum ProviderType { impl ProviderType { fn required_env(&self) -> &'static [&'static str] { match self { + ProviderType::Azure => &[ + "AZURE_OPENAI_API_KEY", + "AZURE_OPENAI_ENDPOINT", + "AZURE_OPENAI_DEPLOYMENT_NAME", + ], ProviderType::OpenAi => &["OPENAI_API_KEY"], ProviderType::Anthropic => &["ANTHROPIC_API_KEY"], ProviderType::Databricks => &["DATABRICKS_HOST"], @@ -56,6 +63,7 @@ impl ProviderType { fn create_provider(&self, model_config: ModelConfig) -> Result> { Ok(match self { + ProviderType::Azure => Box::new(AzureProvider::from_env(model_config)?), ProviderType::OpenAi => Box::new(OpenAiProvider::from_env(model_config)?), ProviderType::Anthropic => Box::new(AnthropicProvider::from_env(model_config)?), ProviderType::Databricks => Box::new(DatabricksProvider::from_env(model_config)?), @@ -172,6 +180,16 @@ mod tests { .await } + #[tokio::test] + async fn test_truncate_agent_with_azure() -> Result<()> { + run_test_with_config(TestConfig { + provider_type: ProviderType::Azure, + model: "gpt-4o-mini", + context_window: 128_000, + }) + .await + } + #[tokio::test] async fn test_truncate_agent_with_anthropic() -> Result<()> { run_test_with_config(TestConfig {