mirror of
https://github.com/aljazceru/goose.git
synced 2025-12-20 15:44:25 +01:00
refactor(api): inject server state
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
use warp::{Filter, Rejection};
|
||||
use warp::http::HeaderValue;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::sync::LazyLock;
|
||||
use std::sync::Arc;
|
||||
use goose::config::{Config, ExtensionEntry};
|
||||
use goose::agents::{
|
||||
extension::Envs,
|
||||
@@ -14,7 +14,6 @@ use uuid::Uuid;
|
||||
use goose::session::{self, Identifier};
|
||||
use goose::agents::SessionConfig;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
|
||||
use goose::providers::{create, providers};
|
||||
use goose::model::ModelConfig;
|
||||
@@ -24,13 +23,11 @@ use config::{builder::DefaultState, ConfigBuilder, Environment, File};
|
||||
use serde_json::Value; // Import the correct Value type
|
||||
use futures_util::TryStreamExt;
|
||||
|
||||
// Global extension manager for extension listing
|
||||
static EXTENSION_MANAGER: LazyLock<ExtensionManager> = LazyLock::new(|| ExtensionManager::default());
|
||||
|
||||
// Global agent for handling sessions
|
||||
static AGENT: LazyLock<tokio::sync::Mutex<Agent>> = LazyLock::new(|| {
|
||||
tokio::sync::Mutex::new(Agent::new())
|
||||
});
|
||||
#[derive(Clone)]
|
||||
struct ServerState {
|
||||
agent: Arc<tokio::sync::Mutex<Agent>>,
|
||||
extension_manager: Arc<tokio::sync::Mutex<ExtensionManager>>,
|
||||
}
|
||||
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
@@ -120,11 +117,12 @@ enum ExtensionConfigRequest {
|
||||
|
||||
async fn start_session_handler(
|
||||
req: SessionRequest,
|
||||
state: ServerState,
|
||||
_api_key: String,
|
||||
) -> Result<impl warp::Reply, Rejection> {
|
||||
info!("Starting session with prompt: {}", req.prompt);
|
||||
|
||||
let mut agent = AGENT.lock().await;
|
||||
let mut agent = state.agent.lock().await;
|
||||
|
||||
// Create a user message with the prompt
|
||||
let mut messages = vec![Message::user().with_text(&req.prompt)];
|
||||
@@ -197,11 +195,12 @@ async fn start_session_handler(
|
||||
|
||||
async fn reply_session_handler(
|
||||
req: SessionReplyRequest,
|
||||
state: ServerState,
|
||||
_api_key: String,
|
||||
) -> Result<impl warp::Reply, Rejection> {
|
||||
info!("Replying to session with prompt: {}", req.prompt);
|
||||
|
||||
let mut agent = AGENT.lock().await;
|
||||
let mut agent = state.agent.lock().await;
|
||||
|
||||
let session_name = req.session_id.to_string();
|
||||
let session_path = session::get_path(Identifier::Name(session_name.clone()));
|
||||
@@ -284,6 +283,7 @@ async fn reply_session_handler(
|
||||
|
||||
async fn end_session_handler(
|
||||
req: EndSessionRequest,
|
||||
_state: ServerState,
|
||||
_api_key: String,
|
||||
) -> Result<impl warp::Reply, Rejection> {
|
||||
let session_name = req.session_id.to_string();
|
||||
@@ -310,10 +310,11 @@ async fn end_session_handler(
|
||||
}
|
||||
}
|
||||
|
||||
async fn list_extensions_handler() -> Result<impl warp::Reply, Rejection> {
|
||||
async fn list_extensions_handler(state: ServerState) -> Result<impl warp::Reply, Rejection> {
|
||||
info!("Listing extensions");
|
||||
|
||||
match EXTENSION_MANAGER.list_extensions().await {
|
||||
let manager = state.extension_manager.lock().await;
|
||||
match manager.list_extensions().await {
|
||||
Ok(exts) => {
|
||||
let response = ExtensionsResponse { extensions: exts };
|
||||
Ok::<warp::reply::Json, warp::Rejection>(warp::reply::json(&response))
|
||||
@@ -328,7 +329,7 @@ async fn list_extensions_handler() -> Result<impl warp::Reply, Rejection> {
|
||||
}
|
||||
}
|
||||
|
||||
async fn get_provider_config_handler() -> Result<impl warp::Reply, Rejection> {
|
||||
async fn get_provider_config_handler(_state: ServerState) -> Result<impl warp::Reply, Rejection> {
|
||||
info!("Getting provider configuration");
|
||||
|
||||
let config = Config::global();
|
||||
@@ -343,6 +344,7 @@ async fn get_provider_config_handler() -> Result<impl warp::Reply, Rejection> {
|
||||
|
||||
async fn add_extension_handler(
|
||||
req: ExtensionConfigRequest,
|
||||
state: ServerState,
|
||||
_api_key: String,
|
||||
) -> Result<impl warp::Reply, Rejection> {
|
||||
info!("Adding extension: {:?}", req);
|
||||
@@ -438,7 +440,7 @@ async fn add_extension_handler(
|
||||
}
|
||||
};
|
||||
|
||||
let agent = AGENT.lock().await;
|
||||
let agent = state.agent.lock().await;
|
||||
let result = agent.add_extension(extension).await;
|
||||
|
||||
let resp = match result {
|
||||
@@ -453,10 +455,11 @@ async fn add_extension_handler(
|
||||
|
||||
async fn remove_extension_handler(
|
||||
name: String,
|
||||
state: ServerState,
|
||||
_api_key: String,
|
||||
) -> Result<impl warp::Reply, Rejection> {
|
||||
info!("Removing extension: {}", name);
|
||||
let agent = AGENT.lock().await;
|
||||
let agent = state.agent.lock().await;
|
||||
agent.remove_extension(&name).await;
|
||||
|
||||
let resp = ExtensionResponse { error: false, message: None };
|
||||
@@ -488,7 +491,7 @@ fn load_configuration() -> std::result::Result<config::Config, config::ConfigErr
|
||||
}
|
||||
|
||||
// Initialize global provider configuration
|
||||
async fn initialize_provider_config() -> Result<(), anyhow::Error> {
|
||||
async fn initialize_provider_config(state: &ServerState) -> Result<(), anyhow::Error> {
|
||||
// Get configuration
|
||||
let api_config = load_configuration()?;
|
||||
|
||||
@@ -535,14 +538,14 @@ async fn initialize_provider_config() -> Result<(), anyhow::Error> {
|
||||
let model_config = ModelConfig::new(model_name);
|
||||
let provider = create(&provider_name, model_config)?;
|
||||
|
||||
let agent = AGENT.lock().await;
|
||||
let agent = state.agent.lock().await;
|
||||
agent.update_provider(provider).await?;
|
||||
|
||||
info!("Provider configuration successful");
|
||||
Ok(())
|
||||
}
|
||||
/// Initialize extensions from the configuration.
|
||||
async fn initialize_extensions(config: &config::Config) -> Result<(), anyhow::Error> {
|
||||
async fn initialize_extensions(state: &ServerState, config: &config::Config) -> Result<(), anyhow::Error> {
|
||||
if let Ok(ext_table) = config.get_table("extensions") {
|
||||
for (name, ext_config) in ext_table {
|
||||
// Deserialize into ExtensionEntry to get enabled flag and config
|
||||
@@ -552,7 +555,7 @@ async fn initialize_extensions(config: &config::Config) -> Result<(), anyhow::Er
|
||||
if entry.enabled {
|
||||
let extension_config: ExtensionConfig = entry.config;
|
||||
// Acquire the global agent lock and try to add the extension
|
||||
let mut agent = AGENT.lock().await;
|
||||
let mut agent = state.agent.lock().await;
|
||||
if let Err(e) = agent.add_extension(extension_config).await {
|
||||
error!("Failed to add extension {}: {}", name, e);
|
||||
}
|
||||
@@ -567,10 +570,10 @@ async fn initialize_extensions(config: &config::Config) -> Result<(), anyhow::Er
|
||||
}
|
||||
|
||||
|
||||
async fn run_init_tests() -> Result<(), anyhow::Error> {
|
||||
async fn run_init_tests(state: &ServerState) -> Result<(), anyhow::Error> {
|
||||
info!("Running initialization tests");
|
||||
{
|
||||
let _agent = AGENT.lock().await;
|
||||
let _agent = state.agent.lock().await;
|
||||
info!("Agent initialization test passed");
|
||||
}
|
||||
info!("Initialization tests completed");
|
||||
@@ -597,26 +600,34 @@ async fn main() -> Result<(), anyhow::Error> {
|
||||
"default_api_key".to_string()
|
||||
});
|
||||
|
||||
let state = ServerState {
|
||||
agent: Arc::new(tokio::sync::Mutex::new(Agent::new())),
|
||||
extension_manager: Arc::new(tokio::sync::Mutex::new(ExtensionManager::default())),
|
||||
};
|
||||
|
||||
// Initialize provider configuration
|
||||
if let Err(e) = initialize_provider_config().await {
|
||||
if let Err(e) = initialize_provider_config(&state).await {
|
||||
error!("Failed to initialize provider: {}", e);
|
||||
return Err(e);
|
||||
}
|
||||
|
||||
|
||||
// Initialize extensions from configuration
|
||||
if let Err(e) = initialize_extensions(&api_config).await {
|
||||
if let Err(e) = initialize_extensions(&state, &api_config).await {
|
||||
error!("Failed to initialize extensions: {}", e);
|
||||
}
|
||||
|
||||
if let Err(e) = run_init_tests().await {
|
||||
|
||||
if let Err(e) = run_init_tests(&state).await {
|
||||
error!("Initialization tests failed: {}", e);
|
||||
}
|
||||
|
||||
let state_filter = warp::any().map(move || state.clone());
|
||||
|
||||
// Session start endpoint
|
||||
let start_session = warp::path("session")
|
||||
.and(warp::path("start"))
|
||||
.and(warp::post())
|
||||
.and(warp::body::json())
|
||||
.and(state_filter.clone())
|
||||
.and(with_api_key(api_key.clone()))
|
||||
.and_then(start_session_handler);
|
||||
|
||||
@@ -625,6 +636,7 @@ async fn main() -> Result<(), anyhow::Error> {
|
||||
.and(warp::path("reply"))
|
||||
.and(warp::post())
|
||||
.and(warp::body::json())
|
||||
.and(state_filter.clone())
|
||||
.and(with_api_key(api_key.clone()))
|
||||
.and_then(reply_session_handler);
|
||||
|
||||
@@ -633,6 +645,7 @@ async fn main() -> Result<(), anyhow::Error> {
|
||||
.and(warp::path("end"))
|
||||
.and(warp::post())
|
||||
.and(warp::body::json())
|
||||
.and(state_filter.clone())
|
||||
.and(with_api_key(api_key.clone()))
|
||||
.and_then(end_session_handler);
|
||||
|
||||
@@ -640,6 +653,7 @@ async fn main() -> Result<(), anyhow::Error> {
|
||||
let list_extensions = warp::path("extensions")
|
||||
.and(warp::path("list"))
|
||||
.and(warp::get())
|
||||
.and(state_filter.clone())
|
||||
.and_then(list_extensions_handler);
|
||||
|
||||
// Add extension endpoint
|
||||
@@ -647,6 +661,7 @@ async fn main() -> Result<(), anyhow::Error> {
|
||||
.and(warp::path("add"))
|
||||
.and(warp::post())
|
||||
.and(warp::body::json())
|
||||
.and(state_filter.clone())
|
||||
.and(with_api_key(api_key.clone()))
|
||||
.and_then(add_extension_handler);
|
||||
|
||||
@@ -655,6 +670,7 @@ async fn main() -> Result<(), anyhow::Error> {
|
||||
.and(warp::path("remove"))
|
||||
.and(warp::post())
|
||||
.and(warp::body::json())
|
||||
.and(state_filter.clone())
|
||||
.and(with_api_key(api_key.clone()))
|
||||
.and_then(remove_extension_handler);
|
||||
|
||||
@@ -662,6 +678,7 @@ async fn main() -> Result<(), anyhow::Error> {
|
||||
let get_provider_config = warp::path("provider")
|
||||
.and(warp::path("config"))
|
||||
.and(warp::get())
|
||||
.and(state_filter.clone())
|
||||
.and_then(get_provider_config_handler);
|
||||
|
||||
// Combine all routes
|
||||
|
||||
Reference in New Issue
Block a user