feat: update config endpoints for use with providers (#1563)

This commit is contained in:
Lily Delalande
2025-03-10 09:51:54 -07:00
committed by GitHub
parent 3b36591cb5
commit 5df2875c1c
43 changed files with 945 additions and 428 deletions

1
Cargo.lock generated
View File

@@ -2312,6 +2312,7 @@ dependencies = [
"tracing", "tracing",
"tracing-subscriber", "tracing-subscriber",
"url", "url",
"utoipa",
"uuid", "uuid",
"webbrowser", "webbrowser",
"winapi", "winapi",

View File

@@ -140,10 +140,10 @@ pub async fn run_benchmark(
let config = Config::global(); let config = Config::global();
let goose_model: String = config let goose_model: String = config
.get("GOOSE_MODEL") .get_param("GOOSE_MODEL")
.expect("No model configured. Run 'goose configure' first"); .expect("No model configured. Run 'goose configure' first");
let provider_name: String = config let provider_name: String = config
.get("GOOSE_PROVIDER") .get_param("GOOSE_PROVIDER")
.expect("No provider configured. Run 'goose configure' first"); .expect("No provider configured. Run 'goose configure' first");
let mut results = BenchmarkResults::new(provider_name.clone()); let mut results = BenchmarkResults::new(provider_name.clone());

View File

@@ -184,7 +184,7 @@ pub async fn configure_provider_dialog() -> Result<bool, Box<dyn Error>> {
.collect(); .collect();
// Get current default provider if it exists // Get current default provider if it exists
let current_provider: Option<String> = config.get("GOOSE_PROVIDER").ok(); let current_provider: Option<String> = config.get_param("GOOSE_PROVIDER").ok();
let default_provider = current_provider.unwrap_or_default(); let default_provider = current_provider.unwrap_or_default();
// Select provider // Select provider
@@ -219,7 +219,7 @@ pub async fn configure_provider_dialog() -> Result<bool, Box<dyn Error>> {
if key.secret { if key.secret {
config.set_secret(&key.name, Value::String(env_value))?; config.set_secret(&key.name, Value::String(env_value))?;
} else { } else {
config.set(&key.name, Value::String(env_value))?; config.set_param(&key.name, Value::String(env_value))?;
} }
let _ = cliclack::log::info(format!("Saved {} to config file", key.name)); let _ = cliclack::log::info(format!("Saved {} to config file", key.name));
} }
@@ -229,7 +229,7 @@ pub async fn configure_provider_dialog() -> Result<bool, Box<dyn Error>> {
let existing: Result<String, _> = if key.secret { let existing: Result<String, _> = if key.secret {
config.get_secret(&key.name) config.get_secret(&key.name)
} else { } else {
config.get(&key.name) config.get_param(&key.name)
}; };
match existing { match existing {
@@ -252,7 +252,7 @@ pub async fn configure_provider_dialog() -> Result<bool, Box<dyn Error>> {
if key.secret { if key.secret {
config.set_secret(&key.name, Value::String(new_value))?; config.set_secret(&key.name, Value::String(new_value))?;
} else { } else {
config.set(&key.name, Value::String(new_value))?; config.set_param(&key.name, Value::String(new_value))?;
} }
} }
} }
@@ -278,7 +278,7 @@ pub async fn configure_provider_dialog() -> Result<bool, Box<dyn Error>> {
if key.secret { if key.secret {
config.set_secret(&key.name, Value::String(value))?; config.set_secret(&key.name, Value::String(value))?;
} else { } else {
config.set(&key.name, Value::String(value))?; config.set_param(&key.name, Value::String(value))?;
} }
} }
} }
@@ -325,8 +325,8 @@ pub async fn configure_provider_dialog() -> Result<bool, Box<dyn Error>> {
match result { match result {
Ok((_message, _usage)) => { Ok((_message, _usage)) => {
// Update config with new values only if the test succeeds // Update config with new values only if the test succeeds
config.set("GOOSE_PROVIDER", Value::String(provider_name.to_string()))?; config.set_param("GOOSE_PROVIDER", Value::String(provider_name.to_string()))?;
config.set("GOOSE_MODEL", Value::String(model.clone()))?; config.set_param("GOOSE_MODEL", Value::String(model.clone()))?;
cliclack::outro("Configuration saved successfully")?; cliclack::outro("Configuration saved successfully")?;
Ok(true) Ok(true)
} }
@@ -708,15 +708,15 @@ pub fn configure_goose_mode_dialog() -> Result<(), Box<dyn Error>> {
match mode { match mode {
"auto" => { "auto" => {
config.set("GOOSE_MODE", Value::String("auto".to_string()))?; config.set_param("GOOSE_MODE", Value::String("auto".to_string()))?;
cliclack::outro("Set to Auto Mode - full file modification enabled")?; cliclack::outro("Set to Auto Mode - full file modification enabled")?;
} }
"approve" => { "approve" => {
config.set("GOOSE_MODE", Value::String("approve".to_string()))?; config.set_param("GOOSE_MODE", Value::String("approve".to_string()))?;
cliclack::outro("Set to Approve Mode - modifications require approval")?; cliclack::outro("Set to Approve Mode - modifications require approval")?;
} }
"chat" => { "chat" => {
config.set("GOOSE_MODE", Value::String("chat".to_string()))?; config.set_param("GOOSE_MODE", Value::String("chat".to_string()))?;
cliclack::outro("Set to Chat Mode - no tools or modifications enabled")?; cliclack::outro("Set to Chat Mode - no tools or modifications enabled")?;
} }
_ => unreachable!(), _ => unreachable!(),
@@ -738,15 +738,15 @@ pub fn configure_tool_output_dialog() -> Result<(), Box<dyn Error>> {
match tool_log_level { match tool_log_level {
"high" => { "high" => {
config.set("GOOSE_CLI_MIN_PRIORITY", Value::from(0.8))?; config.set_param("GOOSE_CLI_MIN_PRIORITY", Value::from(0.8))?;
cliclack::outro("Showing tool output of high importance only.")?; cliclack::outro("Showing tool output of high importance only.")?;
} }
"medium" => { "medium" => {
config.set("GOOSE_CLI_MIN_PRIORITY", Value::from(0.2))?; config.set_param("GOOSE_CLI_MIN_PRIORITY", Value::from(0.2))?;
cliclack::outro("Showing tool output of medium importance.")?; cliclack::outro("Showing tool output of medium importance.")?;
} }
"all" => { "all" => {
config.set("GOOSE_CLI_MIN_PRIORITY", Value::from(0.0))?; config.set_param("GOOSE_CLI_MIN_PRIORITY", Value::from(0.0))?;
cliclack::outro("Showing all tool output.")?; cliclack::outro("Showing all tool output.")?;
} }
_ => unreachable!(), _ => unreachable!(),

View File

@@ -21,11 +21,11 @@ pub async fn build_session(
let config = Config::global(); let config = Config::global();
let provider_name: String = config let provider_name: String = config
.get("GOOSE_PROVIDER") .get_param("GOOSE_PROVIDER")
.expect("No provider configured. Run 'goose configure' first"); .expect("No provider configured. Run 'goose configure' first");
let model: String = config let model: String = config
.get("GOOSE_MODEL") .get_param("GOOSE_MODEL")
.expect("No model configured. Run 'goose configure' first"); .expect("No model configured. Run 'goose configure' first");
let model_config = goose::model::ModelConfig::new(model.clone()); let model_config = goose::model::ModelConfig::new(model.clone());
let provider = let provider =
@@ -137,7 +137,7 @@ pub async fn build_session(
.await; .await;
// Only override system prompt if a system override exists // Only override system prompt if a system override exists
let system_prompt_file: Option<String> = config.get("GOOSE_SYSTEM_PROMPT_FILE_PATH").ok(); let system_prompt_file: Option<String> = config.get_param("GOOSE_SYSTEM_PROMPT_FILE_PATH").ok();
if let Some(ref path) = system_prompt_file { if let Some(ref path) = system_prompt_file {
let override_prompt = let override_prompt =
std::fs::read_to_string(path).expect("Failed to read system prompt file"); std::fs::read_to_string(path).expect("Failed to read system prompt file");

View File

@@ -343,7 +343,7 @@ impl Session {
} }
config config
.set("GOOSE_MODE", Value::String(mode.to_string())) .set_param("GOOSE_MODE", Value::String(mode.to_string()))
.unwrap(); .unwrap();
println!("Goose mode set to '{}'", mode); println!("Goose mode set to '{}'", mode);
continue; continue;

View File

@@ -150,7 +150,7 @@ fn render_tool_response(resp: &ToolResponse, theme: Theme, debug: bool) {
} }
let min_priority = config let min_priority = config
.get::<f32>("GOOSE_CLI_MIN_PRIORITY") .get_param::<f32>("GOOSE_CLI_MIN_PRIORITY")
.ok() .ok()
.unwrap_or(0.0); .unwrap_or(0.0);

View File

@@ -1,5 +1,8 @@
use utoipa::OpenApi; use utoipa::OpenApi;
use goose::providers::base::ConfigKey;
use goose::providers::base::ProviderMetadata;
#[allow(dead_code)] // Used by utoipa for OpenAPI generation #[allow(dead_code)] // Used by utoipa for OpenAPI generation
#[derive(OpenApi)] #[derive(OpenApi)]
#[openapi( #[openapi(
@@ -10,13 +13,19 @@ use utoipa::OpenApi;
super::routes::config_management::add_extension, super::routes::config_management::add_extension,
super::routes::config_management::remove_extension, super::routes::config_management::remove_extension,
super::routes::config_management::update_extension, super::routes::config_management::update_extension,
super::routes::config_management::read_all_config super::routes::config_management::read_all_config,
super::routes::config_management::providers
), ),
components(schemas( components(schemas(
super::routes::config_management::UpsertConfigQuery, super::routes::config_management::UpsertConfigQuery,
super::routes::config_management::ConfigKeyQuery, super::routes::config_management::ConfigKeyQuery,
super::routes::config_management::ExtensionQuery, super::routes::config_management::ExtensionQuery,
super::routes::config_management::ConfigResponse super::routes::config_management::ConfigResponse,
super::routes::config_management::ProvidersResponse,
super::routes::config_management::ProvidersResponse,
super::routes::config_management::ProviderDetails,
ProviderMetadata,
ConfigKey
)) ))
)] )]
pub struct ApiDoc; pub struct ApiDoc;

View File

@@ -121,7 +121,7 @@ async fn create_agent(
let config = Config::global(); let config = Config::global();
let model = payload.model.unwrap_or_else(|| { let model = payload.model.unwrap_or_else(|| {
config config
.get("GOOSE_MODEL") .get_param("GOOSE_MODEL")
.expect("Did not find a model on payload or in env") .expect("Did not find a model on payload or in env")
}); });
let model_config = ModelConfig::new(model); let model_config = ModelConfig::new(model);

View File

@@ -5,25 +5,42 @@ use axum::{
Json, Router, Json, Router,
}; };
use goose::config::Config; use goose::config::Config;
use http::StatusCode; use goose::providers::base::ProviderMetadata;
use goose::providers::providers as get_providers;
use http::{HeaderMap, StatusCode};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_json::Value; use serde_json::Value;
use std::{collections::HashMap, sync::Arc}; use std::collections::HashMap;
use tokio::sync::Mutex; use std::env;
use utoipa::ToSchema; use utoipa::ToSchema;
use crate::state::AppState; use crate::state::AppState;
fn verify_secret_key(headers: &HeaderMap, state: &AppState) -> Result<StatusCode, StatusCode> {
// Verify secret key
let secret_key = headers
.get("X-Secret-Key")
.and_then(|value| value.to_str().ok())
.ok_or(StatusCode::UNAUTHORIZED)?;
if secret_key != state.secret_key {
Err(StatusCode::UNAUTHORIZED)
} else {
Ok(StatusCode::OK)
}
}
#[derive(Deserialize, ToSchema)] #[derive(Deserialize, ToSchema)]
pub struct UpsertConfigQuery { pub struct UpsertConfigQuery {
pub key: String, pub key: String,
pub value: Value, pub value: Value,
pub is_secret: Option<bool>, pub is_secret: bool,
} }
#[derive(Deserialize, ToSchema)] #[derive(Deserialize, ToSchema)]
pub struct ConfigKeyQuery { pub struct ConfigKeyQuery {
pub key: String, pub key: String,
pub is_secret: bool,
} }
#[derive(Deserialize, ToSchema)] #[derive(Deserialize, ToSchema)]
@@ -37,6 +54,22 @@ pub struct ConfigResponse {
pub config: HashMap<String, Value>, pub config: HashMap<String, Value>,
} }
// Define a new structure to encapsulate the provider details along with configuration status
#[derive(Debug, Serialize, Deserialize, ToSchema)]
pub struct ProviderDetails {
/// Unique identifier and name of the provider
pub name: String,
/// Metadata about the provider
pub metadata: ProviderMetadata,
/// Indicates whether the provider is fully configured
pub is_configured: bool,
}
#[derive(Serialize, ToSchema)]
pub struct ProvidersResponse {
pub providers: Vec<ProviderDetails>,
}
#[utoipa::path( #[utoipa::path(
post, post,
path = "/config/upsert", path = "/config/upsert",
@@ -47,16 +80,15 @@ pub struct ConfigResponse {
) )
)] )]
pub async fn upsert_config( pub async fn upsert_config(
State(_state): State<Arc<Mutex<HashMap<String, Value>>>>, State(state): State<AppState>,
headers: HeaderMap,
Json(query): Json<UpsertConfigQuery>, Json(query): Json<UpsertConfigQuery>,
) -> Result<Json<Value>, StatusCode> { ) -> Result<Json<Value>, StatusCode> {
let config = Config::global(); // Use the helper function to verify the secret key
verify_secret_key(&headers, &state)?;
let result = if query.is_secret.unwrap_or(false) { let config = Config::global();
config.set_secret(&query.key, query.value) let result = config.set(&query.key, query.value, query.is_secret);
} else {
config.set(&query.key, query.value)
};
match result { match result {
Ok(_) => Ok(Json(Value::String(format!("Upserted key {}", query.key)))), Ok(_) => Ok(Json(Value::String(format!("Upserted key {}", query.key)))),
@@ -75,9 +107,13 @@ pub async fn upsert_config(
) )
)] )]
pub async fn remove_config( pub async fn remove_config(
State(_state): State<Arc<Mutex<HashMap<String, Value>>>>, State(state): State<AppState>,
headers: HeaderMap,
Json(query): Json<ConfigKeyQuery>, Json(query): Json<ConfigKeyQuery>,
) -> Result<Json<String>, StatusCode> { ) -> Result<Json<String>, StatusCode> {
// Use the helper function to verify the secret key
verify_secret_key(&headers, &state)?;
let config = Config::global(); let config = Config::global();
match config.delete(&query.key) { match config.delete(&query.key) {
@@ -96,13 +132,25 @@ pub async fn remove_config(
) )
)] )]
pub async fn read_config( pub async fn read_config(
State(_state): State<Arc<Mutex<HashMap<String, Value>>>>, State(state): State<AppState>,
headers: HeaderMap,
Json(query): Json<ConfigKeyQuery>, Json(query): Json<ConfigKeyQuery>,
) -> Result<Json<Value>, StatusCode> { ) -> Result<Json<Value>, StatusCode> {
verify_secret_key(&headers, &state)?;
let config = Config::global(); let config = Config::global();
match config.get::<Value>(&query.key) { match config.get(&query.key, query.is_secret) {
Ok(value) => Ok(Json(value)), // Always get the actual value
Ok(value) => {
if query.is_secret {
// If it's marked as secret, return a boolean indicating presence
Ok(Json(Value::Bool(true)))
} else {
// Return the actual value if not secret
Ok(Json(value))
}
}
Err(_) => Err(StatusCode::NOT_FOUND), Err(_) => Err(StatusCode::NOT_FOUND),
} }
} }
@@ -118,20 +166,25 @@ pub async fn read_config(
) )
)] )]
pub async fn add_extension( pub async fn add_extension(
State(_state): State<Arc<Mutex<HashMap<String, Value>>>>, State(state): State<AppState>,
headers: HeaderMap,
Json(extension): Json<ExtensionQuery>, Json(extension): Json<ExtensionQuery>,
) -> Result<Json<String>, StatusCode> { ) -> Result<Json<String>, StatusCode> {
// Use the helper function to verify the secret key
verify_secret_key(&headers, &state)?;
let config = Config::global(); let config = Config::global();
// Get current extensions or initialize empty map // Get current extensions or initialize empty map
let mut extensions: HashMap<String, Value> = let mut extensions: HashMap<String, Value> = config
config.get("extensions").unwrap_or_else(|_| HashMap::new()); .get_param("extensions")
.unwrap_or_else(|_| HashMap::new());
// Add new extension // Add new extension
extensions.insert(extension.name.clone(), extension.config); extensions.insert(extension.name.clone(), extension.config);
// Save updated extensions // Save updated extensions
match config.set( match config.set_param(
"extensions", "extensions",
Value::Object(extensions.into_iter().collect()), Value::Object(extensions.into_iter().collect()),
) { ) {
@@ -151,13 +204,17 @@ pub async fn add_extension(
) )
)] )]
pub async fn remove_extension( pub async fn remove_extension(
State(_state): State<Arc<Mutex<HashMap<String, Value>>>>, State(state): State<AppState>,
headers: HeaderMap,
Json(query): Json<ConfigKeyQuery>, Json(query): Json<ConfigKeyQuery>,
) -> Result<Json<String>, StatusCode> { ) -> Result<Json<String>, StatusCode> {
// Use the helper function to verify the secret key
verify_secret_key(&headers, &state)?;
let config = Config::global(); let config = Config::global();
// Get current extensions // Get current extensions
let mut extensions: HashMap<String, Value> = match config.get("extensions") { let mut extensions: HashMap<String, Value> = match config.get_param("extensions") {
Ok(exts) => exts, Ok(exts) => exts,
Err(_) => return Err(StatusCode::NOT_FOUND), Err(_) => return Err(StatusCode::NOT_FOUND),
}; };
@@ -165,7 +222,7 @@ pub async fn remove_extension(
// Remove extension if it exists // Remove extension if it exists
if extensions.remove(&query.key).is_some() { if extensions.remove(&query.key).is_some() {
// Save updated extensions // Save updated extensions
match config.set( match config.set_param(
"extensions", "extensions",
Value::Object(extensions.into_iter().collect()), Value::Object(extensions.into_iter().collect()),
) { ) {
@@ -185,8 +242,12 @@ pub async fn remove_extension(
) )
)] )]
pub async fn read_all_config( pub async fn read_all_config(
State(_state): State<Arc<Mutex<HashMap<String, Value>>>>, State(state): State<AppState>,
headers: HeaderMap,
) -> Result<Json<ConfigResponse>, StatusCode> { ) -> Result<Json<ConfigResponse>, StatusCode> {
// Use the helper function to verify the secret key
verify_secret_key(&headers, &state)?;
let config = Config::global(); let config = Config::global();
// Load values from config file // Load values from config file
@@ -206,13 +267,17 @@ pub async fn read_all_config(
) )
)] )]
pub async fn update_extension( pub async fn update_extension(
State(_state): State<Arc<Mutex<HashMap<String, Value>>>>, State(state): State<AppState>,
headers: HeaderMap,
Json(extension): Json<ExtensionQuery>, Json(extension): Json<ExtensionQuery>,
) -> Result<Json<String>, StatusCode> { ) -> Result<Json<String>, StatusCode> {
// Use the helper function to verify the secret key
verify_secret_key(&headers, &state)?;
let config = Config::global(); let config = Config::global();
// Get current extensions // Get current extensions
let mut extensions: HashMap<String, Value> = match config.get("extensions") { let mut extensions: HashMap<String, Value> = match config.get_param("extensions") {
Ok(exts) => exts, Ok(exts) => exts,
Err(_) => return Err(StatusCode::NOT_FOUND), Err(_) => return Err(StatusCode::NOT_FOUND),
}; };
@@ -226,7 +291,7 @@ pub async fn update_extension(
extensions.insert(extension.name.clone(), extension.config); extensions.insert(extension.name.clone(), extension.config);
// Save updated extensions // Save updated extensions
match config.set( match config.set_param(
"extensions", "extensions",
Value::Object(extensions.into_iter().collect()), Value::Object(extensions.into_iter().collect()),
) { ) {
@@ -235,6 +300,66 @@ pub async fn update_extension(
} }
} }
// Modified providers function using the new response type
#[utoipa::path(
get,
path = "/config/providers",
responses(
(status = 200, description = "All configuration values retrieved successfully", body = [ProviderDetails])
)
)]
pub async fn providers(
State(state): State<AppState>,
headers: HeaderMap,
) -> Result<Json<Vec<ProviderDetails>>, StatusCode> {
verify_secret_key(&headers, &state)?;
// Fetch the list of providers, which are likely stored in the AppState or can be retrieved via a function call
let providers_metadata = get_providers();
// Construct the response by checking configuration status for each provider
let providers_response: Vec<ProviderDetails> = providers_metadata
.into_iter()
.map(|metadata| {
// Check if the provider is configured (this will depend on how you track configuration status)
let is_configured = check_provider_configured(&metadata);
ProviderDetails {
name: metadata.name.clone(),
metadata,
is_configured,
}
})
.collect();
Ok(Json(providers_response))
}
fn check_provider_configured(metadata: &ProviderMetadata) -> bool {
let config = Config::global();
// Check all required keys for the provider
for key in &metadata.config_keys {
if key.required {
let key_name = &key.name;
// First, check if the key is set in the environment
let is_set_in_env = env::var(key_name).is_ok();
// If not set in environment, check the config file based on whether it's a secret or not
let is_set_in_config = config.get(key_name, key.secret).is_ok();
// If the key is neither in the environment nor in the config, the provider is not configured
if !is_set_in_env && !is_set_in_config {
return false;
}
}
}
// If all required keys are accounted for, the provider is considered configured
true
}
pub fn routes(state: AppState) -> Router { pub fn routes(state: AppState) -> Router {
Router::new() Router::new()
.route("/config", get(read_all_config)) .route("/config", get(read_all_config))
@@ -244,5 +369,6 @@ pub fn routes(state: AppState) -> Router {
.route("/config/extension", post(add_extension)) .route("/config/extension", post(add_extension))
.route("/config/extension", put(update_extension)) .route("/config/extension", put(update_extension))
.route("/config/extension", delete(remove_extension)) .route("/config/extension", delete(remove_extension))
.with_state(state.config) .route("/config/providers", get(providers))
.with_state(state)
} }

View File

@@ -43,7 +43,7 @@ async fn store_config(
let result = if request.is_secret { let result = if request.is_secret {
config.set_secret(&request.key, Value::String(request.value)) config.set_secret(&request.key, Value::String(request.value))
} else { } else {
config.set(&request.key, Value::String(request.value)) config.set_param(&request.key, Value::String(request.value))
}; };
match result { match result {
Ok(_) => Ok(Json(ConfigResponse { error: false })), Ok(_) => Ok(Json(ConfigResponse { error: false })),
@@ -87,7 +87,7 @@ static PROVIDER_ENV_REQUIREMENTS: Lazy<HashMap<String, ProviderConfig>> = Lazy::
fn check_key_status(config: &Config, key: &str) -> (bool, Option<String>) { fn check_key_status(config: &Config, key: &str) -> (bool, Option<String>) {
if let Ok(_value) = std::env::var(key) { if let Ok(_value) = std::env::var(key) {
(true, Some("env".to_string())) (true, Some("env".to_string()))
} else if config.get::<String>(key).is_ok() { } else if config.get_param::<String>(key).is_ok() {
(true, Some("yaml".to_string())) (true, Some("yaml".to_string()))
} else if config.get_secret::<String>(key).is_ok() { } else if config.get_secret::<String>(key).is_ok() {
(true, Some("keyring".to_string())) (true, Some("keyring".to_string()))
@@ -171,7 +171,7 @@ pub async fn get_config(
// Fetch the configuration value. Right now we don't allow get a secret. // Fetch the configuration value. Right now we don't allow get a secret.
let config = Config::global(); let config = Config::global();
let value = if let Ok(config_value) = config.get::<String>(&query.key) { let value = if let Ok(config_value) = config.get_param::<String>(&query.key) {
Some(config_value) Some(config_value)
} else if let Ok(env_value) = std::env::var(&query.key) { } else if let Ok(env_value) = std::env::var(&query.key) {
Some(env_value) Some(env_value)

View File

@@ -0,0 +1,92 @@
use serde::{Deserialize, Serialize};
use std::error::Error;
use goose::config::Config;
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub enum KeyLocation {
Environment,
ConfigFile,
Keychain,
NotFound
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct KeyInfo {
pub name: String,
pub is_set: bool,
pub location: KeyLocation,
pub is_secret: bool,
pub value: Option<String>, // Only populated for non-secret keys that are set
}
/// Inspects a configuration key to determine if it's set, its location, and value (for non-secret keys)
pub fn inspect_key(
key_name: &str,
is_secret: bool,
) -> Result<KeyInfo, Box<dyn Error>> {
let config = Config::global();
// Check environment variable first
let env_value = std::env::var(key_name).ok();
if let Some(value) = env_value {
return Ok(KeyInfo {
name: key_name.to_string(),
is_set: true,
location: KeyLocation::Environment,
is_secret,
// Only include value for non-secret keys
value: if !is_secret { Some(value) } else { None },
});
}
// Check config store
let config_result = if is_secret {
config.get_secret(key_name).map(|v| (v, true))
} else {
config.get(key_name).map(|v| (v, false))
};
match config_result {
Ok((value, is_secret_actual)) => {
// Determine location based on whether it's a secret value
let location = if is_secret_actual {
KeyLocation::Keychain
} else {
KeyLocation::ConfigFile
};
Ok(KeyInfo {
name: key_name.to_string(),
is_set: true,
location,
is_secret: is_secret_actual,
// Only include value for non-secret keys
value: if !is_secret_actual { Some(value) } else { None },
})
},
Err(_) => {
Ok(KeyInfo {
name: key_name.to_string(),
is_set: false,
location: KeyLocation::NotFound,
is_secret,
value: None,
})
}
}
}
/// Inspects multiple keys at once
pub fn inspect_keys(
keys: &[(String, bool)], // (name, is_secret) pairs
) -> Result<Vec<KeyInfo>, Box<dyn Error>> {
let mut results = Vec::new();
for (key_name, is_secret) in keys {
let info = inspect_key(key_name, *is_secret)?;
results.push(info);
}
Ok(results)
}

View File

@@ -60,6 +60,7 @@ serde_yaml = "0.9.34"
once_cell = "1.20.2" once_cell = "1.20.2"
etcetera = "0.8.0" etcetera = "0.8.0"
rand = "0.8.5" rand = "0.8.5"
utoipa = { version = "4.1" }
# For Bedrock provider # For Bedrock provider
aws-config = { version = "1.1.7", features = ["behavior-version-latest"] } aws-config = { version = "1.1.7", features = ["behavior-version-latest"] }

View File

@@ -351,7 +351,7 @@ impl Capabilities {
let mut system_prompt_extensions = self.system_prompt_extensions.clone(); let mut system_prompt_extensions = self.system_prompt_extensions.clone();
let config = Config::global(); let config = Config::global();
let goose_mode = config.get("GOOSE_MODE").unwrap_or("auto".to_string()); let goose_mode = config.get_param("GOOSE_MODE").unwrap_or("auto".to_string());
if goose_mode == "chat" { if goose_mode == "chat" {
system_prompt_extensions.push( system_prompt_extensions.push(
"Right now you are in the chat only mode, no access to any tool use and system." "Right now you are in the chat only mode, no access to any tool use and system."

View File

@@ -50,7 +50,7 @@ impl AgentFactory {
pub fn configured_version() -> String { pub fn configured_version() -> String {
let config = Config::global(); let config = Config::global();
config config
.get::<String>("GOOSE_AGENT") .get_param::<String>("GOOSE_AGENT")
.unwrap_or_else(|_| Self::default_version().to_string()) .unwrap_or_else(|_| Self::default_version().to_string())
} }

View File

@@ -177,7 +177,7 @@ impl Agent for SummarizeAgent {
// Load settings from config // Load settings from config
let config = Config::global(); let config = Config::global();
let goose_mode = config.get("GOOSE_MODE").unwrap_or("auto".to_string()); let goose_mode = config.get_param("GOOSE_MODE").unwrap_or("auto".to_string());
// we add in the 2 resource tools if any extensions support resources // we add in the 2 resource tools if any extensions support resources
// TODO: make sure there is no collision with another extension's tool name // TODO: make sure there is no collision with another extension's tool name

View File

@@ -171,7 +171,7 @@ impl Agent for TruncateAgent {
// Load settings from config // Load settings from config
let config = Config::global(); let config = Config::global();
let goose_mode = config.get("GOOSE_MODE").unwrap_or("auto".to_string()); let goose_mode = config.get_param("GOOSE_MODE").unwrap_or("auto".to_string());
// we add in the 2 resource tools if any extensions support resources // we add in the 2 resource tools if any extensions support resources
// TODO: make sure there is no collision with another extension's tool name // TODO: make sure there is no collision with another extension's tool name

View File

@@ -78,7 +78,7 @@ impl From<keyring::Error> for ConfigError {
/// ///
/// // Get a string value /// // Get a string value
/// let config = Config::global(); /// let config = Config::global();
/// let api_key: String = config.get("OPENAI_API_KEY").unwrap(); /// let api_key: String = config.get_param("OPENAI_API_KEY").unwrap();
/// ///
/// // Get a complex type /// // Get a complex type
/// #[derive(Deserialize)] /// #[derive(Deserialize)]
@@ -87,7 +87,7 @@ impl From<keyring::Error> for ConfigError {
/// port: u16, /// port: u16,
/// } /// }
/// ///
/// let server_config: ServerConfig = config.get("server").unwrap(); /// let server_config: ServerConfig = config.get_param("server").unwrap();
/// ``` /// ```
/// ///
/// # Naming Convention /// # Naming Convention
@@ -204,7 +204,25 @@ impl Config {
} }
} }
/// Get a configuration value. // check all possible places for a parameter
pub fn get(&self, key: &str, is_secret: bool) -> Result<Value, ConfigError> {
if is_secret {
self.get_secret(key)
} else {
self.get_param(key)
}
}
// save a parameter in the appropriate location based on if it's secret or not
pub fn set(&self, key: &str, value: Value, is_secret: bool) -> Result<(), ConfigError> {
if is_secret {
self.set_secret(key, value)
} else {
self.set_param(key, value)
}
}
/// Get a configuration value (non-secret).
/// ///
/// This will attempt to get the value from: /// This will attempt to get the value from:
/// 1. Environment variable with the exact key name /// 1. Environment variable with the exact key name
@@ -220,7 +238,7 @@ impl Config {
/// - The key doesn't exist in either environment or config file /// - The key doesn't exist in either environment or config file
/// - The value cannot be deserialized into the requested type /// - The value cannot be deserialized into the requested type
/// - There is an error reading the config file /// - There is an error reading the config file
pub fn get<T: for<'de> Deserialize<'de>>(&self, key: &str) -> Result<T, ConfigError> { pub fn get_param<T: for<'de> Deserialize<'de>>(&self, key: &str) -> Result<T, ConfigError> {
// First check environment variables (convert to uppercase) // First check environment variables (convert to uppercase)
let env_key = key.to_uppercase(); let env_key = key.to_uppercase();
if let Ok(val) = env::var(&env_key) { if let Ok(val) = env::var(&env_key) {
@@ -239,7 +257,7 @@ impl Config {
.and_then(|v| Ok(serde_json::from_value(v.clone())?)) .and_then(|v| Ok(serde_json::from_value(v.clone())?))
} }
/// Set a configuration value in the config file. /// Set a configuration value in the config file (non-secret).
/// ///
/// This will immediately write the value to the config file. The value /// This will immediately write the value to the config file. The value
/// can be any type that can be serialized to JSON/YAML. /// can be any type that can be serialized to JSON/YAML.
@@ -252,7 +270,7 @@ impl Config {
/// Returns a ConfigError if: /// Returns a ConfigError if:
/// - There is an error reading or writing the config file /// - There is an error reading or writing the config file
/// - There is an error serializing the value /// - There is an error serializing the value
pub fn set(&self, key: &str, value: Value) -> Result<(), ConfigError> { pub fn set_param(&self, key: &str, value: Value) -> Result<(), ConfigError> {
let mut values = self.load_values()?; let mut values = self.load_values()?;
values.insert(key.to_string(), value); values.insert(key.to_string(), value);
@@ -377,15 +395,15 @@ mod tests {
let config = Config::new(temp_file.path(), TEST_KEYRING_SERVICE)?; let config = Config::new(temp_file.path(), TEST_KEYRING_SERVICE)?;
// Set a simple string value // Set a simple string value
config.set("test_key", Value::String("test_value".to_string()))?; config.set_param("test_key", Value::String("test_value".to_string()))?;
// Test simple string retrieval // Test simple string retrieval
let value: String = config.get("test_key")?; let value: String = config.get_param("test_key")?;
assert_eq!(value, "test_value"); assert_eq!(value, "test_value");
// Test with environment variable override // Test with environment variable override
std::env::set_var("TEST_KEY", "env_value"); std::env::set_var("TEST_KEY", "env_value");
let value: String = config.get("test_key")?; let value: String = config.get_param("test_key")?;
assert_eq!(value, "env_value"); assert_eq!(value, "env_value");
Ok(()) Ok(())
@@ -403,7 +421,7 @@ mod tests {
let config = Config::new(temp_file.path(), TEST_KEYRING_SERVICE)?; let config = Config::new(temp_file.path(), TEST_KEYRING_SERVICE)?;
// Set a complex value // Set a complex value
config.set( config.set_param(
"complex_key", "complex_key",
serde_json::json!({ serde_json::json!({
"field1": "hello", "field1": "hello",
@@ -411,7 +429,7 @@ mod tests {
}), }),
)?; )?;
let value: TestStruct = config.get("complex_key")?; let value: TestStruct = config.get_param("complex_key")?;
assert_eq!(value.field1, "hello"); assert_eq!(value.field1, "hello");
assert_eq!(value.field2, 42); assert_eq!(value.field2, 42);
@@ -423,7 +441,7 @@ mod tests {
let temp_file = NamedTempFile::new().unwrap(); let temp_file = NamedTempFile::new().unwrap();
let config = Config::new(temp_file.path(), TEST_KEYRING_SERVICE).unwrap(); let config = Config::new(temp_file.path(), TEST_KEYRING_SERVICE).unwrap();
let result: Result<String, ConfigError> = config.get("nonexistent_key"); let result: Result<String, ConfigError> = config.get_param("nonexistent_key");
assert!(matches!(result, Err(ConfigError::NotFound(_)))); assert!(matches!(result, Err(ConfigError::NotFound(_))));
} }
@@ -432,8 +450,8 @@ mod tests {
let temp_file = NamedTempFile::new().unwrap(); let temp_file = NamedTempFile::new().unwrap();
let config = Config::new(temp_file.path(), TEST_KEYRING_SERVICE)?; let config = Config::new(temp_file.path(), TEST_KEYRING_SERVICE)?;
config.set("key1", Value::String("value1".to_string()))?; config.set_param("key1", Value::String("value1".to_string()))?;
config.set("key2", Value::Number(42.into()))?; config.set_param("key2", Value::Number(42.into()))?;
// Read the file directly to check YAML formatting // Read the file directly to check YAML formatting
let content = std::fs::read_to_string(temp_file.path())?; let content = std::fs::read_to_string(temp_file.path())?;
@@ -448,14 +466,14 @@ mod tests {
let temp_file = NamedTempFile::new().unwrap(); let temp_file = NamedTempFile::new().unwrap();
let config = Config::new(temp_file.path(), TEST_KEYRING_SERVICE)?; let config = Config::new(temp_file.path(), TEST_KEYRING_SERVICE)?;
config.set("key", Value::String("value".to_string()))?; config.set_param("key", Value::String("value".to_string()))?;
let value: String = config.get("key")?; let value: String = config.get_param("key")?;
assert_eq!(value, "value"); assert_eq!(value, "value");
config.delete("key")?; config.delete("key")?;
let result: Result<String, ConfigError> = config.get("key"); let result: Result<String, ConfigError> = config.get_param("key");
assert!(matches!(result, Err(ConfigError::NotFound(_)))); assert!(matches!(result, Err(ConfigError::NotFound(_))));
Ok(()) Ok(())

View File

@@ -18,7 +18,8 @@ impl ExperimentManager {
/// - Removes experiments not in `ALL_EXPERIMENTS`. /// - Removes experiments not in `ALL_EXPERIMENTS`.
pub fn get_all() -> Result<Vec<(String, bool)>> { pub fn get_all() -> Result<Vec<(String, bool)>> {
let config = Config::global(); let config = Config::global();
let mut experiments: HashMap<String, bool> = config.get("experiments").unwrap_or_default(); let mut experiments: HashMap<String, bool> =
config.get_param("experiments").unwrap_or_default();
Self::refresh_experiments(&mut experiments); Self::refresh_experiments(&mut experiments);
Ok(experiments.into_iter().collect()) Ok(experiments.into_iter().collect())
@@ -27,12 +28,13 @@ impl ExperimentManager {
/// Enable or disable an experiment /// Enable or disable an experiment
pub fn set_enabled(name: &str, enabled: bool) -> Result<()> { pub fn set_enabled(name: &str, enabled: bool) -> Result<()> {
let config = Config::global(); let config = Config::global();
let mut experiments: HashMap<String, bool> = let mut experiments: HashMap<String, bool> = config
config.get("experiments").unwrap_or_else(|_| HashMap::new()); .get_param("experiments")
.unwrap_or_else(|_| HashMap::new());
Self::refresh_experiments(&mut experiments); Self::refresh_experiments(&mut experiments);
experiments.insert(name.to_string(), enabled); experiments.insert(name.to_string(), enabled);
config.set("experiments", serde_json::to_value(experiments)?)?; config.set_param("experiments", serde_json::to_value(experiments)?)?;
Ok(()) Ok(())
} }

View File

@@ -23,7 +23,7 @@ impl ExtensionManager {
let config = Config::global(); let config = Config::global();
// Try to get the extension entry // Try to get the extension entry
let extensions: HashMap<String, ExtensionEntry> = match config.get("extensions") { let extensions: HashMap<String, ExtensionEntry> = match config.get_param("extensions") {
Ok(exts) => exts, Ok(exts) => exts,
Err(super::ConfigError::NotFound(_)) => { Err(super::ConfigError::NotFound(_)) => {
// Initialize with default developer extension // Initialize with default developer extension
@@ -37,7 +37,7 @@ impl ExtensionManager {
}, },
}, },
)]); )]);
config.set("extensions", serde_json::to_value(&defaults)?)?; config.set_param("extensions", serde_json::to_value(&defaults)?)?;
defaults defaults
} }
Err(e) => return Err(e.into()), Err(e) => return Err(e.into()),
@@ -56,11 +56,12 @@ impl ExtensionManager {
pub fn set(entry: ExtensionEntry) -> Result<()> { pub fn set(entry: ExtensionEntry) -> Result<()> {
let config = Config::global(); let config = Config::global();
let mut extensions: HashMap<String, ExtensionEntry> = let mut extensions: HashMap<String, ExtensionEntry> = config
config.get("extensions").unwrap_or_else(|_| HashMap::new()); .get_param("extensions")
.unwrap_or_else(|_| HashMap::new());
extensions.insert(entry.config.name().parse()?, entry); extensions.insert(entry.config.name().parse()?, entry);
config.set("extensions", serde_json::to_value(extensions)?)?; config.set_param("extensions", serde_json::to_value(extensions)?)?;
Ok(()) Ok(())
} }
@@ -68,11 +69,12 @@ impl ExtensionManager {
pub fn remove(name: &str) -> Result<()> { pub fn remove(name: &str) -> Result<()> {
let config = Config::global(); let config = Config::global();
let mut extensions: HashMap<String, ExtensionEntry> = let mut extensions: HashMap<String, ExtensionEntry> = config
config.get("extensions").unwrap_or_else(|_| HashMap::new()); .get_param("extensions")
.unwrap_or_else(|_| HashMap::new());
extensions.remove(name); extensions.remove(name);
config.set("extensions", serde_json::to_value(extensions)?)?; config.set_param("extensions", serde_json::to_value(extensions)?)?;
Ok(()) Ok(())
} }
@@ -80,12 +82,13 @@ impl ExtensionManager {
pub fn set_enabled(name: &str, enabled: bool) -> Result<()> { pub fn set_enabled(name: &str, enabled: bool) -> Result<()> {
let config = Config::global(); let config = Config::global();
let mut extensions: HashMap<String, ExtensionEntry> = let mut extensions: HashMap<String, ExtensionEntry> = config
config.get("extensions").unwrap_or_else(|_| HashMap::new()); .get_param("extensions")
.unwrap_or_else(|_| HashMap::new());
if let Some(entry) = extensions.get_mut(name) { if let Some(entry) = extensions.get_mut(name) {
entry.enabled = enabled; entry.enabled = enabled;
config.set("extensions", serde_json::to_value(extensions)?)?; config.set_param("extensions", serde_json::to_value(extensions)?)?;
} }
Ok(()) Ok(())
} }
@@ -94,7 +97,7 @@ impl ExtensionManager {
pub fn get_all() -> Result<Vec<ExtensionEntry>> { pub fn get_all() -> Result<Vec<ExtensionEntry>> {
let config = Config::global(); let config = Config::global();
let extensions: HashMap<String, ExtensionEntry> = let extensions: HashMap<String, ExtensionEntry> =
config.get("extensions").unwrap_or_default(); config.get_param("extensions").unwrap_or_default();
Ok(Vec::from_iter(extensions.values().cloned())) Ok(Vec::from_iter(extensions.values().cloned()))
} }
@@ -102,15 +105,16 @@ impl ExtensionManager {
pub fn get_all_names() -> Result<Vec<String>> { pub fn get_all_names() -> Result<Vec<String>> {
let config = Config::global(); let config = Config::global();
Ok(config Ok(config
.get("extensions") .get_param("extensions")
.unwrap_or_else(|_| get_keys(Default::default()))) .unwrap_or_else(|_| get_keys(Default::default())))
} }
/// Check if an extension is enabled /// Check if an extension is enabled
pub fn is_enabled(name: &str) -> Result<bool> { pub fn is_enabled(name: &str) -> Result<bool> {
let config = Config::global(); let config = Config::global();
let extensions: HashMap<String, ExtensionEntry> = let extensions: HashMap<String, ExtensionEntry> = config
config.get("extensions").unwrap_or_else(|_| HashMap::new()); .get_param("extensions")
.unwrap_or_else(|_| HashMap::new());
Ok(extensions.get(name).map(|e| e.enabled).unwrap_or(false)) Ok(extensions.get(name).map(|e| e.enabled).unwrap_or(false))
} }

View File

@@ -45,7 +45,7 @@ impl AnthropicProvider {
let config = crate::config::Config::global(); let config = crate::config::Config::global();
let api_key: String = config.get_secret("ANTHROPIC_API_KEY")?; let api_key: String = config.get_secret("ANTHROPIC_API_KEY")?;
let host: String = config let host: String = config
.get("ANTHROPIC_HOST") .get_param("ANTHROPIC_HOST")
.unwrap_or_else(|_| "https://api.anthropic.com".to_string()); .unwrap_or_else(|_| "https://api.anthropic.com".to_string());
let client = Client::builder() let client = Client::builder()

View File

@@ -40,10 +40,10 @@ impl AzureProvider {
pub fn from_env(model: ModelConfig) -> Result<Self> { pub fn from_env(model: ModelConfig) -> Result<Self> {
let config = crate::config::Config::global(); let config = crate::config::Config::global();
let api_key: String = config.get_secret("AZURE_OPENAI_API_KEY")?; let api_key: String = config.get_secret("AZURE_OPENAI_API_KEY")?;
let endpoint: String = config.get("AZURE_OPENAI_ENDPOINT")?; let endpoint: String = config.get_param("AZURE_OPENAI_ENDPOINT")?;
let deployment_name: String = config.get("AZURE_OPENAI_DEPLOYMENT_NAME")?; let deployment_name: String = config.get_param("AZURE_OPENAI_DEPLOYMENT_NAME")?;
let api_version: String = config let api_version: String = config
.get("AZURE_OPENAI_API_VERSION") .get_param("AZURE_OPENAI_API_VERSION")
.unwrap_or_else(|_| AZURE_DEFAULT_API_VERSION.to_string()); .unwrap_or_else(|_| AZURE_DEFAULT_API_VERSION.to_string());
let client = Client::builder() let client = Client::builder()
@@ -109,18 +109,8 @@ impl Provider for AzureProvider {
vec![ vec![
ConfigKey::new("AZURE_OPENAI_API_KEY", true, true, None), ConfigKey::new("AZURE_OPENAI_API_KEY", true, true, None),
ConfigKey::new("AZURE_OPENAI_ENDPOINT", true, false, None), ConfigKey::new("AZURE_OPENAI_ENDPOINT", true, false, None),
ConfigKey::new( ConfigKey::new("AZURE_OPENAI_DEPLOYMENT_NAME", true, false, None),
"AZURE_OPENAI_DEPLOYMENT_NAME", ConfigKey::new("AZURE_OPENAI_API_VERSION", false, false, Some("2024-10-21")),
true,
false,
Some("Name of your Azure OpenAI deployment"),
),
ConfigKey::new(
"AZURE_OPENAI_API_VERSION",
false,
false,
Some("Azure OpenAI API version, default: 2024-10-21"),
),
], ],
) )
} }

View File

@@ -5,9 +5,10 @@ use super::errors::ProviderError;
use crate::message::Message; use crate::message::Message;
use crate::model::ModelConfig; use crate::model::ModelConfig;
use mcp_core::tool::Tool; use mcp_core::tool::Tool;
use utoipa::ToSchema;
/// Metadata about a provider's configuration requirements and capabilities /// Metadata about a provider's configuration requirements and capabilities
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize, ToSchema)]
pub struct ProviderMetadata { pub struct ProviderMetadata {
/// The unique identifier for this provider /// The unique identifier for this provider
pub name: String, pub name: String,
@@ -60,7 +61,7 @@ impl ProviderMetadata {
} }
} }
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize, ToSchema)]
pub struct ConfigKey { pub struct ConfigKey {
pub name: String, pub name: String,
pub required: bool, pub required: bool,

View File

@@ -83,7 +83,7 @@ impl DatabricksProvider {
// For compatibility for now we check both config and secret for databricks host // For compatibility for now we check both config and secret for databricks host
// but it is not actually a secret value // but it is not actually a secret value
let mut host: Result<String, ConfigError> = config.get("DATABRICKS_HOST"); let mut host: Result<String, ConfigError> = config.get_param("DATABRICKS_HOST");
if host.is_err() { if host.is_err() {
host = config.get_secret("DATABRICKS_HOST") host = config.get_secret("DATABRICKS_HOST")

View File

@@ -146,7 +146,7 @@ impl GcpVertexAIProvider {
/// * `model` - Configuration for the model to be used /// * `model` - Configuration for the model to be used
async fn new_async(model: ModelConfig) -> Result<Self> { async fn new_async(model: ModelConfig) -> Result<Self> {
let config = crate::config::Config::global(); let config = crate::config::Config::global();
let project_id = config.get("GCP_PROJECT_ID")?; let project_id = config.get_param("GCP_PROJECT_ID")?;
let location = Self::determine_location(config)?; let location = Self::determine_location(config)?;
let host = format!("https://{}-aiplatform.googleapis.com", location); let host = format!("https://{}-aiplatform.googleapis.com", location);
@@ -173,25 +173,25 @@ impl GcpVertexAIProvider {
/// Loads retry configuration from environment variables or uses defaults. /// Loads retry configuration from environment variables or uses defaults.
fn load_retry_config(config: &crate::config::Config) -> RetryConfig { fn load_retry_config(config: &crate::config::Config) -> RetryConfig {
let max_retries = config let max_retries = config
.get("GCP_MAX_RETRIES") .get_param("GCP_MAX_RETRIES")
.ok() .ok()
.and_then(|v: String| v.parse::<usize>().ok()) .and_then(|v: String| v.parse::<usize>().ok())
.unwrap_or(DEFAULT_MAX_RETRIES); .unwrap_or(DEFAULT_MAX_RETRIES);
let initial_interval_ms = config let initial_interval_ms = config
.get("GCP_INITIAL_RETRY_INTERVAL_MS") .get_param("GCP_INITIAL_RETRY_INTERVAL_MS")
.ok() .ok()
.and_then(|v: String| v.parse::<u64>().ok()) .and_then(|v: String| v.parse::<u64>().ok())
.unwrap_or(DEFAULT_INITIAL_RETRY_INTERVAL_MS); .unwrap_or(DEFAULT_INITIAL_RETRY_INTERVAL_MS);
let backoff_multiplier = config let backoff_multiplier = config
.get("GCP_BACKOFF_MULTIPLIER") .get_param("GCP_BACKOFF_MULTIPLIER")
.ok() .ok()
.and_then(|v: String| v.parse::<f64>().ok()) .and_then(|v: String| v.parse::<f64>().ok())
.unwrap_or(DEFAULT_BACKOFF_MULTIPLIER); .unwrap_or(DEFAULT_BACKOFF_MULTIPLIER);
let max_interval_ms = config let max_interval_ms = config
.get("GCP_MAX_RETRY_INTERVAL_MS") .get_param("GCP_MAX_RETRY_INTERVAL_MS")
.ok() .ok()
.and_then(|v: String| v.parse::<u64>().ok()) .and_then(|v: String| v.parse::<u64>().ok())
.unwrap_or(DEFAULT_MAX_RETRY_INTERVAL_MS); .unwrap_or(DEFAULT_MAX_RETRY_INTERVAL_MS);
@@ -211,7 +211,7 @@ impl GcpVertexAIProvider {
/// 2. Global default location (Iowa) /// 2. Global default location (Iowa)
fn determine_location(config: &crate::config::Config) -> Result<String> { fn determine_location(config: &crate::config::Config) -> Result<String> {
Ok(config Ok(config
.get("GCP_LOCATION") .get_param("GCP_LOCATION")
.ok() .ok()
.filter(|location: &String| !location.trim().is_empty()) .filter(|location: &String| !location.trim().is_empty())
.unwrap_or_else(|| Iowa.to_string())) .unwrap_or_else(|| Iowa.to_string()))

View File

@@ -50,7 +50,7 @@ impl GoogleProvider {
let config = crate::config::Config::global(); let config = crate::config::Config::global();
let api_key: String = config.get_secret("GOOGLE_API_KEY")?; let api_key: String = config.get_secret("GOOGLE_API_KEY")?;
let host: String = config let host: String = config
.get("GOOGLE_HOST") .get_param("GOOGLE_HOST")
.unwrap_or_else(|_| GOOGLE_API_HOST.to_string()); .unwrap_or_else(|_| GOOGLE_API_HOST.to_string());
let client = Client::builder() let client = Client::builder()

View File

@@ -39,7 +39,7 @@ impl GroqProvider {
let config = crate::config::Config::global(); let config = crate::config::Config::global();
let api_key: String = config.get_secret("GROQ_API_KEY")?; let api_key: String = config.get_secret("GROQ_API_KEY")?;
let host: String = config let host: String = config
.get("GROQ_HOST") .get_param("GROQ_HOST")
.unwrap_or_else(|_| GROQ_API_HOST.to_string()); .unwrap_or_else(|_| GROQ_API_HOST.to_string());
let client = Client::builder() let client = Client::builder()

View File

@@ -39,7 +39,7 @@ impl OllamaProvider {
pub fn from_env(model: ModelConfig) -> Result<Self> { pub fn from_env(model: ModelConfig) -> Result<Self> {
let config = crate::config::Config::global(); let config = crate::config::Config::global();
let host: String = config let host: String = config
.get("OLLAMA_HOST") .get_param("OLLAMA_HOST")
.unwrap_or_else(|_| OLLAMA_HOST.to_string()); .unwrap_or_else(|_| OLLAMA_HOST.to_string());
let client = Client::builder() let client = Client::builder()

View File

@@ -47,13 +47,13 @@ impl OpenAiProvider {
let config = crate::config::Config::global(); let config = crate::config::Config::global();
let api_key: String = config.get_secret("OPENAI_API_KEY")?; let api_key: String = config.get_secret("OPENAI_API_KEY")?;
let host: String = config let host: String = config
.get("OPENAI_HOST") .get_param("OPENAI_HOST")
.unwrap_or_else(|_| "https://api.openai.com".to_string()); .unwrap_or_else(|_| "https://api.openai.com".to_string());
let base_path: String = config let base_path: String = config
.get("OPENAI_BASE_PATH") .get_param("OPENAI_BASE_PATH")
.unwrap_or_else(|_| "v1/chat/completions".to_string()); .unwrap_or_else(|_| "v1/chat/completions".to_string());
let organization: Option<String> = config.get("OPENAI_ORGANIZATION").ok(); let organization: Option<String> = config.get_param("OPENAI_ORGANIZATION").ok();
let project: Option<String> = config.get("OPENAI_PROJECT").ok(); let project: Option<String> = config.get_param("OPENAI_PROJECT").ok();
let client = Client::builder() let client = Client::builder()
.timeout(Duration::from_secs(600)) .timeout(Duration::from_secs(600))
.build()?; .build()?;

View File

@@ -44,7 +44,7 @@ impl OpenRouterProvider {
let config = crate::config::Config::global(); let config = crate::config::Config::global();
let api_key: String = config.get_secret("OPENROUTER_API_KEY")?; let api_key: String = config.get_secret("OPENROUTER_API_KEY")?;
let host: String = config let host: String = config
.get("OPENROUTER_HOST") .get_param("OPENROUTER_HOST")
.unwrap_or_else(|_| "https://openrouter.ai".to_string()); .unwrap_or_else(|_| "https://openrouter.ai".to_string());
let client = Client::builder() let client = Client::builder()

View File

@@ -137,6 +137,29 @@
} }
} }
}, },
"/config/providers": {
"get": {
"tags": [
"super::routes::config_management"
],
"operationId": "providers",
"responses": {
"200": {
"description": "All configuration values retrieved successfully",
"content": {
"application/json": {
"schema": {
"type": "array",
"items": {
"$ref": "#/components/schemas/ProviderDetails"
}
}
}
}
}
}
}
},
"/config/read": { "/config/read": {
"get": { "get": {
"tags": [ "tags": [
@@ -240,12 +263,39 @@
}, },
"components": { "components": {
"schemas": { "schemas": {
"ConfigKey": {
"type": "object",
"required": [
"name",
"required",
"secret"
],
"properties": {
"default": {
"type": "string",
"nullable": true
},
"name": {
"type": "string"
},
"required": {
"type": "boolean"
},
"secret": {
"type": "boolean"
}
}
},
"ConfigKeyQuery": { "ConfigKeyQuery": {
"type": "object", "type": "object",
"required": [ "required": [
"key" "key",
"is_secret"
], ],
"properties": { "properties": {
"is_secret": {
"type": "boolean"
},
"key": { "key": {
"type": "string" "type": "string"
} }
@@ -276,16 +326,100 @@
} }
} }
}, },
"ProviderDetails": {
"type": "object",
"required": [
"name",
"metadata",
"is_configured"
],
"properties": {
"is_configured": {
"type": "boolean",
"description": "Indicates whether the provider is fully configured"
},
"metadata": {
"$ref": "#/components/schemas/ProviderMetadata"
},
"name": {
"type": "string",
"description": "Unique identifier and name of the provider"
}
}
},
"ProviderMetadata": {
"type": "object",
"description": "Metadata about a provider's configuration requirements and capabilities",
"required": [
"name",
"display_name",
"description",
"default_model",
"known_models",
"model_doc_link",
"config_keys"
],
"properties": {
"config_keys": {
"type": "array",
"items": {
"$ref": "#/components/schemas/ConfigKey"
},
"description": "Required configuration keys"
},
"default_model": {
"type": "string",
"description": "The default/recommended model for this provider"
},
"description": {
"type": "string",
"description": "Description of the provider's capabilities"
},
"display_name": {
"type": "string",
"description": "Display name for the provider in UIs"
},
"known_models": {
"type": "array",
"items": {
"type": "string"
},
"description": "A list of currently known models\nTODO: eventually query the apis directly"
},
"model_doc_link": {
"type": "string",
"description": "Link to the docs where models can be found"
},
"name": {
"type": "string",
"description": "The unique identifier for this provider"
}
}
},
"ProvidersResponse": {
"type": "object",
"required": [
"providers"
],
"properties": {
"providers": {
"type": "array",
"items": {
"$ref": "#/components/schemas/ProviderDetails"
}
}
}
},
"UpsertConfigQuery": { "UpsertConfigQuery": {
"type": "object", "type": "object",
"required": [ "required": [
"key", "key",
"value" "value",
"is_secret"
], ],
"properties": { "properties": {
"is_secret": { "is_secret": {
"type": "boolean", "type": "boolean"
"nullable": true
}, },
"key": { "key": {
"type": "string" "type": "string"

View File

@@ -1,7 +1,7 @@
// This file is auto-generated by @hey-api/openapi-ts // This file is auto-generated by @hey-api/openapi-ts
import type { Options as ClientOptions, TDataShape, Client } from '@hey-api/client-fetch'; import type { Options as ClientOptions, TDataShape, Client } from '@hey-api/client-fetch';
import type { ReadAllConfigData, ReadAllConfigResponse, RemoveExtensionData, RemoveExtensionResponse, AddExtensionData, AddExtensionResponse, UpdateExtensionData, UpdateExtensionResponse, ReadConfigData, RemoveConfigData, RemoveConfigResponse, UpsertConfigData, UpsertConfigResponse } from './types.gen'; import type { ReadAllConfigData, ReadAllConfigResponse, RemoveExtensionData, RemoveExtensionResponse, AddExtensionData, AddExtensionResponse, UpdateExtensionData, UpdateExtensionResponse, ProvidersData, ProvidersResponse2, ReadConfigData, RemoveConfigData, RemoveConfigResponse, UpsertConfigData, UpsertConfigResponse } from './types.gen';
import { client as _heyApiClient } from './client.gen'; import { client as _heyApiClient } from './client.gen';
export type Options<TData extends TDataShape = TDataShape, ThrowOnError extends boolean = boolean> = ClientOptions<TData, ThrowOnError> & { export type Options<TData extends TDataShape = TDataShape, ThrowOnError extends boolean = boolean> = ClientOptions<TData, ThrowOnError> & {
@@ -58,6 +58,13 @@ export const updateExtension = <ThrowOnError extends boolean = false>(options: O
}); });
}; };
export const providers = <ThrowOnError extends boolean = false>(options?: Options<ProvidersData, ThrowOnError>) => {
return (options?.client ?? _heyApiClient).get<ProvidersResponse2, unknown, ThrowOnError>({
url: '/config/providers',
...options
});
};
export const readConfig = <ThrowOnError extends boolean = false>(options: Options<ReadConfigData, ThrowOnError>) => { export const readConfig = <ThrowOnError extends boolean = false>(options: Options<ReadConfigData, ThrowOnError>) => {
return (options.client ?? _heyApiClient).get<unknown, unknown, ThrowOnError>({ return (options.client ?? _heyApiClient).get<unknown, unknown, ThrowOnError>({
url: '/config/read', url: '/config/read',

View File

@@ -1,6 +1,14 @@
// This file is auto-generated by @hey-api/openapi-ts // This file is auto-generated by @hey-api/openapi-ts
export type ConfigKey = {
default?: string | null;
name: string;
required: boolean;
secret: boolean;
};
export type ConfigKeyQuery = { export type ConfigKeyQuery = {
is_secret: boolean;
key: string; key: string;
}; };
@@ -13,8 +21,59 @@ export type ExtensionQuery = {
name: string; name: string;
}; };
export type ProviderDetails = {
/**
* Indicates whether the provider is fully configured
*/
is_configured: boolean;
metadata: ProviderMetadata;
/**
* Unique identifier and name of the provider
*/
name: string;
};
/**
* Metadata about a provider's configuration requirements and capabilities
*/
export type ProviderMetadata = {
/**
* Required configuration keys
*/
config_keys: Array<ConfigKey>;
/**
* The default/recommended model for this provider
*/
default_model: string;
/**
* Description of the provider's capabilities
*/
description: string;
/**
* Display name for the provider in UIs
*/
display_name: string;
/**
* A list of currently known models
* TODO: eventually query the apis directly
*/
known_models: Array<string>;
/**
* Link to the docs where models can be found
*/
model_doc_link: string;
/**
* The unique identifier for this provider
*/
name: string;
};
export type ProvidersResponse = {
providers: Array<ProviderDetails>;
};
export type UpsertConfigQuery = { export type UpsertConfigQuery = {
is_secret?: boolean | null; is_secret: boolean;
key: string; key: string;
value: unknown; value: unknown;
}; };
@@ -116,6 +175,22 @@ export type UpdateExtensionResponses = {
export type UpdateExtensionResponse = UpdateExtensionResponses[keyof UpdateExtensionResponses]; export type UpdateExtensionResponse = UpdateExtensionResponses[keyof UpdateExtensionResponses];
export type ProvidersData = {
body?: never;
path?: never;
query?: never;
url: '/config/providers';
};
export type ProvidersResponses = {
/**
* All configuration values retrieved successfully
*/
200: Array<ProviderDetails>;
};
export type ProvidersResponse2 = ProvidersResponses[keyof ProvidersResponses];
export type ReadConfigData = { export type ReadConfigData = {
body: ConfigKeyQuery; body: ConfigKeyQuery;
path?: never; path?: never;

View File

@@ -1,4 +1,4 @@
import React, { createContext, useContext, useState, useEffect } from 'react'; import React, { createContext, useContext, useState, useEffect, useMemo } from 'react';
import { import {
readAllConfig, readAllConfig,
readConfig, readConfig,
@@ -7,8 +7,16 @@ import {
addExtension as apiAddExtension, addExtension as apiAddExtension,
removeExtension as apiRemoveExtension, removeExtension as apiRemoveExtension,
updateExtension as apiUpdateExtension, updateExtension as apiUpdateExtension,
providers,
} from '../api'; } from '../api';
import { client } from '../api/client.gen'; import { client } from '../api/client.gen';
import type {
ConfigResponse,
UpsertConfigQuery,
ConfigKeyQuery,
ExtensionQuery,
ProviderDetails,
} from '../api/types.gen';
// Initialize client configuration // Initialize client configuration
client.setConfig({ client.setConfig({
@@ -20,13 +28,15 @@ client.setConfig({
}); });
interface ConfigContextType { interface ConfigContextType {
config: Record<string, any>; config: ConfigResponse['config'];
upsert: (key: string, value: any, isSecret?: boolean) => Promise<void>; providersList: ProviderDetails[];
read: (key: string) => Promise<any>; upsert: (key: string, value: unknown, is_secret: boolean) => Promise<void>;
remove: (key: string) => Promise<void>; read: (key: string, is_secret: boolean) => Promise<unknown>;
addExtension: (name: string, config: any) => Promise<void>; remove: (key: string, is_secret: boolean) => Promise<void>;
updateExtension: (name: string, config: any) => Promise<void>; addExtension: (name: string, config: unknown) => Promise<void>;
updateExtension: (name: string, config: unknown) => Promise<void>;
removeExtension: (name: string) => Promise<void>; removeExtension: (name: string) => Promise<void>;
getProviders: (b: boolean) => Promise<ProviderDetails[]>;
} }
interface ConfigProviderProps { interface ConfigProviderProps {
@@ -36,13 +46,23 @@ interface ConfigProviderProps {
const ConfigContext = createContext<ConfigContextType | undefined>(undefined); const ConfigContext = createContext<ConfigContextType | undefined>(undefined);
export const ConfigProvider: React.FC<ConfigProviderProps> = ({ children }) => { export const ConfigProvider: React.FC<ConfigProviderProps> = ({ children }) => {
const [config, setConfig] = useState<Record<string, any>>({}); const [config, setConfig] = useState<ConfigResponse['config']>({});
const [providersList, setProvidersList] = useState<ProviderDetails[]>([]);
useEffect(() => { useEffect(() => {
// Load all configuration data on mount // Load all configuration data and providers on mount
(async () => { (async () => {
const response = await readAllConfig(); // Load config
setConfig(response.data.config || {}); const configResponse = await readAllConfig();
setConfig(configResponse.data.config || {});
// Load providers
try {
const providersResponse = await providers();
setProvidersList(providersResponse.data);
} catch (error) {
console.error('Failed to load providers:', error);
}
})(); })();
}, []); }, []);
@@ -51,58 +71,86 @@ export const ConfigProvider: React.FC<ConfigProviderProps> = ({ children }) => {
setConfig(response.data.config || {}); setConfig(response.data.config || {});
}; };
const upsert = async (key: string, value: any, isSecret?: boolean) => { const upsert = async (key: string, value: unknown, isSecret?: boolean) => {
const query: UpsertConfigQuery = {
key,
value,
is_secret: isSecret || null,
};
await upsertConfig({ await upsertConfig({
body: { body: query,
key,
value,
is_secret: isSecret,
},
}); });
await reloadConfig(); await reloadConfig();
}; };
const read = async (key: string) => { const read = async (key: string, is_secret: boolean = false) => {
return await readConfig({ const query: ConfigKeyQuery = { key: key, is_secret: is_secret };
body: { key }, const response = await readConfig({
body: query,
}); });
return response.data;
}; };
const remove = async (key: string) => { const remove = async (key: string, is_secret: boolean) => {
const query: ConfigKeyQuery = { key: key, is_secret: is_secret };
await removeConfig({ await removeConfig({
body: { key }, body: query,
}); });
await reloadConfig(); await reloadConfig();
}; };
const addExtension = async (name: string, config: any) => { const addExtension = async (name: string, config: unknown) => {
const query: ExtensionQuery = { name, config };
await apiAddExtension({ await apiAddExtension({
body: { name, config }, body: query,
}); });
await reloadConfig(); await reloadConfig();
}; };
const removeExtension = async (name: string) => { const removeExtension = async (name: string) => {
const query: ConfigKeyQuery = { key: name, is_secret: false };
await apiRemoveExtension({ await apiRemoveExtension({
body: { key: name }, body: query,
}); });
await reloadConfig(); await reloadConfig();
}; };
const updateExtension = async (name: string, config: any) => { const updateExtension = async (name: string, config: unknown) => {
const query: ExtensionQuery = { name, config };
await apiUpdateExtension({ await apiUpdateExtension({
body: { name, config }, body: query,
}); });
await reloadConfig(); await reloadConfig();
}; };
return ( const getProviders = async (forceRefresh = false): Promise<ProviderDetails[]> => {
<ConfigContext.Provider if (forceRefresh || providersList.length === 0) {
value={{ config, upsert, read, remove, addExtension, updateExtension, removeExtension }} // If a refresh is forced or we don't have providers yet
> const response = await providers();
{children} setProvidersList(response.data);
</ConfigContext.Provider> return response.data;
); }
// Otherwise return the cached providers
return providersList;
};
const contextValue = useMemo(
() => ({
config,
providersList,
upsert,
read,
remove,
addExtension,
updateExtension,
removeExtension,
getProviders,
}),
[config, providersList]
); // Functions don't need to be dependencies as they don't change
return <ConfigContext.Provider value={contextValue}>{children}</ConfigContext.Provider>;
}; };
export const useConfig = () => { export const useConfig = () => {

View File

@@ -1,70 +1,104 @@
import React from 'react'; import React, { memo, useMemo, useCallback } from 'react';
import { ProviderCard } from './subcomponents/ProviderCard'; import { ProviderCard } from './subcomponents/ProviderCard';
import ProviderState from './interfaces/ProviderState';
import OnRefresh from './callbacks/RefreshActiveProviders'; import OnRefresh from './callbacks/RefreshActiveProviders';
import { ProviderModalProvider, useProviderModal } from './modal/ProviderModalProvider'; import { ProviderModalProvider, useProviderModal } from './modal/ProviderModalProvider';
import ProviderConfigurationModal from './modal/ProviderConfiguationModal'; import ProviderConfigurationModal from './modal/ProviderConfiguationModal';
import { ProviderDetails } from '../../../api';
function GridLayout({ children }: { children: React.ReactNode }) { const GridLayout = memo(function GridLayout({ children }: { children: React.ReactNode }) {
return ( return (
<div className="grid grid-cols-[repeat(auto-fill,_minmax(140px,_1fr))] gap-3 [&_*]:z-20"> <div className="grid grid-cols-[repeat(auto-fill,_minmax(140px,_1fr))] gap-3 [&_*]:z-20">
{children} {children}
</div> </div>
); );
} });
function ProviderCards({ // Memoize the ProviderCards component
const ProviderCards = memo(function ProviderCards({
providers, providers,
isOnboarding, isOnboarding,
}: { }: {
providers: ProviderState[]; providers: ProviderDetails[];
isOnboarding: boolean; isOnboarding: boolean;
}) { }) {
const { openModal } = useProviderModal(); const { openModal } = useProviderModal();
const configureProviderViaModal = (provider: ProviderState) => { // Memoize these functions so they don't get recreated on every render
openModal(provider, { const configureProviderViaModal = useCallback(
onSubmit: (values: any) => { (provider: ProviderDetails) => {
console.log(`Configuring ${provider.name}:`, values); openModal(provider, {
// Your logic to save the configuration onSubmit: (values: any) => {
}, // Your logic to save the configuration
formProps: {}, },
}); formProps: {},
}; });
},
const handleLaunch = () => { [openModal]
OnRefresh();
};
return (
<>
{providers.map((provider) => (
<ProviderCard
key={provider.name}
provider={provider}
onConfigure={() => configureProviderViaModal(provider)}
onLaunch={handleLaunch}
isOnboarding={isOnboarding}
/>
))}
</>
); );
}
export default function ProviderGrid({ const handleLaunch = useCallback(() => {
OnRefresh();
}, []);
// Use useMemo to memoize the cards array
const providerCards = useMemo(() => {
return providers.map((provider) => (
<ProviderCard
key={provider.name}
provider={provider}
onConfigure={() => configureProviderViaModal(provider)}
onLaunch={handleLaunch}
isOnboarding={isOnboarding}
/>
));
}, [providers, isOnboarding, configureProviderViaModal, handleLaunch]);
return <>{providerCards}</>;
});
// Fix the ProviderModalProvider
export const OptimizedProviderModalProvider = memo(function OptimizedProviderModalProvider({
children,
}: {
children: React.ReactNode;
}) {
const contextValue = useMemo(
() => ({
isOpen: false,
currentProvider: null,
modalProps: {},
openModal: (provider, additionalProps = {}) => {
// Implementation
},
closeModal: () => {
// Implementation
},
}),
[]
);
return <ProviderModalProvider>{children}</ProviderModalProvider>;
});
export default memo(function ProviderGrid({
providers, providers,
isOnboarding, isOnboarding,
}: { }: {
providers: ProviderState[]; providers: ProviderDetails[];
isOnboarding: boolean; isOnboarding: boolean;
}) { }) {
console.log('(1) Provider Grid -- is this the onboarding page?', isOnboarding); // Remove the console.log
return ( console.log('provider grid');
<GridLayout> // Memoize the modal provider and its children to avoid recreating on every render
const modalProviderContent = useMemo(
() => (
<ProviderModalProvider> <ProviderModalProvider>
<ProviderCards providers={providers} isOnboarding={isOnboarding} /> <ProviderCards providers={providers} isOnboarding={isOnboarding} />
<ProviderConfigurationModal /> <ProviderConfigurationModal />
</ProviderModalProvider> </ProviderModalProvider>
</GridLayout> ),
[providers, isOnboarding]
); );
}
return <GridLayout>{modalProviderContent}</GridLayout>;
});

View File

@@ -1,61 +1,45 @@
import React from 'react'; import React, { useEffect, useState } from 'react';
import { ScrollArea } from '../../ui/scroll-area'; import { ScrollArea } from '../../ui/scroll-area';
import BackButton from '../../ui/BackButton'; import BackButton from '../../ui/BackButton';
import ProviderGrid from './ProviderGrid'; import ProviderGrid from './ProviderGrid';
import ProviderState from './interfaces/ProviderState'; import { useConfig } from '../../ConfigContext';
import { ProviderDetails } from '../../../api/types.gen';
const fakeProviderState: ProviderState[] = [
{
id: 'openai',
name: 'OpenAI',
isConfigured: true,
metadata: null,
},
{
id: 'anthropic',
name: 'Anthropic',
isConfigured: false,
metadata: null,
},
{
id: 'groq',
name: 'Groq',
isConfigured: false,
metadata: null,
},
{
id: 'google',
name: 'Google',
isConfigured: false,
metadata: null,
},
{
id: 'openrouter',
name: 'OpenRouter',
isConfigured: false,
metadata: null,
},
{
id: 'databricks',
name: 'Databricks',
isConfigured: false,
metadata: null,
},
{
id: 'ollama',
name: 'Ollama',
isConfigured: false,
metadata: { location: null },
},
{
id: 'gcp_vertex_ai',
name: 'GCP Vertex AI',
isConfigured: true,
metadata: { location: null },
},
];
export default function ProviderSettings({ onClose }: { onClose: () => void }) { export default function ProviderSettings({ onClose }: { onClose: () => void }) {
const { getProviders } = useConfig();
const [loading, setLoading] = useState(true);
const [providers, setProviders] = useState<ProviderDetails[]>([]);
// Load providers only once when component mounts
useEffect(() => {
let isMounted = true;
const loadProviders = async () => {
try {
// Force refresh to ensure we have the latest data
const result = await getProviders(true);
// Only update state if component is still mounted
if (isMounted && result) {
setProviders(result);
}
} catch (error) {
console.error('Failed to load providers:', error);
} finally {
if (isMounted) {
setLoading(false);
}
}
};
loadProviders();
// Cleanup function to prevent state updates on unmounted component
return () => {
isMounted = false;
};
}, []); // Empty dependency array ensures this only runs once
console.log(providers);
return ( return (
<div className="h-screen w-full"> <div className="h-screen w-full">
<div className="relative flex items-center h-[36px] w-full bg-bgSubtle"></div> <div className="relative flex items-center h-[36px] w-full bg-bgSubtle"></div>
@@ -66,7 +50,7 @@ export default function ProviderSettings({ onClose }: { onClose: () => void }) {
<h1 className="text-3xl font-medium text-textStandard mt-1">Configure</h1> <h1 className="text-3xl font-medium text-textStandard mt-1">Configure</h1>
</div> </div>
<div className=" py-8 pt-[20px]"> <div className="py-8 pt-[20px]">
<div className="flex justify-between items-center mb-6 border-b border-borderSubtle px-8"> <div className="flex justify-between items-center mb-6 border-b border-borderSubtle px-8">
<h2 className="text-xl font-medium text-textStandard">Providers</h2> <h2 className="text-xl font-medium text-textStandard">Providers</h2>
</div> </div>
@@ -74,7 +58,11 @@ export default function ProviderSettings({ onClose }: { onClose: () => void }) {
{/* Content Area */} {/* Content Area */}
<div className="max-w-5xl pt-4 px-8"> <div className="max-w-5xl pt-4 px-8">
<div className="relative z-10"> <div className="relative z-10">
<ProviderGrid providers={fakeProviderState} isOnboarding={false} /> {loading ? (
<div>Loading providers...</div>
) : (
<ProviderGrid providers={providers} isOnboarding={false} />
)}
</div> </div>
</div> </div>
</div> </div>

View File

@@ -5,10 +5,18 @@ import DefaultProviderSetupForm from './subcomponents/forms/DefaultProviderSetup
import ProviderSetupActions from './subcomponents/ProviderSetupActions'; import ProviderSetupActions from './subcomponents/ProviderSetupActions';
import ProviderLogo from './subcomponents/ProviderLogo'; import ProviderLogo from './subcomponents/ProviderLogo';
import { useProviderModal } from './ProviderModalProvider'; import { useProviderModal } from './ProviderModalProvider';
import { toast } from 'react-toastify';
import { PROVIDER_REGISTRY } from '../ProviderRegistry';
import { SecureStorageNotice } from './subcomponents/SecureStorageNotice'; import { SecureStorageNotice } from './subcomponents/SecureStorageNotice';
import DefaultSubmitHandler from './subcomponents/handlers/DefaultSubmitHandler'; import DefaultSubmitHandler from './subcomponents/handlers/DefaultSubmitHandler';
import OllamaSubmitHandler from './subcomponents/handlers/OllamaSubmitHandler';
import OllamaForm from './subcomponents/forms/OllamaForm';
const customSubmitHandler = {
provider_name: OllamaSubmitHandler, // example
};
const customForms = {
provider_name: OllamaForm, // example
};
export default function ProviderConfigurationModal() { export default function ProviderConfigurationModal() {
const { isOpen, currentProvider, modalProps, closeModal } = useProviderModal(); const { isOpen, currentProvider, modalProps, closeModal } = useProviderModal();
@@ -32,23 +40,11 @@ export default function ProviderConfigurationModal() {
if (!isOpen || !currentProvider) return null; if (!isOpen || !currentProvider) return null;
const headerText = `Configure ${currentProvider.name}`; const headerText = `Configure ${currentProvider.metadata.display_name}`;
const descriptionText = `Add your API key(s) for this provider to integrate into Goose`; const descriptionText = `Add your API key(s) for this provider to integrate into Goose`;
// Find the provider in the registry to get the details with customForm const SubmitHandler = customSubmitHandler[currentProvider.name] || DefaultSubmitHandler;
const providerEntry = PROVIDER_REGISTRY.find((p) => p.name === currentProvider.name); const FormComponent = customForms[currentProvider.name] || DefaultProviderSetupForm;
// Get the custom submit handler from the provider details
const customSubmitHandler = providerEntry?.details?.customSubmit;
// Use custom submit handler otherwise use default
const SubmitHandler = customSubmitHandler || DefaultSubmitHandler;
// Get the custom form component from the provider details
const CustomForm = providerEntry?.details?.customForm;
// Use custom form component if available, otherwise use default
const FormComponent = CustomForm || DefaultProviderSetupForm;
const handleSubmitForm = (e) => { const handleSubmitForm = (e) => {
e.preventDefault(); e.preventDefault();
@@ -74,7 +70,7 @@ export default function ProviderConfigurationModal() {
<Modal> <Modal>
<div className="space-y-1"> <div className="space-y-1">
{/* Logo area - centered above title */} {/* Logo area - centered above title */}
<ProviderLogo providerName={currentProvider.id} /> <ProviderLogo providerName={currentProvider.name} />
{/* Title and some information - centered */} {/* Title and some information - centered */}
<ProviderSetupHeader title={headerText} body={descriptionText} /> <ProviderSetupHeader title={headerText} body={descriptionText} />
</div> </div>
@@ -87,7 +83,7 @@ export default function ProviderConfigurationModal() {
{...(modalProps.formProps || {})} // Spread any custom form props {...(modalProps.formProps || {})} // Spread any custom form props
/> />
{providerEntry?.details?.parameters && providerEntry.details.parameters.length > 0 && ( {currentProvider.metadata.config_keys && currentProvider.metadata.config_keys.length > 0 && (
<SecureStorageNotice /> <SecureStorageNotice />
)} )}
<ProviderSetupActions onCancel={handleCancel} onSubmit={handleSubmitForm} /> <ProviderSetupActions onCancel={handleCancel} onSubmit={handleSubmitForm} />

View File

@@ -1,55 +1,60 @@
import React, { createContext, useContext, useState } from 'react'; import React, { createContext, useContext, useState, useMemo, useCallback } from 'react';
import ProviderState from '../interfaces/ProviderState'; import { ProviderDetails } from '../../../../api';
interface ProviderModalContextType { interface ProviderModalContextType {
isOpen: boolean; isOpen: boolean;
currentProvider: ProviderState | null; currentProvider: ProviderDetails | null;
modalProps: any; modalProps: any;
openModal: (provider: ProviderState, additionalProps: any) => void; openModal: (provider: ProviderDetails, additionalProps: any) => void;
closeModal: () => void; closeModal: () => void;
} }
const ProviderModalContext = createContext({ const defaultContext: ProviderModalContextType = {
isOpen: false, isOpen: false,
currentProvider: null, currentProvider: null,
modalProps: {}, modalProps: {},
openModal: (provider, additionalProps) => {}, openModal: () => {},
closeModal: () => {}, closeModal: () => {},
}); };
const ProviderModalContext = createContext<ProviderModalContextType>(defaultContext);
export const useProviderModal = () => useContext<ProviderModalContextType>(ProviderModalContext); export const useProviderModal = () => useContext<ProviderModalContextType>(ProviderModalContext);
export const ProviderModalProvider = ({ children }) => { export const ProviderModalProvider: React.FC<{ children: React.ReactNode }> = ({ children }) => {
const [isOpen, setIsOpen] = useState(false); const [isOpen, setIsOpen] = useState(false);
const [currentProvider, setCurrentProvider] = useState(null); const [currentProvider, setCurrentProvider] = useState<ProviderDetails | null>(null);
const [modalProps, setModalProps] = useState({}); const [modalProps, setModalProps] = useState({});
const openModal = (provider, additionalProps = {}) => { // Use useCallback to prevent function recreation on each render
const openModal = useCallback((provider: ProviderDetails, additionalProps = {}) => {
setCurrentProvider(provider); setCurrentProvider(provider);
setModalProps(additionalProps); setModalProps(additionalProps);
setIsOpen(true); setIsOpen(true);
}; }, []);
const closeModal = () => { const closeModal = useCallback(() => {
setIsOpen(false); setIsOpen(false);
// Use a small timeout to prevent UI flicker // Use a small timeout to prevent UI flicker
setTimeout(() => { setTimeout(() => {
setCurrentProvider(null); setCurrentProvider(null);
setModalProps({}); setModalProps({});
}, 200); }, 200);
}; }, []);
// Memoize the context value to prevent unnecessary re-renders
const contextValue = useMemo(
() => ({
isOpen,
currentProvider,
modalProps,
openModal,
closeModal,
}),
[isOpen, currentProvider, modalProps, openModal, closeModal]
);
return ( return (
<ProviderModalContext.Provider <ProviderModalContext.Provider value={contextValue}>{children}</ProviderModalContext.Provider>
value={{
isOpen,
currentProvider,
modalProps,
openModal,
closeModal,
}}
>
{children}
</ProviderModalContext.Provider>
); );
}; };

View File

@@ -20,7 +20,11 @@ const providerLogos = {
default: DefaultLogo, default: DefaultLogo,
}; };
export default function ProviderLogo({ providerName }) { interface ProviderLogoProps {
providerName: string;
}
export default function ProviderLogo({ providerName }: ProviderLogoProps) {
// Convert provider name to lowercase and fetch the logo // Convert provider name to lowercase and fetch the logo
const logoKey = providerName.toLowerCase(); const logoKey = providerName.toLowerCase();
const logo = providerLogos[logoKey] || DefaultLogo; const logo = providerLogos[logoKey] || DefaultLogo;

View File

@@ -1,16 +1,29 @@
import React, { useEffect } from 'react'; import React, { useEffect, useMemo } from 'react';
import { Input } from '../../../../../ui/input'; import { Input } from '../../../../../ui/input';
import { PROVIDER_REGISTRY } from '../../../ProviderRegistry';
export default function DefaultProviderSetupForm({ configValues, setConfigValues, provider }) { interface DefaultProviderSetupFormProps {
const providerEntry = PROVIDER_REGISTRY.find((p) => p.name === provider.name); configValues: Record<string, any>;
const parameters = providerEntry?.details?.parameters || []; setConfigValues: React.Dispatch<React.SetStateAction<Record<string, any>>>;
provider: any;
}
export default function DefaultProviderSetupForm({
configValues,
setConfigValues,
provider,
}: DefaultProviderSetupFormProps) {
const parameters = provider.metadata.config_keys || [];
// Initialize default values when the component mounts or provider changes // Initialize default values when the component mounts or provider changes
useEffect(() => { useEffect(() => {
const defaultValues = {}; const defaultValues = {};
parameters.forEach((parameter) => { parameters.forEach((parameter) => {
if (parameter.default !== undefined && !configValues[parameter.name]) { if (
parameter.required &&
parameter.default !== undefined &&
parameter.default !== null &&
!configValues[parameter.name]
) {
defaultValues[parameter.name] = parameter.default; defaultValues[parameter.name] = parameter.default;
} }
}); });
@@ -24,25 +37,51 @@ export default function DefaultProviderSetupForm({ configValues, setConfigValues
} }
}, [provider.name, parameters, setConfigValues, configValues]); }, [provider.name, parameters, setConfigValues, configValues]);
// Filter parameters to only show required ones
const requiredParameters = useMemo(() => {
return parameters.filter((param) => param.required === true);
}, [parameters]);
// Helper function to generate appropriate placeholder text
const getPlaceholder = (parameter) => {
// If default is defined and not null, show it
if (parameter.default !== undefined && parameter.default !== null) {
return `Default: ${parameter.default}`;
}
// Otherwise, use the parameter name as a hint
return parameter.name.toUpperCase();
};
return ( return (
<div className="mt-4 space-y-4"> <div className="mt-4 space-y-4">
{parameters.map((parameter) => ( {requiredParameters.length === 0 ? (
<div key={parameter.name}> <div className="text-center text-gray-500">
<Input No required configuration for this provider.
type={parameter.is_secret ? 'password' : 'text'}
value={configValues[parameter.name] || ''}
onChange={(e) =>
setConfigValues((prev) => ({
...prev,
[parameter.name]: e.target.value,
}))
}
placeholder={parameter.name}
className="w-full h-14 px-4 font-regular rounded-lg border shadow-none border-gray-300 bg-white text-lg placeholder:text-gray-400 font-regular text-gray-900"
required
/>
</div> </div>
))} ) : (
requiredParameters.map((parameter) => (
<div key={parameter.name}>
<label className="block text-sm font-medium text-gray-700 mb-1">
{parameter.name}
<span className="text-red-500 ml-1">*</span>
</label>
<Input
type={parameter.secret ? 'password' : 'text'}
value={configValues[parameter.name] || ''}
onChange={(e) =>
setConfigValues((prev) => ({
...prev,
[parameter.name]: e.target.value,
}))
}
placeholder={getPlaceholder(parameter)}
className="w-full h-14 px-4 font-regular rounded-lg border shadow-none border-gray-300 bg-white text-lg placeholder:text-gray-400 font-regular text-gray-900"
required={true}
/>
</div>
))
)}
</div> </div>
); );
} }

View File

@@ -1,10 +1,6 @@
import React from 'react'; import React, { memo } from 'react';
import { ExclamationButton, GreenCheckButton } from './buttons/CardButtons'; import { GreenCheckButton } from './buttons/CardButtons';
import { import { ConfiguredProviderTooltipMessage, ProviderDescription } from './utils/StringUtils';
ConfiguredProviderTooltipMessage,
OllamaNotConfiguredTooltipMessage,
ProviderDescription,
} from './utils/StringUtils';
interface CardHeaderProps { interface CardHeaderProps {
name: string; name: string;
@@ -13,9 +9,10 @@ interface CardHeaderProps {
} }
// Make CardTitle a proper React component // Make CardTitle a proper React component
function CardTitle({ name }: { name: string }) { const CardTitle = memo(({ name }: { name: string }) => {
return <h3 className="text-base font-medium text-textStandard truncate mr-2">{name}</h3>; return <h3 className="text-base font-medium text-textStandard truncate mr-2">{name}</h3>;
} });
CardTitle.displayName = 'CardTitle';
// Properly type ProviderNameAndStatus props // Properly type ProviderNameAndStatus props
interface ProviderNameAndStatusProps { interface ProviderNameAndStatusProps {
@@ -23,9 +20,8 @@ interface ProviderNameAndStatusProps {
isConfigured: boolean; isConfigured: boolean;
} }
function ProviderNameAndStatus({ name, isConfigured }: ProviderNameAndStatusProps) { const ProviderNameAndStatus = memo(({ name, isConfigured }: ProviderNameAndStatusProps) => {
console.log(`Provider Name: ${name}, Is Configured: ${isConfigured}`); // Remove the console.log completely
return ( return (
<div className="flex items-center justify-between w-full"> <div className="flex items-center justify-between w-full">
<CardTitle name={name} /> <CardTitle name={name} />
@@ -34,14 +30,18 @@ function ProviderNameAndStatus({ name, isConfigured }: ProviderNameAndStatusProp
{isConfigured && <GreenCheckButton tooltip={ConfiguredProviderTooltipMessage(name)} />} {isConfigured && <GreenCheckButton tooltip={ConfiguredProviderTooltipMessage(name)} />}
</div> </div>
); );
} });
ProviderNameAndStatus.displayName = 'ProviderNameAndStatus';
// Add a container div to the CardHeader // Add a container div to the CardHeader
export default function CardHeader({ name, description, isConfigured }: CardHeaderProps) { const CardHeader = memo(function CardHeader({ name, description, isConfigured }: CardHeaderProps) {
return ( return (
<> <>
<ProviderNameAndStatus name={name} isConfigured={isConfigured} /> <ProviderNameAndStatus name={name} isConfigured={isConfigured} />
<ProviderDescription description={description} /> <ProviderDescription description={description} />
</> </>
); );
} });
CardHeader.displayName = 'CardHeader';
export default CardHeader;

View File

@@ -1,74 +1,42 @@
import React from 'react'; import React, { memo, useMemo } from 'react';
import CardContainer from './CardContainer'; import CardContainer from './CardContainer';
import CardHeader from './CardHeader'; import CardHeader from './CardHeader';
import ProviderState from '../interfaces/ProviderState';
import CardBody from './CardBody'; import CardBody from './CardBody';
import { PROVIDER_REGISTRY } from '../ProviderRegistry';
import DefaultCardButtons from './buttons/DefaultCardButtons'; import DefaultCardButtons from './buttons/DefaultCardButtons';
import { ProviderDetails, ProviderMetadata } from '../../../../api';
type ProviderCardProps = { type ProviderCardProps = {
provider: ProviderState; provider: ProviderDetails;
onConfigure: () => void; onConfigure: () => void;
onLaunch: () => void; onLaunch: () => void;
isOnboarding: boolean; isOnboarding: boolean;
}; };
// export function ProviderCard({ provider, buttonCallbacks, isOnboarding }: ProviderCardProps) { export const ProviderCard = memo(function ProviderCard({
// const providerEntry = PROVIDER_REGISTRY.find((p) => p.name === provider.name); provider,
// onConfigure,
// // Add safety check onLaunch,
// if (!providerEntry) { isOnboarding,
// console.error(`Provider ${provider.name} not found in registry`); }: ProviderCardProps) {
// return null; // Safely access metadata with null checks
// } const providerMetadata: ProviderMetadata | null = provider?.metadata || null;
//
// const providerDetails = providerEntry.details;
// // Add another safety check
// if (!providerDetails) {
// console.error(`Provider ${provider.name} has no details`);
// return null;
// }
// console.log('provider details', providerDetails);
//
// try {
// const actions = providerDetails.getActions(provider, buttonCallbacks, isOnboarding);
//
// return (
// <CardContainer
// header={
// <CardHeader
// name={providerDetails.name}
// description={providerDetails.description}
// isConfigured={provider.isConfigured}
// />
// }
// body={<CardBody actions={actions} />}
// />
// );
// } catch (error) {
// console.error(`Error rendering provider card for ${provider.name}:`, error);
// return null;
// }
// }
export function ProviderCard({ provider, onConfigure, onLaunch, isOnboarding }: ProviderCardProps) { // Instead of useEffect for logging, use useMemo to memoize the metadata
const providerEntry = PROVIDER_REGISTRY.find((p) => p.name === provider.name); const metadata = useMemo(() => providerMetadata, [provider]);
// Add safety check // Remove the logging completely
if (!providerEntry?.details) {
console.error(`Provider ${provider.name} not found in registry or has no details`); if (!metadata) {
return null; return <div>ProviderCard error: No metadata provided</div>;
} }
const providerDetails = providerEntry.details;
return ( return (
<CardContainer <CardContainer
header={ header={
<CardHeader <CardHeader
name={providerDetails.name} name={metadata.display_name || provider?.name || 'Unknown Provider'}
description={providerDetails.description} description={metadata.description || ''}
isConfigured={provider.isConfigured} isConfigured={provider?.is_configured || false}
/> />
} }
body={ body={
@@ -83,4 +51,4 @@ export function ProviderCard({ provider, onConfigure, onLaunch, isOnboarding }:
} }
/> />
); );
} });

View File

@@ -1,13 +1,13 @@
import React from 'react'; import React from 'react';
import { ConfigureSettingsButton, RocketButton } from './CardButtons'; import { ConfigureSettingsButton, RocketButton } from './CardButtons';
import ProviderState from '@/src/components/settings_v2/providers/interfaces/ProviderState'; import { ProviderDetails } from '../../../../../api';
// can define other optional callbacks as needed // can define other optional callbacks as needed
interface CardButtonsProps { interface CardButtonsProps {
provider: ProviderState; provider: ProviderDetails;
isOnboardingPage: boolean; isOnboardingPage: boolean;
onConfigure: (provider: ProviderState) => void; onConfigure: (provider: ProviderDetails) => void;
onLaunch: (provider: ProviderState) => void; onLaunch: (provider: ProviderDetails) => void;
} }
function getDefaultTooltipMessages(name: string, actionType: string) { function getDefaultTooltipMessages(name: string, actionType: string) {
@@ -32,7 +32,7 @@ export default function DefaultCardButtons({
return ( return (
<> <>
{/*Set up an unconfigured provider */} {/*Set up an unconfigured provider */}
{!provider.isConfigured && ( {!provider.is_configured && (
<ConfigureSettingsButton <ConfigureSettingsButton
tooltip={getDefaultTooltipMessages(provider.name, 'add')} tooltip={getDefaultTooltipMessages(provider.name, 'add')}
onClick={(e) => { onClick={(e) => {
@@ -42,7 +42,7 @@ export default function DefaultCardButtons({
/> />
)} )}
{/*show edit tooltip instead when hovering over button for configured providers*/} {/*show edit tooltip instead when hovering over button for configured providers*/}
{provider.isConfigured && !isOnboardingPage && ( {provider.is_configured && !isOnboardingPage && (
<ConfigureSettingsButton <ConfigureSettingsButton
tooltip={getDefaultTooltipMessages(provider.name, 'edit')} tooltip={getDefaultTooltipMessages(provider.name, 'edit')}
onClick={(e) => { onClick={(e) => {
@@ -52,7 +52,7 @@ export default function DefaultCardButtons({
/> />
)} )}
{/*show Launch button for configured providers on onboarding page*/} {/*show Launch button for configured providers on onboarding page*/}
{provider.isConfigured && isOnboardingPage && ( {provider.is_configured && isOnboardingPage && (
<RocketButton <RocketButton
onClick={(e) => { onClick={(e) => {
e.stopPropagation(); e.stopPropagation();

View File

@@ -14,30 +14,5 @@ export function snakeToTitleCase(snake: string): string {
export function patchConsoleLogging() { export function patchConsoleLogging() {
// Intercept console methods // Intercept console methods
const originalConsole = { return;
log: console.log,
error: console.error,
warn: console.warn,
info: console.info,
};
console.log = (...args: any[]) => {
window.electron.logInfo(`[LOG] ${args.join(' ')}`);
originalConsole.log(...args);
};
console.error = (...args: any[]) => {
window.electron.logInfo(`[ERROR] ${args.join(' ')}`);
originalConsole.error(...args);
};
console.warn = (...args: any[]) => {
window.electron.logInfo(`[WARN] ${args.join(' ')}`);
originalConsole.warn(...args);
};
console.info = (...args: any[]) => {
window.electron.logInfo(`[INFO] ${args.join(' ')}`);
originalConsole.info(...args);
};
} }