From a974090b5e719ddcc6aec2399828458c0367f916 Mon Sep 17 00:00:00 2001 From: Aljaz Date: Thu, 29 May 2025 10:12:07 +0200 Subject: [PATCH 1/4] refactor(api): inject server state --- crates/goose-api/src/main.rs | 71 ++++++++++++++++++++++-------------- 1 file changed, 44 insertions(+), 27 deletions(-) diff --git a/crates/goose-api/src/main.rs b/crates/goose-api/src/main.rs index 9a68f872..99ce0157 100644 --- a/crates/goose-api/src/main.rs +++ b/crates/goose-api/src/main.rs @@ -1,7 +1,7 @@ use warp::{Filter, Rejection}; use warp::http::HeaderValue; use serde::{Deserialize, Serialize}; -use std::sync::LazyLock; +use std::sync::Arc; use goose::config::{Config, ExtensionEntry}; use goose::agents::{ extension::Envs, @@ -14,7 +14,6 @@ use uuid::Uuid; use goose::session::{self, Identifier}; use goose::agents::SessionConfig; use std::path::PathBuf; -use std::sync::Arc; use goose::providers::{create, providers}; use goose::model::ModelConfig; @@ -24,13 +23,11 @@ use config::{builder::DefaultState, ConfigBuilder, Environment, File}; use serde_json::Value; // Import the correct Value type use futures_util::TryStreamExt; -// Global extension manager for extension listing -static EXTENSION_MANAGER: LazyLock = LazyLock::new(|| ExtensionManager::default()); - -// Global agent for handling sessions -static AGENT: LazyLock> = LazyLock::new(|| { - tokio::sync::Mutex::new(Agent::new()) -}); +#[derive(Clone)] +struct ServerState { + agent: Arc>, + extension_manager: Arc>, +} #[derive(Debug, Serialize, Deserialize)] @@ -120,11 +117,12 @@ enum ExtensionConfigRequest { async fn start_session_handler( req: SessionRequest, + state: ServerState, _api_key: String, ) -> Result { info!("Starting session with prompt: {}", req.prompt); - let mut agent = AGENT.lock().await; + let mut agent = state.agent.lock().await; // Create a user message with the prompt let mut messages = vec![Message::user().with_text(&req.prompt)]; @@ -197,11 +195,12 @@ async fn start_session_handler( async fn reply_session_handler( req: SessionReplyRequest, + state: ServerState, _api_key: String, ) -> Result { info!("Replying to session with prompt: {}", req.prompt); - let mut agent = AGENT.lock().await; + let mut agent = state.agent.lock().await; let session_name = req.session_id.to_string(); let session_path = session::get_path(Identifier::Name(session_name.clone())); @@ -284,6 +283,7 @@ async fn reply_session_handler( async fn end_session_handler( req: EndSessionRequest, + _state: ServerState, _api_key: String, ) -> Result { let session_name = req.session_id.to_string(); @@ -310,10 +310,11 @@ async fn end_session_handler( } } -async fn list_extensions_handler() -> Result { +async fn list_extensions_handler(state: ServerState) -> Result { info!("Listing extensions"); - match EXTENSION_MANAGER.list_extensions().await { + let manager = state.extension_manager.lock().await; + match manager.list_extensions().await { Ok(exts) => { let response = ExtensionsResponse { extensions: exts }; Ok::(warp::reply::json(&response)) @@ -328,7 +329,7 @@ async fn list_extensions_handler() -> Result { } } -async fn get_provider_config_handler() -> Result { +async fn get_provider_config_handler(_state: ServerState) -> Result { info!("Getting provider configuration"); let config = Config::global(); @@ -343,6 +344,7 @@ async fn get_provider_config_handler() -> Result { async fn add_extension_handler( req: ExtensionConfigRequest, + state: ServerState, _api_key: String, ) -> Result { info!("Adding extension: {:?}", req); @@ -438,7 +440,7 @@ async fn add_extension_handler( } }; - let agent = AGENT.lock().await; + let agent = state.agent.lock().await; let result = agent.add_extension(extension).await; let resp = match result { @@ -453,10 +455,11 @@ async fn add_extension_handler( async fn remove_extension_handler( name: String, + state: ServerState, _api_key: String, ) -> Result { info!("Removing extension: {}", name); - let agent = AGENT.lock().await; + let agent = state.agent.lock().await; agent.remove_extension(&name).await; let resp = ExtensionResponse { error: false, message: None }; @@ -488,7 +491,7 @@ fn load_configuration() -> std::result::Result Result<(), anyhow::Error> { +async fn initialize_provider_config(state: &ServerState) -> Result<(), anyhow::Error> { // Get configuration let api_config = load_configuration()?; @@ -535,14 +538,14 @@ async fn initialize_provider_config() -> Result<(), anyhow::Error> { let model_config = ModelConfig::new(model_name); let provider = create(&provider_name, model_config)?; - let agent = AGENT.lock().await; + let agent = state.agent.lock().await; agent.update_provider(provider).await?; info!("Provider configuration successful"); Ok(()) } /// Initialize extensions from the configuration. -async fn initialize_extensions(config: &config::Config) -> Result<(), anyhow::Error> { +async fn initialize_extensions(state: &ServerState, config: &config::Config) -> Result<(), anyhow::Error> { if let Ok(ext_table) = config.get_table("extensions") { for (name, ext_config) in ext_table { // Deserialize into ExtensionEntry to get enabled flag and config @@ -552,7 +555,7 @@ async fn initialize_extensions(config: &config::Config) -> Result<(), anyhow::Er if entry.enabled { let extension_config: ExtensionConfig = entry.config; // Acquire the global agent lock and try to add the extension - let mut agent = AGENT.lock().await; + let mut agent = state.agent.lock().await; if let Err(e) = agent.add_extension(extension_config).await { error!("Failed to add extension {}: {}", name, e); } @@ -567,10 +570,10 @@ async fn initialize_extensions(config: &config::Config) -> Result<(), anyhow::Er } -async fn run_init_tests() -> Result<(), anyhow::Error> { +async fn run_init_tests(state: &ServerState) -> Result<(), anyhow::Error> { info!("Running initialization tests"); { - let _agent = AGENT.lock().await; + let _agent = state.agent.lock().await; info!("Agent initialization test passed"); } info!("Initialization tests completed"); @@ -597,26 +600,34 @@ async fn main() -> Result<(), anyhow::Error> { "default_api_key".to_string() }); + let state = ServerState { + agent: Arc::new(tokio::sync::Mutex::new(Agent::new())), + extension_manager: Arc::new(tokio::sync::Mutex::new(ExtensionManager::default())), + }; + // Initialize provider configuration - if let Err(e) = initialize_provider_config().await { + if let Err(e) = initialize_provider_config(&state).await { error!("Failed to initialize provider: {}", e); return Err(e); } - + // Initialize extensions from configuration - if let Err(e) = initialize_extensions(&api_config).await { + if let Err(e) = initialize_extensions(&state, &api_config).await { error!("Failed to initialize extensions: {}", e); } - - if let Err(e) = run_init_tests().await { + + if let Err(e) = run_init_tests(&state).await { error!("Initialization tests failed: {}", e); } + let state_filter = warp::any().map(move || state.clone()); + // Session start endpoint let start_session = warp::path("session") .and(warp::path("start")) .and(warp::post()) .and(warp::body::json()) + .and(state_filter.clone()) .and(with_api_key(api_key.clone())) .and_then(start_session_handler); @@ -625,6 +636,7 @@ async fn main() -> Result<(), anyhow::Error> { .and(warp::path("reply")) .and(warp::post()) .and(warp::body::json()) + .and(state_filter.clone()) .and(with_api_key(api_key.clone())) .and_then(reply_session_handler); @@ -633,6 +645,7 @@ async fn main() -> Result<(), anyhow::Error> { .and(warp::path("end")) .and(warp::post()) .and(warp::body::json()) + .and(state_filter.clone()) .and(with_api_key(api_key.clone())) .and_then(end_session_handler); @@ -640,6 +653,7 @@ async fn main() -> Result<(), anyhow::Error> { let list_extensions = warp::path("extensions") .and(warp::path("list")) .and(warp::get()) + .and(state_filter.clone()) .and_then(list_extensions_handler); // Add extension endpoint @@ -647,6 +661,7 @@ async fn main() -> Result<(), anyhow::Error> { .and(warp::path("add")) .and(warp::post()) .and(warp::body::json()) + .and(state_filter.clone()) .and(with_api_key(api_key.clone())) .and_then(add_extension_handler); @@ -655,6 +670,7 @@ async fn main() -> Result<(), anyhow::Error> { .and(warp::path("remove")) .and(warp::post()) .and(warp::body::json()) + .and(state_filter.clone()) .and(with_api_key(api_key.clone())) .and_then(remove_extension_handler); @@ -662,6 +678,7 @@ async fn main() -> Result<(), anyhow::Error> { let get_provider_config = warp::path("provider") .and(warp::path("config")) .and(warp::get()) + .and(state_filter.clone()) .and_then(get_provider_config_handler); // Combine all routes From 0a9cd1eea712fd463cd87069c55879774cb6844a Mon Sep 17 00:00:00 2001 From: Aljaz Date: Thu, 29 May 2025 10:12:12 +0200 Subject: [PATCH 2/4] api: return 401 for invalid api key --- crates/goose-api/src/main.rs | 59 ++++++++++++++++++++++++++++++++++-- 1 file changed, 56 insertions(+), 3 deletions(-) diff --git a/crates/goose-api/src/main.rs b/crates/goose-api/src/main.rs index 9a68f872..71660b33 100644 --- a/crates/goose-api/src/main.rs +++ b/crates/goose-api/src/main.rs @@ -1,5 +1,6 @@ -use warp::{Filter, Rejection}; +use warp::{Filter, Rejection, Reply}; use warp::http::HeaderValue; +use std::convert::Infallible; use serde::{Deserialize, Serialize}; use std::sync::LazyLock; use goose::config::{Config, ExtensionEntry}; @@ -463,6 +464,24 @@ async fn remove_extension_handler( Ok(warp::reply::json(&resp)) } +#[derive(Debug)] +struct Unauthorized; + +impl warp::reject::Reject for Unauthorized {} + +async fn handle_rejection(err: Rejection) -> Result { + if err.find::().is_some() { + Ok(warp::reply::with_status("UNAUTHORIZED", warp::http::StatusCode::UNAUTHORIZED)) + } else if err.is_not_found() { + Ok(warp::reply::with_status("NOT_FOUND", warp::http::StatusCode::NOT_FOUND)) + } else { + Ok(warp::reply::with_status( + "INTERNAL_SERVER_ERROR", + warp::http::StatusCode::INTERNAL_SERVER_ERROR, + )) + } +} + fn with_api_key(api_key: String) -> impl Filter + Clone { warp::header::value("x-api-key") .and_then(move |header_api_key: HeaderValue| { @@ -471,12 +490,45 @@ fn with_api_key(api_key: String) -> impl Filter std::result::Result { let config_path = std::env::var("GOOSE_CONFIG").unwrap_or_else(|_| "config".to_string()); @@ -671,7 +723,8 @@ async fn main() -> Result<(), anyhow::Error> { .or(list_extensions) .or(add_extension) .or(remove_extension) - .or(get_provider_config); + .or(get_provider_config) + .recover(handle_rejection); // Get bind address from configuration or use default let host = std::env::var("GOOSE_API_HOST") From 45bddbdf1230791a95b1f9ed6a723bc1c25bc1d4 Mon Sep 17 00:00:00 2001 From: Aljaz Date: Thu, 29 May 2025 10:12:40 +0200 Subject: [PATCH 3/4] test: serialize env-modifying tests --- crates/goose/src/tracing/langfuse_layer.rs | 2 ++ crates/goose/tests/providers.rs | 12 ++++++++++++ 2 files changed, 14 insertions(+) diff --git a/crates/goose/src/tracing/langfuse_layer.rs b/crates/goose/src/tracing/langfuse_layer.rs index 2ac418cf..4bf19376 100644 --- a/crates/goose/src/tracing/langfuse_layer.rs +++ b/crates/goose/src/tracing/langfuse_layer.rs @@ -187,6 +187,7 @@ mod tests { use super::*; use serde_json::json; use std::collections::HashMap; + use serial_test::serial; use tokio::sync::Mutex; use tracing::dispatcher; use wiremock::matchers::{method, path}; @@ -389,6 +390,7 @@ mod tests { } #[tokio::test] + #[serial] async fn test_create_langfuse_observer() { let fixture = TestFixture::new().await.with_mock_server().await; diff --git a/crates/goose/tests/providers.rs b/crates/goose/tests/providers.rs index e65aff66..d18d4226 100644 --- a/crates/goose/tests/providers.rs +++ b/crates/goose/tests/providers.rs @@ -8,6 +8,7 @@ use goose::providers::{ }; use mcp_core::content::Content; use mcp_core::tool::Tool; +use serial_test::serial; use std::collections::HashMap; use std::sync::Arc; use std::sync::Mutex; @@ -352,6 +353,7 @@ where } #[tokio::test] +#[serial] async fn test_openai_provider() -> Result<()> { test_provider( "OpenAI", @@ -363,6 +365,7 @@ async fn test_openai_provider() -> Result<()> { } #[tokio::test] +#[serial] async fn test_azure_provider() -> Result<()> { test_provider( "Azure", @@ -378,6 +381,7 @@ async fn test_azure_provider() -> Result<()> { } #[tokio::test] +#[serial] async fn test_bedrock_provider_long_term_credentials() -> Result<()> { test_provider( "Bedrock", @@ -389,6 +393,7 @@ async fn test_bedrock_provider_long_term_credentials() -> Result<()> { } #[tokio::test] +#[serial] async fn test_bedrock_provider_aws_profile_credentials() -> Result<()> { let env_mods = HashMap::from_iter([ // Ensure to unset long-term credentials to use AWS Profile provider @@ -406,6 +411,7 @@ async fn test_bedrock_provider_aws_profile_credentials() -> Result<()> { } #[tokio::test] +#[serial] async fn test_databricks_provider() -> Result<()> { test_provider( "Databricks", @@ -417,6 +423,7 @@ async fn test_databricks_provider() -> Result<()> { } #[tokio::test] +#[serial] async fn test_databricks_provider_oauth() -> Result<()> { let mut env_mods = HashMap::new(); env_mods.insert("DATABRICKS_TOKEN", None); @@ -431,6 +438,7 @@ async fn test_databricks_provider_oauth() -> Result<()> { } #[tokio::test] +#[serial] async fn test_ollama_provider() -> Result<()> { test_provider( "Ollama", @@ -442,11 +450,13 @@ async fn test_ollama_provider() -> Result<()> { } #[tokio::test] +#[serial] async fn test_groq_provider() -> Result<()> { test_provider("Groq", &["GROQ_API_KEY"], None, groq::GroqProvider::default).await } #[tokio::test] +#[serial] async fn test_anthropic_provider() -> Result<()> { test_provider( "Anthropic", @@ -458,6 +468,7 @@ async fn test_anthropic_provider() -> Result<()> { } #[tokio::test] +#[serial] async fn test_openrouter_provider() -> Result<()> { test_provider( "OpenRouter", @@ -469,6 +480,7 @@ async fn test_openrouter_provider() -> Result<()> { } #[tokio::test] +#[serial] async fn test_google_provider() -> Result<()> { test_provider( "Google", From f6e305958e9c38c460ee4c4cc61ef8931fac0322 Mon Sep 17 00:00:00 2001 From: Aljaz Date: Thu, 29 May 2025 10:13:24 +0200 Subject: [PATCH 4/4] docs(api): update implementation status --- crates/goose-api/README.md | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/crates/goose-api/README.md b/crates/goose-api/README.md index 86e49ae9..3b5469a4 100644 --- a/crates/goose-api/README.md +++ b/crates/goose-api/README.md @@ -355,14 +355,24 @@ The current implementation includes the following features from the implementati 🟡 **Step 7**: Extension loading mechanism (partial implementation) 🟡 **Step 8**: MCP support (partial implementation) -✅ **Step 10**: Documentation -❌ **Step 11**: Tests (not yet implemented) +✅ **Step 10**: Documentation +✅ **Step 11**: Tests + +## Running Tests + +Run all unit and integration tests with: + +```bash +cargo test +``` + +This command executes the entire workspace test suite. To test a single crate, use `cargo test -p `. ## Future Work - Extend session management capabilities - Add more comprehensive error handling -- Implement unit and integration tests +- Expand unit and integration tests - Complete MCP integration - Add metrics and monitoring - Add OpenAPI documentation generation \ No newline at end of file