chore: add tests and update url construction for azure provider (#988)

This commit is contained in:
Alice Hau
2025-01-31 14:03:46 -05:00
committed by GitHub
parent 16df22f817
commit e851a6d64c
4 changed files with 47 additions and 19 deletions

View File

@@ -44,7 +44,7 @@
"azure_openai": {
"name": "Azure OpenAI",
"description": "Connect to Azure OpenAI Service",
"models": ["gpt-4o", "gpt-4o-mini", "o1", "o1-mini"],
"models": ["gpt-4o", "gpt-4o-mini"],
"required_keys": ["AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOYMENT_NAME"]
}
}

View File

@@ -16,14 +16,7 @@ pub const AZURE_DEFAULT_MODEL: &str = "gpt-4o";
pub const AZURE_DOC_URL: &str =
"https://learn.microsoft.com/en-us/azure/ai-services/openai/concepts/models";
pub const AZURE_API_VERSION: &str = "2024-10-21";
pub const AZURE_OPENAI_KNOWN_MODELS: &[&str] = &[
"gpt-4o",
"gpt-4o-mini",
"o1",
"o1-mini",
"o1-preview",
"gpt-4",
];
pub const AZURE_OPENAI_KNOWN_MODELS: &[&str] = &["gpt-4o", "gpt-4o-mini", "gpt-4"];
#[derive(Debug, serde::Serialize)]
pub struct AzureProvider {
@@ -63,16 +56,18 @@ impl AzureProvider {
}
async fn post(&self, payload: Value) -> Result<Value, ProviderError> {
let url = format!(
"{}/openai/deployments/{}/chat/completions?api-version={}",
self.endpoint.trim_end_matches('/'),
self.deployment_name,
AZURE_API_VERSION
);
let mut base_url = url::Url::parse(&self.endpoint)
.map_err(|e| ProviderError::RequestFailed(format!("Invalid base URL: {e}")))?;
base_url.set_path(&format!(
"openai/deployments/{}/chat/completions",
self.deployment_name
));
base_url.set_query(Some(&format!("api-version={}", AZURE_API_VERSION)));
let response: reqwest::Response = self
.client
.post(&url)
.post(base_url)
.header("api-key", &self.api_key)
.json(&payload)
.send()

View File

@@ -3,7 +3,7 @@ use dotenv::dotenv;
use goose::message::{Message, MessageContent};
use goose::providers::base::Provider;
use goose::providers::errors::ProviderError;
use goose::providers::{anthropic, databricks, google, groq, ollama, openai, openrouter};
use goose::providers::{anthropic, azure, databricks, google, groq, ollama, openai, openrouter};
use mcp_core::content::Content;
use mcp_core::tool::Tool;
use std::collections::HashMap;
@@ -359,6 +359,21 @@ async fn test_openai_provider() -> Result<()> {
.await
}
#[tokio::test]
async fn test_azure_provider() -> Result<()> {
test_provider(
"Azure",
&[
"AZURE_OPENAI_API_KEY",
"AZURE_OPENAI_ENDPOINT",
"AZURE_OPENAI_DEPLOYMENT_NAME",
],
None,
azure::AzureProvider::default,
)
.await
}
#[tokio::test]
async fn test_databricks_provider() -> Result<()> {
test_provider(

View File

@@ -7,13 +7,15 @@ use goose::message::Message;
use goose::model::ModelConfig;
use goose::providers::base::Provider;
use goose::providers::{anthropic::AnthropicProvider, databricks::DatabricksProvider};
use goose::providers::{google::GoogleProvider, groq::GroqProvider};
use goose::providers::{
ollama::OllamaProvider, openai::OpenAiProvider, openrouter::OpenRouterProvider,
azure::AzureProvider, ollama::OllamaProvider, openai::OpenAiProvider,
openrouter::OpenRouterProvider,
};
use goose::providers::{google::GoogleProvider, groq::GroqProvider};
#[derive(Debug)]
enum ProviderType {
Azure,
OpenAi,
Anthropic,
Databricks,
@@ -26,6 +28,11 @@ enum ProviderType {
impl ProviderType {
fn required_env(&self) -> &'static [&'static str] {
match self {
ProviderType::Azure => &[
"AZURE_OPENAI_API_KEY",
"AZURE_OPENAI_ENDPOINT",
"AZURE_OPENAI_DEPLOYMENT_NAME",
],
ProviderType::OpenAi => &["OPENAI_API_KEY"],
ProviderType::Anthropic => &["ANTHROPIC_API_KEY"],
ProviderType::Databricks => &["DATABRICKS_HOST"],
@@ -56,6 +63,7 @@ impl ProviderType {
fn create_provider(&self, model_config: ModelConfig) -> Result<Box<dyn Provider>> {
Ok(match self {
ProviderType::Azure => Box::new(AzureProvider::from_env(model_config)?),
ProviderType::OpenAi => Box::new(OpenAiProvider::from_env(model_config)?),
ProviderType::Anthropic => Box::new(AnthropicProvider::from_env(model_config)?),
ProviderType::Databricks => Box::new(DatabricksProvider::from_env(model_config)?),
@@ -172,6 +180,16 @@ mod tests {
.await
}
#[tokio::test]
async fn test_truncate_agent_with_azure() -> Result<()> {
run_test_with_config(TestConfig {
provider_type: ProviderType::Azure,
model: "gpt-4o-mini",
context_window: 128_000,
})
.await
}
#[tokio::test]
async fn test_truncate_agent_with_anthropic() -> Result<()> {
run_test_with_config(TestConfig {