diff --git a/crates/goose-api/src/main.rs b/crates/goose-api/src/main.rs index 99ce0157..dfe561ae 100644 --- a/crates/goose-api/src/main.rs +++ b/crates/goose-api/src/main.rs @@ -1,5 +1,6 @@ -use warp::{Filter, Rejection}; +use warp::{Filter, Rejection, Reply}; use warp::http::HeaderValue; +use std::convert::Infallible; use serde::{Deserialize, Serialize}; use std::sync::Arc; use goose::config::{Config, ExtensionEntry}; @@ -466,6 +467,24 @@ async fn remove_extension_handler( Ok(warp::reply::json(&resp)) } +#[derive(Debug)] +struct Unauthorized; + +impl warp::reject::Reject for Unauthorized {} + +async fn handle_rejection(err: Rejection) -> Result { + if err.find::().is_some() { + Ok(warp::reply::with_status("UNAUTHORIZED", warp::http::StatusCode::UNAUTHORIZED)) + } else if err.is_not_found() { + Ok(warp::reply::with_status("NOT_FOUND", warp::http::StatusCode::NOT_FOUND)) + } else { + Ok(warp::reply::with_status( + "INTERNAL_SERVER_ERROR", + warp::http::StatusCode::INTERNAL_SERVER_ERROR, + )) + } +} + fn with_api_key(api_key: String) -> impl Filter + Clone { warp::header::value("x-api-key") .and_then(move |header_api_key: HeaderValue| { @@ -474,12 +493,45 @@ fn with_api_key(api_key: String) -> impl Filter std::result::Result { let config_path = std::env::var("GOOSE_CONFIG").unwrap_or_else(|_| "config".to_string()); @@ -688,7 +740,8 @@ async fn main() -> Result<(), anyhow::Error> { .or(list_extensions) .or(add_extension) .or(remove_extension) - .or(get_provider_config); + .or(get_provider_config) + .recover(handle_rejection); // Get bind address from configuration or use default let host = std::env::var("GOOSE_API_HOST")