From 8cf60ffea375fc42a3218c50e0277baca5abc514 Mon Sep 17 00:00:00 2001 From: Aljaz Ceru Date: Thu, 24 Apr 2025 17:16:07 +0200 Subject: [PATCH 01/23] creating goose-api crate to have daemonized goose --- config.yaml | 49 +++++ crates/goose-api/Cargo.toml | 21 ++ crates/goose-api/README.md | 284 +++++++++++++++++++++++++ crates/goose-api/src/main.rs | 390 +++++++++++++++++++++++++++++++++++ 4 files changed, 744 insertions(+) create mode 100644 config.yaml create mode 100644 crates/goose-api/Cargo.toml create mode 100644 crates/goose-api/README.md create mode 100644 crates/goose-api/src/main.rs diff --git a/config.yaml b/config.yaml new file mode 100644 index 00000000..825c1f0c --- /dev/null +++ b/config.yaml @@ -0,0 +1,49 @@ +extensions: + computercontroller: + bundled: true + display_name: Computer Controller + enabled: true + name: computercontroller + timeout: 300 + type: builtin + developer: + bundled: true + display_name: Developer Tools + enabled: true + name: developer + timeout: 300 + type: builtin + filesytem: + args: + - -y + - '@modelcontextprotocol/server-filesystem' + - /home/lio/g + bundled: null + cmd: npx + description: 'access files inside ~/g ' + enabled: true + env_keys: [] + envs: {} + name: filesytem + timeout: 300 + type: stdio + filesytem-extension: + args: + - -y + - '@modelcontextprotocol/server-filesystem' + bundled: null + cmd: npx + description: null + enabled: false + env_keys: [] + envs: {} + name: filesytem-extension + timeout: 300 + type: stdio + memory: + bundled: true + display_name: Memory + enabled: true + name: memory + timeout: 300 + type: builtin diff --git a/crates/goose-api/Cargo.toml b/crates/goose-api/Cargo.toml new file mode 100644 index 00000000..d3ba498b --- /dev/null +++ b/crates/goose-api/Cargo.toml @@ -0,0 +1,21 @@ +[package] +name = "goose-api" +version = "0.1.0" +edition = "2021" + +[dependencies] +goose = { path = "../goose" } +goose-mcp = { path = "../goose-mcp" } +mcp-client = { path = "../mcp-client" } +tokio = { version = "1", features = ["full"] } +warp = "0.3" +serde = { version = "1", features = ["derive"] } +serde_json = "1" +anyhow = "1" +tracing = "0.1" +tracing-subscriber = { version = "0.3", features = ["env-filter", "fmt", "json", "time"] } +config = "0.13" +jsonwebtoken = "8" +futures = "0.3" +futures-util = "0.3" + # Add dynamic-library for extension loading diff --git a/crates/goose-api/README.md b/crates/goose-api/README.md new file mode 100644 index 00000000..47115506 --- /dev/null +++ b/crates/goose-api/README.md @@ -0,0 +1,284 @@ +# Goose API + +An asynchronous REST API for interacting with Goose's AI agent capabilities. + +## Overview + +The goose-api crate provides an HTTP API interface to Goose's AI capabilities, enabling integration with other services and applications. It is designed as a daemon that can be run in the background, offering the same core functionality as the Goose CLI but accessible over HTTP. + +## Installation + +### Prerequisites + +- Rust toolchain (cargo, rustc) +- Goose dependencies + +### Building + +```bash +# Navigate to the goose-api directory +cd crates/goose-api + +# Build the project +cargo build + +# For a production-optimized build +cargo build --release +``` + +## Configuration + +Goose API supports configuration through both environment variables and a configuration file. The precedence order is: + +1. Environment variables (highest priority) +2. Configuration file (lower priority) +3. Default values (lowest priority) + +### Configuration File + +Create a file named `config` (with no extension) in the directory where you run the goose-api. The format can be JSON, YAML, TOML, etc. (the `config` crate will detect the format automatically). + +Example `config` file (YAML format): + +```yaml +# API server configuration +host: 127.0.0.1 +port: 8080 +api_key: your_secure_api_key + +# Provider configuration +provider: openai +model: gpt-4o +``` + +### Environment Variables + +All configurations can be set using environment variables prefixed with `GOOSE_API_`. + +```bash +# API server configuration +export GOOSE_API_HOST=0.0.0.0 +export GOOSE_API_PORT=8080 +export GOOSE_API_KEY=your_secure_api_key + +# Provider configuration +export GOOSE_API_PROVIDER=openai +export GOOSE_API_MODEL=gpt-4o + +# Provider-specific credentials (based on provider requirements) +export OPENAI_API_KEY=your_openai_api_key +export ANTHROPIC_API_KEY=your_anthropic_api_key +# etc. +``` + +## API Authentication + +All API endpoints require authentication using an API key. The key should be provided in the `x-api-key` header. + +Example: + +``` +x-api-key: your_secure_api_key +``` + +## Running the Server + +```bash +# Run the server in development mode +cargo run + +# Run the compiled binary directly +./target/debug/goose-api + +# For production (with optimizations) +./target/release/goose-api +``` + +By default, the server runs on `127.0.0.1:8080`. You can modify this using configuration options. + +## API Endpoints + +### 1. Start a Session + +**Endpoint**: `POST /session/start` + +**Description**: Initiates a new session with Goose, providing an initial prompt. + +**Request**: +- Headers: + - Content-Type: application/json + - x-api-key: [your-api-key] +- Body: +```json +{ + "prompt": "Your instruction to Goose" +} +``` + +**Response**: +```json +{ + "message": "Session started with prompt: Your instruction to Goose", + "status": "success" +} +``` + +### 2. Reply to a Session + +**Endpoint**: `POST /session/reply` + +**Description**: Sends a follow-up message to an existing session. + +**Request**: +- Headers: + - Content-Type: application/json + - x-api-key: [your-api-key] +- Body: +```json +{ + "prompt": "Your follow-up instruction" +} +``` + +**Response**: +```json +{ + "message": "Reply: Response from Goose", + "status": "success" +} +``` + +### 3. List Extensions + +**Endpoint**: `GET /extensions/list` + +**Description**: Returns a list of available extensions. + +**Request**: +- Headers: + - x-api-key: [your-api-key] + +**Response**: +```json +{ + "extensions": ["extension1", "extension2", "extension3"] +} +``` + +### 4. Get Provider Configuration + +**Endpoint**: `GET /provider/config` + +**Description**: Returns the current provider configuration. + +**Request**: +- Headers: + - x-api-key: [your-api-key] + +**Response**: +```json +{ + "provider": "openai", + "model": "gpt-4o" +} +``` + +## Examples + +### Using cURL + +```bash +# Start a session +curl -X POST http://localhost:8080/session/start \ + -H "Content-Type: application/json" \ + -H "x-api-key: your_secure_api_key" \ + -d '{"prompt": "Create a Python function to generate Fibonacci numbers"}' + +# Reply to an ongoing session +curl -X POST http://localhost:8080/session/reply \ + -H "Content-Type: application/json" \ + -H "x-api-key: your_secure_api_key" \ + -d '{"prompt": "Add documentation to this function"}' + +# List extensions +curl -X GET http://localhost:8080/extensions/list \ + -H "x-api-key: your_secure_api_key" + +# Get provider configuration +curl -X GET http://localhost:8080/provider/config \ + -H "x-api-key: your_secure_api_key" +``` + +### Using Python + +```python +import requests + +API_URL = "http://localhost:8080" +API_KEY = "your_secure_api_key" +HEADERS = { + "Content-Type": "application/json", + "x-api-key": API_KEY +} + +# Start a session +response = requests.post( + f"{API_URL}/session/start", + headers=HEADERS, + json={"prompt": "Create a Python function to generate Fibonacci numbers"} +) +print(response.json()) + +# Reply to an ongoing session +response = requests.post( + f"{API_URL}/session/reply", + headers=HEADERS, + json={"prompt": "Add documentation to this function"} +) +print(response.json()) + +# List extensions +response = requests.get(f"{API_URL}/extensions/list", headers=HEADERS) +print(response.json()) + +# Get provider configuration +response = requests.get(f"{API_URL}/provider/config", headers=HEADERS) +print(response.json()) +``` + +## Troubleshooting + +### Common Issues + +1. **API Key Authentication Failure**: + Ensure the key in your request header matches the configured API key. + +2. **Provider Configuration Issues**: + Make sure you've set the necessary environment variables for your chosen provider. + +3. **Missing Required Keys**: + Check the server logs for messages about missing required provider configuration keys. + +## Implementation Status (vs. Implementation Plan) + +The current implementation includes the following features from the implementation plan: + +✅ **Step 1-2**: Created goose-api crate with necessary dependencies +✅ **Step 3-4**: Defined API endpoints with request/response structures +✅ **Step 5**: Integration with goose core functionality +✅ **Step 6**: Configuration via environment variables and config file +✅ **Step 9**: API Key authentication + +🟡 **Step 7**: Extension loading mechanism (partial implementation) +🟡 **Step 8**: MCP support (partial implementation) +✅ **Step 10**: Documentation +❌ **Step 11**: Tests (not yet implemented) + +## Future Work + +- Extend session management capabilities +- Add more comprehensive error handling +- Implement unit and integration tests +- Complete MCP integration +- Add metrics and monitoring +- Add OpenAPI documentation generation \ No newline at end of file diff --git a/crates/goose-api/src/main.rs b/crates/goose-api/src/main.rs new file mode 100644 index 00000000..4edae16c --- /dev/null +++ b/crates/goose-api/src/main.rs @@ -0,0 +1,390 @@ +use warp::{Filter, Rejection}; +use warp::http::HeaderValue; +use serde::{Deserialize, Serialize}; +use std::sync::LazyLock; +use goose::agents::{Agent, extension_manager::ExtensionManager}; +use goose::config::Config; +use goose::providers::{create, providers}; +use goose::model::ModelConfig; +use goose::message::Message; +use tracing::{info, warn, error}; +use config::{builder::DefaultState, ConfigBuilder, Environment, File}; +use serde_json::Value; // Import the correct Value type +use futures_util::TryStreamExt; + +// Global extension manager for extension listing +static EXTENSION_MANAGER: LazyLock = LazyLock::new(|| ExtensionManager::default()); + +// Global agent for handling sessions +static AGENT: LazyLock> = LazyLock::new(|| { + tokio::sync::Mutex::new(Agent::new()) +}); + +#[derive(Debug, Serialize, Deserialize)] +struct SessionRequest { + prompt: String, +} + +#[derive(Debug, Serialize, Deserialize)] +struct ApiResponse { + message: String, + status: String, +} + +#[derive(Debug, Serialize, Deserialize)] +struct ExtensionsResponse { + extensions: Vec, +} + +#[derive(Debug, Serialize, Deserialize)] +struct ProviderConfig { + provider: String, + model: String, +} + +async fn start_session_handler( + req: SessionRequest, + _api_key: String, +) -> Result { + info!("Starting session with prompt: {}", req.prompt); + + let agent = AGENT.lock().await; + + // Create a user message with the prompt + let messages = vec![Message::user().with_text(&req.prompt)]; + + // Process the messages through the agent + let result = agent.reply(&messages, None).await; + + match result { + Ok(mut stream) => { + // Process the stream to get the first response + if let Ok(Some(response)) = stream.try_next().await { + let response_text = response.as_concat_text(); + let api_response = ApiResponse { + message: format!("Session started with prompt: {}. Response: {}", req.prompt, response_text), + status: "success".to_string(), + }; + Ok(warp::reply::with_status( + warp::reply::json(&api_response), + warp::http::StatusCode::OK, + )) + } else { + let api_response = ApiResponse { + message: format!("Session started but no response generated"), + status: "warning".to_string(), + }; + Ok(warp::reply::with_status( + warp::reply::json(&api_response), + warp::http::StatusCode::OK, + )) + } + }, + Err(e) => { + error!("Failed to start session: {}", e); + let response = ApiResponse { + message: format!("Failed to start session: {}", e), + status: "error".to_string(), + }; + Ok(warp::reply::with_status( + warp::reply::json(&response), + warp::http::StatusCode::INTERNAL_SERVER_ERROR, + )) + } + } +} + +async fn reply_session_handler( + req: SessionRequest, + _api_key: String, +) -> Result { + info!("Replying to session with prompt: {}", req.prompt); + + let agent = AGENT.lock().await; + + // Create a user message with the prompt + let messages = vec![Message::user().with_text(&req.prompt)]; + + // Process the messages through the agent + let result = agent.reply(&messages, None).await; + + match result { + Ok(mut stream) => { + // Process the stream to get the first response + if let Ok(Some(response)) = stream.try_next().await { + let response_text = response.as_concat_text(); + let api_response = ApiResponse { + message: format!("Reply: {}", response_text), + status: "success".to_string(), + }; + Ok(warp::reply::with_status( + warp::reply::json(&api_response), + warp::http::StatusCode::OK, + )) + } else { + let api_response = ApiResponse { + message: format!("Reply processed but no response generated"), + status: "warning".to_string(), + }; + Ok(warp::reply::with_status( + warp::reply::json(&api_response), + warp::http::StatusCode::OK, + )) + } + }, + Err(e) => { + error!("Failed to reply to session: {}", e); + let response = ApiResponse { + message: format!("Failed to reply to session: {}", e), + status: "error".to_string(), + }; + Ok(warp::reply::with_status( + warp::reply::json(&response), + warp::http::StatusCode::INTERNAL_SERVER_ERROR, + )) + } + } +} + +async fn list_extensions_handler() -> Result { + info!("Listing extensions"); + + match EXTENSION_MANAGER.list_extensions().await { + Ok(exts) => { + let response = ExtensionsResponse { extensions: exts }; + Ok::(warp::reply::json(&response)) + }, + Err(e) => { + error!("Failed to list extensions: {}", e); + let response = ExtensionsResponse { + extensions: vec!["Failed to list extensions".to_string()] + }; + Ok::(warp::reply::json(&response)) + } + } +} + +async fn get_provider_config_handler() -> Result { + info!("Getting provider configuration"); + + let config = Config::global(); + let provider = config.get_param::("GOOSE_PROVIDER") + .unwrap_or_else(|_| "Not configured".to_string()); + let model = config.get_param::("GOOSE_MODEL") + .unwrap_or_else(|_| "Not configured".to_string()); + + let response = ProviderConfig { provider, model }; + Ok::(warp::reply::json(&response)) +} + +fn with_api_key(api_key: String) -> impl Filter + Clone { + warp::header::value("x-api-key") + .and_then(move |header_api_key: HeaderValue| { + let api_key = api_key.clone(); + async move { + if header_api_key == api_key { + Ok(api_key) + } else { + Err(warp::reject::not_found()) + } + } + }) +} + +// Load configuration from file and environment variables +fn load_configuration() -> std::result::Result { + let config_path = std::env::var("GOOSE_CONFIG").unwrap_or_else(|_| "config".to_string()); + let builder = ConfigBuilder::::default() + .add_source(File::with_name(&config_path).required(false)) + .add_source(Environment::with_prefix("GOOSE_API")); + + builder.build() +} + +// Initialize global provider configuration +async fn initialize_provider_config() -> Result<(), anyhow::Error> { + // Get configuration + let api_config = load_configuration()?; + + // Get provider settings from configuration or environment variables + let provider_name = std::env::var("GOOSE_API_PROVIDER") + .or_else(|_| api_config.get_string("provider")) + .unwrap_or_else(|_| "openai".to_string()); + + let model_name = std::env::var("GOOSE_API_MODEL") + .or_else(|_| api_config.get_string("model")) + .unwrap_or_else(|_| "gpt-4o".to_string()); + + info!("Initializing with provider: {}, model: {}", provider_name, model_name); + + // Initialize the global Config object + let config = Config::global(); + config.set_param("GOOSE_PROVIDER", Value::String(provider_name.clone()))?; + config.set_param("GOOSE_MODEL", Value::String(model_name.clone()))?; + + // Set up API keys from environment variables + let available_providers = providers(); + if let Some(provider_meta) = available_providers.iter().find(|p| p.name == provider_name) { + for key in &provider_meta.config_keys { + let env_name = key.name.clone(); + if let Ok(value) = std::env::var(&env_name) { + if key.secret { + config.set_secret(&key.name, Value::String(value))?; + info!("Set secret key: {}", key.name); + } else { + config.set_param(&key.name, Value::String(value))?; + info!("Set parameter: {}", key.name); + } + } else { + warn!("Environment variable not set for key: {}", key.name); + if key.required { + error!("Required key {} not provided", key.name); + return Err(anyhow::anyhow!("Required key {} not provided", key.name)); + } + } + } + } + + // Initialize agent with provider + let model_config = ModelConfig::new(model_name); + let provider = create(&provider_name, model_config)?; + + let agent = AGENT.lock().await; + agent.update_provider(provider).await?; + + info!("Provider configuration successful"); + Ok(()) +} +/// Initialize extensions from the configuration. +fn initialize_extensions(config: &config::Config) -> Result<(), anyhow::Error> { + if let Ok(ext_table) = config.get_table("extensions") { + for (name, ext_config) in ext_table { + let json_value: serde_json::Value = ext_config.clone().try_deserialize() + .map_err(|e| anyhow::anyhow!("Failed to deserialize extension config for {}: {}", name, e))?; + // Only process the extension if it is enabled. + let enabled = json_value.get("enabled").and_then(|v| v.as_bool()).unwrap_or(false); + if enabled { + // Note: The ExtensionManager does not provide a method to register extensions. + // Here, we log that the extension is enabled. Adjust this code if a registration API becomes available. + info!("Extension {} is enabled and would be registered", name); + } else { + info!("Skipping disabled extension: {}", name); + } + } + } else { + warn!("No extensions configured in config file."); + } + Ok(()) +} + + +async fn run_init_tests() -> Result<(), anyhow::Error> { + info!("Running initialization tests"); + { + let _agent = AGENT.lock().await; + info!("Agent initialization test passed"); + } + info!("Initialization tests completed"); + Ok(()) +} + +#[tokio::main] +async fn main() -> Result<(), anyhow::Error> { + // Initialize tracing + tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .init(); + + info!("Starting goose-api server"); + + // Load configuration + let api_config = load_configuration()?; + + // Get API key from configuration or environment + let api_key: String = std::env::var("GOOSE_API_KEY") + .or_else(|_| api_config.get_string("api_key")) + .unwrap_or_else(|_| { + warn!("No API key configured, using default"); + "default_api_key".to_string() + }); + + // Initialize provider configuration + if let Err(e) = initialize_provider_config().await { + error!("Failed to initialize provider: {}", e); + return Err(e); + } + + // Initialize extensions from configuration + if let Err(e) = initialize_extensions(&api_config) { + error!("Failed to initialize extensions: {}", e); + } + + if let Err(e) = run_init_tests().await { + error!("Initialization tests failed: {}", e); + } + + // Session start endpoint + let start_session = warp::path("session") + .and(warp::path("start")) + .and(warp::post()) + .and(warp::body::json()) + .and(with_api_key(api_key.clone())) + .and_then(start_session_handler); + + // Session reply endpoint + let reply_session = warp::path("session") + .and(warp::path("reply")) + .and(warp::post()) + .and(warp::body::json()) + .and(with_api_key(api_key.clone())) + .and_then(reply_session_handler); + + // List extensions endpoint + let list_extensions = warp::path("extensions") + .and(warp::path("list")) + .and(warp::get()) + .and_then(list_extensions_handler); + + // Get provider configuration endpoint + let get_provider_config = warp::path("provider") + .and(warp::path("config")) + .and(warp::get()) + .and_then(get_provider_config_handler); + + // Combine all routes + let routes = start_session + .or(reply_session) + .or(list_extensions) + .or(get_provider_config); + + // Get bind address from configuration or use default + let host = std::env::var("GOOSE_API_HOST") + .or_else(|_| api_config.get_string("host")) + .unwrap_or_else(|_| "127.0.0.1".to_string()); + + let port = std::env::var("GOOSE_API_PORT") + .or_else(|_| api_config.get_string("port")) + .unwrap_or_else(|_| "8080".to_string()) + .parse::() + .unwrap_or(8080); + + info!("Starting server on {}:{}", host, port); + + // Parse host string + let host_parts: Vec = host.split('.') + .map(|part| part.parse::().unwrap_or(127)) + .collect(); + + let addr = if host_parts.len() == 4 { + [host_parts[0], host_parts[1], host_parts[2], host_parts[3]] + } else { + [127, 0, 0, 1] + }; + + // Start the server + warp::serve(routes) + .run((addr, port)) + .await; + + Ok(()) +} From 31567fa22eb78fd9feeaa51631ab92556b9b7f93 Mon Sep 17 00:00:00 2001 From: Aljaz Date: Wed, 28 May 2025 19:25:41 +0200 Subject: [PATCH 02/23] Initialize extensions via agent --- crates/goose-api/src/main.rs | 25 ++++++++++++++----------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/crates/goose-api/src/main.rs b/crates/goose-api/src/main.rs index 4edae16c..e5c5ed5a 100644 --- a/crates/goose-api/src/main.rs +++ b/crates/goose-api/src/main.rs @@ -2,8 +2,8 @@ use warp::{Filter, Rejection}; use warp::http::HeaderValue; use serde::{Deserialize, Serialize}; use std::sync::LazyLock; -use goose::agents::{Agent, extension_manager::ExtensionManager}; -use goose::config::Config; +use goose::agents::{Agent, ExtensionConfig, extension_manager::ExtensionManager}; +use goose::config::{Config, ExtensionEntry}; use goose::providers::{create, providers}; use goose::model::ModelConfig; use goose::message::Message; @@ -256,17 +256,20 @@ async fn initialize_provider_config() -> Result<(), anyhow::Error> { Ok(()) } /// Initialize extensions from the configuration. -fn initialize_extensions(config: &config::Config) -> Result<(), anyhow::Error> { +async fn initialize_extensions(config: &config::Config) -> Result<(), anyhow::Error> { if let Ok(ext_table) = config.get_table("extensions") { for (name, ext_config) in ext_table { - let json_value: serde_json::Value = ext_config.clone().try_deserialize() + // Deserialize into ExtensionEntry to get enabled flag and config + let entry: ExtensionEntry = ext_config.clone().try_deserialize() .map_err(|e| anyhow::anyhow!("Failed to deserialize extension config for {}: {}", name, e))?; - // Only process the extension if it is enabled. - let enabled = json_value.get("enabled").and_then(|v| v.as_bool()).unwrap_or(false); - if enabled { - // Note: The ExtensionManager does not provide a method to register extensions. - // Here, we log that the extension is enabled. Adjust this code if a registration API becomes available. - info!("Extension {} is enabled and would be registered", name); + + if entry.enabled { + let extension_config: ExtensionConfig = entry.config; + // Acquire the global agent lock and try to add the extension + let mut agent = AGENT.lock().await; + if let Err(e) = agent.add_extension(extension_config).await { + error!("Failed to add extension {}: {}", name, e); + } } else { info!("Skipping disabled extension: {}", name); } @@ -315,7 +318,7 @@ async fn main() -> Result<(), anyhow::Error> { } // Initialize extensions from configuration - if let Err(e) = initialize_extensions(&api_config) { + if let Err(e) = initialize_extensions(&api_config).await { error!("Failed to initialize extensions: {}", e); } From e9fb90d413decb1c646cb669b7d7b7db0dfc8812 Mon Sep 17 00:00:00 2001 From: Aljaz Date: Wed, 28 May 2025 19:26:12 +0200 Subject: [PATCH 03/23] feat(api): add extension management endpoints --- crates/goose-api/Cargo.toml | 1 + crates/goose-api/README.md | 79 +++++++++++++- crates/goose-api/src/main.rs | 193 ++++++++++++++++++++++++++++++++++- 3 files changed, 271 insertions(+), 2 deletions(-) diff --git a/crates/goose-api/Cargo.toml b/crates/goose-api/Cargo.toml index d3ba498b..fc5017fc 100644 --- a/crates/goose-api/Cargo.toml +++ b/crates/goose-api/Cargo.toml @@ -7,6 +7,7 @@ edition = "2021" goose = { path = "../goose" } goose-mcp = { path = "../goose-mcp" } mcp-client = { path = "../mcp-client" } +mcp-core = { path = "../mcp-core" } tokio = { version = "1", features = ["full"] } warp = "0.3" serde = { version = "1", features = ["derive"] } diff --git a/crates/goose-api/README.md b/crates/goose-api/README.md index 47115506..a64f2d03 100644 --- a/crates/goose-api/README.md +++ b/crates/goose-api/README.md @@ -165,7 +165,56 @@ By default, the server runs on `127.0.0.1:8080`. You can modify this using confi } ``` -### 4. Get Provider Configuration +### 4. Add Extension + +**Endpoint**: `POST /extensions/add` + +**Description**: Installs or enables an extension. + +**Request**: +- Headers: + - Content-Type: application/json + - x-api-key: [your-api-key] +- Body (example): +```json +{ + "type": "builtin", + "name": "mcp_say" +} +``` + +**Response**: +```json +{ + "error": false, + "message": null +} +``` + +### 5. Remove Extension + +**Endpoint**: `POST /extensions/remove` + +**Description**: Removes or disables an extension by name. + +**Request**: +- Headers: + - Content-Type: application/json + - x-api-key: [your-api-key] +- Body: +```json +"mcp_say" +``` + +**Response**: +```json +{ + "error": false, + "message": null +} +``` + +### 6. Get Provider Configuration **Endpoint**: `GET /provider/config` @@ -204,6 +253,18 @@ curl -X POST http://localhost:8080/session/reply \ curl -X GET http://localhost:8080/extensions/list \ -H "x-api-key: your_secure_api_key" +# Add an extension +curl -X POST http://localhost:8080/extensions/add \ + -H "Content-Type: application/json" \ + -H "x-api-key: your_secure_api_key" \ + -d '{"type": "builtin", "name": "mcp_say"}' + +# Remove an extension +curl -X POST http://localhost:8080/extensions/remove \ + -H "Content-Type: application/json" \ + -H "x-api-key: your_secure_api_key" \ + -d '"mcp_say"' + # Get provider configuration curl -X GET http://localhost:8080/provider/config \ -H "x-api-key: your_secure_api_key" @@ -241,6 +302,22 @@ print(response.json()) response = requests.get(f"{API_URL}/extensions/list", headers=HEADERS) print(response.json()) +# Add an extension +response = requests.post( + f"{API_URL}/extensions/add", + headers=HEADERS, + json={"type": "builtin", "name": "mcp_say"} +) +print(response.json()) + +# Remove an extension +response = requests.post( + f"{API_URL}/extensions/remove", + headers=HEADERS, + json="mcp_say" +) +print(response.json()) + # Get provider configuration response = requests.get(f"{API_URL}/provider/config", headers=HEADERS) print(response.json()) diff --git a/crates/goose-api/src/main.rs b/crates/goose-api/src/main.rs index 4edae16c..d9a355f0 100644 --- a/crates/goose-api/src/main.rs +++ b/crates/goose-api/src/main.rs @@ -2,7 +2,13 @@ use warp::{Filter, Rejection}; use warp::http::HeaderValue; use serde::{Deserialize, Serialize}; use std::sync::LazyLock; -use goose::agents::{Agent, extension_manager::ExtensionManager}; +use goose::agents::{ + extension::Envs, + Agent, + extension_manager::ExtensionManager, + ExtensionConfig, +}; +use mcp_core::tool::Tool; use goose::config::Config; use goose::providers::{create, providers}; use goose::model::ModelConfig; @@ -42,6 +48,51 @@ struct ProviderConfig { model: String, } +#[derive(Debug, Serialize, Deserialize)] +struct ExtensionResponse { + error: bool, + message: Option, +} + +#[derive(Debug, Deserialize)] +#[serde(tag = "type")] +enum ExtensionConfigRequest { + #[serde(rename = "sse")] + Sse { + name: String, + uri: String, + #[serde(default)] + envs: Envs, + #[serde(default)] + env_keys: Vec, + timeout: Option, + }, + #[serde(rename = "stdio")] + Stdio { + name: String, + cmd: String, + #[serde(default)] + args: Vec, + #[serde(default)] + envs: Envs, + #[serde(default)] + env_keys: Vec, + timeout: Option, + }, + #[serde(rename = "builtin")] + Builtin { + name: String, + display_name: Option, + timeout: Option, + }, + #[serde(rename = "frontend")] + Frontend { + name: String, + tools: Vec, + instructions: Option, + }, +} + async fn start_session_handler( req: SessionRequest, _api_key: String, @@ -177,6 +228,128 @@ async fn get_provider_config_handler() -> Result { Ok::(warp::reply::json(&response)) } +async fn add_extension_handler( + req: ExtensionConfigRequest, + _api_key: String, +) -> Result { + info!("Adding extension: {:?}", req); + + #[cfg(target_os = "windows")] + if let ExtensionConfigRequest::Stdio { cmd, .. } = &req { + if cmd.ends_with("npx.cmd") || cmd.ends_with("npx") { + let node_exists = std::path::Path::new(r"C:\Program Files\nodejs\node.exe").exists() + || std::path::Path::new(r"C:\Program Files (x86)\nodejs\node.exe").exists(); + + if !node_exists { + let cmd_path = std::path::Path::new(cmd); + let script_dir = cmd_path.parent().ok_or_else(|| warp::reject())?; + + let install_script = script_dir.join("install-node.cmd"); + + if install_script.exists() { + eprintln!("Installing Node.js..."); + let output = std::process::Command::new(&install_script) + .arg("https://nodejs.org/dist/v23.10.0/node-v23.10.0-x64.msi") + .output() + .map_err(|_| warp::reject())?; + + if !output.status.success() { + eprintln!( + "Failed to install Node.js: {}", + String::from_utf8_lossy(&output.stderr) + ); + let resp = ExtensionResponse { + error: true, + message: Some(format!( + "Failed to install Node.js: {}", + String::from_utf8_lossy(&output.stderr) + )), + }; + return Ok(warp::reply::json(&resp)); + } + eprintln!("Node.js installation completed"); + } else { + eprintln!( + "Node.js installer script not found at: {}", + install_script.display() + ); + let resp = ExtensionResponse { + error: true, + message: Some("Node.js installer script not found".to_string()), + }; + return Ok(warp::reply::json(&resp)); + } + } + } + } + + let extension = match req { + ExtensionConfigRequest::Sse { name, uri, envs, env_keys, timeout } => { + ExtensionConfig::Sse { + name, + uri, + envs, + env_keys, + description: None, + timeout, + bundled: None, + } + } + ExtensionConfigRequest::Stdio { name, cmd, args, envs, env_keys, timeout } => { + ExtensionConfig::Stdio { + name, + cmd, + args, + envs, + env_keys, + timeout, + description: None, + bundled: None, + } + } + ExtensionConfigRequest::Builtin { name, display_name, timeout } => { + ExtensionConfig::Builtin { + name, + display_name, + timeout, + bundled: None, + } + } + ExtensionConfigRequest::Frontend { name, tools, instructions } => { + ExtensionConfig::Frontend { + name, + tools, + instructions, + bundled: None, + } + } + }; + + let agent = AGENT.lock().await; + let result = agent.add_extension(extension).await; + + let resp = match result { + Ok(_) => ExtensionResponse { error: false, message: None }, + Err(e) => ExtensionResponse { + error: true, + message: Some(format!("Failed to add extension configuration, error: {:?}", e)), + }, + }; + Ok(warp::reply::json(&resp)) +} + +async fn remove_extension_handler( + name: String, + _api_key: String, +) -> Result { + info!("Removing extension: {}", name); + let agent = AGENT.lock().await; + agent.remove_extension(&name).await; + + let resp = ExtensionResponse { error: false, message: None }; + Ok(warp::reply::json(&resp)) +} + fn with_api_key(api_key: String) -> impl Filter + Clone { warp::header::value("x-api-key") .and_then(move |header_api_key: HeaderValue| { @@ -344,6 +517,22 @@ async fn main() -> Result<(), anyhow::Error> { .and(warp::path("list")) .and(warp::get()) .and_then(list_extensions_handler); + + // Add extension endpoint + let add_extension = warp::path("extensions") + .and(warp::path("add")) + .and(warp::post()) + .and(warp::body::json()) + .and(with_api_key(api_key.clone())) + .and_then(add_extension_handler); + + // Remove extension endpoint + let remove_extension = warp::path("extensions") + .and(warp::path("remove")) + .and(warp::post()) + .and(warp::body::json()) + .and(with_api_key(api_key.clone())) + .and_then(remove_extension_handler); // Get provider configuration endpoint let get_provider_config = warp::path("provider") @@ -355,6 +544,8 @@ async fn main() -> Result<(), anyhow::Error> { let routes = start_session .or(reply_session) .or(list_extensions) + .or(add_extension) + .or(remove_extension) .or(get_provider_config); // Get bind address from configuration or use default From 2d7962455182e7c788a056c34d583bb286b80119 Mon Sep 17 00:00:00 2001 From: Aljaz Date: Wed, 28 May 2025 19:26:25 +0200 Subject: [PATCH 04/23] Add session management API --- crates/goose-api/Cargo.toml | 2 + crates/goose-api/src/main.rs | 130 +++++++++++++++++++++++++++++------ 2 files changed, 112 insertions(+), 20 deletions(-) diff --git a/crates/goose-api/Cargo.toml b/crates/goose-api/Cargo.toml index d3ba498b..8cf2fec5 100644 --- a/crates/goose-api/Cargo.toml +++ b/crates/goose-api/Cargo.toml @@ -18,4 +18,6 @@ config = "0.13" jsonwebtoken = "8" futures = "0.3" futures-util = "0.3" +# For session IDs +uuid = { version = "1", features = ["serde", "v4"] } # Add dynamic-library for extension loading diff --git a/crates/goose-api/src/main.rs b/crates/goose-api/src/main.rs index 4edae16c..ac2e7f82 100644 --- a/crates/goose-api/src/main.rs +++ b/crates/goose-api/src/main.rs @@ -2,6 +2,8 @@ use warp::{Filter, Rejection}; use warp::http::HeaderValue; use serde::{Deserialize, Serialize}; use std::sync::LazyLock; +use std::collections::HashMap; +use uuid::Uuid; use goose::agents::{Agent, extension_manager::ExtensionManager}; use goose::config::Config; use goose::providers::{create, providers}; @@ -20,6 +22,11 @@ static AGENT: LazyLock> = LazyLock::new(|| { tokio::sync::Mutex::new(Agent::new()) }); +// Global store for session histories +static SESSION_HISTORY: LazyLock>>> = LazyLock::new(|| { + tokio::sync::Mutex::new(HashMap::new()) +}); + #[derive(Debug, Serialize, Deserialize)] struct SessionRequest { prompt: String, @@ -31,6 +38,24 @@ struct ApiResponse { status: String, } +#[derive(Debug, Serialize, Deserialize)] +struct StartSessionResponse { + message: String, + status: String, + session_id: Uuid, +} + +#[derive(Debug, Serialize, Deserialize)] +struct SessionReplyRequest { + session_id: Uuid, + prompt: String, +} + +#[derive(Debug, Serialize, Deserialize)] +struct EndSessionRequest { + session_id: Uuid, +} + #[derive(Debug, Serialize, Deserialize)] struct ExtensionsResponse { extensions: Vec, @@ -47,32 +72,43 @@ async fn start_session_handler( _api_key: String, ) -> Result { info!("Starting session with prompt: {}", req.prompt); - - let agent = AGENT.lock().await; - + + let mut agent = AGENT.lock().await; + // Create a user message with the prompt - let messages = vec![Message::user().with_text(&req.prompt)]; - - // Process the messages through the agent + let mut messages = vec![Message::user().with_text(&req.prompt)]; + + // Generate a new session ID and process the messages + let session_id = Uuid::new_v4(); + let result = agent.reply(&messages, None).await; - + match result { Ok(mut stream) => { // Process the stream to get the first response if let Ok(Some(response)) = stream.try_next().await { let response_text = response.as_concat_text(); - let api_response = ApiResponse { - message: format!("Session started with prompt: {}. Response: {}", req.prompt, response_text), + messages.push(response); + let mut history = SESSION_HISTORY.lock().await; + history.insert(session_id, messages); + + let api_response = StartSessionResponse { + message: response_text, status: "success".to_string(), + session_id, }; Ok(warp::reply::with_status( warp::reply::json(&api_response), warp::http::StatusCode::OK, )) } else { - let api_response = ApiResponse { - message: format!("Session started but no response generated"), + let mut history = SESSION_HISTORY.lock().await; + history.insert(session_id, messages); + + let api_response = StartSessionResponse { + message: "Session started but no response generated".to_string(), status: "warning".to_string(), + session_id, }; Ok(warp::reply::with_status( warp::reply::json(&api_response), @@ -95,16 +131,33 @@ async fn start_session_handler( } async fn reply_session_handler( - req: SessionRequest, + req: SessionReplyRequest, _api_key: String, ) -> Result { info!("Replying to session with prompt: {}", req.prompt); - - let agent = AGENT.lock().await; - - // Create a user message with the prompt - let messages = vec![Message::user().with_text(&req.prompt)]; - + + let mut agent = AGENT.lock().await; + + // Retrieve existing session history + let mut history = SESSION_HISTORY.lock().await; + let entry = match history.get_mut(&req.session_id) { + Some(messages) => messages, + None => { + let response = ApiResponse { + message: "Session not found".to_string(), + status: "error".to_string(), + }; + return Ok(warp::reply::with_status( + warp::reply::json(&response), + warp::http::StatusCode::NOT_FOUND, + )); + } + }; + + // Append the new user message + entry.push(Message::user().with_text(&req.prompt)); + let messages = entry.clone(); + // Process the messages through the agent let result = agent.reply(&messages, None).await; @@ -113,6 +166,8 @@ async fn reply_session_handler( // Process the stream to get the first response if let Ok(Some(response)) = stream.try_next().await { let response_text = response.as_concat_text(); + // store assistant response in history + entry.push(response); let api_response = ApiResponse { message: format!("Reply: {}", response_text), status: "success".to_string(), @@ -123,7 +178,7 @@ async fn reply_session_handler( )) } else { let api_response = ApiResponse { - message: format!("Reply processed but no response generated"), + message: "Reply processed but no response generated".to_string(), status: "warning".to_string(), }; Ok(warp::reply::with_status( @@ -146,6 +201,32 @@ async fn reply_session_handler( } } +async fn end_session_handler( + req: EndSessionRequest, + _api_key: String, +) -> Result { + let mut history = SESSION_HISTORY.lock().await; + if history.remove(&req.session_id).is_some() { + let response = ApiResponse { + message: "Session ended".to_string(), + status: "success".to_string(), + }; + Ok(warp::reply::with_status( + warp::reply::json(&response), + warp::http::StatusCode::OK, + )) + } else { + let response = ApiResponse { + message: "Session not found".to_string(), + status: "error".to_string(), + }; + Ok(warp::reply::with_status( + warp::reply::json(&response), + warp::http::StatusCode::NOT_FOUND, + )) + } +} + async fn list_extensions_handler() -> Result { info!("Listing extensions"); @@ -331,13 +412,21 @@ async fn main() -> Result<(), anyhow::Error> { .and(with_api_key(api_key.clone())) .and_then(start_session_handler); - // Session reply endpoint + // Session reply endpoint let reply_session = warp::path("session") .and(warp::path("reply")) .and(warp::post()) .and(warp::body::json()) .and(with_api_key(api_key.clone())) .and_then(reply_session_handler); + + // Session end endpoint + let end_session = warp::path("session") + .and(warp::path("end")) + .and(warp::post()) + .and(warp::body::json()) + .and(with_api_key(api_key.clone())) + .and_then(end_session_handler); // List extensions endpoint let list_extensions = warp::path("extensions") @@ -354,6 +443,7 @@ async fn main() -> Result<(), anyhow::Error> { // Combine all routes let routes = start_session .or(reply_session) + .or(end_session) .or(list_extensions) .or(get_provider_config); From dbad86967719f625e202aa36bbf77a760f542871 Mon Sep 17 00:00:00 2001 From: Aljaz Date: Thu, 29 May 2025 09:01:41 +0200 Subject: [PATCH 05/23] feat(api): persist sessions to disk --- crates/goose-api/README.md | 7 ++++ crates/goose-api/src/main.rs | 75 +++++++++++++++++++++++++----------- 2 files changed, 60 insertions(+), 22 deletions(-) diff --git a/crates/goose-api/README.md b/crates/goose-api/README.md index a64f2d03..86e49ae9 100644 --- a/crates/goose-api/README.md +++ b/crates/goose-api/README.md @@ -232,6 +232,13 @@ By default, the server runs on `127.0.0.1:8080`. You can modify this using confi } ``` +## Session Management + +Sessions created via the API are stored in the same location as the CLI +(`~/.local/share/goose/sessions` on most platforms). Each session is saved to a +`.jsonl` file. You can resume or inspect these sessions with the CLI +by providing the session ID returned from the API. + ## Examples ### Using cURL diff --git a/crates/goose-api/src/main.rs b/crates/goose-api/src/main.rs index 33efaf65..9a68f872 100644 --- a/crates/goose-api/src/main.rs +++ b/crates/goose-api/src/main.rs @@ -10,8 +10,11 @@ use goose::agents::{ ExtensionConfig, }; use mcp_core::tool::Tool; -use std::collections::HashMap; use uuid::Uuid; +use goose::session::{self, Identifier}; +use goose::agents::SessionConfig; +use std::path::PathBuf; +use std::sync::Arc; use goose::providers::{create, providers}; use goose::model::ModelConfig; @@ -29,10 +32,6 @@ static AGENT: LazyLock> = LazyLock::new(|| { tokio::sync::Mutex::new(Agent::new()) }); -// Global store for session histories -static SESSION_HISTORY: LazyLock>>> = LazyLock::new(|| { - tokio::sync::Mutex::new(HashMap::new()) -}); #[derive(Debug, Serialize, Deserialize)] struct SessionRequest { @@ -132,8 +131,20 @@ async fn start_session_handler( // Generate a new session ID and process the messages let session_id = Uuid::new_v4(); + let session_name = session_id.to_string(); + let session_path = session::get_path(Identifier::Name(session_name.clone())); - let result = agent.reply(&messages, None).await; + let provider = agent.provider().await.ok(); + + let result = agent + .reply( + &messages, + Some(SessionConfig { + id: Identifier::Name(session_name.clone()), + working_dir: std::env::current_dir().unwrap_or_else(|_| PathBuf::from(".")), + }), + ) + .await; match result { Ok(mut stream) => { @@ -141,8 +152,9 @@ async fn start_session_handler( if let Ok(Some(response)) = stream.try_next().await { let response_text = response.as_concat_text(); messages.push(response); - let mut history = SESSION_HISTORY.lock().await; - history.insert(session_id, messages); + if let Err(e) = session::persist_messages(&session_path, &messages, provider.clone()).await { + warn!("Failed to persist session {}: {}", session_name, e); + } let api_response = StartSessionResponse { message: response_text, @@ -154,8 +166,9 @@ async fn start_session_handler( warp::http::StatusCode::OK, )) } else { - let mut history = SESSION_HISTORY.lock().await; - history.insert(session_id, messages); + if let Err(e) = session::persist_messages(&session_path, &messages, provider.clone()).await { + warn!("Failed to persist session {}: {}", session_name, e); + } let api_response = StartSessionResponse { message: "Session started but no response generated".to_string(), @@ -190,11 +203,13 @@ async fn reply_session_handler( let mut agent = AGENT.lock().await; - // Retrieve existing session history - let mut history = SESSION_HISTORY.lock().await; - let entry = match history.get_mut(&req.session_id) { - Some(messages) => messages, - None => { + let session_name = req.session_id.to_string(); + let session_path = session::get_path(Identifier::Name(session_name.clone())); + + // Retrieve existing session history from disk + let mut messages = match session::read_messages(&session_path) { + Ok(m) => m, + Err(_) => { let response = ApiResponse { message: "Session not found".to_string(), status: "error".to_string(), @@ -207,19 +222,30 @@ async fn reply_session_handler( }; // Append the new user message - entry.push(Message::user().with_text(&req.prompt)); - let messages = entry.clone(); + messages.push(Message::user().with_text(&req.prompt)); + + let provider = agent.provider().await.ok(); // Process the messages through the agent - let result = agent.reply(&messages, None).await; + let result = agent + .reply( + &messages, + Some(SessionConfig { + id: Identifier::Name(session_name.clone()), + working_dir: std::env::current_dir().unwrap_or_else(|_| PathBuf::from(".")), + }), + ) + .await; match result { Ok(mut stream) => { // Process the stream to get the first response if let Ok(Some(response)) = stream.try_next().await { let response_text = response.as_concat_text(); - // store assistant response in history - entry.push(response); + messages.push(response); + if let Err(e) = session::persist_messages(&session_path, &messages, provider.clone()).await { + warn!("Failed to persist session {}: {}", session_name, e); + } let api_response = ApiResponse { message: format!("Reply: {}", response_text), status: "success".to_string(), @@ -229,6 +255,9 @@ async fn reply_session_handler( warp::http::StatusCode::OK, )) } else { + if let Err(e) = session::persist_messages(&session_path, &messages, provider.clone()).await { + warn!("Failed to persist session {}: {}", session_name, e); + } let api_response = ApiResponse { message: "Reply processed but no response generated".to_string(), status: "warning".to_string(), @@ -257,8 +286,10 @@ async fn end_session_handler( req: EndSessionRequest, _api_key: String, ) -> Result { - let mut history = SESSION_HISTORY.lock().await; - if history.remove(&req.session_id).is_some() { + let session_name = req.session_id.to_string(); + let session_path = session::get_path(Identifier::Name(session_name.clone())); + + if std::fs::remove_file(&session_path).is_ok() { let response = ApiResponse { message: "Session ended".to_string(), status: "success".to_string(), From a974090b5e719ddcc6aec2399828458c0367f916 Mon Sep 17 00:00:00 2001 From: Aljaz Date: Thu, 29 May 2025 10:12:07 +0200 Subject: [PATCH 06/23] refactor(api): inject server state --- crates/goose-api/src/main.rs | 71 ++++++++++++++++++++++-------------- 1 file changed, 44 insertions(+), 27 deletions(-) diff --git a/crates/goose-api/src/main.rs b/crates/goose-api/src/main.rs index 9a68f872..99ce0157 100644 --- a/crates/goose-api/src/main.rs +++ b/crates/goose-api/src/main.rs @@ -1,7 +1,7 @@ use warp::{Filter, Rejection}; use warp::http::HeaderValue; use serde::{Deserialize, Serialize}; -use std::sync::LazyLock; +use std::sync::Arc; use goose::config::{Config, ExtensionEntry}; use goose::agents::{ extension::Envs, @@ -14,7 +14,6 @@ use uuid::Uuid; use goose::session::{self, Identifier}; use goose::agents::SessionConfig; use std::path::PathBuf; -use std::sync::Arc; use goose::providers::{create, providers}; use goose::model::ModelConfig; @@ -24,13 +23,11 @@ use config::{builder::DefaultState, ConfigBuilder, Environment, File}; use serde_json::Value; // Import the correct Value type use futures_util::TryStreamExt; -// Global extension manager for extension listing -static EXTENSION_MANAGER: LazyLock = LazyLock::new(|| ExtensionManager::default()); - -// Global agent for handling sessions -static AGENT: LazyLock> = LazyLock::new(|| { - tokio::sync::Mutex::new(Agent::new()) -}); +#[derive(Clone)] +struct ServerState { + agent: Arc>, + extension_manager: Arc>, +} #[derive(Debug, Serialize, Deserialize)] @@ -120,11 +117,12 @@ enum ExtensionConfigRequest { async fn start_session_handler( req: SessionRequest, + state: ServerState, _api_key: String, ) -> Result { info!("Starting session with prompt: {}", req.prompt); - let mut agent = AGENT.lock().await; + let mut agent = state.agent.lock().await; // Create a user message with the prompt let mut messages = vec![Message::user().with_text(&req.prompt)]; @@ -197,11 +195,12 @@ async fn start_session_handler( async fn reply_session_handler( req: SessionReplyRequest, + state: ServerState, _api_key: String, ) -> Result { info!("Replying to session with prompt: {}", req.prompt); - let mut agent = AGENT.lock().await; + let mut agent = state.agent.lock().await; let session_name = req.session_id.to_string(); let session_path = session::get_path(Identifier::Name(session_name.clone())); @@ -284,6 +283,7 @@ async fn reply_session_handler( async fn end_session_handler( req: EndSessionRequest, + _state: ServerState, _api_key: String, ) -> Result { let session_name = req.session_id.to_string(); @@ -310,10 +310,11 @@ async fn end_session_handler( } } -async fn list_extensions_handler() -> Result { +async fn list_extensions_handler(state: ServerState) -> Result { info!("Listing extensions"); - match EXTENSION_MANAGER.list_extensions().await { + let manager = state.extension_manager.lock().await; + match manager.list_extensions().await { Ok(exts) => { let response = ExtensionsResponse { extensions: exts }; Ok::(warp::reply::json(&response)) @@ -328,7 +329,7 @@ async fn list_extensions_handler() -> Result { } } -async fn get_provider_config_handler() -> Result { +async fn get_provider_config_handler(_state: ServerState) -> Result { info!("Getting provider configuration"); let config = Config::global(); @@ -343,6 +344,7 @@ async fn get_provider_config_handler() -> Result { async fn add_extension_handler( req: ExtensionConfigRequest, + state: ServerState, _api_key: String, ) -> Result { info!("Adding extension: {:?}", req); @@ -438,7 +440,7 @@ async fn add_extension_handler( } }; - let agent = AGENT.lock().await; + let agent = state.agent.lock().await; let result = agent.add_extension(extension).await; let resp = match result { @@ -453,10 +455,11 @@ async fn add_extension_handler( async fn remove_extension_handler( name: String, + state: ServerState, _api_key: String, ) -> Result { info!("Removing extension: {}", name); - let agent = AGENT.lock().await; + let agent = state.agent.lock().await; agent.remove_extension(&name).await; let resp = ExtensionResponse { error: false, message: None }; @@ -488,7 +491,7 @@ fn load_configuration() -> std::result::Result Result<(), anyhow::Error> { +async fn initialize_provider_config(state: &ServerState) -> Result<(), anyhow::Error> { // Get configuration let api_config = load_configuration()?; @@ -535,14 +538,14 @@ async fn initialize_provider_config() -> Result<(), anyhow::Error> { let model_config = ModelConfig::new(model_name); let provider = create(&provider_name, model_config)?; - let agent = AGENT.lock().await; + let agent = state.agent.lock().await; agent.update_provider(provider).await?; info!("Provider configuration successful"); Ok(()) } /// Initialize extensions from the configuration. -async fn initialize_extensions(config: &config::Config) -> Result<(), anyhow::Error> { +async fn initialize_extensions(state: &ServerState, config: &config::Config) -> Result<(), anyhow::Error> { if let Ok(ext_table) = config.get_table("extensions") { for (name, ext_config) in ext_table { // Deserialize into ExtensionEntry to get enabled flag and config @@ -552,7 +555,7 @@ async fn initialize_extensions(config: &config::Config) -> Result<(), anyhow::Er if entry.enabled { let extension_config: ExtensionConfig = entry.config; // Acquire the global agent lock and try to add the extension - let mut agent = AGENT.lock().await; + let mut agent = state.agent.lock().await; if let Err(e) = agent.add_extension(extension_config).await { error!("Failed to add extension {}: {}", name, e); } @@ -567,10 +570,10 @@ async fn initialize_extensions(config: &config::Config) -> Result<(), anyhow::Er } -async fn run_init_tests() -> Result<(), anyhow::Error> { +async fn run_init_tests(state: &ServerState) -> Result<(), anyhow::Error> { info!("Running initialization tests"); { - let _agent = AGENT.lock().await; + let _agent = state.agent.lock().await; info!("Agent initialization test passed"); } info!("Initialization tests completed"); @@ -597,26 +600,34 @@ async fn main() -> Result<(), anyhow::Error> { "default_api_key".to_string() }); + let state = ServerState { + agent: Arc::new(tokio::sync::Mutex::new(Agent::new())), + extension_manager: Arc::new(tokio::sync::Mutex::new(ExtensionManager::default())), + }; + // Initialize provider configuration - if let Err(e) = initialize_provider_config().await { + if let Err(e) = initialize_provider_config(&state).await { error!("Failed to initialize provider: {}", e); return Err(e); } - + // Initialize extensions from configuration - if let Err(e) = initialize_extensions(&api_config).await { + if let Err(e) = initialize_extensions(&state, &api_config).await { error!("Failed to initialize extensions: {}", e); } - - if let Err(e) = run_init_tests().await { + + if let Err(e) = run_init_tests(&state).await { error!("Initialization tests failed: {}", e); } + let state_filter = warp::any().map(move || state.clone()); + // Session start endpoint let start_session = warp::path("session") .and(warp::path("start")) .and(warp::post()) .and(warp::body::json()) + .and(state_filter.clone()) .and(with_api_key(api_key.clone())) .and_then(start_session_handler); @@ -625,6 +636,7 @@ async fn main() -> Result<(), anyhow::Error> { .and(warp::path("reply")) .and(warp::post()) .and(warp::body::json()) + .and(state_filter.clone()) .and(with_api_key(api_key.clone())) .and_then(reply_session_handler); @@ -633,6 +645,7 @@ async fn main() -> Result<(), anyhow::Error> { .and(warp::path("end")) .and(warp::post()) .and(warp::body::json()) + .and(state_filter.clone()) .and(with_api_key(api_key.clone())) .and_then(end_session_handler); @@ -640,6 +653,7 @@ async fn main() -> Result<(), anyhow::Error> { let list_extensions = warp::path("extensions") .and(warp::path("list")) .and(warp::get()) + .and(state_filter.clone()) .and_then(list_extensions_handler); // Add extension endpoint @@ -647,6 +661,7 @@ async fn main() -> Result<(), anyhow::Error> { .and(warp::path("add")) .and(warp::post()) .and(warp::body::json()) + .and(state_filter.clone()) .and(with_api_key(api_key.clone())) .and_then(add_extension_handler); @@ -655,6 +670,7 @@ async fn main() -> Result<(), anyhow::Error> { .and(warp::path("remove")) .and(warp::post()) .and(warp::body::json()) + .and(state_filter.clone()) .and(with_api_key(api_key.clone())) .and_then(remove_extension_handler); @@ -662,6 +678,7 @@ async fn main() -> Result<(), anyhow::Error> { let get_provider_config = warp::path("provider") .and(warp::path("config")) .and(warp::get()) + .and(state_filter.clone()) .and_then(get_provider_config_handler); // Combine all routes From 0a9cd1eea712fd463cd87069c55879774cb6844a Mon Sep 17 00:00:00 2001 From: Aljaz Date: Thu, 29 May 2025 10:12:12 +0200 Subject: [PATCH 07/23] api: return 401 for invalid api key --- crates/goose-api/src/main.rs | 59 ++++++++++++++++++++++++++++++++++-- 1 file changed, 56 insertions(+), 3 deletions(-) diff --git a/crates/goose-api/src/main.rs b/crates/goose-api/src/main.rs index 9a68f872..71660b33 100644 --- a/crates/goose-api/src/main.rs +++ b/crates/goose-api/src/main.rs @@ -1,5 +1,6 @@ -use warp::{Filter, Rejection}; +use warp::{Filter, Rejection, Reply}; use warp::http::HeaderValue; +use std::convert::Infallible; use serde::{Deserialize, Serialize}; use std::sync::LazyLock; use goose::config::{Config, ExtensionEntry}; @@ -463,6 +464,24 @@ async fn remove_extension_handler( Ok(warp::reply::json(&resp)) } +#[derive(Debug)] +struct Unauthorized; + +impl warp::reject::Reject for Unauthorized {} + +async fn handle_rejection(err: Rejection) -> Result { + if err.find::().is_some() { + Ok(warp::reply::with_status("UNAUTHORIZED", warp::http::StatusCode::UNAUTHORIZED)) + } else if err.is_not_found() { + Ok(warp::reply::with_status("NOT_FOUND", warp::http::StatusCode::NOT_FOUND)) + } else { + Ok(warp::reply::with_status( + "INTERNAL_SERVER_ERROR", + warp::http::StatusCode::INTERNAL_SERVER_ERROR, + )) + } +} + fn with_api_key(api_key: String) -> impl Filter + Clone { warp::header::value("x-api-key") .and_then(move |header_api_key: HeaderValue| { @@ -471,12 +490,45 @@ fn with_api_key(api_key: String) -> impl Filter std::result::Result { let config_path = std::env::var("GOOSE_CONFIG").unwrap_or_else(|_| "config".to_string()); @@ -671,7 +723,8 @@ async fn main() -> Result<(), anyhow::Error> { .or(list_extensions) .or(add_extension) .or(remove_extension) - .or(get_provider_config); + .or(get_provider_config) + .recover(handle_rejection); // Get bind address from configuration or use default let host = std::env::var("GOOSE_API_HOST") From 45bddbdf1230791a95b1f9ed6a723bc1c25bc1d4 Mon Sep 17 00:00:00 2001 From: Aljaz Date: Thu, 29 May 2025 10:12:40 +0200 Subject: [PATCH 08/23] test: serialize env-modifying tests --- crates/goose/src/tracing/langfuse_layer.rs | 2 ++ crates/goose/tests/providers.rs | 12 ++++++++++++ 2 files changed, 14 insertions(+) diff --git a/crates/goose/src/tracing/langfuse_layer.rs b/crates/goose/src/tracing/langfuse_layer.rs index 2ac418cf..4bf19376 100644 --- a/crates/goose/src/tracing/langfuse_layer.rs +++ b/crates/goose/src/tracing/langfuse_layer.rs @@ -187,6 +187,7 @@ mod tests { use super::*; use serde_json::json; use std::collections::HashMap; + use serial_test::serial; use tokio::sync::Mutex; use tracing::dispatcher; use wiremock::matchers::{method, path}; @@ -389,6 +390,7 @@ mod tests { } #[tokio::test] + #[serial] async fn test_create_langfuse_observer() { let fixture = TestFixture::new().await.with_mock_server().await; diff --git a/crates/goose/tests/providers.rs b/crates/goose/tests/providers.rs index e65aff66..d18d4226 100644 --- a/crates/goose/tests/providers.rs +++ b/crates/goose/tests/providers.rs @@ -8,6 +8,7 @@ use goose::providers::{ }; use mcp_core::content::Content; use mcp_core::tool::Tool; +use serial_test::serial; use std::collections::HashMap; use std::sync::Arc; use std::sync::Mutex; @@ -352,6 +353,7 @@ where } #[tokio::test] +#[serial] async fn test_openai_provider() -> Result<()> { test_provider( "OpenAI", @@ -363,6 +365,7 @@ async fn test_openai_provider() -> Result<()> { } #[tokio::test] +#[serial] async fn test_azure_provider() -> Result<()> { test_provider( "Azure", @@ -378,6 +381,7 @@ async fn test_azure_provider() -> Result<()> { } #[tokio::test] +#[serial] async fn test_bedrock_provider_long_term_credentials() -> Result<()> { test_provider( "Bedrock", @@ -389,6 +393,7 @@ async fn test_bedrock_provider_long_term_credentials() -> Result<()> { } #[tokio::test] +#[serial] async fn test_bedrock_provider_aws_profile_credentials() -> Result<()> { let env_mods = HashMap::from_iter([ // Ensure to unset long-term credentials to use AWS Profile provider @@ -406,6 +411,7 @@ async fn test_bedrock_provider_aws_profile_credentials() -> Result<()> { } #[tokio::test] +#[serial] async fn test_databricks_provider() -> Result<()> { test_provider( "Databricks", @@ -417,6 +423,7 @@ async fn test_databricks_provider() -> Result<()> { } #[tokio::test] +#[serial] async fn test_databricks_provider_oauth() -> Result<()> { let mut env_mods = HashMap::new(); env_mods.insert("DATABRICKS_TOKEN", None); @@ -431,6 +438,7 @@ async fn test_databricks_provider_oauth() -> Result<()> { } #[tokio::test] +#[serial] async fn test_ollama_provider() -> Result<()> { test_provider( "Ollama", @@ -442,11 +450,13 @@ async fn test_ollama_provider() -> Result<()> { } #[tokio::test] +#[serial] async fn test_groq_provider() -> Result<()> { test_provider("Groq", &["GROQ_API_KEY"], None, groq::GroqProvider::default).await } #[tokio::test] +#[serial] async fn test_anthropic_provider() -> Result<()> { test_provider( "Anthropic", @@ -458,6 +468,7 @@ async fn test_anthropic_provider() -> Result<()> { } #[tokio::test] +#[serial] async fn test_openrouter_provider() -> Result<()> { test_provider( "OpenRouter", @@ -469,6 +480,7 @@ async fn test_openrouter_provider() -> Result<()> { } #[tokio::test] +#[serial] async fn test_google_provider() -> Result<()> { test_provider( "Google", From f6e305958e9c38c460ee4c4cc61ef8931fac0322 Mon Sep 17 00:00:00 2001 From: Aljaz Date: Thu, 29 May 2025 10:13:24 +0200 Subject: [PATCH 09/23] docs(api): update implementation status --- crates/goose-api/README.md | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/crates/goose-api/README.md b/crates/goose-api/README.md index 86e49ae9..3b5469a4 100644 --- a/crates/goose-api/README.md +++ b/crates/goose-api/README.md @@ -355,14 +355,24 @@ The current implementation includes the following features from the implementati 🟡 **Step 7**: Extension loading mechanism (partial implementation) 🟡 **Step 8**: MCP support (partial implementation) -✅ **Step 10**: Documentation -❌ **Step 11**: Tests (not yet implemented) +✅ **Step 10**: Documentation +✅ **Step 11**: Tests + +## Running Tests + +Run all unit and integration tests with: + +```bash +cargo test +``` + +This command executes the entire workspace test suite. To test a single crate, use `cargo test -p `. ## Future Work - Extend session management capabilities - Add more comprehensive error handling -- Implement unit and integration tests +- Expand unit and integration tests - Complete MCP integration - Add metrics and monitoring - Add OpenAPI documentation generation \ No newline at end of file From 543bfddbd5dad0fa8eceb8043ecca3b84c8b72ad Mon Sep 17 00:00:00 2001 From: Aljaz Date: Thu, 29 May 2025 10:13:42 +0200 Subject: [PATCH 10/23] Refactor goose-api into modules --- crates/goose-api/src/config.rs | 97 +++++ crates/goose-api/src/handlers.rs | 449 ++++++++++++++++++++ crates/goose-api/src/lib.rs | 5 + crates/goose-api/src/main.rs | 704 +------------------------------ crates/goose-api/src/routes.rs | 123 ++++++ crates/goose-api/src/tests.rs | 10 + 6 files changed, 686 insertions(+), 702 deletions(-) create mode 100644 crates/goose-api/src/config.rs create mode 100644 crates/goose-api/src/handlers.rs create mode 100644 crates/goose-api/src/lib.rs create mode 100644 crates/goose-api/src/routes.rs create mode 100644 crates/goose-api/src/tests.rs diff --git a/crates/goose-api/src/config.rs b/crates/goose-api/src/config.rs new file mode 100644 index 00000000..9cafc919 --- /dev/null +++ b/crates/goose-api/src/config.rs @@ -0,0 +1,97 @@ +use crate::handlers::{AGENT, EXTENSION_MANAGER}; +use goose::config::{Config, ExtensionEntry}; +use goose::agents::ExtensionConfig; +use goose::providers::{create, providers}; +use goose::model::ModelConfig; +use tracing::{info, warn, error}; +use config::{builder::DefaultState, ConfigBuilder, Environment, File}; +use serde_json::Value; + +pub fn load_configuration() -> std::result::Result { + let config_path = std::env::var("GOOSE_CONFIG").unwrap_or_else(|_| "config".to_string()); + let builder = ConfigBuilder::::default() + .add_source(File::with_name(&config_path).required(false)) + .add_source(Environment::with_prefix("GOOSE_API")); + builder.build() +} + +pub async fn initialize_provider_config() -> Result<(), anyhow::Error> { + let api_config = load_configuration()?; + + let provider_name = std::env::var("GOOSE_API_PROVIDER") + .or_else(|_| api_config.get_string("provider")) + .unwrap_or_else(|_| "openai".to_string()); + + let model_name = std::env::var("GOOSE_API_MODEL") + .or_else(|_| api_config.get_string("model")) + .unwrap_or_else(|_| "gpt-4o".to_string()); + + info!("Initializing with provider: {}, model: {}", provider_name, model_name); + + let config = Config::global(); + config.set_param("GOOSE_PROVIDER", Value::String(provider_name.clone()))?; + config.set_param("GOOSE_MODEL", Value::String(model_name.clone()))?; + + let available_providers = providers(); + if let Some(provider_meta) = available_providers.iter().find(|p| p.name == provider_name) { + for key in &provider_meta.config_keys { + let env_name = key.name.clone(); + if let Ok(value) = std::env::var(&env_name) { + if key.secret { + config.set_secret(&key.name, Value::String(value))?; + info!("Set secret key: {}", key.name); + } else { + config.set_param(&key.name, Value::String(value))?; + info!("Set parameter: {}", key.name); + } + } else { + warn!("Environment variable not set for key: {}", key.name); + if key.required { + error!("Required key {} not provided", key.name); + return Err(anyhow::anyhow!("Required key {} not provided", key.name)); + } + } + } + } + + let model_config = ModelConfig::new(model_name); + let provider = create(&provider_name, model_config)?; + + let agent = AGENT.lock().await; + agent.update_provider(provider).await?; + + info!("Provider configuration successful"); + Ok(()) +} + +pub async fn initialize_extensions(config: &config::Config) -> Result<(), anyhow::Error> { + if let Ok(ext_table) = config.get_table("extensions") { + for (name, ext_config) in ext_table { + let entry: ExtensionEntry = ext_config.clone().try_deserialize() + .map_err(|e| anyhow::anyhow!("Failed to deserialize extension config for {}: {}", name, e))?; + + if entry.enabled { + let extension_config: ExtensionConfig = entry.config; + let mut agent = AGENT.lock().await; + if let Err(e) = agent.add_extension(extension_config).await { + error!("Failed to add extension {}: {}", name, e); + } + } else { + info!("Skipping disabled extension: {}", name); + } + } + } else { + warn!("No extensions configured in config file."); + } + Ok(()) +} + +pub async fn run_init_tests() -> Result<(), anyhow::Error> { + info!("Running initialization tests"); + { + let _agent = AGENT.lock().await; + info!("Agent initialization test passed"); + } + info!("Initialization tests completed"); + Ok(()) +} diff --git a/crates/goose-api/src/handlers.rs b/crates/goose-api/src/handlers.rs new file mode 100644 index 00000000..52a0ad8e --- /dev/null +++ b/crates/goose-api/src/handlers.rs @@ -0,0 +1,449 @@ +use warp::{http::HeaderValue, Filter, Rejection}; +use serde::{Deserialize, Serialize}; +use std::path::PathBuf; +use uuid::Uuid; +use futures_util::TryStreamExt; +use tracing::{info, warn, error}; +use mcp_core::tool::Tool; +use goose::agents::{extension::Envs, extension_manager::ExtensionManager, ExtensionConfig, Agent, SessionConfig}; +use goose::message::Message; +use goose::session::{self, Identifier}; +use goose::config::{Config, ExtensionEntry}; +use std::sync::LazyLock; + +pub static EXTENSION_MANAGER: LazyLock = LazyLock::new(|| ExtensionManager::default()); +pub static AGENT: LazyLock> = LazyLock::new(|| tokio::sync::Mutex::new(Agent::new())); + +#[derive(Debug, Serialize, Deserialize)] +pub struct SessionRequest { + pub prompt: String, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct ApiResponse { + pub message: String, + pub status: String, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct StartSessionResponse { + pub message: String, + pub status: String, + pub session_id: Uuid, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct SessionReplyRequest { + pub session_id: Uuid, + pub prompt: String, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct EndSessionRequest { + pub session_id: Uuid, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct ExtensionsResponse { + pub extensions: Vec, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct ProviderConfig { + pub provider: String, + pub model: String, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct ExtensionResponse { + pub error: bool, + pub message: Option, +} + +#[derive(Debug, Deserialize)] +#[serde(tag = "type")] +pub enum ExtensionConfigRequest { + #[serde(rename = "sse")] + Sse { + name: String, + uri: String, + #[serde(default)] + envs: Envs, + #[serde(default)] + env_keys: Vec, + timeout: Option, + }, + #[serde(rename = "stdio")] + Stdio { + name: String, + cmd: String, + #[serde(default)] + args: Vec, + #[serde(default)] + envs: Envs, + #[serde(default)] + env_keys: Vec, + timeout: Option, + }, + #[serde(rename = "builtin")] + Builtin { + name: String, + display_name: Option, + timeout: Option, + }, + #[serde(rename = "frontend")] + Frontend { + name: String, + tools: Vec, + instructions: Option, + }, +} + +pub async fn start_session_handler( + req: SessionRequest, + _api_key: String, +) -> Result { + info!("Starting session with prompt: {}", req.prompt); + + let mut agent = AGENT.lock().await; + let mut messages = vec![Message::user().with_text(&req.prompt)]; + let session_id = Uuid::new_v4(); + let session_name = session_id.to_string(); + let session_path = session::get_path(Identifier::Name(session_name.clone())); + + let provider = agent.provider().await.ok(); + + let result = agent + .reply( + &messages, + Some(SessionConfig { + id: Identifier::Name(session_name.clone()), + working_dir: std::env::current_dir().unwrap_or_else(|_| PathBuf::from(".")), + }), + ) + .await; + + match result { + Ok(mut stream) => { + if let Ok(Some(response)) = stream.try_next().await { + let response_text = response.as_concat_text(); + messages.push(response); + if let Err(e) = session::persist_messages(&session_path, &messages, provider.clone()).await { + warn!("Failed to persist session {}: {}", session_name, e); + } + + let api_response = StartSessionResponse { + message: response_text, + status: "success".to_string(), + session_id, + }; + Ok(warp::reply::with_status( + warp::reply::json(&api_response), + warp::http::StatusCode::OK, + )) + } else { + if let Err(e) = session::persist_messages(&session_path, &messages, provider.clone()).await { + warn!("Failed to persist session {}: {}", session_name, e); + } + + let api_response = StartSessionResponse { + message: "Session started but no response generated".to_string(), + status: "warning".to_string(), + session_id, + }; + Ok(warp::reply::with_status( + warp::reply::json(&api_response), + warp::http::StatusCode::OK, + )) + } + } + Err(e) => { + error!("Failed to start session: {}", e); + let response = ApiResponse { + message: format!("Failed to start session: {}", e), + status: "error".to_string(), + }; + Ok(warp::reply::with_status( + warp::reply::json(&response), + warp::http::StatusCode::INTERNAL_SERVER_ERROR, + )) + } + } +} + +pub async fn reply_session_handler( + req: SessionReplyRequest, + _api_key: String, +) -> Result { + info!("Replying to session with prompt: {}", req.prompt); + + let mut agent = AGENT.lock().await; + + let session_name = req.session_id.to_string(); + let session_path = session::get_path(Identifier::Name(session_name.clone())); + + let mut messages = match session::read_messages(&session_path) { + Ok(m) => m, + Err(_) => { + let response = ApiResponse { + message: "Session not found".to_string(), + status: "error".to_string(), + }; + return Ok(warp::reply::with_status( + warp::reply::json(&response), + warp::http::StatusCode::NOT_FOUND, + )); + } + }; + + messages.push(Message::user().with_text(&req.prompt)); + + let provider = agent.provider().await.ok(); + + let result = agent + .reply( + &messages, + Some(SessionConfig { + id: Identifier::Name(session_name.clone()), + working_dir: std::env::current_dir().unwrap_or_else(|_| PathBuf::from(".")), + }), + ) + .await; + + match result { + Ok(mut stream) => { + if let Ok(Some(response)) = stream.try_next().await { + let response_text = response.as_concat_text(); + messages.push(response); + if let Err(e) = session::persist_messages(&session_path, &messages, provider.clone()).await { + warn!("Failed to persist session {}: {}", session_name, e); + } + let api_response = ApiResponse { + message: format!("Reply: {}", response_text), + status: "success".to_string(), + }; + Ok(warp::reply::with_status( + warp::reply::json(&api_response), + warp::http::StatusCode::OK, + )) + } else { + if let Err(e) = session::persist_messages(&session_path, &messages, provider.clone()).await { + warn!("Failed to persist session {}: {}", session_name, e); + } + let api_response = ApiResponse { + message: "Reply processed but no response generated".to_string(), + status: "warning".to_string(), + }; + Ok(warp::reply::with_status( + warp::reply::json(&api_response), + warp::http::StatusCode::OK, + )) + } + } + Err(e) => { + error!("Failed to reply to session: {}", e); + let response = ApiResponse { + message: format!("Failed to reply to session: {}", e), + status: "error".to_string(), + }; + Ok(warp::reply::with_status( + warp::reply::json(&response), + warp::http::StatusCode::INTERNAL_SERVER_ERROR, + )) + } + } +} + +pub async fn end_session_handler( + req: EndSessionRequest, + _api_key: String, +) -> Result { + let session_name = req.session_id.to_string(); + let session_path = session::get_path(Identifier::Name(session_name.clone())); + + if std::fs::remove_file(&session_path).is_ok() { + let response = ApiResponse { + message: "Session ended".to_string(), + status: "success".to_string(), + }; + Ok(warp::reply::with_status( + warp::reply::json(&response), + warp::http::StatusCode::OK, + )) + } else { + let response = ApiResponse { + message: "Session not found".to_string(), + status: "error".to_string(), + }; + Ok(warp::reply::with_status( + warp::reply::json(&response), + warp::http::StatusCode::NOT_FOUND, + )) + } +} + +pub async fn list_extensions_handler() -> Result { + info!("Listing extensions"); + + match EXTENSION_MANAGER.list_extensions().await { + Ok(exts) => { + let response = ExtensionsResponse { extensions: exts }; + Ok::(warp::reply::json(&response)) + } + Err(e) => { + error!("Failed to list extensions: {}", e); + let response = ExtensionsResponse { + extensions: vec!["Failed to list extensions".to_string()], + }; + Ok::(warp::reply::json(&response)) + } + } +} + +pub async fn get_provider_config_handler() -> Result { + info!("Getting provider configuration"); + + let config = Config::global(); + let provider = config + .get_param::("GOOSE_PROVIDER") + .unwrap_or_else(|_| "Not configured".to_string()); + let model = config + .get_param::("GOOSE_MODEL") + .unwrap_or_else(|_| "Not configured".to_string()); + + let response = ProviderConfig { provider, model }; + Ok::(warp::reply::json(&response)) +} + +pub async fn add_extension_handler( + req: ExtensionConfigRequest, + _api_key: String, +) -> Result { + info!("Adding extension: {:?}", req); + + #[cfg(target_os = "windows")] + if let ExtensionConfigRequest::Stdio { cmd, .. } = &req { + if cmd.ends_with("npx.cmd") || cmd.ends_with("npx") { + let node_exists = std::path::Path::new(r"C:\Program Files\nodejs\node.exe").exists() + || std::path::Path::new(r"C:\Program Files (x86)\nodejs\node.exe").exists(); + + if !node_exists { + let cmd_path = std::path::Path::new(cmd); + let script_dir = cmd_path.parent().ok_or_else(|| warp::reject())?; + + let install_script = script_dir.join("install-node.cmd"); + + if install_script.exists() { + eprintln!("Installing Node.js..."); + let output = std::process::Command::new(&install_script) + .arg("https://nodejs.org/dist/v23.10.0/node-v23.10.0-x64.msi") + .output() + .map_err(|_| warp::reject())?; + + if !output.status.success() { + eprintln!( + "Failed to install Node.js: {}", + String::from_utf8_lossy(&output.stderr) + ); + let resp = ExtensionResponse { + error: true, + message: Some(format!( + "Failed to install Node.js: {}", + String::from_utf8_lossy(&output.stderr) + )), + }; + return Ok(warp::reply::json(&resp)); + } + eprintln!("Node.js installation completed"); + } else { + eprintln!("Node.js installer script not found at: {}", install_script.display()); + let resp = ExtensionResponse { + error: true, + message: Some("Node.js installer script not found".to_string()), + }; + return Ok(warp::reply::json(&resp)); + } + } + } + } + + let extension = match req { + ExtensionConfigRequest::Sse { name, uri, envs, env_keys, timeout } => { + ExtensionConfig::Sse { + name, + uri, + envs, + env_keys, + description: None, + timeout, + bundled: None, + } + } + ExtensionConfigRequest::Stdio { name, cmd, args, envs, env_keys, timeout } => { + ExtensionConfig::Stdio { + name, + cmd, + args, + envs, + env_keys, + timeout, + description: None, + bundled: None, + } + } + ExtensionConfigRequest::Builtin { name, display_name, timeout } => { + ExtensionConfig::Builtin { + name, + display_name, + timeout, + bundled: None, + } + } + ExtensionConfigRequest::Frontend { name, tools, instructions } => { + ExtensionConfig::Frontend { + name, + tools, + instructions, + bundled: None, + } + } + }; + + let agent = AGENT.lock().await; + let result = agent.add_extension(extension).await; + + let resp = match result { + Ok(_) => ExtensionResponse { error: false, message: None }, + Err(e) => ExtensionResponse { + error: true, + message: Some(format!("Failed to add extension configuration, error: {:?}", e)), + }, + }; + Ok(warp::reply::json(&resp)) +} + +pub async fn remove_extension_handler( + name: String, + _api_key: String, +) -> Result { + info!("Removing extension: {}", name); + let agent = AGENT.lock().await; + agent.remove_extension(&name).await; + + let resp = ExtensionResponse { error: false, message: None }; + Ok(warp::reply::json(&resp)) +} + +pub fn with_api_key(api_key: String) -> impl Filter + Clone { + warp::header::value("x-api-key") + .and_then(move |header_api_key: HeaderValue| { + let api_key = api_key.clone(); + async move { + if header_api_key == api_key { + Ok(api_key) + } else { + Err(warp::reject::not_found()) + } + } + }) +} diff --git a/crates/goose-api/src/lib.rs b/crates/goose-api/src/lib.rs new file mode 100644 index 00000000..b2037198 --- /dev/null +++ b/crates/goose-api/src/lib.rs @@ -0,0 +1,5 @@ +mod handlers; +mod config; +mod routes; + +pub use routes::{build_routes, run_server}; diff --git a/crates/goose-api/src/main.rs b/crates/goose-api/src/main.rs index 9a68f872..451d0428 100644 --- a/crates/goose-api/src/main.rs +++ b/crates/goose-api/src/main.rs @@ -1,706 +1,6 @@ -use warp::{Filter, Rejection}; -use warp::http::HeaderValue; -use serde::{Deserialize, Serialize}; -use std::sync::LazyLock; -use goose::config::{Config, ExtensionEntry}; -use goose::agents::{ - extension::Envs, - Agent, - extension_manager::ExtensionManager, - ExtensionConfig, -}; -use mcp_core::tool::Tool; -use uuid::Uuid; -use goose::session::{self, Identifier}; -use goose::agents::SessionConfig; -use std::path::PathBuf; -use std::sync::Arc; - -use goose::providers::{create, providers}; -use goose::model::ModelConfig; -use goose::message::Message; -use tracing::{info, warn, error}; -use config::{builder::DefaultState, ConfigBuilder, Environment, File}; -use serde_json::Value; // Import the correct Value type -use futures_util::TryStreamExt; - -// Global extension manager for extension listing -static EXTENSION_MANAGER: LazyLock = LazyLock::new(|| ExtensionManager::default()); - -// Global agent for handling sessions -static AGENT: LazyLock> = LazyLock::new(|| { - tokio::sync::Mutex::new(Agent::new()) -}); - - -#[derive(Debug, Serialize, Deserialize)] -struct SessionRequest { - prompt: String, -} - -#[derive(Debug, Serialize, Deserialize)] -struct ApiResponse { - message: String, - status: String, -} - -#[derive(Debug, Serialize, Deserialize)] -struct StartSessionResponse { - message: String, - status: String, - session_id: Uuid, -} - -#[derive(Debug, Serialize, Deserialize)] -struct SessionReplyRequest { - session_id: Uuid, - prompt: String, -} - -#[derive(Debug, Serialize, Deserialize)] -struct EndSessionRequest { - session_id: Uuid, -} - -#[derive(Debug, Serialize, Deserialize)] -struct ExtensionsResponse { - extensions: Vec, -} - -#[derive(Debug, Serialize, Deserialize)] -struct ProviderConfig { - provider: String, - model: String, -} - -#[derive(Debug, Serialize, Deserialize)] -struct ExtensionResponse { - error: bool, - message: Option, -} - -#[derive(Debug, Deserialize)] -#[serde(tag = "type")] -enum ExtensionConfigRequest { - #[serde(rename = "sse")] - Sse { - name: String, - uri: String, - #[serde(default)] - envs: Envs, - #[serde(default)] - env_keys: Vec, - timeout: Option, - }, - #[serde(rename = "stdio")] - Stdio { - name: String, - cmd: String, - #[serde(default)] - args: Vec, - #[serde(default)] - envs: Envs, - #[serde(default)] - env_keys: Vec, - timeout: Option, - }, - #[serde(rename = "builtin")] - Builtin { - name: String, - display_name: Option, - timeout: Option, - }, - #[serde(rename = "frontend")] - Frontend { - name: String, - tools: Vec, - instructions: Option, - }, -} - -async fn start_session_handler( - req: SessionRequest, - _api_key: String, -) -> Result { - info!("Starting session with prompt: {}", req.prompt); - - let mut agent = AGENT.lock().await; - - // Create a user message with the prompt - let mut messages = vec![Message::user().with_text(&req.prompt)]; - - // Generate a new session ID and process the messages - let session_id = Uuid::new_v4(); - let session_name = session_id.to_string(); - let session_path = session::get_path(Identifier::Name(session_name.clone())); - - let provider = agent.provider().await.ok(); - - let result = agent - .reply( - &messages, - Some(SessionConfig { - id: Identifier::Name(session_name.clone()), - working_dir: std::env::current_dir().unwrap_or_else(|_| PathBuf::from(".")), - }), - ) - .await; - - match result { - Ok(mut stream) => { - // Process the stream to get the first response - if let Ok(Some(response)) = stream.try_next().await { - let response_text = response.as_concat_text(); - messages.push(response); - if let Err(e) = session::persist_messages(&session_path, &messages, provider.clone()).await { - warn!("Failed to persist session {}: {}", session_name, e); - } - - let api_response = StartSessionResponse { - message: response_text, - status: "success".to_string(), - session_id, - }; - Ok(warp::reply::with_status( - warp::reply::json(&api_response), - warp::http::StatusCode::OK, - )) - } else { - if let Err(e) = session::persist_messages(&session_path, &messages, provider.clone()).await { - warn!("Failed to persist session {}: {}", session_name, e); - } - - let api_response = StartSessionResponse { - message: "Session started but no response generated".to_string(), - status: "warning".to_string(), - session_id, - }; - Ok(warp::reply::with_status( - warp::reply::json(&api_response), - warp::http::StatusCode::OK, - )) - } - }, - Err(e) => { - error!("Failed to start session: {}", e); - let response = ApiResponse { - message: format!("Failed to start session: {}", e), - status: "error".to_string(), - }; - Ok(warp::reply::with_status( - warp::reply::json(&response), - warp::http::StatusCode::INTERNAL_SERVER_ERROR, - )) - } - } -} - -async fn reply_session_handler( - req: SessionReplyRequest, - _api_key: String, -) -> Result { - info!("Replying to session with prompt: {}", req.prompt); - - let mut agent = AGENT.lock().await; - - let session_name = req.session_id.to_string(); - let session_path = session::get_path(Identifier::Name(session_name.clone())); - - // Retrieve existing session history from disk - let mut messages = match session::read_messages(&session_path) { - Ok(m) => m, - Err(_) => { - let response = ApiResponse { - message: "Session not found".to_string(), - status: "error".to_string(), - }; - return Ok(warp::reply::with_status( - warp::reply::json(&response), - warp::http::StatusCode::NOT_FOUND, - )); - } - }; - - // Append the new user message - messages.push(Message::user().with_text(&req.prompt)); - - let provider = agent.provider().await.ok(); - - // Process the messages through the agent - let result = agent - .reply( - &messages, - Some(SessionConfig { - id: Identifier::Name(session_name.clone()), - working_dir: std::env::current_dir().unwrap_or_else(|_| PathBuf::from(".")), - }), - ) - .await; - - match result { - Ok(mut stream) => { - // Process the stream to get the first response - if let Ok(Some(response)) = stream.try_next().await { - let response_text = response.as_concat_text(); - messages.push(response); - if let Err(e) = session::persist_messages(&session_path, &messages, provider.clone()).await { - warn!("Failed to persist session {}: {}", session_name, e); - } - let api_response = ApiResponse { - message: format!("Reply: {}", response_text), - status: "success".to_string(), - }; - Ok(warp::reply::with_status( - warp::reply::json(&api_response), - warp::http::StatusCode::OK, - )) - } else { - if let Err(e) = session::persist_messages(&session_path, &messages, provider.clone()).await { - warn!("Failed to persist session {}: {}", session_name, e); - } - let api_response = ApiResponse { - message: "Reply processed but no response generated".to_string(), - status: "warning".to_string(), - }; - Ok(warp::reply::with_status( - warp::reply::json(&api_response), - warp::http::StatusCode::OK, - )) - } - }, - Err(e) => { - error!("Failed to reply to session: {}", e); - let response = ApiResponse { - message: format!("Failed to reply to session: {}", e), - status: "error".to_string(), - }; - Ok(warp::reply::with_status( - warp::reply::json(&response), - warp::http::StatusCode::INTERNAL_SERVER_ERROR, - )) - } - } -} - -async fn end_session_handler( - req: EndSessionRequest, - _api_key: String, -) -> Result { - let session_name = req.session_id.to_string(); - let session_path = session::get_path(Identifier::Name(session_name.clone())); - - if std::fs::remove_file(&session_path).is_ok() { - let response = ApiResponse { - message: "Session ended".to_string(), - status: "success".to_string(), - }; - Ok(warp::reply::with_status( - warp::reply::json(&response), - warp::http::StatusCode::OK, - )) - } else { - let response = ApiResponse { - message: "Session not found".to_string(), - status: "error".to_string(), - }; - Ok(warp::reply::with_status( - warp::reply::json(&response), - warp::http::StatusCode::NOT_FOUND, - )) - } -} - -async fn list_extensions_handler() -> Result { - info!("Listing extensions"); - - match EXTENSION_MANAGER.list_extensions().await { - Ok(exts) => { - let response = ExtensionsResponse { extensions: exts }; - Ok::(warp::reply::json(&response)) - }, - Err(e) => { - error!("Failed to list extensions: {}", e); - let response = ExtensionsResponse { - extensions: vec!["Failed to list extensions".to_string()] - }; - Ok::(warp::reply::json(&response)) - } - } -} - -async fn get_provider_config_handler() -> Result { - info!("Getting provider configuration"); - - let config = Config::global(); - let provider = config.get_param::("GOOSE_PROVIDER") - .unwrap_or_else(|_| "Not configured".to_string()); - let model = config.get_param::("GOOSE_MODEL") - .unwrap_or_else(|_| "Not configured".to_string()); - - let response = ProviderConfig { provider, model }; - Ok::(warp::reply::json(&response)) -} - -async fn add_extension_handler( - req: ExtensionConfigRequest, - _api_key: String, -) -> Result { - info!("Adding extension: {:?}", req); - - #[cfg(target_os = "windows")] - if let ExtensionConfigRequest::Stdio { cmd, .. } = &req { - if cmd.ends_with("npx.cmd") || cmd.ends_with("npx") { - let node_exists = std::path::Path::new(r"C:\Program Files\nodejs\node.exe").exists() - || std::path::Path::new(r"C:\Program Files (x86)\nodejs\node.exe").exists(); - - if !node_exists { - let cmd_path = std::path::Path::new(cmd); - let script_dir = cmd_path.parent().ok_or_else(|| warp::reject())?; - - let install_script = script_dir.join("install-node.cmd"); - - if install_script.exists() { - eprintln!("Installing Node.js..."); - let output = std::process::Command::new(&install_script) - .arg("https://nodejs.org/dist/v23.10.0/node-v23.10.0-x64.msi") - .output() - .map_err(|_| warp::reject())?; - - if !output.status.success() { - eprintln!( - "Failed to install Node.js: {}", - String::from_utf8_lossy(&output.stderr) - ); - let resp = ExtensionResponse { - error: true, - message: Some(format!( - "Failed to install Node.js: {}", - String::from_utf8_lossy(&output.stderr) - )), - }; - return Ok(warp::reply::json(&resp)); - } - eprintln!("Node.js installation completed"); - } else { - eprintln!( - "Node.js installer script not found at: {}", - install_script.display() - ); - let resp = ExtensionResponse { - error: true, - message: Some("Node.js installer script not found".to_string()), - }; - return Ok(warp::reply::json(&resp)); - } - } - } - } - - let extension = match req { - ExtensionConfigRequest::Sse { name, uri, envs, env_keys, timeout } => { - ExtensionConfig::Sse { - name, - uri, - envs, - env_keys, - description: None, - timeout, - bundled: None, - } - } - ExtensionConfigRequest::Stdio { name, cmd, args, envs, env_keys, timeout } => { - ExtensionConfig::Stdio { - name, - cmd, - args, - envs, - env_keys, - timeout, - description: None, - bundled: None, - } - } - ExtensionConfigRequest::Builtin { name, display_name, timeout } => { - ExtensionConfig::Builtin { - name, - display_name, - timeout, - bundled: None, - } - } - ExtensionConfigRequest::Frontend { name, tools, instructions } => { - ExtensionConfig::Frontend { - name, - tools, - instructions, - bundled: None, - } - } - }; - - let agent = AGENT.lock().await; - let result = agent.add_extension(extension).await; - - let resp = match result { - Ok(_) => ExtensionResponse { error: false, message: None }, - Err(e) => ExtensionResponse { - error: true, - message: Some(format!("Failed to add extension configuration, error: {:?}", e)), - }, - }; - Ok(warp::reply::json(&resp)) -} - -async fn remove_extension_handler( - name: String, - _api_key: String, -) -> Result { - info!("Removing extension: {}", name); - let agent = AGENT.lock().await; - agent.remove_extension(&name).await; - - let resp = ExtensionResponse { error: false, message: None }; - Ok(warp::reply::json(&resp)) -} - -fn with_api_key(api_key: String) -> impl Filter + Clone { - warp::header::value("x-api-key") - .and_then(move |header_api_key: HeaderValue| { - let api_key = api_key.clone(); - async move { - if header_api_key == api_key { - Ok(api_key) - } else { - Err(warp::reject::not_found()) - } - } - }) -} - -// Load configuration from file and environment variables -fn load_configuration() -> std::result::Result { - let config_path = std::env::var("GOOSE_CONFIG").unwrap_or_else(|_| "config".to_string()); - let builder = ConfigBuilder::::default() - .add_source(File::with_name(&config_path).required(false)) - .add_source(Environment::with_prefix("GOOSE_API")); - - builder.build() -} - -// Initialize global provider configuration -async fn initialize_provider_config() -> Result<(), anyhow::Error> { - // Get configuration - let api_config = load_configuration()?; - - // Get provider settings from configuration or environment variables - let provider_name = std::env::var("GOOSE_API_PROVIDER") - .or_else(|_| api_config.get_string("provider")) - .unwrap_or_else(|_| "openai".to_string()); - - let model_name = std::env::var("GOOSE_API_MODEL") - .or_else(|_| api_config.get_string("model")) - .unwrap_or_else(|_| "gpt-4o".to_string()); - - info!("Initializing with provider: {}, model: {}", provider_name, model_name); - - // Initialize the global Config object - let config = Config::global(); - config.set_param("GOOSE_PROVIDER", Value::String(provider_name.clone()))?; - config.set_param("GOOSE_MODEL", Value::String(model_name.clone()))?; - - // Set up API keys from environment variables - let available_providers = providers(); - if let Some(provider_meta) = available_providers.iter().find(|p| p.name == provider_name) { - for key in &provider_meta.config_keys { - let env_name = key.name.clone(); - if let Ok(value) = std::env::var(&env_name) { - if key.secret { - config.set_secret(&key.name, Value::String(value))?; - info!("Set secret key: {}", key.name); - } else { - config.set_param(&key.name, Value::String(value))?; - info!("Set parameter: {}", key.name); - } - } else { - warn!("Environment variable not set for key: {}", key.name); - if key.required { - error!("Required key {} not provided", key.name); - return Err(anyhow::anyhow!("Required key {} not provided", key.name)); - } - } - } - } - - // Initialize agent with provider - let model_config = ModelConfig::new(model_name); - let provider = create(&provider_name, model_config)?; - - let agent = AGENT.lock().await; - agent.update_provider(provider).await?; - - info!("Provider configuration successful"); - Ok(()) -} -/// Initialize extensions from the configuration. -async fn initialize_extensions(config: &config::Config) -> Result<(), anyhow::Error> { - if let Ok(ext_table) = config.get_table("extensions") { - for (name, ext_config) in ext_table { - // Deserialize into ExtensionEntry to get enabled flag and config - let entry: ExtensionEntry = ext_config.clone().try_deserialize() - .map_err(|e| anyhow::anyhow!("Failed to deserialize extension config for {}: {}", name, e))?; - - if entry.enabled { - let extension_config: ExtensionConfig = entry.config; - // Acquire the global agent lock and try to add the extension - let mut agent = AGENT.lock().await; - if let Err(e) = agent.add_extension(extension_config).await { - error!("Failed to add extension {}: {}", name, e); - } - } else { - info!("Skipping disabled extension: {}", name); - } - } - } else { - warn!("No extensions configured in config file."); - } - Ok(()) -} - - -async fn run_init_tests() -> Result<(), anyhow::Error> { - info!("Running initialization tests"); - { - let _agent = AGENT.lock().await; - info!("Agent initialization test passed"); - } - info!("Initialization tests completed"); - Ok(()) -} +use goose_api::run_server; #[tokio::main] async fn main() -> Result<(), anyhow::Error> { - // Initialize tracing - tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .init(); - - info!("Starting goose-api server"); - - // Load configuration - let api_config = load_configuration()?; - - // Get API key from configuration or environment - let api_key: String = std::env::var("GOOSE_API_KEY") - .or_else(|_| api_config.get_string("api_key")) - .unwrap_or_else(|_| { - warn!("No API key configured, using default"); - "default_api_key".to_string() - }); - - // Initialize provider configuration - if let Err(e) = initialize_provider_config().await { - error!("Failed to initialize provider: {}", e); - return Err(e); - } - - // Initialize extensions from configuration - if let Err(e) = initialize_extensions(&api_config).await { - error!("Failed to initialize extensions: {}", e); - } - - if let Err(e) = run_init_tests().await { - error!("Initialization tests failed: {}", e); - } - - // Session start endpoint - let start_session = warp::path("session") - .and(warp::path("start")) - .and(warp::post()) - .and(warp::body::json()) - .and(with_api_key(api_key.clone())) - .and_then(start_session_handler); - - // Session reply endpoint - let reply_session = warp::path("session") - .and(warp::path("reply")) - .and(warp::post()) - .and(warp::body::json()) - .and(with_api_key(api_key.clone())) - .and_then(reply_session_handler); - - // Session end endpoint - let end_session = warp::path("session") - .and(warp::path("end")) - .and(warp::post()) - .and(warp::body::json()) - .and(with_api_key(api_key.clone())) - .and_then(end_session_handler); - - // List extensions endpoint - let list_extensions = warp::path("extensions") - .and(warp::path("list")) - .and(warp::get()) - .and_then(list_extensions_handler); - - // Add extension endpoint - let add_extension = warp::path("extensions") - .and(warp::path("add")) - .and(warp::post()) - .and(warp::body::json()) - .and(with_api_key(api_key.clone())) - .and_then(add_extension_handler); - - // Remove extension endpoint - let remove_extension = warp::path("extensions") - .and(warp::path("remove")) - .and(warp::post()) - .and(warp::body::json()) - .and(with_api_key(api_key.clone())) - .and_then(remove_extension_handler); - - // Get provider configuration endpoint - let get_provider_config = warp::path("provider") - .and(warp::path("config")) - .and(warp::get()) - .and_then(get_provider_config_handler); - - // Combine all routes - let routes = start_session - .or(reply_session) - .or(end_session) - .or(list_extensions) - .or(add_extension) - .or(remove_extension) - .or(get_provider_config); - - // Get bind address from configuration or use default - let host = std::env::var("GOOSE_API_HOST") - .or_else(|_| api_config.get_string("host")) - .unwrap_or_else(|_| "127.0.0.1".to_string()); - - let port = std::env::var("GOOSE_API_PORT") - .or_else(|_| api_config.get_string("port")) - .unwrap_or_else(|_| "8080".to_string()) - .parse::() - .unwrap_or(8080); - - info!("Starting server on {}:{}", host, port); - - // Parse host string - let host_parts: Vec = host.split('.') - .map(|part| part.parse::().unwrap_or(127)) - .collect(); - - let addr = if host_parts.len() == 4 { - [host_parts[0], host_parts[1], host_parts[2], host_parts[3]] - } else { - [127, 0, 0, 1] - }; - - // Start the server - warp::serve(routes) - .run((addr, port)) - .await; - - Ok(()) + run_server().await } diff --git a/crates/goose-api/src/routes.rs b/crates/goose-api/src/routes.rs new file mode 100644 index 00000000..759786c3 --- /dev/null +++ b/crates/goose-api/src/routes.rs @@ -0,0 +1,123 @@ +use warp::Filter; +use tracing::{info, warn, error}; + +use crate::handlers::{ + add_extension_handler, end_session_handler, get_provider_config_handler, + list_extensions_handler, remove_extension_handler, reply_session_handler, + start_session_handler, with_api_key, +}; +use crate::config::{ + initialize_extensions, initialize_provider_config, load_configuration, + run_init_tests, +}; + +pub fn build_routes(api_key: String) -> impl Filter + Clone { + let start_session = warp::path("session") + .and(warp::path("start")) + .and(warp::post()) + .and(warp::body::json()) + .and(with_api_key(api_key.clone())) + .and_then(start_session_handler); + + let reply_session = warp::path("session") + .and(warp::path("reply")) + .and(warp::post()) + .and(warp::body::json()) + .and(with_api_key(api_key.clone())) + .and_then(reply_session_handler); + + let end_session = warp::path("session") + .and(warp::path("end")) + .and(warp::post()) + .and(warp::body::json()) + .and(with_api_key(api_key.clone())) + .and_then(end_session_handler); + + let list_extensions = warp::path("extensions") + .and(warp::path("list")) + .and(warp::get()) + .and_then(list_extensions_handler); + + let add_extension = warp::path("extensions") + .and(warp::path("add")) + .and(warp::post()) + .and(warp::body::json()) + .and(with_api_key(api_key.clone())) + .and_then(add_extension_handler); + + let remove_extension = warp::path("extensions") + .and(warp::path("remove")) + .and(warp::post()) + .and(warp::body::json()) + .and(with_api_key(api_key.clone())) + .and_then(remove_extension_handler); + + let get_provider_config = warp::path("provider") + .and(warp::path("config")) + .and(warp::get()) + .and_then(get_provider_config_handler); + + start_session + .or(reply_session) + .or(end_session) + .or(list_extensions) + .or(add_extension) + .or(remove_extension) + .or(get_provider_config) +} + +pub async fn run_server() -> Result<(), anyhow::Error> { + tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .init(); + + info!("Starting goose-api server"); + + let api_config = load_configuration()?; + + let api_key: String = std::env::var("GOOSE_API_KEY") + .or_else(|_| api_config.get_string("api_key")) + .unwrap_or_else(|_| { + warn!("No API key configured, using default"); + "default_api_key".to_string() + }); + + if let Err(e) = initialize_provider_config().await { + error!("Failed to initialize provider: {}", e); + return Err(e); + } + + if let Err(e) = initialize_extensions(&api_config).await { + error!("Failed to initialize extensions: {}", e); + } + + if let Err(e) = run_init_tests().await { + error!("Initialization tests failed: {}", e); + } + + let routes = build_routes(api_key.clone()); + + let host = std::env::var("GOOSE_API_HOST") + .or_else(|_| api_config.get_string("host")) + .unwrap_or_else(|_| "127.0.0.1".to_string()); + let port = std::env::var("GOOSE_API_PORT") + .or_else(|_| api_config.get_string("port")) + .unwrap_or_else(|_| "8080".to_string()) + .parse::() + .unwrap_or(8080); + + info!("Starting server on {}:{}", host, port); + + let host_parts: Vec = host + .split('.') + .map(|part| part.parse::().unwrap_or(127)) + .collect(); + let addr = if host_parts.len() == 4 { + [host_parts[0], host_parts[1], host_parts[2], host_parts[3]] + } else { + [127, 0, 0, 1] + }; + + warp::serve(routes).run((addr, port)).await; + Ok(()) +} diff --git a/crates/goose-api/src/tests.rs b/crates/goose-api/src/tests.rs new file mode 100644 index 00000000..302cf8c3 --- /dev/null +++ b/crates/goose-api/src/tests.rs @@ -0,0 +1,10 @@ +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn build_routes_compiles() { + let _routes = build_routes("test-key".to_string()); + // Just ensure building routes doesn't panic + } +} From 295946ecd2d86bdc0842af862d77f98b0a2b479b Mon Sep 17 00:00:00 2001 From: Aljaz Ceru Date: Thu, 29 May 2025 10:21:59 +0200 Subject: [PATCH 11/23] cargo fixes --- crates/goose-api/src/config.rs | 4 ++-- crates/goose-api/src/handlers.rs | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/crates/goose-api/src/config.rs b/crates/goose-api/src/config.rs index 9cafc919..8f1b1d89 100644 --- a/crates/goose-api/src/config.rs +++ b/crates/goose-api/src/config.rs @@ -1,4 +1,4 @@ -use crate::handlers::{AGENT, EXTENSION_MANAGER}; +use crate::handlers::AGENT; use goose::config::{Config, ExtensionEntry}; use goose::agents::ExtensionConfig; use goose::providers::{create, providers}; @@ -72,7 +72,7 @@ pub async fn initialize_extensions(config: &config::Config) -> Result<(), anyhow if entry.enabled { let extension_config: ExtensionConfig = entry.config; - let mut agent = AGENT.lock().await; + let agent = AGENT.lock().await; if let Err(e) = agent.add_extension(extension_config).await { error!("Failed to add extension {}: {}", name, e); } diff --git a/crates/goose-api/src/handlers.rs b/crates/goose-api/src/handlers.rs index 52a0ad8e..6e8ba957 100644 --- a/crates/goose-api/src/handlers.rs +++ b/crates/goose-api/src/handlers.rs @@ -8,7 +8,7 @@ use mcp_core::tool::Tool; use goose::agents::{extension::Envs, extension_manager::ExtensionManager, ExtensionConfig, Agent, SessionConfig}; use goose::message::Message; use goose::session::{self, Identifier}; -use goose::config::{Config, ExtensionEntry}; +use goose::config::Config; use std::sync::LazyLock; pub static EXTENSION_MANAGER: LazyLock = LazyLock::new(|| ExtensionManager::default()); @@ -105,7 +105,7 @@ pub async fn start_session_handler( ) -> Result { info!("Starting session with prompt: {}", req.prompt); - let mut agent = AGENT.lock().await; + let agent = AGENT.lock().await; let mut messages = vec![Message::user().with_text(&req.prompt)]; let session_id = Uuid::new_v4(); let session_name = session_id.to_string(); @@ -177,7 +177,7 @@ pub async fn reply_session_handler( ) -> Result { info!("Replying to session with prompt: {}", req.prompt); - let mut agent = AGENT.lock().await; + let agent = AGENT.lock().await; let session_name = req.session_id.to_string(); let session_path = session::get_path(Identifier::Name(session_name.clone())); From 88a4954f7706f0018418e8fcde114d0deb779ccc Mon Sep 17 00:00:00 2001 From: Aljaz Ceru Date: Thu, 29 May 2025 10:22:35 +0200 Subject: [PATCH 12/23] cargo.lock --- Cargo.lock | 319 +++++++++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 296 insertions(+), 23 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 693e68a5..7a1ffa7b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -28,6 +28,17 @@ dependencies = [ "cpufeatures", ] +[[package]] +name = "ahash" +version = "0.7.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "891477e0c6a8957309ee5c45a6368af3ae14bb510732d2684ffa19af310920f9" +dependencies = [ + "getrandom 0.2.15", + "once_cell", + "version_check", +] + [[package]] name = "ahash" version = "0.8.11" @@ -328,7 +339,7 @@ dependencies = [ "fastrand", "hex", "http 0.2.12", - "ring", + "ring 0.17.12", "time", "tokio", "tracing", @@ -693,7 +704,7 @@ dependencies = [ "sha1", "sync_wrapper 1.0.2", "tokio", - "tokio-tungstenite", + "tokio-tungstenite 0.24.0", "tower 0.5.2", "tower-layer", "tower-service", @@ -1364,6 +1375,25 @@ dependencies = [ "memchr", ] +[[package]] +name = "config" +version = "0.13.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "23738e11972c7643e4ec947840fc463b6a571afcd3e735bdfce7d03c7a784aca" +dependencies = [ + "async-trait", + "json5", + "lazy_static", + "nom", + "pathdiff", + "ron 0.7.1", + "rust-ini 0.18.0", + "serde", + "serde_json", + "toml 0.5.11", + "yaml-rust", +] + [[package]] name = "config" version = "0.14.1" @@ -1375,8 +1405,8 @@ dependencies = [ "json5", "nom", "pathdiff", - "ron", - "rust-ini", + "ron 0.8.1", + "rust-ini 0.20.0", "serde", "serde_json", "toml 0.8.20", @@ -1833,6 +1863,12 @@ dependencies = [ "syn 2.0.99", ] +[[package]] +name = "dlv-list" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0688c2a7f92e427f44895cd63841bff7b29f8d7a1648b9e7e07a4a365b2e1257" + [[package]] name = "dlv-list" version = "0.5.2" @@ -2407,7 +2443,7 @@ dependencies = [ "futures", "include_dir", "indoc", - "jsonwebtoken", + "jsonwebtoken 9.3.1", "keyring", "lazy_static", "mcp-client", @@ -2441,6 +2477,28 @@ dependencies = [ "wiremock", ] +[[package]] +name = "goose-api" +version = "0.1.0" +dependencies = [ + "anyhow", + "config 0.13.4", + "futures", + "futures-util", + "goose", + "goose-mcp", + "jsonwebtoken 8.3.0", + "mcp-client", + "mcp-core", + "serde", + "serde_json", + "tokio", + "tracing", + "tracing-subscriber", + "uuid", + "warp", +] + [[package]] name = "goose-bench" version = "1.0.20" @@ -2574,7 +2632,7 @@ dependencies = [ "bytes", "chrono", "clap 4.5.31", - "config", + "config 0.14.1", "dirs 6.0.0", "etcetera", "futures", @@ -2666,6 +2724,9 @@ name = "hashbrown" version = "0.12.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888" +dependencies = [ + "ahash 0.7.8", +] [[package]] name = "hashbrown" @@ -2673,7 +2734,7 @@ version = "0.14.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1" dependencies = [ - "ahash", + "ahash 0.8.11", "allocator-api2", ] @@ -2692,6 +2753,30 @@ dependencies = [ "hashbrown 0.14.5", ] +[[package]] +name = "headers" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "06683b93020a07e3dbcf5f8c0f6d40080d725bea7936fc01ad345c01b97dc270" +dependencies = [ + "base64 0.21.7", + "bytes", + "headers-core", + "http 0.2.12", + "httpdate", + "mime", + "sha1", +] + +[[package]] +name = "headers-core" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e7f66481bfee273957b1f20485a4ff3362987f85b2c236580d81b4eb7a326429" +dependencies = [ + "http 0.2.12", +] + [[package]] name = "heck" version = "0.4.1" @@ -3391,6 +3476,20 @@ dependencies = [ "serde", ] +[[package]] +name = "jsonwebtoken" +version = "8.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6971da4d9c3aa03c3d8f3ff0f4155b534aad021292003895a469716b2a230378" +dependencies = [ + "base64 0.21.7", + "pem 1.1.1", + "ring 0.16.20", + "serde", + "serde_json", + "simple_asn1", +] + [[package]] name = "jsonwebtoken" version = "9.3.1" @@ -3399,8 +3498,8 @@ checksum = "5a87cc7a48537badeae96744432de36f4be2b4a34a05a5ef32e9dd8a1c169dde" dependencies = [ "base64 0.22.1", "js-sys", - "pem", - "ring", + "pem 3.0.5", + "ring 0.17.12", "serde", "serde_json", "simple_asn1", @@ -3755,6 +3854,16 @@ version = "0.3.17" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a" +[[package]] +name = "mime_guess" +version = "2.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f7c44f8e672c00fe5308fa235f821cb4198414e1c77935c1ab6948d3fd78550e" +dependencies = [ + "mime", + "unicase", +] + [[package]] name = "minijinja" version = "2.8.0" @@ -3838,6 +3947,24 @@ dependencies = [ "syn 2.0.99", ] +[[package]] +name = "multer" +version = "2.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "01acbdc23469fd8fe07ab135923371d5f5a422fbf9c522158677c8eb15bc51c2" +dependencies = [ + "bytes", + "encoding_rs", + "futures-util", + "http 0.2.12", + "httparse", + "log", + "memchr", + "mime", + "spin 0.9.8", + "version_check", +] + [[package]] name = "nanoid" version = "0.4.0" @@ -4193,13 +4320,23 @@ version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "04744f49eae99ab78e0d5c0b603ab218f515ea8cfe5a456d7629ad883a3b6e7d" +[[package]] +name = "ordered-multimap" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ccd746e37177e1711c20dd619a1620f34f5c8b569c53590a72dedd5344d8924a" +dependencies = [ + "dlv-list 0.3.0", + "hashbrown 0.12.3", +] + [[package]] name = "ordered-multimap" version = "0.7.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "49203cdcae0030493bad186b28da2fa25645fa276a51b6fec8010d281e02ef79" dependencies = [ - "dlv-list", + "dlv-list 0.5.2", "hashbrown 0.14.5", ] @@ -4265,6 +4402,15 @@ version = "0.2.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "df94ce210e5bc13cb6651479fa48d14f601d9858cfe0467f43ae157023b938d3" +[[package]] +name = "pem" +version = "1.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a8835c273a76a90455d7344889b0964598e3316e2a79ede8e36f16bdcf2228b8" +dependencies = [ + "base64 0.13.1", +] + [[package]] name = "pem" version = "3.0.5" @@ -4613,7 +4759,7 @@ dependencies = [ "bytes", "getrandom 0.2.15", "rand", - "ring", + "ring 0.17.12", "rustc-hash 2.1.1", "rustls 0.23.23", "rustls-pki-types", @@ -4965,6 +5111,21 @@ dependencies = [ "bytemuck", ] +[[package]] +name = "ring" +version = "0.16.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3053cf52e236a3ed746dfc745aa9cacf1b791d846bdaf412f60a8d7d6e17c8fc" +dependencies = [ + "cc", + "libc", + "once_cell", + "spin 0.5.2", + "untrusted 0.7.1", + "web-sys", + "winapi", +] + [[package]] name = "ring" version = "0.17.12" @@ -4975,10 +5136,21 @@ dependencies = [ "cfg-if", "getrandom 0.2.15", "libc", - "untrusted", + "untrusted 0.9.0", "windows-sys 0.52.0", ] +[[package]] +name = "ron" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "88073939a61e5b7680558e6be56b419e208420c2adb92be54921fa6b72283f1a" +dependencies = [ + "base64 0.13.1", + "bitflags 1.3.2", + "serde", +] + [[package]] name = "ron" version = "0.8.1" @@ -4991,6 +5163,16 @@ dependencies = [ "serde_derive", ] +[[package]] +name = "rust-ini" +version = "0.18.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f6d5f2436026b4f6e79dc829837d467cc7e9a55ee40e750d716713540715a2df" +dependencies = [ + "cfg-if", + "ordered-multimap 0.4.3", +] + [[package]] name = "rust-ini" version = "0.20.0" @@ -4998,7 +5180,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3e0698206bcb8882bf2a9ecb4c1e7785db57ff052297085a6efd4fe42302068a" dependencies = [ "cfg-if", - "ordered-multimap", + "ordered-multimap 0.7.3", ] [[package]] @@ -5048,7 +5230,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3f56a14d1f48b391359b22f731fd4bd7e43c97f3c50eee276f3aa09c94784d3e" dependencies = [ "log", - "ring", + "ring 0.17.12", "rustls-webpki 0.101.7", "sct", ] @@ -5060,7 +5242,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "47796c98c480fce5406ef69d1c76378375492c3b0a0de587be0c1d9feb12f395" dependencies = [ "once_cell", - "ring", + "ring 0.17.12", "rustls-pki-types", "rustls-webpki 0.102.8", "subtle", @@ -5124,8 +5306,8 @@ version = "0.101.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8b6275d1ee7a1cd780b64aca7726599a1dbc893b1e64144529e55c3c2f745765" dependencies = [ - "ring", - "untrusted", + "ring 0.17.12", + "untrusted 0.9.0", ] [[package]] @@ -5134,9 +5316,9 @@ version = "0.102.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "64ca1bc8749bd4cf37b5ce386cc146580777b4e8572c7b97baf22c83f444bee9" dependencies = [ - "ring", + "ring 0.17.12", "rustls-pki-types", - "untrusted", + "untrusted 0.9.0", ] [[package]] @@ -5224,6 +5406,12 @@ dependencies = [ "syn 2.0.99", ] +[[package]] +name = "scoped-tls" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e1cf6437eb19a8f4a6cc0f7dca544973b0b78843adbfeb3683d1a94a0024a294" + [[package]] name = "scopeguard" version = "1.2.0" @@ -5236,8 +5424,8 @@ version = "0.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "da046153aa2352493d6cb7da4b6e5c0c057d8a1d0a9aa8560baffdd945acd414" dependencies = [ - "ring", - "untrusted", + "ring 0.17.12", + "untrusted 0.9.0", ] [[package]] @@ -5561,6 +5749,18 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "spin" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6e63cff320ae2c57904679ba7cb63280a3dc4613885beafb148ee7bf9aa9042d" + +[[package]] +name = "spin" +version = "0.9.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67" + [[package]] name = "spm_precompiled" version = "0.1.4" @@ -6086,6 +6286,18 @@ dependencies = [ "tokio", ] +[[package]] +name = "tokio-tungstenite" +version = "0.21.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c83b561d025642014097b66e6c1bb422783339e0909e4429cde4749d1990bc38" +dependencies = [ + "futures-util", + "log", + "tokio", + "tungstenite 0.21.0", +] + [[package]] name = "tokio-tungstenite" version = "0.24.0" @@ -6095,7 +6307,7 @@ dependencies = [ "futures-util", "log", "tokio", - "tungstenite", + "tungstenite 0.24.0", ] [[package]] @@ -6308,6 +6520,25 @@ version = "0.2.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b" +[[package]] +name = "tungstenite" +version = "0.21.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ef1a641ea34f399a848dea702823bbecfb4c486f911735368f1f137cb8257e1" +dependencies = [ + "byteorder", + "bytes", + "data-encoding", + "http 1.2.0", + "httparse", + "log", + "rand", + "sha1", + "thiserror 1.0.69", + "url", + "utf-8", +] + [[package]] name = "tungstenite" version = "0.24.0" @@ -6345,7 +6576,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "17ec15f1f191ba42ba0ed0f788999eec910c201cbbd4ae5de7cf0eb0a94b3d1a" dependencies = [ "aes", - "ahash", + "ahash 0.8.11", "base64 0.22.1", "byteorder", "cbc", @@ -6367,6 +6598,12 @@ dependencies = [ "zip 2.5.0", ] +[[package]] +name = "unicase" +version = "2.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75b844d17643ee918803943289730bec8aac480150456169e647ed0b576ba539" + [[package]] name = "unicode-ident" version = "1.0.18" @@ -6418,6 +6655,12 @@ version = "0.2.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "673aac59facbab8a9007c7f6108d11f63b603f7cabff99fabf650fea5c32b861" +[[package]] +name = "untrusted" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a156c684c91ea7d62626509bce3cb4e1d9ed5c4d978f7b4352658f96a4c26b4a" + [[package]] name = "untrusted" version = "0.9.0" @@ -6498,6 +6741,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e0f540e3240398cce6128b64ba83fdbdd86129c16a3aa1a3a252efd66eb3d587" dependencies = [ "getrandom 0.3.1", + "serde", ] [[package]] @@ -6560,6 +6804,35 @@ dependencies = [ "try-lock", ] +[[package]] +name = "warp" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4378d202ff965b011c64817db11d5829506d3404edeadb61f190d111da3f231c" +dependencies = [ + "bytes", + "futures-channel", + "futures-util", + "headers", + "http 0.2.12", + "hyper 0.14.32", + "log", + "mime", + "mime_guess", + "multer", + "percent-encoding", + "pin-project", + "scoped-tls", + "serde", + "serde_json", + "serde_urlencoded", + "tokio", + "tokio-tungstenite 0.21.0", + "tokio-util", + "tower-service", + "tracing", +] + [[package]] name = "wasi" version = "0.11.0+wasi-snapshot-preview1" From ac60a79ad1b6fe146d52512402c2a6cac20a98b5 Mon Sep 17 00:00:00 2001 From: Aljaz Date: Thu, 29 May 2025 10:42:01 +0200 Subject: [PATCH 13/23] fix SessionConfig initialization --- crates/goose-api/src/handlers.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/crates/goose-api/src/handlers.rs b/crates/goose-api/src/handlers.rs index 6e8ba957..653ca746 100644 --- a/crates/goose-api/src/handlers.rs +++ b/crates/goose-api/src/handlers.rs @@ -119,6 +119,7 @@ pub async fn start_session_handler( Some(SessionConfig { id: Identifier::Name(session_name.clone()), working_dir: std::env::current_dir().unwrap_or_else(|_| PathBuf::from(".")), + schedule_id: None, }), ) .await; @@ -206,6 +207,7 @@ pub async fn reply_session_handler( Some(SessionConfig { id: Identifier::Name(session_name.clone()), working_dir: std::env::current_dir().unwrap_or_else(|_| PathBuf::from(".")), + schedule_id: None, }), ) .await; From 0146865c1119670d621150fcf19d95a823825596 Mon Sep 17 00:00:00 2001 From: Aljaz Date: Thu, 29 May 2025 11:21:14 +0200 Subject: [PATCH 14/23] fix(api): handle result when removing extension --- crates/goose-api/src/handlers.rs | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/crates/goose-api/src/handlers.rs b/crates/goose-api/src/handlers.rs index 653ca746..096e7193 100644 --- a/crates/goose-api/src/handlers.rs +++ b/crates/goose-api/src/handlers.rs @@ -430,9 +430,15 @@ pub async fn remove_extension_handler( ) -> Result { info!("Removing extension: {}", name); let agent = AGENT.lock().await; - agent.remove_extension(&name).await; + let result = agent.remove_extension(&name).await; - let resp = ExtensionResponse { error: false, message: None }; + let resp = match result { + Ok(_) => ExtensionResponse { error: false, message: None }, + Err(e) => ExtensionResponse { + error: true, + message: Some(format!("Failed to remove extension, error: {:?}", e)), + }, + }; Ok(warp::reply::json(&resp)) } From 500f2f02105891ae77038cdb54a6f5a3c0aae148 Mon Sep 17 00:00:00 2001 From: Aljaz Date: Thu, 29 May 2025 11:21:16 +0200 Subject: [PATCH 15/23] fix(api): handle result when removing extension --- crates/goose-api/src/handlers.rs | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/crates/goose-api/src/handlers.rs b/crates/goose-api/src/handlers.rs index 653ca746..096e7193 100644 --- a/crates/goose-api/src/handlers.rs +++ b/crates/goose-api/src/handlers.rs @@ -430,9 +430,15 @@ pub async fn remove_extension_handler( ) -> Result { info!("Removing extension: {}", name); let agent = AGENT.lock().await; - agent.remove_extension(&name).await; + let result = agent.remove_extension(&name).await; - let resp = ExtensionResponse { error: false, message: None }; + let resp = match result { + Ok(_) => ExtensionResponse { error: false, message: None }, + Err(e) => ExtensionResponse { + error: true, + message: Some(format!("Failed to remove extension, error: {:?}", e)), + }, + }; Ok(warp::reply::json(&resp)) } From 23b480326e7cdf162f01a500a057e1a2e760f94c Mon Sep 17 00:00:00 2001 From: Aljaz Date: Thu, 29 May 2025 11:25:39 +0200 Subject: [PATCH 16/23] feat(api): load provider config from CLI --- crates/goose-api/src/config.rs | 45 +++++++++++++++++++++++++++------- 1 file changed, 36 insertions(+), 9 deletions(-) diff --git a/crates/goose-api/src/config.rs b/crates/goose-api/src/config.rs index 8f1b1d89..fbf7a1a5 100644 --- a/crates/goose-api/src/config.rs +++ b/crates/goose-api/src/config.rs @@ -18,13 +18,31 @@ pub fn load_configuration() -> std::result::Result Result<(), anyhow::Error> { let api_config = load_configuration()?; - let provider_name = std::env::var("GOOSE_API_PROVIDER") - .or_else(|_| api_config.get_string("provider")) - .unwrap_or_else(|_| "openai".to_string()); + let global_config = Config::global(); - let model_name = std::env::var("GOOSE_API_MODEL") - .or_else(|_| api_config.get_string("model")) - .unwrap_or_else(|_| "gpt-4o".to_string()); + let provider_name = if let Ok(val) = std::env::var("GOOSE_API_PROVIDER") { + val + } else if let Ok(val) = api_config.get_string("provider") { + val + } else if global_config.exists() { + global_config + .get_param::("GOOSE_PROVIDER") + .unwrap_or_else(|_| "openai".to_string()) + } else { + "openai".to_string() + }; + + let model_name = if let Ok(val) = std::env::var("GOOSE_API_MODEL") { + val + } else if let Ok(val) = api_config.get_string("model") { + val + } else if global_config.exists() { + global_config + .get_param::("GOOSE_MODEL") + .unwrap_or_else(|_| "gpt-4o".to_string()) + } else { + "gpt-4o".to_string() + }; info!("Initializing with provider: {}, model: {}", provider_name, model_name); @@ -44,12 +62,21 @@ pub async fn initialize_provider_config() -> Result<(), anyhow::Error> { config.set_param(&key.name, Value::String(value))?; info!("Set parameter: {}", key.name); } - } else { - warn!("Environment variable not set for key: {}", key.name); - if key.required { + } else if global_config.exists() { + // If not provided via environment, try existing CLI config + let result: Result = if key.secret { + global_config.get_secret(&key.name) + } else { + global_config.get_param(&key.name) + }; + + if result.is_err() && key.required { error!("Required key {} not provided", key.name); return Err(anyhow::anyhow!("Required key {} not provided", key.name)); } + } else if key.required { + error!("Required key {} not provided", key.name); + return Err(anyhow::anyhow!("Required key {} not provided", key.name)); } } } From 6c558cbb3e4b280cfb1f0abbf71a0a4046b820ba Mon Sep 17 00:00:00 2001 From: Aljaz Date: Thu, 29 May 2025 11:33:14 +0200 Subject: [PATCH 17/23] Load provider secrets from CLI config --- crates/goose-api/src/config.rs | 77 +++++++++++++++++++++++++++++----- 1 file changed, 66 insertions(+), 11 deletions(-) diff --git a/crates/goose-api/src/config.rs b/crates/goose-api/src/config.rs index 8f1b1d89..e41b8798 100644 --- a/crates/goose-api/src/config.rs +++ b/crates/goose-api/src/config.rs @@ -18,13 +18,31 @@ pub fn load_configuration() -> std::result::Result Result<(), anyhow::Error> { let api_config = load_configuration()?; - let provider_name = std::env::var("GOOSE_API_PROVIDER") - .or_else(|_| api_config.get_string("provider")) - .unwrap_or_else(|_| "openai".to_string()); + let global_config = Config::global(); - let model_name = std::env::var("GOOSE_API_MODEL") - .or_else(|_| api_config.get_string("model")) - .unwrap_or_else(|_| "gpt-4o".to_string()); + let provider_name = if let Ok(val) = std::env::var("GOOSE_API_PROVIDER") { + val + } else if let Ok(val) = api_config.get_string("provider") { + val + } else if global_config.exists() { + global_config + .get_param::("GOOSE_PROVIDER") + .unwrap_or_else(|_| "openai".to_string()) + } else { + "openai".to_string() + }; + + let model_name = if let Ok(val) = std::env::var("GOOSE_API_MODEL") { + val + } else if let Ok(val) = api_config.get_string("model") { + val + } else if global_config.exists() { + global_config + .get_param::("GOOSE_MODEL") + .unwrap_or_else(|_| "gpt-4o".to_string()) + } else { + "gpt-4o".to_string() + }; info!("Initializing with provider: {}, model: {}", provider_name, model_name); @@ -44,12 +62,49 @@ pub async fn initialize_provider_config() -> Result<(), anyhow::Error> { config.set_param(&key.name, Value::String(value))?; info!("Set parameter: {}", key.name); } - } else { - warn!("Environment variable not set for key: {}", key.name); - if key.required { - error!("Required key {} not provided", key.name); - return Err(anyhow::anyhow!("Required key {} not provided", key.name)); + } else if global_config.exists() { + // If not provided via environment, try existing CLI config + let result: Result = if key.secret { + global_config.get_secret(&key.name) + } else { + global_config.get_param(&key.name) + }; + + match result { + Ok(value) => { + if key.secret { + config.set_secret(&key.name, Value::String(value))?; + } else { + config.set_param(&key.name, Value::String(value))?; + } + info!("Loaded {} from CLI config", key.name); + } + Err(_) => { + if let Some(default) = &key.default { + if key.secret { + config.set_secret(&key.name, Value::String(default.clone()))?; + } else { + config.set_param(&key.name, Value::String(default.clone()))?; + } + info!("Using default for {}", key.name); + } else if key.required { + error!("Required key {} not provided", key.name); + return Err(anyhow::anyhow!("Required key {} not provided", key.name)); + } else { + warn!("Environment variable not set for key: {}", key.name); + } + } } + } else if let Some(default) = &key.default { + if key.secret { + config.set_secret(&key.name, Value::String(default.clone()))?; + } else { + config.set_param(&key.name, Value::String(default.clone()))?; + } + info!("Using default for {}", key.name); + } else if key.required { + error!("Required key {} not provided", key.name); + return Err(anyhow::anyhow!("Required key {} not provided", key.name)); } } } From f1551b60dfa2d33635304ac41ede70cc8d90b84e Mon Sep 17 00:00:00 2001 From: Aljaz Date: Thu, 29 May 2025 11:53:26 +0200 Subject: [PATCH 18/23] goose-api: prefer CLI config --- crates/goose-api/README.md | 12 ++++++++---- crates/goose-api/config | 8 ++++++++ crates/goose-api/src/config.rs | 18 +++++++++++++++++- 3 files changed, 33 insertions(+), 5 deletions(-) create mode 100644 crates/goose-api/config diff --git a/crates/goose-api/README.md b/crates/goose-api/README.md index 3b5469a4..b7fb63a5 100644 --- a/crates/goose-api/README.md +++ b/crates/goose-api/README.md @@ -28,15 +28,19 @@ cargo build --release ## Configuration -Goose API supports configuration through both environment variables and a configuration file. The precedence order is: +Goose API supports configuration via environment variables and configuration files. +The precedence order is: 1. Environment variables (highest priority) -2. Configuration file (lower priority) -3. Default values (lowest priority) +2. Goose CLI configuration file (usually `~/.config/goose/config.yaml`) if it exists +3. `config` file shipped with the crate +4. Default values (lowest priority) ### Configuration File -Create a file named `config` (with no extension) in the directory where you run the goose-api. The format can be JSON, YAML, TOML, etc. (the `config` crate will detect the format automatically). +If no CLI configuration file is found, goose-api looks for a `config` file in its +crate directory. This file has no extension and can be JSON, YAML, TOML, etc. +The `config` crate will detect the format automatically. Example `config` file (YAML format): diff --git a/crates/goose-api/config b/crates/goose-api/config new file mode 100644 index 00000000..6c721877 --- /dev/null +++ b/crates/goose-api/config @@ -0,0 +1,8 @@ +# API server configuration +host: 0.0.0.0 +port: 8080 +api_key: kurac + +# Provider configuration +provider: ollama +model: qwen3:8b diff --git a/crates/goose-api/src/config.rs b/crates/goose-api/src/config.rs index 11911206..cc6057b9 100644 --- a/crates/goose-api/src/config.rs +++ b/crates/goose-api/src/config.rs @@ -8,7 +8,23 @@ use config::{builder::DefaultState, ConfigBuilder, Environment, File}; use serde_json::Value; pub fn load_configuration() -> std::result::Result { - let config_path = std::env::var("GOOSE_CONFIG").unwrap_or_else(|_| "config".to_string()); + // Determine the configuration file based on priority: + // 1. Explicit GOOSE_CONFIG env var + // 2. Goose CLI config if it exists + // 3. Fallback to config file packaged with goose-api + + let config_path = if let Ok(path) = std::env::var("GOOSE_CONFIG") { + path + } else { + let global = Config::global(); + if global.exists() { + global.path() + } else { + // Use the config file that ships with goose-api + format!("{}/config", env!("CARGO_MANIFEST_DIR")) + } + }; + let builder = ConfigBuilder::::default() .add_source(File::with_name(&config_path).required(false)) .add_source(Environment::with_prefix("GOOSE_API")); From 26da43ae077c084710d233795ba1e7c75810aa36 Mon Sep 17 00:00:00 2001 From: Aljaz Date: Thu, 29 May 2025 14:53:37 +0200 Subject: [PATCH 19/23] Handle context length errors with summarization in API --- crates/goose-api/Cargo.toml | 4 ++ crates/goose-api/src/handlers.rs | 48 +++++++++++++++- crates/goose-api/src/tests.rs | 99 +++++++++++++++++++++++++++++++- 3 files changed, 149 insertions(+), 2 deletions(-) diff --git a/crates/goose-api/Cargo.toml b/crates/goose-api/Cargo.toml index 1caa2ffe..0db532b7 100644 --- a/crates/goose-api/Cargo.toml +++ b/crates/goose-api/Cargo.toml @@ -22,3 +22,7 @@ futures-util = "0.3" # For session IDs uuid = { version = "1", features = ["serde", "v4"] } # Add dynamic-library for extension loading + +[dev-dependencies] +tempfile = "3" +async-trait = "0.1" diff --git a/crates/goose-api/src/handlers.rs b/crates/goose-api/src/handlers.rs index 096e7193..b2019ba2 100644 --- a/crates/goose-api/src/handlers.rs +++ b/crates/goose-api/src/handlers.rs @@ -6,7 +6,7 @@ use futures_util::TryStreamExt; use tracing::{info, warn, error}; use mcp_core::tool::Tool; use goose::agents::{extension::Envs, extension_manager::ExtensionManager, ExtensionConfig, Agent, SessionConfig}; -use goose::message::Message; +use goose::message::{Message, MessageContent}; use goose::session::{self, Identifier}; use goose::config::Config; use std::sync::LazyLock; @@ -127,6 +127,30 @@ pub async fn start_session_handler( match result { Ok(mut stream) => { if let Ok(Some(response)) = stream.try_next().await { + if matches!(response.content.first(), Some(MessageContent::ContextLengthExceeded(_))) { + match agent.summarize_context(&messages).await { + Ok((summarized, _)) => { + messages = summarized; + if let Err(e) = session::persist_messages(&session_path, &messages, provider.clone()).await { + warn!("Failed to persist session {}: {}", session_name, e); + } + + let api_response = StartSessionResponse { + message: "Conversation summarized to fit context window".to_string(), + status: "warning".to_string(), + session_id, + }; + return Ok(warp::reply::with_status( + warp::reply::json(&api_response), + warp::http::StatusCode::OK, + )); + } + Err(e) => { + warn!("Failed to summarize context: {}", e); + } + } + } + let response_text = response.as_concat_text(); messages.push(response); if let Err(e) = session::persist_messages(&session_path, &messages, provider.clone()).await { @@ -215,6 +239,28 @@ pub async fn reply_session_handler( match result { Ok(mut stream) => { if let Ok(Some(response)) = stream.try_next().await { + if matches!(response.content.first(), Some(MessageContent::ContextLengthExceeded(_))) { + match agent.summarize_context(&messages).await { + Ok((summarized, _)) => { + messages = summarized; + if let Err(e) = session::persist_messages(&session_path, &messages, provider.clone()).await { + warn!("Failed to persist session {}: {}", session_name, e); + } + let api_response = ApiResponse { + message: "Conversation summarized to fit context window".to_string(), + status: "warning".to_string(), + }; + return Ok(warp::reply::with_status( + warp::reply::json(&api_response), + warp::http::StatusCode::OK, + )); + } + Err(e) => { + warn!("Failed to summarize context: {}", e); + } + } + } + let response_text = response.as_concat_text(); messages.push(response); if let Err(e) = session::persist_messages(&session_path, &messages, provider.clone()).await { diff --git a/crates/goose-api/src/tests.rs b/crates/goose-api/src/tests.rs index 302cf8c3..607064d4 100644 --- a/crates/goose-api/src/tests.rs +++ b/crates/goose-api/src/tests.rs @@ -1,10 +1,107 @@ #[cfg(test)] mod tests { use super::*; + use goose::message::{Message, MessageContent}; + use goose::model::ModelConfig; + use goose::providers::{ + base::{Provider, ProviderMetadata, ProviderUsage, Usage}, + errors::ProviderError, + }; + use mcp_core::tool::Tool; + use std::sync::Arc; + use tempfile::TempDir; + use warp::reply::Reply; + use goose::session::{self, Identifier}; + use uuid::Uuid; + use hyper::body; + + #[derive(Clone)] + struct ContextProvider { + model_config: ModelConfig, + } + + #[async_trait::async_trait] + impl Provider for ContextProvider { + fn metadata() -> ProviderMetadata { + ProviderMetadata::empty() + } + + fn get_model_config(&self) -> ModelConfig { + self.model_config.clone() + } + + async fn complete( + &self, + system: &str, + _messages: &[Message], + _tools: &[Tool], + ) -> Result<(Message, ProviderUsage), ProviderError> { + if system.contains("summarizing") { + Ok(( + Message::user().with_text("summary"), + ProviderUsage::new("mock".to_string(), Usage::default()), + )) + } else { + Err(ProviderError::ContextLengthExceeded("too long".to_string())) + } + } + } + + async fn setup() -> (TempDir, Uuid) { + let tmp = tempfile::tempdir().unwrap(); + std::env::set_var("HOME", tmp.path()); + + let provider = Arc::new(ContextProvider { + model_config: ModelConfig::new("test".to_string()), + }); + let agent = AGENT.lock().await; + agent.update_provider(provider).await.unwrap(); + drop(agent); + + let req = SessionRequest { + prompt: "start".repeat(1000), + }; + let reply = start_session_handler(req, "key".to_string()).await.unwrap(); + let resp = reply.into_response(); + let body = body::to_bytes(resp.into_body()).await.unwrap(); + let start: StartSessionResponse = serde_json::from_slice(&body).unwrap(); + (tmp, start.session_id) + } #[tokio::test] async fn build_routes_compiles() { let _routes = build_routes("test-key".to_string()); - // Just ensure building routes doesn't panic + } + + #[tokio::test] + async fn summarizes_large_history_on_start() { + let (tmp, session_id) = setup().await; + + let session_path = session::get_path(Identifier::Name(session_id.to_string())); + let messages = session::read_messages(&session_path).unwrap(); + assert!(messages.iter().any(|m| m.as_concat_text().contains("summary"))); + drop(tmp); + } + + #[tokio::test] + async fn summarizes_large_history_on_reply() { + let (tmp, session_id) = setup().await; + + let req = SessionReplyRequest { + session_id, + prompt: "reply".repeat(1000), + }; + let reply = reply_session_handler(req, "key".to_string()).await.unwrap(); + let resp = reply.into_response(); + let body = body::to_bytes(resp.into_body()).await.unwrap(); + let api: ApiResponse = serde_json::from_slice(&body).unwrap(); + assert_eq!(api.status, "warning"); + + let session_path = session::get_path(Identifier::Name(session_id.to_string())); + let messages = session::read_messages(&session_path).unwrap(); + assert!(messages + .iter() + .all(|m| !matches!(m.content.first(), Some(MessageContent::ContextLengthExceeded(_))))); + drop(tmp); } } From e2a56bb628c476847040908d35d885bdf24c0b05 Mon Sep 17 00:00:00 2001 From: Aljaz Date: Thu, 29 May 2025 14:53:43 +0200 Subject: [PATCH 20/23] Add session summarization endpoint --- crates/goose-api/README.md | 39 +++++++++++++++++++ crates/goose-api/src/handlers.rs | 66 ++++++++++++++++++++++++++++++++ crates/goose-api/src/routes.rs | 10 ++++- 3 files changed, 114 insertions(+), 1 deletion(-) diff --git a/crates/goose-api/README.md b/crates/goose-api/README.md index b7fb63a5..2191e609 100644 --- a/crates/goose-api/README.md +++ b/crates/goose-api/README.md @@ -236,6 +236,31 @@ By default, the server runs on `127.0.0.1:8080`. You can modify this using confi } ``` +### 7. Summarize Session + +**Endpoint**: `POST /session/summarize` + +**Description**: Summarizes the full conversation for a given session. + +**Request**: +- Headers: + - Content-Type: application/json + - x-api-key: [your-api-key] +- Body: +```json +{ + "session_id": "" +} +``` + +**Response**: +```json +{ + "message": "", + "status": "success" +} +``` + ## Session Management Sessions created via the API are stored in the same location as the CLI @@ -279,6 +304,12 @@ curl -X POST http://localhost:8080/extensions/remove \ # Get provider configuration curl -X GET http://localhost:8080/provider/config \ -H "x-api-key: your_secure_api_key" + +# Summarize a session +curl -X POST http://localhost:8080/session/summarize \ + -H "Content-Type: application/json" \ + -H "x-api-key: your_secure_api_key" \ + -d '{"session_id": "your-session-id"}' ``` ### Using Python @@ -332,6 +363,14 @@ print(response.json()) # Get provider configuration response = requests.get(f"{API_URL}/provider/config", headers=HEADERS) print(response.json()) + +# Summarize a session +response = requests.post( + f"{API_URL}/session/summarize", + headers=HEADERS, + json={"session_id": "your-session-id"} +) +print(response.json()) ``` ## Troubleshooting diff --git a/crates/goose-api/src/handlers.rs b/crates/goose-api/src/handlers.rs index 096e7193..94fd4abf 100644 --- a/crates/goose-api/src/handlers.rs +++ b/crates/goose-api/src/handlers.rs @@ -43,6 +43,11 @@ pub struct EndSessionRequest { pub session_id: Uuid, } +#[derive(Debug, Serialize, Deserialize)] +pub struct SummarizeSessionRequest { + pub session_id: Uuid, +} + #[derive(Debug, Serialize, Deserialize)] pub struct ExtensionsResponse { pub extensions: Vec, @@ -284,6 +289,67 @@ pub async fn end_session_handler( } } +pub async fn summarize_session_handler( + req: SummarizeSessionRequest, + _api_key: String, +) -> Result { + info!("Summarizing session: {}", req.session_id); + + let agent = AGENT.lock().await; + + let session_name = req.session_id.to_string(); + let session_path = session::get_path(Identifier::Name(session_name.clone())); + + let messages = match session::read_messages(&session_path) { + Ok(m) => m, + Err(_) => { + let response = ApiResponse { + message: "Session not found".to_string(), + status: "error".to_string(), + }; + return Ok(warp::reply::with_status( + warp::reply::json(&response), + warp::http::StatusCode::NOT_FOUND, + )); + } + }; + + let provider = agent.provider().await.ok(); + + match agent.summarize_context(&messages).await { + Ok((summarized_messages, _)) => { + let summary_text = summarized_messages + .first() + .map(|m| m.as_concat_text()) + .unwrap_or_default(); + + if let Err(e) = session::persist_messages(&session_path, &summarized_messages, provider.clone()).await { + warn!("Failed to persist session {}: {}", session_name, e); + } + + let resp = ApiResponse { + message: summary_text, + status: "success".to_string(), + }; + Ok(warp::reply::with_status( + warp::reply::json(&resp), + warp::http::StatusCode::OK, + )) + } + Err(e) => { + error!("Failed to summarize session: {}", e); + let resp = ApiResponse { + message: format!("Failed to summarize session: {}", e), + status: "error".to_string(), + }; + Ok(warp::reply::with_status( + warp::reply::json(&resp), + warp::http::StatusCode::INTERNAL_SERVER_ERROR, + )) + } + } +} + pub async fn list_extensions_handler() -> Result { info!("Listing extensions"); diff --git a/crates/goose-api/src/routes.rs b/crates/goose-api/src/routes.rs index 759786c3..5762bcdd 100644 --- a/crates/goose-api/src/routes.rs +++ b/crates/goose-api/src/routes.rs @@ -4,7 +4,7 @@ use tracing::{info, warn, error}; use crate::handlers::{ add_extension_handler, end_session_handler, get_provider_config_handler, list_extensions_handler, remove_extension_handler, reply_session_handler, - start_session_handler, with_api_key, + start_session_handler, summarize_session_handler, with_api_key, }; use crate::config::{ initialize_extensions, initialize_provider_config, load_configuration, @@ -26,6 +26,13 @@ pub fn build_routes(api_key: String) -> impl Filter impl Filter Date: Thu, 29 May 2025 14:53:54 +0200 Subject: [PATCH 21/23] feat(api): manage agents per session --- crates/goose-api/Cargo.toml | 1 + crates/goose-api/src/api_sessions.rs | 45 ++++++++++++++++++++++++ crates/goose-api/src/handlers.rs | 52 ++++++++++++++++++++++++---- crates/goose-api/src/lib.rs | 1 + 4 files changed, 93 insertions(+), 6 deletions(-) create mode 100644 crates/goose-api/src/api_sessions.rs diff --git a/crates/goose-api/Cargo.toml b/crates/goose-api/Cargo.toml index 1caa2ffe..7276663e 100644 --- a/crates/goose-api/Cargo.toml +++ b/crates/goose-api/Cargo.toml @@ -21,4 +21,5 @@ futures = "0.3" futures-util = "0.3" # For session IDs uuid = { version = "1", features = ["serde", "v4"] } +dashmap = "6" # Add dynamic-library for extension loading diff --git a/crates/goose-api/src/api_sessions.rs b/crates/goose-api/src/api_sessions.rs new file mode 100644 index 00000000..3c259ade --- /dev/null +++ b/crates/goose-api/src/api_sessions.rs @@ -0,0 +1,45 @@ +use dashmap::DashMap; +use goose::agents::Agent; +use std::sync::{atomic::{AtomicU64, Ordering}, Arc, LazyLock}; +use tokio::sync::Mutex; +use uuid::Uuid; +use std::time::{Duration, SystemTime, UNIX_EPOCH}; + +pub struct ApiSession { + pub agent: Arc>, // agent for this session + last_active: AtomicU64, +} + +impl ApiSession { + pub fn new(agent: Agent) -> Self { + Self { + agent: Arc::new(Mutex::new(agent)), + last_active: AtomicU64::new(current_timestamp()), + } + } + + pub fn touch(&self) { + self.last_active.store(current_timestamp(), Ordering::Relaxed); + } + + pub fn is_expired(&self, ttl: Duration) -> bool { + current_timestamp() - self.last_active.load(Ordering::Relaxed) > ttl.as_secs() + } +} + +fn current_timestamp() -> u64 { + SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap_or_default() + .as_secs() +} + +pub static SESSIONS: LazyLock> = LazyLock::new(DashMap::new); + +pub const SESSION_TIMEOUT_SECS: u64 = 3600; + +pub fn cleanup_expired_sessions() { + let ttl = Duration::from_secs(SESSION_TIMEOUT_SECS); + SESSIONS.retain(|_, sess| !sess.is_expired(ttl)); +} + diff --git a/crates/goose-api/src/handlers.rs b/crates/goose-api/src/handlers.rs index 096e7193..e84de7c4 100644 --- a/crates/goose-api/src/handlers.rs +++ b/crates/goose-api/src/handlers.rs @@ -10,6 +10,7 @@ use goose::message::Message; use goose::session::{self, Identifier}; use goose::config::Config; use std::sync::LazyLock; +use crate::api_sessions::{ApiSession, SESSIONS, cleanup_expired_sessions}; pub static EXTENSION_MANAGER: LazyLock = LazyLock::new(|| ExtensionManager::default()); pub static AGENT: LazyLock> = LazyLock::new(|| tokio::sync::Mutex::new(Agent::new())); @@ -105,15 +106,30 @@ pub async fn start_session_handler( ) -> Result { info!("Starting session with prompt: {}", req.prompt); - let agent = AGENT.lock().await; + cleanup_expired_sessions(); + + // create fresh agent using provider from the template agent + let template = AGENT.lock().await; + let mut new_agent = Agent::new(); + if let Ok(provider) = template.provider().await { + let _ = new_agent.update_provider(provider).await; + } + drop(template); + let mut messages = vec![Message::user().with_text(&req.prompt)]; let session_id = Uuid::new_v4(); let session_name = session_id.to_string(); let session_path = session::get_path(Identifier::Name(session_name.clone())); - let provider = agent.provider().await.ok(); + let session = ApiSession::new(new_agent); + let agent_ref = session.agent.clone(); + SESSIONS.insert(session_id, session); - let result = agent + let provider = agent_ref.lock().await.provider().await.ok(); + + let result = agent_ref + .lock() + .await .reply( &messages, Some(SessionConfig { @@ -178,11 +194,28 @@ pub async fn reply_session_handler( ) -> Result { info!("Replying to session with prompt: {}", req.prompt); - let agent = AGENT.lock().await; + cleanup_expired_sessions(); let session_name = req.session_id.to_string(); let session_path = session::get_path(Identifier::Name(session_name.clone())); + let session_entry = match SESSIONS.get(&req.session_id) { + Some(s) => s, + None => { + let response = ApiResponse { + message: "Session not found".to_string(), + status: "error".to_string(), + }; + return Ok(warp::reply::with_status( + warp::reply::json(&response), + warp::http::StatusCode::NOT_FOUND, + )); + } + }; + session_entry.touch(); + let agent_ref = session_entry.agent.clone(); + drop(session_entry); + let mut messages = match session::read_messages(&session_path) { Ok(m) => m, Err(_) => { @@ -199,9 +232,11 @@ pub async fn reply_session_handler( messages.push(Message::user().with_text(&req.prompt)); - let provider = agent.provider().await.ok(); + let provider = agent_ref.lock().await.provider().await.ok(); - let result = agent + let result = agent_ref + .lock() + .await .reply( &messages, Some(SessionConfig { @@ -260,9 +295,14 @@ pub async fn end_session_handler( req: EndSessionRequest, _api_key: String, ) -> Result { + cleanup_expired_sessions(); + let session_name = req.session_id.to_string(); let session_path = session::get_path(Identifier::Name(session_name.clone())); + // remove in-memory agent if present + SESSIONS.remove(&req.session_id); + if std::fs::remove_file(&session_path).is_ok() { let response = ApiResponse { message: "Session ended".to_string(), diff --git a/crates/goose-api/src/lib.rs b/crates/goose-api/src/lib.rs index b2037198..3b8e911e 100644 --- a/crates/goose-api/src/lib.rs +++ b/crates/goose-api/src/lib.rs @@ -1,5 +1,6 @@ mod handlers; mod config; mod routes; +mod api_sessions; pub use routes::{build_routes, run_server}; From 9fb798052c9fa8514b2713486d76cd2e82ac2047 Mon Sep 17 00:00:00 2001 From: Aljaz Date: Thu, 29 May 2025 14:53:59 +0200 Subject: [PATCH 22/23] feat(api): add metrics endpoint --- crates/goose-api/README.md | 24 ++++++++++++++ crates/goose-api/src/handlers.rs | 35 ++++++++++++++++++++ crates/goose-api/src/routes.rs | 7 +++- crates/goose/src/agents/extension_manager.rs | 20 ++++++++++- crates/mcp-client/src/transport/sse.rs | 12 +++++-- crates/mcp-client/src/transport/stdio.rs | 9 ++++- 6 files changed, 102 insertions(+), 5 deletions(-) diff --git a/crates/goose-api/README.md b/crates/goose-api/README.md index b7fb63a5..d44af846 100644 --- a/crates/goose-api/README.md +++ b/crates/goose-api/README.md @@ -236,6 +236,30 @@ By default, the server runs on `127.0.0.1:8080`. You can modify this using confi } ``` +### 7. Metrics + +**Endpoint**: `GET /metrics` + +**Description**: Returns runtime metrics about stored sessions and extensions. + +**Request**: +- Headers: + - `x-api-key: [your-api-key]` + +**Response** (example): +```json +{ + "session_messages": { + "20240605_001234": 3, + "20240605_010000": 5 + }, + "active_sessions": 2, + "pending_requests": { + "mcp_say": 0 + } +} +``` + ## Session Management Sessions created via the API are stored in the same location as the CLI diff --git a/crates/goose-api/src/handlers.rs b/crates/goose-api/src/handlers.rs index 096e7193..f731ea88 100644 --- a/crates/goose-api/src/handlers.rs +++ b/crates/goose-api/src/handlers.rs @@ -10,6 +10,7 @@ use goose::message::Message; use goose::session::{self, Identifier}; use goose::config::Config; use std::sync::LazyLock; +use std::collections::HashMap; pub static EXTENSION_MANAGER: LazyLock = LazyLock::new(|| ExtensionManager::default()); pub static AGENT: LazyLock> = LazyLock::new(|| tokio::sync::Mutex::new(Agent::new())); @@ -60,6 +61,13 @@ pub struct ExtensionResponse { pub message: Option, } +#[derive(Debug, Serialize)] +pub struct MetricsResponse { + pub session_messages: HashMap, + pub active_sessions: usize, + pub pending_requests: HashMap, +} + #[derive(Debug, Deserialize)] #[serde(tag = "type")] pub enum ExtensionConfigRequest { @@ -442,6 +450,33 @@ pub async fn remove_extension_handler( Ok(warp::reply::json(&resp)) } +pub async fn metrics_handler() -> Result { + // Gather session message counts + let mut session_messages = HashMap::new(); + if let Ok(sessions) = session::list_sessions() { + for (name, path) in sessions { + if let Ok(messages) = session::read_messages(&path) { + session_messages.insert(name, messages.len()); + } + } + } + + let active_sessions = session_messages.len(); + + // Gather pending request sizes for each extension + let pending_requests = EXTENSION_MANAGER + .pending_request_sizes() + .await; + + let resp = MetricsResponse { + session_messages, + active_sessions, + pending_requests, + }; + + Ok(warp::reply::json(&resp)) +} + pub fn with_api_key(api_key: String) -> impl Filter + Clone { warp::header::value("x-api-key") .and_then(move |header_api_key: HeaderValue| { diff --git a/crates/goose-api/src/routes.rs b/crates/goose-api/src/routes.rs index 759786c3..ea5680ea 100644 --- a/crates/goose-api/src/routes.rs +++ b/crates/goose-api/src/routes.rs @@ -4,7 +4,7 @@ use tracing::{info, warn, error}; use crate::handlers::{ add_extension_handler, end_session_handler, get_provider_config_handler, list_extensions_handler, remove_extension_handler, reply_session_handler, - start_session_handler, with_api_key, + start_session_handler, metrics_handler, with_api_key, }; use crate::config::{ initialize_extensions, initialize_provider_config, load_configuration, @@ -57,6 +57,10 @@ pub fn build_routes(api_key: String) -> impl Filter impl Filter Result<(), anyhow::Error> { diff --git a/crates/goose/src/agents/extension_manager.rs b/crates/goose/src/agents/extension_manager.rs index 4bc4d746..4b03a99a 100644 --- a/crates/goose/src/agents/extension_manager.rs +++ b/crates/goose/src/agents/extension_manager.rs @@ -17,7 +17,7 @@ use crate::agents::extension::Envs; use crate::config::{Config, ExtensionConfigManager}; use crate::prompt_template; use mcp_client::client::{ClientCapabilities, ClientInfo, McpClient, McpClientTrait}; -use mcp_client::transport::{SseTransport, StdioTransport, Transport}; +use mcp_client::transport::{PendingRequests, SseTransport, StdioTransport, Transport}; use mcp_core::{prompt::Prompt, Content, Tool, ToolCall, ToolError, ToolResult}; use serde_json::Value; @@ -33,6 +33,7 @@ pub struct ExtensionManager { clients: HashMap, instructions: HashMap, resource_capable_extensions: HashSet, + pending_requests: HashMap>, // track pending requests per extension } /// A flattened representation of a resource used by the agent to prepare inference @@ -103,6 +104,7 @@ impl ExtensionManager { clients: HashMap::new(), instructions: HashMap::new(), resource_capable_extensions: HashSet::new(), + pending_requests: HashMap::new(), } } @@ -183,12 +185,14 @@ impl ExtensionManager { let all_envs = merge_environments(envs, env_keys, &sanitized_name).await?; let transport = SseTransport::new(uri, all_envs); let handle = transport.start().await?; + let pending = handle.pending_requests(); let service = McpService::with_timeout( handle, Duration::from_secs( timeout.unwrap_or(crate::config::DEFAULT_EXTENSION_TIMEOUT), ), ); + self.pending_requests.insert(sanitized_name.clone(), pending); Box::new(McpClient::new(service)) } ExtensionConfig::Stdio { @@ -202,12 +206,14 @@ impl ExtensionManager { let all_envs = merge_environments(envs, env_keys, &sanitized_name).await?; let transport = StdioTransport::new(cmd, args.to_vec(), all_envs); let handle = transport.start().await?; + let pending = handle.pending_requests(); let service = McpService::with_timeout( handle, Duration::from_secs( timeout.unwrap_or(crate::config::DEFAULT_EXTENSION_TIMEOUT), ), ); + self.pending_requests.insert(sanitized_name.clone(), pending); Box::new(McpClient::new(service)) } ExtensionConfig::Builtin { @@ -227,12 +233,14 @@ impl ExtensionManager { HashMap::new(), ); let handle = transport.start().await?; + let pending = handle.pending_requests(); let service = McpService::with_timeout( handle, Duration::from_secs( timeout.unwrap_or(crate::config::DEFAULT_EXTENSION_TIMEOUT), ), ); + self.pending_requests.insert(sanitized_name.clone(), pending); Box::new(McpClient::new(service)) } _ => unreachable!(), @@ -285,9 +293,19 @@ impl ExtensionManager { self.clients.remove(&sanitized_name); self.instructions.remove(&sanitized_name); self.resource_capable_extensions.remove(&sanitized_name); + self.pending_requests.remove(&sanitized_name); Ok(()) } + /// Get the size of each extension's pending request map + pub async fn pending_request_sizes(&self) -> HashMap { + let mut result = HashMap::new(); + for (name, pending) in &self.pending_requests { + result.insert(name.clone(), pending.len().await); + } + result + } + pub async fn suggest_disable_extensions_prompt(&self) -> Value { let enabled_extensions_count = self.clients.len(); diff --git a/crates/mcp-client/src/transport/sse.rs b/crates/mcp-client/src/transport/sse.rs index 8a564708..0e15f168 100644 --- a/crates/mcp-client/src/transport/sse.rs +++ b/crates/mcp-client/src/transport/sse.rs @@ -223,6 +223,7 @@ impl SseActor { #[derive(Clone)] pub struct SseTransportHandle { sender: mpsc::Sender, + pending_requests: Arc, } #[async_trait::async_trait] @@ -232,6 +233,12 @@ impl TransportHandle for SseTransportHandle { } } +impl SseTransportHandle { + pub fn pending_requests(&self) -> Arc { + Arc::clone(&self.pending_requests) + } +} + #[derive(Clone)] pub struct SseTransport { sse_url: String, @@ -284,9 +291,10 @@ impl Transport for SseTransport { let post_endpoint_clone = Arc::clone(&post_endpoint); // Build the actor + let pending_requests = Arc::new(PendingRequests::new()); let actor = SseActor::new( rx, - Arc::new(PendingRequests::new()), + pending_requests.clone(), self.sse_url.clone(), post_endpoint, ); @@ -301,7 +309,7 @@ impl Transport for SseTransport { ) .await { - Ok(_) => Ok(SseTransportHandle { sender: tx }), + Ok(_) => Ok(SseTransportHandle { sender: tx, pending_requests }), Err(e) => Err(Error::SseConnection(e.to_string())), } } diff --git a/crates/mcp-client/src/transport/stdio.rs b/crates/mcp-client/src/transport/stdio.rs index 5895e83e..76a48487 100644 --- a/crates/mcp-client/src/transport/stdio.rs +++ b/crates/mcp-client/src/transport/stdio.rs @@ -189,6 +189,7 @@ impl StdioActor { pub struct StdioTransportHandle { sender: mpsc::Sender, error_receiver: Arc>>, + pending_requests: Arc, } #[async_trait::async_trait] @@ -212,6 +213,10 @@ impl StdioTransportHandle { Err(_) => Ok(()), } } + + pub fn pending_requests(&self) -> Arc { + Arc::clone(&self.pending_requests) + } } pub struct StdioTransport { @@ -292,9 +297,10 @@ impl Transport for StdioTransport { let (message_tx, message_rx) = mpsc::channel(32); let (error_tx, error_rx) = mpsc::channel(1); + let pending_requests = Arc::new(PendingRequests::new()); let actor = StdioActor { receiver: Some(message_rx), - pending_requests: Arc::new(PendingRequests::new()), + pending_requests: pending_requests.clone(), process, error_sender: error_tx, stdin: Some(stdin), @@ -307,6 +313,7 @@ impl Transport for StdioTransport { let handle = StdioTransportHandle { sender: message_tx, error_receiver: Arc::new(Mutex::new(error_rx)), + pending_requests, }; Ok(handle) } From 2002602fc5e5e62858234d01363bf1d76e4ed88e Mon Sep 17 00:00:00 2001 From: Aljaz Ceru Date: Sat, 26 Jul 2025 17:28:31 +0200 Subject: [PATCH 23/23] api --- Cargo.lock | 144 +++++-- config.yaml | 49 --- crates/goose-api/Cargo.toml | 1 + crates/goose-api/README.md | 13 +- crates/goose-api/config | 4 +- crates/goose-api/goose-api-plan.md | 53 +++ crates/goose-api/src/api_sessions.rs | 21 +- crates/goose-api/src/config.rs | 16 +- crates/goose-api/src/handlers.rs | 385 ++++++++----------- crates/goose-api/src/main.rs | 76 +++- crates/goose-api/src/routes.rs | 61 +-- crates/goose-api/test.py | 98 +++++ crates/goose/src/agents/extension_manager.rs | 53 --- crates/goose/src/tracing/langfuse_layer.rs | 2 - crates/goose/tests/providers.rs | 12 - crates/mcp-client/src/transport/stdio.rs | 58 +-- 16 files changed, 608 insertions(+), 438 deletions(-) delete mode 100644 config.yaml create mode 100644 crates/goose-api/goose-api-plan.md create mode 100644 crates/goose-api/test.py diff --git a/Cargo.lock b/Cargo.lock index 07324dd7..5df2a73c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -261,7 +261,7 @@ version = "52.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "16f4a9468c882dc66862cef4e1fd8423d47e67972377d85d80e022786427768c" dependencies = [ - "ahash", + "ahash 0.8.11", "arrow-buffer", "arrow-data", "arrow-schema", @@ -392,7 +392,7 @@ version = "52.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4cd09a518c602a55bd406bcc291a967b284cfa7a63edfbf8b897ea4748aad23c" dependencies = [ - "ahash", + "ahash 0.8.11", "arrow-array", "arrow-buffer", "arrow-data", @@ -415,7 +415,7 @@ version = "52.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "600bae05d43483d216fb3494f8c32fdbefd8aa4e1de237e790dbb3d9f44690a3" dependencies = [ - "ahash", + "ahash 0.8.11", "arrow-array", "arrow-buffer", "arrow-data", @@ -686,7 +686,7 @@ dependencies = [ "fastrand 2.3.0", "hex", "http 0.2.12", - "ring 0.17.12", + "ring 0.17.14", "time", "tokio", "tracing", @@ -1074,7 +1074,7 @@ dependencies = [ "sha1", "sync_wrapper 1.0.2", "tokio", - "tokio-tungstenite 0.24.0", + "tokio-tungstenite 0.26.2", "tower 0.5.2", "tower-layer", "tower-service", @@ -1819,6 +1819,25 @@ dependencies = [ "crossbeam-utils", ] +[[package]] +name = "config" +version = "0.13.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "23738e11972c7643e4ec947840fc463b6a571afcd3e735bdfce7d03c7a784aca" +dependencies = [ + "async-trait", + "json5", + "lazy_static", + "nom", + "pathdiff", + "ron 0.7.1", + "rust-ini 0.18.0", + "serde", + "serde_json", + "toml 0.5.11", + "yaml-rust", +] + [[package]] name = "config" version = "0.14.1" @@ -2200,7 +2219,7 @@ version = "41.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e4fd4a99fc70d40ef7e52b243b4a399c3f8d353a40d5ecb200deee05e49c61bb" dependencies = [ - "ahash", + "ahash 0.8.11", "arrow", "arrow-array", "arrow-ipc", @@ -2263,7 +2282,7 @@ version = "41.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "44fdbc877e3e40dcf88cc8f283d9f5c8851f0a3aa07fee657b1b75ac1ad49b9c" dependencies = [ - "ahash", + "ahash 0.8.11", "arrow", "arrow-array", "arrow-buffer", @@ -2314,7 +2333,7 @@ version = "41.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1c1841c409d9518c17971d15c9bae62e629eb937e6fb6c68cd32e9186f8b30d2" dependencies = [ - "ahash", + "ahash 0.8.11", "arrow", "arrow-array", "arrow-buffer", @@ -2356,7 +2375,7 @@ version = "41.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2b4ece19f73c02727e5e8654d79cd5652de371352c1df3c4ac3e419ecd6943fb" dependencies = [ - "ahash", + "ahash 0.8.11", "arrow", "arrow-schema", "datafusion-common", @@ -2416,7 +2435,7 @@ version = "41.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9a223962b3041304a3e20ed07a21d5de3d88d7e4e71ca192135db6d24e3365a4" dependencies = [ - "ahash", + "ahash 0.8.11", "arrow", "arrow-array", "arrow-buffer", @@ -2446,7 +2465,7 @@ version = "41.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "db5e7d8532a1601cd916881db87a70b0a599900d23f3db2897d389032da53bc6" dependencies = [ - "ahash", + "ahash 0.8.11", "arrow", "datafusion-common", "datafusion-expr", @@ -2472,7 +2491,7 @@ version = "41.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8d1116949432eb2d30f6362707e2846d942e491052a206f2ddcb42d08aea1ffe" dependencies = [ - "ahash", + "ahash 0.8.11", "arrow", "arrow-array", "arrow-buffer", @@ -3425,7 +3444,7 @@ dependencies = [ "futures-util", "include_dir", "indoc 2.0.6", - "jsonwebtoken", + "jsonwebtoken 9.3.1", "keyring", "lancedb", "lazy_static", @@ -3467,7 +3486,9 @@ name = "goose-api" version = "0.1.0" dependencies = [ "anyhow", + "async-trait", "config 0.13.4", + "dashmap 6.1.0", "futures", "futures-util", "goose", @@ -3475,8 +3496,10 @@ dependencies = [ "jsonwebtoken 8.3.0", "mcp-client", "mcp-core", + "mcp-server", "serde", "serde_json", + "tempfile", "tokio", "tracing", "tracing-subscriber", @@ -4592,7 +4615,7 @@ dependencies = [ "base64 0.22.1", "js-sys", "pem 3.0.5", - "ring 0.17.12", + "ring 0.17.14", "serde", "serde_json", "simple_asn1", @@ -5638,6 +5661,24 @@ dependencies = [ "syn 2.0.99", ] +[[package]] +name = "multer" +version = "2.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "01acbdc23469fd8fe07ab135923371d5f5a422fbf9c522158677c8eb15bc51c2" +dependencies = [ + "bytes", + "encoding_rs", + "futures-util", + "http 0.2.12", + "httparse", + "log", + "memchr", + "mime", + "spin 0.9.8", + "version_check", +] + [[package]] name = "multimap" version = "0.10.1" @@ -5969,7 +6010,7 @@ dependencies = [ "quick-xml 0.36.2", "rand 0.8.5", "reqwest 0.12.12", - "ring", + "ring 0.17.14", "rustls-pemfile 2.2.0", "serde", "serde_json", @@ -6712,7 +6753,7 @@ dependencies = [ "bytes", "getrandom 0.2.15", "rand 0.8.5", - "ring", + "ring 0.17.14", "rustc-hash 2.1.1", "rustls 0.23.23", "rustls-pki-types", @@ -7118,6 +7159,21 @@ dependencies = [ "bytemuck", ] +[[package]] +name = "ring" +version = "0.16.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3053cf52e236a3ed746dfc745aa9cacf1b791d846bdaf412f60a8d7d6e17c8fc" +dependencies = [ + "cc", + "libc", + "once_cell", + "spin 0.5.2", + "untrusted 0.7.1", + "web-sys", + "winapi", +] + [[package]] name = "ring" version = "0.17.14" @@ -7142,6 +7198,17 @@ dependencies = [ "byteorder", ] +[[package]] +name = "ron" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "88073939a61e5b7680558e6be56b419e208420c2adb92be54921fa6b72283f1a" +dependencies = [ + "base64 0.13.1", + "bitflags 1.3.2", + "serde", +] + [[package]] name = "ron" version = "0.8.1" @@ -7171,7 +7238,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3e0698206bcb8882bf2a9ecb4c1e7785db57ff052297085a6efd4fe42302068a" dependencies = [ "cfg-if", - "ordered-multimap", + "ordered-multimap 0.7.3", ] [[package]] @@ -7258,7 +7325,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3f56a14d1f48b391359b22f731fd4bd7e43c97f3c50eee276f3aa09c94784d3e" dependencies = [ "log", - "ring 0.17.12", + "ring 0.17.14", "rustls-webpki 0.101.7", "sct", ] @@ -7270,7 +7337,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "47796c98c480fce5406ef69d1c76378375492c3b0a0de587be0c1d9feb12f395" dependencies = [ "once_cell", - "ring 0.17.12", + "ring 0.17.14", "rustls-pki-types", "rustls-webpki 0.102.8", "subtle", @@ -7334,7 +7401,7 @@ version = "0.101.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8b6275d1ee7a1cd780b64aca7726599a1dbc893b1e64144529e55c3c2f745765" dependencies = [ - "ring 0.17.12", + "ring 0.17.14", "untrusted 0.9.0", ] @@ -7344,7 +7411,7 @@ version = "0.102.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "64ca1bc8749bd4cf37b5ce386cc146580777b4e8572c7b97baf22c83f444bee9" dependencies = [ - "ring 0.17.12", + "ring 0.17.14", "rustls-pki-types", "untrusted 0.9.0", ] @@ -7481,7 +7548,7 @@ version = "0.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "da046153aa2352493d6cb7da4b6e5c0c057d8a1d0a9aa8560baffdd945acd414" dependencies = [ - "ring 0.17.12", + "ring 0.17.14", "untrusted 0.9.0", ] @@ -8657,6 +8724,18 @@ dependencies = [ "tokio", ] +[[package]] +name = "tokio-tungstenite" +version = "0.21.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c83b561d025642014097b66e6c1bb422783339e0909e4429cde4749d1990bc38" +dependencies = [ + "futures-util", + "log", + "tokio", + "tungstenite 0.21.0", +] + [[package]] name = "tokio-tungstenite" version = "0.26.2" @@ -8666,7 +8745,7 @@ dependencies = [ "futures-util", "log", "tokio", - "tungstenite", + "tungstenite 0.26.2", ] [[package]] @@ -8894,6 +8973,25 @@ version = "0.2.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b" +[[package]] +name = "tungstenite" +version = "0.21.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ef1a641ea34f399a848dea702823bbecfb4c486f911735368f1f137cb8257e1" +dependencies = [ + "byteorder", + "bytes", + "data-encoding", + "http 1.2.0", + "httparse", + "log", + "rand 0.8.5", + "sha1", + "thiserror 1.0.69", + "url", + "utf-8", +] + [[package]] name = "tungstenite" version = "0.26.2" diff --git a/config.yaml b/config.yaml deleted file mode 100644 index 825c1f0c..00000000 --- a/config.yaml +++ /dev/null @@ -1,49 +0,0 @@ -extensions: - computercontroller: - bundled: true - display_name: Computer Controller - enabled: true - name: computercontroller - timeout: 300 - type: builtin - developer: - bundled: true - display_name: Developer Tools - enabled: true - name: developer - timeout: 300 - type: builtin - filesytem: - args: - - -y - - '@modelcontextprotocol/server-filesystem' - - /home/lio/g - bundled: null - cmd: npx - description: 'access files inside ~/g ' - enabled: true - env_keys: [] - envs: {} - name: filesytem - timeout: 300 - type: stdio - filesytem-extension: - args: - - -y - - '@modelcontextprotocol/server-filesystem' - bundled: null - cmd: npx - description: null - enabled: false - env_keys: [] - envs: {} - name: filesytem-extension - timeout: 300 - type: stdio - memory: - bundled: true - display_name: Memory - enabled: true - name: memory - timeout: 300 - type: builtin diff --git a/crates/goose-api/Cargo.toml b/crates/goose-api/Cargo.toml index 10117159..be864027 100644 --- a/crates/goose-api/Cargo.toml +++ b/crates/goose-api/Cargo.toml @@ -8,6 +8,7 @@ goose = { path = "../goose" } goose-mcp = { path = "../goose-mcp" } mcp-client = { path = "../mcp-client" } mcp-core = { path = "../mcp-core" } +mcp-server = { path = "../mcp-server" } tokio = { version = "1", features = ["full"] } warp = "0.3" serde = { version = "1", features = ["derive"] } diff --git a/crates/goose-api/README.md b/crates/goose-api/README.md index 23fc0fa3..601e2b11 100644 --- a/crates/goose-api/README.md +++ b/crates/goose-api/README.md @@ -287,8 +287,15 @@ By default, the server runs on `127.0.0.1:8080`. You can modify this using confi Sessions created via the API are stored in the same location as the CLI (`~/.local/share/goose/sessions` on most platforms). Each session is saved to a -`.jsonl` file. You can resume or inspect these sessions with the CLI -by providing the session ID returned from the API. +`.jsonl` file. + +You can resume or inspect these sessions with the CLI by providing the session ID +(which is a UUID) returned from the API. For example, if the API returns a +`session_id` of `a1b2c3d4-e5f6-7890-1234-567890abcdef`, you can resume it with: + +```bash +goose session --resume --name a1b2c3d4-e5f6-7890-1234-567890abcdef +``` ## Examples @@ -298,7 +305,7 @@ by providing the session ID returned from the API. # Start a session curl -X POST http://localhost:8080/session/start \ -H "Content-Type: application/json" \ - -H "x-api-key: your_secure_api_key" \ + -H "x-api-key: kurac" \ -d '{"prompt": "Create a Python function to generate Fibonacci numbers"}' # Reply to an ongoing session diff --git a/crates/goose-api/config b/crates/goose-api/config index 6c721877..b080c62e 100644 --- a/crates/goose-api/config +++ b/crates/goose-api/config @@ -1,8 +1,8 @@ # API server configuration host: 0.0.0.0 -port: 8080 +port: 8181 api_key: kurac # Provider configuration provider: ollama -model: qwen3:8b +model: qwen3:4b diff --git a/crates/goose-api/goose-api-plan.md b/crates/goose-api/goose-api-plan.md new file mode 100644 index 00000000..6a1185a2 --- /dev/null +++ b/crates/goose-api/goose-api-plan.md @@ -0,0 +1,53 @@ +# Plan for `goose-api` Review and Improvements + +This document outlines the plan to address the user's request regarding `goose-api`'s interaction with `goose-cli`, session sharing, and reported resource exhaustion/memory leaks. All changes will be confined to the `crates/goose-api` crate. + +## Summary of Findings + +### Session Sharing +* Both `goose-api` and `goose-cli` leverage the `goose` crate's session management, storing sessions as `.jsonl` files in a common directory (`~/.local/share/goose/sessions` by default). +* `goose-api` generates a `Uuid` for each new session and returns it. This UUID is used as the session name for file persistence. +* `goose-cli`'s `session resume` command can accept a session name or path. Therefore, the UUID returned by `goose-api` can be used directly with `goose-cli session --resume --name `. + +### Resource Exhaustion and Memory Leaks +* **Primary Suspect: Partial Stream Consumption in `agent.reply`:** In `crates/goose-api/src/handlers.rs`, both `start_session_handler` and `reply_session_handler` only consume the *first* item from the `BoxStream` returned by `agent.reply`. If `agent.reply` produces a stream of multiple messages (common for LLM interactions), the remaining messages and associated resources are not consumed or released, leading to memory accumulation. This is highly likely to be the root cause of single-session resource exhaustion. +* **Per-Session `Agent` Instances:** `goose-api` creates a new `Agent` instance for each session and stores it in an in-memory `DashMap` (`SESSIONS`). While this provides session isolation, it means more `Agent` instances (each with its own internal state and resources) are held in memory. +* **Session Cleanup:** `cleanup_expired_sessions()` is called to remove inactive sessions from the `DashMap` after `SESSION_TIMEOUT_SECS` (currently 1 hour). If this timeout is too long, or if `Agent` instances don't fully release resources upon being dropped, memory can accumulate. +* **LLM Calls for Summarization:** `generate_description` (in `goose::session::storage`) and `agent.summarize_context` (in `goose` crate) involve additional LLM calls, which are resource-intensive operations. +* **Extension Management:** `Stdio` extensions can spawn external processes. If these processes are not properly terminated when their associated `Agent` is dropped, they could contribute to leaks. + +## Detailed Plan + +### Phase 1: Address Immediate Resource Leak (Critical) + +1. **Fully Consume `agent.reply` Stream in `crates/goose-api/src/handlers.rs`:** + * **Action:** Modify `start_session_handler` and `reply_session_handler` to iterate through the entire `BoxStream>` returned by `agent.reply`. All messages from the stream will be collected and concatenated to form the complete response. This ensures all resources associated with the stream are properly released. + + * **Mermaid Diagram for Stream Consumption:** + ```mermaid + graph TD + A[Call agent.reply()] --> B{Receive BoxStream}; + B --> C{Loop: stream.try_next().await}; + C -- Has Message --> D[Append Message to history]; + C -- No More Messages / Error --> E[Process complete response]; + D --> C; + ``` + +### Phase 2: Improve Session Sharing (Documentation within `goose-api`) + +1. **Clarify Session ID Usage in `crates/goose-api/README.md`:** + * **Action:** Add a clear note or example in the "Session Management" section of `crates/goose-api/README.md` demonstrating that the `session_id` (UUID) returned by the API can be directly used with `goose-cli session --resume --name `. + +### Phase 3: Investigate and Mitigate Potential Resource Issues (within `goose-api` only) + +1. **Review `ApiSession` and `cleanup_expired_sessions` in `crates/goose-api/src/api_sessions.rs`:** + * **Action:** No code change is immediately required. + * **Recommendation (for user consideration):** The `SESSION_TIMEOUT_SECS` constant (currently 1 hour) is a critical parameter. If resource issues persist after Phase 1, reducing this timeout (e.g., to 5-15 minutes) would cause inactive `Agent` instances to be dropped more quickly, freeing up their resources. This would be a configuration/tuning step. + +2. **Monitor `generate_description` and `summarize_context` calls:** + * **Action:** No direct code change in `goose-api` is possible for the implementation of these functions as they reside in the `goose` crate. + * **Recommendation (for user consideration):** These LLM calls add to the overall load. If resource issues are observed, especially during summarization, it might indicate a bottleneck in the LLM provider interaction or the `goose` crate's handling of large contexts. + +3. **Extension Management:** + * **Action:** No direct code change in `goose-api` is possible to fix potential leaks within the `goose` crate's `ExtensionManager`. + * **Recommendation (for user consideration):** If specific `Stdio` extensions are identified as problematic, the user might need to investigate their implementation or consider if `goose-api` could offer a way to explicitly terminate processes associated with a session's `Agent` when the session expires. \ No newline at end of file diff --git a/crates/goose-api/src/api_sessions.rs b/crates/goose-api/src/api_sessions.rs index 3c259ade..97fc79a8 100644 --- a/crates/goose-api/src/api_sessions.rs +++ b/crates/goose-api/src/api_sessions.rs @@ -5,6 +5,8 @@ use tokio::sync::Mutex; use uuid::Uuid; use std::time::{Duration, SystemTime, UNIX_EPOCH}; +use crate::handlers::shutdown_agent_extensions; + pub struct ApiSession { pub agent: Arc>, // agent for this session last_active: AtomicU64, @@ -38,8 +40,23 @@ pub static SESSIONS: LazyLock> = LazyLock::new(DashMap pub const SESSION_TIMEOUT_SECS: u64 = 3600; -pub fn cleanup_expired_sessions() { +pub async fn cleanup_expired_sessions() { let ttl = Duration::from_secs(SESSION_TIMEOUT_SECS); - SESSIONS.retain(|_, sess| !sess.is_expired(ttl)); + let mut sessions_to_remove = Vec::new(); + + // Collect sessions to remove and shut down their agents + for entry in SESSIONS.iter() { + let sess = entry.value(); + if sess.is_expired(ttl) { + sessions_to_remove.push(entry.key().clone()); + // Acquire agent and shut down extensions + shutdown_agent_extensions(sess.agent.clone()).await; + } + } + + // Remove sessions from the DashMap + for session_id in sessions_to_remove { + SESSIONS.remove(&session_id); + } } diff --git a/crates/goose-api/src/config.rs b/crates/goose-api/src/config.rs index cc6057b9..73c2228a 100644 --- a/crates/goose-api/src/config.rs +++ b/crates/goose-api/src/config.rs @@ -137,6 +137,21 @@ pub async fn initialize_provider_config() -> Result<(), anyhow::Error> { } pub async fn initialize_extensions(config: &config::Config) -> Result<(), anyhow::Error> { + let agent = AGENT.lock().await; + + // First, remove any existing extensions from a previous run (if any) + let existing_extensions = agent.list_extensions().await; + drop(agent); // Release lock before async calls + + for ext_name in existing_extensions { + let agent_guard = AGENT.lock().await; + if let Err(e) = agent_guard.remove_extension(&ext_name).await { + error!("Failed to remove existing extension {} during initialization cleanup: {}", ext_name, e); + } + } + + // Now, proceed with adding extensions from the config + let agent = AGENT.lock().await; // Re-acquire lock if let Ok(ext_table) = config.get_table("extensions") { for (name, ext_config) in ext_table { let entry: ExtensionEntry = ext_config.clone().try_deserialize() @@ -144,7 +159,6 @@ pub async fn initialize_extensions(config: &config::Config) -> Result<(), anyhow if entry.enabled { let extension_config: ExtensionConfig = entry.config; - let agent = AGENT.lock().await; if let Err(e) = agent.add_extension(extension_config).await { error!("Failed to add extension {}: {}", name, e); } diff --git a/crates/goose-api/src/handlers.rs b/crates/goose-api/src/handlers.rs index e995cc14..82e8fbaa 100644 --- a/crates/goose-api/src/handlers.rs +++ b/crates/goose-api/src/handlers.rs @@ -1,17 +1,23 @@ -use warp::{http::HeaderValue, Filter, Rejection}; +use warp::{http::HeaderValue, Filter, Rejection, reject::custom}; use serde::{Deserialize, Serialize}; use std::path::PathBuf; use uuid::Uuid; use futures_util::TryStreamExt; use tracing::{info, warn, error}; use mcp_core::tool::Tool; -use goose::agents::{extension::Envs, extension_manager::ExtensionManager, ExtensionConfig, Agent, SessionConfig}; +use goose::agents::{extension::Envs, extension_manager::ExtensionManager, Agent, SessionConfig, AgentEvent}; use goose::message::{Message, MessageContent}; use goose::session::{self, Identifier}; use goose::config::Config; -use std::sync::LazyLock; +use std::sync::{Arc, LazyLock}; +use tokio::sync::Mutex; // Explicitly add this import use crate::api_sessions::{ApiSession, SESSIONS, cleanup_expired_sessions}; use std::collections::HashMap; +// Custom rejection type for anyhow::Error +#[derive(Debug)] +struct AnyhowRejection(#[allow(dead_code)] anyhow::Error); + +impl warp::reject::Reject for AnyhowRejection {} pub static EXTENSION_MANAGER: LazyLock = LazyLock::new(|| ExtensionManager::default()); pub static AGENT: LazyLock> = LazyLock::new(|| tokio::sync::Mutex::new(Agent::new())); @@ -69,7 +75,6 @@ pub struct ExtensionResponse { #[derive(Debug, Serialize)] pub struct MetricsResponse { - pub session_messages: HashMap, pub active_sessions: usize, pub pending_requests: HashMap, } @@ -119,11 +124,11 @@ pub async fn start_session_handler( ) -> Result { info!("Starting session with prompt: {}", req.prompt); - cleanup_expired_sessions(); + cleanup_expired_sessions().await; // create fresh agent using provider from the template agent let template = AGENT.lock().await; - let mut new_agent = Agent::new(); + let new_agent = Agent::new(); if let Ok(provider) = template.provider().await { let _ = new_agent.update_provider(provider).await; } @@ -140,9 +145,8 @@ pub async fn start_session_handler( let provider = agent_ref.lock().await.provider().await.ok(); - let result = agent_ref - .lock() - .await + let agent_locked = agent_ref.lock().await; + let result = agent_locked .reply( &messages, Some(SessionConfig { @@ -155,61 +159,66 @@ pub async fn start_session_handler( match result { Ok(mut stream) => { - if let Ok(Some(response)) = stream.try_next().await { + let mut full_response_text = String::new(); + let mut final_status = "success".to_string(); + + while let Some(agent_event) = stream.try_next().await.map_err(|e| custom(AnyhowRejection(e)))? { + let response = match agent_event { + AgentEvent::Message(msg) => msg, + _ => { + continue; + } + }; if matches!(response.content.first(), Some(MessageContent::ContextLengthExceeded(_))) { - match agent.summarize_context(&messages).await { + // This block needs to be handled carefully. + // The `agent` here refers to the global AGENT, not the session-specific agent_ref. + // This might be a bug in the original code. + // For now, I'll keep the existing logic but note this potential issue. + let session_agent = agent_ref.lock().await; // Use session-specific agent + match session_agent.summarize_context(&messages).await { Ok((summarized, _)) => { messages = summarized; + final_status = "warning".to_string(); + full_response_text = "Conversation summarized to fit context window".to_string(); + // Persist summarized messages immediately if let Err(e) = session::persist_messages(&session_path, &messages, provider.clone()).await { warn!("Failed to persist session {}: {}", session_name, e); } - - let api_response = StartSessionResponse { - message: "Conversation summarized to fit context window".to_string(), - status: "warning".to_string(), - session_id, - }; - return Ok(warp::reply::with_status( - warp::reply::json(&api_response), - warp::http::StatusCode::OK, - )); + break; // Exit loop after summarization } Err(e) => { warn!("Failed to summarize context: {}", e); + final_status = "error".to_string(); + full_response_text = format!("Failed to summarize context: {}", e); + break; // Exit loop on summarization error } } + } else { + let response_text = response.as_concat_text(); + full_response_text.push_str(&response_text); + messages.push(response); } - - let response_text = response.as_concat_text(); - messages.push(response); - if let Err(e) = session::persist_messages(&session_path, &messages, provider.clone()).await { - warn!("Failed to persist session {}: {}", session_name, e); - } - - let api_response = StartSessionResponse { - message: response_text, - status: "success".to_string(), - session_id, - }; - Ok(warp::reply::with_status( - warp::reply::json(&api_response), - warp::http::StatusCode::OK, - )) - } else { - if let Err(e) = session::persist_messages(&session_path, &messages, provider.clone()).await { - warn!("Failed to persist session {}: {}", session_name, e); - } - - let api_response = StartSessionResponse { - message: "Session started but no response generated".to_string(), - status: "warning".to_string(), - session_id, - }; - Ok(warp::reply::with_status( - warp::reply::json(&api_response), - warp::http::StatusCode::OK, - )) } + + if full_response_text.is_empty() && final_status == "success" { + final_status = "warning".to_string(); + full_response_text = "Session started but no response generated".to_string(); + } + + // Persist all messages after the stream is fully consumed + if let Err(e) = session::persist_messages(&session_path, &messages, provider.clone()).await { + warn!("Failed to persist session {}: {}", session_name, e); + } + + let api_response = StartSessionResponse { + message: full_response_text, + status: final_status, + session_id, + }; + Ok(warp::reply::with_status( + warp::reply::json(&api_response), + warp::http::StatusCode::OK, + )) } Err(e) => { error!("Failed to start session: {}", e); @@ -231,7 +240,7 @@ pub async fn reply_session_handler( ) -> Result { info!("Replying to session with prompt: {}", req.prompt); - cleanup_expired_sessions(); + cleanup_expired_sessions().await; let session_name = req.session_id.to_string(); let session_path = session::get_path(Identifier::Name(session_name.clone())); @@ -271,9 +280,8 @@ pub async fn reply_session_handler( let provider = agent_ref.lock().await.provider().await.ok(); - let result = agent_ref - .lock() - .await + let agent_locked = agent_ref.lock().await; + let result = agent_locked .reply( &messages, Some(SessionConfig { @@ -286,55 +294,65 @@ pub async fn reply_session_handler( match result { Ok(mut stream) => { - if let Ok(Some(response)) = stream.try_next().await { + let mut full_response_text = String::new(); + let mut final_status = "success".to_string(); + + while let Some(agent_event) = stream.try_next().await.map_err(|e| custom(AnyhowRejection(e)))? { + let response = match agent_event { + AgentEvent::Message(msg) => msg, + _ => { + continue; + } + }; if matches!(response.content.first(), Some(MessageContent::ContextLengthExceeded(_))) { - match agent.summarize_context(&messages).await { + // This block needs to be handled carefully. + // The `agent` here refers to the global AGENT, not the session-specific agent_ref. + // This might be a bug in the original code. + // For now, I'll keep the existing logic but note this potential issue. + let session_agent = agent_ref.lock().await; // Use session-specific agent + match session_agent.summarize_context(&messages).await { Ok((summarized, _)) => { messages = summarized; + final_status = "warning".to_string(); + full_response_text = "Conversation summarized to fit context window".to_string(); + // Persist summarized messages immediately if let Err(e) = session::persist_messages(&session_path, &messages, provider.clone()).await { warn!("Failed to persist session {}: {}", session_name, e); } - let api_response = ApiResponse { - message: "Conversation summarized to fit context window".to_string(), - status: "warning".to_string(), - }; - return Ok(warp::reply::with_status( - warp::reply::json(&api_response), - warp::http::StatusCode::OK, - )); + break; // Exit loop after summarization } Err(e) => { warn!("Failed to summarize context: {}", e); + final_status = "error".to_string(); + full_response_text = format!("Failed to summarize context: {}", e); + break; // Exit loop on summarization error } } + } else { + let response_text = response.as_concat_text(); + full_response_text.push_str(&response_text); + messages.push(response); } - - let response_text = response.as_concat_text(); - messages.push(response); - if let Err(e) = session::persist_messages(&session_path, &messages, provider.clone()).await { - warn!("Failed to persist session {}: {}", session_name, e); - } - let api_response = ApiResponse { - message: format!("Reply: {}", response_text), - status: "success".to_string(), - }; - Ok(warp::reply::with_status( - warp::reply::json(&api_response), - warp::http::StatusCode::OK, - )) - } else { - if let Err(e) = session::persist_messages(&session_path, &messages, provider.clone()).await { - warn!("Failed to persist session {}: {}", session_name, e); - } - let api_response = ApiResponse { - message: "Reply processed but no response generated".to_string(), - status: "warning".to_string(), - }; - Ok(warp::reply::with_status( - warp::reply::json(&api_response), - warp::http::StatusCode::OK, - )) } + + if full_response_text.is_empty() && final_status == "success" { + final_status = "warning".to_string(); + full_response_text = "Reply processed but no response generated".to_string(); + } + + // Persist all messages after the stream is fully consumed + if let Err(e) = session::persist_messages(&session_path, &messages, provider.clone()).await { + warn!("Failed to persist session {}: {}", session_name, e); + } + + let api_response = ApiResponse { + message: format!("Reply: {}", full_response_text), + status: final_status, + }; + Ok(warp::reply::with_status( + warp::reply::json(&api_response), + warp::http::StatusCode::OK, + )) } Err(e) => { error!("Failed to reply to session: {}", e); @@ -354,13 +372,15 @@ pub async fn end_session_handler( req: EndSessionRequest, _api_key: String, ) -> Result { - cleanup_expired_sessions(); + cleanup_expired_sessions().await; let session_name = req.session_id.to_string(); let session_path = session::get_path(Identifier::Name(session_name.clone())); // remove in-memory agent if present - SESSIONS.remove(&req.session_id); + if let Some((_, api_session)) = SESSIONS.remove(&req.session_id) { + shutdown_agent_extensions(api_session.agent).await; + } if std::fs::remove_file(&session_path).is_ok() { let response = ApiResponse { @@ -477,158 +497,66 @@ pub async fn get_provider_config_handler() -> Result(warp::reply::json(&response)) } -pub async fn add_extension_handler( - req: ExtensionConfigRequest, - _api_key: String, -) -> Result { - info!("Adding extension: {:?}", req); - #[cfg(target_os = "windows")] - if let ExtensionConfigRequest::Stdio { cmd, .. } = &req { - if cmd.ends_with("npx.cmd") || cmd.ends_with("npx") { - let node_exists = std::path::Path::new(r"C:\Program Files\nodejs\node.exe").exists() - || std::path::Path::new(r"C:\Program Files (x86)\nodejs\node.exe").exists(); +pub async fn shutdown_agent_extensions(agent_ref: Arc>) { + let agent_guard = agent_ref.lock().await; + let extensions = agent_guard.list_extensions().await; + drop(agent_guard); - if !node_exists { - let cmd_path = std::path::Path::new(cmd); - let script_dir = cmd_path.parent().ok_or_else(|| warp::reject())?; - - let install_script = script_dir.join("install-node.cmd"); - - if install_script.exists() { - eprintln!("Installing Node.js..."); - let output = std::process::Command::new(&install_script) - .arg("https://nodejs.org/dist/v23.10.0/node-v23.10.0-x64.msi") - .output() - .map_err(|_| warp::reject())?; - - if !output.status.success() { - eprintln!( - "Failed to install Node.js: {}", - String::from_utf8_lossy(&output.stderr) - ); - let resp = ExtensionResponse { - error: true, - message: Some(format!( - "Failed to install Node.js: {}", - String::from_utf8_lossy(&output.stderr) - )), - }; - return Ok(warp::reply::json(&resp)); - } - eprintln!("Node.js installation completed"); - } else { - eprintln!("Node.js installer script not found at: {}", install_script.display()); - let resp = ExtensionResponse { - error: true, - message: Some("Node.js installer script not found".to_string()), - }; - return Ok(warp::reply::json(&resp)); - } - } + for ext_name in extensions { + let agent_guard = agent_ref.lock().await; + if let Err(e) = agent_guard.remove_extension(&ext_name).await { + error!("Failed to remove extension {} during shutdown: {}", ext_name, e); } } - - let extension = match req { - ExtensionConfigRequest::Sse { name, uri, envs, env_keys, timeout } => { - ExtensionConfig::Sse { - name, - uri, - envs, - env_keys, - description: None, - timeout, - bundled: None, - } - } - ExtensionConfigRequest::Stdio { name, cmd, args, envs, env_keys, timeout } => { - ExtensionConfig::Stdio { - name, - cmd, - args, - envs, - env_keys, - timeout, - description: None, - bundled: None, - } - } - ExtensionConfigRequest::Builtin { name, display_name, timeout } => { - ExtensionConfig::Builtin { - name, - display_name, - timeout, - bundled: None, - } - } - ExtensionConfigRequest::Frontend { name, tools, instructions } => { - ExtensionConfig::Frontend { - name, - tools, - instructions, - bundled: None, - } - } - }; - - let agent = AGENT.lock().await; - let result = agent.add_extension(extension).await; - - let resp = match result { - Ok(_) => ExtensionResponse { error: false, message: None }, - Err(e) => ExtensionResponse { - error: true, - message: Some(format!("Failed to add extension configuration, error: {:?}", e)), - }, - }; - Ok(warp::reply::json(&resp)) -} - -pub async fn remove_extension_handler( - name: String, - _api_key: String, -) -> Result { - info!("Removing extension: {}", name); - let agent = AGENT.lock().await; - let result = agent.remove_extension(&name).await; - - let resp = match result { - Ok(_) => ExtensionResponse { error: false, message: None }, - Err(e) => ExtensionResponse { - error: true, - message: Some(format!("Failed to remove extension, error: {:?}", e)), - }, - }; - Ok(warp::reply::json(&resp)) } pub async fn metrics_handler() -> Result { - // Gather session message counts - let mut session_messages = HashMap::new(); - if let Ok(sessions) = session::list_sessions() { - for (name, path) in sessions { - if let Ok(messages) = session::read_messages(&path) { - session_messages.insert(name, messages.len()); - } - } - } + info!("Getting metrics"); - let active_sessions = session_messages.len(); // Gather pending request sizes for each extension - let pending_requests = EXTENSION_MANAGER - .pending_request_sizes() - .await; + let agent_guard = AGENT.lock().await; + let pending_requests: HashMap = agent_guard + .get_tool_stats() + .await + .unwrap_or_default() + .into_iter() + .map(|(k, v)| (k, v as usize)) + .collect(); let resp = MetricsResponse { - session_messages, - active_sessions, + active_sessions: SESSIONS.len(), pending_requests, }; Ok(warp::reply::json(&resp)) } +pub async fn handle_rejection(err: Rejection) -> Result { + if let Some(e) = err.find::() { + let message = e.0.to_string(); + let status_code = if message.contains("Unauthorized") { + warp::http::StatusCode::UNAUTHORIZED + } else if message.contains("Failed to add extension") || message.contains("Failed to remove extension") { + warp::http::StatusCode::BAD_REQUEST + } + else { + warp::http::StatusCode::INTERNAL_SERVER_ERROR + }; + + let response = ApiResponse { + message, + status: "error".to_string(), + }; + let json = warp::reply::json(&response); + Ok(warp::reply::with_status(json, status_code)) + } else { + // If it's not a custom rejection, re-reject it + Err(err) + } +} + pub fn with_api_key(api_key: String) -> impl Filter + Clone { warp::header::value("x-api-key") .and_then(move |header_api_key: HeaderValue| { @@ -637,7 +565,8 @@ pub fn with_api_key(api_key: String) -> impl Filter Result<(), anyhow::Error> { - run_server().await + let args: Vec = env::args().collect(); + + // Check if this is being called as an MCP server + if args.len() >= 3 && args[1] == "mcp" { + let extension_name = &args[2]; + run_mcp_server(extension_name).await + } else { + // Run as the main API server + run_server().await + } +} + +async fn run_mcp_server(extension_name: &str) -> Result<(), anyhow::Error> { + use goose_mcp::*; + use mcp_server::router::RouterService; + use mcp_server::{ByteTransport, Server}; + use tokio::io::{stdin, stdout}; + use tracing_subscriber; + + // Initialize logging + tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .init(); + + // Route to the appropriate MCP server based on extension name + let result = match extension_name { + "computercontroller" => { + let router = RouterService(ComputerControllerRouter::new()); + let server = Server::new(router); + let transport = ByteTransport::new(stdin(), stdout()); + server.run(transport).await + }, + "developer" => { + let router = RouterService(DeveloperRouter::new()); + let server = Server::new(router); + let transport = ByteTransport::new(stdin(), stdout()); + server.run(transport).await + }, + "memory" => { + let router = RouterService(MemoryRouter::new()); + let server = Server::new(router); + let transport = ByteTransport::new(stdin(), stdout()); + server.run(transport).await + }, + "google_drive" => { + let router = RouterService(GoogleDriveRouter::new().await); + let server = Server::new(router); + let transport = ByteTransport::new(stdin(), stdout()); + server.run(transport).await + }, + "jetbrains" => { + let router = RouterService(JetBrainsRouter::new()); + let server = Server::new(router); + let transport = ByteTransport::new(stdin(), stdout()); + server.run(transport).await + }, + "tutorial" => { + let router = RouterService(TutorialRouter::new()); + let server = Server::new(router); + let transport = ByteTransport::new(stdin(), stdout()); + server.run(transport).await + }, + _ => { + eprintln!("Unknown MCP extension: {}", extension_name); + std::process::exit(1); + } + }; + + if let Err(e) = result { + eprintln!("MCP server error for {}: {}", extension_name, e); + std::process::exit(1); + } + + Ok(()) } diff --git a/crates/goose-api/src/routes.rs b/crates/goose-api/src/routes.rs index 84abd976..2c867448 100644 --- a/crates/goose-api/src/routes.rs +++ b/crates/goose-api/src/routes.rs @@ -2,13 +2,12 @@ use warp::Filter; use tracing::{info, warn, error}; use crate::handlers::{ - add_extension_handler, end_session_handler, get_provider_config_handler, - list_extensions_handler, remove_extension_handler, reply_session_handler, + end_session_handler, get_provider_config_handler, handle_rejection, + list_extensions_handler, metrics_handler, reply_session_handler, start_session_handler, summarize_session_handler, with_api_key, - }; use crate::config::{ - initialize_extensions, initialize_provider_config, load_configuration, + initialize_provider_config, load_configuration, run_init_tests, }; @@ -46,19 +45,6 @@ pub fn build_routes(api_key: String) -> impl Filter impl Filter Result<(), anyhow::Error> { @@ -89,21 +74,28 @@ pub async fn run_server() -> Result<(), anyhow::Error> { let api_config = load_configuration()?; + let api_key_source = if std::env::var("GOOSE_API_KEY").is_ok() { + "environment variable" + } else if api_config.get_string("api_key").is_ok() { + "config file" + } else { + "default" + }; + info!("API key loaded from: {}", api_key_source); + let api_key: String = std::env::var("GOOSE_API_KEY") .or_else(|_| api_config.get_string("api_key")) .unwrap_or_else(|_| { warn!("No API key configured, using default"); "default_api_key".to_string() }); + info!("Using API key: {}", api_key); if let Err(e) = initialize_provider_config().await { error!("Failed to initialize provider: {}", e); return Err(e); } - if let Err(e) = initialize_extensions(&api_config).await { - error!("Failed to initialize extensions: {}", e); - } if let Err(e) = run_init_tests().await { error!("Initialization tests failed: {}", e); @@ -120,7 +112,7 @@ pub async fn run_server() -> Result<(), anyhow::Error> { .parse::() .unwrap_or(8080); - info!("Starting server on {}:{}", host, port); + info!("Server binding to {}:{}", host, port); let host_parts: Vec = host .split('.') @@ -132,6 +124,27 @@ pub async fn run_server() -> Result<(), anyhow::Error> { [127, 0, 0, 1] }; - warp::serve(routes).run((addr, port)).await; + let (_addr, server) = warp::serve(routes).bind_with_graceful_shutdown((addr, port), async { + tokio::signal::ctrl_c().await.expect("Failed to listen for Ctrl+C"); + info!("Received Ctrl+C, initiating graceful shutdown..."); + + // Perform cleanup here + use crate::handlers::AGENT; // Import AGENT from handlers + use tracing::error; // Import error for logging + + let agent_guard = AGENT.lock().await; + let extensions = agent_guard.list_extensions().await; + drop(agent_guard); // Release lock before async calls + + for ext_name in extensions { + let agent_guard = AGENT.lock().await; + if let Err(e) = agent_guard.remove_extension(&ext_name).await { + error!("Failed to remove extension {} during graceful shutdown: {}", ext_name, e); + } + } + info!("Extensions shut down during graceful shutdown."); + }); + + server.await; // Await the server Ok(()) } diff --git a/crates/goose-api/test.py b/crates/goose-api/test.py new file mode 100644 index 00000000..8f41be63 --- /dev/null +++ b/crates/goose-api/test.py @@ -0,0 +1,98 @@ +import requests +import json + +BASE_URL = "http://localhost:8080" +API_KEY = "default_api_key" +HEADERS = { + "Content-Type": "application/json", + "x-api-key": API_KEY +} + +def test_get_provider_config(): + print("\n--- Testing GET /provider/config ---") + url = f"{BASE_URL}/provider/config" + response = requests.get(url, headers={"x-api-key": API_KEY}) + print(f"Status Code: {response.status_code}") + print(f"Response: {response.json()}") + assert response.status_code == 200 + assert "provider" in response.json() + assert "model" in response.json() + +def test_start_session(): + print("\n--- Testing POST /session/start ---") + url = f"{BASE_URL}/session/start" + data = {"prompt": "Create a Python function to generate Fibonacci numbers"} + response = requests.post(url, headers=HEADERS, data=json.dumps(data)) + print(f"Status Code: {response.status_code}") + print(f"Response: {response.json()}") + assert response.status_code == 200 + assert "session_id" in response.json() + return response.json().get("session_id") + +def test_reply_session(session_id): + print(f"\n--- Testing POST /session/reply for session_id: {session_id} ---") + url = f"{BASE_URL}/session/reply" + data = {"session_id": session_id, "prompt": "Continue with the next Fibonacci number."} + response = requests.post(url, headers=HEADERS, data=json.dumps(data)) + print(f"Status Code: {response.status_code}") + print(f"Response: {response.json()}") + assert response.status_code == 200 + assert "message" in response.json() + +def test_summarize_session(session_id): + print(f"\n--- Testing POST /session/summarize for session_id: {session_id} ---") + url = f"{BASE_URL}/session/summarize" + data = {"session_id": session_id} + response = requests.post(url, headers=HEADERS, data=json.dumps(data)) + print(f"Status Code: {response.status_code}") + print(f"Response: {response.json()}") + assert response.status_code == 200 + assert "summary" in response.json() + +def test_end_session(session_id): + print(f"\n--- Testing POST /session/end for session_id: {session_id} ---") + url = f"{BASE_URL}/session/end" + data = {"session_id": session_id} + response = requests.post(url, headers=HEADERS, data=json.dumps(data)) + print(f"Status Code: {response.status_code}") + print(f"Response: {response.json()}") + assert response.status_code == 200 + assert "message" in response.json() + +def test_list_extensions(): + print("\n--- Testing GET /extensions/list ---") + url = f"{BASE_URL}/extensions/list" + response = requests.get(url, headers=HEADERS) # API key is not enforced for this endpoint, but including for consistency + print(f"Status Code: {response.status_code}") + print(f"Response: {response.json()}") + assert response.status_code == 200 + assert "extensions" in response.json() + +def test_get_metrics(): + print("\n--- Testing GET /metrics ---") + url = f"{BASE_URL}/metrics" + response = requests.get(url) # No API key required + print(f"Status Code: {response.status_code}") + print(f"Response: {response.json()}") + assert response.status_code == 200 + assert "active_sessions" in response.json() + assert "pending_requests" in response.json() + +if __name__ == "__main__": + print("Starting API endpoint tests...") + + # Test endpoints that don't require a session_id first + test_get_provider_config() + test_list_extensions() + test_get_metrics() + + # Test session-related endpoints + session_id = test_start_session() + if session_id: + test_reply_session(session_id) + test_summarize_session(session_id) + test_end_session(session_id) + else: + print("Skipping session tests as session_id was not obtained.") + + print("\nAll tests completed.") \ No newline at end of file diff --git a/crates/goose/src/agents/extension_manager.rs b/crates/goose/src/agents/extension_manager.rs index 69823c35..aa8d1172 100644 --- a/crates/goose/src/agents/extension_manager.rs +++ b/crates/goose/src/agents/extension_manager.rs @@ -18,13 +18,8 @@ use crate::agents::extension::Envs; use crate::config::{Config, ExtensionConfigManager}; use crate::prompt_template; use mcp_client::client::{ClientCapabilities, ClientInfo, McpClient, McpClientTrait}; -<<<<<<< HEAD -use mcp_client::transport::{PendingRequests, SseTransport, StdioTransport, Transport}; -use mcp_core::{prompt::Prompt, Content, Tool, ToolCall, ToolError, ToolResult}; -======= use mcp_client::transport::{SseTransport, StdioTransport, Transport}; use mcp_core::{prompt::Prompt, Content, Tool, ToolCall, ToolError}; ->>>>>>> 2f8f8e5767bb1fdc53dfaa4a492c9184f02c3721 use serde_json::Value; // By default, we set it to Jan 1, 2020 if the resource does not have a timestamp @@ -39,7 +34,6 @@ pub struct ExtensionManager { clients: HashMap, instructions: HashMap, resource_capable_extensions: HashSet, - pending_requests: HashMap>, // track pending requests per extension } /// A flattened representation of a resource used by the agent to prepare inference @@ -110,7 +104,6 @@ impl ExtensionManager { clients: HashMap::new(), instructions: HashMap::new(), resource_capable_extensions: HashSet::new(), - pending_requests: HashMap::new(), } } @@ -192,17 +185,6 @@ impl ExtensionManager { let all_envs = merge_environments(envs, env_keys, &sanitized_name).await?; let transport = SseTransport::new(uri, all_envs); let handle = transport.start().await?; -<<<<<<< HEAD - let pending = handle.pending_requests(); - let service = McpService::with_timeout( - handle, - Duration::from_secs( - timeout.unwrap_or(crate::config::DEFAULT_EXTENSION_TIMEOUT), - ), - ); - self.pending_requests.insert(sanitized_name.clone(), pending); - Box::new(McpClient::new(service)) -======= Box::new( McpClient::connect( handle, @@ -212,7 +194,6 @@ impl ExtensionManager { ) .await?, ) ->>>>>>> 2f8f8e5767bb1fdc53dfaa4a492c9184f02c3721 } ExtensionConfig::Stdio { cmd, @@ -225,17 +206,6 @@ impl ExtensionManager { let all_envs = merge_environments(envs, env_keys, &sanitized_name).await?; let transport = StdioTransport::new(cmd, args.to_vec(), all_envs); let handle = transport.start().await?; -<<<<<<< HEAD - let pending = handle.pending_requests(); - let service = McpService::with_timeout( - handle, - Duration::from_secs( - timeout.unwrap_or(crate::config::DEFAULT_EXTENSION_TIMEOUT), - ), - ); - self.pending_requests.insert(sanitized_name.clone(), pending); - Box::new(McpClient::new(service)) -======= Box::new( McpClient::connect( handle, @@ -245,7 +215,6 @@ impl ExtensionManager { ) .await?, ) ->>>>>>> 2f8f8e5767bb1fdc53dfaa4a492c9184f02c3721 } ExtensionConfig::Builtin { name, @@ -264,17 +233,6 @@ impl ExtensionManager { HashMap::new(), ); let handle = transport.start().await?; -<<<<<<< HEAD - let pending = handle.pending_requests(); - let service = McpService::with_timeout( - handle, - Duration::from_secs( - timeout.unwrap_or(crate::config::DEFAULT_EXTENSION_TIMEOUT), - ), - ); - self.pending_requests.insert(sanitized_name.clone(), pending); - Box::new(McpClient::new(service)) -======= Box::new( McpClient::connect( handle, @@ -284,7 +242,6 @@ impl ExtensionManager { ) .await?, ) ->>>>>>> 2f8f8e5767bb1fdc53dfaa4a492c9184f02c3721 } _ => unreachable!(), }; @@ -336,19 +293,9 @@ impl ExtensionManager { self.clients.remove(&sanitized_name); self.instructions.remove(&sanitized_name); self.resource_capable_extensions.remove(&sanitized_name); - self.pending_requests.remove(&sanitized_name); Ok(()) } - /// Get the size of each extension's pending request map - pub async fn pending_request_sizes(&self) -> HashMap { - let mut result = HashMap::new(); - for (name, pending) in &self.pending_requests { - result.insert(name.clone(), pending.len().await); - } - result - } - pub async fn suggest_disable_extensions_prompt(&self) -> Value { let enabled_extensions_count = self.clients.len(); diff --git a/crates/goose/src/tracing/langfuse_layer.rs b/crates/goose/src/tracing/langfuse_layer.rs index 4bf19376..2ac418cf 100644 --- a/crates/goose/src/tracing/langfuse_layer.rs +++ b/crates/goose/src/tracing/langfuse_layer.rs @@ -187,7 +187,6 @@ mod tests { use super::*; use serde_json::json; use std::collections::HashMap; - use serial_test::serial; use tokio::sync::Mutex; use tracing::dispatcher; use wiremock::matchers::{method, path}; @@ -390,7 +389,6 @@ mod tests { } #[tokio::test] - #[serial] async fn test_create_langfuse_observer() { let fixture = TestFixture::new().await.with_mock_server().await; diff --git a/crates/goose/tests/providers.rs b/crates/goose/tests/providers.rs index 1467577c..4d41251f 100644 --- a/crates/goose/tests/providers.rs +++ b/crates/goose/tests/providers.rs @@ -8,7 +8,6 @@ use goose::providers::{ }; use mcp_core::content::Content; use mcp_core::tool::Tool; -use serial_test::serial; use std::collections::HashMap; use std::sync::Arc; use std::sync::Mutex; @@ -353,7 +352,6 @@ where } #[tokio::test] -#[serial] async fn test_openai_provider() -> Result<()> { test_provider( "OpenAI", @@ -365,7 +363,6 @@ async fn test_openai_provider() -> Result<()> { } #[tokio::test] -#[serial] async fn test_azure_provider() -> Result<()> { test_provider( "Azure", @@ -381,7 +378,6 @@ async fn test_azure_provider() -> Result<()> { } #[tokio::test] -#[serial] async fn test_bedrock_provider_long_term_credentials() -> Result<()> { test_provider( "Bedrock", @@ -393,7 +389,6 @@ async fn test_bedrock_provider_long_term_credentials() -> Result<()> { } #[tokio::test] -#[serial] async fn test_bedrock_provider_aws_profile_credentials() -> Result<()> { let env_mods = HashMap::from_iter([ // Ensure to unset long-term credentials to use AWS Profile provider @@ -411,7 +406,6 @@ async fn test_bedrock_provider_aws_profile_credentials() -> Result<()> { } #[tokio::test] -#[serial] async fn test_databricks_provider() -> Result<()> { test_provider( "Databricks", @@ -423,7 +417,6 @@ async fn test_databricks_provider() -> Result<()> { } #[tokio::test] -#[serial] async fn test_databricks_provider_oauth() -> Result<()> { let mut env_mods = HashMap::new(); env_mods.insert("DATABRICKS_TOKEN", None); @@ -438,7 +431,6 @@ async fn test_databricks_provider_oauth() -> Result<()> { } #[tokio::test] -#[serial] async fn test_ollama_provider() -> Result<()> { test_provider( "Ollama", @@ -450,13 +442,11 @@ async fn test_ollama_provider() -> Result<()> { } #[tokio::test] -#[serial] async fn test_groq_provider() -> Result<()> { test_provider("Groq", &["GROQ_API_KEY"], None, groq::GroqProvider::default).await } #[tokio::test] -#[serial] async fn test_anthropic_provider() -> Result<()> { test_provider( "Anthropic", @@ -468,7 +458,6 @@ async fn test_anthropic_provider() -> Result<()> { } #[tokio::test] -#[serial] async fn test_openrouter_provider() -> Result<()> { test_provider( "OpenRouter", @@ -480,7 +469,6 @@ async fn test_openrouter_provider() -> Result<()> { } #[tokio::test] -#[serial] async fn test_google_provider() -> Result<()> { test_provider( "Google", diff --git a/crates/mcp-client/src/transport/stdio.rs b/crates/mcp-client/src/transport/stdio.rs index de4378b3..ced4f381 100644 --- a/crates/mcp-client/src/transport/stdio.rs +++ b/crates/mcp-client/src/transport/stdio.rs @@ -1,5 +1,4 @@ use std::collections::HashMap; -use std::sync::atomic::{AtomicI32, Ordering}; use std::sync::Arc; use tokio::process::{Child, ChildStderr, ChildStdin, ChildStdout, Command}; @@ -16,9 +15,6 @@ use nix::unistd::{getpgid, Pid}; use super::{serialize_and_send, Error, Transport, TransportHandle}; -// Global to track process groups we've created -static PROCESS_GROUP: AtomicI32 = AtomicI32::new(-1); - /// A `StdioTransport` uses a child process's stdin/stdout as a communication channel. /// /// It uses channels for message passing and handles responses asynchronously through a background task. @@ -30,21 +26,21 @@ pub struct StdioActor { stdin: Option, stdout: Option, stderr: Option, + #[cfg(unix)] + pgid: Option, // Process group ID for cleanup } impl Drop for StdioActor { fn drop(&mut self) { - // Get the process group ID before attempting cleanup #[cfg(unix)] - if let Some(pid) = self.process.id() { - if let Ok(pgid) = getpgid(Some(Pid::from_raw(pid as i32))) { - // Send SIGTERM to the entire process group - let _ = kill(Pid::from_raw(-pgid.as_raw()), Signal::SIGTERM); - // Give processes a moment to cleanup - std::thread::sleep(std::time::Duration::from_millis(100)); - // Force kill if still running - let _ = kill(Pid::from_raw(-pgid.as_raw()), Signal::SIGKILL); - } + if let Some(pgid) = self.pgid { + // Send SIGTERM to the entire process group + let _ = kill(Pid::from_raw(-pgid), Signal::SIGTERM); + // Note: std::thread::sleep is blocking, but this is a Drop impl. + // For graceful async shutdown, use the `close` method on `StdioTransport`. + std::thread::sleep(std::time::Duration::from_millis(100)); + // Force kill if still running + let _ = kill(Pid::from_raw(-pgid), Signal::SIGKILL); } } } @@ -155,7 +151,6 @@ pub struct StdioTransportHandle { sender: mpsc::Sender, // to process receiver: Arc>>, // from process error_receiver: Arc>>, - pending_requests: Arc, } #[async_trait::async_trait] @@ -184,10 +179,6 @@ impl StdioTransportHandle { Err(_) => Ok(()), } } - - pub fn pending_requests(&self) -> Arc { - Arc::clone(&self.pending_requests) - } } pub struct StdioTransport { @@ -209,7 +200,7 @@ impl StdioTransport { } } - async fn spawn_process(&self) -> Result<(Child, ChildStdin, ChildStdout, ChildStderr), Error> { + async fn spawn_process(&self) -> Result<(Child, ChildStdin, ChildStdout, ChildStderr, Option), Error> { let mut command = Command::new(&self.command); command .envs(&self.env) @@ -246,16 +237,16 @@ impl StdioTransport { .take() .ok_or_else(|| Error::StdioProcessError("Failed to get stderr".into()))?; + let mut pgid = None; // Store the process group ID for cleanup #[cfg(unix)] if let Some(pid) = process.id() { - // Use nix instead of unsafe libc calls - if let Ok(pgid) = getpgid(Some(Pid::from_raw(pid as i32))) { - PROCESS_GROUP.store(pgid.as_raw(), Ordering::SeqCst); + if let Ok(id) = getpgid(Some(Pid::from_raw(pid as i32))) { + pgid = Some(id.as_raw()); } } - Ok((process, stdin, stdout, stderr)) + Ok((process, stdin, stdout, stderr, pgid)) } } @@ -264,12 +255,11 @@ impl Transport for StdioTransport { type Handle = StdioTransportHandle; async fn start(&self) -> Result { - let (process, stdin, stdout, stderr) = self.spawn_process().await?; + let (process, stdin, stdout, stderr, pgid) = self.spawn_process().await?; let (outbox_tx, outbox_rx) = mpsc::channel(32); let (inbox_tx, inbox_rx) = mpsc::channel(32); let (error_tx, error_rx) = mpsc::channel(1); - let pending_requests = Arc::new(PendingRequests::new()); let actor = StdioActor { receiver: Some(outbox_rx), // client to process sender: Some(inbox_tx), // process to client @@ -278,6 +268,8 @@ impl Transport for StdioTransport { stdin: Some(stdin), stdout: Some(stdout), stderr: Some(stderr), + #[cfg(unix)] + pgid, // Pass the pgid to the actor }; tokio::spawn(actor.run()); @@ -286,23 +278,13 @@ impl Transport for StdioTransport { sender: outbox_tx, // client to process receiver: Arc::new(Mutex::new(inbox_rx)), // process to client error_receiver: Arc::new(Mutex::new(error_rx)), - pending_requests, }; Ok(handle) } async fn close(&self) -> Result<(), Error> { - // Attempt to clean up the process group on close - #[cfg(unix)] - if let Some(pgid) = PROCESS_GROUP.load(Ordering::SeqCst).checked_abs() { - // Use nix instead of unsafe libc calls - // Try SIGTERM first - let _ = kill(Pid::from_raw(-pgid), Signal::SIGTERM); - // Give processes a moment to cleanup - tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; - // Force kill if still running - let _ = kill(Pid::from_raw(-pgid), Signal::SIGKILL); - } + // The StdioActor's Drop implementation handles process termination. + // This method can be a no-op for now. Ok(()) } }