mirror of
https://github.com/aljazceru/goose.git
synced 2025-12-18 14:44:21 +01:00
Add session management API
This commit is contained in:
@@ -18,4 +18,6 @@ config = "0.13"
|
||||
jsonwebtoken = "8"
|
||||
futures = "0.3"
|
||||
futures-util = "0.3"
|
||||
# For session IDs
|
||||
uuid = { version = "1", features = ["serde", "v4"] }
|
||||
# Add dynamic-library for extension loading
|
||||
|
||||
@@ -2,6 +2,8 @@ use warp::{Filter, Rejection};
|
||||
use warp::http::HeaderValue;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::sync::LazyLock;
|
||||
use std::collections::HashMap;
|
||||
use uuid::Uuid;
|
||||
use goose::agents::{Agent, extension_manager::ExtensionManager};
|
||||
use goose::config::Config;
|
||||
use goose::providers::{create, providers};
|
||||
@@ -20,6 +22,11 @@ static AGENT: LazyLock<tokio::sync::Mutex<Agent>> = LazyLock::new(|| {
|
||||
tokio::sync::Mutex::new(Agent::new())
|
||||
});
|
||||
|
||||
// Global store for session histories
|
||||
static SESSION_HISTORY: LazyLock<tokio::sync::Mutex<HashMap<Uuid, Vec<Message>>>> = LazyLock::new(|| {
|
||||
tokio::sync::Mutex::new(HashMap::new())
|
||||
});
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct SessionRequest {
|
||||
prompt: String,
|
||||
@@ -31,6 +38,24 @@ struct ApiResponse {
|
||||
status: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct StartSessionResponse {
|
||||
message: String,
|
||||
status: String,
|
||||
session_id: Uuid,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct SessionReplyRequest {
|
||||
session_id: Uuid,
|
||||
prompt: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct EndSessionRequest {
|
||||
session_id: Uuid,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct ExtensionsResponse {
|
||||
extensions: Vec<String>,
|
||||
@@ -47,32 +72,43 @@ async fn start_session_handler(
|
||||
_api_key: String,
|
||||
) -> Result<impl warp::Reply, Rejection> {
|
||||
info!("Starting session with prompt: {}", req.prompt);
|
||||
|
||||
let agent = AGENT.lock().await;
|
||||
|
||||
|
||||
let mut agent = AGENT.lock().await;
|
||||
|
||||
// Create a user message with the prompt
|
||||
let messages = vec![Message::user().with_text(&req.prompt)];
|
||||
|
||||
// Process the messages through the agent
|
||||
let mut messages = vec![Message::user().with_text(&req.prompt)];
|
||||
|
||||
// Generate a new session ID and process the messages
|
||||
let session_id = Uuid::new_v4();
|
||||
|
||||
let result = agent.reply(&messages, None).await;
|
||||
|
||||
|
||||
match result {
|
||||
Ok(mut stream) => {
|
||||
// Process the stream to get the first response
|
||||
if let Ok(Some(response)) = stream.try_next().await {
|
||||
let response_text = response.as_concat_text();
|
||||
let api_response = ApiResponse {
|
||||
message: format!("Session started with prompt: {}. Response: {}", req.prompt, response_text),
|
||||
messages.push(response);
|
||||
let mut history = SESSION_HISTORY.lock().await;
|
||||
history.insert(session_id, messages);
|
||||
|
||||
let api_response = StartSessionResponse {
|
||||
message: response_text,
|
||||
status: "success".to_string(),
|
||||
session_id,
|
||||
};
|
||||
Ok(warp::reply::with_status(
|
||||
warp::reply::json(&api_response),
|
||||
warp::http::StatusCode::OK,
|
||||
))
|
||||
} else {
|
||||
let api_response = ApiResponse {
|
||||
message: format!("Session started but no response generated"),
|
||||
let mut history = SESSION_HISTORY.lock().await;
|
||||
history.insert(session_id, messages);
|
||||
|
||||
let api_response = StartSessionResponse {
|
||||
message: "Session started but no response generated".to_string(),
|
||||
status: "warning".to_string(),
|
||||
session_id,
|
||||
};
|
||||
Ok(warp::reply::with_status(
|
||||
warp::reply::json(&api_response),
|
||||
@@ -95,16 +131,33 @@ async fn start_session_handler(
|
||||
}
|
||||
|
||||
async fn reply_session_handler(
|
||||
req: SessionRequest,
|
||||
req: SessionReplyRequest,
|
||||
_api_key: String,
|
||||
) -> Result<impl warp::Reply, Rejection> {
|
||||
info!("Replying to session with prompt: {}", req.prompt);
|
||||
|
||||
let agent = AGENT.lock().await;
|
||||
|
||||
// Create a user message with the prompt
|
||||
let messages = vec![Message::user().with_text(&req.prompt)];
|
||||
|
||||
|
||||
let mut agent = AGENT.lock().await;
|
||||
|
||||
// Retrieve existing session history
|
||||
let mut history = SESSION_HISTORY.lock().await;
|
||||
let entry = match history.get_mut(&req.session_id) {
|
||||
Some(messages) => messages,
|
||||
None => {
|
||||
let response = ApiResponse {
|
||||
message: "Session not found".to_string(),
|
||||
status: "error".to_string(),
|
||||
};
|
||||
return Ok(warp::reply::with_status(
|
||||
warp::reply::json(&response),
|
||||
warp::http::StatusCode::NOT_FOUND,
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
// Append the new user message
|
||||
entry.push(Message::user().with_text(&req.prompt));
|
||||
let messages = entry.clone();
|
||||
|
||||
// Process the messages through the agent
|
||||
let result = agent.reply(&messages, None).await;
|
||||
|
||||
@@ -113,6 +166,8 @@ async fn reply_session_handler(
|
||||
// Process the stream to get the first response
|
||||
if let Ok(Some(response)) = stream.try_next().await {
|
||||
let response_text = response.as_concat_text();
|
||||
// store assistant response in history
|
||||
entry.push(response);
|
||||
let api_response = ApiResponse {
|
||||
message: format!("Reply: {}", response_text),
|
||||
status: "success".to_string(),
|
||||
@@ -123,7 +178,7 @@ async fn reply_session_handler(
|
||||
))
|
||||
} else {
|
||||
let api_response = ApiResponse {
|
||||
message: format!("Reply processed but no response generated"),
|
||||
message: "Reply processed but no response generated".to_string(),
|
||||
status: "warning".to_string(),
|
||||
};
|
||||
Ok(warp::reply::with_status(
|
||||
@@ -146,6 +201,32 @@ async fn reply_session_handler(
|
||||
}
|
||||
}
|
||||
|
||||
async fn end_session_handler(
|
||||
req: EndSessionRequest,
|
||||
_api_key: String,
|
||||
) -> Result<impl warp::Reply, Rejection> {
|
||||
let mut history = SESSION_HISTORY.lock().await;
|
||||
if history.remove(&req.session_id).is_some() {
|
||||
let response = ApiResponse {
|
||||
message: "Session ended".to_string(),
|
||||
status: "success".to_string(),
|
||||
};
|
||||
Ok(warp::reply::with_status(
|
||||
warp::reply::json(&response),
|
||||
warp::http::StatusCode::OK,
|
||||
))
|
||||
} else {
|
||||
let response = ApiResponse {
|
||||
message: "Session not found".to_string(),
|
||||
status: "error".to_string(),
|
||||
};
|
||||
Ok(warp::reply::with_status(
|
||||
warp::reply::json(&response),
|
||||
warp::http::StatusCode::NOT_FOUND,
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
async fn list_extensions_handler() -> Result<impl warp::Reply, Rejection> {
|
||||
info!("Listing extensions");
|
||||
|
||||
@@ -331,13 +412,21 @@ async fn main() -> Result<(), anyhow::Error> {
|
||||
.and(with_api_key(api_key.clone()))
|
||||
.and_then(start_session_handler);
|
||||
|
||||
// Session reply endpoint
|
||||
// Session reply endpoint
|
||||
let reply_session = warp::path("session")
|
||||
.and(warp::path("reply"))
|
||||
.and(warp::post())
|
||||
.and(warp::body::json())
|
||||
.and(with_api_key(api_key.clone()))
|
||||
.and_then(reply_session_handler);
|
||||
|
||||
// Session end endpoint
|
||||
let end_session = warp::path("session")
|
||||
.and(warp::path("end"))
|
||||
.and(warp::post())
|
||||
.and(warp::body::json())
|
||||
.and(with_api_key(api_key.clone()))
|
||||
.and_then(end_session_handler);
|
||||
|
||||
// List extensions endpoint
|
||||
let list_extensions = warp::path("extensions")
|
||||
@@ -354,6 +443,7 @@ async fn main() -> Result<(), anyhow::Error> {
|
||||
// Combine all routes
|
||||
let routes = start_session
|
||||
.or(reply_session)
|
||||
.or(end_session)
|
||||
.or(list_extensions)
|
||||
.or(get_provider_config);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user