Fix: Rate limit - can not extract user limit key & memory leak (#127)

* before ThrottledBody

* added upload rate limiting

* before dynamic rate limiting

* flexible rate limiter + config

* fixed tests

* fmt and clippy

* reset auth.js e2e

* more cleaning up

* improved code and added comments

* limit downloads too

* fmt and clippy

* improved sample comments

* fixed comment

* removed http dependency in favour of axum:http

* parse speed units as lowercase

* replaced regex with glob

* clippy and fmt

* fix: user rate limiting finding limit key

* fmt

* turn **/* patterns into **

* improved cleanup task
This commit is contained in:
Severin Alexander Bühler
2025-05-14 10:05:01 +03:00
committed by GitHub
parent 1be8bf5e9e
commit 44eb42dd9a
6 changed files with 202 additions and 18 deletions

View File

@@ -1,3 +1,4 @@
mod auth;
mod http;
mod public;
mod rate_limiting;

View File

@@ -0,0 +1,132 @@
use std::time::Duration;
use pkarr::Keypair;
use pubky_testnet::{
pubky_homeserver::{
quota_config::{GlobPattern, LimitKey, PathLimit},
ConfigToml, MockDataDir,
},
Testnet,
};
use reqwest::{Method, StatusCode, Url};
use tokio::time::Instant;
#[tokio::test]
async fn test_limit_signin_get_session() {
let mut testnet = Testnet::new().await.unwrap();
let client = testnet.pubky_client_builder().build().unwrap();
let mut config = ConfigToml::test();
config.drive.rate_limits = vec![
PathLimit::new(
GlobPattern::new("/session"),
Method::POST,
"1r/m".parse().unwrap(),
LimitKey::Ip,
None,
), // Limit signins
PathLimit::new(
GlobPattern::new("/session"),
Method::GET,
"1r/m".parse().unwrap(),
LimitKey::User,
None,
), // Limit decode sessions
];
let mock_dir = MockDataDir::new(config, None).unwrap();
let server = testnet
.create_homeserver_suite_with_mock(mock_dir)
.await
.unwrap();
// Create a new user
let keypair = Keypair::random();
client
.signup(&keypair, &server.public_key(), None)
.await
.unwrap();
client.signin(&keypair).await.unwrap(); // First signin should be ok
client.session(&keypair.public_key()).await.unwrap(); // First session should be ok
client
.session(&keypair.public_key())
.await
.expect_err("Should be rate limited"); // Second session should be rate limited
client
.signin(&keypair)
.await
.expect_err("Should be rate limited"); // Second signin should be rate limited
}
#[tokio::test]
async fn test_limit_events() {
let mut testnet = Testnet::new().await.unwrap();
let client = testnet.pubky_client_builder().build().unwrap();
let mut config = ConfigToml::test();
config.drive.rate_limits = vec![
PathLimit::new(
GlobPattern::new("/events/"),
Method::GET,
"1r/m".parse().unwrap(),
LimitKey::Ip,
None,
), // Limit events
];
let mock_dir = MockDataDir::new(config, None).unwrap();
let server = testnet
.create_homeserver_suite_with_mock(mock_dir)
.await
.unwrap();
let url = server.pubky_url().join("/events/").unwrap();
let res = client.get(url.clone()).send().await.unwrap(); // First event should be ok
assert_eq!(res.status(), StatusCode::OK);
let res = client.get(url).send().await.unwrap(); // Second event should be rate limited
assert_eq!(res.status(), StatusCode::TOO_MANY_REQUESTS);
}
#[tokio::test]
async fn test_limit_upload() {
let mut testnet = Testnet::new().await.unwrap();
let client = testnet.pubky_client_builder().build().unwrap();
let mut config = ConfigToml::test();
config.drive.rate_limits = vec![
PathLimit::new(
GlobPattern::new("/pub/**"),
Method::PUT,
"1kb/s".parse().unwrap(),
LimitKey::User,
None,
), // Limit events
];
let mock_dir = MockDataDir::new(config, None).unwrap();
let server = testnet
.create_homeserver_suite_with_mock(mock_dir)
.await
.unwrap();
// Create a new user
let keypair = Keypair::random();
client
.signup(&keypair, &server.public_key(), None)
.await
.unwrap();
let url: Url = format!("pubky://{}/pub/test.txt", keypair.public_key())
.parse()
.unwrap();
let start = Instant::now();
let res = client
.put(url)
.body(vec![0u8; 3 * 1024]) // 2kb
.send()
.await
.unwrap();
assert_eq!(res.status(), StatusCode::CREATED);
assert!(start.elapsed() > Duration::from_secs(2));
}

View File

@@ -80,16 +80,32 @@ impl<S> Layer<S> for RateLimiterLayer {
/// A tuple of a path limit and the actual governor rate limiter.
#[derive(Debug, Clone)]
struct LimitTuple(
pub PathLimit,
pub Arc<RateLimiter<String, DashMapStateStore<String>, QuantaClock>>,
);
struct LimitTuple {
pub limit: PathLimit,
pub limiter: Arc<RateLimiter<String, DashMapStateStore<String>, QuantaClock>>,
}
impl LimitTuple {
pub fn new(path_limit: PathLimit) -> Self {
let quota: Quota = path_limit.clone().into();
let limiter = Arc::new(RateLimiter::keyed(quota));
Self(path_limit, limiter)
// Forget keys that are not used anymore. This is to prevent memory leaks.
let limiter_clone = limiter.clone();
tokio::spawn(async move {
let mut interval = tokio::time::interval(Duration::from_secs(60));
interval.tick().await;
loop {
interval.tick().await;
limiter_clone.retain_recent();
limiter_clone.shrink_to_fit();
}
});
Self {
limit: path_limit,
limiter,
}
}
/// Extract the key from the request.
@@ -97,7 +113,7 @@ impl LimitTuple {
/// The key is either the ip address of the client
/// or the user pubkey.
fn extract_key(&self, req: &Request<Body>) -> anyhow::Result<String> {
match self.0.key {
match self.limit.key {
LimitKey::Ip => extract_ip(req).map(|ip| ip.to_string()),
LimitKey::User => {
// Extract the user pubkey from the request.
@@ -111,7 +127,10 @@ impl LimitTuple {
/// Check if the request matches the limit.
pub fn is_match(&self, req: &Request<Body>) -> bool {
self.0.path.is_match(req.uri().path()) && self.0.method.0 == req.method()
let path = req.uri().path();
let glob_match = self.limit.path.is_match(path);
let method_match = self.limit.method.0 == req.method();
glob_match && method_match
}
}
@@ -259,13 +278,13 @@ where
// This should only happen if the pubky-host couldn't be extracted.
tracing::warn!(
"{} {} Failed to extract key for rate limiting: {}",
limit.0.path.0,
limit.0.method.0,
limit.limit.path.0,
limit.limit.method.0,
e
);
return Box::pin(async move {
Ok(Error::new(
StatusCode::BAD_REQUEST,
StatusCode::INTERNAL_SERVER_ERROR,
Some("Failed to extract key for rate limiting".to_string()),
)
.into_response())
@@ -273,17 +292,17 @@ where
}
};
match limit.0.quota.rate_unit {
match limit.limit.quota.rate_unit {
RateUnit::SpeedRateUnit(_) => {
// Speed limiting is enabled, so we need to throttle the upload.
req = Self::throttle_upload(req, &key, &limit.1);
req = Self::throttle_upload(req, &key, &limit.limiter);
}
RateUnit::Request => {
// Request limiting is enabled, so we need to limit the number of requests.
if let Err(e) = limit.1.check_key(&key) {
if let Err(e) = limit.limiter.check_key(&key) {
tracing::debug!(
"Rate limit of {} exceeded for {key}: {}",
limit.0.quota,
limit.limit.quota,
e
);
return Box::pin(async move {
@@ -306,7 +325,7 @@ where
let speed_limits = limits
.into_iter()
.filter(|limit| limit.0.quota.rate_unit.is_speed_rate_unit())
.filter(|limit| limit.limit.quota.rate_unit.is_speed_rate_unit())
.cloned()
.collect::<Vec<_>>();
Box::pin(async move {
@@ -318,7 +337,7 @@ where
// Rate limit the download speed.
for limit in speed_limits {
if let Ok(key) = limit.extract_key(&req_clone) {
response = Self::throttle_download(response, &key, &limit.1);
response = Self::throttle_download(response, &key, &limit.limiter);
};
}
Ok(response)

View File

@@ -45,10 +45,10 @@ pub fn create_app(state: AppState, context: &AppContext) -> Router {
.layer(CookieManagerLayer::new())
.layer(CorsLayer::very_permissive())
.layer(ServiceBuilder::new().layer(middleware::from_fn(add_server_header)))
.layer(PubkyHostLayer)
.layer(RateLimiterLayer::new(
context.config_toml.drive.rate_limits.clone(),
))
.layer(PubkyHostLayer)
.with_state(state);
// Apply trace and pubky host layers to the complete router.

View File

@@ -86,4 +86,36 @@ mod tests {
assert!(!glob_pattern.is_match("/priv/test.pdf"));
assert!(!glob_pattern.is_match("/session/test.txt"));
}
#[test]
fn test_glob_pattern2() {
let glob_pattern = GlobPattern::from_str("/events/").unwrap();
assert!(glob_pattern.is_match("/events/"));
}
#[test]
fn test_glob_pattern3() {
let glob_pattern = GlobPattern::from_str("/pub/**").unwrap();
assert!(glob_pattern.is_match("/pub/test.txt"));
assert!(glob_pattern.is_match("/pub/test/test.txt"));
assert!(glob_pattern.is_match("/pub/"));
}
#[test]
fn test_glob_pattern4() {
let glob_pattern = GlobPattern::from_str("/pub/**/update").unwrap();
assert!(glob_pattern.is_match("/pub/test.txt/update"));
assert!(glob_pattern.is_match("/pub/test/test.txt/update"));
assert!(!glob_pattern.is_match("/pub/test/test.txt"));
assert!(!glob_pattern.is_match("/pub/"));
}
#[test]
fn test_glob_pattern5() {
let glob_pattern = GlobPattern::from_str("/pub/**/update/*").unwrap();
assert!(glob_pattern.is_match("/pub/test.txt/update/test.txt"));
assert!(glob_pattern.is_match("/pub/test/test.txt/update/test.txt"));
assert!(!glob_pattern.is_match("/pub/test/test.txt"));
assert!(!glob_pattern.is_match("/pub/"));
}
}

View File

@@ -78,7 +78,7 @@ impl HomeserverSuite {
/// Returns the `https://<server public key>` url
pub fn pubky_url(&self) -> url::Url {
url::Url::parse(&format!("pubky://{}", self.public_key())).expect("valid url")
url::Url::parse(&format!("https://{}", self.public_key())).expect("valid url")
}
/// Returns the `https://<server public key>` url