mirror of
https://github.com/aljazceru/goose.git
synced 2025-12-18 06:34:26 +01:00
Merge pull request #6 from aljazceru/codex/modify-with_api_key-rejection-and-update-tests
Return 401 for invalid API key
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
use warp::{Filter, Rejection};
|
||||
use warp::{Filter, Rejection, Reply};
|
||||
use warp::http::HeaderValue;
|
||||
use std::convert::Infallible;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::sync::Arc;
|
||||
use goose::config::{Config, ExtensionEntry};
|
||||
@@ -466,6 +467,24 @@ async fn remove_extension_handler(
|
||||
Ok(warp::reply::json(&resp))
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct Unauthorized;
|
||||
|
||||
impl warp::reject::Reject for Unauthorized {}
|
||||
|
||||
async fn handle_rejection(err: Rejection) -> Result<impl Reply, Infallible> {
|
||||
if err.find::<Unauthorized>().is_some() {
|
||||
Ok(warp::reply::with_status("UNAUTHORIZED", warp::http::StatusCode::UNAUTHORIZED))
|
||||
} else if err.is_not_found() {
|
||||
Ok(warp::reply::with_status("NOT_FOUND", warp::http::StatusCode::NOT_FOUND))
|
||||
} else {
|
||||
Ok(warp::reply::with_status(
|
||||
"INTERNAL_SERVER_ERROR",
|
||||
warp::http::StatusCode::INTERNAL_SERVER_ERROR,
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
fn with_api_key(api_key: String) -> impl Filter<Extract = (String,), Error = Rejection> + Clone {
|
||||
warp::header::value("x-api-key")
|
||||
.and_then(move |header_api_key: HeaderValue| {
|
||||
@@ -474,12 +493,45 @@ fn with_api_key(api_key: String) -> impl Filter<Extract = (String,), Error = Rej
|
||||
if header_api_key == api_key {
|
||||
Ok(api_key)
|
||||
} else {
|
||||
Err(warp::reject::not_found())
|
||||
Err(warp::reject::custom(Unauthorized))
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use warp::http::StatusCode;
|
||||
|
||||
#[tokio::test]
|
||||
async fn valid_key_allows_request() {
|
||||
let filter = with_api_key("secret".to_string())
|
||||
.map(|k: String| k)
|
||||
.recover(handle_rejection);
|
||||
|
||||
let res = warp::test::request()
|
||||
.header("x-api-key", "secret")
|
||||
.reply(&filter)
|
||||
.await;
|
||||
|
||||
assert_eq!(res.status(), StatusCode::OK);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn invalid_key_is_rejected() {
|
||||
let filter = with_api_key("secret".to_string())
|
||||
.map(|k: String| k)
|
||||
.recover(handle_rejection);
|
||||
|
||||
let res = warp::test::request()
|
||||
.header("x-api-key", "wrong")
|
||||
.reply(&filter)
|
||||
.await;
|
||||
|
||||
assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
|
||||
}
|
||||
}
|
||||
// Load configuration from file and environment variables
|
||||
fn load_configuration() -> std::result::Result<config::Config, config::ConfigError> {
|
||||
let config_path = std::env::var("GOOSE_CONFIG").unwrap_or_else(|_| "config".to_string());
|
||||
@@ -688,7 +740,8 @@ async fn main() -> Result<(), anyhow::Error> {
|
||||
.or(list_extensions)
|
||||
.or(add_extension)
|
||||
.or(remove_extension)
|
||||
.or(get_provider_config);
|
||||
.or(get_provider_config)
|
||||
.recover(handle_rejection);
|
||||
|
||||
// Get bind address from configuration or use default
|
||||
let host = std::env::var("GOOSE_API_HOST")
|
||||
|
||||
Reference in New Issue
Block a user