chore: add tests and update url construction for azure provider (#988)

This commit is contained in:
Alice Hau
2025-01-31 14:03:46 -05:00
committed by GitHub
parent 16df22f817
commit e851a6d64c
4 changed files with 47 additions and 19 deletions

View File

@@ -44,7 +44,7 @@
"azure_openai": { "azure_openai": {
"name": "Azure OpenAI", "name": "Azure OpenAI",
"description": "Connect to Azure OpenAI Service", "description": "Connect to Azure OpenAI Service",
"models": ["gpt-4o", "gpt-4o-mini", "o1", "o1-mini"], "models": ["gpt-4o", "gpt-4o-mini"],
"required_keys": ["AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOYMENT_NAME"] "required_keys": ["AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOYMENT_NAME"]
} }
} }

View File

@@ -16,14 +16,7 @@ pub const AZURE_DEFAULT_MODEL: &str = "gpt-4o";
pub const AZURE_DOC_URL: &str = pub const AZURE_DOC_URL: &str =
"https://learn.microsoft.com/en-us/azure/ai-services/openai/concepts/models"; "https://learn.microsoft.com/en-us/azure/ai-services/openai/concepts/models";
pub const AZURE_API_VERSION: &str = "2024-10-21"; pub const AZURE_API_VERSION: &str = "2024-10-21";
pub const AZURE_OPENAI_KNOWN_MODELS: &[&str] = &[ pub const AZURE_OPENAI_KNOWN_MODELS: &[&str] = &["gpt-4o", "gpt-4o-mini", "gpt-4"];
"gpt-4o",
"gpt-4o-mini",
"o1",
"o1-mini",
"o1-preview",
"gpt-4",
];
#[derive(Debug, serde::Serialize)] #[derive(Debug, serde::Serialize)]
pub struct AzureProvider { pub struct AzureProvider {
@@ -63,16 +56,18 @@ impl AzureProvider {
} }
async fn post(&self, payload: Value) -> Result<Value, ProviderError> { async fn post(&self, payload: Value) -> Result<Value, ProviderError> {
let url = format!( let mut base_url = url::Url::parse(&self.endpoint)
"{}/openai/deployments/{}/chat/completions?api-version={}", .map_err(|e| ProviderError::RequestFailed(format!("Invalid base URL: {e}")))?;
self.endpoint.trim_end_matches('/'),
self.deployment_name, base_url.set_path(&format!(
AZURE_API_VERSION "openai/deployments/{}/chat/completions",
); self.deployment_name
));
base_url.set_query(Some(&format!("api-version={}", AZURE_API_VERSION)));
let response: reqwest::Response = self let response: reqwest::Response = self
.client .client
.post(&url) .post(base_url)
.header("api-key", &self.api_key) .header("api-key", &self.api_key)
.json(&payload) .json(&payload)
.send() .send()

View File

@@ -3,7 +3,7 @@ use dotenv::dotenv;
use goose::message::{Message, MessageContent}; use goose::message::{Message, MessageContent};
use goose::providers::base::Provider; use goose::providers::base::Provider;
use goose::providers::errors::ProviderError; use goose::providers::errors::ProviderError;
use goose::providers::{anthropic, databricks, google, groq, ollama, openai, openrouter}; use goose::providers::{anthropic, azure, databricks, google, groq, ollama, openai, openrouter};
use mcp_core::content::Content; use mcp_core::content::Content;
use mcp_core::tool::Tool; use mcp_core::tool::Tool;
use std::collections::HashMap; use std::collections::HashMap;
@@ -359,6 +359,21 @@ async fn test_openai_provider() -> Result<()> {
.await .await
} }
#[tokio::test]
async fn test_azure_provider() -> Result<()> {
test_provider(
"Azure",
&[
"AZURE_OPENAI_API_KEY",
"AZURE_OPENAI_ENDPOINT",
"AZURE_OPENAI_DEPLOYMENT_NAME",
],
None,
azure::AzureProvider::default,
)
.await
}
#[tokio::test] #[tokio::test]
async fn test_databricks_provider() -> Result<()> { async fn test_databricks_provider() -> Result<()> {
test_provider( test_provider(

View File

@@ -7,13 +7,15 @@ use goose::message::Message;
use goose::model::ModelConfig; use goose::model::ModelConfig;
use goose::providers::base::Provider; use goose::providers::base::Provider;
use goose::providers::{anthropic::AnthropicProvider, databricks::DatabricksProvider}; use goose::providers::{anthropic::AnthropicProvider, databricks::DatabricksProvider};
use goose::providers::{google::GoogleProvider, groq::GroqProvider};
use goose::providers::{ use goose::providers::{
ollama::OllamaProvider, openai::OpenAiProvider, openrouter::OpenRouterProvider, azure::AzureProvider, ollama::OllamaProvider, openai::OpenAiProvider,
openrouter::OpenRouterProvider,
}; };
use goose::providers::{google::GoogleProvider, groq::GroqProvider};
#[derive(Debug)] #[derive(Debug)]
enum ProviderType { enum ProviderType {
Azure,
OpenAi, OpenAi,
Anthropic, Anthropic,
Databricks, Databricks,
@@ -26,6 +28,11 @@ enum ProviderType {
impl ProviderType { impl ProviderType {
fn required_env(&self) -> &'static [&'static str] { fn required_env(&self) -> &'static [&'static str] {
match self { match self {
ProviderType::Azure => &[
"AZURE_OPENAI_API_KEY",
"AZURE_OPENAI_ENDPOINT",
"AZURE_OPENAI_DEPLOYMENT_NAME",
],
ProviderType::OpenAi => &["OPENAI_API_KEY"], ProviderType::OpenAi => &["OPENAI_API_KEY"],
ProviderType::Anthropic => &["ANTHROPIC_API_KEY"], ProviderType::Anthropic => &["ANTHROPIC_API_KEY"],
ProviderType::Databricks => &["DATABRICKS_HOST"], ProviderType::Databricks => &["DATABRICKS_HOST"],
@@ -56,6 +63,7 @@ impl ProviderType {
fn create_provider(&self, model_config: ModelConfig) -> Result<Box<dyn Provider>> { fn create_provider(&self, model_config: ModelConfig) -> Result<Box<dyn Provider>> {
Ok(match self { Ok(match self {
ProviderType::Azure => Box::new(AzureProvider::from_env(model_config)?),
ProviderType::OpenAi => Box::new(OpenAiProvider::from_env(model_config)?), ProviderType::OpenAi => Box::new(OpenAiProvider::from_env(model_config)?),
ProviderType::Anthropic => Box::new(AnthropicProvider::from_env(model_config)?), ProviderType::Anthropic => Box::new(AnthropicProvider::from_env(model_config)?),
ProviderType::Databricks => Box::new(DatabricksProvider::from_env(model_config)?), ProviderType::Databricks => Box::new(DatabricksProvider::from_env(model_config)?),
@@ -172,6 +180,16 @@ mod tests {
.await .await
} }
#[tokio::test]
async fn test_truncate_agent_with_azure() -> Result<()> {
run_test_with_config(TestConfig {
provider_type: ProviderType::Azure,
model: "gpt-4o-mini",
context_window: 128_000,
})
.await
}
#[tokio::test] #[tokio::test]
async fn test_truncate_agent_with_anthropic() -> Result<()> { async fn test_truncate_agent_with_anthropic() -> Result<()> {
run_test_with_config(TestConfig { run_test_with_config(TestConfig {