mirror of
https://github.com/aljazceru/goose.git
synced 2026-02-15 19:44:20 +01:00
feat(api): manage agents per session
This commit is contained in:
@@ -21,4 +21,5 @@ futures = "0.3"
|
||||
futures-util = "0.3"
|
||||
# For session IDs
|
||||
uuid = { version = "1", features = ["serde", "v4"] }
|
||||
dashmap = "6"
|
||||
# Add dynamic-library for extension loading
|
||||
|
||||
45
crates/goose-api/src/api_sessions.rs
Normal file
45
crates/goose-api/src/api_sessions.rs
Normal file
@@ -0,0 +1,45 @@
|
||||
use dashmap::DashMap;
|
||||
use goose::agents::Agent;
|
||||
use std::sync::{atomic::{AtomicU64, Ordering}, Arc, LazyLock};
|
||||
use tokio::sync::Mutex;
|
||||
use uuid::Uuid;
|
||||
use std::time::{Duration, SystemTime, UNIX_EPOCH};
|
||||
|
||||
pub struct ApiSession {
|
||||
pub agent: Arc<Mutex<Agent>>, // agent for this session
|
||||
last_active: AtomicU64,
|
||||
}
|
||||
|
||||
impl ApiSession {
|
||||
pub fn new(agent: Agent) -> Self {
|
||||
Self {
|
||||
agent: Arc::new(Mutex::new(agent)),
|
||||
last_active: AtomicU64::new(current_timestamp()),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn touch(&self) {
|
||||
self.last_active.store(current_timestamp(), Ordering::Relaxed);
|
||||
}
|
||||
|
||||
pub fn is_expired(&self, ttl: Duration) -> bool {
|
||||
current_timestamp() - self.last_active.load(Ordering::Relaxed) > ttl.as_secs()
|
||||
}
|
||||
}
|
||||
|
||||
fn current_timestamp() -> u64 {
|
||||
SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.as_secs()
|
||||
}
|
||||
|
||||
pub static SESSIONS: LazyLock<DashMap<Uuid, ApiSession>> = LazyLock::new(DashMap::new);
|
||||
|
||||
pub const SESSION_TIMEOUT_SECS: u64 = 3600;
|
||||
|
||||
pub fn cleanup_expired_sessions() {
|
||||
let ttl = Duration::from_secs(SESSION_TIMEOUT_SECS);
|
||||
SESSIONS.retain(|_, sess| !sess.is_expired(ttl));
|
||||
}
|
||||
|
||||
@@ -10,6 +10,7 @@ use goose::message::Message;
|
||||
use goose::session::{self, Identifier};
|
||||
use goose::config::Config;
|
||||
use std::sync::LazyLock;
|
||||
use crate::api_sessions::{ApiSession, SESSIONS, cleanup_expired_sessions};
|
||||
|
||||
pub static EXTENSION_MANAGER: LazyLock<ExtensionManager> = LazyLock::new(|| ExtensionManager::default());
|
||||
pub static AGENT: LazyLock<tokio::sync::Mutex<Agent>> = LazyLock::new(|| tokio::sync::Mutex::new(Agent::new()));
|
||||
@@ -105,15 +106,30 @@ pub async fn start_session_handler(
|
||||
) -> Result<impl warp::Reply, Rejection> {
|
||||
info!("Starting session with prompt: {}", req.prompt);
|
||||
|
||||
let agent = AGENT.lock().await;
|
||||
cleanup_expired_sessions();
|
||||
|
||||
// create fresh agent using provider from the template agent
|
||||
let template = AGENT.lock().await;
|
||||
let mut new_agent = Agent::new();
|
||||
if let Ok(provider) = template.provider().await {
|
||||
let _ = new_agent.update_provider(provider).await;
|
||||
}
|
||||
drop(template);
|
||||
|
||||
let mut messages = vec![Message::user().with_text(&req.prompt)];
|
||||
let session_id = Uuid::new_v4();
|
||||
let session_name = session_id.to_string();
|
||||
let session_path = session::get_path(Identifier::Name(session_name.clone()));
|
||||
|
||||
let provider = agent.provider().await.ok();
|
||||
let session = ApiSession::new(new_agent);
|
||||
let agent_ref = session.agent.clone();
|
||||
SESSIONS.insert(session_id, session);
|
||||
|
||||
let result = agent
|
||||
let provider = agent_ref.lock().await.provider().await.ok();
|
||||
|
||||
let result = agent_ref
|
||||
.lock()
|
||||
.await
|
||||
.reply(
|
||||
&messages,
|
||||
Some(SessionConfig {
|
||||
@@ -178,11 +194,28 @@ pub async fn reply_session_handler(
|
||||
) -> Result<impl warp::Reply, Rejection> {
|
||||
info!("Replying to session with prompt: {}", req.prompt);
|
||||
|
||||
let agent = AGENT.lock().await;
|
||||
cleanup_expired_sessions();
|
||||
|
||||
let session_name = req.session_id.to_string();
|
||||
let session_path = session::get_path(Identifier::Name(session_name.clone()));
|
||||
|
||||
let session_entry = match SESSIONS.get(&req.session_id) {
|
||||
Some(s) => s,
|
||||
None => {
|
||||
let response = ApiResponse {
|
||||
message: "Session not found".to_string(),
|
||||
status: "error".to_string(),
|
||||
};
|
||||
return Ok(warp::reply::with_status(
|
||||
warp::reply::json(&response),
|
||||
warp::http::StatusCode::NOT_FOUND,
|
||||
));
|
||||
}
|
||||
};
|
||||
session_entry.touch();
|
||||
let agent_ref = session_entry.agent.clone();
|
||||
drop(session_entry);
|
||||
|
||||
let mut messages = match session::read_messages(&session_path) {
|
||||
Ok(m) => m,
|
||||
Err(_) => {
|
||||
@@ -199,9 +232,11 @@ pub async fn reply_session_handler(
|
||||
|
||||
messages.push(Message::user().with_text(&req.prompt));
|
||||
|
||||
let provider = agent.provider().await.ok();
|
||||
let provider = agent_ref.lock().await.provider().await.ok();
|
||||
|
||||
let result = agent
|
||||
let result = agent_ref
|
||||
.lock()
|
||||
.await
|
||||
.reply(
|
||||
&messages,
|
||||
Some(SessionConfig {
|
||||
@@ -260,9 +295,14 @@ pub async fn end_session_handler(
|
||||
req: EndSessionRequest,
|
||||
_api_key: String,
|
||||
) -> Result<impl warp::Reply, Rejection> {
|
||||
cleanup_expired_sessions();
|
||||
|
||||
let session_name = req.session_id.to_string();
|
||||
let session_path = session::get_path(Identifier::Name(session_name.clone()));
|
||||
|
||||
// remove in-memory agent if present
|
||||
SESSIONS.remove(&req.session_id);
|
||||
|
||||
if std::fs::remove_file(&session_path).is_ok() {
|
||||
let response = ApiResponse {
|
||||
message: "Session ended".to_string(),
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
mod handlers;
|
||||
mod config;
|
||||
mod routes;
|
||||
mod api_sessions;
|
||||
|
||||
pub use routes::{build_routes, run_server};
|
||||
|
||||
Reference in New Issue
Block a user