mirror of
https://github.com/aljazceru/goose.git
synced 2025-12-18 22:54:24 +01:00
Add session management API
This commit is contained in:
@@ -18,4 +18,6 @@ config = "0.13"
|
|||||||
jsonwebtoken = "8"
|
jsonwebtoken = "8"
|
||||||
futures = "0.3"
|
futures = "0.3"
|
||||||
futures-util = "0.3"
|
futures-util = "0.3"
|
||||||
|
# For session IDs
|
||||||
|
uuid = { version = "1", features = ["serde", "v4"] }
|
||||||
# Add dynamic-library for extension loading
|
# Add dynamic-library for extension loading
|
||||||
|
|||||||
@@ -2,6 +2,8 @@ use warp::{Filter, Rejection};
|
|||||||
use warp::http::HeaderValue;
|
use warp::http::HeaderValue;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use std::sync::LazyLock;
|
use std::sync::LazyLock;
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use uuid::Uuid;
|
||||||
use goose::agents::{Agent, extension_manager::ExtensionManager};
|
use goose::agents::{Agent, extension_manager::ExtensionManager};
|
||||||
use goose::config::Config;
|
use goose::config::Config;
|
||||||
use goose::providers::{create, providers};
|
use goose::providers::{create, providers};
|
||||||
@@ -20,6 +22,11 @@ static AGENT: LazyLock<tokio::sync::Mutex<Agent>> = LazyLock::new(|| {
|
|||||||
tokio::sync::Mutex::new(Agent::new())
|
tokio::sync::Mutex::new(Agent::new())
|
||||||
});
|
});
|
||||||
|
|
||||||
|
// Global store for session histories
|
||||||
|
static SESSION_HISTORY: LazyLock<tokio::sync::Mutex<HashMap<Uuid, Vec<Message>>>> = LazyLock::new(|| {
|
||||||
|
tokio::sync::Mutex::new(HashMap::new())
|
||||||
|
});
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize)]
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
struct SessionRequest {
|
struct SessionRequest {
|
||||||
prompt: String,
|
prompt: String,
|
||||||
@@ -31,6 +38,24 @@ struct ApiResponse {
|
|||||||
status: String,
|
status: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
|
struct StartSessionResponse {
|
||||||
|
message: String,
|
||||||
|
status: String,
|
||||||
|
session_id: Uuid,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
|
struct SessionReplyRequest {
|
||||||
|
session_id: Uuid,
|
||||||
|
prompt: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
|
struct EndSessionRequest {
|
||||||
|
session_id: Uuid,
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize)]
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
struct ExtensionsResponse {
|
struct ExtensionsResponse {
|
||||||
extensions: Vec<String>,
|
extensions: Vec<String>,
|
||||||
@@ -47,32 +72,43 @@ async fn start_session_handler(
|
|||||||
_api_key: String,
|
_api_key: String,
|
||||||
) -> Result<impl warp::Reply, Rejection> {
|
) -> Result<impl warp::Reply, Rejection> {
|
||||||
info!("Starting session with prompt: {}", req.prompt);
|
info!("Starting session with prompt: {}", req.prompt);
|
||||||
|
|
||||||
let agent = AGENT.lock().await;
|
let mut agent = AGENT.lock().await;
|
||||||
|
|
||||||
// Create a user message with the prompt
|
// Create a user message with the prompt
|
||||||
let messages = vec![Message::user().with_text(&req.prompt)];
|
let mut messages = vec![Message::user().with_text(&req.prompt)];
|
||||||
|
|
||||||
// Process the messages through the agent
|
// Generate a new session ID and process the messages
|
||||||
|
let session_id = Uuid::new_v4();
|
||||||
|
|
||||||
let result = agent.reply(&messages, None).await;
|
let result = agent.reply(&messages, None).await;
|
||||||
|
|
||||||
match result {
|
match result {
|
||||||
Ok(mut stream) => {
|
Ok(mut stream) => {
|
||||||
// Process the stream to get the first response
|
// Process the stream to get the first response
|
||||||
if let Ok(Some(response)) = stream.try_next().await {
|
if let Ok(Some(response)) = stream.try_next().await {
|
||||||
let response_text = response.as_concat_text();
|
let response_text = response.as_concat_text();
|
||||||
let api_response = ApiResponse {
|
messages.push(response);
|
||||||
message: format!("Session started with prompt: {}. Response: {}", req.prompt, response_text),
|
let mut history = SESSION_HISTORY.lock().await;
|
||||||
|
history.insert(session_id, messages);
|
||||||
|
|
||||||
|
let api_response = StartSessionResponse {
|
||||||
|
message: response_text,
|
||||||
status: "success".to_string(),
|
status: "success".to_string(),
|
||||||
|
session_id,
|
||||||
};
|
};
|
||||||
Ok(warp::reply::with_status(
|
Ok(warp::reply::with_status(
|
||||||
warp::reply::json(&api_response),
|
warp::reply::json(&api_response),
|
||||||
warp::http::StatusCode::OK,
|
warp::http::StatusCode::OK,
|
||||||
))
|
))
|
||||||
} else {
|
} else {
|
||||||
let api_response = ApiResponse {
|
let mut history = SESSION_HISTORY.lock().await;
|
||||||
message: format!("Session started but no response generated"),
|
history.insert(session_id, messages);
|
||||||
|
|
||||||
|
let api_response = StartSessionResponse {
|
||||||
|
message: "Session started but no response generated".to_string(),
|
||||||
status: "warning".to_string(),
|
status: "warning".to_string(),
|
||||||
|
session_id,
|
||||||
};
|
};
|
||||||
Ok(warp::reply::with_status(
|
Ok(warp::reply::with_status(
|
||||||
warp::reply::json(&api_response),
|
warp::reply::json(&api_response),
|
||||||
@@ -95,16 +131,33 @@ async fn start_session_handler(
|
|||||||
}
|
}
|
||||||
|
|
||||||
async fn reply_session_handler(
|
async fn reply_session_handler(
|
||||||
req: SessionRequest,
|
req: SessionReplyRequest,
|
||||||
_api_key: String,
|
_api_key: String,
|
||||||
) -> Result<impl warp::Reply, Rejection> {
|
) -> Result<impl warp::Reply, Rejection> {
|
||||||
info!("Replying to session with prompt: {}", req.prompt);
|
info!("Replying to session with prompt: {}", req.prompt);
|
||||||
|
|
||||||
let agent = AGENT.lock().await;
|
let mut agent = AGENT.lock().await;
|
||||||
|
|
||||||
// Create a user message with the prompt
|
// Retrieve existing session history
|
||||||
let messages = vec![Message::user().with_text(&req.prompt)];
|
let mut history = SESSION_HISTORY.lock().await;
|
||||||
|
let entry = match history.get_mut(&req.session_id) {
|
||||||
|
Some(messages) => messages,
|
||||||
|
None => {
|
||||||
|
let response = ApiResponse {
|
||||||
|
message: "Session not found".to_string(),
|
||||||
|
status: "error".to_string(),
|
||||||
|
};
|
||||||
|
return Ok(warp::reply::with_status(
|
||||||
|
warp::reply::json(&response),
|
||||||
|
warp::http::StatusCode::NOT_FOUND,
|
||||||
|
));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Append the new user message
|
||||||
|
entry.push(Message::user().with_text(&req.prompt));
|
||||||
|
let messages = entry.clone();
|
||||||
|
|
||||||
// Process the messages through the agent
|
// Process the messages through the agent
|
||||||
let result = agent.reply(&messages, None).await;
|
let result = agent.reply(&messages, None).await;
|
||||||
|
|
||||||
@@ -113,6 +166,8 @@ async fn reply_session_handler(
|
|||||||
// Process the stream to get the first response
|
// Process the stream to get the first response
|
||||||
if let Ok(Some(response)) = stream.try_next().await {
|
if let Ok(Some(response)) = stream.try_next().await {
|
||||||
let response_text = response.as_concat_text();
|
let response_text = response.as_concat_text();
|
||||||
|
// store assistant response in history
|
||||||
|
entry.push(response);
|
||||||
let api_response = ApiResponse {
|
let api_response = ApiResponse {
|
||||||
message: format!("Reply: {}", response_text),
|
message: format!("Reply: {}", response_text),
|
||||||
status: "success".to_string(),
|
status: "success".to_string(),
|
||||||
@@ -123,7 +178,7 @@ async fn reply_session_handler(
|
|||||||
))
|
))
|
||||||
} else {
|
} else {
|
||||||
let api_response = ApiResponse {
|
let api_response = ApiResponse {
|
||||||
message: format!("Reply processed but no response generated"),
|
message: "Reply processed but no response generated".to_string(),
|
||||||
status: "warning".to_string(),
|
status: "warning".to_string(),
|
||||||
};
|
};
|
||||||
Ok(warp::reply::with_status(
|
Ok(warp::reply::with_status(
|
||||||
@@ -146,6 +201,32 @@ async fn reply_session_handler(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn end_session_handler(
|
||||||
|
req: EndSessionRequest,
|
||||||
|
_api_key: String,
|
||||||
|
) -> Result<impl warp::Reply, Rejection> {
|
||||||
|
let mut history = SESSION_HISTORY.lock().await;
|
||||||
|
if history.remove(&req.session_id).is_some() {
|
||||||
|
let response = ApiResponse {
|
||||||
|
message: "Session ended".to_string(),
|
||||||
|
status: "success".to_string(),
|
||||||
|
};
|
||||||
|
Ok(warp::reply::with_status(
|
||||||
|
warp::reply::json(&response),
|
||||||
|
warp::http::StatusCode::OK,
|
||||||
|
))
|
||||||
|
} else {
|
||||||
|
let response = ApiResponse {
|
||||||
|
message: "Session not found".to_string(),
|
||||||
|
status: "error".to_string(),
|
||||||
|
};
|
||||||
|
Ok(warp::reply::with_status(
|
||||||
|
warp::reply::json(&response),
|
||||||
|
warp::http::StatusCode::NOT_FOUND,
|
||||||
|
))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
async fn list_extensions_handler() -> Result<impl warp::Reply, Rejection> {
|
async fn list_extensions_handler() -> Result<impl warp::Reply, Rejection> {
|
||||||
info!("Listing extensions");
|
info!("Listing extensions");
|
||||||
|
|
||||||
@@ -331,13 +412,21 @@ async fn main() -> Result<(), anyhow::Error> {
|
|||||||
.and(with_api_key(api_key.clone()))
|
.and(with_api_key(api_key.clone()))
|
||||||
.and_then(start_session_handler);
|
.and_then(start_session_handler);
|
||||||
|
|
||||||
// Session reply endpoint
|
// Session reply endpoint
|
||||||
let reply_session = warp::path("session")
|
let reply_session = warp::path("session")
|
||||||
.and(warp::path("reply"))
|
.and(warp::path("reply"))
|
||||||
.and(warp::post())
|
.and(warp::post())
|
||||||
.and(warp::body::json())
|
.and(warp::body::json())
|
||||||
.and(with_api_key(api_key.clone()))
|
.and(with_api_key(api_key.clone()))
|
||||||
.and_then(reply_session_handler);
|
.and_then(reply_session_handler);
|
||||||
|
|
||||||
|
// Session end endpoint
|
||||||
|
let end_session = warp::path("session")
|
||||||
|
.and(warp::path("end"))
|
||||||
|
.and(warp::post())
|
||||||
|
.and(warp::body::json())
|
||||||
|
.and(with_api_key(api_key.clone()))
|
||||||
|
.and_then(end_session_handler);
|
||||||
|
|
||||||
// List extensions endpoint
|
// List extensions endpoint
|
||||||
let list_extensions = warp::path("extensions")
|
let list_extensions = warp::path("extensions")
|
||||||
@@ -354,6 +443,7 @@ async fn main() -> Result<(), anyhow::Error> {
|
|||||||
// Combine all routes
|
// Combine all routes
|
||||||
let routes = start_session
|
let routes = start_session
|
||||||
.or(reply_session)
|
.or(reply_session)
|
||||||
|
.or(end_session)
|
||||||
.or(list_extensions)
|
.or(list_extensions)
|
||||||
.or(get_provider_config);
|
.or(get_provider_config);
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user