feat: Flexible endpoint rate limiting (#124)

* before ThrottledBody

* added upload rate limiting

* before dynamic rate limiting

* flexible rate limiter + config

* fixed tests

* fmt and clippy

* reset auth.js e2e

* more cleaning up

* improved code and added comments

* limit downloads too

* fmt and clippy

* improved sample comments

* fixed comment

* removed http dependency in favour of axum:http

* parse speed units as lowercase

* replaced regex with glob

* clippy and fmt
This commit is contained in:
Severin Alexander Bühler
2025-05-13 11:06:23 +03:00
committed by GitHub
parent 00aafb2163
commit 8a0cec71ef
21 changed files with 1341 additions and 14 deletions

59
Cargo.lock generated
View File

@@ -36,6 +36,12 @@ dependencies = [
"memchr",
]
[[package]]
name = "allocator-api2"
version = "0.2.21"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "683d7910e743518b0e34f1186f92494becacb047c7b6bf616c96772180fef923"
[[package]]
name = "anstream"
version = "0.6.18"
@@ -923,6 +929,15 @@ dependencies = [
"windows-sys 0.59.0",
]
[[package]]
name = "fast-glob"
version = "0.4.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "39ea3f6bbcf4dbe2076b372186fc7aeecd5f6f84754582e56ee7db262b15a6f0"
dependencies = [
"arrayvec",
]
[[package]]
name = "fastrand"
version = "2.3.0"
@@ -953,6 +968,12 @@ version = "1.0.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1"
[[package]]
name = "foldhash"
version = "0.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d9c4f5dac5e15c24eb999c26181a6ca40b39fe946cbe4c263c7209467bc83af2"
[[package]]
name = "foreign-types"
version = "0.3.2"
@@ -1185,6 +1206,29 @@ dependencies = [
"spinning_top",
]
[[package]]
name = "governor"
version = "0.10.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3cbe789d04bf14543f03c4b60cd494148aa79438c8440ae7d81a7778147745c3"
dependencies = [
"cfg-if",
"dashmap",
"futures-sink",
"futures-timer",
"futures-util",
"getrandom 0.3.1",
"hashbrown 0.15.2",
"nonzero_ext",
"parking_lot",
"portable-atomic",
"quanta",
"rand 0.9.0",
"smallvec",
"spinning_top",
"web-time",
]
[[package]]
name = "h2"
version = "0.4.8"
@@ -1224,6 +1268,11 @@ name = "hashbrown"
version = "0.15.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bf151400ff0baff5465007dd2f3e717f3fe502074ca563069ce3a6629d07b289"
dependencies = [
"allocator-api2",
"equivalent",
"foldhash",
]
[[package]]
name = "headers"
@@ -1321,9 +1370,9 @@ checksum = "f558a64ac9af88b5ba400d99b579451af0d39c6d360980045b91aac966d705e2"
[[package]]
name = "http"
version = "1.2.0"
version = "1.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f16ca2af56261c99fba8bac40a10251ce8188205a4c448fbb745a2e4daa76fea"
checksum = "f4a85d31aea989eead29a3aaf9e1115a180df8282431156e533de47660892565"
dependencies = [
"bytes",
"fnv",
@@ -2157,7 +2206,7 @@ dependencies = [
"bytes",
"clap",
"dirs-next",
"governor",
"governor 0.8.0",
"http",
"httpdate",
"pkarr",
@@ -2343,9 +2392,11 @@ dependencies = [
"clap",
"dirs",
"dyn-clone",
"fast-glob",
"flume",
"futures-lite",
"futures-util",
"governor 0.10.0",
"heed",
"hex",
"hostname-validator",
@@ -3500,7 +3551,7 @@ checksum = "57a2ccff6830fa835371af7541e561a90e4c07b84f72991ebac4b3cb6790dc0d"
dependencies = [
"axum",
"forwarded-header-value",
"governor",
"governor 0.8.0",
"http",
"pin-project",
"thiserror 2.0.12",

View File

@@ -49,7 +49,7 @@ impl Client {
.expect("testnet relays are valid urls")
});
let mut client = builder.build().expect("testnet build should be infallibl");
let mut client = builder.build().expect("testnet build should be infallible");
client.testnet = true;

View File

@@ -53,7 +53,10 @@ axum-test = "17.2.0"
tempfile = { version = "3.10.1" }
dyn-clone = "1.0.19"
reqwest = "0.12.15"
governor = "0.10.0"
fast-glob = "0.4.5"
[dev-dependencies]
futures-lite = "2.6.0"

View File

@@ -24,6 +24,35 @@ pubky_listen_socket = "127.0.0.1:6287"
# May be put behind a reverse proxy with TLS enabled.
icann_listen_socket = "127.0.0.1:6286"
# Rate limit endpoints dynamically.
# `path` is a glob pattern of the path. See syntax in https://crates.io/crates/fast-glob
# `method` is the HTTP method. Examples: GET, POST, PUT, HEAD, DELETE
# `quota` defines the limit itself in the format $rate$rate_unit/$time_unit.
# - $rate is a positive integer, max 4'294'967'296.
# - $rate_unit is either "r" for requests, "kb" for kilobytes,
# "mb" for megabytes, "gb" for gigabytes.
# Speed limits limit download and upload.
# - $time_unit is either "s" for second, "m" for minute.
# `key` defines what who is rate limited.
# - "ip" limits based on the IP address.
# - "user" limits based on the user pubkey. Requires the endpoint to have authentication.
# `burst` is a temporary allowance of quota that is added to the limit.
# By default, burst is equal the quota rate.
#
# Limit login attempts to 20 requests per minute per IP.
[[drive.rate_limits]]
path = "/session"
method = "POST"
quota = "20r/m"
key = "ip"
#
# Limit file uploads to 1 megabyte per second per user with a temporary burst of 10 megabytes.
[[drive.rate_limits]]
path = "/pub/**"
method = "PUT"
quota = "1mb/s"
key = "user"
burst = 10
[admin]
# The port number to run the admin HTTP (clear text) server on.

View File

@@ -151,7 +151,7 @@ impl HomeserverCore {
signup_mode: context.config_toml.general.signup_mode.clone(),
user_quota_bytes: quota_bytes,
};
super::routes::create_app(state.clone())
super::routes::create_app(state.clone(), context)
}
/// Start the ICANN HTTP server

View File

@@ -1,3 +1,4 @@
pub mod authz;
pub mod pubky_host;
pub mod rate_limiter;
pub mod trace;

View File

@@ -0,0 +1,37 @@
use axum::extract::Request;
use axum::http::HeaderMap;
use std::net::{IpAddr, SocketAddr};
// From https://github.com/benwis/tower-governor/blob/main/src/key_extractor.rs#L121
const X_REAL_IP: &str = "x-real-ip";
const X_FORWARDED_FOR: &str = "x-forwarded-for";
/// Tries to parse the `x-forwarded-for` header
fn maybe_x_forwarded_for(headers: &HeaderMap) -> Option<IpAddr> {
headers
.get(X_FORWARDED_FOR)
.and_then(|hv| hv.to_str().ok())
.and_then(|s| s.split(',').find_map(|s| s.trim().parse::<IpAddr>().ok()))
}
/// Tries to parse the `x-real-ip` header
fn maybe_x_real_ip(headers: &HeaderMap) -> Option<IpAddr> {
headers
.get(X_REAL_IP)
.and_then(|hv| hv.to_str().ok())
.and_then(|s| s.parse::<IpAddr>().ok())
}
fn maybe_connect_info<T>(req: &Request<T>) -> Option<IpAddr> {
req.extensions()
.get::<axum::extract::ConnectInfo<SocketAddr>>()
.map(|addr| addr.ip())
}
pub fn extract_ip<T>(req: &Request<T>) -> anyhow::Result<IpAddr> {
let headers = req.headers();
maybe_x_forwarded_for(headers)
.or_else(|| maybe_x_real_ip(headers))
.or_else(|| maybe_connect_info(req))
.ok_or(anyhow::anyhow!("Failed to extract ip."))
}

View File

@@ -0,0 +1,539 @@
//!
//! Implements rate limiting with governor.
//!
//! Would love to use tower_governor but I can't type it properly due to
//! https://github.com/benwis/tower-governor/issues/49.
//!
//! So we implement our own rate limiter here.
//!
use axum::response::{IntoResponse, Response};
use axum::{
body::Body,
http::{Request, StatusCode},
};
use futures_util::future::BoxFuture;
use governor::clock::QuantaClock;
use governor::state::keyed::DashMapStateStore;
use std::num::NonZero;
use std::sync::Arc;
use std::time::Duration;
use std::{convert::Infallible, task::Poll};
use tower::{Layer, Service};
use crate::core::error::Result;
use crate::core::extractors::PubkyHost;
use crate::core::Error;
use crate::quota_config::{LimitKey, PathLimit, RateUnit};
use futures_util::StreamExt;
use governor::{Jitter, Quota, RateLimiter};
use super::extract_ip::extract_ip;
/// A Tower Layer to handle general rate limiting.
///
/// Supports rate limiting by request count and by upload/download speed.
///
/// Requires a `PubkyHostLayer` to be applied first.
/// Used to extract the user pubkey as the key for the rate limiter.
///
/// Returns 400 BAD REQUEST if the user pubkey aka pubky-host cannot be extracted.
///
#[derive(Debug, Clone)]
pub struct RateLimiterLayer {
limits: Vec<PathLimit>,
}
impl RateLimiterLayer {
/// Create a new rate limiter layer with the given quota.
///
/// If quota is None, rate limiting is disabled.
pub fn new(limits: Vec<PathLimit>) -> Self {
if limits.is_empty() {
tracing::info!("Rate limiting is disabled.");
} else {
let limits_str = limits
.iter()
.map(|limit| format!("\"{limit}\""))
.collect::<Vec<String>>();
tracing::info!("Rate limits configured: {}", limits_str.join(", "));
}
Self { limits }
}
}
impl<S> Layer<S> for RateLimiterLayer {
type Service = RateLimiterMiddleware<S>;
fn layer(&self, inner: S) -> Self::Service {
let tuples = self
.limits
.iter()
.map(|path| LimitTuple::new(path.clone()))
.collect();
RateLimiterMiddleware {
inner,
limits: tuples,
}
}
}
/// A tuple of a path limit and the actual governor rate limiter.
#[derive(Debug, Clone)]
struct LimitTuple(
pub PathLimit,
pub Arc<RateLimiter<String, DashMapStateStore<String>, QuantaClock>>,
);
impl LimitTuple {
pub fn new(path_limit: PathLimit) -> Self {
let quota: Quota = path_limit.clone().into();
let limiter = Arc::new(RateLimiter::keyed(quota));
Self(path_limit, limiter)
}
/// Extract the key from the request.
///
/// The key is either the ip address of the client
/// or the user pubkey.
fn extract_key(&self, req: &Request<Body>) -> anyhow::Result<String> {
match self.0.key {
LimitKey::Ip => extract_ip(req).map(|ip| ip.to_string()),
LimitKey::User => {
// Extract the user pubkey from the request.
req.extensions()
.get::<PubkyHost>()
.map(|pk| pk.public_key().to_string())
.ok_or(anyhow::anyhow!("Failed to extract user pubkey."))
}
}
}
/// Check if the request matches the limit.
pub fn is_match(&self, req: &Request<Body>) -> bool {
self.0.path.is_match(req.uri().path()) && self.0.method.0 == req.method()
}
}
#[derive(Debug, Clone)]
pub struct RateLimiterMiddleware<S> {
inner: S,
limits: Vec<LimitTuple>,
}
impl<S> RateLimiterMiddleware<S> {
/// Throttle the upload body.
fn throttle_upload(
req: Request<Body>,
key: &str,
limiter: &Arc<RateLimiter<String, DashMapStateStore<String>, QuantaClock>>,
) -> Request<Body> {
let (parts, body) = req.into_parts();
let new_body = Self::throttle_body(body, key, limiter);
Request::from_parts(parts, new_body)
}
/// Throttle the download body.
fn throttle_download(
res: Response<Body>,
key: &str,
limiter: &Arc<RateLimiter<String, DashMapStateStore<String>, QuantaClock>>,
) -> Response<Body> {
let (parts, body) = res.into_parts();
let new_body = Self::throttle_body(body, key, limiter);
Response::from_parts(parts, new_body)
}
/// Throttle the up or download body.
///
/// Important: The speed quotas are always in kilobytes, not bytes.
/// Counting bytes is not practical.
///
fn throttle_body(
body: Body,
key: &str,
limiter: &Arc<RateLimiter<String, DashMapStateStore<String>, QuantaClock>>,
) -> Body {
let body_stream = body.into_data_stream();
let limiter = limiter.clone();
let key = key.to_string();
let throttled = body_stream
.map(move |chunk| {
let limiter = limiter.clone();
let key = key.to_string();
// When the rate limit is exceeded, we wait between 25ms and 500ms before retrying.
// This is to avoid overwhelming the server with requests when the rate limit is exceeded.
// Randomization is used to avoid thundering herd problem.
let jitter = Jitter::new(
Duration::from_millis(25),
Duration::from_millis(500),
);
async move {
let bytes = match chunk {
Ok(actual_chunk) => {
actual_chunk
}
Err(e) => return Err(e),
};
// --- Round up to the nearest kilobyte. ---
// Important: If the chunk is < 1KB, it will be rounded up to 1 kb.
// Many small uploads will be counted as more than they actually are.
// I am not too concerned about this though because small random disk writes are stressing
// the disk more anyway compared to larger writes.
// Why are we doing this? governor::Quota is defined as a u32. u32 can only count up to 4GB.
// To support 4GB/s+ limits we need to count in kilobytes.
//
// --- Chunk Size ---
// The chunk size is determined by the client library.
// Common chunk sizes: 16KB to 10MB.
// HTTP based uploads are usually between 256KB and 1MB.
// Asking the limiter for 1KB packets is tradeoff between
// - Not calling the limiter too much
// - Guaranteeing the call size (1kb) is low enough to not cause race condition issues.
let chunk_kilobytes = bytes.len().div_ceil(1024);
for _ in 0..chunk_kilobytes {
// Check each kilobyte
if limiter
.until_key_n_ready_with_jitter(
&key,
NonZero::new(1).expect("1 is always non zero"),
jitter,
)
.await.is_err()
{
// Requested rate (1kb) is higher then the set limit (x kb/s).
// This should never happen.
unreachable!("Rate limiting is based on the number of kilobytes, not bytes. So 1 kb should always be allowed.");
};
}
Ok(bytes)
}
})
.buffered(1);
Body::from_stream(throttled)
}
/// Get the limits that match the request.
fn get_limit_matches(&self, req: &Request<Body>) -> Vec<&LimitTuple> {
self.limits
.iter()
.filter(|limit| limit.is_match(req))
.collect()
}
}
impl<S> Service<Request<Body>> for RateLimiterMiddleware<S>
where
S: Service<Request<Body>, Response = axum::response::Response, Error = Infallible>
+ Send
+ 'static
+ Clone,
S::Future: Send + 'static,
{
type Response = S::Response;
type Error = S::Error;
type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
fn poll_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx).map_err(|_| unreachable!()) // `Infallible` conversion
}
fn call(&mut self, mut req: Request<Body>) -> Self::Future {
let mut inner = self.inner.clone();
// Match the request path + method to the defined limits.
let limits = self.get_limit_matches(&req);
if limits.is_empty() {
// No limits matched, so we can just call the next layer.
return Box::pin(async move { inner.call(req).await.map_err(|_| unreachable!()) });
}
// Go through all the limits and check if we need to throttle or reject the request.
for limit in limits.clone() {
let key = match limit.extract_key(&req) {
Ok(key) => key,
Err(e) => {
// Failed to extract the key, so we reject the request.
// This should only happen if the pubky-host couldn't be extracted.
tracing::warn!(
"{} {} Failed to extract key for rate limiting: {}",
limit.0.path.0,
limit.0.method.0,
e
);
return Box::pin(async move {
Ok(Error::new(
StatusCode::BAD_REQUEST,
Some("Failed to extract key for rate limiting".to_string()),
)
.into_response())
});
}
};
match limit.0.quota.rate_unit {
RateUnit::SpeedRateUnit(_) => {
// Speed limiting is enabled, so we need to throttle the upload.
req = Self::throttle_upload(req, &key, &limit.1);
}
RateUnit::Request => {
// Request limiting is enabled, so we need to limit the number of requests.
if let Err(e) = limit.1.check_key(&key) {
tracing::debug!(
"Rate limit of {} exceeded for {key}: {}",
limit.0.quota,
e
);
return Box::pin(async move {
Ok(Error::new(
StatusCode::TOO_MANY_REQUESTS,
Some("Rate limit exceeded".to_string()),
)
.into_response())
});
};
}
};
}
// Create a clone of the request without the body.
// This way, we can extract the keys for the response too.
let (parts, body) = req.into_parts();
let req_clone = Request::from_parts(parts.clone(), Body::empty());
let req = Request::from_parts(parts, body);
let speed_limits = limits
.into_iter()
.filter(|limit| limit.0.quota.rate_unit.is_speed_rate_unit())
.cloned()
.collect::<Vec<_>>();
Box::pin(async move {
// Call the next layer and receive the response.
let mut response = match inner.call(req).await.map_err(|_| unreachable!()) {
Ok(response) => response,
Err(e) => return Err(e),
};
// Rate limit the download speed.
for limit in speed_limits {
if let Ok(key) = limit.extract_key(&req_clone) {
response = Self::throttle_download(response, &key, &limit.1);
};
}
Ok(response)
})
}
}
#[cfg(test)]
mod tests {
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use axum::http::Method;
use axum::{
routing::{get, post},
Router,
};
use axum_server::Server;
use pkarr::{Keypair, PublicKey};
use reqwest::{Client, Response};
use tokio::{task::JoinHandle, time::Instant};
use crate::{core::layers::pubky_host::PubkyHostLayer, quota_config::GlobPattern};
use super::*;
// Fake upload handler that just consumes the body.
pub async fn upload_handler(body: Body) -> Result<impl IntoResponse> {
let mut stream = body.into_data_stream();
while let Some(chunk) = stream.next().await.transpose()? {
// Consume body
let _ = chunk;
}
Ok((StatusCode::CREATED, ()))
}
// Fake upload handler that just consumes the body.
pub async fn download_handler() -> Result<impl IntoResponse> {
let response_body = vec![0u8; 3 * 1024]; // 3kb
Ok((StatusCode::OK, response_body))
}
// Start a server with the given quota config on a random port.
async fn start_server(config: Vec<PathLimit>) -> SocketAddr {
let app = Router::new()
.route("/upload", post(upload_handler))
.route("/download", get(download_handler))
.layer(RateLimiterLayer::new(config))
.layer(PubkyHostLayer);
// Create a TCP listener to bind to the socket first
// Use port 0 to let the OS assign a random available port
let listener = tokio::net::TcpListener::bind(SocketAddr::new(
IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)),
0,
))
.await
.unwrap();
// Get the actual socket address with the OS-assigned port
let socket = listener.local_addr().unwrap();
// Use the listener with axum_server
let server = Server::from_tcp(listener.into_std().unwrap());
tokio::spawn(async move {
server
.serve(app.into_make_service_with_connect_info::<SocketAddr>())
.await
.unwrap();
});
socket
}
#[tokio::test]
async fn test_throttle_upload() {
let path_limit = PathLimit::new(
GlobPattern::new("/upload"),
Method::POST,
"1kb/s".parse().unwrap(),
LimitKey::Ip,
None,
);
let socket = start_server(vec![path_limit]).await;
fn upload_data(socket: SocketAddr, kilobytes: usize) -> JoinHandle<()> {
tokio::spawn(async move {
let client = Client::new();
let data = vec![0u8; kilobytes * 1024];
let response = client
.post(format!("http://{}/upload", socket))
.body(data.clone())
.send()
.await
.unwrap();
assert_eq!(response.status(), StatusCode::CREATED);
})
}
let start = Instant::now();
// Spawn in the background to test 2 uploads in parallel
// Upload 3kb each
let handle1 = upload_data(socket, 3);
let handle2 = upload_data(socket, 3);
// Wait for the uploads to finish
let _ = tokio::try_join!(handle1, handle2);
let time_taken = start.elapsed();
assert!(time_taken > Duration::from_secs(6), "Should at least take 6s because uploads are limited to 1kb/s and the sum of the uploads is 6kb");
}
#[tokio::test]
async fn test_throttle_download() {
let path_limit = PathLimit::new(
GlobPattern::new("/download"),
Method::GET,
"1kb/s".parse().unwrap(),
LimitKey::Ip,
None,
);
let socket = start_server(vec![path_limit]).await;
fn download_data(socket: SocketAddr) -> JoinHandle<()> {
tokio::spawn(async move {
let client = Client::new();
let response = client
.get(format!("http://{}/download", socket))
.send()
.await
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
response.bytes().await.unwrap(); // Download the body
})
}
let start = Instant::now();
// Spawn in the background to test 2 downloads in parallel
// Download 3kb each
let handle1 = download_data(socket);
let handle2 = download_data(socket);
// Wait for the uploads to finish
let _ = tokio::try_join!(handle1, handle2);
let time_taken = start.elapsed();
assert!(time_taken > Duration::from_secs(6), "Should at least take 6s because downloads are limited to 1kb/s and the sum of the downloads is 6kb");
}
#[tokio::test]
async fn test_limit_parallel_requests_with_ip_key() {
let path_limit = PathLimit::new(
GlobPattern::new("/upload"),
Method::POST,
"1r/m".parse().unwrap(),
LimitKey::Ip,
None,
);
let socket = start_server(vec![path_limit]).await;
fn send_request(socket: SocketAddr) -> JoinHandle<Response> {
tokio::spawn(async move {
let client = Client::new();
let response = client
.post(format!("http://{}/upload", socket))
.send()
.await
.unwrap();
response
})
}
// Spawn in the background to test 2 uploads in parallel
let handle1 = send_request(socket);
let handle2 = send_request(socket);
// Wait for the uploads to finish
let (res1, res2) = tokio::try_join!(handle1, handle2).unwrap();
assert_eq!(res1.status(), StatusCode::CREATED);
assert_eq!(res2.status(), StatusCode::TOO_MANY_REQUESTS);
}
#[tokio::test]
async fn test_limit_parallel_requests_with_user_key() {
let path_limit = PathLimit::new(
GlobPattern::new("/upload"),
Method::POST,
"1r/m".parse().unwrap(),
LimitKey::User,
None,
);
let socket = start_server(vec![path_limit]).await;
fn send_request(socket: SocketAddr, user_pubkey: PublicKey) -> JoinHandle<Response> {
tokio::spawn(async move {
let client = Client::new();
let response = client
.post(format!("http://{}/upload?pubky-host={user_pubkey}", socket))
.send()
.await
.unwrap();
response
})
}
// Spawn in the background to test 2 uploads in parallel
let user1_pubkey = Keypair::random().public_key();
let handle1 = send_request(socket, user1_pubkey.clone());
let handle2 = send_request(socket, user1_pubkey.clone());
let user2_pubkey = Keypair::random().public_key();
let handle3 = send_request(socket, user2_pubkey.clone());
// Wait for the uploads to finish
let (res1, res2, res3) = tokio::try_join!(handle1, handle2, handle3).unwrap();
assert_eq!(res1.status(), StatusCode::CREATED);
assert_eq!(res2.status(), StatusCode::TOO_MANY_REQUESTS);
assert_eq!(res3.status(), StatusCode::CREATED);
}
}

View File

@@ -0,0 +1,3 @@
mod extract_ip;
mod layer;
pub use layer::*;

View File

@@ -13,9 +13,11 @@ use tower::ServiceBuilder;
use tower_cookies::CookieManagerLayer;
use tower_http::cors::CorsLayer;
use crate::core::AppState;
use crate::{core::AppState, AppContext};
use super::layers::{pubky_host::PubkyHostLayer, trace::with_trace_layer};
use super::layers::{
pubky_host::PubkyHostLayer, rate_limiter::RateLimiterLayer, trace::with_trace_layer,
};
mod auth;
mod feed;
@@ -37,16 +39,20 @@ fn base() -> Router<AppState> {
// TODO: maybe add to a separate router (drive router?).
}
pub fn create_app(state: AppState) -> Router {
pub fn create_app(state: AppState, context: &AppContext) -> Router {
let app = base()
.merge(tenants::router(state.clone()))
.layer(CookieManagerLayer::new())
.layer(CorsLayer::very_permissive())
.layer(ServiceBuilder::new().layer(middleware::from_fn(add_server_header)))
.layer(PubkyHostLayer)
.layer(RateLimiterLayer::new(
context.config_toml.drive.rate_limits.clone(),
))
.with_state(state);
// Apply trace and pubky host layers to the complete router.
with_trace_layer(app, &TRACING_EXCLUDED_PATHS).layer(PubkyHostLayer)
with_trace_layer(app, &TRACING_EXCLUDED_PATHS)
}
// Middleware to add a `Server` header to all responses

View File

@@ -3,9 +3,11 @@ signup_mode = "token_required"
lmdb_backup_interval_s = 0
user_storage_quota_mb = 0
[drive]
pubky_listen_socket = "127.0.0.1:6287"
icann_listen_socket = "127.0.0.1:6286"
rate_limits = []
[admin]
listen_socket = "127.0.0.1:6288"

View File

@@ -4,7 +4,7 @@
//! This module embeds that file at compile-time, parses it once,
//! and lets callers optionally layer their own TOML on top.
use super::{domain_port::DomainPort, Domain, SignupMode};
use super::{domain_port::DomainPort, quota_config::PathLimit, Domain, SignupMode};
use serde::{Deserialize, Serialize};
use std::{
fmt::Debug,
@@ -54,6 +54,7 @@ pub struct PkdnsToml {
pub struct DriveToml {
pub pubky_listen_socket: SocketAddr,
pub icann_listen_socket: SocketAddr,
pub rate_limits: Vec<PathLimit>,
}
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
@@ -144,9 +145,8 @@ impl ConfigToml {
.lines()
.map(|line| {
let trimmed = line.trim_start();
let is_title = trimmed.starts_with('[');
let is_comment = trimmed.starts_with('#');
if !is_title && !is_comment && !trimmed.is_empty() {
if !is_comment && !trimmed.is_empty() {
format!("# {}", line)
} else {
line.to_string()
@@ -216,12 +216,13 @@ mod tests {
assert_eq!(c.pkdns.dht_bootstrap_nodes, None);
assert_eq!(c.pkdns.dht_relay_nodes, None);
assert_eq!(c.pkdns.dht_request_timeout_ms, None);
assert_eq!(c.drive.rate_limits, vec![]);
}
#[test]
fn test_sample_config() {
// Validate that the sample config can be parsed
ConfigToml::from_str(SAMPLE_CONFIG).expect("Embedded config.default.toml must be valid");
ConfigToml::from_str(SAMPLE_CONFIG).expect("Embedded config.sample.toml must be valid");
}
#[test]

View File

@@ -4,6 +4,8 @@ mod domain;
mod domain_port;
mod mock_data_dir;
mod persistent_data_dir;
/// Quota configuration for the TomlConfig.
pub mod quota_config;
mod signup_mode;
pub use config_toml::{ConfigReadError, ConfigToml};

View File

@@ -0,0 +1,89 @@
use serde::{Deserialize, Serialize};
use std::{fmt::Display, str::FromStr};
/// A wrapper around fast_glob to allow serialize/deserialize.
/// Pattern matches glob patterns.
///
/// Syntax - Meaning
/// `?` - Matches any single character.
/// `*` - Matches zero or more characters, except for path separators (e.g. /).
/// `**` - Matches zero or more characters, including path separators. Must match a complete path segment (i.e. followed by a / or the end of the pattern).
/// `[ab]` - Matches one of the characters contained in the brackets. Character ranges, e.g. `[a-z]` are also supported. Use `[!ab]` or `[^ab]` to match any character except those contained in the brackets.
/// `{a,b}` - Matches one of the patterns contained in the braces. Any of the wildcard characters can be used in the sub-patterns. Braces may be nested up to 10 levels deep.
/// `!` - When at the start of the glob, this negates the result. Multiple `!` characters negate the glob multiple times.
/// `\` - A backslash character may be used to escape any of the above special characters.
#[derive(Debug, Clone)]
pub struct GlobPattern(pub String);
impl GlobPattern {
/// Create a new glob pattern.
pub fn new(pattern: &str) -> Self {
Self(pattern.to_string())
}
/// Check if the path matches the glob pattern.
pub fn is_match(&self, path: &str) -> bool {
fast_glob::glob_match(&self.0, path)
}
}
impl std::hash::Hash for GlobPattern {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.0.as_str().hash(state);
}
}
impl FromStr for GlobPattern {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
Ok(GlobPattern(s.to_string()))
}
}
impl Display for GlobPattern {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0.as_str())
}
}
impl Serialize for GlobPattern {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
serializer.serialize_str(self.0.as_str())
}
}
impl<'de> Deserialize<'de> for GlobPattern {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let s = String::deserialize(deserializer)?;
GlobPattern::from_str(&s).map_err(serde::de::Error::custom)
}
}
impl PartialEq for GlobPattern {
fn eq(&self, other: &Self) -> bool {
self.0.as_str() == other.0.as_str()
}
}
impl Eq for GlobPattern {}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_glob_pattern() {
let glob_pattern = GlobPattern::from_str("/pub/**").unwrap();
assert!(glob_pattern.is_match("/pub/test.txt"));
assert!(glob_pattern.is_match("/pub/test/test.txt"));
assert!(!glob_pattern.is_match("/priv/test.pdf"));
assert!(!glob_pattern.is_match("/session/test.txt"));
}
}

View File

@@ -0,0 +1,55 @@
use std::{fmt::Display, str::FromStr};
use axum::http::Method;
use serde::{Deserialize, Serialize};
/// A wrapper around http::Method to implement serde traits
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct HttpMethod(pub Method);
impl From<Method> for HttpMethod {
fn from(method: Method) -> Self {
HttpMethod(method)
}
}
impl From<HttpMethod> for Method {
fn from(method: HttpMethod) -> Self {
method.0
}
}
impl FromStr for HttpMethod {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
Method::from_str(s.to_uppercase().as_str())
.map(HttpMethod)
.map_err(|_| format!("Invalid method: {}", s))
}
}
impl Display for HttpMethod {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
impl Serialize for HttpMethod {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
serializer.serialize_str(&self.to_string())
}
}
impl<'de> Deserialize<'de> for HttpMethod {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let s = String::deserialize(deserializer)?;
HttpMethod::from_str(&s).map_err(serde::de::Error::custom)
}
}

View File

@@ -0,0 +1,55 @@
use std::fmt;
use std::str::FromStr;
/// The key to limit the quota on.
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum LimitKey {
/// Limit on the user id
User,
/// Limit on the ip address
Ip,
}
impl fmt::Display for LimitKey {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"{}",
match self {
LimitKey::User => "user",
LimitKey::Ip => "ip",
}
)
}
}
impl FromStr for LimitKey {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"user" => Ok(LimitKey::User),
"ip" => Ok(LimitKey::Ip),
_ => Err(format!("Invalid limit key: {}", s)),
}
}
}
impl serde::Serialize for LimitKey {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
serializer.serialize_str(self.to_string().as_str())
}
}
impl<'de> serde::Deserialize<'de> for LimitKey {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let s = String::deserialize(deserializer)?;
LimitKey::from_str(&s).map_err(serde::de::Error::custom)
}
}

View File

@@ -0,0 +1,15 @@
mod glob_pattern;
mod http_method;
mod limit_key;
mod path_limit;
mod quota_value;
mod rate_unit;
mod time_unit;
pub use glob_pattern::GlobPattern;
pub use http_method::HttpMethod;
pub use limit_key::LimitKey;
pub use path_limit::*;
pub use quota_value::QuotaValue;
pub use rate_unit::RateUnit;
pub use time_unit::TimeUnit;

View File

@@ -0,0 +1,77 @@
use super::{GlobPattern, HttpMethod, LimitKey, QuotaValue};
use axum::http::Method;
use serde::{Deserialize, Serialize};
use std::num::NonZeroU32;
/// A limit on a path for a specific method.
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Hash)]
pub struct PathLimit {
/// The path glob pattern to match against.
pub path: GlobPattern,
/// The method to limit.
pub method: HttpMethod,
/// The limit to apply.
pub quota: QuotaValue,
/// The key to limit.
pub key: LimitKey,
/// The burst to apply.
pub burst: Option<NonZeroU32>,
}
impl PathLimit {
/// Create a new path limit.
pub fn new(
path: GlobPattern,
method: Method,
quota: QuotaValue,
key: LimitKey,
burst: Option<NonZeroU32>,
) -> Self {
Self {
path,
method: HttpMethod(method),
quota,
key,
burst,
}
}
}
impl std::fmt::Display for PathLimit {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"{} {}: {}-{} by {}",
self.method,
self.path,
self.quota,
self.burst.map_or_else(String::new, |b| b.to_string()),
self.key
)
}
}
impl From<PathLimit> for governor::Quota {
fn from(value: PathLimit) -> Self {
let quota: governor::Quota = value.quota.into();
if let Some(burst) = value.burst {
quota.allow_burst(burst);
}
quota
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_http_method_serde() {
let method = Method::GET;
let http_method = HttpMethod(method);
assert_eq!(http_method.to_string(), "GET");
let deserialized: HttpMethod = "GET".parse().unwrap();
assert_eq!(deserialized, http_method);
}
}

View File

@@ -0,0 +1,201 @@
use std::fmt;
use std::str::FromStr;
use std::{num::NonZeroU32, time::Duration};
use super::{RateUnit, TimeUnit};
/// Quota value
///
/// Examples:
/// - 5r/m
/// - 5r/s
/// - 5kb/m
/// - 5mb/m
/// - 5gb/s
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct QuotaValue {
/// The rate.
pub rate: NonZeroU32,
/// The unit of the rate.
pub rate_unit: RateUnit,
/// The unit of the time.
pub time_unit: TimeUnit,
}
impl From<QuotaValue> for governor::Quota {
/// Get the quota to do the actual rate limiting.
///
/// Important: The speed quotas are always in kilobytes, not bytes.
/// Counting bytes is not practical.
///
fn from(value: QuotaValue) -> Self {
let rate_count = value.rate.get();
let rate_unit = value.rate_unit.multiplier().get();
let rate = NonZeroU32::new(rate_count * rate_unit)
.expect("Is always non-zero because rate count and rate unit multiplier are non-zero");
let time_unit = Duration::from_secs(value.time_unit.multiplier_in_seconds().get() as u64);
let replenish_1_per = time_unit / rate.get();
let base_quota = governor::Quota::with_period(replenish_1_per)
.expect("Is always non-zero because replenish_1_per is non-zero");
base_quota.allow_burst(rate)
}
}
impl fmt::Display for QuotaValue {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}{}/{}", self.rate, self.rate_unit, self.time_unit)
}
}
impl FromStr for QuotaValue {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
// Split rate part by '/' to get rate+unit and time unit
let rate_parts: Vec<&str> = s.split('/').collect();
if rate_parts.len() != 2 {
return Err(format!(
"Invalid rate format: '{}', expected {{rate}}{{unit}}/{{time}}",
s
));
}
let rate_with_unit = rate_parts[0];
let time_unit = TimeUnit::from_str(rate_parts[1])?;
// Find the boundary between rate digits and unit
let rate_digit_end = rate_with_unit
.chars()
.position(|c| !c.is_ascii_digit())
.unwrap_or(rate_with_unit.len());
if rate_digit_end == 0 {
return Err(format!("Missing rate value in '{}'", rate_with_unit));
}
let rate_str = &rate_with_unit[..rate_digit_end];
let rate_unit_str = &rate_with_unit[rate_digit_end..];
let rate = rate_str
.parse::<NonZeroU32>()
.map_err(|_| format!("Failed to parse rate from '{}'", rate_str))?;
let rate_unit = RateUnit::from_str(rate_unit_str)?;
Ok(QuotaValue {
rate,
rate_unit,
time_unit,
})
}
}
impl serde::Serialize for QuotaValue {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
serializer.serialize_str(&self.to_string())
}
}
impl<'de> serde::Deserialize<'de> for QuotaValue {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let s = String::deserialize(deserializer)?;
// Parse the quota string
QuotaValue::from_str(&s).map_err(serde::de::Error::custom)
}
}
#[cfg(test)]
mod tests {
use crate::quota_config::rate_unit::SpeedRateUnit;
use super::*;
#[test]
fn test_get_quota() {
let quota = QuotaValue::from_str("5r/s").unwrap();
assert_eq!(
governor::Quota::from(quota),
governor::Quota::per_second(NonZeroU32::new(5).unwrap())
);
let quota = QuotaValue::from_str("5r/m").unwrap();
assert_eq!(
governor::Quota::from(quota),
governor::Quota::per_minute(NonZeroU32::new(5).unwrap())
);
let quota = QuotaValue::from_str("5kb/s").unwrap();
assert_eq!(
governor::Quota::from(quota),
governor::Quota::per_second(NonZeroU32::new(5).unwrap())
);
let quota = QuotaValue::from_str("5mb/m").unwrap();
assert_eq!(
governor::Quota::from(quota),
governor::Quota::per_minute(NonZeroU32::new(5 * 1024).unwrap())
);
}
#[test]
fn test_quota_value_from_str() {
// Test without burst
let quota = QuotaValue::from_str("5r/s").unwrap();
assert_eq!(quota.rate, NonZeroU32::new(5).unwrap());
assert_eq!(quota.rate_unit, RateUnit::Request);
assert_eq!(quota.time_unit, TimeUnit::Second);
// Test with burst (should fail or be handled differently)
let quota = QuotaValue::from_str("10mb/m").unwrap();
assert_eq!(quota.rate, NonZeroU32::new(10).unwrap());
assert_eq!(
quota.rate_unit,
RateUnit::SpeedRateUnit(SpeedRateUnit::Megabyte)
);
assert_eq!(quota.time_unit, TimeUnit::Minute);
}
#[test]
fn test_quota_value_display() {
// Test without burst
let quota = QuotaValue {
rate: NonZeroU32::new(5).unwrap(),
rate_unit: RateUnit::Request,
time_unit: TimeUnit::Second,
};
assert_eq!(quota.to_string(), "5r/s");
// Test with burst (should be displayed without burst)
let quota = QuotaValue {
rate: NonZeroU32::new(10).unwrap(),
rate_unit: RateUnit::SpeedRateUnit(SpeedRateUnit::Megabyte),
time_unit: TimeUnit::Minute,
};
assert_eq!(quota.to_string(), "10mb/m");
}
#[test]
fn test_quota_value_invalid_formats() {
// Invalid format: missing /
assert!(QuotaValue::from_str("5rs").is_err());
// Invalid format: missing rate
assert!(QuotaValue::from_str("r/s").is_err());
// Invalid format: invalid rate unit
assert!(QuotaValue::from_str("5x/s").is_err());
// Invalid format: invalid time unit
assert!(QuotaValue::from_str("5r/x").is_err());
// Invalid format: invalid burst (this test case might need to be removed or updated)
assert!(QuotaValue::from_str("5r/s-2burst").is_err());
}
}

View File

@@ -0,0 +1,112 @@
use std::{num::NonZeroU32, str::FromStr};
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum SpeedRateUnit {
/// Kilobyte
Kilobyte,
/// Megabyte
Megabyte,
/// Gigabyte
Gigabyte,
}
impl SpeedRateUnit {
/// Returns the number of bytes for this unit
pub const fn multiplier(&self) -> NonZeroU32 {
match self {
// Speed quotas are always in kilobytes.
// Because counting bytes is not practical and we are limited to u32 = 4GB max.
// Counting in kb as more practical and we can count up to 4GB*1024 = 4TB.
SpeedRateUnit::Kilobyte => NonZeroU32::new(1).expect("Is always non-zero"),
SpeedRateUnit::Megabyte => NonZeroU32::new(1024).expect("Is always non-zero"),
SpeedRateUnit::Gigabyte => NonZeroU32::new(1024 * 1024).expect("Is always non-zero"),
}
}
}
impl std::fmt::Display for SpeedRateUnit {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"{}",
match self {
SpeedRateUnit::Kilobyte => "kb",
SpeedRateUnit::Megabyte => "mb",
SpeedRateUnit::Gigabyte => "gb",
}
)
}
}
impl FromStr for SpeedRateUnit {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.to_lowercase().as_str() {
"kb" => Ok(SpeedRateUnit::Kilobyte),
"mb" => Ok(SpeedRateUnit::Megabyte),
"gb" => Ok(SpeedRateUnit::Gigabyte),
_ => Err(format!("Invalid speedrate unit: {}", s)),
}
}
}
/// The unit of the rate.
///
/// Examples:
/// - "r" -> request
/// - "kb" -> kilobyte
/// - "mb" -> megabyte
/// - "gb" -> gigabyte
/// - "tb" -> terabyte
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum RateUnit {
/// Request
Request,
/// Speed rate unit
SpeedRateUnit(SpeedRateUnit),
}
impl RateUnit {
/// Returns true if the rate unit is a speed rate unit.
pub fn is_speed_rate_unit(&self) -> bool {
matches!(self, RateUnit::SpeedRateUnit(_))
}
}
impl std::fmt::Display for RateUnit {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"{}",
match self {
RateUnit::Request => "r".to_string(),
RateUnit::SpeedRateUnit(unit) => unit.to_string(),
}
)
}
}
impl FromStr for RateUnit {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.to_lowercase().as_str() {
"r" => Ok(RateUnit::Request),
other => match SpeedRateUnit::from_str(other) {
Ok(unit) => Ok(RateUnit::SpeedRateUnit(unit)),
Err(_) => Err(format!("Invalid rate unit: {}", s)),
},
}
}
}
impl RateUnit {
/// Returns the number of bytes for this unit
pub const fn multiplier(&self) -> NonZeroU32 {
match self {
RateUnit::Request => NonZeroU32::new(1).expect("Is always non-zero"),
RateUnit::SpeedRateUnit(unit) => unit.multiplier(),
}
}
}

View File

@@ -0,0 +1,49 @@
use std::num::NonZeroU32;
/// The time unit of the quota.
///
/// Examples:
/// - "s" -> second
/// - "m" -> minute
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum TimeUnit {
/// Second
Second,
/// Minute
Minute,
}
impl std::fmt::Display for TimeUnit {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"{}",
match self {
TimeUnit::Second => "s",
TimeUnit::Minute => "m",
}
)
}
}
impl std::str::FromStr for TimeUnit {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"s" => Ok(TimeUnit::Second),
"m" => Ok(TimeUnit::Minute),
_ => Err(format!("Invalid time unit: {}", s)),
}
}
}
impl TimeUnit {
/// Returns the number of seconds for each unit
pub const fn multiplier_in_seconds(&self) -> NonZeroU32 {
match self {
TimeUnit::Second => NonZeroU32::new(1).expect("Is always non-zero"),
TimeUnit::Minute => NonZeroU32::new(60).expect("Is always non-zero"),
}
}
}