diff --git a/crates/goose-api/Cargo.toml b/crates/goose-api/Cargo.toml index 1caa2ffe..7276663e 100644 --- a/crates/goose-api/Cargo.toml +++ b/crates/goose-api/Cargo.toml @@ -21,4 +21,5 @@ futures = "0.3" futures-util = "0.3" # For session IDs uuid = { version = "1", features = ["serde", "v4"] } +dashmap = "6" # Add dynamic-library for extension loading diff --git a/crates/goose-api/src/api_sessions.rs b/crates/goose-api/src/api_sessions.rs new file mode 100644 index 00000000..3c259ade --- /dev/null +++ b/crates/goose-api/src/api_sessions.rs @@ -0,0 +1,45 @@ +use dashmap::DashMap; +use goose::agents::Agent; +use std::sync::{atomic::{AtomicU64, Ordering}, Arc, LazyLock}; +use tokio::sync::Mutex; +use uuid::Uuid; +use std::time::{Duration, SystemTime, UNIX_EPOCH}; + +pub struct ApiSession { + pub agent: Arc>, // agent for this session + last_active: AtomicU64, +} + +impl ApiSession { + pub fn new(agent: Agent) -> Self { + Self { + agent: Arc::new(Mutex::new(agent)), + last_active: AtomicU64::new(current_timestamp()), + } + } + + pub fn touch(&self) { + self.last_active.store(current_timestamp(), Ordering::Relaxed); + } + + pub fn is_expired(&self, ttl: Duration) -> bool { + current_timestamp() - self.last_active.load(Ordering::Relaxed) > ttl.as_secs() + } +} + +fn current_timestamp() -> u64 { + SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap_or_default() + .as_secs() +} + +pub static SESSIONS: LazyLock> = LazyLock::new(DashMap::new); + +pub const SESSION_TIMEOUT_SECS: u64 = 3600; + +pub fn cleanup_expired_sessions() { + let ttl = Duration::from_secs(SESSION_TIMEOUT_SECS); + SESSIONS.retain(|_, sess| !sess.is_expired(ttl)); +} + diff --git a/crates/goose-api/src/handlers.rs b/crates/goose-api/src/handlers.rs index 096e7193..e84de7c4 100644 --- a/crates/goose-api/src/handlers.rs +++ b/crates/goose-api/src/handlers.rs @@ -10,6 +10,7 @@ use goose::message::Message; use goose::session::{self, Identifier}; use goose::config::Config; use std::sync::LazyLock; +use crate::api_sessions::{ApiSession, SESSIONS, cleanup_expired_sessions}; pub static EXTENSION_MANAGER: LazyLock = LazyLock::new(|| ExtensionManager::default()); pub static AGENT: LazyLock> = LazyLock::new(|| tokio::sync::Mutex::new(Agent::new())); @@ -105,15 +106,30 @@ pub async fn start_session_handler( ) -> Result { info!("Starting session with prompt: {}", req.prompt); - let agent = AGENT.lock().await; + cleanup_expired_sessions(); + + // create fresh agent using provider from the template agent + let template = AGENT.lock().await; + let mut new_agent = Agent::new(); + if let Ok(provider) = template.provider().await { + let _ = new_agent.update_provider(provider).await; + } + drop(template); + let mut messages = vec![Message::user().with_text(&req.prompt)]; let session_id = Uuid::new_v4(); let session_name = session_id.to_string(); let session_path = session::get_path(Identifier::Name(session_name.clone())); - let provider = agent.provider().await.ok(); + let session = ApiSession::new(new_agent); + let agent_ref = session.agent.clone(); + SESSIONS.insert(session_id, session); - let result = agent + let provider = agent_ref.lock().await.provider().await.ok(); + + let result = agent_ref + .lock() + .await .reply( &messages, Some(SessionConfig { @@ -178,11 +194,28 @@ pub async fn reply_session_handler( ) -> Result { info!("Replying to session with prompt: {}", req.prompt); - let agent = AGENT.lock().await; + cleanup_expired_sessions(); let session_name = req.session_id.to_string(); let session_path = session::get_path(Identifier::Name(session_name.clone())); + let session_entry = match SESSIONS.get(&req.session_id) { + Some(s) => s, + None => { + let response = ApiResponse { + message: "Session not found".to_string(), + status: "error".to_string(), + }; + return Ok(warp::reply::with_status( + warp::reply::json(&response), + warp::http::StatusCode::NOT_FOUND, + )); + } + }; + session_entry.touch(); + let agent_ref = session_entry.agent.clone(); + drop(session_entry); + let mut messages = match session::read_messages(&session_path) { Ok(m) => m, Err(_) => { @@ -199,9 +232,11 @@ pub async fn reply_session_handler( messages.push(Message::user().with_text(&req.prompt)); - let provider = agent.provider().await.ok(); + let provider = agent_ref.lock().await.provider().await.ok(); - let result = agent + let result = agent_ref + .lock() + .await .reply( &messages, Some(SessionConfig { @@ -260,9 +295,14 @@ pub async fn end_session_handler( req: EndSessionRequest, _api_key: String, ) -> Result { + cleanup_expired_sessions(); + let session_name = req.session_id.to_string(); let session_path = session::get_path(Identifier::Name(session_name.clone())); + // remove in-memory agent if present + SESSIONS.remove(&req.session_id); + if std::fs::remove_file(&session_path).is_ok() { let response = ApiResponse { message: "Session ended".to_string(), diff --git a/crates/goose-api/src/lib.rs b/crates/goose-api/src/lib.rs index b2037198..3b8e911e 100644 --- a/crates/goose-api/src/lib.rs +++ b/crates/goose-api/src/lib.rs @@ -1,5 +1,6 @@ mod handlers; mod config; mod routes; +mod api_sessions; pub use routes::{build_routes, run_server};