Merge pull request #6 from aljazceru/codex/modify-with_api_key-rejection-and-update-tests

Return 401 for invalid API key
This commit is contained in:
2025-05-29 10:14:05 +02:00
committed by GitHub

View File

@@ -1,5 +1,6 @@
use warp::{Filter, Rejection}; use warp::{Filter, Rejection, Reply};
use warp::http::HeaderValue; use warp::http::HeaderValue;
use std::convert::Infallible;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::sync::Arc; use std::sync::Arc;
use goose::config::{Config, ExtensionEntry}; use goose::config::{Config, ExtensionEntry};
@@ -466,6 +467,24 @@ async fn remove_extension_handler(
Ok(warp::reply::json(&resp)) Ok(warp::reply::json(&resp))
} }
#[derive(Debug)]
struct Unauthorized;
impl warp::reject::Reject for Unauthorized {}
async fn handle_rejection(err: Rejection) -> Result<impl Reply, Infallible> {
if err.find::<Unauthorized>().is_some() {
Ok(warp::reply::with_status("UNAUTHORIZED", warp::http::StatusCode::UNAUTHORIZED))
} else if err.is_not_found() {
Ok(warp::reply::with_status("NOT_FOUND", warp::http::StatusCode::NOT_FOUND))
} else {
Ok(warp::reply::with_status(
"INTERNAL_SERVER_ERROR",
warp::http::StatusCode::INTERNAL_SERVER_ERROR,
))
}
}
fn with_api_key(api_key: String) -> impl Filter<Extract = (String,), Error = Rejection> + Clone { fn with_api_key(api_key: String) -> impl Filter<Extract = (String,), Error = Rejection> + Clone {
warp::header::value("x-api-key") warp::header::value("x-api-key")
.and_then(move |header_api_key: HeaderValue| { .and_then(move |header_api_key: HeaderValue| {
@@ -474,12 +493,45 @@ fn with_api_key(api_key: String) -> impl Filter<Extract = (String,), Error = Rej
if header_api_key == api_key { if header_api_key == api_key {
Ok(api_key) Ok(api_key)
} else { } else {
Err(warp::reject::not_found()) Err(warp::reject::custom(Unauthorized))
} }
} }
}) })
} }
#[cfg(test)]
mod tests {
use super::*;
use warp::http::StatusCode;
#[tokio::test]
async fn valid_key_allows_request() {
let filter = with_api_key("secret".to_string())
.map(|k: String| k)
.recover(handle_rejection);
let res = warp::test::request()
.header("x-api-key", "secret")
.reply(&filter)
.await;
assert_eq!(res.status(), StatusCode::OK);
}
#[tokio::test]
async fn invalid_key_is_rejected() {
let filter = with_api_key("secret".to_string())
.map(|k: String| k)
.recover(handle_rejection);
let res = warp::test::request()
.header("x-api-key", "wrong")
.reply(&filter)
.await;
assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
}
}
// Load configuration from file and environment variables // Load configuration from file and environment variables
fn load_configuration() -> std::result::Result<config::Config, config::ConfigError> { fn load_configuration() -> std::result::Result<config::Config, config::ConfigError> {
let config_path = std::env::var("GOOSE_CONFIG").unwrap_or_else(|_| "config".to_string()); let config_path = std::env::var("GOOSE_CONFIG").unwrap_or_else(|_| "config".to_string());
@@ -688,7 +740,8 @@ async fn main() -> Result<(), anyhow::Error> {
.or(list_extensions) .or(list_extensions)
.or(add_extension) .or(add_extension)
.or(remove_extension) .or(remove_extension)
.or(get_provider_config); .or(get_provider_config)
.recover(handle_rejection);
// Get bind address from configuration or use default // Get bind address from configuration or use default
let host = std::env::var("GOOSE_API_HOST") let host = std::env::var("GOOSE_API_HOST")