diff --git a/crates/goose-api/Cargo.toml b/crates/goose-api/Cargo.toml index d3ba498b..8cf2fec5 100644 --- a/crates/goose-api/Cargo.toml +++ b/crates/goose-api/Cargo.toml @@ -18,4 +18,6 @@ config = "0.13" jsonwebtoken = "8" futures = "0.3" futures-util = "0.3" +# For session IDs +uuid = { version = "1", features = ["serde", "v4"] } # Add dynamic-library for extension loading diff --git a/crates/goose-api/src/main.rs b/crates/goose-api/src/main.rs index 4edae16c..ac2e7f82 100644 --- a/crates/goose-api/src/main.rs +++ b/crates/goose-api/src/main.rs @@ -2,6 +2,8 @@ use warp::{Filter, Rejection}; use warp::http::HeaderValue; use serde::{Deserialize, Serialize}; use std::sync::LazyLock; +use std::collections::HashMap; +use uuid::Uuid; use goose::agents::{Agent, extension_manager::ExtensionManager}; use goose::config::Config; use goose::providers::{create, providers}; @@ -20,6 +22,11 @@ static AGENT: LazyLock> = LazyLock::new(|| { tokio::sync::Mutex::new(Agent::new()) }); +// Global store for session histories +static SESSION_HISTORY: LazyLock>>> = LazyLock::new(|| { + tokio::sync::Mutex::new(HashMap::new()) +}); + #[derive(Debug, Serialize, Deserialize)] struct SessionRequest { prompt: String, @@ -31,6 +38,24 @@ struct ApiResponse { status: String, } +#[derive(Debug, Serialize, Deserialize)] +struct StartSessionResponse { + message: String, + status: String, + session_id: Uuid, +} + +#[derive(Debug, Serialize, Deserialize)] +struct SessionReplyRequest { + session_id: Uuid, + prompt: String, +} + +#[derive(Debug, Serialize, Deserialize)] +struct EndSessionRequest { + session_id: Uuid, +} + #[derive(Debug, Serialize, Deserialize)] struct ExtensionsResponse { extensions: Vec, @@ -47,32 +72,43 @@ async fn start_session_handler( _api_key: String, ) -> Result { info!("Starting session with prompt: {}", req.prompt); - - let agent = AGENT.lock().await; - + + let mut agent = AGENT.lock().await; + // Create a user message with the prompt - let messages = vec![Message::user().with_text(&req.prompt)]; - - // Process the messages through the agent + let mut messages = vec![Message::user().with_text(&req.prompt)]; + + // Generate a new session ID and process the messages + let session_id = Uuid::new_v4(); + let result = agent.reply(&messages, None).await; - + match result { Ok(mut stream) => { // Process the stream to get the first response if let Ok(Some(response)) = stream.try_next().await { let response_text = response.as_concat_text(); - let api_response = ApiResponse { - message: format!("Session started with prompt: {}. Response: {}", req.prompt, response_text), + messages.push(response); + let mut history = SESSION_HISTORY.lock().await; + history.insert(session_id, messages); + + let api_response = StartSessionResponse { + message: response_text, status: "success".to_string(), + session_id, }; Ok(warp::reply::with_status( warp::reply::json(&api_response), warp::http::StatusCode::OK, )) } else { - let api_response = ApiResponse { - message: format!("Session started but no response generated"), + let mut history = SESSION_HISTORY.lock().await; + history.insert(session_id, messages); + + let api_response = StartSessionResponse { + message: "Session started but no response generated".to_string(), status: "warning".to_string(), + session_id, }; Ok(warp::reply::with_status( warp::reply::json(&api_response), @@ -95,16 +131,33 @@ async fn start_session_handler( } async fn reply_session_handler( - req: SessionRequest, + req: SessionReplyRequest, _api_key: String, ) -> Result { info!("Replying to session with prompt: {}", req.prompt); - - let agent = AGENT.lock().await; - - // Create a user message with the prompt - let messages = vec![Message::user().with_text(&req.prompt)]; - + + let mut agent = AGENT.lock().await; + + // Retrieve existing session history + let mut history = SESSION_HISTORY.lock().await; + let entry = match history.get_mut(&req.session_id) { + Some(messages) => messages, + None => { + let response = ApiResponse { + message: "Session not found".to_string(), + status: "error".to_string(), + }; + return Ok(warp::reply::with_status( + warp::reply::json(&response), + warp::http::StatusCode::NOT_FOUND, + )); + } + }; + + // Append the new user message + entry.push(Message::user().with_text(&req.prompt)); + let messages = entry.clone(); + // Process the messages through the agent let result = agent.reply(&messages, None).await; @@ -113,6 +166,8 @@ async fn reply_session_handler( // Process the stream to get the first response if let Ok(Some(response)) = stream.try_next().await { let response_text = response.as_concat_text(); + // store assistant response in history + entry.push(response); let api_response = ApiResponse { message: format!("Reply: {}", response_text), status: "success".to_string(), @@ -123,7 +178,7 @@ async fn reply_session_handler( )) } else { let api_response = ApiResponse { - message: format!("Reply processed but no response generated"), + message: "Reply processed but no response generated".to_string(), status: "warning".to_string(), }; Ok(warp::reply::with_status( @@ -146,6 +201,32 @@ async fn reply_session_handler( } } +async fn end_session_handler( + req: EndSessionRequest, + _api_key: String, +) -> Result { + let mut history = SESSION_HISTORY.lock().await; + if history.remove(&req.session_id).is_some() { + let response = ApiResponse { + message: "Session ended".to_string(), + status: "success".to_string(), + }; + Ok(warp::reply::with_status( + warp::reply::json(&response), + warp::http::StatusCode::OK, + )) + } else { + let response = ApiResponse { + message: "Session not found".to_string(), + status: "error".to_string(), + }; + Ok(warp::reply::with_status( + warp::reply::json(&response), + warp::http::StatusCode::NOT_FOUND, + )) + } +} + async fn list_extensions_handler() -> Result { info!("Listing extensions"); @@ -331,13 +412,21 @@ async fn main() -> Result<(), anyhow::Error> { .and(with_api_key(api_key.clone())) .and_then(start_session_handler); - // Session reply endpoint + // Session reply endpoint let reply_session = warp::path("session") .and(warp::path("reply")) .and(warp::post()) .and(warp::body::json()) .and(with_api_key(api_key.clone())) .and_then(reply_session_handler); + + // Session end endpoint + let end_session = warp::path("session") + .and(warp::path("end")) + .and(warp::post()) + .and(warp::body::json()) + .and(with_api_key(api_key.clone())) + .and_then(end_session_handler); // List extensions endpoint let list_extensions = warp::path("extensions") @@ -354,6 +443,7 @@ async fn main() -> Result<(), anyhow::Error> { // Combine all routes let routes = start_session .or(reply_session) + .or(end_session) .or(list_extensions) .or(get_provider_config);