From 5f6c85d7bdaa38c825622b8cada60e1f79ba0eae Mon Sep 17 00:00:00 2001 From: Alice Hau <110418948+ahau-square@users.noreply.github.com> Date: Fri, 31 Jan 2025 09:30:36 -0500 Subject: [PATCH] feat: add azure openai provider (#960) --- .../src/routes/providers_and_keys.json | 8 +- crates/goose/src/providers/azure.rs | 141 +++++++++++++++++ crates/goose/src/providers/factory.rs | 3 + crates/goose/src/providers/mod.rs | 1 + .../settings/ProviderSetupModal.tsx | 45 +++--- .../components/settings/api_keys/utils.tsx | 15 +- .../settings/models/hardcoded_stuff.tsx | 7 + .../providers/ConfigureProvidersGrid.tsx | 143 ++++++++++-------- .../welcome_screen/ProviderGrid.tsx | 2 +- ui/desktop/src/utils/providerUtils.ts | 2 +- 10 files changed, 274 insertions(+), 93 deletions(-) create mode 100644 crates/goose/src/providers/azure.rs diff --git a/crates/goose-server/src/routes/providers_and_keys.json b/crates/goose-server/src/routes/providers_and_keys.json index ae21d035..a936893e 100644 --- a/crates/goose-server/src/routes/providers_and_keys.json +++ b/crates/goose-server/src/routes/providers_and_keys.json @@ -40,5 +40,11 @@ "description": "Lorem ipsum", "models": [], "required_keys": ["OPENROUTER_API_KEY"] + }, + "azure_openai": { + "name": "Azure OpenAI", + "description": "Connect to Azure OpenAI Service", + "models": ["gpt-4o", "gpt-4o-mini", "o1", "o1-mini"], + "required_keys": ["AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOYMENT_NAME"] } -} \ No newline at end of file +} diff --git a/crates/goose/src/providers/azure.rs b/crates/goose/src/providers/azure.rs new file mode 100644 index 00000000..c17f0804 --- /dev/null +++ b/crates/goose/src/providers/azure.rs @@ -0,0 +1,141 @@ +use anyhow::Result; +use async_trait::async_trait; +use reqwest::Client; +use serde_json::Value; +use std::time::Duration; + +use super::base::{ConfigKey, Provider, ProviderMetadata, ProviderUsage, Usage}; +use super::errors::ProviderError; +use super::formats::openai::{create_request, get_usage, response_to_message}; +use super::utils::{emit_debug_trace, get_model, handle_response_openai_compat, ImageFormat}; +use crate::message::Message; +use crate::model::ModelConfig; +use mcp_core::tool::Tool; + +pub const AZURE_DEFAULT_MODEL: &str = "gpt-4o"; +pub const AZURE_DOC_URL: &str = + "https://learn.microsoft.com/en-us/azure/ai-services/openai/concepts/models"; +pub const AZURE_API_VERSION: &str = "2024-10-21"; +pub const AZURE_OPENAI_KNOWN_MODELS: &[&str] = &[ + "gpt-4o", + "gpt-4o-mini", + "o1", + "o1-mini", + "o1-preview", + "gpt-4", +]; + +#[derive(Debug, serde::Serialize)] +pub struct AzureProvider { + #[serde(skip)] + client: Client, + endpoint: String, + api_key: String, + deployment_name: String, + model: ModelConfig, +} + +impl Default for AzureProvider { + fn default() -> Self { + let model = ModelConfig::new(AzureProvider::metadata().default_model); + AzureProvider::from_env(model).expect("Failed to initialize Azure OpenAI provider") + } +} + +impl AzureProvider { + pub fn from_env(model: ModelConfig) -> Result { + let config = crate::config::Config::global(); + let api_key: String = config.get_secret("AZURE_OPENAI_API_KEY")?; + let endpoint: String = config.get("AZURE_OPENAI_ENDPOINT")?; + let deployment_name: String = config.get("AZURE_OPENAI_DEPLOYMENT_NAME")?; + + let client = Client::builder() + .timeout(Duration::from_secs(600)) + .build()?; + + Ok(Self { + client, + endpoint, + api_key, + deployment_name, + model, + }) + } + + async fn post(&self, payload: Value) -> Result { + let url = format!( + "{}/openai/deployments/{}/chat/completions?api-version={}", + self.endpoint.trim_end_matches('/'), + self.deployment_name, + AZURE_API_VERSION + ); + + let response: reqwest::Response = self + .client + .post(&url) + .header("api-key", &self.api_key) + .json(&payload) + .send() + .await?; + + handle_response_openai_compat(response).await + } +} + +#[async_trait] +impl Provider for AzureProvider { + fn metadata() -> ProviderMetadata { + ProviderMetadata::new( + "azure_openai", + "Azure OpenAI", + "Models through Azure OpenAI Service", + "gpt-4o", + AZURE_OPENAI_KNOWN_MODELS + .iter() + .map(|s| s.to_string()) + .collect(), + AZURE_DOC_URL, + vec![ + ConfigKey::new("AZURE_OPENAI_API_KEY", true, true, None), + ConfigKey::new("AZURE_OPENAI_ENDPOINT", true, false, None), + ConfigKey::new( + "AZURE_OPENAI_DEPLOYMENT_NAME", + true, + false, + Some("Name of your Azure OpenAI deployment"), + ), + ], + ) + } + + fn get_model_config(&self) -> ModelConfig { + self.model.clone() + } + + #[tracing::instrument( + skip(self, system, messages, tools), + fields(model_config, input, output, input_tokens, output_tokens, total_tokens) + )] + async fn complete( + &self, + system: &str, + messages: &[Message], + tools: &[Tool], + ) -> Result<(Message, ProviderUsage), ProviderError> { + let payload = create_request(&self.model, system, messages, tools, &ImageFormat::OpenAi)?; + let response = self.post(payload.clone()).await?; + + let message = response_to_message(response.clone())?; + let usage = match get_usage(&response) { + Ok(usage) => usage, + Err(ProviderError::UsageError(e)) => { + tracing::warn!("Failed to get usage data: {}", e); + Usage::default() + } + Err(e) => return Err(e), + }; + let model = get_model(&response); + emit_debug_trace(self, &payload, &response, &usage); + Ok((message, ProviderUsage::new(model, usage))) + } +} diff --git a/crates/goose/src/providers/factory.rs b/crates/goose/src/providers/factory.rs index 99a257ae..ed169aa7 100644 --- a/crates/goose/src/providers/factory.rs +++ b/crates/goose/src/providers/factory.rs @@ -1,5 +1,6 @@ use super::{ anthropic::AnthropicProvider, + azure::AzureProvider, base::{Provider, ProviderMetadata}, databricks::DatabricksProvider, google::GoogleProvider, @@ -14,6 +15,7 @@ use anyhow::Result; pub fn providers() -> Vec { vec![ AnthropicProvider::metadata(), + AzureProvider::metadata(), DatabricksProvider::metadata(), GoogleProvider::metadata(), GroqProvider::metadata(), @@ -27,6 +29,7 @@ pub fn create(name: &str, model: ModelConfig) -> Result Ok(Box::new(OpenAiProvider::from_env(model)?)), "anthropic" => Ok(Box::new(AnthropicProvider::from_env(model)?)), + "azure_openai" => Ok(Box::new(AzureProvider::from_env(model)?)), "databricks" => Ok(Box::new(DatabricksProvider::from_env(model)?)), "groq" => Ok(Box::new(GroqProvider::from_env(model)?)), "ollama" => Ok(Box::new(OllamaProvider::from_env(model)?)), diff --git a/crates/goose/src/providers/mod.rs b/crates/goose/src/providers/mod.rs index 52448a58..de622576 100644 --- a/crates/goose/src/providers/mod.rs +++ b/crates/goose/src/providers/mod.rs @@ -1,4 +1,5 @@ pub mod anthropic; +pub mod azure; pub mod base; pub mod databricks; pub mod errors; diff --git a/ui/desktop/src/components/settings/ProviderSetupModal.tsx b/ui/desktop/src/components/settings/ProviderSetupModal.tsx index 7c493639..006d0952 100644 --- a/ui/desktop/src/components/settings/ProviderSetupModal.tsx +++ b/ui/desktop/src/components/settings/ProviderSetupModal.tsx @@ -12,7 +12,7 @@ interface ProviderSetupModalProps { model: string; endpoint: string; title?: string; - onSubmit: (apiKey: string) => void; + onSubmit: (configValues: { [key: string]: string }) => void; onCancel: () => void; } @@ -24,14 +24,14 @@ export function ProviderSetupModal({ onSubmit, onCancel, }: ProviderSetupModalProps) { - const [apiKey, setApiKey] = React.useState(''); - const keyName = required_keys[provider]?.[0] || 'API Key'; - const headerText = `Setup ${provider}`; + const [configValues, setConfigValues] = React.useState<{ [key: string]: string }>({}); + const requiredKeys = required_keys[provider] || ['API Key']; + const headerText = title || `Setup ${provider}`; + const handleSubmit = (e: React.FormEvent) => { e.preventDefault(); - onSubmit(apiKey); + onSubmit(configValues); }; - const inputType = isSecretKey(keyName) ? 'password' : 'text'; return (
@@ -48,20 +48,27 @@ export function ProviderSetupModal({ {/* Form */}
-
-
- setApiKey(e.target.value)} - placeholder={keyName} - className="w-full h-14 px-4 font-regular rounded-lg border shadow-none border-gray-300 bg-white text-lg placeholder:text-gray-400 font-regular text-gray-900" - required - /> -
- - {`Your API key or host will be stored securely in the keychain and used only for making requests to ${provider}`} +
+ {requiredKeys.map((keyName) => ( +
+ + setConfigValues((prev) => ({ + ...prev, + [keyName]: e.target.value, + })) + } + placeholder={keyName} + className="w-full h-14 px-4 font-regular rounded-lg border shadow-none border-gray-300 bg-white text-lg placeholder:text-gray-400 font-regular text-gray-900" + required + />
+ ))} +
+ + {`Your configuration values will be stored securely in the keychain and used only for making requests to ${provider}`}
diff --git a/ui/desktop/src/components/settings/api_keys/utils.tsx b/ui/desktop/src/components/settings/api_keys/utils.tsx index c8291812..dc8805eb 100644 --- a/ui/desktop/src/components/settings/api_keys/utils.tsx +++ b/ui/desktop/src/components/settings/api_keys/utils.tsx @@ -2,8 +2,14 @@ import { Provider, ProviderResponse } from './types'; import { getApiUrl, getSecretKey } from '../../../config'; export function isSecretKey(keyName: string): boolean { - // Ollama and Databricks use host name right now and it should not be stored as secret. - return keyName != 'DATABRICKS_HOST' && keyName != 'OLLAMA_HOST'; + // Endpoints and hosts should not be stored as secrets + const nonSecretKeys = [ + 'DATABRICKS_HOST', + 'OLLAMA_HOST', + 'AZURE_OPENAI_ENDPOINT', + 'AZURE_OPENAI_DEPLOYMENT_NAME', + ]; + return !nonSecretKeys.includes(keyName); } export async function getActiveProviders(): Promise { @@ -16,9 +22,8 @@ export async function getActiveProviders(): Promise { .filter((provider) => { const apiKeyStatus = Object.values(provider.config_status || {}); // Get all key statuses - // Include providers if: - // - They have at least one key set (`is_set: true`) - return apiKeyStatus.some((key) => key.is_set); + // Include providers if all required keys are set + return apiKeyStatus.length > 0 && apiKeyStatus.every((key) => key.is_set); }) .map((provider) => provider.name || 'Unknown Provider'); // Extract provider name diff --git a/ui/desktop/src/components/settings/models/hardcoded_stuff.tsx b/ui/desktop/src/components/settings/models/hardcoded_stuff.tsx index 9dcba01a..282499e7 100644 --- a/ui/desktop/src/components/settings/models/hardcoded_stuff.tsx +++ b/ui/desktop/src/components/settings/models/hardcoded_stuff.tsx @@ -17,6 +17,7 @@ export const goose_models: Model[] = [ { id: 15, name: 'llama-3.3-70b-versatile', provider: 'Groq' }, { id: 16, name: 'qwen2.5', provider: 'Ollama' }, { id: 17, name: 'anthropic/claude-3.5-sonnet', provider: 'OpenRouter' }, + { id: 18, name: 'gpt-4o', provider: 'Azure OpenAI' }, ]; export const openai_models = ['gpt-4o-mini', 'gpt-4o', 'gpt-4-turbo', 'o1']; @@ -42,6 +43,8 @@ export const ollama_mdoels = ['qwen2.5']; export const openrouter_models = ['anthropic/claude-3.5-sonnet']; +export const azure_openai_models = ['gpt-4o']; + export const default_models = { openai: 'gpt-4o', anthropic: 'claude-3-5-sonnet-latest', @@ -50,6 +53,7 @@ export const default_models = { groq: 'llama-3.3-70b-versatile', openrouter: 'anthropic/claude-3.5-sonnet', ollama: 'qwen2.5', + azure_openai: 'gpt-4o', }; export function getDefaultModel(key: string): string | undefined { @@ -66,6 +70,7 @@ export const required_keys = { Ollama: ['OLLAMA_HOST'], Google: ['GOOGLE_API_KEY'], OpenRouter: ['OPENROUTER_API_KEY'], + 'Azure OpenAI': ['AZURE_OPENAI_API_KEY', 'AZURE_OPENAI_ENDPOINT', 'AZURE_OPENAI_DEPLOYMENT_NAME'], }; export const supported_providers = [ @@ -76,6 +81,7 @@ export const supported_providers = [ 'Google', 'Ollama', 'OpenRouter', + 'Azure OpenAI', ]; export const model_docs_link = [ @@ -99,4 +105,5 @@ export const provider_aliases = [ { provider: 'Databricks', alias: 'databricks' }, { provider: 'OpenRouter', alias: 'openrouter' }, { provider: 'Google', alias: 'google' }, + { provider: 'Azure OpenAI', alias: 'azure_openai' }, ]; diff --git a/ui/desktop/src/components/settings/providers/ConfigureProvidersGrid.tsx b/ui/desktop/src/components/settings/providers/ConfigureProvidersGrid.tsx index 01d6bd95..339adc6a 100644 --- a/ui/desktop/src/components/settings/providers/ConfigureProvidersGrid.tsx +++ b/ui/desktop/src/components/settings/providers/ConfigureProvidersGrid.tsx @@ -74,68 +74,77 @@ export function ConfigureProvidersGrid() { setShowSetupModal(true); }; - const handleModalSubmit = async (apiKey: string) => { + const handleModalSubmit = async (configValues: { [key: string]: string }) => { if (!selectedForSetup) return; const provider = providers.find((p) => p.id === selectedForSetup)?.name; if (!provider) return; - const keyName = required_keys[provider]?.[0]; - if (!keyName) { - console.error(`No key found for provider ${provider}`); + const requiredKeys = required_keys[provider]; + if (!requiredKeys || requiredKeys.length === 0) { + console.error(`No keys found for provider ${provider}`); return; } - const isSecret = isSecretKey(keyName); - try { - // Delete existing key if provider is already configured + // Delete existing keys if provider is already configured const isUpdate = providers.find((p) => p.id === selectedForSetup)?.isConfigured; if (isUpdate) { - const deleteResponse = await fetch(getApiUrl('/configs/delete'), { - method: 'DELETE', + for (const keyName of requiredKeys) { + const isSecret = isSecretKey(keyName); + const deleteResponse = await fetch(getApiUrl('/configs/delete'), { + method: 'DELETE', + headers: { + 'Content-Type': 'application/json', + 'X-Secret-Key': getSecretKey(), + }, + body: JSON.stringify({ + key: keyName, + isSecret, + }), + }); + + if (!deleteResponse.ok) { + const errorText = await deleteResponse.text(); + console.error('Delete response error:', errorText); + throw new Error(`Failed to delete old key: ${keyName}`); + } + } + } + + // Store new keys + for (const keyName of requiredKeys) { + const value = configValues[keyName]; + if (!value) { + console.error(`Missing value for required key: ${keyName}`); + continue; + } + + const isSecret = isSecretKey(keyName); + const storeResponse = await fetch(getApiUrl('/configs/store'), { + method: 'POST', headers: { 'Content-Type': 'application/json', 'X-Secret-Key': getSecretKey(), }, - body: JSON.stringify({ - key: keyName, + body: JSON.stringify({ + key: keyName, + value: value, isSecret, }), }); - if (!deleteResponse.ok) { - const errorText = await deleteResponse.text(); - console.error('Delete response error:', errorText); - throw new Error('Failed to delete old key'); + if (!storeResponse.ok) { + const errorText = await storeResponse.text(); + console.error('Store response error:', errorText); + throw new Error(`Failed to store new key: ${keyName}`); } } - // Store new key - const storeResponse = await fetch(getApiUrl('/configs/store'), { - method: 'POST', - headers: { - 'Content-Type': 'application/json', - 'X-Secret-Key': getSecretKey(), - }, - body: JSON.stringify({ - key: keyName, - value: apiKey.trim(), - isSecret, - }), - }); - - if (!storeResponse.ok) { - const errorText = await storeResponse.text(); - console.error('Store response error:', errorText); - throw new Error('Failed to store new key'); - } - - const toastInfo = isSecret ? 'API key' : 'host'; toast.success( isUpdate - ? `Successfully updated ${toastInfo} for ${provider}` - : `Successfully added ${toastInfo} for ${provider}` + ? `Successfully updated configuration for ${provider}` + : `Successfully added configuration for ${provider}` ); const updatedKeys = await getActiveProviders(); @@ -147,7 +156,7 @@ export function ConfigureProvidersGrid() { } catch (error) { console.error('Error handling modal submit:', error); toast.error( - `Failed to ${providers.find((p) => p.id === selectedForSetup)?.isConfigured ? 'update' : 'add'} API key for ${provider}` + `Failed to ${providers.find((p) => p.id === selectedForSetup)?.isConfigured ? 'update' : 'add'} configuration for ${provider}` ); } }; @@ -160,50 +169,52 @@ export function ConfigureProvidersGrid() { const confirmDelete = async () => { if (!providerToDelete) return; - const keyName = required_keys[providerToDelete.name]?.[0]; - if (!keyName) { - console.error(`No key found for provider ${providerToDelete.name}`); + const requiredKeys = required_keys[providerToDelete.name]; + if (!requiredKeys || requiredKeys.length === 0) { + console.error(`No keys found for provider ${providerToDelete.name}`); return; } - const isSecret = isSecretKey(keyName); - const toastInfo = isSecret ? 'API key' : 'host'; try { // Check if the selected provider is currently active if (currentModel?.provider === providerToDelete.name) { toast.error( - `Cannot delete the ${toastInfo} for ${providerToDelete.name} because it's the provider of the current model (${currentModel.name}). Please switch to a different model first.` + `Cannot delete the configuration for ${providerToDelete.name} because it's the provider of the current model (${currentModel.name}). Please switch to a different model first.` ); setIsConfirmationOpen(false); return; } - const deleteResponse = await fetch(getApiUrl('/configs/delete'), { - method: 'DELETE', - headers: { - 'Content-Type': 'application/json', - 'X-Secret-Key': getSecretKey(), - }, - body: JSON.stringify({ - key: keyName, - isSecret, - }), - }); + // Delete all keys for the provider + for (const keyName of requiredKeys) { + const isSecret = isSecretKey(keyName); + const deleteResponse = await fetch(getApiUrl('/configs/delete'), { + method: 'DELETE', + headers: { + 'Content-Type': 'application/json', + 'X-Secret-Key': getSecretKey(), + }, + body: JSON.stringify({ + key: keyName, + isSecret, + }), + }); - if (!deleteResponse.ok) { - const errorText = await deleteResponse.text(); - console.error('Delete response error:', errorText); - throw new Error('Failed to delete key'); + if (!deleteResponse.ok) { + const errorText = await deleteResponse.text(); + console.error('Delete response error:', errorText); + throw new Error(`Failed to delete key: ${keyName}`); + } } - console.log('Key deleted successfully.'); - toast.success(`Successfully deleted ${toastInfo} for ${providerToDelete.name}`); + console.log('Configuration deleted successfully.'); + toast.success(`Successfully deleted configuration for ${providerToDelete.name}`); const updatedKeys = await getActiveProviders(); setActiveKeys(updatedKeys); } catch (error) { - console.error('Error deleting key:', error); - toast.error(`Unable to delete ${toastInfo} for ${providerToDelete.name}`); + console.error('Error deleting configuration:', error); + toast.error(`Unable to delete configuration for ${providerToDelete.name}`); } setIsConfirmationOpen(false); }; @@ -228,7 +239,7 @@ export function ConfigureProvidersGrid() { endpoint="Example Endpoint" title={ modalMode === 'edit' - ? `Edit ${providers.find((p) => p.id === selectedForSetup)?.name} API Key` + ? `Edit ${providers.find((p) => p.id === selectedForSetup)?.name} Configuration` : undefined } onSubmit={handleModalSubmit} @@ -242,7 +253,7 @@ export function ConfigureProvidersGrid() { {isConfirmationOpen && providerToDelete && ( setIsConfirmationOpen(false)} /> diff --git a/ui/desktop/src/components/welcome_screen/ProviderGrid.tsx b/ui/desktop/src/components/welcome_screen/ProviderGrid.tsx index f2f4f61d..60992f74 100644 --- a/ui/desktop/src/components/welcome_screen/ProviderGrid.tsx +++ b/ui/desktop/src/components/welcome_screen/ProviderGrid.tsx @@ -59,7 +59,7 @@ export function ProviderGrid({ onSubmit }: ProviderGridProps) { localStorage.setItem('GOOSE_PROVIDER', providerId); toast.success( - `Selected ${provider.name} provider. Starting Goose with default model: ${getDefaultModel(provider.name.toLowerCase())}.` + `Selected ${provider.name} provider. Starting Goose with default model: ${getDefaultModel(provider.name.toLowerCase().replace(/ /g, '_'))}.` ); onSubmit?.(); diff --git a/ui/desktop/src/utils/providerUtils.ts b/ui/desktop/src/utils/providerUtils.ts index 074758fd..fd66b6b1 100644 --- a/ui/desktop/src/utils/providerUtils.ts +++ b/ui/desktop/src/utils/providerUtils.ts @@ -76,7 +76,7 @@ const addAgent = async (provider: string, model: string) => { export const initializeSystem = async (provider: string, model: string) => { try { console.log('initializing agent with provider', provider, 'model', model); - await addAgent(provider.toLowerCase(), model); + await addAgent(provider.toLowerCase().replace(/ /g, '_'), model); loadAndAddStoredExtensions().catch((error) => { console.error('Failed to load and add stored extension configs:', error);