mirror of
https://github.com/aljazceru/pubky-core.git
synced 2026-01-26 01:14:24 +01:00
feat: Flexible endpoint rate limiting (#124)
* before ThrottledBody * added upload rate limiting * before dynamic rate limiting * flexible rate limiter + config * fixed tests * fmt and clippy * reset auth.js e2e * more cleaning up * improved code and added comments * limit downloads too * fmt and clippy * improved sample comments * fixed comment * removed http dependency in favour of axum:http * parse speed units as lowercase * replaced regex with glob * clippy and fmt
This commit is contained in:
committed by
GitHub
parent
00aafb2163
commit
8a0cec71ef
59
Cargo.lock
generated
59
Cargo.lock
generated
@@ -36,6 +36,12 @@ dependencies = [
|
||||
"memchr",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "allocator-api2"
|
||||
version = "0.2.21"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "683d7910e743518b0e34f1186f92494becacb047c7b6bf616c96772180fef923"
|
||||
|
||||
[[package]]
|
||||
name = "anstream"
|
||||
version = "0.6.18"
|
||||
@@ -923,6 +929,15 @@ dependencies = [
|
||||
"windows-sys 0.59.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "fast-glob"
|
||||
version = "0.4.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "39ea3f6bbcf4dbe2076b372186fc7aeecd5f6f84754582e56ee7db262b15a6f0"
|
||||
dependencies = [
|
||||
"arrayvec",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "fastrand"
|
||||
version = "2.3.0"
|
||||
@@ -953,6 +968,12 @@ version = "1.0.7"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1"
|
||||
|
||||
[[package]]
|
||||
name = "foldhash"
|
||||
version = "0.1.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d9c4f5dac5e15c24eb999c26181a6ca40b39fe946cbe4c263c7209467bc83af2"
|
||||
|
||||
[[package]]
|
||||
name = "foreign-types"
|
||||
version = "0.3.2"
|
||||
@@ -1185,6 +1206,29 @@ dependencies = [
|
||||
"spinning_top",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "governor"
|
||||
version = "0.10.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3cbe789d04bf14543f03c4b60cd494148aa79438c8440ae7d81a7778147745c3"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"dashmap",
|
||||
"futures-sink",
|
||||
"futures-timer",
|
||||
"futures-util",
|
||||
"getrandom 0.3.1",
|
||||
"hashbrown 0.15.2",
|
||||
"nonzero_ext",
|
||||
"parking_lot",
|
||||
"portable-atomic",
|
||||
"quanta",
|
||||
"rand 0.9.0",
|
||||
"smallvec",
|
||||
"spinning_top",
|
||||
"web-time",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "h2"
|
||||
version = "0.4.8"
|
||||
@@ -1224,6 +1268,11 @@ name = "hashbrown"
|
||||
version = "0.15.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "bf151400ff0baff5465007dd2f3e717f3fe502074ca563069ce3a6629d07b289"
|
||||
dependencies = [
|
||||
"allocator-api2",
|
||||
"equivalent",
|
||||
"foldhash",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "headers"
|
||||
@@ -1321,9 +1370,9 @@ checksum = "f558a64ac9af88b5ba400d99b579451af0d39c6d360980045b91aac966d705e2"
|
||||
|
||||
[[package]]
|
||||
name = "http"
|
||||
version = "1.2.0"
|
||||
version = "1.3.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f16ca2af56261c99fba8bac40a10251ce8188205a4c448fbb745a2e4daa76fea"
|
||||
checksum = "f4a85d31aea989eead29a3aaf9e1115a180df8282431156e533de47660892565"
|
||||
dependencies = [
|
||||
"bytes",
|
||||
"fnv",
|
||||
@@ -2157,7 +2206,7 @@ dependencies = [
|
||||
"bytes",
|
||||
"clap",
|
||||
"dirs-next",
|
||||
"governor",
|
||||
"governor 0.8.0",
|
||||
"http",
|
||||
"httpdate",
|
||||
"pkarr",
|
||||
@@ -2343,9 +2392,11 @@ dependencies = [
|
||||
"clap",
|
||||
"dirs",
|
||||
"dyn-clone",
|
||||
"fast-glob",
|
||||
"flume",
|
||||
"futures-lite",
|
||||
"futures-util",
|
||||
"governor 0.10.0",
|
||||
"heed",
|
||||
"hex",
|
||||
"hostname-validator",
|
||||
@@ -3500,7 +3551,7 @@ checksum = "57a2ccff6830fa835371af7541e561a90e4c07b84f72991ebac4b3cb6790dc0d"
|
||||
dependencies = [
|
||||
"axum",
|
||||
"forwarded-header-value",
|
||||
"governor",
|
||||
"governor 0.8.0",
|
||||
"http",
|
||||
"pin-project",
|
||||
"thiserror 2.0.12",
|
||||
|
||||
@@ -49,7 +49,7 @@ impl Client {
|
||||
.expect("testnet relays are valid urls")
|
||||
});
|
||||
|
||||
let mut client = builder.build().expect("testnet build should be infallibl");
|
||||
let mut client = builder.build().expect("testnet build should be infallible");
|
||||
|
||||
client.testnet = true;
|
||||
|
||||
|
||||
@@ -53,7 +53,10 @@ axum-test = "17.2.0"
|
||||
tempfile = { version = "3.10.1" }
|
||||
dyn-clone = "1.0.19"
|
||||
reqwest = "0.12.15"
|
||||
governor = "0.10.0"
|
||||
fast-glob = "0.4.5"
|
||||
|
||||
|
||||
[dev-dependencies]
|
||||
futures-lite = "2.6.0"
|
||||
|
||||
|
||||
@@ -24,6 +24,35 @@ pubky_listen_socket = "127.0.0.1:6287"
|
||||
# May be put behind a reverse proxy with TLS enabled.
|
||||
icann_listen_socket = "127.0.0.1:6286"
|
||||
|
||||
# Rate limit endpoints dynamically.
|
||||
# `path` is a glob pattern of the path. See syntax in https://crates.io/crates/fast-glob
|
||||
# `method` is the HTTP method. Examples: GET, POST, PUT, HEAD, DELETE
|
||||
# `quota` defines the limit itself in the format $rate$rate_unit/$time_unit.
|
||||
# - $rate is a positive integer, max 4'294'967'296.
|
||||
# - $rate_unit is either "r" for requests, "kb" for kilobytes,
|
||||
# "mb" for megabytes, "gb" for gigabytes.
|
||||
# Speed limits limit download and upload.
|
||||
# - $time_unit is either "s" for second, "m" for minute.
|
||||
# `key` defines what who is rate limited.
|
||||
# - "ip" limits based on the IP address.
|
||||
# - "user" limits based on the user pubkey. Requires the endpoint to have authentication.
|
||||
# `burst` is a temporary allowance of quota that is added to the limit.
|
||||
# By default, burst is equal the quota rate.
|
||||
#
|
||||
# Limit login attempts to 20 requests per minute per IP.
|
||||
[[drive.rate_limits]]
|
||||
path = "/session"
|
||||
method = "POST"
|
||||
quota = "20r/m"
|
||||
key = "ip"
|
||||
#
|
||||
# Limit file uploads to 1 megabyte per second per user with a temporary burst of 10 megabytes.
|
||||
[[drive.rate_limits]]
|
||||
path = "/pub/**"
|
||||
method = "PUT"
|
||||
quota = "1mb/s"
|
||||
key = "user"
|
||||
burst = 10
|
||||
|
||||
[admin]
|
||||
# The port number to run the admin HTTP (clear text) server on.
|
||||
|
||||
@@ -151,7 +151,7 @@ impl HomeserverCore {
|
||||
signup_mode: context.config_toml.general.signup_mode.clone(),
|
||||
user_quota_bytes: quota_bytes,
|
||||
};
|
||||
super::routes::create_app(state.clone())
|
||||
super::routes::create_app(state.clone(), context)
|
||||
}
|
||||
|
||||
/// Start the ICANN HTTP server
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
pub mod authz;
|
||||
pub mod pubky_host;
|
||||
pub mod rate_limiter;
|
||||
pub mod trace;
|
||||
|
||||
37
pubky-homeserver/src/core/layers/rate_limiter/extract_ip.rs
Normal file
37
pubky-homeserver/src/core/layers/rate_limiter/extract_ip.rs
Normal file
@@ -0,0 +1,37 @@
|
||||
use axum::extract::Request;
|
||||
use axum::http::HeaderMap;
|
||||
use std::net::{IpAddr, SocketAddr};
|
||||
|
||||
// From https://github.com/benwis/tower-governor/blob/main/src/key_extractor.rs#L121
|
||||
const X_REAL_IP: &str = "x-real-ip";
|
||||
const X_FORWARDED_FOR: &str = "x-forwarded-for";
|
||||
|
||||
/// Tries to parse the `x-forwarded-for` header
|
||||
fn maybe_x_forwarded_for(headers: &HeaderMap) -> Option<IpAddr> {
|
||||
headers
|
||||
.get(X_FORWARDED_FOR)
|
||||
.and_then(|hv| hv.to_str().ok())
|
||||
.and_then(|s| s.split(',').find_map(|s| s.trim().parse::<IpAddr>().ok()))
|
||||
}
|
||||
|
||||
/// Tries to parse the `x-real-ip` header
|
||||
fn maybe_x_real_ip(headers: &HeaderMap) -> Option<IpAddr> {
|
||||
headers
|
||||
.get(X_REAL_IP)
|
||||
.and_then(|hv| hv.to_str().ok())
|
||||
.and_then(|s| s.parse::<IpAddr>().ok())
|
||||
}
|
||||
|
||||
fn maybe_connect_info<T>(req: &Request<T>) -> Option<IpAddr> {
|
||||
req.extensions()
|
||||
.get::<axum::extract::ConnectInfo<SocketAddr>>()
|
||||
.map(|addr| addr.ip())
|
||||
}
|
||||
|
||||
pub fn extract_ip<T>(req: &Request<T>) -> anyhow::Result<IpAddr> {
|
||||
let headers = req.headers();
|
||||
maybe_x_forwarded_for(headers)
|
||||
.or_else(|| maybe_x_real_ip(headers))
|
||||
.or_else(|| maybe_connect_info(req))
|
||||
.ok_or(anyhow::anyhow!("Failed to extract ip."))
|
||||
}
|
||||
539
pubky-homeserver/src/core/layers/rate_limiter/layer.rs
Normal file
539
pubky-homeserver/src/core/layers/rate_limiter/layer.rs
Normal file
@@ -0,0 +1,539 @@
|
||||
//!
|
||||
//! Implements rate limiting with governor.
|
||||
//!
|
||||
//! Would love to use tower_governor but I can't type it properly due to
|
||||
//! https://github.com/benwis/tower-governor/issues/49.
|
||||
//!
|
||||
//! So we implement our own rate limiter here.
|
||||
//!
|
||||
use axum::response::{IntoResponse, Response};
|
||||
use axum::{
|
||||
body::Body,
|
||||
http::{Request, StatusCode},
|
||||
};
|
||||
use futures_util::future::BoxFuture;
|
||||
use governor::clock::QuantaClock;
|
||||
use governor::state::keyed::DashMapStateStore;
|
||||
use std::num::NonZero;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use std::{convert::Infallible, task::Poll};
|
||||
use tower::{Layer, Service};
|
||||
|
||||
use crate::core::error::Result;
|
||||
use crate::core::extractors::PubkyHost;
|
||||
use crate::core::Error;
|
||||
use crate::quota_config::{LimitKey, PathLimit, RateUnit};
|
||||
use futures_util::StreamExt;
|
||||
use governor::{Jitter, Quota, RateLimiter};
|
||||
|
||||
use super::extract_ip::extract_ip;
|
||||
|
||||
/// A Tower Layer to handle general rate limiting.
|
||||
///
|
||||
/// Supports rate limiting by request count and by upload/download speed.
|
||||
///
|
||||
/// Requires a `PubkyHostLayer` to be applied first.
|
||||
/// Used to extract the user pubkey as the key for the rate limiter.
|
||||
///
|
||||
/// Returns 400 BAD REQUEST if the user pubkey aka pubky-host cannot be extracted.
|
||||
///
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct RateLimiterLayer {
|
||||
limits: Vec<PathLimit>,
|
||||
}
|
||||
|
||||
impl RateLimiterLayer {
|
||||
/// Create a new rate limiter layer with the given quota.
|
||||
///
|
||||
/// If quota is None, rate limiting is disabled.
|
||||
pub fn new(limits: Vec<PathLimit>) -> Self {
|
||||
if limits.is_empty() {
|
||||
tracing::info!("Rate limiting is disabled.");
|
||||
} else {
|
||||
let limits_str = limits
|
||||
.iter()
|
||||
.map(|limit| format!("\"{limit}\""))
|
||||
.collect::<Vec<String>>();
|
||||
tracing::info!("Rate limits configured: {}", limits_str.join(", "));
|
||||
}
|
||||
Self { limits }
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> Layer<S> for RateLimiterLayer {
|
||||
type Service = RateLimiterMiddleware<S>;
|
||||
|
||||
fn layer(&self, inner: S) -> Self::Service {
|
||||
let tuples = self
|
||||
.limits
|
||||
.iter()
|
||||
.map(|path| LimitTuple::new(path.clone()))
|
||||
.collect();
|
||||
|
||||
RateLimiterMiddleware {
|
||||
inner,
|
||||
limits: tuples,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A tuple of a path limit and the actual governor rate limiter.
|
||||
#[derive(Debug, Clone)]
|
||||
struct LimitTuple(
|
||||
pub PathLimit,
|
||||
pub Arc<RateLimiter<String, DashMapStateStore<String>, QuantaClock>>,
|
||||
);
|
||||
|
||||
impl LimitTuple {
|
||||
pub fn new(path_limit: PathLimit) -> Self {
|
||||
let quota: Quota = path_limit.clone().into();
|
||||
let limiter = Arc::new(RateLimiter::keyed(quota));
|
||||
Self(path_limit, limiter)
|
||||
}
|
||||
|
||||
/// Extract the key from the request.
|
||||
///
|
||||
/// The key is either the ip address of the client
|
||||
/// or the user pubkey.
|
||||
fn extract_key(&self, req: &Request<Body>) -> anyhow::Result<String> {
|
||||
match self.0.key {
|
||||
LimitKey::Ip => extract_ip(req).map(|ip| ip.to_string()),
|
||||
LimitKey::User => {
|
||||
// Extract the user pubkey from the request.
|
||||
req.extensions()
|
||||
.get::<PubkyHost>()
|
||||
.map(|pk| pk.public_key().to_string())
|
||||
.ok_or(anyhow::anyhow!("Failed to extract user pubkey."))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if the request matches the limit.
|
||||
pub fn is_match(&self, req: &Request<Body>) -> bool {
|
||||
self.0.path.is_match(req.uri().path()) && self.0.method.0 == req.method()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct RateLimiterMiddleware<S> {
|
||||
inner: S,
|
||||
limits: Vec<LimitTuple>,
|
||||
}
|
||||
|
||||
impl<S> RateLimiterMiddleware<S> {
|
||||
/// Throttle the upload body.
|
||||
fn throttle_upload(
|
||||
req: Request<Body>,
|
||||
key: &str,
|
||||
limiter: &Arc<RateLimiter<String, DashMapStateStore<String>, QuantaClock>>,
|
||||
) -> Request<Body> {
|
||||
let (parts, body) = req.into_parts();
|
||||
let new_body = Self::throttle_body(body, key, limiter);
|
||||
Request::from_parts(parts, new_body)
|
||||
}
|
||||
|
||||
/// Throttle the download body.
|
||||
fn throttle_download(
|
||||
res: Response<Body>,
|
||||
key: &str,
|
||||
limiter: &Arc<RateLimiter<String, DashMapStateStore<String>, QuantaClock>>,
|
||||
) -> Response<Body> {
|
||||
let (parts, body) = res.into_parts();
|
||||
let new_body = Self::throttle_body(body, key, limiter);
|
||||
Response::from_parts(parts, new_body)
|
||||
}
|
||||
|
||||
/// Throttle the up or download body.
|
||||
///
|
||||
/// Important: The speed quotas are always in kilobytes, not bytes.
|
||||
/// Counting bytes is not practical.
|
||||
///
|
||||
fn throttle_body(
|
||||
body: Body,
|
||||
key: &str,
|
||||
limiter: &Arc<RateLimiter<String, DashMapStateStore<String>, QuantaClock>>,
|
||||
) -> Body {
|
||||
let body_stream = body.into_data_stream();
|
||||
let limiter = limiter.clone();
|
||||
let key = key.to_string();
|
||||
let throttled = body_stream
|
||||
.map(move |chunk| {
|
||||
let limiter = limiter.clone();
|
||||
let key = key.to_string();
|
||||
// When the rate limit is exceeded, we wait between 25ms and 500ms before retrying.
|
||||
// This is to avoid overwhelming the server with requests when the rate limit is exceeded.
|
||||
// Randomization is used to avoid thundering herd problem.
|
||||
let jitter = Jitter::new(
|
||||
Duration::from_millis(25),
|
||||
Duration::from_millis(500),
|
||||
);
|
||||
async move {
|
||||
let bytes = match chunk {
|
||||
Ok(actual_chunk) => {
|
||||
actual_chunk
|
||||
}
|
||||
Err(e) => return Err(e),
|
||||
};
|
||||
|
||||
// --- Round up to the nearest kilobyte. ---
|
||||
// Important: If the chunk is < 1KB, it will be rounded up to 1 kb.
|
||||
// Many small uploads will be counted as more than they actually are.
|
||||
// I am not too concerned about this though because small random disk writes are stressing
|
||||
// the disk more anyway compared to larger writes.
|
||||
// Why are we doing this? governor::Quota is defined as a u32. u32 can only count up to 4GB.
|
||||
// To support 4GB/s+ limits we need to count in kilobytes.
|
||||
//
|
||||
// --- Chunk Size ---
|
||||
// The chunk size is determined by the client library.
|
||||
// Common chunk sizes: 16KB to 10MB.
|
||||
// HTTP based uploads are usually between 256KB and 1MB.
|
||||
// Asking the limiter for 1KB packets is tradeoff between
|
||||
// - Not calling the limiter too much
|
||||
// - Guaranteeing the call size (1kb) is low enough to not cause race condition issues.
|
||||
let chunk_kilobytes = bytes.len().div_ceil(1024);
|
||||
for _ in 0..chunk_kilobytes {
|
||||
// Check each kilobyte
|
||||
if limiter
|
||||
.until_key_n_ready_with_jitter(
|
||||
&key,
|
||||
NonZero::new(1).expect("1 is always non zero"),
|
||||
jitter,
|
||||
)
|
||||
.await.is_err()
|
||||
{
|
||||
// Requested rate (1kb) is higher then the set limit (x kb/s).
|
||||
// This should never happen.
|
||||
unreachable!("Rate limiting is based on the number of kilobytes, not bytes. So 1 kb should always be allowed.");
|
||||
};
|
||||
}
|
||||
Ok(bytes)
|
||||
}
|
||||
})
|
||||
.buffered(1);
|
||||
|
||||
Body::from_stream(throttled)
|
||||
}
|
||||
|
||||
/// Get the limits that match the request.
|
||||
fn get_limit_matches(&self, req: &Request<Body>) -> Vec<&LimitTuple> {
|
||||
self.limits
|
||||
.iter()
|
||||
.filter(|limit| limit.is_match(req))
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> Service<Request<Body>> for RateLimiterMiddleware<S>
|
||||
where
|
||||
S: Service<Request<Body>, Response = axum::response::Response, Error = Infallible>
|
||||
+ Send
|
||||
+ 'static
|
||||
+ Clone,
|
||||
S::Future: Send + 'static,
|
||||
{
|
||||
type Response = S::Response;
|
||||
type Error = S::Error;
|
||||
type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
|
||||
|
||||
fn poll_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
self.inner.poll_ready(cx).map_err(|_| unreachable!()) // `Infallible` conversion
|
||||
}
|
||||
|
||||
fn call(&mut self, mut req: Request<Body>) -> Self::Future {
|
||||
let mut inner = self.inner.clone();
|
||||
|
||||
// Match the request path + method to the defined limits.
|
||||
let limits = self.get_limit_matches(&req);
|
||||
if limits.is_empty() {
|
||||
// No limits matched, so we can just call the next layer.
|
||||
return Box::pin(async move { inner.call(req).await.map_err(|_| unreachable!()) });
|
||||
}
|
||||
|
||||
// Go through all the limits and check if we need to throttle or reject the request.
|
||||
for limit in limits.clone() {
|
||||
let key = match limit.extract_key(&req) {
|
||||
Ok(key) => key,
|
||||
Err(e) => {
|
||||
// Failed to extract the key, so we reject the request.
|
||||
// This should only happen if the pubky-host couldn't be extracted.
|
||||
tracing::warn!(
|
||||
"{} {} Failed to extract key for rate limiting: {}",
|
||||
limit.0.path.0,
|
||||
limit.0.method.0,
|
||||
e
|
||||
);
|
||||
return Box::pin(async move {
|
||||
Ok(Error::new(
|
||||
StatusCode::BAD_REQUEST,
|
||||
Some("Failed to extract key for rate limiting".to_string()),
|
||||
)
|
||||
.into_response())
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
match limit.0.quota.rate_unit {
|
||||
RateUnit::SpeedRateUnit(_) => {
|
||||
// Speed limiting is enabled, so we need to throttle the upload.
|
||||
req = Self::throttle_upload(req, &key, &limit.1);
|
||||
}
|
||||
RateUnit::Request => {
|
||||
// Request limiting is enabled, so we need to limit the number of requests.
|
||||
if let Err(e) = limit.1.check_key(&key) {
|
||||
tracing::debug!(
|
||||
"Rate limit of {} exceeded for {key}: {}",
|
||||
limit.0.quota,
|
||||
e
|
||||
);
|
||||
return Box::pin(async move {
|
||||
Ok(Error::new(
|
||||
StatusCode::TOO_MANY_REQUESTS,
|
||||
Some("Rate limit exceeded".to_string()),
|
||||
)
|
||||
.into_response())
|
||||
});
|
||||
};
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
// Create a clone of the request without the body.
|
||||
// This way, we can extract the keys for the response too.
|
||||
let (parts, body) = req.into_parts();
|
||||
let req_clone = Request::from_parts(parts.clone(), Body::empty());
|
||||
let req = Request::from_parts(parts, body);
|
||||
|
||||
let speed_limits = limits
|
||||
.into_iter()
|
||||
.filter(|limit| limit.0.quota.rate_unit.is_speed_rate_unit())
|
||||
.cloned()
|
||||
.collect::<Vec<_>>();
|
||||
Box::pin(async move {
|
||||
// Call the next layer and receive the response.
|
||||
let mut response = match inner.call(req).await.map_err(|_| unreachable!()) {
|
||||
Ok(response) => response,
|
||||
Err(e) => return Err(e),
|
||||
};
|
||||
// Rate limit the download speed.
|
||||
for limit in speed_limits {
|
||||
if let Ok(key) = limit.extract_key(&req_clone) {
|
||||
response = Self::throttle_download(response, &key, &limit.1);
|
||||
};
|
||||
}
|
||||
Ok(response)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
|
||||
|
||||
use axum::http::Method;
|
||||
use axum::{
|
||||
routing::{get, post},
|
||||
Router,
|
||||
};
|
||||
use axum_server::Server;
|
||||
use pkarr::{Keypair, PublicKey};
|
||||
use reqwest::{Client, Response};
|
||||
use tokio::{task::JoinHandle, time::Instant};
|
||||
|
||||
use crate::{core::layers::pubky_host::PubkyHostLayer, quota_config::GlobPattern};
|
||||
|
||||
use super::*;
|
||||
|
||||
// Fake upload handler that just consumes the body.
|
||||
pub async fn upload_handler(body: Body) -> Result<impl IntoResponse> {
|
||||
let mut stream = body.into_data_stream();
|
||||
while let Some(chunk) = stream.next().await.transpose()? {
|
||||
// Consume body
|
||||
let _ = chunk;
|
||||
}
|
||||
Ok((StatusCode::CREATED, ()))
|
||||
}
|
||||
|
||||
// Fake upload handler that just consumes the body.
|
||||
pub async fn download_handler() -> Result<impl IntoResponse> {
|
||||
let response_body = vec![0u8; 3 * 1024]; // 3kb
|
||||
Ok((StatusCode::OK, response_body))
|
||||
}
|
||||
|
||||
// Start a server with the given quota config on a random port.
|
||||
async fn start_server(config: Vec<PathLimit>) -> SocketAddr {
|
||||
let app = Router::new()
|
||||
.route("/upload", post(upload_handler))
|
||||
.route("/download", get(download_handler))
|
||||
.layer(RateLimiterLayer::new(config))
|
||||
.layer(PubkyHostLayer);
|
||||
|
||||
// Create a TCP listener to bind to the socket first
|
||||
// Use port 0 to let the OS assign a random available port
|
||||
let listener = tokio::net::TcpListener::bind(SocketAddr::new(
|
||||
IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)),
|
||||
0,
|
||||
))
|
||||
.await
|
||||
.unwrap();
|
||||
// Get the actual socket address with the OS-assigned port
|
||||
let socket = listener.local_addr().unwrap();
|
||||
|
||||
// Use the listener with axum_server
|
||||
let server = Server::from_tcp(listener.into_std().unwrap());
|
||||
|
||||
tokio::spawn(async move {
|
||||
server
|
||||
.serve(app.into_make_service_with_connect_info::<SocketAddr>())
|
||||
.await
|
||||
.unwrap();
|
||||
});
|
||||
|
||||
socket
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_throttle_upload() {
|
||||
let path_limit = PathLimit::new(
|
||||
GlobPattern::new("/upload"),
|
||||
Method::POST,
|
||||
"1kb/s".parse().unwrap(),
|
||||
LimitKey::Ip,
|
||||
None,
|
||||
);
|
||||
let socket = start_server(vec![path_limit]).await;
|
||||
|
||||
fn upload_data(socket: SocketAddr, kilobytes: usize) -> JoinHandle<()> {
|
||||
tokio::spawn(async move {
|
||||
let client = Client::new();
|
||||
let data = vec![0u8; kilobytes * 1024];
|
||||
let response = client
|
||||
.post(format!("http://{}/upload", socket))
|
||||
.body(data.clone())
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(response.status(), StatusCode::CREATED);
|
||||
})
|
||||
}
|
||||
|
||||
let start = Instant::now();
|
||||
// Spawn in the background to test 2 uploads in parallel
|
||||
// Upload 3kb each
|
||||
let handle1 = upload_data(socket, 3);
|
||||
let handle2 = upload_data(socket, 3);
|
||||
|
||||
// Wait for the uploads to finish
|
||||
let _ = tokio::try_join!(handle1, handle2);
|
||||
|
||||
let time_taken = start.elapsed();
|
||||
assert!(time_taken > Duration::from_secs(6), "Should at least take 6s because uploads are limited to 1kb/s and the sum of the uploads is 6kb");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_throttle_download() {
|
||||
let path_limit = PathLimit::new(
|
||||
GlobPattern::new("/download"),
|
||||
Method::GET,
|
||||
"1kb/s".parse().unwrap(),
|
||||
LimitKey::Ip,
|
||||
None,
|
||||
);
|
||||
let socket = start_server(vec![path_limit]).await;
|
||||
|
||||
fn download_data(socket: SocketAddr) -> JoinHandle<()> {
|
||||
tokio::spawn(async move {
|
||||
let client = Client::new();
|
||||
let response = client
|
||||
.get(format!("http://{}/download", socket))
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(response.status(), StatusCode::OK);
|
||||
response.bytes().await.unwrap(); // Download the body
|
||||
})
|
||||
}
|
||||
|
||||
let start = Instant::now();
|
||||
// Spawn in the background to test 2 downloads in parallel
|
||||
// Download 3kb each
|
||||
let handle1 = download_data(socket);
|
||||
let handle2 = download_data(socket);
|
||||
|
||||
// Wait for the uploads to finish
|
||||
let _ = tokio::try_join!(handle1, handle2);
|
||||
|
||||
let time_taken = start.elapsed();
|
||||
assert!(time_taken > Duration::from_secs(6), "Should at least take 6s because downloads are limited to 1kb/s and the sum of the downloads is 6kb");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_limit_parallel_requests_with_ip_key() {
|
||||
let path_limit = PathLimit::new(
|
||||
GlobPattern::new("/upload"),
|
||||
Method::POST,
|
||||
"1r/m".parse().unwrap(),
|
||||
LimitKey::Ip,
|
||||
None,
|
||||
);
|
||||
let socket = start_server(vec![path_limit]).await;
|
||||
|
||||
fn send_request(socket: SocketAddr) -> JoinHandle<Response> {
|
||||
tokio::spawn(async move {
|
||||
let client = Client::new();
|
||||
let response = client
|
||||
.post(format!("http://{}/upload", socket))
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
response
|
||||
})
|
||||
}
|
||||
|
||||
// Spawn in the background to test 2 uploads in parallel
|
||||
let handle1 = send_request(socket);
|
||||
let handle2 = send_request(socket);
|
||||
|
||||
// Wait for the uploads to finish
|
||||
let (res1, res2) = tokio::try_join!(handle1, handle2).unwrap();
|
||||
assert_eq!(res1.status(), StatusCode::CREATED);
|
||||
assert_eq!(res2.status(), StatusCode::TOO_MANY_REQUESTS);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_limit_parallel_requests_with_user_key() {
|
||||
let path_limit = PathLimit::new(
|
||||
GlobPattern::new("/upload"),
|
||||
Method::POST,
|
||||
"1r/m".parse().unwrap(),
|
||||
LimitKey::User,
|
||||
None,
|
||||
);
|
||||
let socket = start_server(vec![path_limit]).await;
|
||||
|
||||
fn send_request(socket: SocketAddr, user_pubkey: PublicKey) -> JoinHandle<Response> {
|
||||
tokio::spawn(async move {
|
||||
let client = Client::new();
|
||||
let response = client
|
||||
.post(format!("http://{}/upload?pubky-host={user_pubkey}", socket))
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
response
|
||||
})
|
||||
}
|
||||
|
||||
// Spawn in the background to test 2 uploads in parallel
|
||||
let user1_pubkey = Keypair::random().public_key();
|
||||
let handle1 = send_request(socket, user1_pubkey.clone());
|
||||
let handle2 = send_request(socket, user1_pubkey.clone());
|
||||
let user2_pubkey = Keypair::random().public_key();
|
||||
let handle3 = send_request(socket, user2_pubkey.clone());
|
||||
|
||||
// Wait for the uploads to finish
|
||||
let (res1, res2, res3) = tokio::try_join!(handle1, handle2, handle3).unwrap();
|
||||
assert_eq!(res1.status(), StatusCode::CREATED);
|
||||
assert_eq!(res2.status(), StatusCode::TOO_MANY_REQUESTS);
|
||||
assert_eq!(res3.status(), StatusCode::CREATED);
|
||||
}
|
||||
}
|
||||
3
pubky-homeserver/src/core/layers/rate_limiter/mod.rs
Normal file
3
pubky-homeserver/src/core/layers/rate_limiter/mod.rs
Normal file
@@ -0,0 +1,3 @@
|
||||
mod extract_ip;
|
||||
mod layer;
|
||||
pub use layer::*;
|
||||
@@ -13,9 +13,11 @@ use tower::ServiceBuilder;
|
||||
use tower_cookies::CookieManagerLayer;
|
||||
use tower_http::cors::CorsLayer;
|
||||
|
||||
use crate::core::AppState;
|
||||
use crate::{core::AppState, AppContext};
|
||||
|
||||
use super::layers::{pubky_host::PubkyHostLayer, trace::with_trace_layer};
|
||||
use super::layers::{
|
||||
pubky_host::PubkyHostLayer, rate_limiter::RateLimiterLayer, trace::with_trace_layer,
|
||||
};
|
||||
|
||||
mod auth;
|
||||
mod feed;
|
||||
@@ -37,16 +39,20 @@ fn base() -> Router<AppState> {
|
||||
// TODO: maybe add to a separate router (drive router?).
|
||||
}
|
||||
|
||||
pub fn create_app(state: AppState) -> Router {
|
||||
pub fn create_app(state: AppState, context: &AppContext) -> Router {
|
||||
let app = base()
|
||||
.merge(tenants::router(state.clone()))
|
||||
.layer(CookieManagerLayer::new())
|
||||
.layer(CorsLayer::very_permissive())
|
||||
.layer(ServiceBuilder::new().layer(middleware::from_fn(add_server_header)))
|
||||
.layer(PubkyHostLayer)
|
||||
.layer(RateLimiterLayer::new(
|
||||
context.config_toml.drive.rate_limits.clone(),
|
||||
))
|
||||
.with_state(state);
|
||||
|
||||
// Apply trace and pubky host layers to the complete router.
|
||||
with_trace_layer(app, &TRACING_EXCLUDED_PATHS).layer(PubkyHostLayer)
|
||||
with_trace_layer(app, &TRACING_EXCLUDED_PATHS)
|
||||
}
|
||||
|
||||
// Middleware to add a `Server` header to all responses
|
||||
|
||||
@@ -3,9 +3,11 @@ signup_mode = "token_required"
|
||||
lmdb_backup_interval_s = 0
|
||||
user_storage_quota_mb = 0
|
||||
|
||||
|
||||
[drive]
|
||||
pubky_listen_socket = "127.0.0.1:6287"
|
||||
icann_listen_socket = "127.0.0.1:6286"
|
||||
rate_limits = []
|
||||
|
||||
[admin]
|
||||
listen_socket = "127.0.0.1:6288"
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
//! This module embeds that file at compile-time, parses it once,
|
||||
//! and lets callers optionally layer their own TOML on top.
|
||||
|
||||
use super::{domain_port::DomainPort, Domain, SignupMode};
|
||||
use super::{domain_port::DomainPort, quota_config::PathLimit, Domain, SignupMode};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{
|
||||
fmt::Debug,
|
||||
@@ -54,6 +54,7 @@ pub struct PkdnsToml {
|
||||
pub struct DriveToml {
|
||||
pub pubky_listen_socket: SocketAddr,
|
||||
pub icann_listen_socket: SocketAddr,
|
||||
pub rate_limits: Vec<PathLimit>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
|
||||
@@ -144,9 +145,8 @@ impl ConfigToml {
|
||||
.lines()
|
||||
.map(|line| {
|
||||
let trimmed = line.trim_start();
|
||||
let is_title = trimmed.starts_with('[');
|
||||
let is_comment = trimmed.starts_with('#');
|
||||
if !is_title && !is_comment && !trimmed.is_empty() {
|
||||
if !is_comment && !trimmed.is_empty() {
|
||||
format!("# {}", line)
|
||||
} else {
|
||||
line.to_string()
|
||||
@@ -216,12 +216,13 @@ mod tests {
|
||||
assert_eq!(c.pkdns.dht_bootstrap_nodes, None);
|
||||
assert_eq!(c.pkdns.dht_relay_nodes, None);
|
||||
assert_eq!(c.pkdns.dht_request_timeout_ms, None);
|
||||
assert_eq!(c.drive.rate_limits, vec![]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sample_config() {
|
||||
// Validate that the sample config can be parsed
|
||||
ConfigToml::from_str(SAMPLE_CONFIG).expect("Embedded config.default.toml must be valid");
|
||||
ConfigToml::from_str(SAMPLE_CONFIG).expect("Embedded config.sample.toml must be valid");
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
||||
@@ -4,6 +4,8 @@ mod domain;
|
||||
mod domain_port;
|
||||
mod mock_data_dir;
|
||||
mod persistent_data_dir;
|
||||
/// Quota configuration for the TomlConfig.
|
||||
pub mod quota_config;
|
||||
mod signup_mode;
|
||||
|
||||
pub use config_toml::{ConfigReadError, ConfigToml};
|
||||
|
||||
@@ -0,0 +1,89 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{fmt::Display, str::FromStr};
|
||||
|
||||
/// A wrapper around fast_glob to allow serialize/deserialize.
|
||||
/// Pattern matches glob patterns.
|
||||
///
|
||||
/// Syntax - Meaning
|
||||
/// `?` - Matches any single character.
|
||||
/// `*` - Matches zero or more characters, except for path separators (e.g. /).
|
||||
/// `**` - Matches zero or more characters, including path separators. Must match a complete path segment (i.e. followed by a / or the end of the pattern).
|
||||
/// `[ab]` - Matches one of the characters contained in the brackets. Character ranges, e.g. `[a-z]` are also supported. Use `[!ab]` or `[^ab]` to match any character except those contained in the brackets.
|
||||
/// `{a,b}` - Matches one of the patterns contained in the braces. Any of the wildcard characters can be used in the sub-patterns. Braces may be nested up to 10 levels deep.
|
||||
/// `!` - When at the start of the glob, this negates the result. Multiple `!` characters negate the glob multiple times.
|
||||
/// `\` - A backslash character may be used to escape any of the above special characters.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct GlobPattern(pub String);
|
||||
|
||||
impl GlobPattern {
|
||||
/// Create a new glob pattern.
|
||||
pub fn new(pattern: &str) -> Self {
|
||||
Self(pattern.to_string())
|
||||
}
|
||||
|
||||
/// Check if the path matches the glob pattern.
|
||||
pub fn is_match(&self, path: &str) -> bool {
|
||||
fast_glob::glob_match(&self.0, path)
|
||||
}
|
||||
}
|
||||
|
||||
impl std::hash::Hash for GlobPattern {
|
||||
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
|
||||
self.0.as_str().hash(state);
|
||||
}
|
||||
}
|
||||
|
||||
impl FromStr for GlobPattern {
|
||||
type Err = String;
|
||||
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
Ok(GlobPattern(s.to_string()))
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for GlobPattern {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "{}", self.0.as_str())
|
||||
}
|
||||
}
|
||||
|
||||
impl Serialize for GlobPattern {
|
||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: serde::Serializer,
|
||||
{
|
||||
serializer.serialize_str(self.0.as_str())
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de> Deserialize<'de> for GlobPattern {
|
||||
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||
where
|
||||
D: serde::Deserializer<'de>,
|
||||
{
|
||||
let s = String::deserialize(deserializer)?;
|
||||
GlobPattern::from_str(&s).map_err(serde::de::Error::custom)
|
||||
}
|
||||
}
|
||||
|
||||
impl PartialEq for GlobPattern {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
self.0.as_str() == other.0.as_str()
|
||||
}
|
||||
}
|
||||
|
||||
impl Eq for GlobPattern {}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_glob_pattern() {
|
||||
let glob_pattern = GlobPattern::from_str("/pub/**").unwrap();
|
||||
assert!(glob_pattern.is_match("/pub/test.txt"));
|
||||
assert!(glob_pattern.is_match("/pub/test/test.txt"));
|
||||
assert!(!glob_pattern.is_match("/priv/test.pdf"));
|
||||
assert!(!glob_pattern.is_match("/session/test.txt"));
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,55 @@
|
||||
use std::{fmt::Display, str::FromStr};
|
||||
|
||||
use axum::http::Method;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// A wrapper around http::Method to implement serde traits
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
||||
pub struct HttpMethod(pub Method);
|
||||
|
||||
impl From<Method> for HttpMethod {
|
||||
fn from(method: Method) -> Self {
|
||||
HttpMethod(method)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<HttpMethod> for Method {
|
||||
fn from(method: HttpMethod) -> Self {
|
||||
method.0
|
||||
}
|
||||
}
|
||||
|
||||
impl FromStr for HttpMethod {
|
||||
type Err = String;
|
||||
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
Method::from_str(s.to_uppercase().as_str())
|
||||
.map(HttpMethod)
|
||||
.map_err(|_| format!("Invalid method: {}", s))
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for HttpMethod {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "{}", self.0)
|
||||
}
|
||||
}
|
||||
|
||||
impl Serialize for HttpMethod {
|
||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: serde::Serializer,
|
||||
{
|
||||
serializer.serialize_str(&self.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de> Deserialize<'de> for HttpMethod {
|
||||
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||
where
|
||||
D: serde::Deserializer<'de>,
|
||||
{
|
||||
let s = String::deserialize(deserializer)?;
|
||||
HttpMethod::from_str(&s).map_err(serde::de::Error::custom)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,55 @@
|
||||
use std::fmt;
|
||||
use std::str::FromStr;
|
||||
|
||||
/// The key to limit the quota on.
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
||||
pub enum LimitKey {
|
||||
/// Limit on the user id
|
||||
User,
|
||||
/// Limit on the ip address
|
||||
Ip,
|
||||
}
|
||||
|
||||
impl fmt::Display for LimitKey {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(
|
||||
f,
|
||||
"{}",
|
||||
match self {
|
||||
LimitKey::User => "user",
|
||||
LimitKey::Ip => "ip",
|
||||
}
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl FromStr for LimitKey {
|
||||
type Err = String;
|
||||
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
match s {
|
||||
"user" => Ok(LimitKey::User),
|
||||
"ip" => Ok(LimitKey::Ip),
|
||||
_ => Err(format!("Invalid limit key: {}", s)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl serde::Serialize for LimitKey {
|
||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: serde::Serializer,
|
||||
{
|
||||
serializer.serialize_str(self.to_string().as_str())
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de> serde::Deserialize<'de> for LimitKey {
|
||||
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||
where
|
||||
D: serde::Deserializer<'de>,
|
||||
{
|
||||
let s = String::deserialize(deserializer)?;
|
||||
LimitKey::from_str(&s).map_err(serde::de::Error::custom)
|
||||
}
|
||||
}
|
||||
15
pubky-homeserver/src/data_directory/quota_config/mod.rs
Normal file
15
pubky-homeserver/src/data_directory/quota_config/mod.rs
Normal file
@@ -0,0 +1,15 @@
|
||||
mod glob_pattern;
|
||||
mod http_method;
|
||||
mod limit_key;
|
||||
mod path_limit;
|
||||
mod quota_value;
|
||||
mod rate_unit;
|
||||
mod time_unit;
|
||||
|
||||
pub use glob_pattern::GlobPattern;
|
||||
pub use http_method::HttpMethod;
|
||||
pub use limit_key::LimitKey;
|
||||
pub use path_limit::*;
|
||||
pub use quota_value::QuotaValue;
|
||||
pub use rate_unit::RateUnit;
|
||||
pub use time_unit::TimeUnit;
|
||||
@@ -0,0 +1,77 @@
|
||||
use super::{GlobPattern, HttpMethod, LimitKey, QuotaValue};
|
||||
use axum::http::Method;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::num::NonZeroU32;
|
||||
|
||||
/// A limit on a path for a specific method.
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Hash)]
|
||||
pub struct PathLimit {
|
||||
/// The path glob pattern to match against.
|
||||
pub path: GlobPattern,
|
||||
/// The method to limit.
|
||||
pub method: HttpMethod,
|
||||
/// The limit to apply.
|
||||
pub quota: QuotaValue,
|
||||
/// The key to limit.
|
||||
pub key: LimitKey,
|
||||
/// The burst to apply.
|
||||
pub burst: Option<NonZeroU32>,
|
||||
}
|
||||
|
||||
impl PathLimit {
|
||||
/// Create a new path limit.
|
||||
pub fn new(
|
||||
path: GlobPattern,
|
||||
method: Method,
|
||||
quota: QuotaValue,
|
||||
key: LimitKey,
|
||||
burst: Option<NonZeroU32>,
|
||||
) -> Self {
|
||||
Self {
|
||||
path,
|
||||
method: HttpMethod(method),
|
||||
quota,
|
||||
key,
|
||||
burst,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for PathLimit {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(
|
||||
f,
|
||||
"{} {}: {}-{} by {}",
|
||||
self.method,
|
||||
self.path,
|
||||
self.quota,
|
||||
self.burst.map_or_else(String::new, |b| b.to_string()),
|
||||
self.key
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<PathLimit> for governor::Quota {
|
||||
fn from(value: PathLimit) -> Self {
|
||||
let quota: governor::Quota = value.quota.into();
|
||||
if let Some(burst) = value.burst {
|
||||
quota.allow_burst(burst);
|
||||
}
|
||||
quota
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_http_method_serde() {
|
||||
let method = Method::GET;
|
||||
let http_method = HttpMethod(method);
|
||||
assert_eq!(http_method.to_string(), "GET");
|
||||
|
||||
let deserialized: HttpMethod = "GET".parse().unwrap();
|
||||
assert_eq!(deserialized, http_method);
|
||||
}
|
||||
}
|
||||
201
pubky-homeserver/src/data_directory/quota_config/quota_value.rs
Normal file
201
pubky-homeserver/src/data_directory/quota_config/quota_value.rs
Normal file
@@ -0,0 +1,201 @@
|
||||
use std::fmt;
|
||||
use std::str::FromStr;
|
||||
use std::{num::NonZeroU32, time::Duration};
|
||||
|
||||
use super::{RateUnit, TimeUnit};
|
||||
|
||||
/// Quota value
|
||||
///
|
||||
/// Examples:
|
||||
/// - 5r/m
|
||||
/// - 5r/s
|
||||
/// - 5kb/m
|
||||
/// - 5mb/m
|
||||
/// - 5gb/s
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
||||
pub struct QuotaValue {
|
||||
/// The rate.
|
||||
pub rate: NonZeroU32,
|
||||
/// The unit of the rate.
|
||||
pub rate_unit: RateUnit,
|
||||
/// The unit of the time.
|
||||
pub time_unit: TimeUnit,
|
||||
}
|
||||
|
||||
impl From<QuotaValue> for governor::Quota {
|
||||
/// Get the quota to do the actual rate limiting.
|
||||
///
|
||||
/// Important: The speed quotas are always in kilobytes, not bytes.
|
||||
/// Counting bytes is not practical.
|
||||
///
|
||||
fn from(value: QuotaValue) -> Self {
|
||||
let rate_count = value.rate.get();
|
||||
let rate_unit = value.rate_unit.multiplier().get();
|
||||
let rate = NonZeroU32::new(rate_count * rate_unit)
|
||||
.expect("Is always non-zero because rate count and rate unit multiplier are non-zero");
|
||||
let time_unit = Duration::from_secs(value.time_unit.multiplier_in_seconds().get() as u64);
|
||||
let replenish_1_per = time_unit / rate.get();
|
||||
|
||||
let base_quota = governor::Quota::with_period(replenish_1_per)
|
||||
.expect("Is always non-zero because replenish_1_per is non-zero");
|
||||
base_quota.allow_burst(rate)
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for QuotaValue {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(f, "{}{}/{}", self.rate, self.rate_unit, self.time_unit)
|
||||
}
|
||||
}
|
||||
|
||||
impl FromStr for QuotaValue {
|
||||
type Err = String;
|
||||
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
// Split rate part by '/' to get rate+unit and time unit
|
||||
let rate_parts: Vec<&str> = s.split('/').collect();
|
||||
if rate_parts.len() != 2 {
|
||||
return Err(format!(
|
||||
"Invalid rate format: '{}', expected {{rate}}{{unit}}/{{time}}",
|
||||
s
|
||||
));
|
||||
}
|
||||
|
||||
let rate_with_unit = rate_parts[0];
|
||||
let time_unit = TimeUnit::from_str(rate_parts[1])?;
|
||||
|
||||
// Find the boundary between rate digits and unit
|
||||
let rate_digit_end = rate_with_unit
|
||||
.chars()
|
||||
.position(|c| !c.is_ascii_digit())
|
||||
.unwrap_or(rate_with_unit.len());
|
||||
|
||||
if rate_digit_end == 0 {
|
||||
return Err(format!("Missing rate value in '{}'", rate_with_unit));
|
||||
}
|
||||
|
||||
let rate_str = &rate_with_unit[..rate_digit_end];
|
||||
let rate_unit_str = &rate_with_unit[rate_digit_end..];
|
||||
|
||||
let rate = rate_str
|
||||
.parse::<NonZeroU32>()
|
||||
.map_err(|_| format!("Failed to parse rate from '{}'", rate_str))?;
|
||||
let rate_unit = RateUnit::from_str(rate_unit_str)?;
|
||||
|
||||
Ok(QuotaValue {
|
||||
rate,
|
||||
rate_unit,
|
||||
time_unit,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl serde::Serialize for QuotaValue {
|
||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: serde::Serializer,
|
||||
{
|
||||
serializer.serialize_str(&self.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de> serde::Deserialize<'de> for QuotaValue {
|
||||
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||
where
|
||||
D: serde::Deserializer<'de>,
|
||||
{
|
||||
let s = String::deserialize(deserializer)?;
|
||||
|
||||
// Parse the quota string
|
||||
QuotaValue::from_str(&s).map_err(serde::de::Error::custom)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::quota_config::rate_unit::SpeedRateUnit;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_get_quota() {
|
||||
let quota = QuotaValue::from_str("5r/s").unwrap();
|
||||
assert_eq!(
|
||||
governor::Quota::from(quota),
|
||||
governor::Quota::per_second(NonZeroU32::new(5).unwrap())
|
||||
);
|
||||
|
||||
let quota = QuotaValue::from_str("5r/m").unwrap();
|
||||
assert_eq!(
|
||||
governor::Quota::from(quota),
|
||||
governor::Quota::per_minute(NonZeroU32::new(5).unwrap())
|
||||
);
|
||||
|
||||
let quota = QuotaValue::from_str("5kb/s").unwrap();
|
||||
assert_eq!(
|
||||
governor::Quota::from(quota),
|
||||
governor::Quota::per_second(NonZeroU32::new(5).unwrap())
|
||||
);
|
||||
|
||||
let quota = QuotaValue::from_str("5mb/m").unwrap();
|
||||
assert_eq!(
|
||||
governor::Quota::from(quota),
|
||||
governor::Quota::per_minute(NonZeroU32::new(5 * 1024).unwrap())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_quota_value_from_str() {
|
||||
// Test without burst
|
||||
let quota = QuotaValue::from_str("5r/s").unwrap();
|
||||
assert_eq!(quota.rate, NonZeroU32::new(5).unwrap());
|
||||
assert_eq!(quota.rate_unit, RateUnit::Request);
|
||||
assert_eq!(quota.time_unit, TimeUnit::Second);
|
||||
|
||||
// Test with burst (should fail or be handled differently)
|
||||
let quota = QuotaValue::from_str("10mb/m").unwrap();
|
||||
assert_eq!(quota.rate, NonZeroU32::new(10).unwrap());
|
||||
assert_eq!(
|
||||
quota.rate_unit,
|
||||
RateUnit::SpeedRateUnit(SpeedRateUnit::Megabyte)
|
||||
);
|
||||
assert_eq!(quota.time_unit, TimeUnit::Minute);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_quota_value_display() {
|
||||
// Test without burst
|
||||
let quota = QuotaValue {
|
||||
rate: NonZeroU32::new(5).unwrap(),
|
||||
rate_unit: RateUnit::Request,
|
||||
time_unit: TimeUnit::Second,
|
||||
};
|
||||
assert_eq!(quota.to_string(), "5r/s");
|
||||
|
||||
// Test with burst (should be displayed without burst)
|
||||
let quota = QuotaValue {
|
||||
rate: NonZeroU32::new(10).unwrap(),
|
||||
rate_unit: RateUnit::SpeedRateUnit(SpeedRateUnit::Megabyte),
|
||||
time_unit: TimeUnit::Minute,
|
||||
};
|
||||
assert_eq!(quota.to_string(), "10mb/m");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_quota_value_invalid_formats() {
|
||||
// Invalid format: missing /
|
||||
assert!(QuotaValue::from_str("5rs").is_err());
|
||||
|
||||
// Invalid format: missing rate
|
||||
assert!(QuotaValue::from_str("r/s").is_err());
|
||||
|
||||
// Invalid format: invalid rate unit
|
||||
assert!(QuotaValue::from_str("5x/s").is_err());
|
||||
|
||||
// Invalid format: invalid time unit
|
||||
assert!(QuotaValue::from_str("5r/x").is_err());
|
||||
|
||||
// Invalid format: invalid burst (this test case might need to be removed or updated)
|
||||
assert!(QuotaValue::from_str("5r/s-2burst").is_err());
|
||||
}
|
||||
}
|
||||
112
pubky-homeserver/src/data_directory/quota_config/rate_unit.rs
Normal file
112
pubky-homeserver/src/data_directory/quota_config/rate_unit.rs
Normal file
@@ -0,0 +1,112 @@
|
||||
use std::{num::NonZeroU32, str::FromStr};
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
||||
pub enum SpeedRateUnit {
|
||||
/// Kilobyte
|
||||
Kilobyte,
|
||||
/// Megabyte
|
||||
Megabyte,
|
||||
/// Gigabyte
|
||||
Gigabyte,
|
||||
}
|
||||
|
||||
impl SpeedRateUnit {
|
||||
/// Returns the number of bytes for this unit
|
||||
pub const fn multiplier(&self) -> NonZeroU32 {
|
||||
match self {
|
||||
// Speed quotas are always in kilobytes.
|
||||
// Because counting bytes is not practical and we are limited to u32 = 4GB max.
|
||||
// Counting in kb as more practical and we can count up to 4GB*1024 = 4TB.
|
||||
SpeedRateUnit::Kilobyte => NonZeroU32::new(1).expect("Is always non-zero"),
|
||||
SpeedRateUnit::Megabyte => NonZeroU32::new(1024).expect("Is always non-zero"),
|
||||
SpeedRateUnit::Gigabyte => NonZeroU32::new(1024 * 1024).expect("Is always non-zero"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for SpeedRateUnit {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(
|
||||
f,
|
||||
"{}",
|
||||
match self {
|
||||
SpeedRateUnit::Kilobyte => "kb",
|
||||
SpeedRateUnit::Megabyte => "mb",
|
||||
SpeedRateUnit::Gigabyte => "gb",
|
||||
}
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl FromStr for SpeedRateUnit {
|
||||
type Err = String;
|
||||
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
match s.to_lowercase().as_str() {
|
||||
"kb" => Ok(SpeedRateUnit::Kilobyte),
|
||||
"mb" => Ok(SpeedRateUnit::Megabyte),
|
||||
"gb" => Ok(SpeedRateUnit::Gigabyte),
|
||||
_ => Err(format!("Invalid speedrate unit: {}", s)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// The unit of the rate.
|
||||
///
|
||||
/// Examples:
|
||||
/// - "r" -> request
|
||||
/// - "kb" -> kilobyte
|
||||
/// - "mb" -> megabyte
|
||||
/// - "gb" -> gigabyte
|
||||
/// - "tb" -> terabyte
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
||||
pub enum RateUnit {
|
||||
/// Request
|
||||
Request,
|
||||
/// Speed rate unit
|
||||
SpeedRateUnit(SpeedRateUnit),
|
||||
}
|
||||
|
||||
impl RateUnit {
|
||||
/// Returns true if the rate unit is a speed rate unit.
|
||||
pub fn is_speed_rate_unit(&self) -> bool {
|
||||
matches!(self, RateUnit::SpeedRateUnit(_))
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for RateUnit {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(
|
||||
f,
|
||||
"{}",
|
||||
match self {
|
||||
RateUnit::Request => "r".to_string(),
|
||||
RateUnit::SpeedRateUnit(unit) => unit.to_string(),
|
||||
}
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl FromStr for RateUnit {
|
||||
type Err = String;
|
||||
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
match s.to_lowercase().as_str() {
|
||||
"r" => Ok(RateUnit::Request),
|
||||
other => match SpeedRateUnit::from_str(other) {
|
||||
Ok(unit) => Ok(RateUnit::SpeedRateUnit(unit)),
|
||||
Err(_) => Err(format!("Invalid rate unit: {}", s)),
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl RateUnit {
|
||||
/// Returns the number of bytes for this unit
|
||||
pub const fn multiplier(&self) -> NonZeroU32 {
|
||||
match self {
|
||||
RateUnit::Request => NonZeroU32::new(1).expect("Is always non-zero"),
|
||||
RateUnit::SpeedRateUnit(unit) => unit.multiplier(),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,49 @@
|
||||
use std::num::NonZeroU32;
|
||||
|
||||
/// The time unit of the quota.
|
||||
///
|
||||
/// Examples:
|
||||
/// - "s" -> second
|
||||
/// - "m" -> minute
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
||||
pub enum TimeUnit {
|
||||
/// Second
|
||||
Second,
|
||||
/// Minute
|
||||
Minute,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for TimeUnit {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(
|
||||
f,
|
||||
"{}",
|
||||
match self {
|
||||
TimeUnit::Second => "s",
|
||||
TimeUnit::Minute => "m",
|
||||
}
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl std::str::FromStr for TimeUnit {
|
||||
type Err = String;
|
||||
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
match s {
|
||||
"s" => Ok(TimeUnit::Second),
|
||||
"m" => Ok(TimeUnit::Minute),
|
||||
_ => Err(format!("Invalid time unit: {}", s)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl TimeUnit {
|
||||
/// Returns the number of seconds for each unit
|
||||
pub const fn multiplier_in_seconds(&self) -> NonZeroU32 {
|
||||
match self {
|
||||
TimeUnit::Second => NonZeroU32::new(1).expect("Is always non-zero"),
|
||||
TimeUnit::Minute => NonZeroU32::new(60).expect("Is always non-zero"),
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user