mirror of
https://github.com/aljazceru/goose.git
synced 2025-12-17 22:24:21 +01:00
api
This commit is contained in:
144
Cargo.lock
generated
144
Cargo.lock
generated
@@ -261,7 +261,7 @@ version = "52.2.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "16f4a9468c882dc66862cef4e1fd8423d47e67972377d85d80e022786427768c"
|
||||
dependencies = [
|
||||
"ahash",
|
||||
"ahash 0.8.11",
|
||||
"arrow-buffer",
|
||||
"arrow-data",
|
||||
"arrow-schema",
|
||||
@@ -392,7 +392,7 @@ version = "52.2.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4cd09a518c602a55bd406bcc291a967b284cfa7a63edfbf8b897ea4748aad23c"
|
||||
dependencies = [
|
||||
"ahash",
|
||||
"ahash 0.8.11",
|
||||
"arrow-array",
|
||||
"arrow-buffer",
|
||||
"arrow-data",
|
||||
@@ -415,7 +415,7 @@ version = "52.2.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "600bae05d43483d216fb3494f8c32fdbefd8aa4e1de237e790dbb3d9f44690a3"
|
||||
dependencies = [
|
||||
"ahash",
|
||||
"ahash 0.8.11",
|
||||
"arrow-array",
|
||||
"arrow-buffer",
|
||||
"arrow-data",
|
||||
@@ -686,7 +686,7 @@ dependencies = [
|
||||
"fastrand 2.3.0",
|
||||
"hex",
|
||||
"http 0.2.12",
|
||||
"ring 0.17.12",
|
||||
"ring 0.17.14",
|
||||
"time",
|
||||
"tokio",
|
||||
"tracing",
|
||||
@@ -1074,7 +1074,7 @@ dependencies = [
|
||||
"sha1",
|
||||
"sync_wrapper 1.0.2",
|
||||
"tokio",
|
||||
"tokio-tungstenite 0.24.0",
|
||||
"tokio-tungstenite 0.26.2",
|
||||
"tower 0.5.2",
|
||||
"tower-layer",
|
||||
"tower-service",
|
||||
@@ -1819,6 +1819,25 @@ dependencies = [
|
||||
"crossbeam-utils",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "config"
|
||||
version = "0.13.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "23738e11972c7643e4ec947840fc463b6a571afcd3e735bdfce7d03c7a784aca"
|
||||
dependencies = [
|
||||
"async-trait",
|
||||
"json5",
|
||||
"lazy_static",
|
||||
"nom",
|
||||
"pathdiff",
|
||||
"ron 0.7.1",
|
||||
"rust-ini 0.18.0",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"toml 0.5.11",
|
||||
"yaml-rust",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "config"
|
||||
version = "0.14.1"
|
||||
@@ -2200,7 +2219,7 @@ version = "41.0.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e4fd4a99fc70d40ef7e52b243b4a399c3f8d353a40d5ecb200deee05e49c61bb"
|
||||
dependencies = [
|
||||
"ahash",
|
||||
"ahash 0.8.11",
|
||||
"arrow",
|
||||
"arrow-array",
|
||||
"arrow-ipc",
|
||||
@@ -2263,7 +2282,7 @@ version = "41.0.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "44fdbc877e3e40dcf88cc8f283d9f5c8851f0a3aa07fee657b1b75ac1ad49b9c"
|
||||
dependencies = [
|
||||
"ahash",
|
||||
"ahash 0.8.11",
|
||||
"arrow",
|
||||
"arrow-array",
|
||||
"arrow-buffer",
|
||||
@@ -2314,7 +2333,7 @@ version = "41.0.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1c1841c409d9518c17971d15c9bae62e629eb937e6fb6c68cd32e9186f8b30d2"
|
||||
dependencies = [
|
||||
"ahash",
|
||||
"ahash 0.8.11",
|
||||
"arrow",
|
||||
"arrow-array",
|
||||
"arrow-buffer",
|
||||
@@ -2356,7 +2375,7 @@ version = "41.0.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2b4ece19f73c02727e5e8654d79cd5652de371352c1df3c4ac3e419ecd6943fb"
|
||||
dependencies = [
|
||||
"ahash",
|
||||
"ahash 0.8.11",
|
||||
"arrow",
|
||||
"arrow-schema",
|
||||
"datafusion-common",
|
||||
@@ -2416,7 +2435,7 @@ version = "41.0.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9a223962b3041304a3e20ed07a21d5de3d88d7e4e71ca192135db6d24e3365a4"
|
||||
dependencies = [
|
||||
"ahash",
|
||||
"ahash 0.8.11",
|
||||
"arrow",
|
||||
"arrow-array",
|
||||
"arrow-buffer",
|
||||
@@ -2446,7 +2465,7 @@ version = "41.0.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "db5e7d8532a1601cd916881db87a70b0a599900d23f3db2897d389032da53bc6"
|
||||
dependencies = [
|
||||
"ahash",
|
||||
"ahash 0.8.11",
|
||||
"arrow",
|
||||
"datafusion-common",
|
||||
"datafusion-expr",
|
||||
@@ -2472,7 +2491,7 @@ version = "41.0.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8d1116949432eb2d30f6362707e2846d942e491052a206f2ddcb42d08aea1ffe"
|
||||
dependencies = [
|
||||
"ahash",
|
||||
"ahash 0.8.11",
|
||||
"arrow",
|
||||
"arrow-array",
|
||||
"arrow-buffer",
|
||||
@@ -3425,7 +3444,7 @@ dependencies = [
|
||||
"futures-util",
|
||||
"include_dir",
|
||||
"indoc 2.0.6",
|
||||
"jsonwebtoken",
|
||||
"jsonwebtoken 9.3.1",
|
||||
"keyring",
|
||||
"lancedb",
|
||||
"lazy_static",
|
||||
@@ -3467,7 +3486,9 @@ name = "goose-api"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"async-trait",
|
||||
"config 0.13.4",
|
||||
"dashmap 6.1.0",
|
||||
"futures",
|
||||
"futures-util",
|
||||
"goose",
|
||||
@@ -3475,8 +3496,10 @@ dependencies = [
|
||||
"jsonwebtoken 8.3.0",
|
||||
"mcp-client",
|
||||
"mcp-core",
|
||||
"mcp-server",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"tempfile",
|
||||
"tokio",
|
||||
"tracing",
|
||||
"tracing-subscriber",
|
||||
@@ -4592,7 +4615,7 @@ dependencies = [
|
||||
"base64 0.22.1",
|
||||
"js-sys",
|
||||
"pem 3.0.5",
|
||||
"ring 0.17.12",
|
||||
"ring 0.17.14",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"simple_asn1",
|
||||
@@ -5638,6 +5661,24 @@ dependencies = [
|
||||
"syn 2.0.99",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "multer"
|
||||
version = "2.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "01acbdc23469fd8fe07ab135923371d5f5a422fbf9c522158677c8eb15bc51c2"
|
||||
dependencies = [
|
||||
"bytes",
|
||||
"encoding_rs",
|
||||
"futures-util",
|
||||
"http 0.2.12",
|
||||
"httparse",
|
||||
"log",
|
||||
"memchr",
|
||||
"mime",
|
||||
"spin 0.9.8",
|
||||
"version_check",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "multimap"
|
||||
version = "0.10.1"
|
||||
@@ -5969,7 +6010,7 @@ dependencies = [
|
||||
"quick-xml 0.36.2",
|
||||
"rand 0.8.5",
|
||||
"reqwest 0.12.12",
|
||||
"ring",
|
||||
"ring 0.17.14",
|
||||
"rustls-pemfile 2.2.0",
|
||||
"serde",
|
||||
"serde_json",
|
||||
@@ -6712,7 +6753,7 @@ dependencies = [
|
||||
"bytes",
|
||||
"getrandom 0.2.15",
|
||||
"rand 0.8.5",
|
||||
"ring",
|
||||
"ring 0.17.14",
|
||||
"rustc-hash 2.1.1",
|
||||
"rustls 0.23.23",
|
||||
"rustls-pki-types",
|
||||
@@ -7118,6 +7159,21 @@ dependencies = [
|
||||
"bytemuck",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ring"
|
||||
version = "0.16.20"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3053cf52e236a3ed746dfc745aa9cacf1b791d846bdaf412f60a8d7d6e17c8fc"
|
||||
dependencies = [
|
||||
"cc",
|
||||
"libc",
|
||||
"once_cell",
|
||||
"spin 0.5.2",
|
||||
"untrusted 0.7.1",
|
||||
"web-sys",
|
||||
"winapi",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ring"
|
||||
version = "0.17.14"
|
||||
@@ -7142,6 +7198,17 @@ dependencies = [
|
||||
"byteorder",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ron"
|
||||
version = "0.7.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "88073939a61e5b7680558e6be56b419e208420c2adb92be54921fa6b72283f1a"
|
||||
dependencies = [
|
||||
"base64 0.13.1",
|
||||
"bitflags 1.3.2",
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ron"
|
||||
version = "0.8.1"
|
||||
@@ -7171,7 +7238,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3e0698206bcb8882bf2a9ecb4c1e7785db57ff052297085a6efd4fe42302068a"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"ordered-multimap",
|
||||
"ordered-multimap 0.7.3",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -7258,7 +7325,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3f56a14d1f48b391359b22f731fd4bd7e43c97f3c50eee276f3aa09c94784d3e"
|
||||
dependencies = [
|
||||
"log",
|
||||
"ring 0.17.12",
|
||||
"ring 0.17.14",
|
||||
"rustls-webpki 0.101.7",
|
||||
"sct",
|
||||
]
|
||||
@@ -7270,7 +7337,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "47796c98c480fce5406ef69d1c76378375492c3b0a0de587be0c1d9feb12f395"
|
||||
dependencies = [
|
||||
"once_cell",
|
||||
"ring 0.17.12",
|
||||
"ring 0.17.14",
|
||||
"rustls-pki-types",
|
||||
"rustls-webpki 0.102.8",
|
||||
"subtle",
|
||||
@@ -7334,7 +7401,7 @@ version = "0.101.7"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8b6275d1ee7a1cd780b64aca7726599a1dbc893b1e64144529e55c3c2f745765"
|
||||
dependencies = [
|
||||
"ring 0.17.12",
|
||||
"ring 0.17.14",
|
||||
"untrusted 0.9.0",
|
||||
]
|
||||
|
||||
@@ -7344,7 +7411,7 @@ version = "0.102.8"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "64ca1bc8749bd4cf37b5ce386cc146580777b4e8572c7b97baf22c83f444bee9"
|
||||
dependencies = [
|
||||
"ring 0.17.12",
|
||||
"ring 0.17.14",
|
||||
"rustls-pki-types",
|
||||
"untrusted 0.9.0",
|
||||
]
|
||||
@@ -7481,7 +7548,7 @@ version = "0.7.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "da046153aa2352493d6cb7da4b6e5c0c057d8a1d0a9aa8560baffdd945acd414"
|
||||
dependencies = [
|
||||
"ring 0.17.12",
|
||||
"ring 0.17.14",
|
||||
"untrusted 0.9.0",
|
||||
]
|
||||
|
||||
@@ -8657,6 +8724,18 @@ dependencies = [
|
||||
"tokio",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tokio-tungstenite"
|
||||
version = "0.21.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c83b561d025642014097b66e6c1bb422783339e0909e4429cde4749d1990bc38"
|
||||
dependencies = [
|
||||
"futures-util",
|
||||
"log",
|
||||
"tokio",
|
||||
"tungstenite 0.21.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tokio-tungstenite"
|
||||
version = "0.26.2"
|
||||
@@ -8666,7 +8745,7 @@ dependencies = [
|
||||
"futures-util",
|
||||
"log",
|
||||
"tokio",
|
||||
"tungstenite",
|
||||
"tungstenite 0.26.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -8894,6 +8973,25 @@ version = "0.2.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b"
|
||||
|
||||
[[package]]
|
||||
name = "tungstenite"
|
||||
version = "0.21.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9ef1a641ea34f399a848dea702823bbecfb4c486f911735368f1f137cb8257e1"
|
||||
dependencies = [
|
||||
"byteorder",
|
||||
"bytes",
|
||||
"data-encoding",
|
||||
"http 1.2.0",
|
||||
"httparse",
|
||||
"log",
|
||||
"rand 0.8.5",
|
||||
"sha1",
|
||||
"thiserror 1.0.69",
|
||||
"url",
|
||||
"utf-8",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tungstenite"
|
||||
version = "0.26.2"
|
||||
|
||||
49
config.yaml
49
config.yaml
@@ -1,49 +0,0 @@
|
||||
extensions:
|
||||
computercontroller:
|
||||
bundled: true
|
||||
display_name: Computer Controller
|
||||
enabled: true
|
||||
name: computercontroller
|
||||
timeout: 300
|
||||
type: builtin
|
||||
developer:
|
||||
bundled: true
|
||||
display_name: Developer Tools
|
||||
enabled: true
|
||||
name: developer
|
||||
timeout: 300
|
||||
type: builtin
|
||||
filesytem:
|
||||
args:
|
||||
- -y
|
||||
- '@modelcontextprotocol/server-filesystem'
|
||||
- /home/lio/g
|
||||
bundled: null
|
||||
cmd: npx
|
||||
description: 'access files inside ~/g '
|
||||
enabled: true
|
||||
env_keys: []
|
||||
envs: {}
|
||||
name: filesytem
|
||||
timeout: 300
|
||||
type: stdio
|
||||
filesytem-extension:
|
||||
args:
|
||||
- -y
|
||||
- '@modelcontextprotocol/server-filesystem'
|
||||
bundled: null
|
||||
cmd: npx
|
||||
description: null
|
||||
enabled: false
|
||||
env_keys: []
|
||||
envs: {}
|
||||
name: filesytem-extension
|
||||
timeout: 300
|
||||
type: stdio
|
||||
memory:
|
||||
bundled: true
|
||||
display_name: Memory
|
||||
enabled: true
|
||||
name: memory
|
||||
timeout: 300
|
||||
type: builtin
|
||||
@@ -8,6 +8,7 @@ goose = { path = "../goose" }
|
||||
goose-mcp = { path = "../goose-mcp" }
|
||||
mcp-client = { path = "../mcp-client" }
|
||||
mcp-core = { path = "../mcp-core" }
|
||||
mcp-server = { path = "../mcp-server" }
|
||||
tokio = { version = "1", features = ["full"] }
|
||||
warp = "0.3"
|
||||
serde = { version = "1", features = ["derive"] }
|
||||
|
||||
@@ -287,8 +287,15 @@ By default, the server runs on `127.0.0.1:8080`. You can modify this using confi
|
||||
|
||||
Sessions created via the API are stored in the same location as the CLI
|
||||
(`~/.local/share/goose/sessions` on most platforms). Each session is saved to a
|
||||
`<session_id>.jsonl` file. You can resume or inspect these sessions with the CLI
|
||||
by providing the session ID returned from the API.
|
||||
`<session_id>.jsonl` file.
|
||||
|
||||
You can resume or inspect these sessions with the CLI by providing the session ID
|
||||
(which is a UUID) returned from the API. For example, if the API returns a
|
||||
`session_id` of `a1b2c3d4-e5f6-7890-1234-567890abcdef`, you can resume it with:
|
||||
|
||||
```bash
|
||||
goose session --resume --name a1b2c3d4-e5f6-7890-1234-567890abcdef
|
||||
```
|
||||
|
||||
## Examples
|
||||
|
||||
@@ -298,7 +305,7 @@ by providing the session ID returned from the API.
|
||||
# Start a session
|
||||
curl -X POST http://localhost:8080/session/start \
|
||||
-H "Content-Type: application/json" \
|
||||
-H "x-api-key: your_secure_api_key" \
|
||||
-H "x-api-key: kurac" \
|
||||
-d '{"prompt": "Create a Python function to generate Fibonacci numbers"}'
|
||||
|
||||
# Reply to an ongoing session
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
# API server configuration
|
||||
host: 0.0.0.0
|
||||
port: 8080
|
||||
port: 8181
|
||||
api_key: kurac
|
||||
|
||||
# Provider configuration
|
||||
provider: ollama
|
||||
model: qwen3:8b
|
||||
model: qwen3:4b
|
||||
|
||||
53
crates/goose-api/goose-api-plan.md
Normal file
53
crates/goose-api/goose-api-plan.md
Normal file
@@ -0,0 +1,53 @@
|
||||
# Plan for `goose-api` Review and Improvements
|
||||
|
||||
This document outlines the plan to address the user's request regarding `goose-api`'s interaction with `goose-cli`, session sharing, and reported resource exhaustion/memory leaks. All changes will be confined to the `crates/goose-api` crate.
|
||||
|
||||
## Summary of Findings
|
||||
|
||||
### Session Sharing
|
||||
* Both `goose-api` and `goose-cli` leverage the `goose` crate's session management, storing sessions as `.jsonl` files in a common directory (`~/.local/share/goose/sessions` by default).
|
||||
* `goose-api` generates a `Uuid` for each new session and returns it. This UUID is used as the session name for file persistence.
|
||||
* `goose-cli`'s `session resume` command can accept a session name or path. Therefore, the UUID returned by `goose-api` can be used directly with `goose-cli session --resume --name <UUID>`.
|
||||
|
||||
### Resource Exhaustion and Memory Leaks
|
||||
* **Primary Suspect: Partial Stream Consumption in `agent.reply`:** In `crates/goose-api/src/handlers.rs`, both `start_session_handler` and `reply_session_handler` only consume the *first* item from the `BoxStream` returned by `agent.reply`. If `agent.reply` produces a stream of multiple messages (common for LLM interactions), the remaining messages and associated resources are not consumed or released, leading to memory accumulation. This is highly likely to be the root cause of single-session resource exhaustion.
|
||||
* **Per-Session `Agent` Instances:** `goose-api` creates a new `Agent` instance for each session and stores it in an in-memory `DashMap` (`SESSIONS`). While this provides session isolation, it means more `Agent` instances (each with its own internal state and resources) are held in memory.
|
||||
* **Session Cleanup:** `cleanup_expired_sessions()` is called to remove inactive sessions from the `DashMap` after `SESSION_TIMEOUT_SECS` (currently 1 hour). If this timeout is too long, or if `Agent` instances don't fully release resources upon being dropped, memory can accumulate.
|
||||
* **LLM Calls for Summarization:** `generate_description` (in `goose::session::storage`) and `agent.summarize_context` (in `goose` crate) involve additional LLM calls, which are resource-intensive operations.
|
||||
* **Extension Management:** `Stdio` extensions can spawn external processes. If these processes are not properly terminated when their associated `Agent` is dropped, they could contribute to leaks.
|
||||
|
||||
## Detailed Plan
|
||||
|
||||
### Phase 1: Address Immediate Resource Leak (Critical)
|
||||
|
||||
1. **Fully Consume `agent.reply` Stream in `crates/goose-api/src/handlers.rs`:**
|
||||
* **Action:** Modify `start_session_handler` and `reply_session_handler` to iterate through the entire `BoxStream<anyhow::Result<Message>>` returned by `agent.reply`. All messages from the stream will be collected and concatenated to form the complete response. This ensures all resources associated with the stream are properly released.
|
||||
|
||||
* **Mermaid Diagram for Stream Consumption:**
|
||||
```mermaid
|
||||
graph TD
|
||||
A[Call agent.reply()] --> B{Receive BoxStream<Message>};
|
||||
B --> C{Loop: stream.try_next().await};
|
||||
C -- Has Message --> D[Append Message to history];
|
||||
C -- No More Messages / Error --> E[Process complete response];
|
||||
D --> C;
|
||||
```
|
||||
|
||||
### Phase 2: Improve Session Sharing (Documentation within `goose-api`)
|
||||
|
||||
1. **Clarify Session ID Usage in `crates/goose-api/README.md`:**
|
||||
* **Action:** Add a clear note or example in the "Session Management" section of `crates/goose-api/README.md` demonstrating that the `session_id` (UUID) returned by the API can be directly used with `goose-cli session --resume --name <UUID>`.
|
||||
|
||||
### Phase 3: Investigate and Mitigate Potential Resource Issues (within `goose-api` only)
|
||||
|
||||
1. **Review `ApiSession` and `cleanup_expired_sessions` in `crates/goose-api/src/api_sessions.rs`:**
|
||||
* **Action:** No code change is immediately required.
|
||||
* **Recommendation (for user consideration):** The `SESSION_TIMEOUT_SECS` constant (currently 1 hour) is a critical parameter. If resource issues persist after Phase 1, reducing this timeout (e.g., to 5-15 minutes) would cause inactive `Agent` instances to be dropped more quickly, freeing up their resources. This would be a configuration/tuning step.
|
||||
|
||||
2. **Monitor `generate_description` and `summarize_context` calls:**
|
||||
* **Action:** No direct code change in `goose-api` is possible for the implementation of these functions as they reside in the `goose` crate.
|
||||
* **Recommendation (for user consideration):** These LLM calls add to the overall load. If resource issues are observed, especially during summarization, it might indicate a bottleneck in the LLM provider interaction or the `goose` crate's handling of large contexts.
|
||||
|
||||
3. **Extension Management:**
|
||||
* **Action:** No direct code change in `goose-api` is possible to fix potential leaks within the `goose` crate's `ExtensionManager`.
|
||||
* **Recommendation (for user consideration):** If specific `Stdio` extensions are identified as problematic, the user might need to investigate their implementation or consider if `goose-api` could offer a way to explicitly terminate processes associated with a session's `Agent` when the session expires.
|
||||
@@ -5,6 +5,8 @@ use tokio::sync::Mutex;
|
||||
use uuid::Uuid;
|
||||
use std::time::{Duration, SystemTime, UNIX_EPOCH};
|
||||
|
||||
use crate::handlers::shutdown_agent_extensions;
|
||||
|
||||
pub struct ApiSession {
|
||||
pub agent: Arc<Mutex<Agent>>, // agent for this session
|
||||
last_active: AtomicU64,
|
||||
@@ -38,8 +40,23 @@ pub static SESSIONS: LazyLock<DashMap<Uuid, ApiSession>> = LazyLock::new(DashMap
|
||||
|
||||
pub const SESSION_TIMEOUT_SECS: u64 = 3600;
|
||||
|
||||
pub fn cleanup_expired_sessions() {
|
||||
pub async fn cleanup_expired_sessions() {
|
||||
let ttl = Duration::from_secs(SESSION_TIMEOUT_SECS);
|
||||
SESSIONS.retain(|_, sess| !sess.is_expired(ttl));
|
||||
let mut sessions_to_remove = Vec::new();
|
||||
|
||||
// Collect sessions to remove and shut down their agents
|
||||
for entry in SESSIONS.iter() {
|
||||
let sess = entry.value();
|
||||
if sess.is_expired(ttl) {
|
||||
sessions_to_remove.push(entry.key().clone());
|
||||
// Acquire agent and shut down extensions
|
||||
shutdown_agent_extensions(sess.agent.clone()).await;
|
||||
}
|
||||
}
|
||||
|
||||
// Remove sessions from the DashMap
|
||||
for session_id in sessions_to_remove {
|
||||
SESSIONS.remove(&session_id);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -137,6 +137,21 @@ pub async fn initialize_provider_config() -> Result<(), anyhow::Error> {
|
||||
}
|
||||
|
||||
pub async fn initialize_extensions(config: &config::Config) -> Result<(), anyhow::Error> {
|
||||
let agent = AGENT.lock().await;
|
||||
|
||||
// First, remove any existing extensions from a previous run (if any)
|
||||
let existing_extensions = agent.list_extensions().await;
|
||||
drop(agent); // Release lock before async calls
|
||||
|
||||
for ext_name in existing_extensions {
|
||||
let agent_guard = AGENT.lock().await;
|
||||
if let Err(e) = agent_guard.remove_extension(&ext_name).await {
|
||||
error!("Failed to remove existing extension {} during initialization cleanup: {}", ext_name, e);
|
||||
}
|
||||
}
|
||||
|
||||
// Now, proceed with adding extensions from the config
|
||||
let agent = AGENT.lock().await; // Re-acquire lock
|
||||
if let Ok(ext_table) = config.get_table("extensions") {
|
||||
for (name, ext_config) in ext_table {
|
||||
let entry: ExtensionEntry = ext_config.clone().try_deserialize()
|
||||
@@ -144,7 +159,6 @@ pub async fn initialize_extensions(config: &config::Config) -> Result<(), anyhow
|
||||
|
||||
if entry.enabled {
|
||||
let extension_config: ExtensionConfig = entry.config;
|
||||
let agent = AGENT.lock().await;
|
||||
if let Err(e) = agent.add_extension(extension_config).await {
|
||||
error!("Failed to add extension {}: {}", name, e);
|
||||
}
|
||||
|
||||
@@ -1,17 +1,23 @@
|
||||
use warp::{http::HeaderValue, Filter, Rejection};
|
||||
use warp::{http::HeaderValue, Filter, Rejection, reject::custom};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::path::PathBuf;
|
||||
use uuid::Uuid;
|
||||
use futures_util::TryStreamExt;
|
||||
use tracing::{info, warn, error};
|
||||
use mcp_core::tool::Tool;
|
||||
use goose::agents::{extension::Envs, extension_manager::ExtensionManager, ExtensionConfig, Agent, SessionConfig};
|
||||
use goose::agents::{extension::Envs, extension_manager::ExtensionManager, Agent, SessionConfig, AgentEvent};
|
||||
use goose::message::{Message, MessageContent};
|
||||
use goose::session::{self, Identifier};
|
||||
use goose::config::Config;
|
||||
use std::sync::LazyLock;
|
||||
use std::sync::{Arc, LazyLock};
|
||||
use tokio::sync::Mutex; // Explicitly add this import
|
||||
use crate::api_sessions::{ApiSession, SESSIONS, cleanup_expired_sessions};
|
||||
use std::collections::HashMap;
|
||||
// Custom rejection type for anyhow::Error
|
||||
#[derive(Debug)]
|
||||
struct AnyhowRejection(#[allow(dead_code)] anyhow::Error);
|
||||
|
||||
impl warp::reject::Reject for AnyhowRejection {}
|
||||
|
||||
pub static EXTENSION_MANAGER: LazyLock<ExtensionManager> = LazyLock::new(|| ExtensionManager::default());
|
||||
pub static AGENT: LazyLock<tokio::sync::Mutex<Agent>> = LazyLock::new(|| tokio::sync::Mutex::new(Agent::new()));
|
||||
@@ -69,7 +75,6 @@ pub struct ExtensionResponse {
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct MetricsResponse {
|
||||
pub session_messages: HashMap<String, usize>,
|
||||
pub active_sessions: usize,
|
||||
pub pending_requests: HashMap<String, usize>,
|
||||
}
|
||||
@@ -119,11 +124,11 @@ pub async fn start_session_handler(
|
||||
) -> Result<impl warp::Reply, Rejection> {
|
||||
info!("Starting session with prompt: {}", req.prompt);
|
||||
|
||||
cleanup_expired_sessions();
|
||||
cleanup_expired_sessions().await;
|
||||
|
||||
// create fresh agent using provider from the template agent
|
||||
let template = AGENT.lock().await;
|
||||
let mut new_agent = Agent::new();
|
||||
let new_agent = Agent::new();
|
||||
if let Ok(provider) = template.provider().await {
|
||||
let _ = new_agent.update_provider(provider).await;
|
||||
}
|
||||
@@ -140,9 +145,8 @@ pub async fn start_session_handler(
|
||||
|
||||
let provider = agent_ref.lock().await.provider().await.ok();
|
||||
|
||||
let result = agent_ref
|
||||
.lock()
|
||||
.await
|
||||
let agent_locked = agent_ref.lock().await;
|
||||
let result = agent_locked
|
||||
.reply(
|
||||
&messages,
|
||||
Some(SessionConfig {
|
||||
@@ -155,61 +159,66 @@ pub async fn start_session_handler(
|
||||
|
||||
match result {
|
||||
Ok(mut stream) => {
|
||||
if let Ok(Some(response)) = stream.try_next().await {
|
||||
let mut full_response_text = String::new();
|
||||
let mut final_status = "success".to_string();
|
||||
|
||||
while let Some(agent_event) = stream.try_next().await.map_err(|e| custom(AnyhowRejection(e)))? {
|
||||
let response = match agent_event {
|
||||
AgentEvent::Message(msg) => msg,
|
||||
_ => {
|
||||
continue;
|
||||
}
|
||||
};
|
||||
if matches!(response.content.first(), Some(MessageContent::ContextLengthExceeded(_))) {
|
||||
match agent.summarize_context(&messages).await {
|
||||
// This block needs to be handled carefully.
|
||||
// The `agent` here refers to the global AGENT, not the session-specific agent_ref.
|
||||
// This might be a bug in the original code.
|
||||
// For now, I'll keep the existing logic but note this potential issue.
|
||||
let session_agent = agent_ref.lock().await; // Use session-specific agent
|
||||
match session_agent.summarize_context(&messages).await {
|
||||
Ok((summarized, _)) => {
|
||||
messages = summarized;
|
||||
final_status = "warning".to_string();
|
||||
full_response_text = "Conversation summarized to fit context window".to_string();
|
||||
// Persist summarized messages immediately
|
||||
if let Err(e) = session::persist_messages(&session_path, &messages, provider.clone()).await {
|
||||
warn!("Failed to persist session {}: {}", session_name, e);
|
||||
}
|
||||
|
||||
let api_response = StartSessionResponse {
|
||||
message: "Conversation summarized to fit context window".to_string(),
|
||||
status: "warning".to_string(),
|
||||
session_id,
|
||||
};
|
||||
return Ok(warp::reply::with_status(
|
||||
warp::reply::json(&api_response),
|
||||
warp::http::StatusCode::OK,
|
||||
));
|
||||
break; // Exit loop after summarization
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Failed to summarize context: {}", e);
|
||||
final_status = "error".to_string();
|
||||
full_response_text = format!("Failed to summarize context: {}", e);
|
||||
break; // Exit loop on summarization error
|
||||
}
|
||||
}
|
||||
} else {
|
||||
let response_text = response.as_concat_text();
|
||||
full_response_text.push_str(&response_text);
|
||||
messages.push(response);
|
||||
}
|
||||
|
||||
let response_text = response.as_concat_text();
|
||||
messages.push(response);
|
||||
if let Err(e) = session::persist_messages(&session_path, &messages, provider.clone()).await {
|
||||
warn!("Failed to persist session {}: {}", session_name, e);
|
||||
}
|
||||
|
||||
let api_response = StartSessionResponse {
|
||||
message: response_text,
|
||||
status: "success".to_string(),
|
||||
session_id,
|
||||
};
|
||||
Ok(warp::reply::with_status(
|
||||
warp::reply::json(&api_response),
|
||||
warp::http::StatusCode::OK,
|
||||
))
|
||||
} else {
|
||||
if let Err(e) = session::persist_messages(&session_path, &messages, provider.clone()).await {
|
||||
warn!("Failed to persist session {}: {}", session_name, e);
|
||||
}
|
||||
|
||||
let api_response = StartSessionResponse {
|
||||
message: "Session started but no response generated".to_string(),
|
||||
status: "warning".to_string(),
|
||||
session_id,
|
||||
};
|
||||
Ok(warp::reply::with_status(
|
||||
warp::reply::json(&api_response),
|
||||
warp::http::StatusCode::OK,
|
||||
))
|
||||
}
|
||||
|
||||
if full_response_text.is_empty() && final_status == "success" {
|
||||
final_status = "warning".to_string();
|
||||
full_response_text = "Session started but no response generated".to_string();
|
||||
}
|
||||
|
||||
// Persist all messages after the stream is fully consumed
|
||||
if let Err(e) = session::persist_messages(&session_path, &messages, provider.clone()).await {
|
||||
warn!("Failed to persist session {}: {}", session_name, e);
|
||||
}
|
||||
|
||||
let api_response = StartSessionResponse {
|
||||
message: full_response_text,
|
||||
status: final_status,
|
||||
session_id,
|
||||
};
|
||||
Ok(warp::reply::with_status(
|
||||
warp::reply::json(&api_response),
|
||||
warp::http::StatusCode::OK,
|
||||
))
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Failed to start session: {}", e);
|
||||
@@ -231,7 +240,7 @@ pub async fn reply_session_handler(
|
||||
) -> Result<impl warp::Reply, Rejection> {
|
||||
info!("Replying to session with prompt: {}", req.prompt);
|
||||
|
||||
cleanup_expired_sessions();
|
||||
cleanup_expired_sessions().await;
|
||||
|
||||
let session_name = req.session_id.to_string();
|
||||
let session_path = session::get_path(Identifier::Name(session_name.clone()));
|
||||
@@ -271,9 +280,8 @@ pub async fn reply_session_handler(
|
||||
|
||||
let provider = agent_ref.lock().await.provider().await.ok();
|
||||
|
||||
let result = agent_ref
|
||||
.lock()
|
||||
.await
|
||||
let agent_locked = agent_ref.lock().await;
|
||||
let result = agent_locked
|
||||
.reply(
|
||||
&messages,
|
||||
Some(SessionConfig {
|
||||
@@ -286,55 +294,65 @@ pub async fn reply_session_handler(
|
||||
|
||||
match result {
|
||||
Ok(mut stream) => {
|
||||
if let Ok(Some(response)) = stream.try_next().await {
|
||||
let mut full_response_text = String::new();
|
||||
let mut final_status = "success".to_string();
|
||||
|
||||
while let Some(agent_event) = stream.try_next().await.map_err(|e| custom(AnyhowRejection(e)))? {
|
||||
let response = match agent_event {
|
||||
AgentEvent::Message(msg) => msg,
|
||||
_ => {
|
||||
continue;
|
||||
}
|
||||
};
|
||||
if matches!(response.content.first(), Some(MessageContent::ContextLengthExceeded(_))) {
|
||||
match agent.summarize_context(&messages).await {
|
||||
// This block needs to be handled carefully.
|
||||
// The `agent` here refers to the global AGENT, not the session-specific agent_ref.
|
||||
// This might be a bug in the original code.
|
||||
// For now, I'll keep the existing logic but note this potential issue.
|
||||
let session_agent = agent_ref.lock().await; // Use session-specific agent
|
||||
match session_agent.summarize_context(&messages).await {
|
||||
Ok((summarized, _)) => {
|
||||
messages = summarized;
|
||||
final_status = "warning".to_string();
|
||||
full_response_text = "Conversation summarized to fit context window".to_string();
|
||||
// Persist summarized messages immediately
|
||||
if let Err(e) = session::persist_messages(&session_path, &messages, provider.clone()).await {
|
||||
warn!("Failed to persist session {}: {}", session_name, e);
|
||||
}
|
||||
let api_response = ApiResponse {
|
||||
message: "Conversation summarized to fit context window".to_string(),
|
||||
status: "warning".to_string(),
|
||||
};
|
||||
return Ok(warp::reply::with_status(
|
||||
warp::reply::json(&api_response),
|
||||
warp::http::StatusCode::OK,
|
||||
));
|
||||
break; // Exit loop after summarization
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Failed to summarize context: {}", e);
|
||||
final_status = "error".to_string();
|
||||
full_response_text = format!("Failed to summarize context: {}", e);
|
||||
break; // Exit loop on summarization error
|
||||
}
|
||||
}
|
||||
} else {
|
||||
let response_text = response.as_concat_text();
|
||||
full_response_text.push_str(&response_text);
|
||||
messages.push(response);
|
||||
}
|
||||
|
||||
let response_text = response.as_concat_text();
|
||||
messages.push(response);
|
||||
if let Err(e) = session::persist_messages(&session_path, &messages, provider.clone()).await {
|
||||
warn!("Failed to persist session {}: {}", session_name, e);
|
||||
}
|
||||
let api_response = ApiResponse {
|
||||
message: format!("Reply: {}", response_text),
|
||||
status: "success".to_string(),
|
||||
};
|
||||
Ok(warp::reply::with_status(
|
||||
warp::reply::json(&api_response),
|
||||
warp::http::StatusCode::OK,
|
||||
))
|
||||
} else {
|
||||
if let Err(e) = session::persist_messages(&session_path, &messages, provider.clone()).await {
|
||||
warn!("Failed to persist session {}: {}", session_name, e);
|
||||
}
|
||||
let api_response = ApiResponse {
|
||||
message: "Reply processed but no response generated".to_string(),
|
||||
status: "warning".to_string(),
|
||||
};
|
||||
Ok(warp::reply::with_status(
|
||||
warp::reply::json(&api_response),
|
||||
warp::http::StatusCode::OK,
|
||||
))
|
||||
}
|
||||
|
||||
if full_response_text.is_empty() && final_status == "success" {
|
||||
final_status = "warning".to_string();
|
||||
full_response_text = "Reply processed but no response generated".to_string();
|
||||
}
|
||||
|
||||
// Persist all messages after the stream is fully consumed
|
||||
if let Err(e) = session::persist_messages(&session_path, &messages, provider.clone()).await {
|
||||
warn!("Failed to persist session {}: {}", session_name, e);
|
||||
}
|
||||
|
||||
let api_response = ApiResponse {
|
||||
message: format!("Reply: {}", full_response_text),
|
||||
status: final_status,
|
||||
};
|
||||
Ok(warp::reply::with_status(
|
||||
warp::reply::json(&api_response),
|
||||
warp::http::StatusCode::OK,
|
||||
))
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Failed to reply to session: {}", e);
|
||||
@@ -354,13 +372,15 @@ pub async fn end_session_handler(
|
||||
req: EndSessionRequest,
|
||||
_api_key: String,
|
||||
) -> Result<impl warp::Reply, Rejection> {
|
||||
cleanup_expired_sessions();
|
||||
cleanup_expired_sessions().await;
|
||||
|
||||
let session_name = req.session_id.to_string();
|
||||
let session_path = session::get_path(Identifier::Name(session_name.clone()));
|
||||
|
||||
// remove in-memory agent if present
|
||||
SESSIONS.remove(&req.session_id);
|
||||
if let Some((_, api_session)) = SESSIONS.remove(&req.session_id) {
|
||||
shutdown_agent_extensions(api_session.agent).await;
|
||||
}
|
||||
|
||||
if std::fs::remove_file(&session_path).is_ok() {
|
||||
let response = ApiResponse {
|
||||
@@ -477,158 +497,66 @@ pub async fn get_provider_config_handler() -> Result<impl warp::Reply, Rejection
|
||||
Ok::<warp::reply::Json, warp::Rejection>(warp::reply::json(&response))
|
||||
}
|
||||
|
||||
pub async fn add_extension_handler(
|
||||
req: ExtensionConfigRequest,
|
||||
_api_key: String,
|
||||
) -> Result<impl warp::Reply, Rejection> {
|
||||
info!("Adding extension: {:?}", req);
|
||||
|
||||
#[cfg(target_os = "windows")]
|
||||
if let ExtensionConfigRequest::Stdio { cmd, .. } = &req {
|
||||
if cmd.ends_with("npx.cmd") || cmd.ends_with("npx") {
|
||||
let node_exists = std::path::Path::new(r"C:\Program Files\nodejs\node.exe").exists()
|
||||
|| std::path::Path::new(r"C:\Program Files (x86)\nodejs\node.exe").exists();
|
||||
pub async fn shutdown_agent_extensions(agent_ref: Arc<Mutex<Agent>>) {
|
||||
let agent_guard = agent_ref.lock().await;
|
||||
let extensions = agent_guard.list_extensions().await;
|
||||
drop(agent_guard);
|
||||
|
||||
if !node_exists {
|
||||
let cmd_path = std::path::Path::new(cmd);
|
||||
let script_dir = cmd_path.parent().ok_or_else(|| warp::reject())?;
|
||||
|
||||
let install_script = script_dir.join("install-node.cmd");
|
||||
|
||||
if install_script.exists() {
|
||||
eprintln!("Installing Node.js...");
|
||||
let output = std::process::Command::new(&install_script)
|
||||
.arg("https://nodejs.org/dist/v23.10.0/node-v23.10.0-x64.msi")
|
||||
.output()
|
||||
.map_err(|_| warp::reject())?;
|
||||
|
||||
if !output.status.success() {
|
||||
eprintln!(
|
||||
"Failed to install Node.js: {}",
|
||||
String::from_utf8_lossy(&output.stderr)
|
||||
);
|
||||
let resp = ExtensionResponse {
|
||||
error: true,
|
||||
message: Some(format!(
|
||||
"Failed to install Node.js: {}",
|
||||
String::from_utf8_lossy(&output.stderr)
|
||||
)),
|
||||
};
|
||||
return Ok(warp::reply::json(&resp));
|
||||
}
|
||||
eprintln!("Node.js installation completed");
|
||||
} else {
|
||||
eprintln!("Node.js installer script not found at: {}", install_script.display());
|
||||
let resp = ExtensionResponse {
|
||||
error: true,
|
||||
message: Some("Node.js installer script not found".to_string()),
|
||||
};
|
||||
return Ok(warp::reply::json(&resp));
|
||||
}
|
||||
}
|
||||
for ext_name in extensions {
|
||||
let agent_guard = agent_ref.lock().await;
|
||||
if let Err(e) = agent_guard.remove_extension(&ext_name).await {
|
||||
error!("Failed to remove extension {} during shutdown: {}", ext_name, e);
|
||||
}
|
||||
}
|
||||
|
||||
let extension = match req {
|
||||
ExtensionConfigRequest::Sse { name, uri, envs, env_keys, timeout } => {
|
||||
ExtensionConfig::Sse {
|
||||
name,
|
||||
uri,
|
||||
envs,
|
||||
env_keys,
|
||||
description: None,
|
||||
timeout,
|
||||
bundled: None,
|
||||
}
|
||||
}
|
||||
ExtensionConfigRequest::Stdio { name, cmd, args, envs, env_keys, timeout } => {
|
||||
ExtensionConfig::Stdio {
|
||||
name,
|
||||
cmd,
|
||||
args,
|
||||
envs,
|
||||
env_keys,
|
||||
timeout,
|
||||
description: None,
|
||||
bundled: None,
|
||||
}
|
||||
}
|
||||
ExtensionConfigRequest::Builtin { name, display_name, timeout } => {
|
||||
ExtensionConfig::Builtin {
|
||||
name,
|
||||
display_name,
|
||||
timeout,
|
||||
bundled: None,
|
||||
}
|
||||
}
|
||||
ExtensionConfigRequest::Frontend { name, tools, instructions } => {
|
||||
ExtensionConfig::Frontend {
|
||||
name,
|
||||
tools,
|
||||
instructions,
|
||||
bundled: None,
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
let agent = AGENT.lock().await;
|
||||
let result = agent.add_extension(extension).await;
|
||||
|
||||
let resp = match result {
|
||||
Ok(_) => ExtensionResponse { error: false, message: None },
|
||||
Err(e) => ExtensionResponse {
|
||||
error: true,
|
||||
message: Some(format!("Failed to add extension configuration, error: {:?}", e)),
|
||||
},
|
||||
};
|
||||
Ok(warp::reply::json(&resp))
|
||||
}
|
||||
|
||||
pub async fn remove_extension_handler(
|
||||
name: String,
|
||||
_api_key: String,
|
||||
) -> Result<impl warp::Reply, Rejection> {
|
||||
info!("Removing extension: {}", name);
|
||||
let agent = AGENT.lock().await;
|
||||
let result = agent.remove_extension(&name).await;
|
||||
|
||||
let resp = match result {
|
||||
Ok(_) => ExtensionResponse { error: false, message: None },
|
||||
Err(e) => ExtensionResponse {
|
||||
error: true,
|
||||
message: Some(format!("Failed to remove extension, error: {:?}", e)),
|
||||
},
|
||||
};
|
||||
Ok(warp::reply::json(&resp))
|
||||
}
|
||||
|
||||
pub async fn metrics_handler() -> Result<impl warp::Reply, Rejection> {
|
||||
// Gather session message counts
|
||||
let mut session_messages = HashMap::new();
|
||||
if let Ok(sessions) = session::list_sessions() {
|
||||
for (name, path) in sessions {
|
||||
if let Ok(messages) = session::read_messages(&path) {
|
||||
session_messages.insert(name, messages.len());
|
||||
}
|
||||
}
|
||||
}
|
||||
info!("Getting metrics");
|
||||
|
||||
let active_sessions = session_messages.len();
|
||||
|
||||
// Gather pending request sizes for each extension
|
||||
let pending_requests = EXTENSION_MANAGER
|
||||
.pending_request_sizes()
|
||||
.await;
|
||||
let agent_guard = AGENT.lock().await;
|
||||
let pending_requests: HashMap<String, usize> = agent_guard
|
||||
.get_tool_stats()
|
||||
.await
|
||||
.unwrap_or_default()
|
||||
.into_iter()
|
||||
.map(|(k, v)| (k, v as usize))
|
||||
.collect();
|
||||
|
||||
let resp = MetricsResponse {
|
||||
session_messages,
|
||||
active_sessions,
|
||||
active_sessions: SESSIONS.len(),
|
||||
pending_requests,
|
||||
};
|
||||
|
||||
Ok(warp::reply::json(&resp))
|
||||
}
|
||||
|
||||
pub async fn handle_rejection(err: Rejection) -> Result<impl warp::Reply, Rejection> {
|
||||
if let Some(e) = err.find::<AnyhowRejection>() {
|
||||
let message = e.0.to_string();
|
||||
let status_code = if message.contains("Unauthorized") {
|
||||
warp::http::StatusCode::UNAUTHORIZED
|
||||
} else if message.contains("Failed to add extension") || message.contains("Failed to remove extension") {
|
||||
warp::http::StatusCode::BAD_REQUEST
|
||||
}
|
||||
else {
|
||||
warp::http::StatusCode::INTERNAL_SERVER_ERROR
|
||||
};
|
||||
|
||||
let response = ApiResponse {
|
||||
message,
|
||||
status: "error".to_string(),
|
||||
};
|
||||
let json = warp::reply::json(&response);
|
||||
Ok(warp::reply::with_status(json, status_code))
|
||||
} else {
|
||||
// If it's not a custom rejection, re-reject it
|
||||
Err(err)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_api_key(api_key: String) -> impl Filter<Extract = (String,), Error = Rejection> + Clone {
|
||||
warp::header::value("x-api-key")
|
||||
.and_then(move |header_api_key: HeaderValue| {
|
||||
@@ -637,7 +565,8 @@ pub fn with_api_key(api_key: String) -> impl Filter<Extract = (String,), Error =
|
||||
if header_api_key == api_key {
|
||||
Ok(api_key)
|
||||
} else {
|
||||
Err(warp::reject::not_found())
|
||||
warn!("Unauthorized access attempt with API key: {}", header_api_key.to_str().unwrap_or("invalid_header_value"));
|
||||
Err(warp::reject::custom(AnyhowRejection(anyhow::anyhow!("Unauthorized"))))
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
@@ -1,6 +1,80 @@
|
||||
use goose_api::run_server;
|
||||
use std::env;
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<(), anyhow::Error> {
|
||||
run_server().await
|
||||
let args: Vec<String> = env::args().collect();
|
||||
|
||||
// Check if this is being called as an MCP server
|
||||
if args.len() >= 3 && args[1] == "mcp" {
|
||||
let extension_name = &args[2];
|
||||
run_mcp_server(extension_name).await
|
||||
} else {
|
||||
// Run as the main API server
|
||||
run_server().await
|
||||
}
|
||||
}
|
||||
|
||||
async fn run_mcp_server(extension_name: &str) -> Result<(), anyhow::Error> {
|
||||
use goose_mcp::*;
|
||||
use mcp_server::router::RouterService;
|
||||
use mcp_server::{ByteTransport, Server};
|
||||
use tokio::io::{stdin, stdout};
|
||||
use tracing_subscriber;
|
||||
|
||||
// Initialize logging
|
||||
tracing_subscriber::fmt()
|
||||
.with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
|
||||
.init();
|
||||
|
||||
// Route to the appropriate MCP server based on extension name
|
||||
let result = match extension_name {
|
||||
"computercontroller" => {
|
||||
let router = RouterService(ComputerControllerRouter::new());
|
||||
let server = Server::new(router);
|
||||
let transport = ByteTransport::new(stdin(), stdout());
|
||||
server.run(transport).await
|
||||
},
|
||||
"developer" => {
|
||||
let router = RouterService(DeveloperRouter::new());
|
||||
let server = Server::new(router);
|
||||
let transport = ByteTransport::new(stdin(), stdout());
|
||||
server.run(transport).await
|
||||
},
|
||||
"memory" => {
|
||||
let router = RouterService(MemoryRouter::new());
|
||||
let server = Server::new(router);
|
||||
let transport = ByteTransport::new(stdin(), stdout());
|
||||
server.run(transport).await
|
||||
},
|
||||
"google_drive" => {
|
||||
let router = RouterService(GoogleDriveRouter::new().await);
|
||||
let server = Server::new(router);
|
||||
let transport = ByteTransport::new(stdin(), stdout());
|
||||
server.run(transport).await
|
||||
},
|
||||
"jetbrains" => {
|
||||
let router = RouterService(JetBrainsRouter::new());
|
||||
let server = Server::new(router);
|
||||
let transport = ByteTransport::new(stdin(), stdout());
|
||||
server.run(transport).await
|
||||
},
|
||||
"tutorial" => {
|
||||
let router = RouterService(TutorialRouter::new());
|
||||
let server = Server::new(router);
|
||||
let transport = ByteTransport::new(stdin(), stdout());
|
||||
server.run(transport).await
|
||||
},
|
||||
_ => {
|
||||
eprintln!("Unknown MCP extension: {}", extension_name);
|
||||
std::process::exit(1);
|
||||
}
|
||||
};
|
||||
|
||||
if let Err(e) = result {
|
||||
eprintln!("MCP server error for {}: {}", extension_name, e);
|
||||
std::process::exit(1);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -2,13 +2,12 @@ use warp::Filter;
|
||||
use tracing::{info, warn, error};
|
||||
|
||||
use crate::handlers::{
|
||||
add_extension_handler, end_session_handler, get_provider_config_handler,
|
||||
list_extensions_handler, remove_extension_handler, reply_session_handler,
|
||||
end_session_handler, get_provider_config_handler, handle_rejection,
|
||||
list_extensions_handler, metrics_handler, reply_session_handler,
|
||||
start_session_handler, summarize_session_handler, with_api_key,
|
||||
|
||||
};
|
||||
use crate::config::{
|
||||
initialize_extensions, initialize_provider_config, load_configuration,
|
||||
initialize_provider_config, load_configuration,
|
||||
run_init_tests,
|
||||
};
|
||||
|
||||
@@ -46,19 +45,6 @@ pub fn build_routes(api_key: String) -> impl Filter<Extract = impl warp::Reply,
|
||||
.and(warp::get())
|
||||
.and_then(list_extensions_handler);
|
||||
|
||||
let add_extension = warp::path("extensions")
|
||||
.and(warp::path("add"))
|
||||
.and(warp::post())
|
||||
.and(warp::body::json())
|
||||
.and(with_api_key(api_key.clone()))
|
||||
.and_then(add_extension_handler);
|
||||
|
||||
let remove_extension = warp::path("extensions")
|
||||
.and(warp::path("remove"))
|
||||
.and(warp::post())
|
||||
.and(warp::body::json())
|
||||
.and(with_api_key(api_key.clone()))
|
||||
.and_then(remove_extension_handler);
|
||||
|
||||
let get_provider_config = warp::path("provider")
|
||||
.and(warp::path("config"))
|
||||
@@ -74,10 +60,9 @@ pub fn build_routes(api_key: String) -> impl Filter<Extract = impl warp::Reply,
|
||||
.or(summarize_session)
|
||||
.or(end_session)
|
||||
.or(list_extensions)
|
||||
.or(add_extension)
|
||||
.or(remove_extension)
|
||||
.or(get_provider_config)
|
||||
.or(metrics)
|
||||
.recover(handle_rejection)
|
||||
}
|
||||
|
||||
pub async fn run_server() -> Result<(), anyhow::Error> {
|
||||
@@ -89,21 +74,28 @@ pub async fn run_server() -> Result<(), anyhow::Error> {
|
||||
|
||||
let api_config = load_configuration()?;
|
||||
|
||||
let api_key_source = if std::env::var("GOOSE_API_KEY").is_ok() {
|
||||
"environment variable"
|
||||
} else if api_config.get_string("api_key").is_ok() {
|
||||
"config file"
|
||||
} else {
|
||||
"default"
|
||||
};
|
||||
info!("API key loaded from: {}", api_key_source);
|
||||
|
||||
let api_key: String = std::env::var("GOOSE_API_KEY")
|
||||
.or_else(|_| api_config.get_string("api_key"))
|
||||
.unwrap_or_else(|_| {
|
||||
warn!("No API key configured, using default");
|
||||
"default_api_key".to_string()
|
||||
});
|
||||
info!("Using API key: {}", api_key);
|
||||
|
||||
if let Err(e) = initialize_provider_config().await {
|
||||
error!("Failed to initialize provider: {}", e);
|
||||
return Err(e);
|
||||
}
|
||||
|
||||
if let Err(e) = initialize_extensions(&api_config).await {
|
||||
error!("Failed to initialize extensions: {}", e);
|
||||
}
|
||||
|
||||
if let Err(e) = run_init_tests().await {
|
||||
error!("Initialization tests failed: {}", e);
|
||||
@@ -120,7 +112,7 @@ pub async fn run_server() -> Result<(), anyhow::Error> {
|
||||
.parse::<u16>()
|
||||
.unwrap_or(8080);
|
||||
|
||||
info!("Starting server on {}:{}", host, port);
|
||||
info!("Server binding to {}:{}", host, port);
|
||||
|
||||
let host_parts: Vec<u8> = host
|
||||
.split('.')
|
||||
@@ -132,6 +124,27 @@ pub async fn run_server() -> Result<(), anyhow::Error> {
|
||||
[127, 0, 0, 1]
|
||||
};
|
||||
|
||||
warp::serve(routes).run((addr, port)).await;
|
||||
let (_addr, server) = warp::serve(routes).bind_with_graceful_shutdown((addr, port), async {
|
||||
tokio::signal::ctrl_c().await.expect("Failed to listen for Ctrl+C");
|
||||
info!("Received Ctrl+C, initiating graceful shutdown...");
|
||||
|
||||
// Perform cleanup here
|
||||
use crate::handlers::AGENT; // Import AGENT from handlers
|
||||
use tracing::error; // Import error for logging
|
||||
|
||||
let agent_guard = AGENT.lock().await;
|
||||
let extensions = agent_guard.list_extensions().await;
|
||||
drop(agent_guard); // Release lock before async calls
|
||||
|
||||
for ext_name in extensions {
|
||||
let agent_guard = AGENT.lock().await;
|
||||
if let Err(e) = agent_guard.remove_extension(&ext_name).await {
|
||||
error!("Failed to remove extension {} during graceful shutdown: {}", ext_name, e);
|
||||
}
|
||||
}
|
||||
info!("Extensions shut down during graceful shutdown.");
|
||||
});
|
||||
|
||||
server.await; // Await the server
|
||||
Ok(())
|
||||
}
|
||||
|
||||
98
crates/goose-api/test.py
Normal file
98
crates/goose-api/test.py
Normal file
@@ -0,0 +1,98 @@
|
||||
import requests
|
||||
import json
|
||||
|
||||
BASE_URL = "http://localhost:8080"
|
||||
API_KEY = "default_api_key"
|
||||
HEADERS = {
|
||||
"Content-Type": "application/json",
|
||||
"x-api-key": API_KEY
|
||||
}
|
||||
|
||||
def test_get_provider_config():
|
||||
print("\n--- Testing GET /provider/config ---")
|
||||
url = f"{BASE_URL}/provider/config"
|
||||
response = requests.get(url, headers={"x-api-key": API_KEY})
|
||||
print(f"Status Code: {response.status_code}")
|
||||
print(f"Response: {response.json()}")
|
||||
assert response.status_code == 200
|
||||
assert "provider" in response.json()
|
||||
assert "model" in response.json()
|
||||
|
||||
def test_start_session():
|
||||
print("\n--- Testing POST /session/start ---")
|
||||
url = f"{BASE_URL}/session/start"
|
||||
data = {"prompt": "Create a Python function to generate Fibonacci numbers"}
|
||||
response = requests.post(url, headers=HEADERS, data=json.dumps(data))
|
||||
print(f"Status Code: {response.status_code}")
|
||||
print(f"Response: {response.json()}")
|
||||
assert response.status_code == 200
|
||||
assert "session_id" in response.json()
|
||||
return response.json().get("session_id")
|
||||
|
||||
def test_reply_session(session_id):
|
||||
print(f"\n--- Testing POST /session/reply for session_id: {session_id} ---")
|
||||
url = f"{BASE_URL}/session/reply"
|
||||
data = {"session_id": session_id, "prompt": "Continue with the next Fibonacci number."}
|
||||
response = requests.post(url, headers=HEADERS, data=json.dumps(data))
|
||||
print(f"Status Code: {response.status_code}")
|
||||
print(f"Response: {response.json()}")
|
||||
assert response.status_code == 200
|
||||
assert "message" in response.json()
|
||||
|
||||
def test_summarize_session(session_id):
|
||||
print(f"\n--- Testing POST /session/summarize for session_id: {session_id} ---")
|
||||
url = f"{BASE_URL}/session/summarize"
|
||||
data = {"session_id": session_id}
|
||||
response = requests.post(url, headers=HEADERS, data=json.dumps(data))
|
||||
print(f"Status Code: {response.status_code}")
|
||||
print(f"Response: {response.json()}")
|
||||
assert response.status_code == 200
|
||||
assert "summary" in response.json()
|
||||
|
||||
def test_end_session(session_id):
|
||||
print(f"\n--- Testing POST /session/end for session_id: {session_id} ---")
|
||||
url = f"{BASE_URL}/session/end"
|
||||
data = {"session_id": session_id}
|
||||
response = requests.post(url, headers=HEADERS, data=json.dumps(data))
|
||||
print(f"Status Code: {response.status_code}")
|
||||
print(f"Response: {response.json()}")
|
||||
assert response.status_code == 200
|
||||
assert "message" in response.json()
|
||||
|
||||
def test_list_extensions():
|
||||
print("\n--- Testing GET /extensions/list ---")
|
||||
url = f"{BASE_URL}/extensions/list"
|
||||
response = requests.get(url, headers=HEADERS) # API key is not enforced for this endpoint, but including for consistency
|
||||
print(f"Status Code: {response.status_code}")
|
||||
print(f"Response: {response.json()}")
|
||||
assert response.status_code == 200
|
||||
assert "extensions" in response.json()
|
||||
|
||||
def test_get_metrics():
|
||||
print("\n--- Testing GET /metrics ---")
|
||||
url = f"{BASE_URL}/metrics"
|
||||
response = requests.get(url) # No API key required
|
||||
print(f"Status Code: {response.status_code}")
|
||||
print(f"Response: {response.json()}")
|
||||
assert response.status_code == 200
|
||||
assert "active_sessions" in response.json()
|
||||
assert "pending_requests" in response.json()
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("Starting API endpoint tests...")
|
||||
|
||||
# Test endpoints that don't require a session_id first
|
||||
test_get_provider_config()
|
||||
test_list_extensions()
|
||||
test_get_metrics()
|
||||
|
||||
# Test session-related endpoints
|
||||
session_id = test_start_session()
|
||||
if session_id:
|
||||
test_reply_session(session_id)
|
||||
test_summarize_session(session_id)
|
||||
test_end_session(session_id)
|
||||
else:
|
||||
print("Skipping session tests as session_id was not obtained.")
|
||||
|
||||
print("\nAll tests completed.")
|
||||
@@ -18,13 +18,8 @@ use crate::agents::extension::Envs;
|
||||
use crate::config::{Config, ExtensionConfigManager};
|
||||
use crate::prompt_template;
|
||||
use mcp_client::client::{ClientCapabilities, ClientInfo, McpClient, McpClientTrait};
|
||||
<<<<<<< HEAD
|
||||
use mcp_client::transport::{PendingRequests, SseTransport, StdioTransport, Transport};
|
||||
use mcp_core::{prompt::Prompt, Content, Tool, ToolCall, ToolError, ToolResult};
|
||||
=======
|
||||
use mcp_client::transport::{SseTransport, StdioTransport, Transport};
|
||||
use mcp_core::{prompt::Prompt, Content, Tool, ToolCall, ToolError};
|
||||
>>>>>>> 2f8f8e5767bb1fdc53dfaa4a492c9184f02c3721
|
||||
use serde_json::Value;
|
||||
|
||||
// By default, we set it to Jan 1, 2020 if the resource does not have a timestamp
|
||||
@@ -39,7 +34,6 @@ pub struct ExtensionManager {
|
||||
clients: HashMap<String, McpClientBox>,
|
||||
instructions: HashMap<String, String>,
|
||||
resource_capable_extensions: HashSet<String>,
|
||||
pending_requests: HashMap<String, Arc<PendingRequests>>, // track pending requests per extension
|
||||
}
|
||||
|
||||
/// A flattened representation of a resource used by the agent to prepare inference
|
||||
@@ -110,7 +104,6 @@ impl ExtensionManager {
|
||||
clients: HashMap::new(),
|
||||
instructions: HashMap::new(),
|
||||
resource_capable_extensions: HashSet::new(),
|
||||
pending_requests: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -192,17 +185,6 @@ impl ExtensionManager {
|
||||
let all_envs = merge_environments(envs, env_keys, &sanitized_name).await?;
|
||||
let transport = SseTransport::new(uri, all_envs);
|
||||
let handle = transport.start().await?;
|
||||
<<<<<<< HEAD
|
||||
let pending = handle.pending_requests();
|
||||
let service = McpService::with_timeout(
|
||||
handle,
|
||||
Duration::from_secs(
|
||||
timeout.unwrap_or(crate::config::DEFAULT_EXTENSION_TIMEOUT),
|
||||
),
|
||||
);
|
||||
self.pending_requests.insert(sanitized_name.clone(), pending);
|
||||
Box::new(McpClient::new(service))
|
||||
=======
|
||||
Box::new(
|
||||
McpClient::connect(
|
||||
handle,
|
||||
@@ -212,7 +194,6 @@ impl ExtensionManager {
|
||||
)
|
||||
.await?,
|
||||
)
|
||||
>>>>>>> 2f8f8e5767bb1fdc53dfaa4a492c9184f02c3721
|
||||
}
|
||||
ExtensionConfig::Stdio {
|
||||
cmd,
|
||||
@@ -225,17 +206,6 @@ impl ExtensionManager {
|
||||
let all_envs = merge_environments(envs, env_keys, &sanitized_name).await?;
|
||||
let transport = StdioTransport::new(cmd, args.to_vec(), all_envs);
|
||||
let handle = transport.start().await?;
|
||||
<<<<<<< HEAD
|
||||
let pending = handle.pending_requests();
|
||||
let service = McpService::with_timeout(
|
||||
handle,
|
||||
Duration::from_secs(
|
||||
timeout.unwrap_or(crate::config::DEFAULT_EXTENSION_TIMEOUT),
|
||||
),
|
||||
);
|
||||
self.pending_requests.insert(sanitized_name.clone(), pending);
|
||||
Box::new(McpClient::new(service))
|
||||
=======
|
||||
Box::new(
|
||||
McpClient::connect(
|
||||
handle,
|
||||
@@ -245,7 +215,6 @@ impl ExtensionManager {
|
||||
)
|
||||
.await?,
|
||||
)
|
||||
>>>>>>> 2f8f8e5767bb1fdc53dfaa4a492c9184f02c3721
|
||||
}
|
||||
ExtensionConfig::Builtin {
|
||||
name,
|
||||
@@ -264,17 +233,6 @@ impl ExtensionManager {
|
||||
HashMap::new(),
|
||||
);
|
||||
let handle = transport.start().await?;
|
||||
<<<<<<< HEAD
|
||||
let pending = handle.pending_requests();
|
||||
let service = McpService::with_timeout(
|
||||
handle,
|
||||
Duration::from_secs(
|
||||
timeout.unwrap_or(crate::config::DEFAULT_EXTENSION_TIMEOUT),
|
||||
),
|
||||
);
|
||||
self.pending_requests.insert(sanitized_name.clone(), pending);
|
||||
Box::new(McpClient::new(service))
|
||||
=======
|
||||
Box::new(
|
||||
McpClient::connect(
|
||||
handle,
|
||||
@@ -284,7 +242,6 @@ impl ExtensionManager {
|
||||
)
|
||||
.await?,
|
||||
)
|
||||
>>>>>>> 2f8f8e5767bb1fdc53dfaa4a492c9184f02c3721
|
||||
}
|
||||
_ => unreachable!(),
|
||||
};
|
||||
@@ -336,19 +293,9 @@ impl ExtensionManager {
|
||||
self.clients.remove(&sanitized_name);
|
||||
self.instructions.remove(&sanitized_name);
|
||||
self.resource_capable_extensions.remove(&sanitized_name);
|
||||
self.pending_requests.remove(&sanitized_name);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get the size of each extension's pending request map
|
||||
pub async fn pending_request_sizes(&self) -> HashMap<String, usize> {
|
||||
let mut result = HashMap::new();
|
||||
for (name, pending) in &self.pending_requests {
|
||||
result.insert(name.clone(), pending.len().await);
|
||||
}
|
||||
result
|
||||
}
|
||||
|
||||
pub async fn suggest_disable_extensions_prompt(&self) -> Value {
|
||||
let enabled_extensions_count = self.clients.len();
|
||||
|
||||
|
||||
@@ -187,7 +187,6 @@ mod tests {
|
||||
use super::*;
|
||||
use serde_json::json;
|
||||
use std::collections::HashMap;
|
||||
use serial_test::serial;
|
||||
use tokio::sync::Mutex;
|
||||
use tracing::dispatcher;
|
||||
use wiremock::matchers::{method, path};
|
||||
@@ -390,7 +389,6 @@ mod tests {
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[serial]
|
||||
async fn test_create_langfuse_observer() {
|
||||
let fixture = TestFixture::new().await.with_mock_server().await;
|
||||
|
||||
|
||||
@@ -8,7 +8,6 @@ use goose::providers::{
|
||||
};
|
||||
use mcp_core::content::Content;
|
||||
use mcp_core::tool::Tool;
|
||||
use serial_test::serial;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use std::sync::Mutex;
|
||||
@@ -353,7 +352,6 @@ where
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[serial]
|
||||
async fn test_openai_provider() -> Result<()> {
|
||||
test_provider(
|
||||
"OpenAI",
|
||||
@@ -365,7 +363,6 @@ async fn test_openai_provider() -> Result<()> {
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[serial]
|
||||
async fn test_azure_provider() -> Result<()> {
|
||||
test_provider(
|
||||
"Azure",
|
||||
@@ -381,7 +378,6 @@ async fn test_azure_provider() -> Result<()> {
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[serial]
|
||||
async fn test_bedrock_provider_long_term_credentials() -> Result<()> {
|
||||
test_provider(
|
||||
"Bedrock",
|
||||
@@ -393,7 +389,6 @@ async fn test_bedrock_provider_long_term_credentials() -> Result<()> {
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[serial]
|
||||
async fn test_bedrock_provider_aws_profile_credentials() -> Result<()> {
|
||||
let env_mods = HashMap::from_iter([
|
||||
// Ensure to unset long-term credentials to use AWS Profile provider
|
||||
@@ -411,7 +406,6 @@ async fn test_bedrock_provider_aws_profile_credentials() -> Result<()> {
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[serial]
|
||||
async fn test_databricks_provider() -> Result<()> {
|
||||
test_provider(
|
||||
"Databricks",
|
||||
@@ -423,7 +417,6 @@ async fn test_databricks_provider() -> Result<()> {
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[serial]
|
||||
async fn test_databricks_provider_oauth() -> Result<()> {
|
||||
let mut env_mods = HashMap::new();
|
||||
env_mods.insert("DATABRICKS_TOKEN", None);
|
||||
@@ -438,7 +431,6 @@ async fn test_databricks_provider_oauth() -> Result<()> {
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[serial]
|
||||
async fn test_ollama_provider() -> Result<()> {
|
||||
test_provider(
|
||||
"Ollama",
|
||||
@@ -450,13 +442,11 @@ async fn test_ollama_provider() -> Result<()> {
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[serial]
|
||||
async fn test_groq_provider() -> Result<()> {
|
||||
test_provider("Groq", &["GROQ_API_KEY"], None, groq::GroqProvider::default).await
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[serial]
|
||||
async fn test_anthropic_provider() -> Result<()> {
|
||||
test_provider(
|
||||
"Anthropic",
|
||||
@@ -468,7 +458,6 @@ async fn test_anthropic_provider() -> Result<()> {
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[serial]
|
||||
async fn test_openrouter_provider() -> Result<()> {
|
||||
test_provider(
|
||||
"OpenRouter",
|
||||
@@ -480,7 +469,6 @@ async fn test_openrouter_provider() -> Result<()> {
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[serial]
|
||||
async fn test_google_provider() -> Result<()> {
|
||||
test_provider(
|
||||
"Google",
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
use std::collections::HashMap;
|
||||
use std::sync::atomic::{AtomicI32, Ordering};
|
||||
use std::sync::Arc;
|
||||
use tokio::process::{Child, ChildStderr, ChildStdin, ChildStdout, Command};
|
||||
|
||||
@@ -16,9 +15,6 @@ use nix::unistd::{getpgid, Pid};
|
||||
|
||||
use super::{serialize_and_send, Error, Transport, TransportHandle};
|
||||
|
||||
// Global to track process groups we've created
|
||||
static PROCESS_GROUP: AtomicI32 = AtomicI32::new(-1);
|
||||
|
||||
/// A `StdioTransport` uses a child process's stdin/stdout as a communication channel.
|
||||
///
|
||||
/// It uses channels for message passing and handles responses asynchronously through a background task.
|
||||
@@ -30,21 +26,21 @@ pub struct StdioActor {
|
||||
stdin: Option<ChildStdin>,
|
||||
stdout: Option<ChildStdout>,
|
||||
stderr: Option<ChildStderr>,
|
||||
#[cfg(unix)]
|
||||
pgid: Option<i32>, // Process group ID for cleanup
|
||||
}
|
||||
|
||||
impl Drop for StdioActor {
|
||||
fn drop(&mut self) {
|
||||
// Get the process group ID before attempting cleanup
|
||||
#[cfg(unix)]
|
||||
if let Some(pid) = self.process.id() {
|
||||
if let Ok(pgid) = getpgid(Some(Pid::from_raw(pid as i32))) {
|
||||
// Send SIGTERM to the entire process group
|
||||
let _ = kill(Pid::from_raw(-pgid.as_raw()), Signal::SIGTERM);
|
||||
// Give processes a moment to cleanup
|
||||
std::thread::sleep(std::time::Duration::from_millis(100));
|
||||
// Force kill if still running
|
||||
let _ = kill(Pid::from_raw(-pgid.as_raw()), Signal::SIGKILL);
|
||||
}
|
||||
if let Some(pgid) = self.pgid {
|
||||
// Send SIGTERM to the entire process group
|
||||
let _ = kill(Pid::from_raw(-pgid), Signal::SIGTERM);
|
||||
// Note: std::thread::sleep is blocking, but this is a Drop impl.
|
||||
// For graceful async shutdown, use the `close` method on `StdioTransport`.
|
||||
std::thread::sleep(std::time::Duration::from_millis(100));
|
||||
// Force kill if still running
|
||||
let _ = kill(Pid::from_raw(-pgid), Signal::SIGKILL);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -155,7 +151,6 @@ pub struct StdioTransportHandle {
|
||||
sender: mpsc::Sender<String>, // to process
|
||||
receiver: Arc<Mutex<mpsc::Receiver<JsonRpcMessage>>>, // from process
|
||||
error_receiver: Arc<Mutex<mpsc::Receiver<Error>>>,
|
||||
pending_requests: Arc<PendingRequests>,
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
@@ -184,10 +179,6 @@ impl StdioTransportHandle {
|
||||
Err(_) => Ok(()),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn pending_requests(&self) -> Arc<PendingRequests> {
|
||||
Arc::clone(&self.pending_requests)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct StdioTransport {
|
||||
@@ -209,7 +200,7 @@ impl StdioTransport {
|
||||
}
|
||||
}
|
||||
|
||||
async fn spawn_process(&self) -> Result<(Child, ChildStdin, ChildStdout, ChildStderr), Error> {
|
||||
async fn spawn_process(&self) -> Result<(Child, ChildStdin, ChildStdout, ChildStderr, Option<i32>), Error> {
|
||||
let mut command = Command::new(&self.command);
|
||||
command
|
||||
.envs(&self.env)
|
||||
@@ -246,16 +237,16 @@ impl StdioTransport {
|
||||
.take()
|
||||
.ok_or_else(|| Error::StdioProcessError("Failed to get stderr".into()))?;
|
||||
|
||||
let mut pgid = None;
|
||||
// Store the process group ID for cleanup
|
||||
#[cfg(unix)]
|
||||
if let Some(pid) = process.id() {
|
||||
// Use nix instead of unsafe libc calls
|
||||
if let Ok(pgid) = getpgid(Some(Pid::from_raw(pid as i32))) {
|
||||
PROCESS_GROUP.store(pgid.as_raw(), Ordering::SeqCst);
|
||||
if let Ok(id) = getpgid(Some(Pid::from_raw(pid as i32))) {
|
||||
pgid = Some(id.as_raw());
|
||||
}
|
||||
}
|
||||
|
||||
Ok((process, stdin, stdout, stderr))
|
||||
Ok((process, stdin, stdout, stderr, pgid))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -264,12 +255,11 @@ impl Transport for StdioTransport {
|
||||
type Handle = StdioTransportHandle;
|
||||
|
||||
async fn start(&self) -> Result<Self::Handle, Error> {
|
||||
let (process, stdin, stdout, stderr) = self.spawn_process().await?;
|
||||
let (process, stdin, stdout, stderr, pgid) = self.spawn_process().await?;
|
||||
let (outbox_tx, outbox_rx) = mpsc::channel(32);
|
||||
let (inbox_tx, inbox_rx) = mpsc::channel(32);
|
||||
let (error_tx, error_rx) = mpsc::channel(1);
|
||||
|
||||
let pending_requests = Arc::new(PendingRequests::new());
|
||||
let actor = StdioActor {
|
||||
receiver: Some(outbox_rx), // client to process
|
||||
sender: Some(inbox_tx), // process to client
|
||||
@@ -278,6 +268,8 @@ impl Transport for StdioTransport {
|
||||
stdin: Some(stdin),
|
||||
stdout: Some(stdout),
|
||||
stderr: Some(stderr),
|
||||
#[cfg(unix)]
|
||||
pgid, // Pass the pgid to the actor
|
||||
};
|
||||
|
||||
tokio::spawn(actor.run());
|
||||
@@ -286,23 +278,13 @@ impl Transport for StdioTransport {
|
||||
sender: outbox_tx, // client to process
|
||||
receiver: Arc::new(Mutex::new(inbox_rx)), // process to client
|
||||
error_receiver: Arc::new(Mutex::new(error_rx)),
|
||||
pending_requests,
|
||||
};
|
||||
Ok(handle)
|
||||
}
|
||||
|
||||
async fn close(&self) -> Result<(), Error> {
|
||||
// Attempt to clean up the process group on close
|
||||
#[cfg(unix)]
|
||||
if let Some(pgid) = PROCESS_GROUP.load(Ordering::SeqCst).checked_abs() {
|
||||
// Use nix instead of unsafe libc calls
|
||||
// Try SIGTERM first
|
||||
let _ = kill(Pid::from_raw(-pgid), Signal::SIGTERM);
|
||||
// Give processes a moment to cleanup
|
||||
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
|
||||
// Force kill if still running
|
||||
let _ = kill(Pid::from_raw(-pgid), Signal::SIGKILL);
|
||||
}
|
||||
// The StdioActor's Drop implementation handles process termination.
|
||||
// This method can be a no-op for now.
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user