Add session management API

This commit is contained in:
2025-05-28 19:26:25 +02:00
parent 8cf60ffea3
commit 2d79624551
2 changed files with 112 additions and 20 deletions

View File

@@ -18,4 +18,6 @@ config = "0.13"
jsonwebtoken = "8"
futures = "0.3"
futures-util = "0.3"
# For session IDs
uuid = { version = "1", features = ["serde", "v4"] }
# Add dynamic-library for extension loading

View File

@@ -2,6 +2,8 @@ use warp::{Filter, Rejection};
use warp::http::HeaderValue;
use serde::{Deserialize, Serialize};
use std::sync::LazyLock;
use std::collections::HashMap;
use uuid::Uuid;
use goose::agents::{Agent, extension_manager::ExtensionManager};
use goose::config::Config;
use goose::providers::{create, providers};
@@ -20,6 +22,11 @@ static AGENT: LazyLock<tokio::sync::Mutex<Agent>> = LazyLock::new(|| {
tokio::sync::Mutex::new(Agent::new())
});
// Global store for session histories
static SESSION_HISTORY: LazyLock<tokio::sync::Mutex<HashMap<Uuid, Vec<Message>>>> = LazyLock::new(|| {
tokio::sync::Mutex::new(HashMap::new())
});
#[derive(Debug, Serialize, Deserialize)]
struct SessionRequest {
prompt: String,
@@ -31,6 +38,24 @@ struct ApiResponse {
status: String,
}
#[derive(Debug, Serialize, Deserialize)]
struct StartSessionResponse {
message: String,
status: String,
session_id: Uuid,
}
#[derive(Debug, Serialize, Deserialize)]
struct SessionReplyRequest {
session_id: Uuid,
prompt: String,
}
#[derive(Debug, Serialize, Deserialize)]
struct EndSessionRequest {
session_id: Uuid,
}
#[derive(Debug, Serialize, Deserialize)]
struct ExtensionsResponse {
extensions: Vec<String>,
@@ -47,32 +72,43 @@ async fn start_session_handler(
_api_key: String,
) -> Result<impl warp::Reply, Rejection> {
info!("Starting session with prompt: {}", req.prompt);
let agent = AGENT.lock().await;
let mut agent = AGENT.lock().await;
// Create a user message with the prompt
let messages = vec![Message::user().with_text(&req.prompt)];
// Process the messages through the agent
let mut messages = vec![Message::user().with_text(&req.prompt)];
// Generate a new session ID and process the messages
let session_id = Uuid::new_v4();
let result = agent.reply(&messages, None).await;
match result {
Ok(mut stream) => {
// Process the stream to get the first response
if let Ok(Some(response)) = stream.try_next().await {
let response_text = response.as_concat_text();
let api_response = ApiResponse {
message: format!("Session started with prompt: {}. Response: {}", req.prompt, response_text),
messages.push(response);
let mut history = SESSION_HISTORY.lock().await;
history.insert(session_id, messages);
let api_response = StartSessionResponse {
message: response_text,
status: "success".to_string(),
session_id,
};
Ok(warp::reply::with_status(
warp::reply::json(&api_response),
warp::http::StatusCode::OK,
))
} else {
let api_response = ApiResponse {
message: format!("Session started but no response generated"),
let mut history = SESSION_HISTORY.lock().await;
history.insert(session_id, messages);
let api_response = StartSessionResponse {
message: "Session started but no response generated".to_string(),
status: "warning".to_string(),
session_id,
};
Ok(warp::reply::with_status(
warp::reply::json(&api_response),
@@ -95,16 +131,33 @@ async fn start_session_handler(
}
async fn reply_session_handler(
req: SessionRequest,
req: SessionReplyRequest,
_api_key: String,
) -> Result<impl warp::Reply, Rejection> {
info!("Replying to session with prompt: {}", req.prompt);
let agent = AGENT.lock().await;
// Create a user message with the prompt
let messages = vec![Message::user().with_text(&req.prompt)];
let mut agent = AGENT.lock().await;
// Retrieve existing session history
let mut history = SESSION_HISTORY.lock().await;
let entry = match history.get_mut(&req.session_id) {
Some(messages) => messages,
None => {
let response = ApiResponse {
message: "Session not found".to_string(),
status: "error".to_string(),
};
return Ok(warp::reply::with_status(
warp::reply::json(&response),
warp::http::StatusCode::NOT_FOUND,
));
}
};
// Append the new user message
entry.push(Message::user().with_text(&req.prompt));
let messages = entry.clone();
// Process the messages through the agent
let result = agent.reply(&messages, None).await;
@@ -113,6 +166,8 @@ async fn reply_session_handler(
// Process the stream to get the first response
if let Ok(Some(response)) = stream.try_next().await {
let response_text = response.as_concat_text();
// store assistant response in history
entry.push(response);
let api_response = ApiResponse {
message: format!("Reply: {}", response_text),
status: "success".to_string(),
@@ -123,7 +178,7 @@ async fn reply_session_handler(
))
} else {
let api_response = ApiResponse {
message: format!("Reply processed but no response generated"),
message: "Reply processed but no response generated".to_string(),
status: "warning".to_string(),
};
Ok(warp::reply::with_status(
@@ -146,6 +201,32 @@ async fn reply_session_handler(
}
}
async fn end_session_handler(
req: EndSessionRequest,
_api_key: String,
) -> Result<impl warp::Reply, Rejection> {
let mut history = SESSION_HISTORY.lock().await;
if history.remove(&req.session_id).is_some() {
let response = ApiResponse {
message: "Session ended".to_string(),
status: "success".to_string(),
};
Ok(warp::reply::with_status(
warp::reply::json(&response),
warp::http::StatusCode::OK,
))
} else {
let response = ApiResponse {
message: "Session not found".to_string(),
status: "error".to_string(),
};
Ok(warp::reply::with_status(
warp::reply::json(&response),
warp::http::StatusCode::NOT_FOUND,
))
}
}
async fn list_extensions_handler() -> Result<impl warp::Reply, Rejection> {
info!("Listing extensions");
@@ -331,13 +412,21 @@ async fn main() -> Result<(), anyhow::Error> {
.and(with_api_key(api_key.clone()))
.and_then(start_session_handler);
// Session reply endpoint
// Session reply endpoint
let reply_session = warp::path("session")
.and(warp::path("reply"))
.and(warp::post())
.and(warp::body::json())
.and(with_api_key(api_key.clone()))
.and_then(reply_session_handler);
// Session end endpoint
let end_session = warp::path("session")
.and(warp::path("end"))
.and(warp::post())
.and(warp::body::json())
.and(with_api_key(api_key.clone()))
.and_then(end_session_handler);
// List extensions endpoint
let list_extensions = warp::path("extensions")
@@ -354,6 +443,7 @@ async fn main() -> Result<(), anyhow::Error> {
// Combine all routes
let routes = start_session
.or(reply_session)
.or(end_session)
.or(list_extensions)
.or(get_provider_config);