From 543bfddbd5dad0fa8eceb8043ecca3b84c8b72ad Mon Sep 17 00:00:00 2001 From: Aljaz Date: Thu, 29 May 2025 10:13:42 +0200 Subject: [PATCH] Refactor goose-api into modules --- crates/goose-api/src/config.rs | 97 +++++ crates/goose-api/src/handlers.rs | 449 ++++++++++++++++++++ crates/goose-api/src/lib.rs | 5 + crates/goose-api/src/main.rs | 704 +------------------------------ crates/goose-api/src/routes.rs | 123 ++++++ crates/goose-api/src/tests.rs | 10 + 6 files changed, 686 insertions(+), 702 deletions(-) create mode 100644 crates/goose-api/src/config.rs create mode 100644 crates/goose-api/src/handlers.rs create mode 100644 crates/goose-api/src/lib.rs create mode 100644 crates/goose-api/src/routes.rs create mode 100644 crates/goose-api/src/tests.rs diff --git a/crates/goose-api/src/config.rs b/crates/goose-api/src/config.rs new file mode 100644 index 00000000..9cafc919 --- /dev/null +++ b/crates/goose-api/src/config.rs @@ -0,0 +1,97 @@ +use crate::handlers::{AGENT, EXTENSION_MANAGER}; +use goose::config::{Config, ExtensionEntry}; +use goose::agents::ExtensionConfig; +use goose::providers::{create, providers}; +use goose::model::ModelConfig; +use tracing::{info, warn, error}; +use config::{builder::DefaultState, ConfigBuilder, Environment, File}; +use serde_json::Value; + +pub fn load_configuration() -> std::result::Result { + let config_path = std::env::var("GOOSE_CONFIG").unwrap_or_else(|_| "config".to_string()); + let builder = ConfigBuilder::::default() + .add_source(File::with_name(&config_path).required(false)) + .add_source(Environment::with_prefix("GOOSE_API")); + builder.build() +} + +pub async fn initialize_provider_config() -> Result<(), anyhow::Error> { + let api_config = load_configuration()?; + + let provider_name = std::env::var("GOOSE_API_PROVIDER") + .or_else(|_| api_config.get_string("provider")) + .unwrap_or_else(|_| "openai".to_string()); + + let model_name = std::env::var("GOOSE_API_MODEL") + .or_else(|_| api_config.get_string("model")) + .unwrap_or_else(|_| "gpt-4o".to_string()); + + info!("Initializing with provider: {}, model: {}", provider_name, model_name); + + let config = Config::global(); + config.set_param("GOOSE_PROVIDER", Value::String(provider_name.clone()))?; + config.set_param("GOOSE_MODEL", Value::String(model_name.clone()))?; + + let available_providers = providers(); + if let Some(provider_meta) = available_providers.iter().find(|p| p.name == provider_name) { + for key in &provider_meta.config_keys { + let env_name = key.name.clone(); + if let Ok(value) = std::env::var(&env_name) { + if key.secret { + config.set_secret(&key.name, Value::String(value))?; + info!("Set secret key: {}", key.name); + } else { + config.set_param(&key.name, Value::String(value))?; + info!("Set parameter: {}", key.name); + } + } else { + warn!("Environment variable not set for key: {}", key.name); + if key.required { + error!("Required key {} not provided", key.name); + return Err(anyhow::anyhow!("Required key {} not provided", key.name)); + } + } + } + } + + let model_config = ModelConfig::new(model_name); + let provider = create(&provider_name, model_config)?; + + let agent = AGENT.lock().await; + agent.update_provider(provider).await?; + + info!("Provider configuration successful"); + Ok(()) +} + +pub async fn initialize_extensions(config: &config::Config) -> Result<(), anyhow::Error> { + if let Ok(ext_table) = config.get_table("extensions") { + for (name, ext_config) in ext_table { + let entry: ExtensionEntry = ext_config.clone().try_deserialize() + .map_err(|e| anyhow::anyhow!("Failed to deserialize extension config for {}: {}", name, e))?; + + if entry.enabled { + let extension_config: ExtensionConfig = entry.config; + let mut agent = AGENT.lock().await; + if let Err(e) = agent.add_extension(extension_config).await { + error!("Failed to add extension {}: {}", name, e); + } + } else { + info!("Skipping disabled extension: {}", name); + } + } + } else { + warn!("No extensions configured in config file."); + } + Ok(()) +} + +pub async fn run_init_tests() -> Result<(), anyhow::Error> { + info!("Running initialization tests"); + { + let _agent = AGENT.lock().await; + info!("Agent initialization test passed"); + } + info!("Initialization tests completed"); + Ok(()) +} diff --git a/crates/goose-api/src/handlers.rs b/crates/goose-api/src/handlers.rs new file mode 100644 index 00000000..52a0ad8e --- /dev/null +++ b/crates/goose-api/src/handlers.rs @@ -0,0 +1,449 @@ +use warp::{http::HeaderValue, Filter, Rejection}; +use serde::{Deserialize, Serialize}; +use std::path::PathBuf; +use uuid::Uuid; +use futures_util::TryStreamExt; +use tracing::{info, warn, error}; +use mcp_core::tool::Tool; +use goose::agents::{extension::Envs, extension_manager::ExtensionManager, ExtensionConfig, Agent, SessionConfig}; +use goose::message::Message; +use goose::session::{self, Identifier}; +use goose::config::{Config, ExtensionEntry}; +use std::sync::LazyLock; + +pub static EXTENSION_MANAGER: LazyLock = LazyLock::new(|| ExtensionManager::default()); +pub static AGENT: LazyLock> = LazyLock::new(|| tokio::sync::Mutex::new(Agent::new())); + +#[derive(Debug, Serialize, Deserialize)] +pub struct SessionRequest { + pub prompt: String, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct ApiResponse { + pub message: String, + pub status: String, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct StartSessionResponse { + pub message: String, + pub status: String, + pub session_id: Uuid, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct SessionReplyRequest { + pub session_id: Uuid, + pub prompt: String, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct EndSessionRequest { + pub session_id: Uuid, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct ExtensionsResponse { + pub extensions: Vec, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct ProviderConfig { + pub provider: String, + pub model: String, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct ExtensionResponse { + pub error: bool, + pub message: Option, +} + +#[derive(Debug, Deserialize)] +#[serde(tag = "type")] +pub enum ExtensionConfigRequest { + #[serde(rename = "sse")] + Sse { + name: String, + uri: String, + #[serde(default)] + envs: Envs, + #[serde(default)] + env_keys: Vec, + timeout: Option, + }, + #[serde(rename = "stdio")] + Stdio { + name: String, + cmd: String, + #[serde(default)] + args: Vec, + #[serde(default)] + envs: Envs, + #[serde(default)] + env_keys: Vec, + timeout: Option, + }, + #[serde(rename = "builtin")] + Builtin { + name: String, + display_name: Option, + timeout: Option, + }, + #[serde(rename = "frontend")] + Frontend { + name: String, + tools: Vec, + instructions: Option, + }, +} + +pub async fn start_session_handler( + req: SessionRequest, + _api_key: String, +) -> Result { + info!("Starting session with prompt: {}", req.prompt); + + let mut agent = AGENT.lock().await; + let mut messages = vec![Message::user().with_text(&req.prompt)]; + let session_id = Uuid::new_v4(); + let session_name = session_id.to_string(); + let session_path = session::get_path(Identifier::Name(session_name.clone())); + + let provider = agent.provider().await.ok(); + + let result = agent + .reply( + &messages, + Some(SessionConfig { + id: Identifier::Name(session_name.clone()), + working_dir: std::env::current_dir().unwrap_or_else(|_| PathBuf::from(".")), + }), + ) + .await; + + match result { + Ok(mut stream) => { + if let Ok(Some(response)) = stream.try_next().await { + let response_text = response.as_concat_text(); + messages.push(response); + if let Err(e) = session::persist_messages(&session_path, &messages, provider.clone()).await { + warn!("Failed to persist session {}: {}", session_name, e); + } + + let api_response = StartSessionResponse { + message: response_text, + status: "success".to_string(), + session_id, + }; + Ok(warp::reply::with_status( + warp::reply::json(&api_response), + warp::http::StatusCode::OK, + )) + } else { + if let Err(e) = session::persist_messages(&session_path, &messages, provider.clone()).await { + warn!("Failed to persist session {}: {}", session_name, e); + } + + let api_response = StartSessionResponse { + message: "Session started but no response generated".to_string(), + status: "warning".to_string(), + session_id, + }; + Ok(warp::reply::with_status( + warp::reply::json(&api_response), + warp::http::StatusCode::OK, + )) + } + } + Err(e) => { + error!("Failed to start session: {}", e); + let response = ApiResponse { + message: format!("Failed to start session: {}", e), + status: "error".to_string(), + }; + Ok(warp::reply::with_status( + warp::reply::json(&response), + warp::http::StatusCode::INTERNAL_SERVER_ERROR, + )) + } + } +} + +pub async fn reply_session_handler( + req: SessionReplyRequest, + _api_key: String, +) -> Result { + info!("Replying to session with prompt: {}", req.prompt); + + let mut agent = AGENT.lock().await; + + let session_name = req.session_id.to_string(); + let session_path = session::get_path(Identifier::Name(session_name.clone())); + + let mut messages = match session::read_messages(&session_path) { + Ok(m) => m, + Err(_) => { + let response = ApiResponse { + message: "Session not found".to_string(), + status: "error".to_string(), + }; + return Ok(warp::reply::with_status( + warp::reply::json(&response), + warp::http::StatusCode::NOT_FOUND, + )); + } + }; + + messages.push(Message::user().with_text(&req.prompt)); + + let provider = agent.provider().await.ok(); + + let result = agent + .reply( + &messages, + Some(SessionConfig { + id: Identifier::Name(session_name.clone()), + working_dir: std::env::current_dir().unwrap_or_else(|_| PathBuf::from(".")), + }), + ) + .await; + + match result { + Ok(mut stream) => { + if let Ok(Some(response)) = stream.try_next().await { + let response_text = response.as_concat_text(); + messages.push(response); + if let Err(e) = session::persist_messages(&session_path, &messages, provider.clone()).await { + warn!("Failed to persist session {}: {}", session_name, e); + } + let api_response = ApiResponse { + message: format!("Reply: {}", response_text), + status: "success".to_string(), + }; + Ok(warp::reply::with_status( + warp::reply::json(&api_response), + warp::http::StatusCode::OK, + )) + } else { + if let Err(e) = session::persist_messages(&session_path, &messages, provider.clone()).await { + warn!("Failed to persist session {}: {}", session_name, e); + } + let api_response = ApiResponse { + message: "Reply processed but no response generated".to_string(), + status: "warning".to_string(), + }; + Ok(warp::reply::with_status( + warp::reply::json(&api_response), + warp::http::StatusCode::OK, + )) + } + } + Err(e) => { + error!("Failed to reply to session: {}", e); + let response = ApiResponse { + message: format!("Failed to reply to session: {}", e), + status: "error".to_string(), + }; + Ok(warp::reply::with_status( + warp::reply::json(&response), + warp::http::StatusCode::INTERNAL_SERVER_ERROR, + )) + } + } +} + +pub async fn end_session_handler( + req: EndSessionRequest, + _api_key: String, +) -> Result { + let session_name = req.session_id.to_string(); + let session_path = session::get_path(Identifier::Name(session_name.clone())); + + if std::fs::remove_file(&session_path).is_ok() { + let response = ApiResponse { + message: "Session ended".to_string(), + status: "success".to_string(), + }; + Ok(warp::reply::with_status( + warp::reply::json(&response), + warp::http::StatusCode::OK, + )) + } else { + let response = ApiResponse { + message: "Session not found".to_string(), + status: "error".to_string(), + }; + Ok(warp::reply::with_status( + warp::reply::json(&response), + warp::http::StatusCode::NOT_FOUND, + )) + } +} + +pub async fn list_extensions_handler() -> Result { + info!("Listing extensions"); + + match EXTENSION_MANAGER.list_extensions().await { + Ok(exts) => { + let response = ExtensionsResponse { extensions: exts }; + Ok::(warp::reply::json(&response)) + } + Err(e) => { + error!("Failed to list extensions: {}", e); + let response = ExtensionsResponse { + extensions: vec!["Failed to list extensions".to_string()], + }; + Ok::(warp::reply::json(&response)) + } + } +} + +pub async fn get_provider_config_handler() -> Result { + info!("Getting provider configuration"); + + let config = Config::global(); + let provider = config + .get_param::("GOOSE_PROVIDER") + .unwrap_or_else(|_| "Not configured".to_string()); + let model = config + .get_param::("GOOSE_MODEL") + .unwrap_or_else(|_| "Not configured".to_string()); + + let response = ProviderConfig { provider, model }; + Ok::(warp::reply::json(&response)) +} + +pub async fn add_extension_handler( + req: ExtensionConfigRequest, + _api_key: String, +) -> Result { + info!("Adding extension: {:?}", req); + + #[cfg(target_os = "windows")] + if let ExtensionConfigRequest::Stdio { cmd, .. } = &req { + if cmd.ends_with("npx.cmd") || cmd.ends_with("npx") { + let node_exists = std::path::Path::new(r"C:\Program Files\nodejs\node.exe").exists() + || std::path::Path::new(r"C:\Program Files (x86)\nodejs\node.exe").exists(); + + if !node_exists { + let cmd_path = std::path::Path::new(cmd); + let script_dir = cmd_path.parent().ok_or_else(|| warp::reject())?; + + let install_script = script_dir.join("install-node.cmd"); + + if install_script.exists() { + eprintln!("Installing Node.js..."); + let output = std::process::Command::new(&install_script) + .arg("https://nodejs.org/dist/v23.10.0/node-v23.10.0-x64.msi") + .output() + .map_err(|_| warp::reject())?; + + if !output.status.success() { + eprintln!( + "Failed to install Node.js: {}", + String::from_utf8_lossy(&output.stderr) + ); + let resp = ExtensionResponse { + error: true, + message: Some(format!( + "Failed to install Node.js: {}", + String::from_utf8_lossy(&output.stderr) + )), + }; + return Ok(warp::reply::json(&resp)); + } + eprintln!("Node.js installation completed"); + } else { + eprintln!("Node.js installer script not found at: {}", install_script.display()); + let resp = ExtensionResponse { + error: true, + message: Some("Node.js installer script not found".to_string()), + }; + return Ok(warp::reply::json(&resp)); + } + } + } + } + + let extension = match req { + ExtensionConfigRequest::Sse { name, uri, envs, env_keys, timeout } => { + ExtensionConfig::Sse { + name, + uri, + envs, + env_keys, + description: None, + timeout, + bundled: None, + } + } + ExtensionConfigRequest::Stdio { name, cmd, args, envs, env_keys, timeout } => { + ExtensionConfig::Stdio { + name, + cmd, + args, + envs, + env_keys, + timeout, + description: None, + bundled: None, + } + } + ExtensionConfigRequest::Builtin { name, display_name, timeout } => { + ExtensionConfig::Builtin { + name, + display_name, + timeout, + bundled: None, + } + } + ExtensionConfigRequest::Frontend { name, tools, instructions } => { + ExtensionConfig::Frontend { + name, + tools, + instructions, + bundled: None, + } + } + }; + + let agent = AGENT.lock().await; + let result = agent.add_extension(extension).await; + + let resp = match result { + Ok(_) => ExtensionResponse { error: false, message: None }, + Err(e) => ExtensionResponse { + error: true, + message: Some(format!("Failed to add extension configuration, error: {:?}", e)), + }, + }; + Ok(warp::reply::json(&resp)) +} + +pub async fn remove_extension_handler( + name: String, + _api_key: String, +) -> Result { + info!("Removing extension: {}", name); + let agent = AGENT.lock().await; + agent.remove_extension(&name).await; + + let resp = ExtensionResponse { error: false, message: None }; + Ok(warp::reply::json(&resp)) +} + +pub fn with_api_key(api_key: String) -> impl Filter + Clone { + warp::header::value("x-api-key") + .and_then(move |header_api_key: HeaderValue| { + let api_key = api_key.clone(); + async move { + if header_api_key == api_key { + Ok(api_key) + } else { + Err(warp::reject::not_found()) + } + } + }) +} diff --git a/crates/goose-api/src/lib.rs b/crates/goose-api/src/lib.rs new file mode 100644 index 00000000..b2037198 --- /dev/null +++ b/crates/goose-api/src/lib.rs @@ -0,0 +1,5 @@ +mod handlers; +mod config; +mod routes; + +pub use routes::{build_routes, run_server}; diff --git a/crates/goose-api/src/main.rs b/crates/goose-api/src/main.rs index 9a68f872..451d0428 100644 --- a/crates/goose-api/src/main.rs +++ b/crates/goose-api/src/main.rs @@ -1,706 +1,6 @@ -use warp::{Filter, Rejection}; -use warp::http::HeaderValue; -use serde::{Deserialize, Serialize}; -use std::sync::LazyLock; -use goose::config::{Config, ExtensionEntry}; -use goose::agents::{ - extension::Envs, - Agent, - extension_manager::ExtensionManager, - ExtensionConfig, -}; -use mcp_core::tool::Tool; -use uuid::Uuid; -use goose::session::{self, Identifier}; -use goose::agents::SessionConfig; -use std::path::PathBuf; -use std::sync::Arc; - -use goose::providers::{create, providers}; -use goose::model::ModelConfig; -use goose::message::Message; -use tracing::{info, warn, error}; -use config::{builder::DefaultState, ConfigBuilder, Environment, File}; -use serde_json::Value; // Import the correct Value type -use futures_util::TryStreamExt; - -// Global extension manager for extension listing -static EXTENSION_MANAGER: LazyLock = LazyLock::new(|| ExtensionManager::default()); - -// Global agent for handling sessions -static AGENT: LazyLock> = LazyLock::new(|| { - tokio::sync::Mutex::new(Agent::new()) -}); - - -#[derive(Debug, Serialize, Deserialize)] -struct SessionRequest { - prompt: String, -} - -#[derive(Debug, Serialize, Deserialize)] -struct ApiResponse { - message: String, - status: String, -} - -#[derive(Debug, Serialize, Deserialize)] -struct StartSessionResponse { - message: String, - status: String, - session_id: Uuid, -} - -#[derive(Debug, Serialize, Deserialize)] -struct SessionReplyRequest { - session_id: Uuid, - prompt: String, -} - -#[derive(Debug, Serialize, Deserialize)] -struct EndSessionRequest { - session_id: Uuid, -} - -#[derive(Debug, Serialize, Deserialize)] -struct ExtensionsResponse { - extensions: Vec, -} - -#[derive(Debug, Serialize, Deserialize)] -struct ProviderConfig { - provider: String, - model: String, -} - -#[derive(Debug, Serialize, Deserialize)] -struct ExtensionResponse { - error: bool, - message: Option, -} - -#[derive(Debug, Deserialize)] -#[serde(tag = "type")] -enum ExtensionConfigRequest { - #[serde(rename = "sse")] - Sse { - name: String, - uri: String, - #[serde(default)] - envs: Envs, - #[serde(default)] - env_keys: Vec, - timeout: Option, - }, - #[serde(rename = "stdio")] - Stdio { - name: String, - cmd: String, - #[serde(default)] - args: Vec, - #[serde(default)] - envs: Envs, - #[serde(default)] - env_keys: Vec, - timeout: Option, - }, - #[serde(rename = "builtin")] - Builtin { - name: String, - display_name: Option, - timeout: Option, - }, - #[serde(rename = "frontend")] - Frontend { - name: String, - tools: Vec, - instructions: Option, - }, -} - -async fn start_session_handler( - req: SessionRequest, - _api_key: String, -) -> Result { - info!("Starting session with prompt: {}", req.prompt); - - let mut agent = AGENT.lock().await; - - // Create a user message with the prompt - let mut messages = vec![Message::user().with_text(&req.prompt)]; - - // Generate a new session ID and process the messages - let session_id = Uuid::new_v4(); - let session_name = session_id.to_string(); - let session_path = session::get_path(Identifier::Name(session_name.clone())); - - let provider = agent.provider().await.ok(); - - let result = agent - .reply( - &messages, - Some(SessionConfig { - id: Identifier::Name(session_name.clone()), - working_dir: std::env::current_dir().unwrap_or_else(|_| PathBuf::from(".")), - }), - ) - .await; - - match result { - Ok(mut stream) => { - // Process the stream to get the first response - if let Ok(Some(response)) = stream.try_next().await { - let response_text = response.as_concat_text(); - messages.push(response); - if let Err(e) = session::persist_messages(&session_path, &messages, provider.clone()).await { - warn!("Failed to persist session {}: {}", session_name, e); - } - - let api_response = StartSessionResponse { - message: response_text, - status: "success".to_string(), - session_id, - }; - Ok(warp::reply::with_status( - warp::reply::json(&api_response), - warp::http::StatusCode::OK, - )) - } else { - if let Err(e) = session::persist_messages(&session_path, &messages, provider.clone()).await { - warn!("Failed to persist session {}: {}", session_name, e); - } - - let api_response = StartSessionResponse { - message: "Session started but no response generated".to_string(), - status: "warning".to_string(), - session_id, - }; - Ok(warp::reply::with_status( - warp::reply::json(&api_response), - warp::http::StatusCode::OK, - )) - } - }, - Err(e) => { - error!("Failed to start session: {}", e); - let response = ApiResponse { - message: format!("Failed to start session: {}", e), - status: "error".to_string(), - }; - Ok(warp::reply::with_status( - warp::reply::json(&response), - warp::http::StatusCode::INTERNAL_SERVER_ERROR, - )) - } - } -} - -async fn reply_session_handler( - req: SessionReplyRequest, - _api_key: String, -) -> Result { - info!("Replying to session with prompt: {}", req.prompt); - - let mut agent = AGENT.lock().await; - - let session_name = req.session_id.to_string(); - let session_path = session::get_path(Identifier::Name(session_name.clone())); - - // Retrieve existing session history from disk - let mut messages = match session::read_messages(&session_path) { - Ok(m) => m, - Err(_) => { - let response = ApiResponse { - message: "Session not found".to_string(), - status: "error".to_string(), - }; - return Ok(warp::reply::with_status( - warp::reply::json(&response), - warp::http::StatusCode::NOT_FOUND, - )); - } - }; - - // Append the new user message - messages.push(Message::user().with_text(&req.prompt)); - - let provider = agent.provider().await.ok(); - - // Process the messages through the agent - let result = agent - .reply( - &messages, - Some(SessionConfig { - id: Identifier::Name(session_name.clone()), - working_dir: std::env::current_dir().unwrap_or_else(|_| PathBuf::from(".")), - }), - ) - .await; - - match result { - Ok(mut stream) => { - // Process the stream to get the first response - if let Ok(Some(response)) = stream.try_next().await { - let response_text = response.as_concat_text(); - messages.push(response); - if let Err(e) = session::persist_messages(&session_path, &messages, provider.clone()).await { - warn!("Failed to persist session {}: {}", session_name, e); - } - let api_response = ApiResponse { - message: format!("Reply: {}", response_text), - status: "success".to_string(), - }; - Ok(warp::reply::with_status( - warp::reply::json(&api_response), - warp::http::StatusCode::OK, - )) - } else { - if let Err(e) = session::persist_messages(&session_path, &messages, provider.clone()).await { - warn!("Failed to persist session {}: {}", session_name, e); - } - let api_response = ApiResponse { - message: "Reply processed but no response generated".to_string(), - status: "warning".to_string(), - }; - Ok(warp::reply::with_status( - warp::reply::json(&api_response), - warp::http::StatusCode::OK, - )) - } - }, - Err(e) => { - error!("Failed to reply to session: {}", e); - let response = ApiResponse { - message: format!("Failed to reply to session: {}", e), - status: "error".to_string(), - }; - Ok(warp::reply::with_status( - warp::reply::json(&response), - warp::http::StatusCode::INTERNAL_SERVER_ERROR, - )) - } - } -} - -async fn end_session_handler( - req: EndSessionRequest, - _api_key: String, -) -> Result { - let session_name = req.session_id.to_string(); - let session_path = session::get_path(Identifier::Name(session_name.clone())); - - if std::fs::remove_file(&session_path).is_ok() { - let response = ApiResponse { - message: "Session ended".to_string(), - status: "success".to_string(), - }; - Ok(warp::reply::with_status( - warp::reply::json(&response), - warp::http::StatusCode::OK, - )) - } else { - let response = ApiResponse { - message: "Session not found".to_string(), - status: "error".to_string(), - }; - Ok(warp::reply::with_status( - warp::reply::json(&response), - warp::http::StatusCode::NOT_FOUND, - )) - } -} - -async fn list_extensions_handler() -> Result { - info!("Listing extensions"); - - match EXTENSION_MANAGER.list_extensions().await { - Ok(exts) => { - let response = ExtensionsResponse { extensions: exts }; - Ok::(warp::reply::json(&response)) - }, - Err(e) => { - error!("Failed to list extensions: {}", e); - let response = ExtensionsResponse { - extensions: vec!["Failed to list extensions".to_string()] - }; - Ok::(warp::reply::json(&response)) - } - } -} - -async fn get_provider_config_handler() -> Result { - info!("Getting provider configuration"); - - let config = Config::global(); - let provider = config.get_param::("GOOSE_PROVIDER") - .unwrap_or_else(|_| "Not configured".to_string()); - let model = config.get_param::("GOOSE_MODEL") - .unwrap_or_else(|_| "Not configured".to_string()); - - let response = ProviderConfig { provider, model }; - Ok::(warp::reply::json(&response)) -} - -async fn add_extension_handler( - req: ExtensionConfigRequest, - _api_key: String, -) -> Result { - info!("Adding extension: {:?}", req); - - #[cfg(target_os = "windows")] - if let ExtensionConfigRequest::Stdio { cmd, .. } = &req { - if cmd.ends_with("npx.cmd") || cmd.ends_with("npx") { - let node_exists = std::path::Path::new(r"C:\Program Files\nodejs\node.exe").exists() - || std::path::Path::new(r"C:\Program Files (x86)\nodejs\node.exe").exists(); - - if !node_exists { - let cmd_path = std::path::Path::new(cmd); - let script_dir = cmd_path.parent().ok_or_else(|| warp::reject())?; - - let install_script = script_dir.join("install-node.cmd"); - - if install_script.exists() { - eprintln!("Installing Node.js..."); - let output = std::process::Command::new(&install_script) - .arg("https://nodejs.org/dist/v23.10.0/node-v23.10.0-x64.msi") - .output() - .map_err(|_| warp::reject())?; - - if !output.status.success() { - eprintln!( - "Failed to install Node.js: {}", - String::from_utf8_lossy(&output.stderr) - ); - let resp = ExtensionResponse { - error: true, - message: Some(format!( - "Failed to install Node.js: {}", - String::from_utf8_lossy(&output.stderr) - )), - }; - return Ok(warp::reply::json(&resp)); - } - eprintln!("Node.js installation completed"); - } else { - eprintln!( - "Node.js installer script not found at: {}", - install_script.display() - ); - let resp = ExtensionResponse { - error: true, - message: Some("Node.js installer script not found".to_string()), - }; - return Ok(warp::reply::json(&resp)); - } - } - } - } - - let extension = match req { - ExtensionConfigRequest::Sse { name, uri, envs, env_keys, timeout } => { - ExtensionConfig::Sse { - name, - uri, - envs, - env_keys, - description: None, - timeout, - bundled: None, - } - } - ExtensionConfigRequest::Stdio { name, cmd, args, envs, env_keys, timeout } => { - ExtensionConfig::Stdio { - name, - cmd, - args, - envs, - env_keys, - timeout, - description: None, - bundled: None, - } - } - ExtensionConfigRequest::Builtin { name, display_name, timeout } => { - ExtensionConfig::Builtin { - name, - display_name, - timeout, - bundled: None, - } - } - ExtensionConfigRequest::Frontend { name, tools, instructions } => { - ExtensionConfig::Frontend { - name, - tools, - instructions, - bundled: None, - } - } - }; - - let agent = AGENT.lock().await; - let result = agent.add_extension(extension).await; - - let resp = match result { - Ok(_) => ExtensionResponse { error: false, message: None }, - Err(e) => ExtensionResponse { - error: true, - message: Some(format!("Failed to add extension configuration, error: {:?}", e)), - }, - }; - Ok(warp::reply::json(&resp)) -} - -async fn remove_extension_handler( - name: String, - _api_key: String, -) -> Result { - info!("Removing extension: {}", name); - let agent = AGENT.lock().await; - agent.remove_extension(&name).await; - - let resp = ExtensionResponse { error: false, message: None }; - Ok(warp::reply::json(&resp)) -} - -fn with_api_key(api_key: String) -> impl Filter + Clone { - warp::header::value("x-api-key") - .and_then(move |header_api_key: HeaderValue| { - let api_key = api_key.clone(); - async move { - if header_api_key == api_key { - Ok(api_key) - } else { - Err(warp::reject::not_found()) - } - } - }) -} - -// Load configuration from file and environment variables -fn load_configuration() -> std::result::Result { - let config_path = std::env::var("GOOSE_CONFIG").unwrap_or_else(|_| "config".to_string()); - let builder = ConfigBuilder::::default() - .add_source(File::with_name(&config_path).required(false)) - .add_source(Environment::with_prefix("GOOSE_API")); - - builder.build() -} - -// Initialize global provider configuration -async fn initialize_provider_config() -> Result<(), anyhow::Error> { - // Get configuration - let api_config = load_configuration()?; - - // Get provider settings from configuration or environment variables - let provider_name = std::env::var("GOOSE_API_PROVIDER") - .or_else(|_| api_config.get_string("provider")) - .unwrap_or_else(|_| "openai".to_string()); - - let model_name = std::env::var("GOOSE_API_MODEL") - .or_else(|_| api_config.get_string("model")) - .unwrap_or_else(|_| "gpt-4o".to_string()); - - info!("Initializing with provider: {}, model: {}", provider_name, model_name); - - // Initialize the global Config object - let config = Config::global(); - config.set_param("GOOSE_PROVIDER", Value::String(provider_name.clone()))?; - config.set_param("GOOSE_MODEL", Value::String(model_name.clone()))?; - - // Set up API keys from environment variables - let available_providers = providers(); - if let Some(provider_meta) = available_providers.iter().find(|p| p.name == provider_name) { - for key in &provider_meta.config_keys { - let env_name = key.name.clone(); - if let Ok(value) = std::env::var(&env_name) { - if key.secret { - config.set_secret(&key.name, Value::String(value))?; - info!("Set secret key: {}", key.name); - } else { - config.set_param(&key.name, Value::String(value))?; - info!("Set parameter: {}", key.name); - } - } else { - warn!("Environment variable not set for key: {}", key.name); - if key.required { - error!("Required key {} not provided", key.name); - return Err(anyhow::anyhow!("Required key {} not provided", key.name)); - } - } - } - } - - // Initialize agent with provider - let model_config = ModelConfig::new(model_name); - let provider = create(&provider_name, model_config)?; - - let agent = AGENT.lock().await; - agent.update_provider(provider).await?; - - info!("Provider configuration successful"); - Ok(()) -} -/// Initialize extensions from the configuration. -async fn initialize_extensions(config: &config::Config) -> Result<(), anyhow::Error> { - if let Ok(ext_table) = config.get_table("extensions") { - for (name, ext_config) in ext_table { - // Deserialize into ExtensionEntry to get enabled flag and config - let entry: ExtensionEntry = ext_config.clone().try_deserialize() - .map_err(|e| anyhow::anyhow!("Failed to deserialize extension config for {}: {}", name, e))?; - - if entry.enabled { - let extension_config: ExtensionConfig = entry.config; - // Acquire the global agent lock and try to add the extension - let mut agent = AGENT.lock().await; - if let Err(e) = agent.add_extension(extension_config).await { - error!("Failed to add extension {}: {}", name, e); - } - } else { - info!("Skipping disabled extension: {}", name); - } - } - } else { - warn!("No extensions configured in config file."); - } - Ok(()) -} - - -async fn run_init_tests() -> Result<(), anyhow::Error> { - info!("Running initialization tests"); - { - let _agent = AGENT.lock().await; - info!("Agent initialization test passed"); - } - info!("Initialization tests completed"); - Ok(()) -} +use goose_api::run_server; #[tokio::main] async fn main() -> Result<(), anyhow::Error> { - // Initialize tracing - tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .init(); - - info!("Starting goose-api server"); - - // Load configuration - let api_config = load_configuration()?; - - // Get API key from configuration or environment - let api_key: String = std::env::var("GOOSE_API_KEY") - .or_else(|_| api_config.get_string("api_key")) - .unwrap_or_else(|_| { - warn!("No API key configured, using default"); - "default_api_key".to_string() - }); - - // Initialize provider configuration - if let Err(e) = initialize_provider_config().await { - error!("Failed to initialize provider: {}", e); - return Err(e); - } - - // Initialize extensions from configuration - if let Err(e) = initialize_extensions(&api_config).await { - error!("Failed to initialize extensions: {}", e); - } - - if let Err(e) = run_init_tests().await { - error!("Initialization tests failed: {}", e); - } - - // Session start endpoint - let start_session = warp::path("session") - .and(warp::path("start")) - .and(warp::post()) - .and(warp::body::json()) - .and(with_api_key(api_key.clone())) - .and_then(start_session_handler); - - // Session reply endpoint - let reply_session = warp::path("session") - .and(warp::path("reply")) - .and(warp::post()) - .and(warp::body::json()) - .and(with_api_key(api_key.clone())) - .and_then(reply_session_handler); - - // Session end endpoint - let end_session = warp::path("session") - .and(warp::path("end")) - .and(warp::post()) - .and(warp::body::json()) - .and(with_api_key(api_key.clone())) - .and_then(end_session_handler); - - // List extensions endpoint - let list_extensions = warp::path("extensions") - .and(warp::path("list")) - .and(warp::get()) - .and_then(list_extensions_handler); - - // Add extension endpoint - let add_extension = warp::path("extensions") - .and(warp::path("add")) - .and(warp::post()) - .and(warp::body::json()) - .and(with_api_key(api_key.clone())) - .and_then(add_extension_handler); - - // Remove extension endpoint - let remove_extension = warp::path("extensions") - .and(warp::path("remove")) - .and(warp::post()) - .and(warp::body::json()) - .and(with_api_key(api_key.clone())) - .and_then(remove_extension_handler); - - // Get provider configuration endpoint - let get_provider_config = warp::path("provider") - .and(warp::path("config")) - .and(warp::get()) - .and_then(get_provider_config_handler); - - // Combine all routes - let routes = start_session - .or(reply_session) - .or(end_session) - .or(list_extensions) - .or(add_extension) - .or(remove_extension) - .or(get_provider_config); - - // Get bind address from configuration or use default - let host = std::env::var("GOOSE_API_HOST") - .or_else(|_| api_config.get_string("host")) - .unwrap_or_else(|_| "127.0.0.1".to_string()); - - let port = std::env::var("GOOSE_API_PORT") - .or_else(|_| api_config.get_string("port")) - .unwrap_or_else(|_| "8080".to_string()) - .parse::() - .unwrap_or(8080); - - info!("Starting server on {}:{}", host, port); - - // Parse host string - let host_parts: Vec = host.split('.') - .map(|part| part.parse::().unwrap_or(127)) - .collect(); - - let addr = if host_parts.len() == 4 { - [host_parts[0], host_parts[1], host_parts[2], host_parts[3]] - } else { - [127, 0, 0, 1] - }; - - // Start the server - warp::serve(routes) - .run((addr, port)) - .await; - - Ok(()) + run_server().await } diff --git a/crates/goose-api/src/routes.rs b/crates/goose-api/src/routes.rs new file mode 100644 index 00000000..759786c3 --- /dev/null +++ b/crates/goose-api/src/routes.rs @@ -0,0 +1,123 @@ +use warp::Filter; +use tracing::{info, warn, error}; + +use crate::handlers::{ + add_extension_handler, end_session_handler, get_provider_config_handler, + list_extensions_handler, remove_extension_handler, reply_session_handler, + start_session_handler, with_api_key, +}; +use crate::config::{ + initialize_extensions, initialize_provider_config, load_configuration, + run_init_tests, +}; + +pub fn build_routes(api_key: String) -> impl Filter + Clone { + let start_session = warp::path("session") + .and(warp::path("start")) + .and(warp::post()) + .and(warp::body::json()) + .and(with_api_key(api_key.clone())) + .and_then(start_session_handler); + + let reply_session = warp::path("session") + .and(warp::path("reply")) + .and(warp::post()) + .and(warp::body::json()) + .and(with_api_key(api_key.clone())) + .and_then(reply_session_handler); + + let end_session = warp::path("session") + .and(warp::path("end")) + .and(warp::post()) + .and(warp::body::json()) + .and(with_api_key(api_key.clone())) + .and_then(end_session_handler); + + let list_extensions = warp::path("extensions") + .and(warp::path("list")) + .and(warp::get()) + .and_then(list_extensions_handler); + + let add_extension = warp::path("extensions") + .and(warp::path("add")) + .and(warp::post()) + .and(warp::body::json()) + .and(with_api_key(api_key.clone())) + .and_then(add_extension_handler); + + let remove_extension = warp::path("extensions") + .and(warp::path("remove")) + .and(warp::post()) + .and(warp::body::json()) + .and(with_api_key(api_key.clone())) + .and_then(remove_extension_handler); + + let get_provider_config = warp::path("provider") + .and(warp::path("config")) + .and(warp::get()) + .and_then(get_provider_config_handler); + + start_session + .or(reply_session) + .or(end_session) + .or(list_extensions) + .or(add_extension) + .or(remove_extension) + .or(get_provider_config) +} + +pub async fn run_server() -> Result<(), anyhow::Error> { + tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .init(); + + info!("Starting goose-api server"); + + let api_config = load_configuration()?; + + let api_key: String = std::env::var("GOOSE_API_KEY") + .or_else(|_| api_config.get_string("api_key")) + .unwrap_or_else(|_| { + warn!("No API key configured, using default"); + "default_api_key".to_string() + }); + + if let Err(e) = initialize_provider_config().await { + error!("Failed to initialize provider: {}", e); + return Err(e); + } + + if let Err(e) = initialize_extensions(&api_config).await { + error!("Failed to initialize extensions: {}", e); + } + + if let Err(e) = run_init_tests().await { + error!("Initialization tests failed: {}", e); + } + + let routes = build_routes(api_key.clone()); + + let host = std::env::var("GOOSE_API_HOST") + .or_else(|_| api_config.get_string("host")) + .unwrap_or_else(|_| "127.0.0.1".to_string()); + let port = std::env::var("GOOSE_API_PORT") + .or_else(|_| api_config.get_string("port")) + .unwrap_or_else(|_| "8080".to_string()) + .parse::() + .unwrap_or(8080); + + info!("Starting server on {}:{}", host, port); + + let host_parts: Vec = host + .split('.') + .map(|part| part.parse::().unwrap_or(127)) + .collect(); + let addr = if host_parts.len() == 4 { + [host_parts[0], host_parts[1], host_parts[2], host_parts[3]] + } else { + [127, 0, 0, 1] + }; + + warp::serve(routes).run((addr, port)).await; + Ok(()) +} diff --git a/crates/goose-api/src/tests.rs b/crates/goose-api/src/tests.rs new file mode 100644 index 00000000..302cf8c3 --- /dev/null +++ b/crates/goose-api/src/tests.rs @@ -0,0 +1,10 @@ +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn build_routes_compiles() { + let _routes = build_routes("test-key".to_string()); + // Just ensure building routes doesn't panic + } +}