mirror of
https://github.com/aljazceru/goose.git
synced 2025-12-18 14:44:21 +01:00
chore: add tests and update url construction for azure provider (#988)
This commit is contained in:
@@ -44,7 +44,7 @@
|
||||
"azure_openai": {
|
||||
"name": "Azure OpenAI",
|
||||
"description": "Connect to Azure OpenAI Service",
|
||||
"models": ["gpt-4o", "gpt-4o-mini", "o1", "o1-mini"],
|
||||
"models": ["gpt-4o", "gpt-4o-mini"],
|
||||
"required_keys": ["AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOYMENT_NAME"]
|
||||
}
|
||||
}
|
||||
|
||||
@@ -16,14 +16,7 @@ pub const AZURE_DEFAULT_MODEL: &str = "gpt-4o";
|
||||
pub const AZURE_DOC_URL: &str =
|
||||
"https://learn.microsoft.com/en-us/azure/ai-services/openai/concepts/models";
|
||||
pub const AZURE_API_VERSION: &str = "2024-10-21";
|
||||
pub const AZURE_OPENAI_KNOWN_MODELS: &[&str] = &[
|
||||
"gpt-4o",
|
||||
"gpt-4o-mini",
|
||||
"o1",
|
||||
"o1-mini",
|
||||
"o1-preview",
|
||||
"gpt-4",
|
||||
];
|
||||
pub const AZURE_OPENAI_KNOWN_MODELS: &[&str] = &["gpt-4o", "gpt-4o-mini", "gpt-4"];
|
||||
|
||||
#[derive(Debug, serde::Serialize)]
|
||||
pub struct AzureProvider {
|
||||
@@ -63,16 +56,18 @@ impl AzureProvider {
|
||||
}
|
||||
|
||||
async fn post(&self, payload: Value) -> Result<Value, ProviderError> {
|
||||
let url = format!(
|
||||
"{}/openai/deployments/{}/chat/completions?api-version={}",
|
||||
self.endpoint.trim_end_matches('/'),
|
||||
self.deployment_name,
|
||||
AZURE_API_VERSION
|
||||
);
|
||||
let mut base_url = url::Url::parse(&self.endpoint)
|
||||
.map_err(|e| ProviderError::RequestFailed(format!("Invalid base URL: {e}")))?;
|
||||
|
||||
base_url.set_path(&format!(
|
||||
"openai/deployments/{}/chat/completions",
|
||||
self.deployment_name
|
||||
));
|
||||
base_url.set_query(Some(&format!("api-version={}", AZURE_API_VERSION)));
|
||||
|
||||
let response: reqwest::Response = self
|
||||
.client
|
||||
.post(&url)
|
||||
.post(base_url)
|
||||
.header("api-key", &self.api_key)
|
||||
.json(&payload)
|
||||
.send()
|
||||
|
||||
@@ -3,7 +3,7 @@ use dotenv::dotenv;
|
||||
use goose::message::{Message, MessageContent};
|
||||
use goose::providers::base::Provider;
|
||||
use goose::providers::errors::ProviderError;
|
||||
use goose::providers::{anthropic, databricks, google, groq, ollama, openai, openrouter};
|
||||
use goose::providers::{anthropic, azure, databricks, google, groq, ollama, openai, openrouter};
|
||||
use mcp_core::content::Content;
|
||||
use mcp_core::tool::Tool;
|
||||
use std::collections::HashMap;
|
||||
@@ -359,6 +359,21 @@ async fn test_openai_provider() -> Result<()> {
|
||||
.await
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_azure_provider() -> Result<()> {
|
||||
test_provider(
|
||||
"Azure",
|
||||
&[
|
||||
"AZURE_OPENAI_API_KEY",
|
||||
"AZURE_OPENAI_ENDPOINT",
|
||||
"AZURE_OPENAI_DEPLOYMENT_NAME",
|
||||
],
|
||||
None,
|
||||
azure::AzureProvider::default,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_databricks_provider() -> Result<()> {
|
||||
test_provider(
|
||||
|
||||
@@ -7,13 +7,15 @@ use goose::message::Message;
|
||||
use goose::model::ModelConfig;
|
||||
use goose::providers::base::Provider;
|
||||
use goose::providers::{anthropic::AnthropicProvider, databricks::DatabricksProvider};
|
||||
use goose::providers::{google::GoogleProvider, groq::GroqProvider};
|
||||
use goose::providers::{
|
||||
ollama::OllamaProvider, openai::OpenAiProvider, openrouter::OpenRouterProvider,
|
||||
azure::AzureProvider, ollama::OllamaProvider, openai::OpenAiProvider,
|
||||
openrouter::OpenRouterProvider,
|
||||
};
|
||||
use goose::providers::{google::GoogleProvider, groq::GroqProvider};
|
||||
|
||||
#[derive(Debug)]
|
||||
enum ProviderType {
|
||||
Azure,
|
||||
OpenAi,
|
||||
Anthropic,
|
||||
Databricks,
|
||||
@@ -26,6 +28,11 @@ enum ProviderType {
|
||||
impl ProviderType {
|
||||
fn required_env(&self) -> &'static [&'static str] {
|
||||
match self {
|
||||
ProviderType::Azure => &[
|
||||
"AZURE_OPENAI_API_KEY",
|
||||
"AZURE_OPENAI_ENDPOINT",
|
||||
"AZURE_OPENAI_DEPLOYMENT_NAME",
|
||||
],
|
||||
ProviderType::OpenAi => &["OPENAI_API_KEY"],
|
||||
ProviderType::Anthropic => &["ANTHROPIC_API_KEY"],
|
||||
ProviderType::Databricks => &["DATABRICKS_HOST"],
|
||||
@@ -56,6 +63,7 @@ impl ProviderType {
|
||||
|
||||
fn create_provider(&self, model_config: ModelConfig) -> Result<Box<dyn Provider>> {
|
||||
Ok(match self {
|
||||
ProviderType::Azure => Box::new(AzureProvider::from_env(model_config)?),
|
||||
ProviderType::OpenAi => Box::new(OpenAiProvider::from_env(model_config)?),
|
||||
ProviderType::Anthropic => Box::new(AnthropicProvider::from_env(model_config)?),
|
||||
ProviderType::Databricks => Box::new(DatabricksProvider::from_env(model_config)?),
|
||||
@@ -172,6 +180,16 @@ mod tests {
|
||||
.await
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_truncate_agent_with_azure() -> Result<()> {
|
||||
run_test_with_config(TestConfig {
|
||||
provider_type: ProviderType::Azure,
|
||||
model: "gpt-4o-mini",
|
||||
context_window: 128_000,
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_truncate_agent_with_anthropic() -> Result<()> {
|
||||
run_test_with_config(TestConfig {
|
||||
|
||||
Reference in New Issue
Block a user