diff --git a/crates/goose-cli/src/commands/configure.rs b/crates/goose-cli/src/commands/configure.rs index 687d018c..ec6ca71d 100644 --- a/crates/goose-cli/src/commands/configure.rs +++ b/crates/goose-cli/src/commands/configure.rs @@ -983,11 +983,11 @@ pub async fn configure_tool_permissions_dialog() -> Result<(), Box> { .get_param("GOOSE_MODEL") .expect("No model configured. Please set model first"); let model_config = goose::model::ModelConfig::new(model.clone()); - let provider = - goose::providers::create(&provider_name, model_config).expect("Failed to create provider"); // Create the agent - let mut agent = Agent::new(provider); + let agent = Agent::new(); + let new_provider = create(&provider_name, model_config)?; + agent.update_provider(new_provider).await?; if let Ok(Some(config)) = ExtensionConfigManager::get_config_by_name(&selected_extension_name) { agent .add_extension(config.clone()) diff --git a/crates/goose-cli/src/session/builder.rs b/crates/goose-cli/src/session/builder.rs index 2e8769d3..9c9092b0 100644 --- a/crates/goose-cli/src/session/builder.rs +++ b/crates/goose-cli/src/session/builder.rs @@ -2,6 +2,7 @@ use console::style; use goose::agents::extension::ExtensionError; use goose::agents::Agent; use goose::config::{Config, ExtensionConfig, ExtensionConfigManager}; +use goose::providers::create; use goose::session; use goose::session::Identifier; use mcp_client::transport::Error as McpClientError; @@ -46,11 +47,11 @@ pub async fn build_session(session_config: SessionBuilderConfig) -> Session { .get_param("GOOSE_MODEL") .expect("No model configured. Run 'goose configure' first"); let model_config = goose::model::ModelConfig::new(model.clone()); - let provider = - goose::providers::create(&provider_name, model_config).expect("Failed to create provider"); // Create the agent - let mut agent = Agent::new(provider); + let agent: Agent = Agent::new(); + let new_provider = create(&provider_name, model_config).unwrap(); + let _ = agent.update_provider(new_provider).await; // Handle session file resolution and resuming let session_file = if session_config.resume { diff --git a/crates/goose-cli/src/session/mod.rs b/crates/goose-cli/src/session/mod.rs index 0dfc0899..368274ce 100644 --- a/crates/goose-cli/src/session/mod.rs +++ b/crates/goose-cli/src/session/mod.rs @@ -285,7 +285,7 @@ impl Session { async fn process_message(&mut self, message: String) -> Result<()> { self.messages.push(Message::user().with_text(&message)); // Get the provider from the agent for description generation - let provider = self.agent.provider(); + let provider = self.agent.provider().await?; // Persist messages with provider for automatic description generation session::persist_messages(&self.session_file, &self.messages, Some(provider)).await?; @@ -357,7 +357,7 @@ impl Session { self.messages.push(Message::user().with_text(&content)); // Get the provider from the agent for description generation - let provider = self.agent.provider(); + let provider = self.agent.provider().await?; // Persist messages with provider for automatic description generation session::persist_messages( @@ -526,7 +526,7 @@ impl Session { output::render_message(&plan_response, self.debug); output::hide_thinking(); let planner_response_type = - classify_planner_response(plan_response.as_concat_text(), self.agent.provider()) + classify_planner_response(plan_response.as_concat_text(), self.agent.provider().await?) .await?; match planner_response_type { diff --git a/crates/goose-ffi/src/lib.rs b/crates/goose-ffi/src/lib.rs index 06a9db5a..bd2237d7 100644 --- a/crates/goose-ffi/src/lib.rs +++ b/crates/goose-ffi/src/lib.rs @@ -178,7 +178,10 @@ pub unsafe extern "C" fn goose_agent_new(config: *const ProviderConfigFFI) -> Ag // Create Databricks provider with required parameters match DatabricksProvider::from_params(host, api_key, model_config) { Ok(provider) => { - let agent = Agent::new(Arc::new(provider)); + let agent = Agent::new(); + get_runtime().block_on(async { + let _ = agent.update_provider(Arc::new(provider)).await; + }); Box::into_raw(Box::new(agent)) } Err(e) => { diff --git a/crates/goose-server/src/commands/agent.rs b/crates/goose-server/src/commands/agent.rs index 32787623..bf0a2897 100644 --- a/crates/goose-server/src/commands/agent.rs +++ b/crates/goose-server/src/commands/agent.rs @@ -1,6 +1,9 @@ +use std::sync::Arc; + use crate::configuration; use crate::state; use anyhow::Result; +use goose::agents::Agent; use tower_http::cors::{Any, CorsLayer}; use tracing::info; @@ -15,8 +18,10 @@ pub async fn run() -> Result<()> { let secret_key = std::env::var("GOOSE_SERVER__SECRET_KEY").unwrap_or_else(|_| "test".to_string()); - // Create app state - agent will start as None - let state = state::AppState::new(secret_key.clone()).await?; + let new_agent = Agent::new(); + + // Create app state with agent + let state = state::AppState::new(Arc::new(new_agent), secret_key.clone()).await; // Create router with CORS support let cors = CorsLayer::new() diff --git a/crates/goose-server/src/routes/agent.rs b/crates/goose-server/src/routes/agent.rs index 03696fcc..354c22c6 100644 --- a/crates/goose-server/src/routes/agent.rs +++ b/crates/goose-server/src/routes/agent.rs @@ -8,14 +8,15 @@ use axum::{ }; use goose::config::Config; use goose::config::PermissionManager; -use goose::{agents::Agent, model::ModelConfig, providers}; +use goose::model::ModelConfig; +use goose::providers::create; use goose::{ agents::{extension::ToolInfo, extension_manager::get_parameter_names}, config::permission::PermissionLevel, }; use serde::{Deserialize, Serialize}; use std::collections::HashMap; -use std::env; +use std::sync::Arc; #[derive(Serialize)] struct VersionsResponse { @@ -33,17 +34,6 @@ struct ExtendPromptResponse { success: bool, } -#[derive(Deserialize)] -struct CreateAgentRequest { - provider: String, - model: Option, -} - -#[derive(Serialize)] -struct CreateAgentResponse { - version: String, -} - #[derive(Deserialize)] struct ProviderFile { name: String, @@ -66,6 +56,12 @@ struct ProviderList { details: ProviderDetails, } +#[derive(Deserialize)] +struct UpdateProviderRequest { + provider: String, + model: Option, +} + #[derive(Deserialize)] pub struct GetToolsQuery { extension_name: Option, @@ -82,53 +78,18 @@ async fn get_versions() -> Json { } async fn extend_prompt( - State(state): State, + State(state): State>, headers: HeaderMap, Json(payload): Json, ) -> Result, StatusCode> { verify_secret_key(&headers, &state)?; - let mut agent = state.agent.write().await; - if let Some(ref mut agent) = *agent { - agent.extend_system_prompt(payload.extension).await; - Ok(Json(ExtendPromptResponse { success: true })) - } else { - Err(StatusCode::NOT_FOUND) - } -} - -#[axum::debug_handler] -async fn create_agent( - State(state): State, - headers: HeaderMap, - Json(payload): Json, -) -> Result, StatusCode> { - verify_secret_key(&headers, &state)?; - - // Set the environment variable for the model if provided - if let Some(model) = &payload.model { - let env_var_key = format!("{}_MODEL", payload.provider.to_uppercase()); - env::set_var(env_var_key.clone(), model); - println!("Set environment variable: {}={}", env_var_key, model); - } - - let config = Config::global(); - let model = payload.model.unwrap_or_else(|| { - config - .get_param("GOOSE_MODEL") - .expect("Did not find a model on payload or in env") - }); - let model_config = ModelConfig::new(model); - let provider = - providers::create(&payload.provider, model_config).expect("Failed to create provider"); - - let version = String::from("goose"); - let new_agent = Agent::new(provider); - - let mut agent = state.agent.write().await; - *agent = Some(new_agent); - - Ok(Json(CreateAgentResponse { version })) + let agent = state + .get_agent() + .await + .map_err(|_| StatusCode::PRECONDITION_FAILED)?; + agent.extend_system_prompt(payload.extension.clone()).await; + Ok(Json(ExtendPromptResponse { success: true })) } async fn list_providers() -> Json> { @@ -168,7 +129,7 @@ async fn list_providers() -> Json> { ) )] async fn get_tools( - State(state): State, + State(state): State>, headers: HeaderMap, Query(query): Query, ) -> Result>, StatusCode> { @@ -176,8 +137,10 @@ async fn get_tools( let config = Config::global(); let goose_mode = config.get_param("GOOSE_MODE").unwrap_or("auto".to_string()); - let agent = state.agent.read().await; - let agent = agent.as_ref().ok_or(StatusCode::PRECONDITION_REQUIRED)?; + let agent = state + .get_agent() + .await + .map_err(|_| StatusCode::PRECONDITION_FAILED)?; let permission_manager = PermissionManager::default(); let mut tools: Vec = agent @@ -210,12 +173,56 @@ async fn get_tools( Ok(Json(tools)) } -pub fn routes(state: AppState) -> Router { +#[utoipa::path( + post, + path = "/agent/update_provider", + responses( + (status = 200, description = "Update provider completed", body = String), + (status = 500, description = "Internal server error") + ) +)] +async fn update_agent_provider( + State(state): State>, + headers: HeaderMap, + Json(payload): Json, +) -> Result { + // Verify secret key + let secret_key = headers + .get("X-Secret-Key") + .and_then(|value| value.to_str().ok()) + .ok_or(StatusCode::UNAUTHORIZED)?; + + if secret_key != state.secret_key { + return Err(StatusCode::UNAUTHORIZED); + } + + let agent = state + .get_agent() + .await + .map_err(|_| StatusCode::PRECONDITION_FAILED)?; + + let config = Config::global(); + let model = payload.model.unwrap_or_else(|| { + config + .get_param("GOOSE_MODEL") + .expect("Did not find a model on payload or in env to update provider with") + }); + let model_config = ModelConfig::new(model); + let new_provider = create(&payload.provider, model_config).unwrap(); + agent + .update_provider(new_provider) + .await + .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; + + Ok(StatusCode::OK) +} + +pub fn routes(state: Arc) -> Router { Router::new() .route("/agent/versions", get(get_versions)) .route("/agent/providers", get(list_providers)) .route("/agent/prompt", post(extend_prompt)) .route("/agent/tools", get(get_tools)) - .route("/agent", post(create_agent)) + .route("/agent/update_provider", post(update_agent_provider)) .with_state(state) } diff --git a/crates/goose-server/src/routes/config_management.rs b/crates/goose-server/src/routes/config_management.rs index 87c953e7..14a5b08e 100644 --- a/crates/goose-server/src/routes/config_management.rs +++ b/crates/goose-server/src/routes/config_management.rs @@ -18,7 +18,7 @@ use once_cell::sync::Lazy; use serde::{Deserialize, Serialize}; use serde_json::Value; use serde_yaml; -use std::collections::HashMap; +use std::{collections::HashMap, sync::Arc}; use utoipa::ToSchema; #[derive(Serialize, ToSchema)] @@ -89,7 +89,7 @@ pub struct UpsertPermissionsQuery { ) )] pub async fn upsert_config( - State(state): State, + State(state): State>, headers: HeaderMap, Json(query): Json, ) -> Result, StatusCode> { @@ -116,7 +116,7 @@ pub async fn upsert_config( ) )] pub async fn remove_config( - State(state): State, + State(state): State>, headers: HeaderMap, Json(query): Json, ) -> Result, StatusCode> { @@ -148,7 +148,7 @@ pub async fn remove_config( ) )] pub async fn read_config( - State(state): State, + State(state): State>, headers: HeaderMap, Json(query): Json, ) -> Result, StatusCode> { @@ -180,7 +180,7 @@ pub async fn read_config( ) )] pub async fn get_extensions( - State(state): State, + State(state): State>, headers: HeaderMap, ) -> Result, StatusCode> { verify_secret_key(&headers, &state)?; @@ -213,7 +213,7 @@ pub async fn get_extensions( ) )] pub async fn add_extension( - State(state): State, + State(state): State>, headers: HeaderMap, Json(extension_query): Json, ) -> Result, StatusCode> { @@ -251,7 +251,7 @@ pub async fn add_extension( ) )] pub async fn remove_extension( - State(state): State, + State(state): State>, headers: HeaderMap, axum::extract::Path(name): axum::extract::Path, ) -> Result, StatusCode> { @@ -272,7 +272,7 @@ pub async fn remove_extension( ) )] pub async fn read_all_config( - State(state): State, + State(state): State>, headers: HeaderMap, ) -> Result, StatusCode> { // Use the helper function to verify the secret key @@ -297,7 +297,7 @@ pub async fn read_all_config( ) )] pub async fn providers( - State(state): State, + State(state): State>, headers: HeaderMap, ) -> Result>, StatusCode> { verify_secret_key(&headers, &state)?; @@ -332,7 +332,7 @@ pub async fn providers( ) )] pub async fn init_config( - State(state): State, + State(state): State>, headers: HeaderMap, ) -> Result, StatusCode> { verify_secret_key(&headers, &state)?; @@ -402,7 +402,7 @@ pub async fn init_config( ) )] pub async fn upsert_permissions( - State(state): State, + State(state): State>, headers: HeaderMap, Json(query): Json, ) -> Result, StatusCode> { @@ -435,7 +435,7 @@ pub static APP_STRATEGY: Lazy = Lazy::new(|| AppStrategyArgs { ) )] pub async fn backup_config( - State(state): State, + State(state): State>, headers: HeaderMap, ) -> Result, StatusCode> { verify_secret_key(&headers, &state)?; @@ -466,7 +466,7 @@ pub async fn backup_config( } } -pub fn routes(state: AppState) -> Router { +pub fn routes(state: Arc) -> Router { Router::new() .route("/config", get(read_all_config)) .route("/config/upsert", post(upsert_config)) diff --git a/crates/goose-server/src/routes/configs.rs b/crates/goose-server/src/routes/configs.rs index ee918009..b2513b2e 100644 --- a/crates/goose-server/src/routes/configs.rs +++ b/crates/goose-server/src/routes/configs.rs @@ -10,7 +10,7 @@ use http::{HeaderMap, StatusCode}; use once_cell::sync::Lazy; use serde::{Deserialize, Serialize}; use serde_json::Value; -use std::collections::HashMap; +use std::{collections::HashMap, sync::Arc}; #[derive(Serialize)] struct ConfigResponse { @@ -26,7 +26,7 @@ struct ConfigRequest { } async fn store_config( - State(state): State, + State(state): State>, headers: HeaderMap, Json(request): Json, ) -> Result, StatusCode> { @@ -148,7 +148,7 @@ pub struct GetConfigResponse { } pub async fn get_config( - State(state): State, + State(state): State>, headers: HeaderMap, Query(query): Query, ) -> Result, StatusCode> { @@ -174,7 +174,7 @@ struct DeleteConfigRequest { } async fn delete_config( - State(state): State, + State(state): State>, headers: HeaderMap, Json(request): Json, ) -> Result { @@ -193,7 +193,7 @@ async fn delete_config( } } -pub fn routes(state: AppState) -> Router { +pub fn routes(state: Arc) -> Router { Router::new() .route("/configs/providers", post(check_provider_configs)) .route("/configs/get", get(get_config)) diff --git a/crates/goose-server/src/routes/context.rs b/crates/goose-server/src/routes/context.rs index 3415ac75..46a526d0 100644 --- a/crates/goose-server/src/routes/context.rs +++ b/crates/goose-server/src/routes/context.rs @@ -8,6 +8,7 @@ use axum::{ }; use goose::message::Message; use serde::{Deserialize, Serialize}; +use std::sync::Arc; // Direct message serialization for context mgmt request #[derive(Debug, Deserialize)] @@ -26,15 +27,16 @@ pub struct ContextManageResponse { } async fn manage_context( - State(state): State, + State(state): State>, headers: HeaderMap, Json(request): Json, ) -> Result, StatusCode> { verify_secret_key(&headers, &state)?; - // Get a lock on the shared agent - let agent = state.agent.read().await; - let agent = agent.as_ref().ok_or(StatusCode::PRECONDITION_REQUIRED)?; + let agent = state + .get_agent() + .await + .map_err(|_| StatusCode::PRECONDITION_FAILED)?; let mut processed_messages: Vec = vec![]; let mut token_counts: Vec = vec![]; @@ -57,7 +59,7 @@ async fn manage_context( } // Configure routes for this module -pub fn routes(state: AppState) -> Router { +pub fn routes(state: Arc) -> Router { Router::new() .route("/context/manage", post(manage_context)) .with_state(state) diff --git a/crates/goose-server/src/routes/extension.rs b/crates/goose-server/src/routes/extension.rs index cc12e4e9..b45466b6 100644 --- a/crates/goose-server/src/routes/extension.rs +++ b/crates/goose-server/src/routes/extension.rs @@ -1,5 +1,6 @@ use std::env; use std::path::Path; +use std::sync::Arc; use std::sync::OnceLock; use super::utils::verify_secret_key; @@ -79,7 +80,7 @@ struct ExtensionResponse { /// Handler for adding a new extension configuration. async fn add_extension( - State(state): State, + State(state): State>, headers: HeaderMap, raw: axum::extract::Json, ) -> Result, StatusCode> { @@ -228,9 +229,11 @@ async fn add_extension( }, }; - // Acquire a lock on the agent and attempt to add the extension. - let mut agent = state.agent.write().await; - let agent = agent.as_mut().ok_or(StatusCode::PRECONDITION_REQUIRED)?; + // Get a reference to the agent + let agent = state + .get_agent() + .await + .map_err(|_| StatusCode::PRECONDITION_FAILED)?; let response = agent.add_extension(extension_config).await; // Respond with the result. @@ -254,15 +257,17 @@ async fn add_extension( /// Handler for removing an extension by name async fn remove_extension( - State(state): State, + State(state): State>, headers: HeaderMap, Json(name): Json, ) -> Result, StatusCode> { verify_secret_key(&headers, &state)?; - // Acquire a lock on the agent and attempt to remove the extension - let mut agent = state.agent.write().await; - let agent = agent.as_mut().ok_or(StatusCode::PRECONDITION_REQUIRED)?; + // Get a reference to the agent + let agent = state + .get_agent() + .await + .map_err(|_| StatusCode::PRECONDITION_FAILED)?; agent.remove_extension(&name).await; Ok(Json(ExtensionResponse { @@ -272,7 +277,7 @@ async fn remove_extension( } /// Registers the extension management routes with the Axum router. -pub fn routes(state: AppState) -> Router { +pub fn routes(state: Arc) -> Router { Router::new() .route("/extensions/add", post(add_extension)) .route("/extensions/remove", post(remove_extension)) diff --git a/crates/goose-server/src/routes/mod.rs b/crates/goose-server/src/routes/mod.rs index 75878447..b561bccd 100644 --- a/crates/goose-server/src/routes/mod.rs +++ b/crates/goose-server/src/routes/mod.rs @@ -9,10 +9,12 @@ pub mod recipe; pub mod reply; pub mod session; pub mod utils; +use std::sync::Arc; + use axum::Router; // Function to configure all routes -pub fn configure(state: crate::state::AppState) -> Router { +pub fn configure(state: Arc) -> Router { Router::new() .merge(health::routes()) .merge(reply::routes(state.clone())) @@ -22,5 +24,5 @@ pub fn configure(state: crate::state::AppState) -> Router { .merge(configs::routes(state.clone())) .merge(config_management::routes(state.clone())) .merge(recipe::routes(state.clone())) - .merge(session::routes(state)) + .merge(session::routes(state.clone())) } diff --git a/crates/goose-server/src/routes/recipe.rs b/crates/goose-server/src/routes/recipe.rs index 6e9598d6..d2f2df7b 100644 --- a/crates/goose-server/src/routes/recipe.rs +++ b/crates/goose-server/src/routes/recipe.rs @@ -1,3 +1,5 @@ +use std::sync::Arc; + use axum::{extract::State, http::StatusCode, routing::post, Json, Router}; use goose::message::Message; use goose::recipe::Recipe; @@ -34,17 +36,17 @@ pub struct CreateRecipeResponse { /// Create a Recipe configuration from the current state of an agent async fn create_recipe( - State(state): State, + State(state): State>, Json(request): Json, ) -> Result, (StatusCode, Json)> { - let agent = state.agent.read().await; - let agent = agent.as_ref().ok_or_else(|| { - let error_response = CreateRecipeResponse { - recipe: None, - error: Some("Agent not initialized".to_string()), - }; - (StatusCode::PRECONDITION_REQUIRED, Json(error_response)) - })?; + let error_response = CreateRecipeResponse { + recipe: None, + error: Some("Missing agent".to_string()), + }; + let agent = state + .get_agent() + .await + .map_err(|_| (StatusCode::PRECONDITION_FAILED, Json(error_response)))?; // Create base recipe from agent state and messages let recipe_result = agent.create_recipe(request.messages).await; @@ -82,7 +84,7 @@ async fn create_recipe( } } -pub fn routes(state: AppState) -> Router { +pub fn routes(state: Arc) -> Router { Router::new() .route("/recipe/create", post(create_recipe)) .with_state(state) diff --git a/crates/goose-server/src/routes/reply.rs b/crates/goose-server/src/routes/reply.rs index 7781a4fb..764b9065 100644 --- a/crates/goose-server/src/routes/reply.rs +++ b/crates/goose-server/src/routes/reply.rs @@ -26,6 +26,7 @@ use std::{ convert::Infallible, path::PathBuf, pin::Pin, + sync::Arc, task::{Context, Poll}, time::Duration, }; @@ -101,7 +102,7 @@ async fn stream_event( } async fn handler( - State(state): State, + State(state): State>, headers: HeaderMap, Json(request): Json, ) -> Result { @@ -119,15 +120,34 @@ async fn handler( .session_id .unwrap_or_else(session::generate_session_id); - // Get a lock on the shared agent - let agent = state.agent.clone(); - // Spawn task to handle streaming tokio::spawn(async move { - let agent = agent.read().await; - let agent = match agent.as_ref() { - Some(agent) => agent, - None => { + let agent = state.get_agent().await; + let agent = match agent { + Ok(agent) => { + let provider = agent.provider().await; + match provider { + Ok(_) => agent, + Err(_) => { + let _ = stream_event( + MessageEvent::Error { + error: "No provider configured".to_string(), + }, + &tx, + ) + .await; + let _ = stream_event( + MessageEvent::Finish { + reason: "error".to_string(), + }, + &tx, + ) + .await; + return; + } + } + } + Err(_) => { let _ = stream_event( MessageEvent::Error { error: "No agent configured".to_string(), @@ -147,7 +167,7 @@ async fn handler( }; // Get the provider first, before starting the reply stream - let provider = agent.provider(); + let provider = agent.provider().await; let mut stream = match agent .reply( @@ -204,7 +224,7 @@ async fn handler( // Store messages and generate description in background let session_path = session_path.clone(); let messages = all_messages.clone(); - let provider = provider.clone(); + let provider = Arc::clone(provider.as_ref().unwrap()); tokio::spawn(async move { if let Err(e) = session::persist_messages(&session_path, &messages, Some(provider)).await { tracing::error!("Failed to store session history: {:?}", e); @@ -262,7 +282,7 @@ struct AskResponse { // Simple ask an AI for a response, non streaming async fn ask_handler( - State(state): State, + State(state): State>, headers: HeaderMap, Json(request): Json, ) -> Result, StatusCode> { @@ -275,12 +295,13 @@ async fn ask_handler( .session_id .unwrap_or_else(session::generate_session_id); - let agent = state.agent.clone(); - let agent = agent.write().await; - let agent = agent.as_ref().ok_or(StatusCode::NOT_FOUND)?; + let agent = state + .get_agent() + .await + .map_err(|_| StatusCode::PRECONDITION_FAILED)?; // Get the provider first, before starting the reply stream - let provider = agent.provider(); + let provider = agent.provider().await; // Create a single message for the prompt let messages = vec![Message::user().with_text(request.prompt)]; @@ -339,7 +360,7 @@ async fn ask_handler( // Store messages and generate description in background let session_path = session_path.clone(); let messages = all_messages.clone(); - let provider = provider.clone(); + let provider = Arc::clone(provider.as_ref().unwrap()); tokio::spawn(async move { if let Err(e) = session::persist_messages(&session_path, &messages, Some(provider)).await { tracing::error!("Failed to store session history: {:?}", e); @@ -374,15 +395,16 @@ fn default_principal_type() -> PrincipalType { ) )] pub async fn confirm_permission( - State(state): State, + State(state): State>, headers: HeaderMap, Json(request): Json, ) -> Result, StatusCode> { verify_secret_key(&headers, &state)?; - let agent = state.agent.clone(); - let agent = agent.read().await; - let agent = agent.as_ref().ok_or(StatusCode::NOT_FOUND)?; + let agent = state + .get_agent() + .await + .map_err(|_| StatusCode::PRECONDITION_FAILED)?; let permission = match request.action.as_str() { "always_allow" => Permission::AlwaysAllow, @@ -410,7 +432,7 @@ struct ToolResultRequest { } async fn submit_tool_result( - State(state): State, + State(state): State>, headers: HeaderMap, raw: axum::extract::Json, ) -> Result, StatusCode> { @@ -435,14 +457,16 @@ async fn submit_tool_result( } }; - let agent = state.agent.read().await; - let agent = agent.as_ref().ok_or(StatusCode::NOT_FOUND)?; + let agent = state + .get_agent() + .await + .map_err(|_| StatusCode::PRECONDITION_FAILED)?; agent.handle_tool_result(payload.id, payload.result).await; Ok(Json(json!({"status": "ok"}))) } // Configure routes for this module -pub fn routes(state: AppState) -> Router { +pub fn routes(state: Arc) -> Router { Router::new() .route("/reply", post(handler)) .route("/ask", post(ask_handler)) @@ -496,9 +520,7 @@ mod tests { mod integration_tests { use super::*; use axum::{body::Body, http::Request}; - use std::collections::HashMap; use std::sync::Arc; - use tokio::sync::{Mutex, RwLock}; use tower::ServiceExt; // This test requires tokio runtime @@ -509,12 +531,9 @@ mod tests { let mock_provider = Arc::new(MockProvider { model_config: mock_model_config, }); - let agent = Agent::new(mock_provider); - let state = AppState { - config: Arc::new(Mutex::new(HashMap::new())), - agent: Arc::new(RwLock::new(Some(agent))), - secret_key: "test-secret".to_string(), - }; + let agent = Agent::new(); + let _ = agent.update_provider(mock_provider).await; + let state = AppState::new(Arc::new(agent), "test-secret".to_string()).await; // Build router let app = routes(state); diff --git a/crates/goose-server/src/routes/session.rs b/crates/goose-server/src/routes/session.rs index e887e95e..0c69dc5f 100644 --- a/crates/goose-server/src/routes/session.rs +++ b/crates/goose-server/src/routes/session.rs @@ -1,4 +1,6 @@ use super::utils::verify_secret_key; +use std::sync::Arc; + use crate::state::AppState; use axum::{ extract::{Path, State}, @@ -25,7 +27,7 @@ struct SessionHistoryResponse { // List all available sessions async fn list_sessions( - State(state): State, + State(state): State>, headers: HeaderMap, ) -> Result, StatusCode> { verify_secret_key(&headers, &state)?; @@ -38,7 +40,7 @@ async fn list_sessions( // Get a specific session's history async fn get_session_history( - State(state): State, + State(state): State>, headers: HeaderMap, Path(session_id): Path, ) -> Result, StatusCode> { @@ -65,7 +67,7 @@ async fn get_session_history( } // Configure routes for this module -pub fn routes(state: AppState) -> Router { +pub fn routes(state: Arc) -> Router { Router::new() .route("/sessions", get(list_sessions)) .route("/sessions/:session_id", get(get_session_history)) diff --git a/crates/goose-server/src/state.rs b/crates/goose-server/src/state.rs index c98a2711..b7f45f19 100644 --- a/crates/goose-server/src/state.rs +++ b/crates/goose-server/src/state.rs @@ -1,25 +1,34 @@ -use anyhow::Result; use goose::agents::Agent; -use serde_json::Value; -use std::collections::HashMap; use std::sync::Arc; -use tokio::sync::{Mutex, RwLock}; +/// Shared reference to an Agent that can be cloned cheaply +/// without cloning the underlying Agent object +pub type AgentRef = Arc; + +/// Thread-safe container for an optional Agent reference +/// Outer Arc: Allows multiple route handlers to access the same Mutex +/// - Mutex provides exclusive access for updates +/// - Option allows for the case where no agent exists yet +/// /// Shared application state -#[allow(dead_code)] #[derive(Clone)] pub struct AppState { - pub agent: Arc>>, + // agent: SharedAgentStore, + agent: Option, pub secret_key: String, - pub config: Arc>>, } impl AppState { - pub async fn new(secret_key: String) -> Result { - Ok(Self { - agent: Arc::new(RwLock::new(None)), + pub async fn new(agent: AgentRef, secret_key: String) -> Arc { + Arc::new(Self { + agent: Some(agent.clone()), secret_key, - config: Arc::new(Mutex::new(HashMap::new())), }) } + + pub async fn get_agent(&self) -> Result, anyhow::Error> { + self.agent + .clone() + .ok_or_else(|| anyhow::anyhow!("Agent needs to be created first.")) + } } diff --git a/crates/goose/examples/agent.rs b/crates/goose/examples/agent.rs index 75db4742..bc3badac 100644 --- a/crates/goose/examples/agent.rs +++ b/crates/goose/examples/agent.rs @@ -15,7 +15,8 @@ async fn main() { let provider = Arc::new(DatabricksProvider::default()); // Setup an agent with the developer extension - let mut agent = Agent::new(provider); + let agent = Agent::new(); + let _ = agent.update_provider(provider).await; let config = ExtensionConfig::stdio( "developer", diff --git a/crates/goose/src/agents/agent.rs b/crates/goose/src/agents/agent.rs index 9dc5e2bb..2ae19998 100644 --- a/crates/goose/src/agents/agent.rs +++ b/crates/goose/src/agents/agent.rs @@ -35,12 +35,11 @@ use super::tool_execution::{ToolFuture, CHAT_MODE_TOOL_SKIPPED_RESPONSE, DECLINE /// The main goose Agent pub struct Agent { - pub(super) provider: Arc, + pub(super) provider: Mutex>>, pub(super) extension_manager: Mutex, - pub(super) frontend_tools: HashMap, - pub(super) frontend_instructions: Option, - pub(super) prompt_manager: PromptManager, - // Channels for tool results and confirmations + pub(super) frontend_tools: Mutex>, + pub(super) frontend_instructions: Mutex>, + pub(super) prompt_manager: Mutex, pub(super) confirmation_tx: mpsc::Sender<(String, PermissionConfirmation)>, pub(super) confirmation_rx: Mutex>, pub(super) tool_result_tx: mpsc::Sender<(String, ToolResult>)>, @@ -48,41 +47,52 @@ pub struct Agent { } impl Agent { - pub fn new(provider: Arc) -> Self { + pub fn new() -> Self { // Create channels with buffer size 32 (adjust if needed) let (confirm_tx, confirm_rx) = mpsc::channel(32); let (tool_tx, tool_rx) = mpsc::channel(32); Self { - provider, + provider: Mutex::new(None), extension_manager: Mutex::new(ExtensionManager::new()), - frontend_tools: HashMap::new(), - frontend_instructions: None, - prompt_manager: PromptManager::new(), + frontend_tools: Mutex::new(HashMap::new()), + frontend_instructions: Mutex::new(None), + prompt_manager: Mutex::new(PromptManager::new()), confirmation_tx: confirm_tx, confirmation_rx: Mutex::new(confirm_rx), tool_result_tx: tool_tx, tool_result_rx: Arc::new(Mutex::new(tool_rx)), } } +} +impl Default for Agent { + fn default() -> Self { + Self::new() + } +} + +impl Agent { /// Get a reference count clone to the provider - pub fn provider(&self) -> Arc { - Arc::clone(&self.provider) + pub async fn provider(&self) -> Result, anyhow::Error> { + match &*self.provider.lock().await { + Some(provider) => Ok(Arc::clone(provider)), + None => Err(anyhow!("Provider not set")), + } } /// Check if a tool is a frontend tool - pub fn is_frontend_tool(&self, name: &str) -> bool { - self.frontend_tools.contains_key(name) + pub async fn is_frontend_tool(&self, name: &str) -> bool { + self.frontend_tools.lock().await.contains_key(name) } /// Get a reference to a frontend tool - pub fn get_frontend_tool(&self, name: &str) -> Option<&FrontendTool> { - self.frontend_tools.get(name) + pub async fn get_frontend_tool(&self, name: &str) -> Option { + self.frontend_tools.lock().await.get(name).cloned() } /// Get all tools from all clients with proper prefixing - pub async fn get_prefixed_tools(&mut self) -> ExtensionResult> { + pub async fn get_prefixed_tools(&self) -> ExtensionResult> { let mut tools = self .extension_manager .lock() @@ -91,7 +101,8 @@ impl Agent { .await?; // Add frontend tools directly - they don't need prefixing since they're already uniquely named - for frontend_tool in self.frontend_tools.values() { + let frontend_tools = self.frontend_tools.lock().await; + for frontend_tool in frontend_tools.values() { tools.push(frontend_tool.tool.clone()); } @@ -135,7 +146,7 @@ impl Agent { .await } else if tool_call.name == PLATFORM_SEARCH_AVAILABLE_EXTENSIONS_TOOL_NAME { extension_manager.search_available_extensions().await - } else if self.is_frontend_tool(&tool_call.name) { + } else if self.is_frontend_tool(&tool_call.name).await { // For frontend tools, return an error indicating we need frontend execution Err(ToolError::ExecutionError( "Frontend tool execution required".to_string(), @@ -212,7 +223,7 @@ impl Agent { (request_id, result) } - pub async fn add_extension(&mut self, extension: ExtensionConfig) -> ExtensionResult<()> { + pub async fn add_extension(&self, extension: ExtensionConfig) -> ExtensionResult<()> { match &extension { ExtensionConfig::Frontend { name: _, @@ -221,19 +232,21 @@ impl Agent { bundled: _, } => { // For frontend tools, just store them in the frontend_tools map + let mut frontend_tools = self.frontend_tools.lock().await; for tool in tools { let frontend_tool = FrontendTool { name: tool.name.clone(), tool: tool.clone(), }; - self.frontend_tools.insert(tool.name.clone(), frontend_tool); + frontend_tools.insert(tool.name.clone(), frontend_tool); } // Store instructions if provided, using "frontend" as the key + let mut frontend_instructions = self.frontend_instructions.lock().await; if let Some(instructions) = instructions { - self.frontend_instructions = Some(instructions.clone()); + *frontend_instructions = Some(instructions.clone()); } else { // Default frontend instructions if none provided - self.frontend_instructions = Some( + *frontend_instructions = Some( "The following tools are provided directly by the frontend and will be executed by the frontend when called.".to_string(), ); } @@ -269,7 +282,7 @@ impl Agent { prefixed_tools } - pub async fn remove_extension(&mut self, name: &str) { + pub async fn remove_extension(&self, name: &str) { let mut extension_manager = self.extension_manager.lock().await; extension_manager .remove_extension(name) @@ -329,7 +342,7 @@ impl Agent { let _ = reply_span.enter(); loop { match Self::generate_response_from_provider( - self.provider(), + self.provider().await?, &system_prompt, &messages, &tools, @@ -345,7 +358,7 @@ impl Agent { let (frontend_requests, remaining_requests, filtered_response) = - self.categorize_tool_requests(&response); + self.categorize_tool_requests(&response).await; // Yield the assistant's response with frontend tool requests filtered out @@ -396,8 +409,7 @@ impl Agent { tools_with_readonly_annotation.clone(), tools_without_annotation.clone(), &mut permission_manager, - self.provider(), - ).await; + self.provider().await?).await; // Handle pre-approved and read-only tools in parallel let mut tool_futures: Vec = Vec::new(); @@ -492,13 +504,21 @@ impl Agent { } /// Extend the system prompt with one line of additional instruction - pub async fn extend_system_prompt(&mut self, instruction: String) { - self.prompt_manager.add_system_prompt_extra(instruction); + pub async fn extend_system_prompt(&self, instruction: String) { + let mut prompt_manager = self.prompt_manager.lock().await; + prompt_manager.add_system_prompt_extra(instruction); + } + + /// Update the provider used by this agent + pub async fn update_provider(&self, provider: Arc) -> Result<()> { + *self.provider.lock().await = Some(provider); + Ok(()) } /// Override the system prompt with a custom template - pub async fn override_system_prompt(&mut self, template: String) { - self.prompt_manager.set_system_prompt_override(template); + pub async fn override_system_prompt(&self, template: String) { + let mut prompt_manager = self.prompt_manager.lock().await; + prompt_manager.set_system_prompt_override(template); } pub async fn list_extension_prompts(&self) -> HashMap> { @@ -563,23 +583,29 @@ impl Agent { let extensions_info = extension_manager.get_extensions_info().await; // Get model name from provider - let model_config = self.provider.get_model_config(); + let provider = self.provider().await?; + let model_config = provider.get_model_config(); let model_name = &model_config.model_name; - let system_prompt = self.prompt_manager.build_system_prompt( + let prompt_manager = self.prompt_manager.lock().await; + let system_prompt = prompt_manager.build_system_prompt( extensions_info, - self.frontend_instructions.clone(), + self.frontend_instructions.lock().await.clone(), extension_manager.suggest_disable_extensions_prompt().await, Some(model_name), ); - let recipe_prompt = self.prompt_manager.get_recipe_prompt().await; + let recipe_prompt = prompt_manager.get_recipe_prompt().await; let tools = extension_manager.get_prefixed_tools(None).await?; messages.push(Message::user().with_text(recipe_prompt)); let (result, _usage) = self .provider + .lock() + .await + .as_ref() + .unwrap() .complete(&system_prompt, &messages, &tools) .await?; diff --git a/crates/goose/src/agents/context.rs b/crates/goose/src/agents/context.rs index b75c4a55..e52a8a09 100644 --- a/crates/goose/src/agents/context.rs +++ b/crates/goose/src/agents/context.rs @@ -15,7 +15,7 @@ impl Agent { &self, messages: &[Message], // last message is a user msg that led to assistant message with_context_length_exceeded ) -> Result<(Vec, Vec), anyhow::Error> { - let provider = self.provider.clone(); + let provider = self.provider().await?; let token_counter = TokenCounter::new(provider.get_model_config().tokenizer_name()); let target_context_limit = estimate_target_context_limit(provider); let token_counts = get_messages_token_counts(&token_counter, messages); @@ -41,7 +41,7 @@ impl Agent { &self, messages: &[Message], // last message is a user msg that led to assistant message with_context_length_exceeded ) -> Result<(Vec, Vec), anyhow::Error> { - let provider = self.provider.clone(); + let provider = self.provider().await?; let token_counter = TokenCounter::new(provider.get_model_config().tokenizer_name()); let target_context_limit = estimate_target_context_limit(provider.clone()); diff --git a/crates/goose/src/agents/reply_parts.rs b/crates/goose/src/agents/reply_parts.rs index 421a7b18..5d07635d 100644 --- a/crates/goose/src/agents/reply_parts.rs +++ b/crates/goose/src/agents/reply_parts.rs @@ -22,7 +22,8 @@ impl Agent { let mut tools = self.list_tools(None).await; // Add frontend tools - for frontend_tool in self.frontend_tools.values() { + let frontend_tools = self.frontend_tools.lock().await; + for frontend_tool in frontend_tools.values() { tools.push(frontend_tool.tool.clone()); } @@ -31,19 +32,21 @@ impl Agent { let extensions_info = extension_manager.get_extensions_info().await; // Get model name from provider - let model_config = self.provider.get_model_config(); + let provider = self.provider().await?; + let model_config = provider.get_model_config(); let model_name = &model_config.model_name; - let mut system_prompt = self.prompt_manager.build_system_prompt( + let prompt_manager = self.prompt_manager.lock().await; + let mut system_prompt = prompt_manager.build_system_prompt( extensions_info, - self.frontend_instructions.clone(), + self.frontend_instructions.lock().await.clone(), extension_manager.suggest_disable_extensions_prompt().await, Some(model_name), ); // Handle toolshim if enabled let mut toolshim_tools = vec![]; - if self.provider.get_model_config().toolshim { + if model_config.toolshim { // If tool interpretation is enabled, modify the system prompt system_prompt = modify_system_prompt_for_tool_json(&system_prompt, &tools); // Make a copy of tools before emptying @@ -115,7 +118,7 @@ impl Agent { /// - frontend_requests: Tool requests that should be handled by the frontend /// - other_requests: All other tool requests (including requests to enable extensions) /// - filtered_message: The original message with frontend tool requests removed - pub(crate) fn categorize_tool_requests( + pub(crate) async fn categorize_tool_requests( &self, response: &Message, ) -> (Vec, Vec, Message) { @@ -133,20 +136,25 @@ impl Agent { .collect(); // Create a filtered message with frontend tool requests removed - let filtered_content = response - .content - .iter() - .filter(|c| { - if let MessageContent::ToolRequest(req) = c { - // Only filter out frontend tool requests + let mut filtered_content = Vec::new(); + + // Process each content item one by one + for content in &response.content { + let should_include = match content { + MessageContent::ToolRequest(req) => { if let Ok(tool_call) = &req.tool_call { - return !self.is_frontend_tool(&tool_call.name); + !self.is_frontend_tool(&tool_call.name).await + } else { + true } } - true - }) - .cloned() - .collect(); + _ => true, + }; + + if should_include { + filtered_content.push(content.clone()); + } + } let filtered_message = Message { role: response.role.clone(), @@ -160,7 +168,7 @@ impl Agent { for request in tool_requests { if let Ok(tool_call) = &request.tool_call { - if self.is_frontend_tool(&tool_call.name) { + if self.is_frontend_tool(&tool_call.name).await { frontend_requests.push(request); } else { other_requests.push(request); diff --git a/crates/goose/src/agents/tool_execution.rs b/crates/goose/src/agents/tool_execution.rs index ae718a36..beffae66 100644 --- a/crates/goose/src/agents/tool_execution.rs +++ b/crates/goose/src/agents/tool_execution.rs @@ -87,7 +87,7 @@ impl Agent { try_stream! { for request in tool_requests { if let Ok(tool_call) = request.tool_call.clone() { - if self.is_frontend_tool(&tool_call.name) { + if self.is_frontend_tool(&tool_call.name).await { // Send frontend tool request and wait for response yield Message::assistant().with_frontend_tool_request( request.id.clone(), diff --git a/crates/goose/src/config/base.rs b/crates/goose/src/config/base.rs index cdaf3c65..65887a9a 100644 --- a/crates/goose/src/config/base.rs +++ b/crates/goose/src/config/base.rs @@ -739,10 +739,10 @@ mod tests { thread::sleep(Duration::from_millis(i * 10)); let extension_key = format!("extension_{}", i); - let mut values = config.load_values()?; - values.insert( - extension_key.clone(), + // Use set_param which handles concurrent access properly + config.set_param( + &extension_key, serde_json::json!({ "name": format!("test_extension_{}", i), "version": format!("1.0.{}", i), @@ -752,10 +752,7 @@ mod tests { "option2": i } }), - ); - - // Write all values atomically - config.save_values(values)?; + )?; Ok(()) }); handles.push(handle); diff --git a/crates/goose/tests/agent.rs b/crates/goose/tests/agent.rs index ee46735f..c4365fd2 100644 --- a/crates/goose/tests/agent.rs +++ b/crates/goose/tests/agent.rs @@ -110,7 +110,8 @@ async fn run_truncate_test( .with_temperature(Some(0.0)); let provider = provider_type.create_provider(model_config)?; - let agent = Agent::new(provider); + let agent = Agent::new(); + agent.update_provider(provider).await?; let repeat_count = context_window + 10_000; let large_message_content = "hello ".repeat(repeat_count); let messages = vec![ diff --git a/examples/frontend_tools.py b/examples/frontend_tools.py index 0e032dc2..d8824f3a 100644 --- a/examples/frontend_tools.py +++ b/examples/frontend_tools.py @@ -33,12 +33,28 @@ CALCULATOR_TOOL = { }, } +# Enable Extension tool definition +ENABLE_EXTENSION_TOOL = { + "name": "enable_extension", + "description": "Enable extensions to help complete tasks. Enable an extension by providing the extension name.", + "inputSchema": { + "type": "object", + "required": ["extension_name"], + "properties": { + "extension_name": { + "type": "string", + "description": "The name of the extension to enable", + }, + }, + }, +} + # Frontend extension configuration FRONTEND_CONFIG = { "name": "pythonclient", "type": "frontend", - "tools": [CALCULATOR_TOOL], - "instructions": "A calculator extension that can perform basic arithmetic operations.", + "tools": [CALCULATOR_TOOL, ENABLE_EXTENSION_TOOL], + "instructions": "A calculator extension that can perform basic arithmetic operations. Use enable extension tool to add extesions such as fetch, pdf reader, etc.", } @@ -47,7 +63,7 @@ async def setup_agent() -> None: async with httpx.AsyncClient() as client: # First create the agent response = await client.post( - f"{GOOSE_URL}/agent", + f"{GOOSE_URL}/agent/update_provider", json={"provider": "databricks", "model": "goose"}, headers={"X-Secret-Key": SECRET_KEY}, ) @@ -101,6 +117,55 @@ def execute_calculator(args: Dict[str, Any]) -> List[Dict[str, Any]]: } ] +def get_tools() -> Dict[str, Any]: + with httpx.Client() as client: + response = client.get( + f"{GOOSE_URL}/agent/tools", + headers={"X-Secret-Key": SECRET_KEY}, + ) + response.raise_for_status() + return response.json() + + +def execute_enable_extension(args: Dict[str, Any]) -> List[Dict[str, Any]]: + """ + Execute the enable_extension tool. + This function fetches available extensions, finds the one with the provided extension_name, + and posts its configuration to the /extensions/add endpoint. + """ + extension = args + extension_name = extension.get("name") + + # Post the extension configuration to enable it + with httpx.Client() as client: + payload = { + "type": extension.get("type"), + "name": extension.get("name"), + "cmd": extension.get("cmd"), + "args": extension.get("args"), + "envs": extension.get("envs", {}), + "timeout": extension.get("timeout"), + "bundled": extension.get("bundled"), + } + add_response = client.post( + f"{GOOSE_URL}/extensions/add", + json=payload, + headers={"Content-Type": "application/json", "X-Secret-Key": SECRET_KEY}, + ) + if add_response.status_code != 200: + error_text = add_response.text + return [{ + "type": "text", + "text": f"Error: Failed to enable extension: {error_text}", + "annotations": None, + }] + + return [{ + "type": "text", + "text": f"Successfully enabled extension: {extension_name}", + "annotations": None, + }] + def submit_tool_result(tool_id: str, result: List[Dict[str, Any]]) -> None: """Submit the tool execution result back to Goose. @@ -129,7 +194,7 @@ async def chat_loop() -> None: session_id = "test-session" # Use a client with a longer timeout for streaming - async with httpx.AsyncClient(timeout=30.0) as client: + async with httpx.AsyncClient(timeout=60.0) as client: # Get user input user_message = input("\nYou: ") if user_message.lower() in ["exit", "quit"]: @@ -152,7 +217,7 @@ async def chat_loop() -> None: # Process the stream of responses async with client.stream( "POST", - f"{GOOSE_URL}/reply", + f"{GOOSE_URL}/reply", # lock json=payload, headers={ "X-Secret-Key": SECRET_KEY, @@ -185,9 +250,26 @@ async def chat_loop() -> None: elif content["type"] == "frontendToolRequest": # Execute the tool and submit results tool_call = content["toolCall"]["value"] - print(f"Calculator: {tool_call}") - # Execute the tool - result = execute_calculator(tool_call["arguments"]) + print(f"\nTool Request: {tool_call}") + + if tool_call['name'] == "calculator": + print(f"Calculator: {tool_call}") + # Execute the tool + result = execute_calculator(tool_call["arguments"]) + elif tool_call['name'] == "enable_extension": + # to trigger this tool, use the instruction "use enable_extension tool with "fetch" extension name" + print(f"Enabling fetch extension") + result = execute_enable_extension(args={ + "type": "stdio", + "name": "fetch", + "cmd": "uvx", + "args": ["mcp-server-fetch"], + "timeout": 300, + "bundled": False + }) + listed_tools = get_tools() + print(f"\nTools after enabling extension: {listed_tools}") + # Submit the result submit_tool_result(content["id"], result) diff --git a/ui/desktop/src/agent/index.ts b/ui/desktop/src/agent/index.ts index 0e80b957..5e980570 100644 --- a/ui/desktop/src/agent/index.ts +++ b/ui/desktop/src/agent/index.ts @@ -6,7 +6,7 @@ interface initializeAgentProps { } export async function initializeAgent({ model, provider }: initializeAgentProps) { - const response = await fetch(getApiUrl('/agent'), { + const response = await fetch(getApiUrl('/agent/update_provider'), { method: 'POST', headers: { 'Content-Type': 'application/json',