mirror of
https://github.com/aljazceru/goose.git
synced 2026-02-23 07:24:24 +01:00
Merge branch 'goose-api' into codex/implement-summarize_session_handler
This commit is contained in:
@@ -22,3 +22,7 @@ futures-util = "0.3"
|
||||
# For session IDs
|
||||
uuid = { version = "1", features = ["serde", "v4"] }
|
||||
# Add dynamic-library for extension loading
|
||||
|
||||
[dev-dependencies]
|
||||
tempfile = "3"
|
||||
async-trait = "0.1"
|
||||
|
||||
@@ -258,6 +258,28 @@ By default, the server runs on `127.0.0.1:8080`. You can modify this using confi
|
||||
{
|
||||
"message": "<summarized conversation>",
|
||||
"status": "success"
|
||||
=======
|
||||
### 7. Metrics
|
||||
|
||||
**Endpoint**: `GET /metrics`
|
||||
|
||||
**Description**: Returns runtime metrics about stored sessions and extensions.
|
||||
|
||||
**Request**:
|
||||
- Headers:
|
||||
- `x-api-key: [your-api-key]`
|
||||
|
||||
**Response** (example):
|
||||
```json
|
||||
{
|
||||
"session_messages": {
|
||||
"20240605_001234": 3,
|
||||
"20240605_010000": 5
|
||||
},
|
||||
"active_sessions": 2,
|
||||
"pending_requests": {
|
||||
"mcp_say": 0
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
|
||||
@@ -6,10 +6,11 @@ use futures_util::TryStreamExt;
|
||||
use tracing::{info, warn, error};
|
||||
use mcp_core::tool::Tool;
|
||||
use goose::agents::{extension::Envs, extension_manager::ExtensionManager, ExtensionConfig, Agent, SessionConfig};
|
||||
use goose::message::Message;
|
||||
use goose::message::{Message, MessageContent};
|
||||
use goose::session::{self, Identifier};
|
||||
use goose::config::Config;
|
||||
use std::sync::LazyLock;
|
||||
use std::collections::HashMap;
|
||||
|
||||
pub static EXTENSION_MANAGER: LazyLock<ExtensionManager> = LazyLock::new(|| ExtensionManager::default());
|
||||
pub static AGENT: LazyLock<tokio::sync::Mutex<Agent>> = LazyLock::new(|| tokio::sync::Mutex::new(Agent::new()));
|
||||
@@ -65,6 +66,13 @@ pub struct ExtensionResponse {
|
||||
pub message: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct MetricsResponse {
|
||||
pub session_messages: HashMap<String, usize>,
|
||||
pub active_sessions: usize,
|
||||
pub pending_requests: HashMap<String, usize>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
#[serde(tag = "type")]
|
||||
pub enum ExtensionConfigRequest {
|
||||
@@ -132,6 +140,30 @@ pub async fn start_session_handler(
|
||||
match result {
|
||||
Ok(mut stream) => {
|
||||
if let Ok(Some(response)) = stream.try_next().await {
|
||||
if matches!(response.content.first(), Some(MessageContent::ContextLengthExceeded(_))) {
|
||||
match agent.summarize_context(&messages).await {
|
||||
Ok((summarized, _)) => {
|
||||
messages = summarized;
|
||||
if let Err(e) = session::persist_messages(&session_path, &messages, provider.clone()).await {
|
||||
warn!("Failed to persist session {}: {}", session_name, e);
|
||||
}
|
||||
|
||||
let api_response = StartSessionResponse {
|
||||
message: "Conversation summarized to fit context window".to_string(),
|
||||
status: "warning".to_string(),
|
||||
session_id,
|
||||
};
|
||||
return Ok(warp::reply::with_status(
|
||||
warp::reply::json(&api_response),
|
||||
warp::http::StatusCode::OK,
|
||||
));
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Failed to summarize context: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let response_text = response.as_concat_text();
|
||||
messages.push(response);
|
||||
if let Err(e) = session::persist_messages(&session_path, &messages, provider.clone()).await {
|
||||
@@ -220,6 +252,28 @@ pub async fn reply_session_handler(
|
||||
match result {
|
||||
Ok(mut stream) => {
|
||||
if let Ok(Some(response)) = stream.try_next().await {
|
||||
if matches!(response.content.first(), Some(MessageContent::ContextLengthExceeded(_))) {
|
||||
match agent.summarize_context(&messages).await {
|
||||
Ok((summarized, _)) => {
|
||||
messages = summarized;
|
||||
if let Err(e) = session::persist_messages(&session_path, &messages, provider.clone()).await {
|
||||
warn!("Failed to persist session {}: {}", session_name, e);
|
||||
}
|
||||
let api_response = ApiResponse {
|
||||
message: "Conversation summarized to fit context window".to_string(),
|
||||
status: "warning".to_string(),
|
||||
};
|
||||
return Ok(warp::reply::with_status(
|
||||
warp::reply::json(&api_response),
|
||||
warp::http::StatusCode::OK,
|
||||
));
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Failed to summarize context: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let response_text = response.as_concat_text();
|
||||
messages.push(response);
|
||||
if let Err(e) = session::persist_messages(&session_path, &messages, provider.clone()).await {
|
||||
@@ -508,6 +562,33 @@ pub async fn remove_extension_handler(
|
||||
Ok(warp::reply::json(&resp))
|
||||
}
|
||||
|
||||
pub async fn metrics_handler() -> Result<impl warp::Reply, Rejection> {
|
||||
// Gather session message counts
|
||||
let mut session_messages = HashMap::new();
|
||||
if let Ok(sessions) = session::list_sessions() {
|
||||
for (name, path) in sessions {
|
||||
if let Ok(messages) = session::read_messages(&path) {
|
||||
session_messages.insert(name, messages.len());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let active_sessions = session_messages.len();
|
||||
|
||||
// Gather pending request sizes for each extension
|
||||
let pending_requests = EXTENSION_MANAGER
|
||||
.pending_request_sizes()
|
||||
.await;
|
||||
|
||||
let resp = MetricsResponse {
|
||||
session_messages,
|
||||
active_sessions,
|
||||
pending_requests,
|
||||
};
|
||||
|
||||
Ok(warp::reply::json(&resp))
|
||||
}
|
||||
|
||||
pub fn with_api_key(api_key: String) -> impl Filter<Extract = (String,), Error = Rejection> + Clone {
|
||||
warp::header::value("x-api-key")
|
||||
.and_then(move |header_api_key: HeaderValue| {
|
||||
|
||||
@@ -5,6 +5,7 @@ use crate::handlers::{
|
||||
add_extension_handler, end_session_handler, get_provider_config_handler,
|
||||
list_extensions_handler, remove_extension_handler, reply_session_handler,
|
||||
start_session_handler, summarize_session_handler, with_api_key,
|
||||
|
||||
};
|
||||
use crate::config::{
|
||||
initialize_extensions, initialize_provider_config, load_configuration,
|
||||
@@ -64,6 +65,10 @@ pub fn build_routes(api_key: String) -> impl Filter<Extract = impl warp::Reply,
|
||||
.and(warp::get())
|
||||
.and_then(get_provider_config_handler);
|
||||
|
||||
let metrics = warp::path("metrics")
|
||||
.and(warp::get())
|
||||
.and_then(metrics_handler);
|
||||
|
||||
start_session
|
||||
.or(reply_session)
|
||||
.or(summarize_session)
|
||||
@@ -72,6 +77,7 @@ pub fn build_routes(api_key: String) -> impl Filter<Extract = impl warp::Reply,
|
||||
.or(add_extension)
|
||||
.or(remove_extension)
|
||||
.or(get_provider_config)
|
||||
.or(metrics)
|
||||
}
|
||||
|
||||
pub async fn run_server() -> Result<(), anyhow::Error> {
|
||||
|
||||
@@ -1,10 +1,107 @@
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use goose::message::{Message, MessageContent};
|
||||
use goose::model::ModelConfig;
|
||||
use goose::providers::{
|
||||
base::{Provider, ProviderMetadata, ProviderUsage, Usage},
|
||||
errors::ProviderError,
|
||||
};
|
||||
use mcp_core::tool::Tool;
|
||||
use std::sync::Arc;
|
||||
use tempfile::TempDir;
|
||||
use warp::reply::Reply;
|
||||
use goose::session::{self, Identifier};
|
||||
use uuid::Uuid;
|
||||
use hyper::body;
|
||||
|
||||
#[derive(Clone)]
|
||||
struct ContextProvider {
|
||||
model_config: ModelConfig,
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl Provider for ContextProvider {
|
||||
fn metadata() -> ProviderMetadata {
|
||||
ProviderMetadata::empty()
|
||||
}
|
||||
|
||||
fn get_model_config(&self) -> ModelConfig {
|
||||
self.model_config.clone()
|
||||
}
|
||||
|
||||
async fn complete(
|
||||
&self,
|
||||
system: &str,
|
||||
_messages: &[Message],
|
||||
_tools: &[Tool],
|
||||
) -> Result<(Message, ProviderUsage), ProviderError> {
|
||||
if system.contains("summarizing") {
|
||||
Ok((
|
||||
Message::user().with_text("summary"),
|
||||
ProviderUsage::new("mock".to_string(), Usage::default()),
|
||||
))
|
||||
} else {
|
||||
Err(ProviderError::ContextLengthExceeded("too long".to_string()))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn setup() -> (TempDir, Uuid) {
|
||||
let tmp = tempfile::tempdir().unwrap();
|
||||
std::env::set_var("HOME", tmp.path());
|
||||
|
||||
let provider = Arc::new(ContextProvider {
|
||||
model_config: ModelConfig::new("test".to_string()),
|
||||
});
|
||||
let agent = AGENT.lock().await;
|
||||
agent.update_provider(provider).await.unwrap();
|
||||
drop(agent);
|
||||
|
||||
let req = SessionRequest {
|
||||
prompt: "start".repeat(1000),
|
||||
};
|
||||
let reply = start_session_handler(req, "key".to_string()).await.unwrap();
|
||||
let resp = reply.into_response();
|
||||
let body = body::to_bytes(resp.into_body()).await.unwrap();
|
||||
let start: StartSessionResponse = serde_json::from_slice(&body).unwrap();
|
||||
(tmp, start.session_id)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn build_routes_compiles() {
|
||||
let _routes = build_routes("test-key".to_string());
|
||||
// Just ensure building routes doesn't panic
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn summarizes_large_history_on_start() {
|
||||
let (tmp, session_id) = setup().await;
|
||||
|
||||
let session_path = session::get_path(Identifier::Name(session_id.to_string()));
|
||||
let messages = session::read_messages(&session_path).unwrap();
|
||||
assert!(messages.iter().any(|m| m.as_concat_text().contains("summary")));
|
||||
drop(tmp);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn summarizes_large_history_on_reply() {
|
||||
let (tmp, session_id) = setup().await;
|
||||
|
||||
let req = SessionReplyRequest {
|
||||
session_id,
|
||||
prompt: "reply".repeat(1000),
|
||||
};
|
||||
let reply = reply_session_handler(req, "key".to_string()).await.unwrap();
|
||||
let resp = reply.into_response();
|
||||
let body = body::to_bytes(resp.into_body()).await.unwrap();
|
||||
let api: ApiResponse = serde_json::from_slice(&body).unwrap();
|
||||
assert_eq!(api.status, "warning");
|
||||
|
||||
let session_path = session::get_path(Identifier::Name(session_id.to_string()));
|
||||
let messages = session::read_messages(&session_path).unwrap();
|
||||
assert!(messages
|
||||
.iter()
|
||||
.all(|m| !matches!(m.content.first(), Some(MessageContent::ContextLengthExceeded(_)))));
|
||||
drop(tmp);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -17,7 +17,7 @@ use crate::agents::extension::Envs;
|
||||
use crate::config::{Config, ExtensionConfigManager};
|
||||
use crate::prompt_template;
|
||||
use mcp_client::client::{ClientCapabilities, ClientInfo, McpClient, McpClientTrait};
|
||||
use mcp_client::transport::{SseTransport, StdioTransport, Transport};
|
||||
use mcp_client::transport::{PendingRequests, SseTransport, StdioTransport, Transport};
|
||||
use mcp_core::{prompt::Prompt, Content, Tool, ToolCall, ToolError, ToolResult};
|
||||
use serde_json::Value;
|
||||
|
||||
@@ -33,6 +33,7 @@ pub struct ExtensionManager {
|
||||
clients: HashMap<String, McpClientBox>,
|
||||
instructions: HashMap<String, String>,
|
||||
resource_capable_extensions: HashSet<String>,
|
||||
pending_requests: HashMap<String, Arc<PendingRequests>>, // track pending requests per extension
|
||||
}
|
||||
|
||||
/// A flattened representation of a resource used by the agent to prepare inference
|
||||
@@ -103,6 +104,7 @@ impl ExtensionManager {
|
||||
clients: HashMap::new(),
|
||||
instructions: HashMap::new(),
|
||||
resource_capable_extensions: HashSet::new(),
|
||||
pending_requests: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -183,12 +185,14 @@ impl ExtensionManager {
|
||||
let all_envs = merge_environments(envs, env_keys, &sanitized_name).await?;
|
||||
let transport = SseTransport::new(uri, all_envs);
|
||||
let handle = transport.start().await?;
|
||||
let pending = handle.pending_requests();
|
||||
let service = McpService::with_timeout(
|
||||
handle,
|
||||
Duration::from_secs(
|
||||
timeout.unwrap_or(crate::config::DEFAULT_EXTENSION_TIMEOUT),
|
||||
),
|
||||
);
|
||||
self.pending_requests.insert(sanitized_name.clone(), pending);
|
||||
Box::new(McpClient::new(service))
|
||||
}
|
||||
ExtensionConfig::Stdio {
|
||||
@@ -202,12 +206,14 @@ impl ExtensionManager {
|
||||
let all_envs = merge_environments(envs, env_keys, &sanitized_name).await?;
|
||||
let transport = StdioTransport::new(cmd, args.to_vec(), all_envs);
|
||||
let handle = transport.start().await?;
|
||||
let pending = handle.pending_requests();
|
||||
let service = McpService::with_timeout(
|
||||
handle,
|
||||
Duration::from_secs(
|
||||
timeout.unwrap_or(crate::config::DEFAULT_EXTENSION_TIMEOUT),
|
||||
),
|
||||
);
|
||||
self.pending_requests.insert(sanitized_name.clone(), pending);
|
||||
Box::new(McpClient::new(service))
|
||||
}
|
||||
ExtensionConfig::Builtin {
|
||||
@@ -227,12 +233,14 @@ impl ExtensionManager {
|
||||
HashMap::new(),
|
||||
);
|
||||
let handle = transport.start().await?;
|
||||
let pending = handle.pending_requests();
|
||||
let service = McpService::with_timeout(
|
||||
handle,
|
||||
Duration::from_secs(
|
||||
timeout.unwrap_or(crate::config::DEFAULT_EXTENSION_TIMEOUT),
|
||||
),
|
||||
);
|
||||
self.pending_requests.insert(sanitized_name.clone(), pending);
|
||||
Box::new(McpClient::new(service))
|
||||
}
|
||||
_ => unreachable!(),
|
||||
@@ -285,9 +293,19 @@ impl ExtensionManager {
|
||||
self.clients.remove(&sanitized_name);
|
||||
self.instructions.remove(&sanitized_name);
|
||||
self.resource_capable_extensions.remove(&sanitized_name);
|
||||
self.pending_requests.remove(&sanitized_name);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get the size of each extension's pending request map
|
||||
pub async fn pending_request_sizes(&self) -> HashMap<String, usize> {
|
||||
let mut result = HashMap::new();
|
||||
for (name, pending) in &self.pending_requests {
|
||||
result.insert(name.clone(), pending.len().await);
|
||||
}
|
||||
result
|
||||
}
|
||||
|
||||
pub async fn suggest_disable_extensions_prompt(&self) -> Value {
|
||||
let enabled_extensions_count = self.clients.len();
|
||||
|
||||
|
||||
@@ -223,6 +223,7 @@ impl SseActor {
|
||||
#[derive(Clone)]
|
||||
pub struct SseTransportHandle {
|
||||
sender: mpsc::Sender<TransportMessage>,
|
||||
pending_requests: Arc<PendingRequests>,
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
@@ -232,6 +233,12 @@ impl TransportHandle for SseTransportHandle {
|
||||
}
|
||||
}
|
||||
|
||||
impl SseTransportHandle {
|
||||
pub fn pending_requests(&self) -> Arc<PendingRequests> {
|
||||
Arc::clone(&self.pending_requests)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct SseTransport {
|
||||
sse_url: String,
|
||||
@@ -284,9 +291,10 @@ impl Transport for SseTransport {
|
||||
let post_endpoint_clone = Arc::clone(&post_endpoint);
|
||||
|
||||
// Build the actor
|
||||
let pending_requests = Arc::new(PendingRequests::new());
|
||||
let actor = SseActor::new(
|
||||
rx,
|
||||
Arc::new(PendingRequests::new()),
|
||||
pending_requests.clone(),
|
||||
self.sse_url.clone(),
|
||||
post_endpoint,
|
||||
);
|
||||
@@ -301,7 +309,7 @@ impl Transport for SseTransport {
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(_) => Ok(SseTransportHandle { sender: tx }),
|
||||
Ok(_) => Ok(SseTransportHandle { sender: tx, pending_requests }),
|
||||
Err(e) => Err(Error::SseConnection(e.to_string())),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -189,6 +189,7 @@ impl StdioActor {
|
||||
pub struct StdioTransportHandle {
|
||||
sender: mpsc::Sender<TransportMessage>,
|
||||
error_receiver: Arc<Mutex<mpsc::Receiver<Error>>>,
|
||||
pending_requests: Arc<PendingRequests>,
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
@@ -212,6 +213,10 @@ impl StdioTransportHandle {
|
||||
Err(_) => Ok(()),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn pending_requests(&self) -> Arc<PendingRequests> {
|
||||
Arc::clone(&self.pending_requests)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct StdioTransport {
|
||||
@@ -292,9 +297,10 @@ impl Transport for StdioTransport {
|
||||
let (message_tx, message_rx) = mpsc::channel(32);
|
||||
let (error_tx, error_rx) = mpsc::channel(1);
|
||||
|
||||
let pending_requests = Arc::new(PendingRequests::new());
|
||||
let actor = StdioActor {
|
||||
receiver: Some(message_rx),
|
||||
pending_requests: Arc::new(PendingRequests::new()),
|
||||
pending_requests: pending_requests.clone(),
|
||||
process,
|
||||
error_sender: error_tx,
|
||||
stdin: Some(stdin),
|
||||
@@ -307,6 +313,7 @@ impl Transport for StdioTransport {
|
||||
let handle = StdioTransportHandle {
|
||||
sender: message_tx,
|
||||
error_receiver: Arc::new(Mutex::new(error_rx)),
|
||||
pending_requests,
|
||||
};
|
||||
Ok(handle)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user