feat: add azure openai provider (#960)

This commit is contained in:
Alice Hau
2025-01-31 09:30:36 -05:00
committed by GitHub
parent 092d8711a9
commit 5f6c85d7bd
10 changed files with 274 additions and 93 deletions

View File

@@ -40,5 +40,11 @@
"description": "Lorem ipsum",
"models": [],
"required_keys": ["OPENROUTER_API_KEY"]
},
"azure_openai": {
"name": "Azure OpenAI",
"description": "Connect to Azure OpenAI Service",
"models": ["gpt-4o", "gpt-4o-mini", "o1", "o1-mini"],
"required_keys": ["AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOYMENT_NAME"]
}
}
}

View File

@@ -0,0 +1,141 @@
use anyhow::Result;
use async_trait::async_trait;
use reqwest::Client;
use serde_json::Value;
use std::time::Duration;
use super::base::{ConfigKey, Provider, ProviderMetadata, ProviderUsage, Usage};
use super::errors::ProviderError;
use super::formats::openai::{create_request, get_usage, response_to_message};
use super::utils::{emit_debug_trace, get_model, handle_response_openai_compat, ImageFormat};
use crate::message::Message;
use crate::model::ModelConfig;
use mcp_core::tool::Tool;
pub const AZURE_DEFAULT_MODEL: &str = "gpt-4o";
pub const AZURE_DOC_URL: &str =
"https://learn.microsoft.com/en-us/azure/ai-services/openai/concepts/models";
pub const AZURE_API_VERSION: &str = "2024-10-21";
pub const AZURE_OPENAI_KNOWN_MODELS: &[&str] = &[
"gpt-4o",
"gpt-4o-mini",
"o1",
"o1-mini",
"o1-preview",
"gpt-4",
];
#[derive(Debug, serde::Serialize)]
pub struct AzureProvider {
#[serde(skip)]
client: Client,
endpoint: String,
api_key: String,
deployment_name: String,
model: ModelConfig,
}
impl Default for AzureProvider {
fn default() -> Self {
let model = ModelConfig::new(AzureProvider::metadata().default_model);
AzureProvider::from_env(model).expect("Failed to initialize Azure OpenAI provider")
}
}
impl AzureProvider {
pub fn from_env(model: ModelConfig) -> Result<Self> {
let config = crate::config::Config::global();
let api_key: String = config.get_secret("AZURE_OPENAI_API_KEY")?;
let endpoint: String = config.get("AZURE_OPENAI_ENDPOINT")?;
let deployment_name: String = config.get("AZURE_OPENAI_DEPLOYMENT_NAME")?;
let client = Client::builder()
.timeout(Duration::from_secs(600))
.build()?;
Ok(Self {
client,
endpoint,
api_key,
deployment_name,
model,
})
}
async fn post(&self, payload: Value) -> Result<Value, ProviderError> {
let url = format!(
"{}/openai/deployments/{}/chat/completions?api-version={}",
self.endpoint.trim_end_matches('/'),
self.deployment_name,
AZURE_API_VERSION
);
let response: reqwest::Response = self
.client
.post(&url)
.header("api-key", &self.api_key)
.json(&payload)
.send()
.await?;
handle_response_openai_compat(response).await
}
}
#[async_trait]
impl Provider for AzureProvider {
fn metadata() -> ProviderMetadata {
ProviderMetadata::new(
"azure_openai",
"Azure OpenAI",
"Models through Azure OpenAI Service",
"gpt-4o",
AZURE_OPENAI_KNOWN_MODELS
.iter()
.map(|s| s.to_string())
.collect(),
AZURE_DOC_URL,
vec![
ConfigKey::new("AZURE_OPENAI_API_KEY", true, true, None),
ConfigKey::new("AZURE_OPENAI_ENDPOINT", true, false, None),
ConfigKey::new(
"AZURE_OPENAI_DEPLOYMENT_NAME",
true,
false,
Some("Name of your Azure OpenAI deployment"),
),
],
)
}
fn get_model_config(&self) -> ModelConfig {
self.model.clone()
}
#[tracing::instrument(
skip(self, system, messages, tools),
fields(model_config, input, output, input_tokens, output_tokens, total_tokens)
)]
async fn complete(
&self,
system: &str,
messages: &[Message],
tools: &[Tool],
) -> Result<(Message, ProviderUsage), ProviderError> {
let payload = create_request(&self.model, system, messages, tools, &ImageFormat::OpenAi)?;
let response = self.post(payload.clone()).await?;
let message = response_to_message(response.clone())?;
let usage = match get_usage(&response) {
Ok(usage) => usage,
Err(ProviderError::UsageError(e)) => {
tracing::warn!("Failed to get usage data: {}", e);
Usage::default()
}
Err(e) => return Err(e),
};
let model = get_model(&response);
emit_debug_trace(self, &payload, &response, &usage);
Ok((message, ProviderUsage::new(model, usage)))
}
}

View File

@@ -1,5 +1,6 @@
use super::{
anthropic::AnthropicProvider,
azure::AzureProvider,
base::{Provider, ProviderMetadata},
databricks::DatabricksProvider,
google::GoogleProvider,
@@ -14,6 +15,7 @@ use anyhow::Result;
pub fn providers() -> Vec<ProviderMetadata> {
vec![
AnthropicProvider::metadata(),
AzureProvider::metadata(),
DatabricksProvider::metadata(),
GoogleProvider::metadata(),
GroqProvider::metadata(),
@@ -27,6 +29,7 @@ pub fn create(name: &str, model: ModelConfig) -> Result<Box<dyn Provider + Send
match name {
"openai" => Ok(Box::new(OpenAiProvider::from_env(model)?)),
"anthropic" => Ok(Box::new(AnthropicProvider::from_env(model)?)),
"azure_openai" => Ok(Box::new(AzureProvider::from_env(model)?)),
"databricks" => Ok(Box::new(DatabricksProvider::from_env(model)?)),
"groq" => Ok(Box::new(GroqProvider::from_env(model)?)),
"ollama" => Ok(Box::new(OllamaProvider::from_env(model)?)),

View File

@@ -1,4 +1,5 @@
pub mod anthropic;
pub mod azure;
pub mod base;
pub mod databricks;
pub mod errors;

View File

@@ -12,7 +12,7 @@ interface ProviderSetupModalProps {
model: string;
endpoint: string;
title?: string;
onSubmit: (apiKey: string) => void;
onSubmit: (configValues: { [key: string]: string }) => void;
onCancel: () => void;
}
@@ -24,14 +24,14 @@ export function ProviderSetupModal({
onSubmit,
onCancel,
}: ProviderSetupModalProps) {
const [apiKey, setApiKey] = React.useState('');
const keyName = required_keys[provider]?.[0] || 'API Key';
const headerText = `Setup ${provider}`;
const [configValues, setConfigValues] = React.useState<{ [key: string]: string }>({});
const requiredKeys = required_keys[provider] || ['API Key'];
const headerText = title || `Setup ${provider}`;
const handleSubmit = (e: React.FormEvent) => {
e.preventDefault();
onSubmit(apiKey);
onSubmit(configValues);
};
const inputType = isSecretKey(keyName) ? 'password' : 'text';
return (
<div className="fixed inset-0 bg-black/20 dark:bg-white/20 backdrop-blur-sm transition-colors animate-[fadein_200ms_ease-in_forwards]">
@@ -48,20 +48,27 @@ export function ProviderSetupModal({
{/* Form */}
<form onSubmit={handleSubmit}>
<div className="mt-[24px]">
<div>
<Input
type={inputType}
value={apiKey}
onChange={(e) => setApiKey(e.target.value)}
placeholder={keyName}
className="w-full h-14 px-4 font-regular rounded-lg border shadow-none border-gray-300 bg-white text-lg placeholder:text-gray-400 font-regular text-gray-900"
required
/>
<div className="flex mt-4 text-gray-600 dark:text-gray-300">
<Lock className="w-6 h-6" />
<span className="text-sm font-light ml-4 mt-[2px]">{`Your API key or host will be stored securely in the keychain and used only for making requests to ${provider}`}</span>
<div className="mt-[24px] space-y-4">
{requiredKeys.map((keyName) => (
<div key={keyName}>
<Input
type={isSecretKey(keyName) ? 'password' : 'text'}
value={configValues[keyName] || ''}
onChange={(e) =>
setConfigValues((prev) => ({
...prev,
[keyName]: e.target.value,
}))
}
placeholder={keyName}
className="w-full h-14 px-4 font-regular rounded-lg border shadow-none border-gray-300 bg-white text-lg placeholder:text-gray-400 font-regular text-gray-900"
required
/>
</div>
))}
<div className="flex text-gray-600 dark:text-gray-300">
<Lock className="w-6 h-6" />
<span className="text-sm font-light ml-4 mt-[2px]">{`Your configuration values will be stored securely in the keychain and used only for making requests to ${provider}`}</span>
</div>
</div>

View File

@@ -2,8 +2,14 @@ import { Provider, ProviderResponse } from './types';
import { getApiUrl, getSecretKey } from '../../../config';
export function isSecretKey(keyName: string): boolean {
// Ollama and Databricks use host name right now and it should not be stored as secret.
return keyName != 'DATABRICKS_HOST' && keyName != 'OLLAMA_HOST';
// Endpoints and hosts should not be stored as secrets
const nonSecretKeys = [
'DATABRICKS_HOST',
'OLLAMA_HOST',
'AZURE_OPENAI_ENDPOINT',
'AZURE_OPENAI_DEPLOYMENT_NAME',
];
return !nonSecretKeys.includes(keyName);
}
export async function getActiveProviders(): Promise<string[]> {
@@ -16,9 +22,8 @@ export async function getActiveProviders(): Promise<string[]> {
.filter((provider) => {
const apiKeyStatus = Object.values(provider.config_status || {}); // Get all key statuses
// Include providers if:
// - They have at least one key set (`is_set: true`)
return apiKeyStatus.some((key) => key.is_set);
// Include providers if all required keys are set
return apiKeyStatus.length > 0 && apiKeyStatus.every((key) => key.is_set);
})
.map((provider) => provider.name || 'Unknown Provider'); // Extract provider name

View File

@@ -17,6 +17,7 @@ export const goose_models: Model[] = [
{ id: 15, name: 'llama-3.3-70b-versatile', provider: 'Groq' },
{ id: 16, name: 'qwen2.5', provider: 'Ollama' },
{ id: 17, name: 'anthropic/claude-3.5-sonnet', provider: 'OpenRouter' },
{ id: 18, name: 'gpt-4o', provider: 'Azure OpenAI' },
];
export const openai_models = ['gpt-4o-mini', 'gpt-4o', 'gpt-4-turbo', 'o1'];
@@ -42,6 +43,8 @@ export const ollama_mdoels = ['qwen2.5'];
export const openrouter_models = ['anthropic/claude-3.5-sonnet'];
export const azure_openai_models = ['gpt-4o'];
export const default_models = {
openai: 'gpt-4o',
anthropic: 'claude-3-5-sonnet-latest',
@@ -50,6 +53,7 @@ export const default_models = {
groq: 'llama-3.3-70b-versatile',
openrouter: 'anthropic/claude-3.5-sonnet',
ollama: 'qwen2.5',
azure_openai: 'gpt-4o',
};
export function getDefaultModel(key: string): string | undefined {
@@ -66,6 +70,7 @@ export const required_keys = {
Ollama: ['OLLAMA_HOST'],
Google: ['GOOGLE_API_KEY'],
OpenRouter: ['OPENROUTER_API_KEY'],
'Azure OpenAI': ['AZURE_OPENAI_API_KEY', 'AZURE_OPENAI_ENDPOINT', 'AZURE_OPENAI_DEPLOYMENT_NAME'],
};
export const supported_providers = [
@@ -76,6 +81,7 @@ export const supported_providers = [
'Google',
'Ollama',
'OpenRouter',
'Azure OpenAI',
];
export const model_docs_link = [
@@ -99,4 +105,5 @@ export const provider_aliases = [
{ provider: 'Databricks', alias: 'databricks' },
{ provider: 'OpenRouter', alias: 'openrouter' },
{ provider: 'Google', alias: 'google' },
{ provider: 'Azure OpenAI', alias: 'azure_openai' },
];

View File

@@ -74,68 +74,77 @@ export function ConfigureProvidersGrid() {
setShowSetupModal(true);
};
const handleModalSubmit = async (apiKey: string) => {
const handleModalSubmit = async (configValues: { [key: string]: string }) => {
if (!selectedForSetup) return;
const provider = providers.find((p) => p.id === selectedForSetup)?.name;
if (!provider) return;
const keyName = required_keys[provider]?.[0];
if (!keyName) {
console.error(`No key found for provider ${provider}`);
const requiredKeys = required_keys[provider];
if (!requiredKeys || requiredKeys.length === 0) {
console.error(`No keys found for provider ${provider}`);
return;
}
const isSecret = isSecretKey(keyName);
try {
// Delete existing key if provider is already configured
// Delete existing keys if provider is already configured
const isUpdate = providers.find((p) => p.id === selectedForSetup)?.isConfigured;
if (isUpdate) {
const deleteResponse = await fetch(getApiUrl('/configs/delete'), {
method: 'DELETE',
for (const keyName of requiredKeys) {
const isSecret = isSecretKey(keyName);
const deleteResponse = await fetch(getApiUrl('/configs/delete'), {
method: 'DELETE',
headers: {
'Content-Type': 'application/json',
'X-Secret-Key': getSecretKey(),
},
body: JSON.stringify({
key: keyName,
isSecret,
}),
});
if (!deleteResponse.ok) {
const errorText = await deleteResponse.text();
console.error('Delete response error:', errorText);
throw new Error(`Failed to delete old key: ${keyName}`);
}
}
}
// Store new keys
for (const keyName of requiredKeys) {
const value = configValues[keyName];
if (!value) {
console.error(`Missing value for required key: ${keyName}`);
continue;
}
const isSecret = isSecretKey(keyName);
const storeResponse = await fetch(getApiUrl('/configs/store'), {
method: 'POST',
headers: {
'Content-Type': 'application/json',
'X-Secret-Key': getSecretKey(),
},
body: JSON.stringify({
key: keyName,
body: JSON.stringify({
key: keyName,
value: value,
isSecret,
}),
});
if (!deleteResponse.ok) {
const errorText = await deleteResponse.text();
console.error('Delete response error:', errorText);
throw new Error('Failed to delete old key');
if (!storeResponse.ok) {
const errorText = await storeResponse.text();
console.error('Store response error:', errorText);
throw new Error(`Failed to store new key: ${keyName}`);
}
}
// Store new key
const storeResponse = await fetch(getApiUrl('/configs/store'), {
method: 'POST',
headers: {
'Content-Type': 'application/json',
'X-Secret-Key': getSecretKey(),
},
body: JSON.stringify({
key: keyName,
value: apiKey.trim(),
isSecret,
}),
});
if (!storeResponse.ok) {
const errorText = await storeResponse.text();
console.error('Store response error:', errorText);
throw new Error('Failed to store new key');
}
const toastInfo = isSecret ? 'API key' : 'host';
toast.success(
isUpdate
? `Successfully updated ${toastInfo} for ${provider}`
: `Successfully added ${toastInfo} for ${provider}`
? `Successfully updated configuration for ${provider}`
: `Successfully added configuration for ${provider}`
);
const updatedKeys = await getActiveProviders();
@@ -147,7 +156,7 @@ export function ConfigureProvidersGrid() {
} catch (error) {
console.error('Error handling modal submit:', error);
toast.error(
`Failed to ${providers.find((p) => p.id === selectedForSetup)?.isConfigured ? 'update' : 'add'} API key for ${provider}`
`Failed to ${providers.find((p) => p.id === selectedForSetup)?.isConfigured ? 'update' : 'add'} configuration for ${provider}`
);
}
};
@@ -160,50 +169,52 @@ export function ConfigureProvidersGrid() {
const confirmDelete = async () => {
if (!providerToDelete) return;
const keyName = required_keys[providerToDelete.name]?.[0];
if (!keyName) {
console.error(`No key found for provider ${providerToDelete.name}`);
const requiredKeys = required_keys[providerToDelete.name];
if (!requiredKeys || requiredKeys.length === 0) {
console.error(`No keys found for provider ${providerToDelete.name}`);
return;
}
const isSecret = isSecretKey(keyName);
const toastInfo = isSecret ? 'API key' : 'host';
try {
// Check if the selected provider is currently active
if (currentModel?.provider === providerToDelete.name) {
toast.error(
`Cannot delete the ${toastInfo} for ${providerToDelete.name} because it's the provider of the current model (${currentModel.name}). Please switch to a different model first.`
`Cannot delete the configuration for ${providerToDelete.name} because it's the provider of the current model (${currentModel.name}). Please switch to a different model first.`
);
setIsConfirmationOpen(false);
return;
}
const deleteResponse = await fetch(getApiUrl('/configs/delete'), {
method: 'DELETE',
headers: {
'Content-Type': 'application/json',
'X-Secret-Key': getSecretKey(),
},
body: JSON.stringify({
key: keyName,
isSecret,
}),
});
// Delete all keys for the provider
for (const keyName of requiredKeys) {
const isSecret = isSecretKey(keyName);
const deleteResponse = await fetch(getApiUrl('/configs/delete'), {
method: 'DELETE',
headers: {
'Content-Type': 'application/json',
'X-Secret-Key': getSecretKey(),
},
body: JSON.stringify({
key: keyName,
isSecret,
}),
});
if (!deleteResponse.ok) {
const errorText = await deleteResponse.text();
console.error('Delete response error:', errorText);
throw new Error('Failed to delete key');
if (!deleteResponse.ok) {
const errorText = await deleteResponse.text();
console.error('Delete response error:', errorText);
throw new Error(`Failed to delete key: ${keyName}`);
}
}
console.log('Key deleted successfully.');
toast.success(`Successfully deleted ${toastInfo} for ${providerToDelete.name}`);
console.log('Configuration deleted successfully.');
toast.success(`Successfully deleted configuration for ${providerToDelete.name}`);
const updatedKeys = await getActiveProviders();
setActiveKeys(updatedKeys);
} catch (error) {
console.error('Error deleting key:', error);
toast.error(`Unable to delete ${toastInfo} for ${providerToDelete.name}`);
console.error('Error deleting configuration:', error);
toast.error(`Unable to delete configuration for ${providerToDelete.name}`);
}
setIsConfirmationOpen(false);
};
@@ -228,7 +239,7 @@ export function ConfigureProvidersGrid() {
endpoint="Example Endpoint"
title={
modalMode === 'edit'
? `Edit ${providers.find((p) => p.id === selectedForSetup)?.name} API Key`
? `Edit ${providers.find((p) => p.id === selectedForSetup)?.name} Configuration`
: undefined
}
onSubmit={handleModalSubmit}
@@ -242,7 +253,7 @@ export function ConfigureProvidersGrid() {
{isConfirmationOpen && providerToDelete && (
<ConfirmationModal
message={`Are you sure you want to delete the API key or host for ${providerToDelete.name}? This action cannot be undone.`}
message={`Are you sure you want to delete the configuration for ${providerToDelete.name}? This action cannot be undone.`}
onConfirm={confirmDelete}
onCancel={() => setIsConfirmationOpen(false)}
/>

View File

@@ -59,7 +59,7 @@ export function ProviderGrid({ onSubmit }: ProviderGridProps) {
localStorage.setItem('GOOSE_PROVIDER', providerId);
toast.success(
`Selected ${provider.name} provider. Starting Goose with default model: ${getDefaultModel(provider.name.toLowerCase())}.`
`Selected ${provider.name} provider. Starting Goose with default model: ${getDefaultModel(provider.name.toLowerCase().replace(/ /g, '_'))}.`
);
onSubmit?.();

View File

@@ -76,7 +76,7 @@ const addAgent = async (provider: string, model: string) => {
export const initializeSystem = async (provider: string, model: string) => {
try {
console.log('initializing agent with provider', provider, 'model', model);
await addAgent(provider.toLowerCase(), model);
await addAgent(provider.toLowerCase().replace(/ /g, '_'), model);
loadAndAddStoredExtensions().catch((error) => {
console.error('Failed to load and add stored extension configs:', error);