mirror of
https://github.com/aljazceru/goose.git
synced 2025-12-18 22:54:24 +01:00
Merge pull request #6 from aljazceru/codex/modify-with_api_key-rejection-and-update-tests
Return 401 for invalid API key
This commit is contained in:
@@ -1,5 +1,6 @@
|
|||||||
use warp::{Filter, Rejection};
|
use warp::{Filter, Rejection, Reply};
|
||||||
use warp::http::HeaderValue;
|
use warp::http::HeaderValue;
|
||||||
|
use std::convert::Infallible;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use goose::config::{Config, ExtensionEntry};
|
use goose::config::{Config, ExtensionEntry};
|
||||||
@@ -466,6 +467,24 @@ async fn remove_extension_handler(
|
|||||||
Ok(warp::reply::json(&resp))
|
Ok(warp::reply::json(&resp))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
struct Unauthorized;
|
||||||
|
|
||||||
|
impl warp::reject::Reject for Unauthorized {}
|
||||||
|
|
||||||
|
async fn handle_rejection(err: Rejection) -> Result<impl Reply, Infallible> {
|
||||||
|
if err.find::<Unauthorized>().is_some() {
|
||||||
|
Ok(warp::reply::with_status("UNAUTHORIZED", warp::http::StatusCode::UNAUTHORIZED))
|
||||||
|
} else if err.is_not_found() {
|
||||||
|
Ok(warp::reply::with_status("NOT_FOUND", warp::http::StatusCode::NOT_FOUND))
|
||||||
|
} else {
|
||||||
|
Ok(warp::reply::with_status(
|
||||||
|
"INTERNAL_SERVER_ERROR",
|
||||||
|
warp::http::StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
|
))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn with_api_key(api_key: String) -> impl Filter<Extract = (String,), Error = Rejection> + Clone {
|
fn with_api_key(api_key: String) -> impl Filter<Extract = (String,), Error = Rejection> + Clone {
|
||||||
warp::header::value("x-api-key")
|
warp::header::value("x-api-key")
|
||||||
.and_then(move |header_api_key: HeaderValue| {
|
.and_then(move |header_api_key: HeaderValue| {
|
||||||
@@ -474,12 +493,45 @@ fn with_api_key(api_key: String) -> impl Filter<Extract = (String,), Error = Rej
|
|||||||
if header_api_key == api_key {
|
if header_api_key == api_key {
|
||||||
Ok(api_key)
|
Ok(api_key)
|
||||||
} else {
|
} else {
|
||||||
Err(warp::reject::not_found())
|
Err(warp::reject::custom(Unauthorized))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use warp::http::StatusCode;
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn valid_key_allows_request() {
|
||||||
|
let filter = with_api_key("secret".to_string())
|
||||||
|
.map(|k: String| k)
|
||||||
|
.recover(handle_rejection);
|
||||||
|
|
||||||
|
let res = warp::test::request()
|
||||||
|
.header("x-api-key", "secret")
|
||||||
|
.reply(&filter)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
assert_eq!(res.status(), StatusCode::OK);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn invalid_key_is_rejected() {
|
||||||
|
let filter = with_api_key("secret".to_string())
|
||||||
|
.map(|k: String| k)
|
||||||
|
.recover(handle_rejection);
|
||||||
|
|
||||||
|
let res = warp::test::request()
|
||||||
|
.header("x-api-key", "wrong")
|
||||||
|
.reply(&filter)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
|
||||||
|
}
|
||||||
|
}
|
||||||
// Load configuration from file and environment variables
|
// Load configuration from file and environment variables
|
||||||
fn load_configuration() -> std::result::Result<config::Config, config::ConfigError> {
|
fn load_configuration() -> std::result::Result<config::Config, config::ConfigError> {
|
||||||
let config_path = std::env::var("GOOSE_CONFIG").unwrap_or_else(|_| "config".to_string());
|
let config_path = std::env::var("GOOSE_CONFIG").unwrap_or_else(|_| "config".to_string());
|
||||||
@@ -688,7 +740,8 @@ async fn main() -> Result<(), anyhow::Error> {
|
|||||||
.or(list_extensions)
|
.or(list_extensions)
|
||||||
.or(add_extension)
|
.or(add_extension)
|
||||||
.or(remove_extension)
|
.or(remove_extension)
|
||||||
.or(get_provider_config);
|
.or(get_provider_config)
|
||||||
|
.recover(handle_rejection);
|
||||||
|
|
||||||
// Get bind address from configuration or use default
|
// Get bind address from configuration or use default
|
||||||
let host = std::env::var("GOOSE_API_HOST")
|
let host = std::env::var("GOOSE_API_HOST")
|
||||||
|
|||||||
Reference in New Issue
Block a user