From 31567fa22eb78fd9feeaa51631ab92556b9b7f93 Mon Sep 17 00:00:00 2001 From: Aljaz Date: Wed, 28 May 2025 19:25:41 +0200 Subject: [PATCH] Initialize extensions via agent --- crates/goose-api/src/main.rs | 25 ++++++++++++++----------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/crates/goose-api/src/main.rs b/crates/goose-api/src/main.rs index 4edae16c..e5c5ed5a 100644 --- a/crates/goose-api/src/main.rs +++ b/crates/goose-api/src/main.rs @@ -2,8 +2,8 @@ use warp::{Filter, Rejection}; use warp::http::HeaderValue; use serde::{Deserialize, Serialize}; use std::sync::LazyLock; -use goose::agents::{Agent, extension_manager::ExtensionManager}; -use goose::config::Config; +use goose::agents::{Agent, ExtensionConfig, extension_manager::ExtensionManager}; +use goose::config::{Config, ExtensionEntry}; use goose::providers::{create, providers}; use goose::model::ModelConfig; use goose::message::Message; @@ -256,17 +256,20 @@ async fn initialize_provider_config() -> Result<(), anyhow::Error> { Ok(()) } /// Initialize extensions from the configuration. -fn initialize_extensions(config: &config::Config) -> Result<(), anyhow::Error> { +async fn initialize_extensions(config: &config::Config) -> Result<(), anyhow::Error> { if let Ok(ext_table) = config.get_table("extensions") { for (name, ext_config) in ext_table { - let json_value: serde_json::Value = ext_config.clone().try_deserialize() + // Deserialize into ExtensionEntry to get enabled flag and config + let entry: ExtensionEntry = ext_config.clone().try_deserialize() .map_err(|e| anyhow::anyhow!("Failed to deserialize extension config for {}: {}", name, e))?; - // Only process the extension if it is enabled. - let enabled = json_value.get("enabled").and_then(|v| v.as_bool()).unwrap_or(false); - if enabled { - // Note: The ExtensionManager does not provide a method to register extensions. - // Here, we log that the extension is enabled. Adjust this code if a registration API becomes available. - info!("Extension {} is enabled and would be registered", name); + + if entry.enabled { + let extension_config: ExtensionConfig = entry.config; + // Acquire the global agent lock and try to add the extension + let mut agent = AGENT.lock().await; + if let Err(e) = agent.add_extension(extension_config).await { + error!("Failed to add extension {}: {}", name, e); + } } else { info!("Skipping disabled extension: {}", name); } @@ -315,7 +318,7 @@ async fn main() -> Result<(), anyhow::Error> { } // Initialize extensions from configuration - if let Err(e) = initialize_extensions(&api_config) { + if let Err(e) = initialize_extensions(&api_config).await { error!("Failed to initialize extensions: {}", e); }