This commit is contained in:
2025-07-26 17:28:31 +02:00
parent dad76ef9ac
commit 2002602fc5
16 changed files with 608 additions and 438 deletions

144
Cargo.lock generated
View File

@@ -261,7 +261,7 @@ version = "52.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "16f4a9468c882dc66862cef4e1fd8423d47e67972377d85d80e022786427768c"
dependencies = [
"ahash",
"ahash 0.8.11",
"arrow-buffer",
"arrow-data",
"arrow-schema",
@@ -392,7 +392,7 @@ version = "52.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4cd09a518c602a55bd406bcc291a967b284cfa7a63edfbf8b897ea4748aad23c"
dependencies = [
"ahash",
"ahash 0.8.11",
"arrow-array",
"arrow-buffer",
"arrow-data",
@@ -415,7 +415,7 @@ version = "52.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "600bae05d43483d216fb3494f8c32fdbefd8aa4e1de237e790dbb3d9f44690a3"
dependencies = [
"ahash",
"ahash 0.8.11",
"arrow-array",
"arrow-buffer",
"arrow-data",
@@ -686,7 +686,7 @@ dependencies = [
"fastrand 2.3.0",
"hex",
"http 0.2.12",
"ring 0.17.12",
"ring 0.17.14",
"time",
"tokio",
"tracing",
@@ -1074,7 +1074,7 @@ dependencies = [
"sha1",
"sync_wrapper 1.0.2",
"tokio",
"tokio-tungstenite 0.24.0",
"tokio-tungstenite 0.26.2",
"tower 0.5.2",
"tower-layer",
"tower-service",
@@ -1819,6 +1819,25 @@ dependencies = [
"crossbeam-utils",
]
[[package]]
name = "config"
version = "0.13.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "23738e11972c7643e4ec947840fc463b6a571afcd3e735bdfce7d03c7a784aca"
dependencies = [
"async-trait",
"json5",
"lazy_static",
"nom",
"pathdiff",
"ron 0.7.1",
"rust-ini 0.18.0",
"serde",
"serde_json",
"toml 0.5.11",
"yaml-rust",
]
[[package]]
name = "config"
version = "0.14.1"
@@ -2200,7 +2219,7 @@ version = "41.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e4fd4a99fc70d40ef7e52b243b4a399c3f8d353a40d5ecb200deee05e49c61bb"
dependencies = [
"ahash",
"ahash 0.8.11",
"arrow",
"arrow-array",
"arrow-ipc",
@@ -2263,7 +2282,7 @@ version = "41.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "44fdbc877e3e40dcf88cc8f283d9f5c8851f0a3aa07fee657b1b75ac1ad49b9c"
dependencies = [
"ahash",
"ahash 0.8.11",
"arrow",
"arrow-array",
"arrow-buffer",
@@ -2314,7 +2333,7 @@ version = "41.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1c1841c409d9518c17971d15c9bae62e629eb937e6fb6c68cd32e9186f8b30d2"
dependencies = [
"ahash",
"ahash 0.8.11",
"arrow",
"arrow-array",
"arrow-buffer",
@@ -2356,7 +2375,7 @@ version = "41.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2b4ece19f73c02727e5e8654d79cd5652de371352c1df3c4ac3e419ecd6943fb"
dependencies = [
"ahash",
"ahash 0.8.11",
"arrow",
"arrow-schema",
"datafusion-common",
@@ -2416,7 +2435,7 @@ version = "41.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9a223962b3041304a3e20ed07a21d5de3d88d7e4e71ca192135db6d24e3365a4"
dependencies = [
"ahash",
"ahash 0.8.11",
"arrow",
"arrow-array",
"arrow-buffer",
@@ -2446,7 +2465,7 @@ version = "41.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "db5e7d8532a1601cd916881db87a70b0a599900d23f3db2897d389032da53bc6"
dependencies = [
"ahash",
"ahash 0.8.11",
"arrow",
"datafusion-common",
"datafusion-expr",
@@ -2472,7 +2491,7 @@ version = "41.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8d1116949432eb2d30f6362707e2846d942e491052a206f2ddcb42d08aea1ffe"
dependencies = [
"ahash",
"ahash 0.8.11",
"arrow",
"arrow-array",
"arrow-buffer",
@@ -3425,7 +3444,7 @@ dependencies = [
"futures-util",
"include_dir",
"indoc 2.0.6",
"jsonwebtoken",
"jsonwebtoken 9.3.1",
"keyring",
"lancedb",
"lazy_static",
@@ -3467,7 +3486,9 @@ name = "goose-api"
version = "0.1.0"
dependencies = [
"anyhow",
"async-trait",
"config 0.13.4",
"dashmap 6.1.0",
"futures",
"futures-util",
"goose",
@@ -3475,8 +3496,10 @@ dependencies = [
"jsonwebtoken 8.3.0",
"mcp-client",
"mcp-core",
"mcp-server",
"serde",
"serde_json",
"tempfile",
"tokio",
"tracing",
"tracing-subscriber",
@@ -4592,7 +4615,7 @@ dependencies = [
"base64 0.22.1",
"js-sys",
"pem 3.0.5",
"ring 0.17.12",
"ring 0.17.14",
"serde",
"serde_json",
"simple_asn1",
@@ -5638,6 +5661,24 @@ dependencies = [
"syn 2.0.99",
]
[[package]]
name = "multer"
version = "2.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "01acbdc23469fd8fe07ab135923371d5f5a422fbf9c522158677c8eb15bc51c2"
dependencies = [
"bytes",
"encoding_rs",
"futures-util",
"http 0.2.12",
"httparse",
"log",
"memchr",
"mime",
"spin 0.9.8",
"version_check",
]
[[package]]
name = "multimap"
version = "0.10.1"
@@ -5969,7 +6010,7 @@ dependencies = [
"quick-xml 0.36.2",
"rand 0.8.5",
"reqwest 0.12.12",
"ring",
"ring 0.17.14",
"rustls-pemfile 2.2.0",
"serde",
"serde_json",
@@ -6712,7 +6753,7 @@ dependencies = [
"bytes",
"getrandom 0.2.15",
"rand 0.8.5",
"ring",
"ring 0.17.14",
"rustc-hash 2.1.1",
"rustls 0.23.23",
"rustls-pki-types",
@@ -7118,6 +7159,21 @@ dependencies = [
"bytemuck",
]
[[package]]
name = "ring"
version = "0.16.20"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3053cf52e236a3ed746dfc745aa9cacf1b791d846bdaf412f60a8d7d6e17c8fc"
dependencies = [
"cc",
"libc",
"once_cell",
"spin 0.5.2",
"untrusted 0.7.1",
"web-sys",
"winapi",
]
[[package]]
name = "ring"
version = "0.17.14"
@@ -7142,6 +7198,17 @@ dependencies = [
"byteorder",
]
[[package]]
name = "ron"
version = "0.7.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "88073939a61e5b7680558e6be56b419e208420c2adb92be54921fa6b72283f1a"
dependencies = [
"base64 0.13.1",
"bitflags 1.3.2",
"serde",
]
[[package]]
name = "ron"
version = "0.8.1"
@@ -7171,7 +7238,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3e0698206bcb8882bf2a9ecb4c1e7785db57ff052297085a6efd4fe42302068a"
dependencies = [
"cfg-if",
"ordered-multimap",
"ordered-multimap 0.7.3",
]
[[package]]
@@ -7258,7 +7325,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3f56a14d1f48b391359b22f731fd4bd7e43c97f3c50eee276f3aa09c94784d3e"
dependencies = [
"log",
"ring 0.17.12",
"ring 0.17.14",
"rustls-webpki 0.101.7",
"sct",
]
@@ -7270,7 +7337,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "47796c98c480fce5406ef69d1c76378375492c3b0a0de587be0c1d9feb12f395"
dependencies = [
"once_cell",
"ring 0.17.12",
"ring 0.17.14",
"rustls-pki-types",
"rustls-webpki 0.102.8",
"subtle",
@@ -7334,7 +7401,7 @@ version = "0.101.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8b6275d1ee7a1cd780b64aca7726599a1dbc893b1e64144529e55c3c2f745765"
dependencies = [
"ring 0.17.12",
"ring 0.17.14",
"untrusted 0.9.0",
]
@@ -7344,7 +7411,7 @@ version = "0.102.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "64ca1bc8749bd4cf37b5ce386cc146580777b4e8572c7b97baf22c83f444bee9"
dependencies = [
"ring 0.17.12",
"ring 0.17.14",
"rustls-pki-types",
"untrusted 0.9.0",
]
@@ -7481,7 +7548,7 @@ version = "0.7.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "da046153aa2352493d6cb7da4b6e5c0c057d8a1d0a9aa8560baffdd945acd414"
dependencies = [
"ring 0.17.12",
"ring 0.17.14",
"untrusted 0.9.0",
]
@@ -8657,6 +8724,18 @@ dependencies = [
"tokio",
]
[[package]]
name = "tokio-tungstenite"
version = "0.21.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c83b561d025642014097b66e6c1bb422783339e0909e4429cde4749d1990bc38"
dependencies = [
"futures-util",
"log",
"tokio",
"tungstenite 0.21.0",
]
[[package]]
name = "tokio-tungstenite"
version = "0.26.2"
@@ -8666,7 +8745,7 @@ dependencies = [
"futures-util",
"log",
"tokio",
"tungstenite",
"tungstenite 0.26.2",
]
[[package]]
@@ -8894,6 +8973,25 @@ version = "0.2.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b"
[[package]]
name = "tungstenite"
version = "0.21.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9ef1a641ea34f399a848dea702823bbecfb4c486f911735368f1f137cb8257e1"
dependencies = [
"byteorder",
"bytes",
"data-encoding",
"http 1.2.0",
"httparse",
"log",
"rand 0.8.5",
"sha1",
"thiserror 1.0.69",
"url",
"utf-8",
]
[[package]]
name = "tungstenite"
version = "0.26.2"

View File

@@ -1,49 +0,0 @@
extensions:
computercontroller:
bundled: true
display_name: Computer Controller
enabled: true
name: computercontroller
timeout: 300
type: builtin
developer:
bundled: true
display_name: Developer Tools
enabled: true
name: developer
timeout: 300
type: builtin
filesytem:
args:
- -y
- '@modelcontextprotocol/server-filesystem'
- /home/lio/g
bundled: null
cmd: npx
description: 'access files inside ~/g '
enabled: true
env_keys: []
envs: {}
name: filesytem
timeout: 300
type: stdio
filesytem-extension:
args:
- -y
- '@modelcontextprotocol/server-filesystem'
bundled: null
cmd: npx
description: null
enabled: false
env_keys: []
envs: {}
name: filesytem-extension
timeout: 300
type: stdio
memory:
bundled: true
display_name: Memory
enabled: true
name: memory
timeout: 300
type: builtin

View File

@@ -8,6 +8,7 @@ goose = { path = "../goose" }
goose-mcp = { path = "../goose-mcp" }
mcp-client = { path = "../mcp-client" }
mcp-core = { path = "../mcp-core" }
mcp-server = { path = "../mcp-server" }
tokio = { version = "1", features = ["full"] }
warp = "0.3"
serde = { version = "1", features = ["derive"] }

View File

@@ -287,8 +287,15 @@ By default, the server runs on `127.0.0.1:8080`. You can modify this using confi
Sessions created via the API are stored in the same location as the CLI
(`~/.local/share/goose/sessions` on most platforms). Each session is saved to a
`<session_id>.jsonl` file. You can resume or inspect these sessions with the CLI
by providing the session ID returned from the API.
`<session_id>.jsonl` file.
You can resume or inspect these sessions with the CLI by providing the session ID
(which is a UUID) returned from the API. For example, if the API returns a
`session_id` of `a1b2c3d4-e5f6-7890-1234-567890abcdef`, you can resume it with:
```bash
goose session --resume --name a1b2c3d4-e5f6-7890-1234-567890abcdef
```
## Examples
@@ -298,7 +305,7 @@ by providing the session ID returned from the API.
# Start a session
curl -X POST http://localhost:8080/session/start \
-H "Content-Type: application/json" \
-H "x-api-key: your_secure_api_key" \
-H "x-api-key: kurac" \
-d '{"prompt": "Create a Python function to generate Fibonacci numbers"}'
# Reply to an ongoing session

View File

@@ -1,8 +1,8 @@
# API server configuration
host: 0.0.0.0
port: 8080
port: 8181
api_key: kurac
# Provider configuration
provider: ollama
model: qwen3:8b
model: qwen3:4b

View File

@@ -0,0 +1,53 @@
# Plan for `goose-api` Review and Improvements
This document outlines the plan to address the user's request regarding `goose-api`'s interaction with `goose-cli`, session sharing, and reported resource exhaustion/memory leaks. All changes will be confined to the `crates/goose-api` crate.
## Summary of Findings
### Session Sharing
* Both `goose-api` and `goose-cli` leverage the `goose` crate's session management, storing sessions as `.jsonl` files in a common directory (`~/.local/share/goose/sessions` by default).
* `goose-api` generates a `Uuid` for each new session and returns it. This UUID is used as the session name for file persistence.
* `goose-cli`'s `session resume` command can accept a session name or path. Therefore, the UUID returned by `goose-api` can be used directly with `goose-cli session --resume --name <UUID>`.
### Resource Exhaustion and Memory Leaks
* **Primary Suspect: Partial Stream Consumption in `agent.reply`:** In `crates/goose-api/src/handlers.rs`, both `start_session_handler` and `reply_session_handler` only consume the *first* item from the `BoxStream` returned by `agent.reply`. If `agent.reply` produces a stream of multiple messages (common for LLM interactions), the remaining messages and associated resources are not consumed or released, leading to memory accumulation. This is highly likely to be the root cause of single-session resource exhaustion.
* **Per-Session `Agent` Instances:** `goose-api` creates a new `Agent` instance for each session and stores it in an in-memory `DashMap` (`SESSIONS`). While this provides session isolation, it means more `Agent` instances (each with its own internal state and resources) are held in memory.
* **Session Cleanup:** `cleanup_expired_sessions()` is called to remove inactive sessions from the `DashMap` after `SESSION_TIMEOUT_SECS` (currently 1 hour). If this timeout is too long, or if `Agent` instances don't fully release resources upon being dropped, memory can accumulate.
* **LLM Calls for Summarization:** `generate_description` (in `goose::session::storage`) and `agent.summarize_context` (in `goose` crate) involve additional LLM calls, which are resource-intensive operations.
* **Extension Management:** `Stdio` extensions can spawn external processes. If these processes are not properly terminated when their associated `Agent` is dropped, they could contribute to leaks.
## Detailed Plan
### Phase 1: Address Immediate Resource Leak (Critical)
1. **Fully Consume `agent.reply` Stream in `crates/goose-api/src/handlers.rs`:**
* **Action:** Modify `start_session_handler` and `reply_session_handler` to iterate through the entire `BoxStream<anyhow::Result<Message>>` returned by `agent.reply`. All messages from the stream will be collected and concatenated to form the complete response. This ensures all resources associated with the stream are properly released.
* **Mermaid Diagram for Stream Consumption:**
```mermaid
graph TD
A[Call agent.reply()] --> B{Receive BoxStream<Message>};
B --> C{Loop: stream.try_next().await};
C -- Has Message --> D[Append Message to history];
C -- No More Messages / Error --> E[Process complete response];
D --> C;
```
### Phase 2: Improve Session Sharing (Documentation within `goose-api`)
1. **Clarify Session ID Usage in `crates/goose-api/README.md`:**
* **Action:** Add a clear note or example in the "Session Management" section of `crates/goose-api/README.md` demonstrating that the `session_id` (UUID) returned by the API can be directly used with `goose-cli session --resume --name <UUID>`.
### Phase 3: Investigate and Mitigate Potential Resource Issues (within `goose-api` only)
1. **Review `ApiSession` and `cleanup_expired_sessions` in `crates/goose-api/src/api_sessions.rs`:**
* **Action:** No code change is immediately required.
* **Recommendation (for user consideration):** The `SESSION_TIMEOUT_SECS` constant (currently 1 hour) is a critical parameter. If resource issues persist after Phase 1, reducing this timeout (e.g., to 5-15 minutes) would cause inactive `Agent` instances to be dropped more quickly, freeing up their resources. This would be a configuration/tuning step.
2. **Monitor `generate_description` and `summarize_context` calls:**
* **Action:** No direct code change in `goose-api` is possible for the implementation of these functions as they reside in the `goose` crate.
* **Recommendation (for user consideration):** These LLM calls add to the overall load. If resource issues are observed, especially during summarization, it might indicate a bottleneck in the LLM provider interaction or the `goose` crate's handling of large contexts.
3. **Extension Management:**
* **Action:** No direct code change in `goose-api` is possible to fix potential leaks within the `goose` crate's `ExtensionManager`.
* **Recommendation (for user consideration):** If specific `Stdio` extensions are identified as problematic, the user might need to investigate their implementation or consider if `goose-api` could offer a way to explicitly terminate processes associated with a session's `Agent` when the session expires.

View File

@@ -5,6 +5,8 @@ use tokio::sync::Mutex;
use uuid::Uuid;
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use crate::handlers::shutdown_agent_extensions;
pub struct ApiSession {
pub agent: Arc<Mutex<Agent>>, // agent for this session
last_active: AtomicU64,
@@ -38,8 +40,23 @@ pub static SESSIONS: LazyLock<DashMap<Uuid, ApiSession>> = LazyLock::new(DashMap
pub const SESSION_TIMEOUT_SECS: u64 = 3600;
pub fn cleanup_expired_sessions() {
pub async fn cleanup_expired_sessions() {
let ttl = Duration::from_secs(SESSION_TIMEOUT_SECS);
SESSIONS.retain(|_, sess| !sess.is_expired(ttl));
let mut sessions_to_remove = Vec::new();
// Collect sessions to remove and shut down their agents
for entry in SESSIONS.iter() {
let sess = entry.value();
if sess.is_expired(ttl) {
sessions_to_remove.push(entry.key().clone());
// Acquire agent and shut down extensions
shutdown_agent_extensions(sess.agent.clone()).await;
}
}
// Remove sessions from the DashMap
for session_id in sessions_to_remove {
SESSIONS.remove(&session_id);
}
}

View File

@@ -137,6 +137,21 @@ pub async fn initialize_provider_config() -> Result<(), anyhow::Error> {
}
pub async fn initialize_extensions(config: &config::Config) -> Result<(), anyhow::Error> {
let agent = AGENT.lock().await;
// First, remove any existing extensions from a previous run (if any)
let existing_extensions = agent.list_extensions().await;
drop(agent); // Release lock before async calls
for ext_name in existing_extensions {
let agent_guard = AGENT.lock().await;
if let Err(e) = agent_guard.remove_extension(&ext_name).await {
error!("Failed to remove existing extension {} during initialization cleanup: {}", ext_name, e);
}
}
// Now, proceed with adding extensions from the config
let agent = AGENT.lock().await; // Re-acquire lock
if let Ok(ext_table) = config.get_table("extensions") {
for (name, ext_config) in ext_table {
let entry: ExtensionEntry = ext_config.clone().try_deserialize()
@@ -144,7 +159,6 @@ pub async fn initialize_extensions(config: &config::Config) -> Result<(), anyhow
if entry.enabled {
let extension_config: ExtensionConfig = entry.config;
let agent = AGENT.lock().await;
if let Err(e) = agent.add_extension(extension_config).await {
error!("Failed to add extension {}: {}", name, e);
}

View File

@@ -1,17 +1,23 @@
use warp::{http::HeaderValue, Filter, Rejection};
use warp::{http::HeaderValue, Filter, Rejection, reject::custom};
use serde::{Deserialize, Serialize};
use std::path::PathBuf;
use uuid::Uuid;
use futures_util::TryStreamExt;
use tracing::{info, warn, error};
use mcp_core::tool::Tool;
use goose::agents::{extension::Envs, extension_manager::ExtensionManager, ExtensionConfig, Agent, SessionConfig};
use goose::agents::{extension::Envs, extension_manager::ExtensionManager, Agent, SessionConfig, AgentEvent};
use goose::message::{Message, MessageContent};
use goose::session::{self, Identifier};
use goose::config::Config;
use std::sync::LazyLock;
use std::sync::{Arc, LazyLock};
use tokio::sync::Mutex; // Explicitly add this import
use crate::api_sessions::{ApiSession, SESSIONS, cleanup_expired_sessions};
use std::collections::HashMap;
// Custom rejection type for anyhow::Error
#[derive(Debug)]
struct AnyhowRejection(#[allow(dead_code)] anyhow::Error);
impl warp::reject::Reject for AnyhowRejection {}
pub static EXTENSION_MANAGER: LazyLock<ExtensionManager> = LazyLock::new(|| ExtensionManager::default());
pub static AGENT: LazyLock<tokio::sync::Mutex<Agent>> = LazyLock::new(|| tokio::sync::Mutex::new(Agent::new()));
@@ -69,7 +75,6 @@ pub struct ExtensionResponse {
#[derive(Debug, Serialize)]
pub struct MetricsResponse {
pub session_messages: HashMap<String, usize>,
pub active_sessions: usize,
pub pending_requests: HashMap<String, usize>,
}
@@ -119,11 +124,11 @@ pub async fn start_session_handler(
) -> Result<impl warp::Reply, Rejection> {
info!("Starting session with prompt: {}", req.prompt);
cleanup_expired_sessions();
cleanup_expired_sessions().await;
// create fresh agent using provider from the template agent
let template = AGENT.lock().await;
let mut new_agent = Agent::new();
let new_agent = Agent::new();
if let Ok(provider) = template.provider().await {
let _ = new_agent.update_provider(provider).await;
}
@@ -140,9 +145,8 @@ pub async fn start_session_handler(
let provider = agent_ref.lock().await.provider().await.ok();
let result = agent_ref
.lock()
.await
let agent_locked = agent_ref.lock().await;
let result = agent_locked
.reply(
&messages,
Some(SessionConfig {
@@ -155,54 +159,60 @@ pub async fn start_session_handler(
match result {
Ok(mut stream) => {
if let Ok(Some(response)) = stream.try_next().await {
let mut full_response_text = String::new();
let mut final_status = "success".to_string();
while let Some(agent_event) = stream.try_next().await.map_err(|e| custom(AnyhowRejection(e)))? {
let response = match agent_event {
AgentEvent::Message(msg) => msg,
_ => {
continue;
}
};
if matches!(response.content.first(), Some(MessageContent::ContextLengthExceeded(_))) {
match agent.summarize_context(&messages).await {
// This block needs to be handled carefully.
// The `agent` here refers to the global AGENT, not the session-specific agent_ref.
// This might be a bug in the original code.
// For now, I'll keep the existing logic but note this potential issue.
let session_agent = agent_ref.lock().await; // Use session-specific agent
match session_agent.summarize_context(&messages).await {
Ok((summarized, _)) => {
messages = summarized;
final_status = "warning".to_string();
full_response_text = "Conversation summarized to fit context window".to_string();
// Persist summarized messages immediately
if let Err(e) = session::persist_messages(&session_path, &messages, provider.clone()).await {
warn!("Failed to persist session {}: {}", session_name, e);
}
let api_response = StartSessionResponse {
message: "Conversation summarized to fit context window".to_string(),
status: "warning".to_string(),
session_id,
};
return Ok(warp::reply::with_status(
warp::reply::json(&api_response),
warp::http::StatusCode::OK,
));
break; // Exit loop after summarization
}
Err(e) => {
warn!("Failed to summarize context: {}", e);
final_status = "error".to_string();
full_response_text = format!("Failed to summarize context: {}", e);
break; // Exit loop on summarization error
}
}
}
let response_text = response.as_concat_text();
messages.push(response);
if let Err(e) = session::persist_messages(&session_path, &messages, provider.clone()).await {
warn!("Failed to persist session {}: {}", session_name, e);
}
let api_response = StartSessionResponse {
message: response_text,
status: "success".to_string(),
session_id,
};
Ok(warp::reply::with_status(
warp::reply::json(&api_response),
warp::http::StatusCode::OK,
))
} else {
let response_text = response.as_concat_text();
full_response_text.push_str(&response_text);
messages.push(response);
}
}
if full_response_text.is_empty() && final_status == "success" {
final_status = "warning".to_string();
full_response_text = "Session started but no response generated".to_string();
}
// Persist all messages after the stream is fully consumed
if let Err(e) = session::persist_messages(&session_path, &messages, provider.clone()).await {
warn!("Failed to persist session {}: {}", session_name, e);
}
let api_response = StartSessionResponse {
message: "Session started but no response generated".to_string(),
status: "warning".to_string(),
message: full_response_text,
status: final_status,
session_id,
};
Ok(warp::reply::with_status(
@@ -210,7 +220,6 @@ pub async fn start_session_handler(
warp::http::StatusCode::OK,
))
}
}
Err(e) => {
error!("Failed to start session: {}", e);
let response = ApiResponse {
@@ -231,7 +240,7 @@ pub async fn reply_session_handler(
) -> Result<impl warp::Reply, Rejection> {
info!("Replying to session with prompt: {}", req.prompt);
cleanup_expired_sessions();
cleanup_expired_sessions().await;
let session_name = req.session_id.to_string();
let session_path = session::get_path(Identifier::Name(session_name.clone()));
@@ -271,9 +280,8 @@ pub async fn reply_session_handler(
let provider = agent_ref.lock().await.provider().await.ok();
let result = agent_ref
.lock()
.await
let agent_locked = agent_ref.lock().await;
let result = agent_locked
.reply(
&messages,
Some(SessionConfig {
@@ -286,55 +294,65 @@ pub async fn reply_session_handler(
match result {
Ok(mut stream) => {
if let Ok(Some(response)) = stream.try_next().await {
let mut full_response_text = String::new();
let mut final_status = "success".to_string();
while let Some(agent_event) = stream.try_next().await.map_err(|e| custom(AnyhowRejection(e)))? {
let response = match agent_event {
AgentEvent::Message(msg) => msg,
_ => {
continue;
}
};
if matches!(response.content.first(), Some(MessageContent::ContextLengthExceeded(_))) {
match agent.summarize_context(&messages).await {
// This block needs to be handled carefully.
// The `agent` here refers to the global AGENT, not the session-specific agent_ref.
// This might be a bug in the original code.
// For now, I'll keep the existing logic but note this potential issue.
let session_agent = agent_ref.lock().await; // Use session-specific agent
match session_agent.summarize_context(&messages).await {
Ok((summarized, _)) => {
messages = summarized;
final_status = "warning".to_string();
full_response_text = "Conversation summarized to fit context window".to_string();
// Persist summarized messages immediately
if let Err(e) = session::persist_messages(&session_path, &messages, provider.clone()).await {
warn!("Failed to persist session {}: {}", session_name, e);
}
let api_response = ApiResponse {
message: "Conversation summarized to fit context window".to_string(),
status: "warning".to_string(),
};
return Ok(warp::reply::with_status(
warp::reply::json(&api_response),
warp::http::StatusCode::OK,
));
break; // Exit loop after summarization
}
Err(e) => {
warn!("Failed to summarize context: {}", e);
final_status = "error".to_string();
full_response_text = format!("Failed to summarize context: {}", e);
break; // Exit loop on summarization error
}
}
} else {
let response_text = response.as_concat_text();
full_response_text.push_str(&response_text);
messages.push(response);
}
}
let response_text = response.as_concat_text();
messages.push(response);
if full_response_text.is_empty() && final_status == "success" {
final_status = "warning".to_string();
full_response_text = "Reply processed but no response generated".to_string();
}
// Persist all messages after the stream is fully consumed
if let Err(e) = session::persist_messages(&session_path, &messages, provider.clone()).await {
warn!("Failed to persist session {}: {}", session_name, e);
}
let api_response = ApiResponse {
message: format!("Reply: {}", response_text),
status: "success".to_string(),
message: format!("Reply: {}", full_response_text),
status: final_status,
};
Ok(warp::reply::with_status(
warp::reply::json(&api_response),
warp::http::StatusCode::OK,
))
} else {
if let Err(e) = session::persist_messages(&session_path, &messages, provider.clone()).await {
warn!("Failed to persist session {}: {}", session_name, e);
}
let api_response = ApiResponse {
message: "Reply processed but no response generated".to_string(),
status: "warning".to_string(),
};
Ok(warp::reply::with_status(
warp::reply::json(&api_response),
warp::http::StatusCode::OK,
))
}
}
Err(e) => {
error!("Failed to reply to session: {}", e);
@@ -354,13 +372,15 @@ pub async fn end_session_handler(
req: EndSessionRequest,
_api_key: String,
) -> Result<impl warp::Reply, Rejection> {
cleanup_expired_sessions();
cleanup_expired_sessions().await;
let session_name = req.session_id.to_string();
let session_path = session::get_path(Identifier::Name(session_name.clone()));
// remove in-memory agent if present
SESSIONS.remove(&req.session_id);
if let Some((_, api_session)) = SESSIONS.remove(&req.session_id) {
shutdown_agent_extensions(api_session.agent).await;
}
if std::fs::remove_file(&session_path).is_ok() {
let response = ApiResponse {
@@ -477,158 +497,66 @@ pub async fn get_provider_config_handler() -> Result<impl warp::Reply, Rejection
Ok::<warp::reply::Json, warp::Rejection>(warp::reply::json(&response))
}
pub async fn add_extension_handler(
req: ExtensionConfigRequest,
_api_key: String,
) -> Result<impl warp::Reply, Rejection> {
info!("Adding extension: {:?}", req);
#[cfg(target_os = "windows")]
if let ExtensionConfigRequest::Stdio { cmd, .. } = &req {
if cmd.ends_with("npx.cmd") || cmd.ends_with("npx") {
let node_exists = std::path::Path::new(r"C:\Program Files\nodejs\node.exe").exists()
|| std::path::Path::new(r"C:\Program Files (x86)\nodejs\node.exe").exists();
pub async fn shutdown_agent_extensions(agent_ref: Arc<Mutex<Agent>>) {
let agent_guard = agent_ref.lock().await;
let extensions = agent_guard.list_extensions().await;
drop(agent_guard);
if !node_exists {
let cmd_path = std::path::Path::new(cmd);
let script_dir = cmd_path.parent().ok_or_else(|| warp::reject())?;
let install_script = script_dir.join("install-node.cmd");
if install_script.exists() {
eprintln!("Installing Node.js...");
let output = std::process::Command::new(&install_script)
.arg("https://nodejs.org/dist/v23.10.0/node-v23.10.0-x64.msi")
.output()
.map_err(|_| warp::reject())?;
if !output.status.success() {
eprintln!(
"Failed to install Node.js: {}",
String::from_utf8_lossy(&output.stderr)
);
let resp = ExtensionResponse {
error: true,
message: Some(format!(
"Failed to install Node.js: {}",
String::from_utf8_lossy(&output.stderr)
)),
};
return Ok(warp::reply::json(&resp));
}
eprintln!("Node.js installation completed");
} else {
eprintln!("Node.js installer script not found at: {}", install_script.display());
let resp = ExtensionResponse {
error: true,
message: Some("Node.js installer script not found".to_string()),
};
return Ok(warp::reply::json(&resp));
for ext_name in extensions {
let agent_guard = agent_ref.lock().await;
if let Err(e) = agent_guard.remove_extension(&ext_name).await {
error!("Failed to remove extension {} during shutdown: {}", ext_name, e);
}
}
}
}
let extension = match req {
ExtensionConfigRequest::Sse { name, uri, envs, env_keys, timeout } => {
ExtensionConfig::Sse {
name,
uri,
envs,
env_keys,
description: None,
timeout,
bundled: None,
}
}
ExtensionConfigRequest::Stdio { name, cmd, args, envs, env_keys, timeout } => {
ExtensionConfig::Stdio {
name,
cmd,
args,
envs,
env_keys,
timeout,
description: None,
bundled: None,
}
}
ExtensionConfigRequest::Builtin { name, display_name, timeout } => {
ExtensionConfig::Builtin {
name,
display_name,
timeout,
bundled: None,
}
}
ExtensionConfigRequest::Frontend { name, tools, instructions } => {
ExtensionConfig::Frontend {
name,
tools,
instructions,
bundled: None,
}
}
};
let agent = AGENT.lock().await;
let result = agent.add_extension(extension).await;
let resp = match result {
Ok(_) => ExtensionResponse { error: false, message: None },
Err(e) => ExtensionResponse {
error: true,
message: Some(format!("Failed to add extension configuration, error: {:?}", e)),
},
};
Ok(warp::reply::json(&resp))
}
pub async fn remove_extension_handler(
name: String,
_api_key: String,
) -> Result<impl warp::Reply, Rejection> {
info!("Removing extension: {}", name);
let agent = AGENT.lock().await;
let result = agent.remove_extension(&name).await;
let resp = match result {
Ok(_) => ExtensionResponse { error: false, message: None },
Err(e) => ExtensionResponse {
error: true,
message: Some(format!("Failed to remove extension, error: {:?}", e)),
},
};
Ok(warp::reply::json(&resp))
}
pub async fn metrics_handler() -> Result<impl warp::Reply, Rejection> {
// Gather session message counts
let mut session_messages = HashMap::new();
if let Ok(sessions) = session::list_sessions() {
for (name, path) in sessions {
if let Ok(messages) = session::read_messages(&path) {
session_messages.insert(name, messages.len());
}
}
}
info!("Getting metrics");
let active_sessions = session_messages.len();
// Gather pending request sizes for each extension
let pending_requests = EXTENSION_MANAGER
.pending_request_sizes()
.await;
let agent_guard = AGENT.lock().await;
let pending_requests: HashMap<String, usize> = agent_guard
.get_tool_stats()
.await
.unwrap_or_default()
.into_iter()
.map(|(k, v)| (k, v as usize))
.collect();
let resp = MetricsResponse {
session_messages,
active_sessions,
active_sessions: SESSIONS.len(),
pending_requests,
};
Ok(warp::reply::json(&resp))
}
pub async fn handle_rejection(err: Rejection) -> Result<impl warp::Reply, Rejection> {
if let Some(e) = err.find::<AnyhowRejection>() {
let message = e.0.to_string();
let status_code = if message.contains("Unauthorized") {
warp::http::StatusCode::UNAUTHORIZED
} else if message.contains("Failed to add extension") || message.contains("Failed to remove extension") {
warp::http::StatusCode::BAD_REQUEST
}
else {
warp::http::StatusCode::INTERNAL_SERVER_ERROR
};
let response = ApiResponse {
message,
status: "error".to_string(),
};
let json = warp::reply::json(&response);
Ok(warp::reply::with_status(json, status_code))
} else {
// If it's not a custom rejection, re-reject it
Err(err)
}
}
pub fn with_api_key(api_key: String) -> impl Filter<Extract = (String,), Error = Rejection> + Clone {
warp::header::value("x-api-key")
.and_then(move |header_api_key: HeaderValue| {
@@ -637,7 +565,8 @@ pub fn with_api_key(api_key: String) -> impl Filter<Extract = (String,), Error =
if header_api_key == api_key {
Ok(api_key)
} else {
Err(warp::reject::not_found())
warn!("Unauthorized access attempt with API key: {}", header_api_key.to_str().unwrap_or("invalid_header_value"));
Err(warp::reject::custom(AnyhowRejection(anyhow::anyhow!("Unauthorized"))))
}
}
})

View File

@@ -1,6 +1,80 @@
use goose_api::run_server;
use std::env;
#[tokio::main]
async fn main() -> Result<(), anyhow::Error> {
let args: Vec<String> = env::args().collect();
// Check if this is being called as an MCP server
if args.len() >= 3 && args[1] == "mcp" {
let extension_name = &args[2];
run_mcp_server(extension_name).await
} else {
// Run as the main API server
run_server().await
}
}
async fn run_mcp_server(extension_name: &str) -> Result<(), anyhow::Error> {
use goose_mcp::*;
use mcp_server::router::RouterService;
use mcp_server::{ByteTransport, Server};
use tokio::io::{stdin, stdout};
use tracing_subscriber;
// Initialize logging
tracing_subscriber::fmt()
.with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
.init();
// Route to the appropriate MCP server based on extension name
let result = match extension_name {
"computercontroller" => {
let router = RouterService(ComputerControllerRouter::new());
let server = Server::new(router);
let transport = ByteTransport::new(stdin(), stdout());
server.run(transport).await
},
"developer" => {
let router = RouterService(DeveloperRouter::new());
let server = Server::new(router);
let transport = ByteTransport::new(stdin(), stdout());
server.run(transport).await
},
"memory" => {
let router = RouterService(MemoryRouter::new());
let server = Server::new(router);
let transport = ByteTransport::new(stdin(), stdout());
server.run(transport).await
},
"google_drive" => {
let router = RouterService(GoogleDriveRouter::new().await);
let server = Server::new(router);
let transport = ByteTransport::new(stdin(), stdout());
server.run(transport).await
},
"jetbrains" => {
let router = RouterService(JetBrainsRouter::new());
let server = Server::new(router);
let transport = ByteTransport::new(stdin(), stdout());
server.run(transport).await
},
"tutorial" => {
let router = RouterService(TutorialRouter::new());
let server = Server::new(router);
let transport = ByteTransport::new(stdin(), stdout());
server.run(transport).await
},
_ => {
eprintln!("Unknown MCP extension: {}", extension_name);
std::process::exit(1);
}
};
if let Err(e) = result {
eprintln!("MCP server error for {}: {}", extension_name, e);
std::process::exit(1);
}
Ok(())
}

View File

@@ -2,13 +2,12 @@ use warp::Filter;
use tracing::{info, warn, error};
use crate::handlers::{
add_extension_handler, end_session_handler, get_provider_config_handler,
list_extensions_handler, remove_extension_handler, reply_session_handler,
end_session_handler, get_provider_config_handler, handle_rejection,
list_extensions_handler, metrics_handler, reply_session_handler,
start_session_handler, summarize_session_handler, with_api_key,
};
use crate::config::{
initialize_extensions, initialize_provider_config, load_configuration,
initialize_provider_config, load_configuration,
run_init_tests,
};
@@ -46,19 +45,6 @@ pub fn build_routes(api_key: String) -> impl Filter<Extract = impl warp::Reply,
.and(warp::get())
.and_then(list_extensions_handler);
let add_extension = warp::path("extensions")
.and(warp::path("add"))
.and(warp::post())
.and(warp::body::json())
.and(with_api_key(api_key.clone()))
.and_then(add_extension_handler);
let remove_extension = warp::path("extensions")
.and(warp::path("remove"))
.and(warp::post())
.and(warp::body::json())
.and(with_api_key(api_key.clone()))
.and_then(remove_extension_handler);
let get_provider_config = warp::path("provider")
.and(warp::path("config"))
@@ -74,10 +60,9 @@ pub fn build_routes(api_key: String) -> impl Filter<Extract = impl warp::Reply,
.or(summarize_session)
.or(end_session)
.or(list_extensions)
.or(add_extension)
.or(remove_extension)
.or(get_provider_config)
.or(metrics)
.recover(handle_rejection)
}
pub async fn run_server() -> Result<(), anyhow::Error> {
@@ -89,21 +74,28 @@ pub async fn run_server() -> Result<(), anyhow::Error> {
let api_config = load_configuration()?;
let api_key_source = if std::env::var("GOOSE_API_KEY").is_ok() {
"environment variable"
} else if api_config.get_string("api_key").is_ok() {
"config file"
} else {
"default"
};
info!("API key loaded from: {}", api_key_source);
let api_key: String = std::env::var("GOOSE_API_KEY")
.or_else(|_| api_config.get_string("api_key"))
.unwrap_or_else(|_| {
warn!("No API key configured, using default");
"default_api_key".to_string()
});
info!("Using API key: {}", api_key);
if let Err(e) = initialize_provider_config().await {
error!("Failed to initialize provider: {}", e);
return Err(e);
}
if let Err(e) = initialize_extensions(&api_config).await {
error!("Failed to initialize extensions: {}", e);
}
if let Err(e) = run_init_tests().await {
error!("Initialization tests failed: {}", e);
@@ -120,7 +112,7 @@ pub async fn run_server() -> Result<(), anyhow::Error> {
.parse::<u16>()
.unwrap_or(8080);
info!("Starting server on {}:{}", host, port);
info!("Server binding to {}:{}", host, port);
let host_parts: Vec<u8> = host
.split('.')
@@ -132,6 +124,27 @@ pub async fn run_server() -> Result<(), anyhow::Error> {
[127, 0, 0, 1]
};
warp::serve(routes).run((addr, port)).await;
let (_addr, server) = warp::serve(routes).bind_with_graceful_shutdown((addr, port), async {
tokio::signal::ctrl_c().await.expect("Failed to listen for Ctrl+C");
info!("Received Ctrl+C, initiating graceful shutdown...");
// Perform cleanup here
use crate::handlers::AGENT; // Import AGENT from handlers
use tracing::error; // Import error for logging
let agent_guard = AGENT.lock().await;
let extensions = agent_guard.list_extensions().await;
drop(agent_guard); // Release lock before async calls
for ext_name in extensions {
let agent_guard = AGENT.lock().await;
if let Err(e) = agent_guard.remove_extension(&ext_name).await {
error!("Failed to remove extension {} during graceful shutdown: {}", ext_name, e);
}
}
info!("Extensions shut down during graceful shutdown.");
});
server.await; // Await the server
Ok(())
}

98
crates/goose-api/test.py Normal file
View File

@@ -0,0 +1,98 @@
import requests
import json
BASE_URL = "http://localhost:8080"
API_KEY = "default_api_key"
HEADERS = {
"Content-Type": "application/json",
"x-api-key": API_KEY
}
def test_get_provider_config():
print("\n--- Testing GET /provider/config ---")
url = f"{BASE_URL}/provider/config"
response = requests.get(url, headers={"x-api-key": API_KEY})
print(f"Status Code: {response.status_code}")
print(f"Response: {response.json()}")
assert response.status_code == 200
assert "provider" in response.json()
assert "model" in response.json()
def test_start_session():
print("\n--- Testing POST /session/start ---")
url = f"{BASE_URL}/session/start"
data = {"prompt": "Create a Python function to generate Fibonacci numbers"}
response = requests.post(url, headers=HEADERS, data=json.dumps(data))
print(f"Status Code: {response.status_code}")
print(f"Response: {response.json()}")
assert response.status_code == 200
assert "session_id" in response.json()
return response.json().get("session_id")
def test_reply_session(session_id):
print(f"\n--- Testing POST /session/reply for session_id: {session_id} ---")
url = f"{BASE_URL}/session/reply"
data = {"session_id": session_id, "prompt": "Continue with the next Fibonacci number."}
response = requests.post(url, headers=HEADERS, data=json.dumps(data))
print(f"Status Code: {response.status_code}")
print(f"Response: {response.json()}")
assert response.status_code == 200
assert "message" in response.json()
def test_summarize_session(session_id):
print(f"\n--- Testing POST /session/summarize for session_id: {session_id} ---")
url = f"{BASE_URL}/session/summarize"
data = {"session_id": session_id}
response = requests.post(url, headers=HEADERS, data=json.dumps(data))
print(f"Status Code: {response.status_code}")
print(f"Response: {response.json()}")
assert response.status_code == 200
assert "summary" in response.json()
def test_end_session(session_id):
print(f"\n--- Testing POST /session/end for session_id: {session_id} ---")
url = f"{BASE_URL}/session/end"
data = {"session_id": session_id}
response = requests.post(url, headers=HEADERS, data=json.dumps(data))
print(f"Status Code: {response.status_code}")
print(f"Response: {response.json()}")
assert response.status_code == 200
assert "message" in response.json()
def test_list_extensions():
print("\n--- Testing GET /extensions/list ---")
url = f"{BASE_URL}/extensions/list"
response = requests.get(url, headers=HEADERS) # API key is not enforced for this endpoint, but including for consistency
print(f"Status Code: {response.status_code}")
print(f"Response: {response.json()}")
assert response.status_code == 200
assert "extensions" in response.json()
def test_get_metrics():
print("\n--- Testing GET /metrics ---")
url = f"{BASE_URL}/metrics"
response = requests.get(url) # No API key required
print(f"Status Code: {response.status_code}")
print(f"Response: {response.json()}")
assert response.status_code == 200
assert "active_sessions" in response.json()
assert "pending_requests" in response.json()
if __name__ == "__main__":
print("Starting API endpoint tests...")
# Test endpoints that don't require a session_id first
test_get_provider_config()
test_list_extensions()
test_get_metrics()
# Test session-related endpoints
session_id = test_start_session()
if session_id:
test_reply_session(session_id)
test_summarize_session(session_id)
test_end_session(session_id)
else:
print("Skipping session tests as session_id was not obtained.")
print("\nAll tests completed.")

View File

@@ -18,13 +18,8 @@ use crate::agents::extension::Envs;
use crate::config::{Config, ExtensionConfigManager};
use crate::prompt_template;
use mcp_client::client::{ClientCapabilities, ClientInfo, McpClient, McpClientTrait};
<<<<<<< HEAD
use mcp_client::transport::{PendingRequests, SseTransport, StdioTransport, Transport};
use mcp_core::{prompt::Prompt, Content, Tool, ToolCall, ToolError, ToolResult};
=======
use mcp_client::transport::{SseTransport, StdioTransport, Transport};
use mcp_core::{prompt::Prompt, Content, Tool, ToolCall, ToolError};
>>>>>>> 2f8f8e5767bb1fdc53dfaa4a492c9184f02c3721
use serde_json::Value;
// By default, we set it to Jan 1, 2020 if the resource does not have a timestamp
@@ -39,7 +34,6 @@ pub struct ExtensionManager {
clients: HashMap<String, McpClientBox>,
instructions: HashMap<String, String>,
resource_capable_extensions: HashSet<String>,
pending_requests: HashMap<String, Arc<PendingRequests>>, // track pending requests per extension
}
/// A flattened representation of a resource used by the agent to prepare inference
@@ -110,7 +104,6 @@ impl ExtensionManager {
clients: HashMap::new(),
instructions: HashMap::new(),
resource_capable_extensions: HashSet::new(),
pending_requests: HashMap::new(),
}
}
@@ -192,17 +185,6 @@ impl ExtensionManager {
let all_envs = merge_environments(envs, env_keys, &sanitized_name).await?;
let transport = SseTransport::new(uri, all_envs);
let handle = transport.start().await?;
<<<<<<< HEAD
let pending = handle.pending_requests();
let service = McpService::with_timeout(
handle,
Duration::from_secs(
timeout.unwrap_or(crate::config::DEFAULT_EXTENSION_TIMEOUT),
),
);
self.pending_requests.insert(sanitized_name.clone(), pending);
Box::new(McpClient::new(service))
=======
Box::new(
McpClient::connect(
handle,
@@ -212,7 +194,6 @@ impl ExtensionManager {
)
.await?,
)
>>>>>>> 2f8f8e5767bb1fdc53dfaa4a492c9184f02c3721
}
ExtensionConfig::Stdio {
cmd,
@@ -225,17 +206,6 @@ impl ExtensionManager {
let all_envs = merge_environments(envs, env_keys, &sanitized_name).await?;
let transport = StdioTransport::new(cmd, args.to_vec(), all_envs);
let handle = transport.start().await?;
<<<<<<< HEAD
let pending = handle.pending_requests();
let service = McpService::with_timeout(
handle,
Duration::from_secs(
timeout.unwrap_or(crate::config::DEFAULT_EXTENSION_TIMEOUT),
),
);
self.pending_requests.insert(sanitized_name.clone(), pending);
Box::new(McpClient::new(service))
=======
Box::new(
McpClient::connect(
handle,
@@ -245,7 +215,6 @@ impl ExtensionManager {
)
.await?,
)
>>>>>>> 2f8f8e5767bb1fdc53dfaa4a492c9184f02c3721
}
ExtensionConfig::Builtin {
name,
@@ -264,17 +233,6 @@ impl ExtensionManager {
HashMap::new(),
);
let handle = transport.start().await?;
<<<<<<< HEAD
let pending = handle.pending_requests();
let service = McpService::with_timeout(
handle,
Duration::from_secs(
timeout.unwrap_or(crate::config::DEFAULT_EXTENSION_TIMEOUT),
),
);
self.pending_requests.insert(sanitized_name.clone(), pending);
Box::new(McpClient::new(service))
=======
Box::new(
McpClient::connect(
handle,
@@ -284,7 +242,6 @@ impl ExtensionManager {
)
.await?,
)
>>>>>>> 2f8f8e5767bb1fdc53dfaa4a492c9184f02c3721
}
_ => unreachable!(),
};
@@ -336,19 +293,9 @@ impl ExtensionManager {
self.clients.remove(&sanitized_name);
self.instructions.remove(&sanitized_name);
self.resource_capable_extensions.remove(&sanitized_name);
self.pending_requests.remove(&sanitized_name);
Ok(())
}
/// Get the size of each extension's pending request map
pub async fn pending_request_sizes(&self) -> HashMap<String, usize> {
let mut result = HashMap::new();
for (name, pending) in &self.pending_requests {
result.insert(name.clone(), pending.len().await);
}
result
}
pub async fn suggest_disable_extensions_prompt(&self) -> Value {
let enabled_extensions_count = self.clients.len();

View File

@@ -187,7 +187,6 @@ mod tests {
use super::*;
use serde_json::json;
use std::collections::HashMap;
use serial_test::serial;
use tokio::sync::Mutex;
use tracing::dispatcher;
use wiremock::matchers::{method, path};
@@ -390,7 +389,6 @@ mod tests {
}
#[tokio::test]
#[serial]
async fn test_create_langfuse_observer() {
let fixture = TestFixture::new().await.with_mock_server().await;

View File

@@ -8,7 +8,6 @@ use goose::providers::{
};
use mcp_core::content::Content;
use mcp_core::tool::Tool;
use serial_test::serial;
use std::collections::HashMap;
use std::sync::Arc;
use std::sync::Mutex;
@@ -353,7 +352,6 @@ where
}
#[tokio::test]
#[serial]
async fn test_openai_provider() -> Result<()> {
test_provider(
"OpenAI",
@@ -365,7 +363,6 @@ async fn test_openai_provider() -> Result<()> {
}
#[tokio::test]
#[serial]
async fn test_azure_provider() -> Result<()> {
test_provider(
"Azure",
@@ -381,7 +378,6 @@ async fn test_azure_provider() -> Result<()> {
}
#[tokio::test]
#[serial]
async fn test_bedrock_provider_long_term_credentials() -> Result<()> {
test_provider(
"Bedrock",
@@ -393,7 +389,6 @@ async fn test_bedrock_provider_long_term_credentials() -> Result<()> {
}
#[tokio::test]
#[serial]
async fn test_bedrock_provider_aws_profile_credentials() -> Result<()> {
let env_mods = HashMap::from_iter([
// Ensure to unset long-term credentials to use AWS Profile provider
@@ -411,7 +406,6 @@ async fn test_bedrock_provider_aws_profile_credentials() -> Result<()> {
}
#[tokio::test]
#[serial]
async fn test_databricks_provider() -> Result<()> {
test_provider(
"Databricks",
@@ -423,7 +417,6 @@ async fn test_databricks_provider() -> Result<()> {
}
#[tokio::test]
#[serial]
async fn test_databricks_provider_oauth() -> Result<()> {
let mut env_mods = HashMap::new();
env_mods.insert("DATABRICKS_TOKEN", None);
@@ -438,7 +431,6 @@ async fn test_databricks_provider_oauth() -> Result<()> {
}
#[tokio::test]
#[serial]
async fn test_ollama_provider() -> Result<()> {
test_provider(
"Ollama",
@@ -450,13 +442,11 @@ async fn test_ollama_provider() -> Result<()> {
}
#[tokio::test]
#[serial]
async fn test_groq_provider() -> Result<()> {
test_provider("Groq", &["GROQ_API_KEY"], None, groq::GroqProvider::default).await
}
#[tokio::test]
#[serial]
async fn test_anthropic_provider() -> Result<()> {
test_provider(
"Anthropic",
@@ -468,7 +458,6 @@ async fn test_anthropic_provider() -> Result<()> {
}
#[tokio::test]
#[serial]
async fn test_openrouter_provider() -> Result<()> {
test_provider(
"OpenRouter",
@@ -480,7 +469,6 @@ async fn test_openrouter_provider() -> Result<()> {
}
#[tokio::test]
#[serial]
async fn test_google_provider() -> Result<()> {
test_provider(
"Google",

View File

@@ -1,5 +1,4 @@
use std::collections::HashMap;
use std::sync::atomic::{AtomicI32, Ordering};
use std::sync::Arc;
use tokio::process::{Child, ChildStderr, ChildStdin, ChildStdout, Command};
@@ -16,9 +15,6 @@ use nix::unistd::{getpgid, Pid};
use super::{serialize_and_send, Error, Transport, TransportHandle};
// Global to track process groups we've created
static PROCESS_GROUP: AtomicI32 = AtomicI32::new(-1);
/// A `StdioTransport` uses a child process's stdin/stdout as a communication channel.
///
/// It uses channels for message passing and handles responses asynchronously through a background task.
@@ -30,21 +26,21 @@ pub struct StdioActor {
stdin: Option<ChildStdin>,
stdout: Option<ChildStdout>,
stderr: Option<ChildStderr>,
#[cfg(unix)]
pgid: Option<i32>, // Process group ID for cleanup
}
impl Drop for StdioActor {
fn drop(&mut self) {
// Get the process group ID before attempting cleanup
#[cfg(unix)]
if let Some(pid) = self.process.id() {
if let Ok(pgid) = getpgid(Some(Pid::from_raw(pid as i32))) {
if let Some(pgid) = self.pgid {
// Send SIGTERM to the entire process group
let _ = kill(Pid::from_raw(-pgid.as_raw()), Signal::SIGTERM);
// Give processes a moment to cleanup
let _ = kill(Pid::from_raw(-pgid), Signal::SIGTERM);
// Note: std::thread::sleep is blocking, but this is a Drop impl.
// For graceful async shutdown, use the `close` method on `StdioTransport`.
std::thread::sleep(std::time::Duration::from_millis(100));
// Force kill if still running
let _ = kill(Pid::from_raw(-pgid.as_raw()), Signal::SIGKILL);
}
let _ = kill(Pid::from_raw(-pgid), Signal::SIGKILL);
}
}
}
@@ -155,7 +151,6 @@ pub struct StdioTransportHandle {
sender: mpsc::Sender<String>, // to process
receiver: Arc<Mutex<mpsc::Receiver<JsonRpcMessage>>>, // from process
error_receiver: Arc<Mutex<mpsc::Receiver<Error>>>,
pending_requests: Arc<PendingRequests>,
}
#[async_trait::async_trait]
@@ -184,10 +179,6 @@ impl StdioTransportHandle {
Err(_) => Ok(()),
}
}
pub fn pending_requests(&self) -> Arc<PendingRequests> {
Arc::clone(&self.pending_requests)
}
}
pub struct StdioTransport {
@@ -209,7 +200,7 @@ impl StdioTransport {
}
}
async fn spawn_process(&self) -> Result<(Child, ChildStdin, ChildStdout, ChildStderr), Error> {
async fn spawn_process(&self) -> Result<(Child, ChildStdin, ChildStdout, ChildStderr, Option<i32>), Error> {
let mut command = Command::new(&self.command);
command
.envs(&self.env)
@@ -246,16 +237,16 @@ impl StdioTransport {
.take()
.ok_or_else(|| Error::StdioProcessError("Failed to get stderr".into()))?;
let mut pgid = None;
// Store the process group ID for cleanup
#[cfg(unix)]
if let Some(pid) = process.id() {
// Use nix instead of unsafe libc calls
if let Ok(pgid) = getpgid(Some(Pid::from_raw(pid as i32))) {
PROCESS_GROUP.store(pgid.as_raw(), Ordering::SeqCst);
if let Ok(id) = getpgid(Some(Pid::from_raw(pid as i32))) {
pgid = Some(id.as_raw());
}
}
Ok((process, stdin, stdout, stderr))
Ok((process, stdin, stdout, stderr, pgid))
}
}
@@ -264,12 +255,11 @@ impl Transport for StdioTransport {
type Handle = StdioTransportHandle;
async fn start(&self) -> Result<Self::Handle, Error> {
let (process, stdin, stdout, stderr) = self.spawn_process().await?;
let (process, stdin, stdout, stderr, pgid) = self.spawn_process().await?;
let (outbox_tx, outbox_rx) = mpsc::channel(32);
let (inbox_tx, inbox_rx) = mpsc::channel(32);
let (error_tx, error_rx) = mpsc::channel(1);
let pending_requests = Arc::new(PendingRequests::new());
let actor = StdioActor {
receiver: Some(outbox_rx), // client to process
sender: Some(inbox_tx), // process to client
@@ -278,6 +268,8 @@ impl Transport for StdioTransport {
stdin: Some(stdin),
stdout: Some(stdout),
stderr: Some(stderr),
#[cfg(unix)]
pgid, // Pass the pgid to the actor
};
tokio::spawn(actor.run());
@@ -286,23 +278,13 @@ impl Transport for StdioTransport {
sender: outbox_tx, // client to process
receiver: Arc::new(Mutex::new(inbox_rx)), // process to client
error_receiver: Arc::new(Mutex::new(error_rx)),
pending_requests,
};
Ok(handle)
}
async fn close(&self) -> Result<(), Error> {
// Attempt to clean up the process group on close
#[cfg(unix)]
if let Some(pgid) = PROCESS_GROUP.load(Ordering::SeqCst).checked_abs() {
// Use nix instead of unsafe libc calls
// Try SIGTERM first
let _ = kill(Pid::from_raw(-pgid), Signal::SIGTERM);
// Give processes a moment to cleanup
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
// Force kill if still running
let _ = kill(Pid::from_raw(-pgid), Signal::SIGKILL);
}
// The StdioActor's Drop implementation handles process termination.
// This method can be a no-op for now.
Ok(())
}
}