mirror of
https://github.com/aljazceru/goose.git
synced 2025-12-18 22:54:24 +01:00
Merge pull request #17 from aljazceru/codex/summarize-context-on-message-size
Handle context length exceed in goose API
This commit is contained in:
@@ -22,3 +22,7 @@ futures-util = "0.3"
|
||||
# For session IDs
|
||||
uuid = { version = "1", features = ["serde", "v4"] }
|
||||
# Add dynamic-library for extension loading
|
||||
|
||||
[dev-dependencies]
|
||||
tempfile = "3"
|
||||
async-trait = "0.1"
|
||||
|
||||
@@ -6,7 +6,7 @@ use futures_util::TryStreamExt;
|
||||
use tracing::{info, warn, error};
|
||||
use mcp_core::tool::Tool;
|
||||
use goose::agents::{extension::Envs, extension_manager::ExtensionManager, ExtensionConfig, Agent, SessionConfig};
|
||||
use goose::message::Message;
|
||||
use goose::message::{Message, MessageContent};
|
||||
use goose::session::{self, Identifier};
|
||||
use goose::config::Config;
|
||||
use std::sync::LazyLock;
|
||||
@@ -135,6 +135,30 @@ pub async fn start_session_handler(
|
||||
match result {
|
||||
Ok(mut stream) => {
|
||||
if let Ok(Some(response)) = stream.try_next().await {
|
||||
if matches!(response.content.first(), Some(MessageContent::ContextLengthExceeded(_))) {
|
||||
match agent.summarize_context(&messages).await {
|
||||
Ok((summarized, _)) => {
|
||||
messages = summarized;
|
||||
if let Err(e) = session::persist_messages(&session_path, &messages, provider.clone()).await {
|
||||
warn!("Failed to persist session {}: {}", session_name, e);
|
||||
}
|
||||
|
||||
let api_response = StartSessionResponse {
|
||||
message: "Conversation summarized to fit context window".to_string(),
|
||||
status: "warning".to_string(),
|
||||
session_id,
|
||||
};
|
||||
return Ok(warp::reply::with_status(
|
||||
warp::reply::json(&api_response),
|
||||
warp::http::StatusCode::OK,
|
||||
));
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Failed to summarize context: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let response_text = response.as_concat_text();
|
||||
messages.push(response);
|
||||
if let Err(e) = session::persist_messages(&session_path, &messages, provider.clone()).await {
|
||||
@@ -223,6 +247,28 @@ pub async fn reply_session_handler(
|
||||
match result {
|
||||
Ok(mut stream) => {
|
||||
if let Ok(Some(response)) = stream.try_next().await {
|
||||
if matches!(response.content.first(), Some(MessageContent::ContextLengthExceeded(_))) {
|
||||
match agent.summarize_context(&messages).await {
|
||||
Ok((summarized, _)) => {
|
||||
messages = summarized;
|
||||
if let Err(e) = session::persist_messages(&session_path, &messages, provider.clone()).await {
|
||||
warn!("Failed to persist session {}: {}", session_name, e);
|
||||
}
|
||||
let api_response = ApiResponse {
|
||||
message: "Conversation summarized to fit context window".to_string(),
|
||||
status: "warning".to_string(),
|
||||
};
|
||||
return Ok(warp::reply::with_status(
|
||||
warp::reply::json(&api_response),
|
||||
warp::http::StatusCode::OK,
|
||||
));
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Failed to summarize context: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let response_text = response.as_concat_text();
|
||||
messages.push(response);
|
||||
if let Err(e) = session::persist_messages(&session_path, &messages, provider.clone()).await {
|
||||
|
||||
@@ -1,10 +1,107 @@
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use goose::message::{Message, MessageContent};
|
||||
use goose::model::ModelConfig;
|
||||
use goose::providers::{
|
||||
base::{Provider, ProviderMetadata, ProviderUsage, Usage},
|
||||
errors::ProviderError,
|
||||
};
|
||||
use mcp_core::tool::Tool;
|
||||
use std::sync::Arc;
|
||||
use tempfile::TempDir;
|
||||
use warp::reply::Reply;
|
||||
use goose::session::{self, Identifier};
|
||||
use uuid::Uuid;
|
||||
use hyper::body;
|
||||
|
||||
#[derive(Clone)]
|
||||
struct ContextProvider {
|
||||
model_config: ModelConfig,
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl Provider for ContextProvider {
|
||||
fn metadata() -> ProviderMetadata {
|
||||
ProviderMetadata::empty()
|
||||
}
|
||||
|
||||
fn get_model_config(&self) -> ModelConfig {
|
||||
self.model_config.clone()
|
||||
}
|
||||
|
||||
async fn complete(
|
||||
&self,
|
||||
system: &str,
|
||||
_messages: &[Message],
|
||||
_tools: &[Tool],
|
||||
) -> Result<(Message, ProviderUsage), ProviderError> {
|
||||
if system.contains("summarizing") {
|
||||
Ok((
|
||||
Message::user().with_text("summary"),
|
||||
ProviderUsage::new("mock".to_string(), Usage::default()),
|
||||
))
|
||||
} else {
|
||||
Err(ProviderError::ContextLengthExceeded("too long".to_string()))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn setup() -> (TempDir, Uuid) {
|
||||
let tmp = tempfile::tempdir().unwrap();
|
||||
std::env::set_var("HOME", tmp.path());
|
||||
|
||||
let provider = Arc::new(ContextProvider {
|
||||
model_config: ModelConfig::new("test".to_string()),
|
||||
});
|
||||
let agent = AGENT.lock().await;
|
||||
agent.update_provider(provider).await.unwrap();
|
||||
drop(agent);
|
||||
|
||||
let req = SessionRequest {
|
||||
prompt: "start".repeat(1000),
|
||||
};
|
||||
let reply = start_session_handler(req, "key".to_string()).await.unwrap();
|
||||
let resp = reply.into_response();
|
||||
let body = body::to_bytes(resp.into_body()).await.unwrap();
|
||||
let start: StartSessionResponse = serde_json::from_slice(&body).unwrap();
|
||||
(tmp, start.session_id)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn build_routes_compiles() {
|
||||
let _routes = build_routes("test-key".to_string());
|
||||
// Just ensure building routes doesn't panic
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn summarizes_large_history_on_start() {
|
||||
let (tmp, session_id) = setup().await;
|
||||
|
||||
let session_path = session::get_path(Identifier::Name(session_id.to_string()));
|
||||
let messages = session::read_messages(&session_path).unwrap();
|
||||
assert!(messages.iter().any(|m| m.as_concat_text().contains("summary")));
|
||||
drop(tmp);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn summarizes_large_history_on_reply() {
|
||||
let (tmp, session_id) = setup().await;
|
||||
|
||||
let req = SessionReplyRequest {
|
||||
session_id,
|
||||
prompt: "reply".repeat(1000),
|
||||
};
|
||||
let reply = reply_session_handler(req, "key".to_string()).await.unwrap();
|
||||
let resp = reply.into_response();
|
||||
let body = body::to_bytes(resp.into_body()).await.unwrap();
|
||||
let api: ApiResponse = serde_json::from_slice(&body).unwrap();
|
||||
assert_eq!(api.status, "warning");
|
||||
|
||||
let session_path = session::get_path(Identifier::Name(session_id.to_string()));
|
||||
let messages = session::read_messages(&session_path).unwrap();
|
||||
assert!(messages
|
||||
.iter()
|
||||
.all(|m| !matches!(m.content.first(), Some(MessageContent::ContextLengthExceeded(_)))));
|
||||
drop(tmp);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user