diff --git a/e2e/src/tests/mod.rs b/e2e/src/tests/mod.rs index af911fc..561aec1 100644 --- a/e2e/src/tests/mod.rs +++ b/e2e/src/tests/mod.rs @@ -1,3 +1,4 @@ mod auth; mod http; mod public; +mod rate_limiting; diff --git a/e2e/src/tests/rate_limiting.rs b/e2e/src/tests/rate_limiting.rs new file mode 100644 index 0000000..daeb708 --- /dev/null +++ b/e2e/src/tests/rate_limiting.rs @@ -0,0 +1,132 @@ +use std::time::Duration; + +use pkarr::Keypair; +use pubky_testnet::{ + pubky_homeserver::{ + quota_config::{GlobPattern, LimitKey, PathLimit}, + ConfigToml, MockDataDir, + }, + Testnet, +}; +use reqwest::{Method, StatusCode, Url}; +use tokio::time::Instant; + +#[tokio::test] +async fn test_limit_signin_get_session() { + let mut testnet = Testnet::new().await.unwrap(); + let client = testnet.pubky_client_builder().build().unwrap(); + + let mut config = ConfigToml::test(); + config.drive.rate_limits = vec![ + PathLimit::new( + GlobPattern::new("/session"), + Method::POST, + "1r/m".parse().unwrap(), + LimitKey::Ip, + None, + ), // Limit signins + PathLimit::new( + GlobPattern::new("/session"), + Method::GET, + "1r/m".parse().unwrap(), + LimitKey::User, + None, + ), // Limit decode sessions + ]; + let mock_dir = MockDataDir::new(config, None).unwrap(); + let server = testnet + .create_homeserver_suite_with_mock(mock_dir) + .await + .unwrap(); + + // Create a new user + let keypair = Keypair::random(); + client + .signup(&keypair, &server.public_key(), None) + .await + .unwrap(); + + client.signin(&keypair).await.unwrap(); // First signin should be ok + + client.session(&keypair.public_key()).await.unwrap(); // First session should be ok + client + .session(&keypair.public_key()) + .await + .expect_err("Should be rate limited"); // Second session should be rate limited + + client + .signin(&keypair) + .await + .expect_err("Should be rate limited"); // Second signin should be rate limited +} + +#[tokio::test] +async fn test_limit_events() { + let mut testnet = Testnet::new().await.unwrap(); + let client = testnet.pubky_client_builder().build().unwrap(); + + let mut config = ConfigToml::test(); + config.drive.rate_limits = vec![ + PathLimit::new( + GlobPattern::new("/events/"), + Method::GET, + "1r/m".parse().unwrap(), + LimitKey::Ip, + None, + ), // Limit events + ]; + let mock_dir = MockDataDir::new(config, None).unwrap(); + let server = testnet + .create_homeserver_suite_with_mock(mock_dir) + .await + .unwrap(); + + let url = server.pubky_url().join("/events/").unwrap(); + let res = client.get(url.clone()).send().await.unwrap(); // First event should be ok + assert_eq!(res.status(), StatusCode::OK); + + let res = client.get(url).send().await.unwrap(); // Second event should be rate limited + assert_eq!(res.status(), StatusCode::TOO_MANY_REQUESTS); +} + +#[tokio::test] +async fn test_limit_upload() { + let mut testnet = Testnet::new().await.unwrap(); + let client = testnet.pubky_client_builder().build().unwrap(); + + let mut config = ConfigToml::test(); + config.drive.rate_limits = vec![ + PathLimit::new( + GlobPattern::new("/pub/**"), + Method::PUT, + "1kb/s".parse().unwrap(), + LimitKey::User, + None, + ), // Limit events + ]; + let mock_dir = MockDataDir::new(config, None).unwrap(); + let server = testnet + .create_homeserver_suite_with_mock(mock_dir) + .await + .unwrap(); + + // Create a new user + let keypair = Keypair::random(); + client + .signup(&keypair, &server.public_key(), None) + .await + .unwrap(); + + let url: Url = format!("pubky://{}/pub/test.txt", keypair.public_key()) + .parse() + .unwrap(); + let start = Instant::now(); + let res = client + .put(url) + .body(vec![0u8; 3 * 1024]) // 2kb + .send() + .await + .unwrap(); + assert_eq!(res.status(), StatusCode::CREATED); + assert!(start.elapsed() > Duration::from_secs(2)); +} diff --git a/pubky-homeserver/src/core/layers/rate_limiter/layer.rs b/pubky-homeserver/src/core/layers/rate_limiter/layer.rs index bcbfe71..327ffa2 100644 --- a/pubky-homeserver/src/core/layers/rate_limiter/layer.rs +++ b/pubky-homeserver/src/core/layers/rate_limiter/layer.rs @@ -80,16 +80,32 @@ impl Layer for RateLimiterLayer { /// A tuple of a path limit and the actual governor rate limiter. #[derive(Debug, Clone)] -struct LimitTuple( - pub PathLimit, - pub Arc, QuantaClock>>, -); +struct LimitTuple { + pub limit: PathLimit, + pub limiter: Arc, QuantaClock>>, +} impl LimitTuple { pub fn new(path_limit: PathLimit) -> Self { let quota: Quota = path_limit.clone().into(); let limiter = Arc::new(RateLimiter::keyed(quota)); - Self(path_limit, limiter) + + // Forget keys that are not used anymore. This is to prevent memory leaks. + let limiter_clone = limiter.clone(); + tokio::spawn(async move { + let mut interval = tokio::time::interval(Duration::from_secs(60)); + interval.tick().await; + loop { + interval.tick().await; + limiter_clone.retain_recent(); + limiter_clone.shrink_to_fit(); + } + }); + + Self { + limit: path_limit, + limiter, + } } /// Extract the key from the request. @@ -97,7 +113,7 @@ impl LimitTuple { /// The key is either the ip address of the client /// or the user pubkey. fn extract_key(&self, req: &Request) -> anyhow::Result { - match self.0.key { + match self.limit.key { LimitKey::Ip => extract_ip(req).map(|ip| ip.to_string()), LimitKey::User => { // Extract the user pubkey from the request. @@ -111,7 +127,10 @@ impl LimitTuple { /// Check if the request matches the limit. pub fn is_match(&self, req: &Request) -> bool { - self.0.path.is_match(req.uri().path()) && self.0.method.0 == req.method() + let path = req.uri().path(); + let glob_match = self.limit.path.is_match(path); + let method_match = self.limit.method.0 == req.method(); + glob_match && method_match } } @@ -259,13 +278,13 @@ where // This should only happen if the pubky-host couldn't be extracted. tracing::warn!( "{} {} Failed to extract key for rate limiting: {}", - limit.0.path.0, - limit.0.method.0, + limit.limit.path.0, + limit.limit.method.0, e ); return Box::pin(async move { Ok(Error::new( - StatusCode::BAD_REQUEST, + StatusCode::INTERNAL_SERVER_ERROR, Some("Failed to extract key for rate limiting".to_string()), ) .into_response()) @@ -273,17 +292,17 @@ where } }; - match limit.0.quota.rate_unit { + match limit.limit.quota.rate_unit { RateUnit::SpeedRateUnit(_) => { // Speed limiting is enabled, so we need to throttle the upload. - req = Self::throttle_upload(req, &key, &limit.1); + req = Self::throttle_upload(req, &key, &limit.limiter); } RateUnit::Request => { // Request limiting is enabled, so we need to limit the number of requests. - if let Err(e) = limit.1.check_key(&key) { + if let Err(e) = limit.limiter.check_key(&key) { tracing::debug!( "Rate limit of {} exceeded for {key}: {}", - limit.0.quota, + limit.limit.quota, e ); return Box::pin(async move { @@ -306,7 +325,7 @@ where let speed_limits = limits .into_iter() - .filter(|limit| limit.0.quota.rate_unit.is_speed_rate_unit()) + .filter(|limit| limit.limit.quota.rate_unit.is_speed_rate_unit()) .cloned() .collect::>(); Box::pin(async move { @@ -318,7 +337,7 @@ where // Rate limit the download speed. for limit in speed_limits { if let Ok(key) = limit.extract_key(&req_clone) { - response = Self::throttle_download(response, &key, &limit.1); + response = Self::throttle_download(response, &key, &limit.limiter); }; } Ok(response) diff --git a/pubky-homeserver/src/core/routes/mod.rs b/pubky-homeserver/src/core/routes/mod.rs index b1bfc12..791d91c 100644 --- a/pubky-homeserver/src/core/routes/mod.rs +++ b/pubky-homeserver/src/core/routes/mod.rs @@ -45,10 +45,10 @@ pub fn create_app(state: AppState, context: &AppContext) -> Router { .layer(CookieManagerLayer::new()) .layer(CorsLayer::very_permissive()) .layer(ServiceBuilder::new().layer(middleware::from_fn(add_server_header))) - .layer(PubkyHostLayer) .layer(RateLimiterLayer::new( context.config_toml.drive.rate_limits.clone(), )) + .layer(PubkyHostLayer) .with_state(state); // Apply trace and pubky host layers to the complete router. diff --git a/pubky-homeserver/src/data_directory/quota_config/glob_pattern.rs b/pubky-homeserver/src/data_directory/quota_config/glob_pattern.rs index 7583965..e0fd5bf 100644 --- a/pubky-homeserver/src/data_directory/quota_config/glob_pattern.rs +++ b/pubky-homeserver/src/data_directory/quota_config/glob_pattern.rs @@ -86,4 +86,36 @@ mod tests { assert!(!glob_pattern.is_match("/priv/test.pdf")); assert!(!glob_pattern.is_match("/session/test.txt")); } + + #[test] + fn test_glob_pattern2() { + let glob_pattern = GlobPattern::from_str("/events/").unwrap(); + assert!(glob_pattern.is_match("/events/")); + } + + #[test] + fn test_glob_pattern3() { + let glob_pattern = GlobPattern::from_str("/pub/**").unwrap(); + assert!(glob_pattern.is_match("/pub/test.txt")); + assert!(glob_pattern.is_match("/pub/test/test.txt")); + assert!(glob_pattern.is_match("/pub/")); + } + + #[test] + fn test_glob_pattern4() { + let glob_pattern = GlobPattern::from_str("/pub/**/update").unwrap(); + assert!(glob_pattern.is_match("/pub/test.txt/update")); + assert!(glob_pattern.is_match("/pub/test/test.txt/update")); + assert!(!glob_pattern.is_match("/pub/test/test.txt")); + assert!(!glob_pattern.is_match("/pub/")); + } + + #[test] + fn test_glob_pattern5() { + let glob_pattern = GlobPattern::from_str("/pub/**/update/*").unwrap(); + assert!(glob_pattern.is_match("/pub/test.txt/update/test.txt")); + assert!(glob_pattern.is_match("/pub/test/test.txt/update/test.txt")); + assert!(!glob_pattern.is_match("/pub/test/test.txt")); + assert!(!glob_pattern.is_match("/pub/")); + } } diff --git a/pubky-homeserver/src/homeserver_suite/suite.rs b/pubky-homeserver/src/homeserver_suite/suite.rs index 8508dd8..c439894 100644 --- a/pubky-homeserver/src/homeserver_suite/suite.rs +++ b/pubky-homeserver/src/homeserver_suite/suite.rs @@ -78,7 +78,7 @@ impl HomeserverSuite { /// Returns the `https://` url pub fn pubky_url(&self) -> url::Url { - url::Url::parse(&format!("pubky://{}", self.public_key())).expect("valid url") + url::Url::parse(&format!("https://{}", self.public_key())).expect("valid url") } /// Returns the `https://` url