mirror of
https://github.com/aljazceru/goose.git
synced 2025-12-19 15:14:21 +01:00
chore: add tests and update url construction for azure provider (#988)
This commit is contained in:
@@ -44,7 +44,7 @@
|
|||||||
"azure_openai": {
|
"azure_openai": {
|
||||||
"name": "Azure OpenAI",
|
"name": "Azure OpenAI",
|
||||||
"description": "Connect to Azure OpenAI Service",
|
"description": "Connect to Azure OpenAI Service",
|
||||||
"models": ["gpt-4o", "gpt-4o-mini", "o1", "o1-mini"],
|
"models": ["gpt-4o", "gpt-4o-mini"],
|
||||||
"required_keys": ["AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOYMENT_NAME"]
|
"required_keys": ["AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOYMENT_NAME"]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -16,14 +16,7 @@ pub const AZURE_DEFAULT_MODEL: &str = "gpt-4o";
|
|||||||
pub const AZURE_DOC_URL: &str =
|
pub const AZURE_DOC_URL: &str =
|
||||||
"https://learn.microsoft.com/en-us/azure/ai-services/openai/concepts/models";
|
"https://learn.microsoft.com/en-us/azure/ai-services/openai/concepts/models";
|
||||||
pub const AZURE_API_VERSION: &str = "2024-10-21";
|
pub const AZURE_API_VERSION: &str = "2024-10-21";
|
||||||
pub const AZURE_OPENAI_KNOWN_MODELS: &[&str] = &[
|
pub const AZURE_OPENAI_KNOWN_MODELS: &[&str] = &["gpt-4o", "gpt-4o-mini", "gpt-4"];
|
||||||
"gpt-4o",
|
|
||||||
"gpt-4o-mini",
|
|
||||||
"o1",
|
|
||||||
"o1-mini",
|
|
||||||
"o1-preview",
|
|
||||||
"gpt-4",
|
|
||||||
];
|
|
||||||
|
|
||||||
#[derive(Debug, serde::Serialize)]
|
#[derive(Debug, serde::Serialize)]
|
||||||
pub struct AzureProvider {
|
pub struct AzureProvider {
|
||||||
@@ -63,16 +56,18 @@ impl AzureProvider {
|
|||||||
}
|
}
|
||||||
|
|
||||||
async fn post(&self, payload: Value) -> Result<Value, ProviderError> {
|
async fn post(&self, payload: Value) -> Result<Value, ProviderError> {
|
||||||
let url = format!(
|
let mut base_url = url::Url::parse(&self.endpoint)
|
||||||
"{}/openai/deployments/{}/chat/completions?api-version={}",
|
.map_err(|e| ProviderError::RequestFailed(format!("Invalid base URL: {e}")))?;
|
||||||
self.endpoint.trim_end_matches('/'),
|
|
||||||
self.deployment_name,
|
base_url.set_path(&format!(
|
||||||
AZURE_API_VERSION
|
"openai/deployments/{}/chat/completions",
|
||||||
);
|
self.deployment_name
|
||||||
|
));
|
||||||
|
base_url.set_query(Some(&format!("api-version={}", AZURE_API_VERSION)));
|
||||||
|
|
||||||
let response: reqwest::Response = self
|
let response: reqwest::Response = self
|
||||||
.client
|
.client
|
||||||
.post(&url)
|
.post(base_url)
|
||||||
.header("api-key", &self.api_key)
|
.header("api-key", &self.api_key)
|
||||||
.json(&payload)
|
.json(&payload)
|
||||||
.send()
|
.send()
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ use dotenv::dotenv;
|
|||||||
use goose::message::{Message, MessageContent};
|
use goose::message::{Message, MessageContent};
|
||||||
use goose::providers::base::Provider;
|
use goose::providers::base::Provider;
|
||||||
use goose::providers::errors::ProviderError;
|
use goose::providers::errors::ProviderError;
|
||||||
use goose::providers::{anthropic, databricks, google, groq, ollama, openai, openrouter};
|
use goose::providers::{anthropic, azure, databricks, google, groq, ollama, openai, openrouter};
|
||||||
use mcp_core::content::Content;
|
use mcp_core::content::Content;
|
||||||
use mcp_core::tool::Tool;
|
use mcp_core::tool::Tool;
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
@@ -359,6 +359,21 @@ async fn test_openai_provider() -> Result<()> {
|
|||||||
.await
|
.await
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_azure_provider() -> Result<()> {
|
||||||
|
test_provider(
|
||||||
|
"Azure",
|
||||||
|
&[
|
||||||
|
"AZURE_OPENAI_API_KEY",
|
||||||
|
"AZURE_OPENAI_ENDPOINT",
|
||||||
|
"AZURE_OPENAI_DEPLOYMENT_NAME",
|
||||||
|
],
|
||||||
|
None,
|
||||||
|
azure::AzureProvider::default,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_databricks_provider() -> Result<()> {
|
async fn test_databricks_provider() -> Result<()> {
|
||||||
test_provider(
|
test_provider(
|
||||||
|
|||||||
@@ -7,13 +7,15 @@ use goose::message::Message;
|
|||||||
use goose::model::ModelConfig;
|
use goose::model::ModelConfig;
|
||||||
use goose::providers::base::Provider;
|
use goose::providers::base::Provider;
|
||||||
use goose::providers::{anthropic::AnthropicProvider, databricks::DatabricksProvider};
|
use goose::providers::{anthropic::AnthropicProvider, databricks::DatabricksProvider};
|
||||||
use goose::providers::{google::GoogleProvider, groq::GroqProvider};
|
|
||||||
use goose::providers::{
|
use goose::providers::{
|
||||||
ollama::OllamaProvider, openai::OpenAiProvider, openrouter::OpenRouterProvider,
|
azure::AzureProvider, ollama::OllamaProvider, openai::OpenAiProvider,
|
||||||
|
openrouter::OpenRouterProvider,
|
||||||
};
|
};
|
||||||
|
use goose::providers::{google::GoogleProvider, groq::GroqProvider};
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
enum ProviderType {
|
enum ProviderType {
|
||||||
|
Azure,
|
||||||
OpenAi,
|
OpenAi,
|
||||||
Anthropic,
|
Anthropic,
|
||||||
Databricks,
|
Databricks,
|
||||||
@@ -26,6 +28,11 @@ enum ProviderType {
|
|||||||
impl ProviderType {
|
impl ProviderType {
|
||||||
fn required_env(&self) -> &'static [&'static str] {
|
fn required_env(&self) -> &'static [&'static str] {
|
||||||
match self {
|
match self {
|
||||||
|
ProviderType::Azure => &[
|
||||||
|
"AZURE_OPENAI_API_KEY",
|
||||||
|
"AZURE_OPENAI_ENDPOINT",
|
||||||
|
"AZURE_OPENAI_DEPLOYMENT_NAME",
|
||||||
|
],
|
||||||
ProviderType::OpenAi => &["OPENAI_API_KEY"],
|
ProviderType::OpenAi => &["OPENAI_API_KEY"],
|
||||||
ProviderType::Anthropic => &["ANTHROPIC_API_KEY"],
|
ProviderType::Anthropic => &["ANTHROPIC_API_KEY"],
|
||||||
ProviderType::Databricks => &["DATABRICKS_HOST"],
|
ProviderType::Databricks => &["DATABRICKS_HOST"],
|
||||||
@@ -56,6 +63,7 @@ impl ProviderType {
|
|||||||
|
|
||||||
fn create_provider(&self, model_config: ModelConfig) -> Result<Box<dyn Provider>> {
|
fn create_provider(&self, model_config: ModelConfig) -> Result<Box<dyn Provider>> {
|
||||||
Ok(match self {
|
Ok(match self {
|
||||||
|
ProviderType::Azure => Box::new(AzureProvider::from_env(model_config)?),
|
||||||
ProviderType::OpenAi => Box::new(OpenAiProvider::from_env(model_config)?),
|
ProviderType::OpenAi => Box::new(OpenAiProvider::from_env(model_config)?),
|
||||||
ProviderType::Anthropic => Box::new(AnthropicProvider::from_env(model_config)?),
|
ProviderType::Anthropic => Box::new(AnthropicProvider::from_env(model_config)?),
|
||||||
ProviderType::Databricks => Box::new(DatabricksProvider::from_env(model_config)?),
|
ProviderType::Databricks => Box::new(DatabricksProvider::from_env(model_config)?),
|
||||||
@@ -172,6 +180,16 @@ mod tests {
|
|||||||
.await
|
.await
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_truncate_agent_with_azure() -> Result<()> {
|
||||||
|
run_test_with_config(TestConfig {
|
||||||
|
provider_type: ProviderType::Azure,
|
||||||
|
model: "gpt-4o-mini",
|
||||||
|
context_window: 128_000,
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_truncate_agent_with_anthropic() -> Result<()> {
|
async fn test_truncate_agent_with_anthropic() -> Result<()> {
|
||||||
run_test_with_config(TestConfig {
|
run_test_with_config(TestConfig {
|
||||||
|
|||||||
Reference in New Issue
Block a user