diff --git a/Cargo.lock b/Cargo.lock index 40e6165..2bd83b9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -36,6 +36,12 @@ dependencies = [ "memchr", ] +[[package]] +name = "allocator-api2" +version = "0.2.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "683d7910e743518b0e34f1186f92494becacb047c7b6bf616c96772180fef923" + [[package]] name = "anstream" version = "0.6.18" @@ -923,6 +929,15 @@ dependencies = [ "windows-sys 0.59.0", ] +[[package]] +name = "fast-glob" +version = "0.4.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "39ea3f6bbcf4dbe2076b372186fc7aeecd5f6f84754582e56ee7db262b15a6f0" +dependencies = [ + "arrayvec", +] + [[package]] name = "fastrand" version = "2.3.0" @@ -953,6 +968,12 @@ version = "1.0.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" +[[package]] +name = "foldhash" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d9c4f5dac5e15c24eb999c26181a6ca40b39fe946cbe4c263c7209467bc83af2" + [[package]] name = "foreign-types" version = "0.3.2" @@ -1185,6 +1206,29 @@ dependencies = [ "spinning_top", ] +[[package]] +name = "governor" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3cbe789d04bf14543f03c4b60cd494148aa79438c8440ae7d81a7778147745c3" +dependencies = [ + "cfg-if", + "dashmap", + "futures-sink", + "futures-timer", + "futures-util", + "getrandom 0.3.1", + "hashbrown 0.15.2", + "nonzero_ext", + "parking_lot", + "portable-atomic", + "quanta", + "rand 0.9.0", + "smallvec", + "spinning_top", + "web-time", +] + [[package]] name = "h2" version = "0.4.8" @@ -1224,6 +1268,11 @@ name = "hashbrown" version = "0.15.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bf151400ff0baff5465007dd2f3e717f3fe502074ca563069ce3a6629d07b289" +dependencies = [ + "allocator-api2", + "equivalent", + "foldhash", +] [[package]] name = "headers" @@ -1321,9 +1370,9 @@ checksum = "f558a64ac9af88b5ba400d99b579451af0d39c6d360980045b91aac966d705e2" [[package]] name = "http" -version = "1.2.0" +version = "1.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f16ca2af56261c99fba8bac40a10251ce8188205a4c448fbb745a2e4daa76fea" +checksum = "f4a85d31aea989eead29a3aaf9e1115a180df8282431156e533de47660892565" dependencies = [ "bytes", "fnv", @@ -2157,7 +2206,7 @@ dependencies = [ "bytes", "clap", "dirs-next", - "governor", + "governor 0.8.0", "http", "httpdate", "pkarr", @@ -2343,9 +2392,11 @@ dependencies = [ "clap", "dirs", "dyn-clone", + "fast-glob", "flume", "futures-lite", "futures-util", + "governor 0.10.0", "heed", "hex", "hostname-validator", @@ -3500,7 +3551,7 @@ checksum = "57a2ccff6830fa835371af7541e561a90e4c07b84f72991ebac4b3cb6790dc0d" dependencies = [ "axum", "forwarded-header-value", - "governor", + "governor 0.8.0", "http", "pin-project", "thiserror 2.0.12", diff --git a/pubky-client/src/wasm/mod.rs b/pubky-client/src/wasm/mod.rs index 87d2629..8f6a089 100644 --- a/pubky-client/src/wasm/mod.rs +++ b/pubky-client/src/wasm/mod.rs @@ -49,7 +49,7 @@ impl Client { .expect("testnet relays are valid urls") }); - let mut client = builder.build().expect("testnet build should be infallibl"); + let mut client = builder.build().expect("testnet build should be infallible"); client.testnet = true; diff --git a/pubky-homeserver/Cargo.toml b/pubky-homeserver/Cargo.toml index ed93c7d..514e7e4 100644 --- a/pubky-homeserver/Cargo.toml +++ b/pubky-homeserver/Cargo.toml @@ -53,7 +53,10 @@ axum-test = "17.2.0" tempfile = { version = "3.10.1" } dyn-clone = "1.0.19" reqwest = "0.12.15" +governor = "0.10.0" +fast-glob = "0.4.5" [dev-dependencies] futures-lite = "2.6.0" + diff --git a/pubky-homeserver/config.sample.toml b/pubky-homeserver/config.sample.toml index 2ae4fe2..9e4d4a5 100644 --- a/pubky-homeserver/config.sample.toml +++ b/pubky-homeserver/config.sample.toml @@ -24,6 +24,35 @@ pubky_listen_socket = "127.0.0.1:6287" # May be put behind a reverse proxy with TLS enabled. icann_listen_socket = "127.0.0.1:6286" +# Rate limit endpoints dynamically. +# `path` is a glob pattern of the path. See syntax in https://crates.io/crates/fast-glob +# `method` is the HTTP method. Examples: GET, POST, PUT, HEAD, DELETE +# `quota` defines the limit itself in the format $rate$rate_unit/$time_unit. +# - $rate is a positive integer, max 4'294'967'296. +# - $rate_unit is either "r" for requests, "kb" for kilobytes, +# "mb" for megabytes, "gb" for gigabytes. +# Speed limits limit download and upload. +# - $time_unit is either "s" for second, "m" for minute. +# `key` defines what who is rate limited. +# - "ip" limits based on the IP address. +# - "user" limits based on the user pubkey. Requires the endpoint to have authentication. +# `burst` is a temporary allowance of quota that is added to the limit. +# By default, burst is equal the quota rate. +# +# Limit login attempts to 20 requests per minute per IP. +[[drive.rate_limits]] +path = "/session" +method = "POST" +quota = "20r/m" +key = "ip" +# +# Limit file uploads to 1 megabyte per second per user with a temporary burst of 10 megabytes. +[[drive.rate_limits]] +path = "/pub/**" +method = "PUT" +quota = "1mb/s" +key = "user" +burst = 10 [admin] # The port number to run the admin HTTP (clear text) server on. diff --git a/pubky-homeserver/src/core/homeserver_core.rs b/pubky-homeserver/src/core/homeserver_core.rs index 80bb642..02f8b3c 100644 --- a/pubky-homeserver/src/core/homeserver_core.rs +++ b/pubky-homeserver/src/core/homeserver_core.rs @@ -151,7 +151,7 @@ impl HomeserverCore { signup_mode: context.config_toml.general.signup_mode.clone(), user_quota_bytes: quota_bytes, }; - super::routes::create_app(state.clone()) + super::routes::create_app(state.clone(), context) } /// Start the ICANN HTTP server diff --git a/pubky-homeserver/src/core/layers/mod.rs b/pubky-homeserver/src/core/layers/mod.rs index cceacd2..80b820a 100644 --- a/pubky-homeserver/src/core/layers/mod.rs +++ b/pubky-homeserver/src/core/layers/mod.rs @@ -1,3 +1,4 @@ pub mod authz; pub mod pubky_host; +pub mod rate_limiter; pub mod trace; diff --git a/pubky-homeserver/src/core/layers/rate_limiter/extract_ip.rs b/pubky-homeserver/src/core/layers/rate_limiter/extract_ip.rs new file mode 100644 index 0000000..1fa6844 --- /dev/null +++ b/pubky-homeserver/src/core/layers/rate_limiter/extract_ip.rs @@ -0,0 +1,37 @@ +use axum::extract::Request; +use axum::http::HeaderMap; +use std::net::{IpAddr, SocketAddr}; + +// From https://github.com/benwis/tower-governor/blob/main/src/key_extractor.rs#L121 +const X_REAL_IP: &str = "x-real-ip"; +const X_FORWARDED_FOR: &str = "x-forwarded-for"; + +/// Tries to parse the `x-forwarded-for` header +fn maybe_x_forwarded_for(headers: &HeaderMap) -> Option { + headers + .get(X_FORWARDED_FOR) + .and_then(|hv| hv.to_str().ok()) + .and_then(|s| s.split(',').find_map(|s| s.trim().parse::().ok())) +} + +/// Tries to parse the `x-real-ip` header +fn maybe_x_real_ip(headers: &HeaderMap) -> Option { + headers + .get(X_REAL_IP) + .and_then(|hv| hv.to_str().ok()) + .and_then(|s| s.parse::().ok()) +} + +fn maybe_connect_info(req: &Request) -> Option { + req.extensions() + .get::>() + .map(|addr| addr.ip()) +} + +pub fn extract_ip(req: &Request) -> anyhow::Result { + let headers = req.headers(); + maybe_x_forwarded_for(headers) + .or_else(|| maybe_x_real_ip(headers)) + .or_else(|| maybe_connect_info(req)) + .ok_or(anyhow::anyhow!("Failed to extract ip.")) +} diff --git a/pubky-homeserver/src/core/layers/rate_limiter/layer.rs b/pubky-homeserver/src/core/layers/rate_limiter/layer.rs new file mode 100644 index 0000000..bcbfe71 --- /dev/null +++ b/pubky-homeserver/src/core/layers/rate_limiter/layer.rs @@ -0,0 +1,539 @@ +//! +//! Implements rate limiting with governor. +//! +//! Would love to use tower_governor but I can't type it properly due to +//! https://github.com/benwis/tower-governor/issues/49. +//! +//! So we implement our own rate limiter here. +//! +use axum::response::{IntoResponse, Response}; +use axum::{ + body::Body, + http::{Request, StatusCode}, +}; +use futures_util::future::BoxFuture; +use governor::clock::QuantaClock; +use governor::state::keyed::DashMapStateStore; +use std::num::NonZero; +use std::sync::Arc; +use std::time::Duration; +use std::{convert::Infallible, task::Poll}; +use tower::{Layer, Service}; + +use crate::core::error::Result; +use crate::core::extractors::PubkyHost; +use crate::core::Error; +use crate::quota_config::{LimitKey, PathLimit, RateUnit}; +use futures_util::StreamExt; +use governor::{Jitter, Quota, RateLimiter}; + +use super::extract_ip::extract_ip; + +/// A Tower Layer to handle general rate limiting. +/// +/// Supports rate limiting by request count and by upload/download speed. +/// +/// Requires a `PubkyHostLayer` to be applied first. +/// Used to extract the user pubkey as the key for the rate limiter. +/// +/// Returns 400 BAD REQUEST if the user pubkey aka pubky-host cannot be extracted. +/// +#[derive(Debug, Clone)] +pub struct RateLimiterLayer { + limits: Vec, +} + +impl RateLimiterLayer { + /// Create a new rate limiter layer with the given quota. + /// + /// If quota is None, rate limiting is disabled. + pub fn new(limits: Vec) -> Self { + if limits.is_empty() { + tracing::info!("Rate limiting is disabled."); + } else { + let limits_str = limits + .iter() + .map(|limit| format!("\"{limit}\"")) + .collect::>(); + tracing::info!("Rate limits configured: {}", limits_str.join(", ")); + } + Self { limits } + } +} + +impl Layer for RateLimiterLayer { + type Service = RateLimiterMiddleware; + + fn layer(&self, inner: S) -> Self::Service { + let tuples = self + .limits + .iter() + .map(|path| LimitTuple::new(path.clone())) + .collect(); + + RateLimiterMiddleware { + inner, + limits: tuples, + } + } +} + +/// A tuple of a path limit and the actual governor rate limiter. +#[derive(Debug, Clone)] +struct LimitTuple( + pub PathLimit, + pub Arc, QuantaClock>>, +); + +impl LimitTuple { + pub fn new(path_limit: PathLimit) -> Self { + let quota: Quota = path_limit.clone().into(); + let limiter = Arc::new(RateLimiter::keyed(quota)); + Self(path_limit, limiter) + } + + /// Extract the key from the request. + /// + /// The key is either the ip address of the client + /// or the user pubkey. + fn extract_key(&self, req: &Request) -> anyhow::Result { + match self.0.key { + LimitKey::Ip => extract_ip(req).map(|ip| ip.to_string()), + LimitKey::User => { + // Extract the user pubkey from the request. + req.extensions() + .get::() + .map(|pk| pk.public_key().to_string()) + .ok_or(anyhow::anyhow!("Failed to extract user pubkey.")) + } + } + } + + /// Check if the request matches the limit. + pub fn is_match(&self, req: &Request) -> bool { + self.0.path.is_match(req.uri().path()) && self.0.method.0 == req.method() + } +} + +#[derive(Debug, Clone)] +pub struct RateLimiterMiddleware { + inner: S, + limits: Vec, +} + +impl RateLimiterMiddleware { + /// Throttle the upload body. + fn throttle_upload( + req: Request, + key: &str, + limiter: &Arc, QuantaClock>>, + ) -> Request { + let (parts, body) = req.into_parts(); + let new_body = Self::throttle_body(body, key, limiter); + Request::from_parts(parts, new_body) + } + + /// Throttle the download body. + fn throttle_download( + res: Response, + key: &str, + limiter: &Arc, QuantaClock>>, + ) -> Response { + let (parts, body) = res.into_parts(); + let new_body = Self::throttle_body(body, key, limiter); + Response::from_parts(parts, new_body) + } + + /// Throttle the up or download body. + /// + /// Important: The speed quotas are always in kilobytes, not bytes. + /// Counting bytes is not practical. + /// + fn throttle_body( + body: Body, + key: &str, + limiter: &Arc, QuantaClock>>, + ) -> Body { + let body_stream = body.into_data_stream(); + let limiter = limiter.clone(); + let key = key.to_string(); + let throttled = body_stream + .map(move |chunk| { + let limiter = limiter.clone(); + let key = key.to_string(); + // When the rate limit is exceeded, we wait between 25ms and 500ms before retrying. + // This is to avoid overwhelming the server with requests when the rate limit is exceeded. + // Randomization is used to avoid thundering herd problem. + let jitter = Jitter::new( + Duration::from_millis(25), + Duration::from_millis(500), + ); + async move { + let bytes = match chunk { + Ok(actual_chunk) => { + actual_chunk + } + Err(e) => return Err(e), + }; + + // --- Round up to the nearest kilobyte. --- + // Important: If the chunk is < 1KB, it will be rounded up to 1 kb. + // Many small uploads will be counted as more than they actually are. + // I am not too concerned about this though because small random disk writes are stressing + // the disk more anyway compared to larger writes. + // Why are we doing this? governor::Quota is defined as a u32. u32 can only count up to 4GB. + // To support 4GB/s+ limits we need to count in kilobytes. + // + // --- Chunk Size --- + // The chunk size is determined by the client library. + // Common chunk sizes: 16KB to 10MB. + // HTTP based uploads are usually between 256KB and 1MB. + // Asking the limiter for 1KB packets is tradeoff between + // - Not calling the limiter too much + // - Guaranteeing the call size (1kb) is low enough to not cause race condition issues. + let chunk_kilobytes = bytes.len().div_ceil(1024); + for _ in 0..chunk_kilobytes { + // Check each kilobyte + if limiter + .until_key_n_ready_with_jitter( + &key, + NonZero::new(1).expect("1 is always non zero"), + jitter, + ) + .await.is_err() + { + // Requested rate (1kb) is higher then the set limit (x kb/s). + // This should never happen. + unreachable!("Rate limiting is based on the number of kilobytes, not bytes. So 1 kb should always be allowed."); + }; + } + Ok(bytes) + } + }) + .buffered(1); + + Body::from_stream(throttled) + } + + /// Get the limits that match the request. + fn get_limit_matches(&self, req: &Request) -> Vec<&LimitTuple> { + self.limits + .iter() + .filter(|limit| limit.is_match(req)) + .collect() + } +} + +impl Service> for RateLimiterMiddleware +where + S: Service, Response = axum::response::Response, Error = Infallible> + + Send + + 'static + + Clone, + S::Future: Send + 'static, +{ + type Response = S::Response; + type Error = S::Error; + type Future = BoxFuture<'static, Result>; + + fn poll_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll> { + self.inner.poll_ready(cx).map_err(|_| unreachable!()) // `Infallible` conversion + } + + fn call(&mut self, mut req: Request) -> Self::Future { + let mut inner = self.inner.clone(); + + // Match the request path + method to the defined limits. + let limits = self.get_limit_matches(&req); + if limits.is_empty() { + // No limits matched, so we can just call the next layer. + return Box::pin(async move { inner.call(req).await.map_err(|_| unreachable!()) }); + } + + // Go through all the limits and check if we need to throttle or reject the request. + for limit in limits.clone() { + let key = match limit.extract_key(&req) { + Ok(key) => key, + Err(e) => { + // Failed to extract the key, so we reject the request. + // This should only happen if the pubky-host couldn't be extracted. + tracing::warn!( + "{} {} Failed to extract key for rate limiting: {}", + limit.0.path.0, + limit.0.method.0, + e + ); + return Box::pin(async move { + Ok(Error::new( + StatusCode::BAD_REQUEST, + Some("Failed to extract key for rate limiting".to_string()), + ) + .into_response()) + }); + } + }; + + match limit.0.quota.rate_unit { + RateUnit::SpeedRateUnit(_) => { + // Speed limiting is enabled, so we need to throttle the upload. + req = Self::throttle_upload(req, &key, &limit.1); + } + RateUnit::Request => { + // Request limiting is enabled, so we need to limit the number of requests. + if let Err(e) = limit.1.check_key(&key) { + tracing::debug!( + "Rate limit of {} exceeded for {key}: {}", + limit.0.quota, + e + ); + return Box::pin(async move { + Ok(Error::new( + StatusCode::TOO_MANY_REQUESTS, + Some("Rate limit exceeded".to_string()), + ) + .into_response()) + }); + }; + } + }; + } + + // Create a clone of the request without the body. + // This way, we can extract the keys for the response too. + let (parts, body) = req.into_parts(); + let req_clone = Request::from_parts(parts.clone(), Body::empty()); + let req = Request::from_parts(parts, body); + + let speed_limits = limits + .into_iter() + .filter(|limit| limit.0.quota.rate_unit.is_speed_rate_unit()) + .cloned() + .collect::>(); + Box::pin(async move { + // Call the next layer and receive the response. + let mut response = match inner.call(req).await.map_err(|_| unreachable!()) { + Ok(response) => response, + Err(e) => return Err(e), + }; + // Rate limit the download speed. + for limit in speed_limits { + if let Ok(key) = limit.extract_key(&req_clone) { + response = Self::throttle_download(response, &key, &limit.1); + }; + } + Ok(response) + }) + } +} + +#[cfg(test)] +mod tests { + use std::net::{IpAddr, Ipv4Addr, SocketAddr}; + + use axum::http::Method; + use axum::{ + routing::{get, post}, + Router, + }; + use axum_server::Server; + use pkarr::{Keypair, PublicKey}; + use reqwest::{Client, Response}; + use tokio::{task::JoinHandle, time::Instant}; + + use crate::{core::layers::pubky_host::PubkyHostLayer, quota_config::GlobPattern}; + + use super::*; + + // Fake upload handler that just consumes the body. + pub async fn upload_handler(body: Body) -> Result { + let mut stream = body.into_data_stream(); + while let Some(chunk) = stream.next().await.transpose()? { + // Consume body + let _ = chunk; + } + Ok((StatusCode::CREATED, ())) + } + + // Fake upload handler that just consumes the body. + pub async fn download_handler() -> Result { + let response_body = vec![0u8; 3 * 1024]; // 3kb + Ok((StatusCode::OK, response_body)) + } + + // Start a server with the given quota config on a random port. + async fn start_server(config: Vec) -> SocketAddr { + let app = Router::new() + .route("/upload", post(upload_handler)) + .route("/download", get(download_handler)) + .layer(RateLimiterLayer::new(config)) + .layer(PubkyHostLayer); + + // Create a TCP listener to bind to the socket first + // Use port 0 to let the OS assign a random available port + let listener = tokio::net::TcpListener::bind(SocketAddr::new( + IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), + 0, + )) + .await + .unwrap(); + // Get the actual socket address with the OS-assigned port + let socket = listener.local_addr().unwrap(); + + // Use the listener with axum_server + let server = Server::from_tcp(listener.into_std().unwrap()); + + tokio::spawn(async move { + server + .serve(app.into_make_service_with_connect_info::()) + .await + .unwrap(); + }); + + socket + } + + #[tokio::test] + async fn test_throttle_upload() { + let path_limit = PathLimit::new( + GlobPattern::new("/upload"), + Method::POST, + "1kb/s".parse().unwrap(), + LimitKey::Ip, + None, + ); + let socket = start_server(vec![path_limit]).await; + + fn upload_data(socket: SocketAddr, kilobytes: usize) -> JoinHandle<()> { + tokio::spawn(async move { + let client = Client::new(); + let data = vec![0u8; kilobytes * 1024]; + let response = client + .post(format!("http://{}/upload", socket)) + .body(data.clone()) + .send() + .await + .unwrap(); + assert_eq!(response.status(), StatusCode::CREATED); + }) + } + + let start = Instant::now(); + // Spawn in the background to test 2 uploads in parallel + // Upload 3kb each + let handle1 = upload_data(socket, 3); + let handle2 = upload_data(socket, 3); + + // Wait for the uploads to finish + let _ = tokio::try_join!(handle1, handle2); + + let time_taken = start.elapsed(); + assert!(time_taken > Duration::from_secs(6), "Should at least take 6s because uploads are limited to 1kb/s and the sum of the uploads is 6kb"); + } + + #[tokio::test] + async fn test_throttle_download() { + let path_limit = PathLimit::new( + GlobPattern::new("/download"), + Method::GET, + "1kb/s".parse().unwrap(), + LimitKey::Ip, + None, + ); + let socket = start_server(vec![path_limit]).await; + + fn download_data(socket: SocketAddr) -> JoinHandle<()> { + tokio::spawn(async move { + let client = Client::new(); + let response = client + .get(format!("http://{}/download", socket)) + .send() + .await + .unwrap(); + assert_eq!(response.status(), StatusCode::OK); + response.bytes().await.unwrap(); // Download the body + }) + } + + let start = Instant::now(); + // Spawn in the background to test 2 downloads in parallel + // Download 3kb each + let handle1 = download_data(socket); + let handle2 = download_data(socket); + + // Wait for the uploads to finish + let _ = tokio::try_join!(handle1, handle2); + + let time_taken = start.elapsed(); + assert!(time_taken > Duration::from_secs(6), "Should at least take 6s because downloads are limited to 1kb/s and the sum of the downloads is 6kb"); + } + + #[tokio::test] + async fn test_limit_parallel_requests_with_ip_key() { + let path_limit = PathLimit::new( + GlobPattern::new("/upload"), + Method::POST, + "1r/m".parse().unwrap(), + LimitKey::Ip, + None, + ); + let socket = start_server(vec![path_limit]).await; + + fn send_request(socket: SocketAddr) -> JoinHandle { + tokio::spawn(async move { + let client = Client::new(); + let response = client + .post(format!("http://{}/upload", socket)) + .send() + .await + .unwrap(); + response + }) + } + + // Spawn in the background to test 2 uploads in parallel + let handle1 = send_request(socket); + let handle2 = send_request(socket); + + // Wait for the uploads to finish + let (res1, res2) = tokio::try_join!(handle1, handle2).unwrap(); + assert_eq!(res1.status(), StatusCode::CREATED); + assert_eq!(res2.status(), StatusCode::TOO_MANY_REQUESTS); + } + + #[tokio::test] + async fn test_limit_parallel_requests_with_user_key() { + let path_limit = PathLimit::new( + GlobPattern::new("/upload"), + Method::POST, + "1r/m".parse().unwrap(), + LimitKey::User, + None, + ); + let socket = start_server(vec![path_limit]).await; + + fn send_request(socket: SocketAddr, user_pubkey: PublicKey) -> JoinHandle { + tokio::spawn(async move { + let client = Client::new(); + let response = client + .post(format!("http://{}/upload?pubky-host={user_pubkey}", socket)) + .send() + .await + .unwrap(); + response + }) + } + + // Spawn in the background to test 2 uploads in parallel + let user1_pubkey = Keypair::random().public_key(); + let handle1 = send_request(socket, user1_pubkey.clone()); + let handle2 = send_request(socket, user1_pubkey.clone()); + let user2_pubkey = Keypair::random().public_key(); + let handle3 = send_request(socket, user2_pubkey.clone()); + + // Wait for the uploads to finish + let (res1, res2, res3) = tokio::try_join!(handle1, handle2, handle3).unwrap(); + assert_eq!(res1.status(), StatusCode::CREATED); + assert_eq!(res2.status(), StatusCode::TOO_MANY_REQUESTS); + assert_eq!(res3.status(), StatusCode::CREATED); + } +} diff --git a/pubky-homeserver/src/core/layers/rate_limiter/mod.rs b/pubky-homeserver/src/core/layers/rate_limiter/mod.rs new file mode 100644 index 0000000..ee7d736 --- /dev/null +++ b/pubky-homeserver/src/core/layers/rate_limiter/mod.rs @@ -0,0 +1,3 @@ +mod extract_ip; +mod layer; +pub use layer::*; diff --git a/pubky-homeserver/src/core/routes/mod.rs b/pubky-homeserver/src/core/routes/mod.rs index 19f8b40..b1bfc12 100644 --- a/pubky-homeserver/src/core/routes/mod.rs +++ b/pubky-homeserver/src/core/routes/mod.rs @@ -13,9 +13,11 @@ use tower::ServiceBuilder; use tower_cookies::CookieManagerLayer; use tower_http::cors::CorsLayer; -use crate::core::AppState; +use crate::{core::AppState, AppContext}; -use super::layers::{pubky_host::PubkyHostLayer, trace::with_trace_layer}; +use super::layers::{ + pubky_host::PubkyHostLayer, rate_limiter::RateLimiterLayer, trace::with_trace_layer, +}; mod auth; mod feed; @@ -37,16 +39,20 @@ fn base() -> Router { // TODO: maybe add to a separate router (drive router?). } -pub fn create_app(state: AppState) -> Router { +pub fn create_app(state: AppState, context: &AppContext) -> Router { let app = base() .merge(tenants::router(state.clone())) .layer(CookieManagerLayer::new()) .layer(CorsLayer::very_permissive()) .layer(ServiceBuilder::new().layer(middleware::from_fn(add_server_header))) + .layer(PubkyHostLayer) + .layer(RateLimiterLayer::new( + context.config_toml.drive.rate_limits.clone(), + )) .with_state(state); // Apply trace and pubky host layers to the complete router. - with_trace_layer(app, &TRACING_EXCLUDED_PATHS).layer(PubkyHostLayer) + with_trace_layer(app, &TRACING_EXCLUDED_PATHS) } // Middleware to add a `Server` header to all responses diff --git a/pubky-homeserver/src/data_directory/config.default.toml b/pubky-homeserver/src/data_directory/config.default.toml index c7da0fa..23f2cc8 100644 --- a/pubky-homeserver/src/data_directory/config.default.toml +++ b/pubky-homeserver/src/data_directory/config.default.toml @@ -3,9 +3,11 @@ signup_mode = "token_required" lmdb_backup_interval_s = 0 user_storage_quota_mb = 0 + [drive] pubky_listen_socket = "127.0.0.1:6287" icann_listen_socket = "127.0.0.1:6286" +rate_limits = [] [admin] listen_socket = "127.0.0.1:6288" diff --git a/pubky-homeserver/src/data_directory/config_toml.rs b/pubky-homeserver/src/data_directory/config_toml.rs index 9e7c647..9d8a479 100644 --- a/pubky-homeserver/src/data_directory/config_toml.rs +++ b/pubky-homeserver/src/data_directory/config_toml.rs @@ -4,7 +4,7 @@ //! This module embeds that file at compile-time, parses it once, //! and lets callers optionally layer their own TOML on top. -use super::{domain_port::DomainPort, Domain, SignupMode}; +use super::{domain_port::DomainPort, quota_config::PathLimit, Domain, SignupMode}; use serde::{Deserialize, Serialize}; use std::{ fmt::Debug, @@ -54,6 +54,7 @@ pub struct PkdnsToml { pub struct DriveToml { pub pubky_listen_socket: SocketAddr, pub icann_listen_socket: SocketAddr, + pub rate_limits: Vec, } #[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] @@ -144,9 +145,8 @@ impl ConfigToml { .lines() .map(|line| { let trimmed = line.trim_start(); - let is_title = trimmed.starts_with('['); let is_comment = trimmed.starts_with('#'); - if !is_title && !is_comment && !trimmed.is_empty() { + if !is_comment && !trimmed.is_empty() { format!("# {}", line) } else { line.to_string() @@ -216,12 +216,13 @@ mod tests { assert_eq!(c.pkdns.dht_bootstrap_nodes, None); assert_eq!(c.pkdns.dht_relay_nodes, None); assert_eq!(c.pkdns.dht_request_timeout_ms, None); + assert_eq!(c.drive.rate_limits, vec![]); } #[test] fn test_sample_config() { // Validate that the sample config can be parsed - ConfigToml::from_str(SAMPLE_CONFIG).expect("Embedded config.default.toml must be valid"); + ConfigToml::from_str(SAMPLE_CONFIG).expect("Embedded config.sample.toml must be valid"); } #[test] diff --git a/pubky-homeserver/src/data_directory/mod.rs b/pubky-homeserver/src/data_directory/mod.rs index bf947df..cc3fc1e 100644 --- a/pubky-homeserver/src/data_directory/mod.rs +++ b/pubky-homeserver/src/data_directory/mod.rs @@ -4,6 +4,8 @@ mod domain; mod domain_port; mod mock_data_dir; mod persistent_data_dir; +/// Quota configuration for the TomlConfig. +pub mod quota_config; mod signup_mode; pub use config_toml::{ConfigReadError, ConfigToml}; diff --git a/pubky-homeserver/src/data_directory/quota_config/glob_pattern.rs b/pubky-homeserver/src/data_directory/quota_config/glob_pattern.rs new file mode 100644 index 0000000..7583965 --- /dev/null +++ b/pubky-homeserver/src/data_directory/quota_config/glob_pattern.rs @@ -0,0 +1,89 @@ +use serde::{Deserialize, Serialize}; +use std::{fmt::Display, str::FromStr}; + +/// A wrapper around fast_glob to allow serialize/deserialize. +/// Pattern matches glob patterns. +/// +/// Syntax - Meaning +/// `?` - Matches any single character. +/// `*` - Matches zero or more characters, except for path separators (e.g. /). +/// `**` - Matches zero or more characters, including path separators. Must match a complete path segment (i.e. followed by a / or the end of the pattern). +/// `[ab]` - Matches one of the characters contained in the brackets. Character ranges, e.g. `[a-z]` are also supported. Use `[!ab]` or `[^ab]` to match any character except those contained in the brackets. +/// `{a,b}` - Matches one of the patterns contained in the braces. Any of the wildcard characters can be used in the sub-patterns. Braces may be nested up to 10 levels deep. +/// `!` - When at the start of the glob, this negates the result. Multiple `!` characters negate the glob multiple times. +/// `\` - A backslash character may be used to escape any of the above special characters. +#[derive(Debug, Clone)] +pub struct GlobPattern(pub String); + +impl GlobPattern { + /// Create a new glob pattern. + pub fn new(pattern: &str) -> Self { + Self(pattern.to_string()) + } + + /// Check if the path matches the glob pattern. + pub fn is_match(&self, path: &str) -> bool { + fast_glob::glob_match(&self.0, path) + } +} + +impl std::hash::Hash for GlobPattern { + fn hash(&self, state: &mut H) { + self.0.as_str().hash(state); + } +} + +impl FromStr for GlobPattern { + type Err = String; + + fn from_str(s: &str) -> Result { + Ok(GlobPattern(s.to_string())) + } +} + +impl Display for GlobPattern { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.0.as_str()) + } +} + +impl Serialize for GlobPattern { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + serializer.serialize_str(self.0.as_str()) + } +} + +impl<'de> Deserialize<'de> for GlobPattern { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + let s = String::deserialize(deserializer)?; + GlobPattern::from_str(&s).map_err(serde::de::Error::custom) + } +} + +impl PartialEq for GlobPattern { + fn eq(&self, other: &Self) -> bool { + self.0.as_str() == other.0.as_str() + } +} + +impl Eq for GlobPattern {} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_glob_pattern() { + let glob_pattern = GlobPattern::from_str("/pub/**").unwrap(); + assert!(glob_pattern.is_match("/pub/test.txt")); + assert!(glob_pattern.is_match("/pub/test/test.txt")); + assert!(!glob_pattern.is_match("/priv/test.pdf")); + assert!(!glob_pattern.is_match("/session/test.txt")); + } +} diff --git a/pubky-homeserver/src/data_directory/quota_config/http_method.rs b/pubky-homeserver/src/data_directory/quota_config/http_method.rs new file mode 100644 index 0000000..23ed3a4 --- /dev/null +++ b/pubky-homeserver/src/data_directory/quota_config/http_method.rs @@ -0,0 +1,55 @@ +use std::{fmt::Display, str::FromStr}; + +use axum::http::Method; +use serde::{Deserialize, Serialize}; + +/// A wrapper around http::Method to implement serde traits +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct HttpMethod(pub Method); + +impl From for HttpMethod { + fn from(method: Method) -> Self { + HttpMethod(method) + } +} + +impl From for Method { + fn from(method: HttpMethod) -> Self { + method.0 + } +} + +impl FromStr for HttpMethod { + type Err = String; + + fn from_str(s: &str) -> Result { + Method::from_str(s.to_uppercase().as_str()) + .map(HttpMethod) + .map_err(|_| format!("Invalid method: {}", s)) + } +} + +impl Display for HttpMethod { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.0) + } +} + +impl Serialize for HttpMethod { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + serializer.serialize_str(&self.to_string()) + } +} + +impl<'de> Deserialize<'de> for HttpMethod { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + let s = String::deserialize(deserializer)?; + HttpMethod::from_str(&s).map_err(serde::de::Error::custom) + } +} diff --git a/pubky-homeserver/src/data_directory/quota_config/limit_key.rs b/pubky-homeserver/src/data_directory/quota_config/limit_key.rs new file mode 100644 index 0000000..567dcb3 --- /dev/null +++ b/pubky-homeserver/src/data_directory/quota_config/limit_key.rs @@ -0,0 +1,55 @@ +use std::fmt; +use std::str::FromStr; + +/// The key to limit the quota on. +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub enum LimitKey { + /// Limit on the user id + User, + /// Limit on the ip address + Ip, +} + +impl fmt::Display for LimitKey { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "{}", + match self { + LimitKey::User => "user", + LimitKey::Ip => "ip", + } + ) + } +} + +impl FromStr for LimitKey { + type Err = String; + + fn from_str(s: &str) -> Result { + match s { + "user" => Ok(LimitKey::User), + "ip" => Ok(LimitKey::Ip), + _ => Err(format!("Invalid limit key: {}", s)), + } + } +} + +impl serde::Serialize for LimitKey { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + serializer.serialize_str(self.to_string().as_str()) + } +} + +impl<'de> serde::Deserialize<'de> for LimitKey { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + let s = String::deserialize(deserializer)?; + LimitKey::from_str(&s).map_err(serde::de::Error::custom) + } +} diff --git a/pubky-homeserver/src/data_directory/quota_config/mod.rs b/pubky-homeserver/src/data_directory/quota_config/mod.rs new file mode 100644 index 0000000..f881220 --- /dev/null +++ b/pubky-homeserver/src/data_directory/quota_config/mod.rs @@ -0,0 +1,15 @@ +mod glob_pattern; +mod http_method; +mod limit_key; +mod path_limit; +mod quota_value; +mod rate_unit; +mod time_unit; + +pub use glob_pattern::GlobPattern; +pub use http_method::HttpMethod; +pub use limit_key::LimitKey; +pub use path_limit::*; +pub use quota_value::QuotaValue; +pub use rate_unit::RateUnit; +pub use time_unit::TimeUnit; diff --git a/pubky-homeserver/src/data_directory/quota_config/path_limit.rs b/pubky-homeserver/src/data_directory/quota_config/path_limit.rs new file mode 100644 index 0000000..c1c2722 --- /dev/null +++ b/pubky-homeserver/src/data_directory/quota_config/path_limit.rs @@ -0,0 +1,77 @@ +use super::{GlobPattern, HttpMethod, LimitKey, QuotaValue}; +use axum::http::Method; +use serde::{Deserialize, Serialize}; +use std::num::NonZeroU32; + +/// A limit on a path for a specific method. +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Hash)] +pub struct PathLimit { + /// The path glob pattern to match against. + pub path: GlobPattern, + /// The method to limit. + pub method: HttpMethod, + /// The limit to apply. + pub quota: QuotaValue, + /// The key to limit. + pub key: LimitKey, + /// The burst to apply. + pub burst: Option, +} + +impl PathLimit { + /// Create a new path limit. + pub fn new( + path: GlobPattern, + method: Method, + quota: QuotaValue, + key: LimitKey, + burst: Option, + ) -> Self { + Self { + path, + method: HttpMethod(method), + quota, + key, + burst, + } + } +} + +impl std::fmt::Display for PathLimit { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "{} {}: {}-{} by {}", + self.method, + self.path, + self.quota, + self.burst.map_or_else(String::new, |b| b.to_string()), + self.key + ) + } +} + +impl From for governor::Quota { + fn from(value: PathLimit) -> Self { + let quota: governor::Quota = value.quota.into(); + if let Some(burst) = value.burst { + quota.allow_burst(burst); + } + quota + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_http_method_serde() { + let method = Method::GET; + let http_method = HttpMethod(method); + assert_eq!(http_method.to_string(), "GET"); + + let deserialized: HttpMethod = "GET".parse().unwrap(); + assert_eq!(deserialized, http_method); + } +} diff --git a/pubky-homeserver/src/data_directory/quota_config/quota_value.rs b/pubky-homeserver/src/data_directory/quota_config/quota_value.rs new file mode 100644 index 0000000..f68e788 --- /dev/null +++ b/pubky-homeserver/src/data_directory/quota_config/quota_value.rs @@ -0,0 +1,201 @@ +use std::fmt; +use std::str::FromStr; +use std::{num::NonZeroU32, time::Duration}; + +use super::{RateUnit, TimeUnit}; + +/// Quota value +/// +/// Examples: +/// - 5r/m +/// - 5r/s +/// - 5kb/m +/// - 5mb/m +/// - 5gb/s +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct QuotaValue { + /// The rate. + pub rate: NonZeroU32, + /// The unit of the rate. + pub rate_unit: RateUnit, + /// The unit of the time. + pub time_unit: TimeUnit, +} + +impl From for governor::Quota { + /// Get the quota to do the actual rate limiting. + /// + /// Important: The speed quotas are always in kilobytes, not bytes. + /// Counting bytes is not practical. + /// + fn from(value: QuotaValue) -> Self { + let rate_count = value.rate.get(); + let rate_unit = value.rate_unit.multiplier().get(); + let rate = NonZeroU32::new(rate_count * rate_unit) + .expect("Is always non-zero because rate count and rate unit multiplier are non-zero"); + let time_unit = Duration::from_secs(value.time_unit.multiplier_in_seconds().get() as u64); + let replenish_1_per = time_unit / rate.get(); + + let base_quota = governor::Quota::with_period(replenish_1_per) + .expect("Is always non-zero because replenish_1_per is non-zero"); + base_quota.allow_burst(rate) + } +} + +impl fmt::Display for QuotaValue { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}{}/{}", self.rate, self.rate_unit, self.time_unit) + } +} + +impl FromStr for QuotaValue { + type Err = String; + + fn from_str(s: &str) -> Result { + // Split rate part by '/' to get rate+unit and time unit + let rate_parts: Vec<&str> = s.split('/').collect(); + if rate_parts.len() != 2 { + return Err(format!( + "Invalid rate format: '{}', expected {{rate}}{{unit}}/{{time}}", + s + )); + } + + let rate_with_unit = rate_parts[0]; + let time_unit = TimeUnit::from_str(rate_parts[1])?; + + // Find the boundary between rate digits and unit + let rate_digit_end = rate_with_unit + .chars() + .position(|c| !c.is_ascii_digit()) + .unwrap_or(rate_with_unit.len()); + + if rate_digit_end == 0 { + return Err(format!("Missing rate value in '{}'", rate_with_unit)); + } + + let rate_str = &rate_with_unit[..rate_digit_end]; + let rate_unit_str = &rate_with_unit[rate_digit_end..]; + + let rate = rate_str + .parse::() + .map_err(|_| format!("Failed to parse rate from '{}'", rate_str))?; + let rate_unit = RateUnit::from_str(rate_unit_str)?; + + Ok(QuotaValue { + rate, + rate_unit, + time_unit, + }) + } +} + +impl serde::Serialize for QuotaValue { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + serializer.serialize_str(&self.to_string()) + } +} + +impl<'de> serde::Deserialize<'de> for QuotaValue { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + let s = String::deserialize(deserializer)?; + + // Parse the quota string + QuotaValue::from_str(&s).map_err(serde::de::Error::custom) + } +} + +#[cfg(test)] +mod tests { + use crate::quota_config::rate_unit::SpeedRateUnit; + + use super::*; + + #[test] + fn test_get_quota() { + let quota = QuotaValue::from_str("5r/s").unwrap(); + assert_eq!( + governor::Quota::from(quota), + governor::Quota::per_second(NonZeroU32::new(5).unwrap()) + ); + + let quota = QuotaValue::from_str("5r/m").unwrap(); + assert_eq!( + governor::Quota::from(quota), + governor::Quota::per_minute(NonZeroU32::new(5).unwrap()) + ); + + let quota = QuotaValue::from_str("5kb/s").unwrap(); + assert_eq!( + governor::Quota::from(quota), + governor::Quota::per_second(NonZeroU32::new(5).unwrap()) + ); + + let quota = QuotaValue::from_str("5mb/m").unwrap(); + assert_eq!( + governor::Quota::from(quota), + governor::Quota::per_minute(NonZeroU32::new(5 * 1024).unwrap()) + ); + } + + #[test] + fn test_quota_value_from_str() { + // Test without burst + let quota = QuotaValue::from_str("5r/s").unwrap(); + assert_eq!(quota.rate, NonZeroU32::new(5).unwrap()); + assert_eq!(quota.rate_unit, RateUnit::Request); + assert_eq!(quota.time_unit, TimeUnit::Second); + + // Test with burst (should fail or be handled differently) + let quota = QuotaValue::from_str("10mb/m").unwrap(); + assert_eq!(quota.rate, NonZeroU32::new(10).unwrap()); + assert_eq!( + quota.rate_unit, + RateUnit::SpeedRateUnit(SpeedRateUnit::Megabyte) + ); + assert_eq!(quota.time_unit, TimeUnit::Minute); + } + + #[test] + fn test_quota_value_display() { + // Test without burst + let quota = QuotaValue { + rate: NonZeroU32::new(5).unwrap(), + rate_unit: RateUnit::Request, + time_unit: TimeUnit::Second, + }; + assert_eq!(quota.to_string(), "5r/s"); + + // Test with burst (should be displayed without burst) + let quota = QuotaValue { + rate: NonZeroU32::new(10).unwrap(), + rate_unit: RateUnit::SpeedRateUnit(SpeedRateUnit::Megabyte), + time_unit: TimeUnit::Minute, + }; + assert_eq!(quota.to_string(), "10mb/m"); + } + + #[test] + fn test_quota_value_invalid_formats() { + // Invalid format: missing / + assert!(QuotaValue::from_str("5rs").is_err()); + + // Invalid format: missing rate + assert!(QuotaValue::from_str("r/s").is_err()); + + // Invalid format: invalid rate unit + assert!(QuotaValue::from_str("5x/s").is_err()); + + // Invalid format: invalid time unit + assert!(QuotaValue::from_str("5r/x").is_err()); + + // Invalid format: invalid burst (this test case might need to be removed or updated) + assert!(QuotaValue::from_str("5r/s-2burst").is_err()); + } +} diff --git a/pubky-homeserver/src/data_directory/quota_config/rate_unit.rs b/pubky-homeserver/src/data_directory/quota_config/rate_unit.rs new file mode 100644 index 0000000..7eee078 --- /dev/null +++ b/pubky-homeserver/src/data_directory/quota_config/rate_unit.rs @@ -0,0 +1,112 @@ +use std::{num::NonZeroU32, str::FromStr}; + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub enum SpeedRateUnit { + /// Kilobyte + Kilobyte, + /// Megabyte + Megabyte, + /// Gigabyte + Gigabyte, +} + +impl SpeedRateUnit { + /// Returns the number of bytes for this unit + pub const fn multiplier(&self) -> NonZeroU32 { + match self { + // Speed quotas are always in kilobytes. + // Because counting bytes is not practical and we are limited to u32 = 4GB max. + // Counting in kb as more practical and we can count up to 4GB*1024 = 4TB. + SpeedRateUnit::Kilobyte => NonZeroU32::new(1).expect("Is always non-zero"), + SpeedRateUnit::Megabyte => NonZeroU32::new(1024).expect("Is always non-zero"), + SpeedRateUnit::Gigabyte => NonZeroU32::new(1024 * 1024).expect("Is always non-zero"), + } + } +} + +impl std::fmt::Display for SpeedRateUnit { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "{}", + match self { + SpeedRateUnit::Kilobyte => "kb", + SpeedRateUnit::Megabyte => "mb", + SpeedRateUnit::Gigabyte => "gb", + } + ) + } +} + +impl FromStr for SpeedRateUnit { + type Err = String; + + fn from_str(s: &str) -> Result { + match s.to_lowercase().as_str() { + "kb" => Ok(SpeedRateUnit::Kilobyte), + "mb" => Ok(SpeedRateUnit::Megabyte), + "gb" => Ok(SpeedRateUnit::Gigabyte), + _ => Err(format!("Invalid speedrate unit: {}", s)), + } + } +} + +/// The unit of the rate. +/// +/// Examples: +/// - "r" -> request +/// - "kb" -> kilobyte +/// - "mb" -> megabyte +/// - "gb" -> gigabyte +/// - "tb" -> terabyte +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub enum RateUnit { + /// Request + Request, + /// Speed rate unit + SpeedRateUnit(SpeedRateUnit), +} + +impl RateUnit { + /// Returns true if the rate unit is a speed rate unit. + pub fn is_speed_rate_unit(&self) -> bool { + matches!(self, RateUnit::SpeedRateUnit(_)) + } +} + +impl std::fmt::Display for RateUnit { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "{}", + match self { + RateUnit::Request => "r".to_string(), + RateUnit::SpeedRateUnit(unit) => unit.to_string(), + } + ) + } +} + +impl FromStr for RateUnit { + type Err = String; + + fn from_str(s: &str) -> Result { + match s.to_lowercase().as_str() { + "r" => Ok(RateUnit::Request), + other => match SpeedRateUnit::from_str(other) { + Ok(unit) => Ok(RateUnit::SpeedRateUnit(unit)), + Err(_) => Err(format!("Invalid rate unit: {}", s)), + }, + } + } +} + +impl RateUnit { + /// Returns the number of bytes for this unit + pub const fn multiplier(&self) -> NonZeroU32 { + match self { + RateUnit::Request => NonZeroU32::new(1).expect("Is always non-zero"), + RateUnit::SpeedRateUnit(unit) => unit.multiplier(), + } + } +} diff --git a/pubky-homeserver/src/data_directory/quota_config/time_unit.rs b/pubky-homeserver/src/data_directory/quota_config/time_unit.rs new file mode 100644 index 0000000..796141c --- /dev/null +++ b/pubky-homeserver/src/data_directory/quota_config/time_unit.rs @@ -0,0 +1,49 @@ +use std::num::NonZeroU32; + +/// The time unit of the quota. +/// +/// Examples: +/// - "s" -> second +/// - "m" -> minute +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub enum TimeUnit { + /// Second + Second, + /// Minute + Minute, +} + +impl std::fmt::Display for TimeUnit { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "{}", + match self { + TimeUnit::Second => "s", + TimeUnit::Minute => "m", + } + ) + } +} + +impl std::str::FromStr for TimeUnit { + type Err = String; + + fn from_str(s: &str) -> Result { + match s { + "s" => Ok(TimeUnit::Second), + "m" => Ok(TimeUnit::Minute), + _ => Err(format!("Invalid time unit: {}", s)), + } + } +} + +impl TimeUnit { + /// Returns the number of seconds for each unit + pub const fn multiplier_in_seconds(&self) -> NonZeroU32 { + match self { + TimeUnit::Second => NonZeroU32::new(1).expect("Is always non-zero"), + TimeUnit::Minute => NonZeroU32::new(60).expect("Is always non-zero"), + } + } +}