Merge pull request #195 from hieblmi/rate-limiter

proxy: a configurable rate-limiter per endpoint and L402
This commit is contained in:
Slyghtning
2026-01-23 10:42:39 -05:00
committed by GitHub
11 changed files with 1158 additions and 10 deletions

View File

@@ -50,3 +50,69 @@ services and APIs.
compare with `sample-conf.yaml`.
* Start aperture without any command line parameters (`./aperture`), all configuration
is done in the `~/.aperture/aperture.yaml` file.
## Rate Limiting
Aperture supports optional per-endpoint rate limiting using a token bucket
algorithm. Rate limits are configured per service and applied based on the
client's L402 token ID for authenticated requests, or IP address for
unauthenticated requests.
### Features
* **Token bucket algorithm**: Allows controlled bursting while maintaining a
steady-state request rate.
* **Per-client isolation**: Each L402 token ID or IP address has independent
rate limit buckets.
* **Path-based rules**: Different endpoints can have different rate limits using
regular expressions.
* **Multiple rules**: All matching rules are evaluated; if any rule denies the
request, it is rejected. This allows layering global and endpoint-specific
limits.
* **Protocol-aware responses**: Returns HTTP 429 with `Retry-After` header for
REST requests, and gRPC `ResourceExhausted` status for gRPC requests.
### Configuration
Rate limits are configured in the `ratelimits` section of each service:
```yaml
services:
- name: "myservice"
hostregexp: "api.example.com"
address: "127.0.0.1:8080"
protocol: https
ratelimits:
# Global rate limit for all endpoints
- requests: 100 # Requests allowed per time window
per: 1s # Time window duration (1s, 1m, 1h, etc.)
burst: 100 # Max burst capacity (defaults to 'requests')
# Stricter limit for expensive endpoints
- pathregexp: '^/api/v1/expensive.*$'
requests: 5
per: 1m
burst: 5
```
This example configures two rate limit rules using a token bucket algorithm. Each
client gets a "bucket" of tokens that refills at the `requests/per` rate, up to the
`burst` capacity. A request consumes one token; if no tokens are available, the
request is rejected. This allows clients to make quick bursts of requests (up to
`burst`) while enforcing a steady-state rate limit over time.
1. **Global limit**: All endpoints are limited to 100 requests per second per client,
with a burst capacity of 100.
2. **Endpoint-specific limit**: Paths matching `/api/v1/expensive.*` have a stricter
limit of 5 requests per minute with a burst of 5. Since both rules are evaluated,
requests to expensive endpoints must satisfy both limits.
### Configuration Options
| Option | Description | Required |
|--------|-------------|----------|
| `pathregexp` | Regular expression to match request paths. If omitted, matches all paths. | No |
| `requests` | Number of requests allowed per time window. | Yes |
| `per` | Time window duration (e.g., `1s`, `1m`, `1h`). | Yes |
| `burst` | Maximum burst size. Defaults to `requests` if not set. | No |

View File

@@ -3,10 +3,8 @@ package freebie
import (
"net"
"net/http"
)
var (
defaultIPMask = net.IPv4Mask(0xff, 0xff, 0xff, 0x00)
"github.com/lightninglabs/aperture/netutil"
)
type Count uint16
@@ -17,7 +15,7 @@ type memStore struct {
}
func (m *memStore) getKey(ip net.IP) string {
return ip.Mask(defaultIPMask).String()
return netutil.MaskIP(ip).String()
}
func (m *memStore) currentCount(ip net.IP) Count {
@@ -38,10 +36,10 @@ func (m *memStore) TallyFreebie(r *http.Request, ip net.IP) (bool, error) {
return true, nil
}
// NewMemIPMaskStore creates a new in-memory freebie store that masks the last
// byte of an IP address to keep track of free requests. The last byte of the
// address is discarded for the mapping to reduce risk of abuse by users that
// have a whole range of IPs at their disposal.
// NewMemIPMaskStore creates a new in-memory freebie store that masks IP
// addresses to keep track of free requests. IPv4 addresses are masked to /24
// and IPv6 addresses to /48. This reduces risk of abuse by users that have a
// whole range of IPs at their disposal.
func NewMemIPMaskStore(numFreebies Count) DB {
return &memStore{
numFreebies: numFreebies,

29
netutil/ip.go Normal file
View File

@@ -0,0 +1,29 @@
package netutil
import "net"
var (
// ipv4Mask24 masks IPv4 addresses to /24 (last octet zeroed).
// This groups clients on the same subnet together.
ipv4Mask24 = net.CIDRMask(24, 32)
// ipv6Mask48 masks IPv6 addresses to /48.
// Residential connections typically receive /48 to /64 allocations,
// so /48 provides reasonable grouping for rate limiting purposes.
ipv6Mask48 = net.CIDRMask(48, 128)
)
// MaskIP returns a masked version of the IP address for grouping purposes.
// IPv4 addresses are masked to /24 (zeroing the last octet).
// IPv6 addresses are masked to /48.
//
// This is useful for rate limiting and freebie tracking where we want to
// group requests from the same network segment rather than individual IPs,
// reducing abuse potential from users with multiple addresses.
func MaskIP(ip net.IP) net.IP {
if ip4 := ip.To4(); ip4 != nil {
return ip4.Mask(ipv4Mask24)
}
return ip.Mask(ipv6Mask48)
}

117
netutil/ip_test.go Normal file
View File

@@ -0,0 +1,117 @@
package netutil
import (
"net"
"testing"
"github.com/stretchr/testify/require"
)
// TestMaskIP verifies that MaskIP correctly applies /24 masks to IPv4 and /48
// masks to IPv6 addresses.
func TestMaskIP(t *testing.T) {
tests := []struct {
name string
input string
expected string
}{
{
name: "IPv4 masks last octet",
input: "192.168.1.123",
expected: "192.168.1.0",
},
{
name: "IPv4 already masked",
input: "10.0.0.0",
expected: "10.0.0.0",
},
{
name: "IPv4 different last octet same result",
input: "192.168.1.255",
expected: "192.168.1.0",
},
{
name: "IPv6 masks to /48",
input: "2001:db8:1234:5678:9abc:def0:1234:5678",
expected: "2001:db8:1234::",
},
{
name: "IPv6 already masked",
input: "2001:db8:abcd::",
expected: "2001:db8:abcd::",
},
{
name: "IPv6 loopback",
input: "::1",
expected: "::",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
ip := net.ParseIP(tc.input)
require.NotNil(t, ip, "failed to parse input IP")
result := MaskIP(ip)
require.Equal(t, tc.expected, result.String())
})
}
}
// TestMaskIP_SameSubnetGroupsTogether verifies that IPv4 addresses in the same
// /24 subnet produce identical masked results.
func TestMaskIP_SameSubnetGroupsTogether(t *testing.T) {
// Verify that IPs in the same /24 subnet produce the same masked result.
ips := []string{
"192.168.1.1",
"192.168.1.100",
"192.168.1.255",
}
results := make([]string, 0, len(ips))
for _, ipStr := range ips {
ip := net.ParseIP(ipStr)
results = append(results, MaskIP(ip).String())
}
// All should be the same.
for i := 1; i < len(results); i++ {
require.Equal(t, results[0], results[i],
"IPs in same /24 should have same masked result")
}
}
// TestMaskIP_DifferentSubnetsDiffer verifies that IPv4 addresses in different
// /24 subnets produce distinct masked results.
func TestMaskIP_DifferentSubnetsDiffer(t *testing.T) {
ip1 := net.ParseIP("192.168.1.100")
ip2 := net.ParseIP("192.168.2.100")
result1 := MaskIP(ip1).String()
result2 := MaskIP(ip2).String()
require.NotEqual(t, result1, result2,
"IPs in different /24 subnets should have different masked results")
}
// TestMaskIP_IPv6SamePrefix48GroupsTogether verifies that IPv6 addresses
// sharing the same /48 prefix produce identical masked results.
func TestMaskIP_IPv6SamePrefix48GroupsTogether(t *testing.T) {
// IPs in the same /48 should produce the same masked result.
ips := []string{
"2001:db8:1234:0001::",
"2001:db8:1234:ffff::",
"2001:db8:1234:abcd:1234:5678:9abc:def0",
}
results := make([]string, 0, len(ips))
for _, ipStr := range ips {
ip := net.ParseIP(ipStr)
results = append(results, MaskIP(ip).String())
}
for i := 1; i < len(results); i++ {
require.Equal(t, results[0], results[i],
"IPs in same /48 should have same masked result")
}
}

View File

@@ -4,12 +4,14 @@ import (
"crypto/tls"
"crypto/x509"
"fmt"
"math"
"net"
"net/http"
"net/http/httputil"
"os"
"strconv"
"strings"
"time"
"github.com/lightninglabs/aperture/auth"
"github.com/lightninglabs/aperture/l402"
@@ -172,6 +174,26 @@ func (p *Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// Determine auth level required to access service and dispatch request
// accordingly.
authLevel := target.AuthRequired(r)
// checkRateLimit is a helper that checks rate limits after determining
// the authentication status. This ensures we only use L402 token IDs
// for authenticated requests, preventing DoS via garbage tokens.
checkRateLimit := func(authenticated bool) bool {
if target.rateLimiter == nil {
return true
}
key := ExtractRateLimitKey(r, remoteIP, authenticated)
allowed, retryAfter := target.rateLimiter.Allow(r, key)
if !allowed {
prefixLog.Infof("Rate limit exceeded for key %s, "+
"retry after %v", key, retryAfter)
addCorsHeaders(w.Header())
sendRateLimitResponse(w, r, retryAfter)
}
return allowed
}
skipInvoiceCreation := target.SkipInvoiceCreation(r)
switch {
case authLevel.IsOn():
@@ -215,6 +237,11 @@ func (p *Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
return
}
// User is authenticated, apply rate limit with L402 token ID.
if !checkRateLimit(true) {
return
}
case authLevel.IsFreebie():
// We only need to respect the freebie counter if the user
// is not authenticated at all.
@@ -267,6 +294,21 @@ func (p *Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
)
return
}
// Unauthenticated freebie user, rate limit by IP.
if !checkRateLimit(false) {
return
}
} else if !checkRateLimit(true) {
// Authenticated user on freebie path, rate limit by
// L402 token.
return
}
default:
// Auth is off, rate limit by IP for unauthenticated access.
if !checkRateLimit(false) {
return
}
}
@@ -486,6 +528,36 @@ func sendDirectResponse(w http.ResponseWriter, r *http.Request,
}
}
// sendRateLimitResponse sends a rate limit exceeded response to the client.
// For HTTP clients, it returns 429 Too Many Requests with Retry-After header.
// For gRPC clients, it returns a ResourceExhausted status.
func sendRateLimitResponse(w http.ResponseWriter, r *http.Request,
retryAfter time.Duration) {
// Round up to ensure clients don't retry before the limit resets.
retrySeconds := int(math.Ceil(retryAfter.Seconds()))
if retrySeconds < 1 {
retrySeconds = 1
}
// Set Retry-After header for both HTTP and gRPC.
w.Header().Set("Retry-After", strconv.Itoa(retrySeconds))
// Check if this is a gRPC request.
if strings.HasPrefix(r.Header.Get(hdrContentType), hdrTypeGrpc) {
w.Header().Set(
hdrGrpcStatus,
strconv.Itoa(int(codes.ResourceExhausted)),
)
w.Header().Set(hdrGrpcMessage, "rate limit exceeded")
// gRPC requires 200 OK even for errors.
w.WriteHeader(http.StatusOK)
} else {
http.Error(w, "rate limit exceeded", http.StatusTooManyRequests)
}
}
type trailerFixingTransport struct {
next http.RoundTripper
}

56
proxy/ratelimit_config.go Normal file
View File

@@ -0,0 +1,56 @@
package proxy
import (
"regexp"
"time"
)
// RateLimitConfig defines a rate limiting rule for a specific path pattern.
type RateLimitConfig struct {
// PathRegexp is a regular expression that matches request paths
// to which this rate limit applies. If empty, matches all paths.
PathRegexp string `long:"pathregexp" description:"Regular expression to match the path of the URL against for rate limiting"`
// Requests is the number of requests allowed per time window (Per).
Requests int `long:"requests" description:"Number of requests allowed per time window"`
// Per is the time window duration (e.g., 1s, 1m, 1h). Defaults to 1s.
Per time.Duration `long:"per" description:"Time window for rate limiting (e.g., 1s, 1m, 1h)"`
// Burst is the maximum number of requests that can be made in a burst,
// exceeding the steady-state rate. Defaults to Requests if not set.
Burst int `long:"burst" description:"Maximum burst size (defaults to Requests if not set)"`
// compiledPathRegexp is the compiled version of PathRegexp.
compiledPathRegexp *regexp.Regexp
}
// Rate returns the rate.Limit value (requests per second) for this
// configuration.
func (r *RateLimitConfig) Rate() float64 {
if r.Per == 0 {
return 0
}
return float64(r.Requests) / r.Per.Seconds()
}
// EffectiveBurst returns the burst value, defaulting to Requests if Burst
// is 0.
func (r *RateLimitConfig) EffectiveBurst() int {
if r.Burst == 0 {
return r.Requests
}
return r.Burst
}
// Matches returns true if the given path matches this rate limit's path
// pattern.
func (r *RateLimitConfig) Matches(path string) bool {
if r.compiledPathRegexp == nil {
return true // No pattern means match all
}
return r.compiledPathRegexp.MatchString(path)
}

View File

@@ -0,0 +1,52 @@
package proxy
import (
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"
)
var (
// rateLimitAllowed counts requests that passed rate limiting.
rateLimitAllowed = promauto.NewCounterVec(
prometheus.CounterOpts{
Namespace: "aperture",
Subsystem: "ratelimit",
Name: "allowed_total",
Help: "Total number of requests allowed by rate limiter",
},
[]string{"service", "path_pattern"},
)
// rateLimitDenied counts requests denied by rate limiting.
rateLimitDenied = promauto.NewCounterVec(
prometheus.CounterOpts{
Namespace: "aperture",
Subsystem: "ratelimit",
Name: "denied_total",
Help: "Total number of requests denied by rate limiter",
},
[]string{"service", "path_pattern"},
)
// rateLimitCacheSize tracks the current size of the rate limiter cache.
rateLimitCacheSize = promauto.NewGaugeVec(
prometheus.GaugeOpts{
Namespace: "aperture",
Subsystem: "ratelimit",
Name: "cache_size",
Help: "Current number of entries in the rate limiter cache",
},
[]string{"service"},
)
// rateLimitEvictions counts LRU cache evictions.
rateLimitEvictions = promauto.NewCounterVec(
prometheus.CounterOpts{
Namespace: "aperture",
Subsystem: "ratelimit",
Name: "evictions_total",
Help: "Total number of rate limiter cache evictions",
},
[]string{"service"},
)
)

252
proxy/ratelimiter.go Normal file
View File

@@ -0,0 +1,252 @@
package proxy
import (
"bytes"
"net"
"net/http"
"sync"
"time"
"github.com/lightninglabs/aperture/l402"
"github.com/lightninglabs/aperture/netutil"
"github.com/lightninglabs/neutrino/cache/lru"
"golang.org/x/time/rate"
)
const (
// DefaultMaxCacheSize is the default maximum number of rate limiter
// entries to keep in the LRU cache.
DefaultMaxCacheSize = 10_000
)
// limiterKey is a composite key for the rate limiter cache. Using a struct
// instead of a concatenated string saves memory because the pathPattern field
// can reference the same underlying string across multiple keys.
type limiterKey struct {
// clientKey identifies the client (e.g., "ip:1.2.3.4" or "token:abc").
clientKey string
// pathPattern is the rate limit rule's PathRegexp (pointer to config's
// string, not a copy).
pathPattern string
}
// limiterEntry holds a rate.Limiter. Implements cache.Value interface.
type limiterEntry struct {
limiter *rate.Limiter
}
// Size implements cache.Value. Returns 1 so the LRU cache counts entries
// rather than bytes.
func (e *limiterEntry) Size() (uint64, error) {
return 1, nil
}
// RateLimiter manages per-key rate limiters with LRU eviction.
type RateLimiter struct {
// cacheMu protects the LRU cache which is not concurrency-safe.
cacheMu sync.Mutex
// configs is the list of rate limit configurations for this limiter.
configs []*RateLimitConfig
// cache is the LRU cache of rate limiter entries.
cache *lru.Cache[limiterKey, *limiterEntry]
// maxSize is the maximum number of entries in the cache.
maxSize int
// serviceName is used for metrics labels.
serviceName string
}
// RateLimiterOption is a functional option for configuring a RateLimiter.
type RateLimiterOption func(*RateLimiter)
// WithMaxCacheSize sets the maximum cache size.
func WithMaxCacheSize(size int) RateLimiterOption {
return func(rl *RateLimiter) {
rl.maxSize = size
}
}
// NewRateLimiter creates a new RateLimiter with the given configurations.
func NewRateLimiter(serviceName string, configs []*RateLimitConfig,
opts ...RateLimiterOption) *RateLimiter {
rl := &RateLimiter{
configs: configs,
maxSize: DefaultMaxCacheSize,
serviceName: serviceName,
}
for _, opt := range opts {
opt(rl)
}
// Initialize the LRU cache with the configured max size.
rl.cache = lru.NewCache[limiterKey, *limiterEntry](uint64(rl.maxSize))
return rl
}
// Allow checks if a request should be allowed based on all matching rate
// limits. Returns (allowed, retryAfter) where retryAfter is the suggested
// duration to wait if denied.
func (rl *RateLimiter) Allow(r *http.Request, key string) (bool,
time.Duration) {
path := r.URL.Path
// Collect all matching configs and their reservations. We need to check
// all rules before consuming any tokens, so that if any rule denies we
// can cancel all reservations.
type ruleReservation struct {
cfg *RateLimitConfig
reservation *rate.Reservation
}
reservations := make([]ruleReservation, 0, len(rl.configs))
for _, cfg := range rl.configs {
if !cfg.Matches(path) {
continue
}
// Create composite key: client key + path pattern for
// independent limiting per rule. Using a struct instead of
// string concatenation saves memory since pathPattern
// references the config's string.
cacheKey := limiterKey{
clientKey: key,
pathPattern: cfg.PathRegexp,
}
limiter := rl.getOrCreateLimiter(cacheKey, cfg)
reservation := limiter.Reserve()
reservations = append(reservations, ruleReservation{
cfg: cfg,
reservation: reservation,
})
}
// If no rules matched, allow the request.
if len(reservations) == 0 {
return true, 0
}
// Check if all reservations can proceed immediately. If any rule
// denies, we must cancel ALL reservations to avoid consuming tokens
// unfairly.
var maxWait time.Duration
allAllowed := true
for _, rr := range reservations {
if !rr.reservation.OK() {
// Rate is zero or infinity.
allAllowed = false
maxWait = time.Second
break
}
delay := rr.reservation.Delay()
if delay > 0 {
allAllowed = false
if delay > maxWait {
maxWait = delay
}
}
}
// If any rule denied, cancel all reservations and return denied.
if !allAllowed {
for _, rr := range reservations {
rr.reservation.Cancel()
rateLimitDenied.WithLabelValues(
rl.serviceName, rr.cfg.PathRegexp,
).Inc()
}
return false, maxWait
}
// All rules allowed - tokens are consumed, record metrics.
for _, rr := range reservations {
rateLimitAllowed.WithLabelValues(
rl.serviceName, rr.cfg.PathRegexp,
).Inc()
}
return true, 0
}
// getOrCreateLimiter retrieves an existing limiter or creates a new one.
func (rl *RateLimiter) getOrCreateLimiter(key limiterKey,
cfg *RateLimitConfig) *rate.Limiter {
rl.cacheMu.Lock()
defer rl.cacheMu.Unlock()
// Try to get existing entry from cache (also updates LRU order).
if entry, err := rl.cache.Get(key); err == nil {
return entry.limiter
}
// Create a new limiter.
limiter := rate.NewLimiter(
rate.Limit(cfg.Rate()), cfg.EffectiveBurst(),
)
entry := &limiterEntry{
limiter: limiter,
}
// Put handles eviction automatically when cache is full.
evicted, _ := rl.cache.Put(key, entry)
if evicted {
rateLimitEvictions.WithLabelValues(rl.serviceName).Inc()
}
rateLimitCacheSize.WithLabelValues(rl.serviceName).Set(
float64(rl.cache.Len()),
)
return limiter
}
// Size returns the current number of entries in the cache.
func (rl *RateLimiter) Size() int {
rl.cacheMu.Lock()
defer rl.cacheMu.Unlock()
return rl.cache.Len()
}
// ExtractRateLimitKey extracts the rate-limiting key from a request.
// For authenticated requests, it uses the L402 token ID. For unauthenticated
// requests, it falls back to the client IP address.
//
// IMPORTANT: The authenticated parameter should only be true if the L402 token
// has been validated by the authenticator. Using unvalidated L402 tokens as
// keys is a DoS vector since attackers can flood the cache with garbage tokens.
func ExtractRateLimitKey(r *http.Request, remoteIP net.IP,
authenticated bool) string {
// Only use L402 token ID if the request has been authenticated.
// This prevents DoS attacks where garbage L402 tokens flood the cache.
if authenticated {
mac, _, err := l402.FromHeader(&r.Header)
if err == nil && mac != nil {
identifier, err := l402.DecodeIdentifier(
bytes.NewBuffer(mac.Id()),
)
if err == nil {
return "token:" + identifier.TokenID.String()
}
}
}
// Fall back to IP address for unauthenticated requests.
// Mask the IP to group clients from the same network segment.
return "ip:" + netutil.MaskIP(remoteIP).String()
}

436
proxy/ratelimiter_test.go Normal file
View File

@@ -0,0 +1,436 @@
package proxy
import (
"fmt"
"net"
"net/http"
"net/http/httptest"
"regexp"
"testing"
"time"
"github.com/stretchr/testify/require"
)
// TestRateLimiterBasic tests basic rate limiting functionality.
func TestRateLimiterBasic(t *testing.T) {
cfg := &RateLimitConfig{
PathRegexp: "^/api/.*$",
Requests: 10,
Per: time.Second,
Burst: 10,
}
cfg.compiledPathRegexp = regexp.MustCompile(cfg.PathRegexp)
rl := NewRateLimiter("test-service", []*RateLimitConfig{cfg})
// First 10 requests should be allowed.
for i := 0; i < 10; i++ {
req := httptest.NewRequest("GET", "/api/test", nil)
allowed, _ := rl.Allow(req, "test-key")
require.True(t, allowed, "request %d should be allowed", i)
}
// 11th request should be denied.
req := httptest.NewRequest("GET", "/api/test", nil)
allowed, retryAfter := rl.Allow(req, "test-key")
require.False(t, allowed, "11th request should be denied")
require.Greater(t, retryAfter, time.Duration(0))
}
// TestRateLimiterNoMatchingRules tests that requests pass when no rules match.
func TestRateLimiterNoMatchingRules(t *testing.T) {
cfg := &RateLimitConfig{
PathRegexp: "^/api/.*$",
Requests: 1,
Per: time.Hour,
Burst: 1,
}
cfg.compiledPathRegexp = regexp.MustCompile(cfg.PathRegexp)
rl := NewRateLimiter("test-service", []*RateLimitConfig{cfg})
// Request to non-matching path should always be allowed.
for i := 0; i < 100; i++ {
req := httptest.NewRequest("GET", "/other/path", nil)
allowed, _ := rl.Allow(req, "test-key")
require.True(t, allowed, "non-matching request should be allowed")
}
}
// TestRateLimiterLRUEviction tests that the LRU cache evicts old entries.
func TestRateLimiterLRUEviction(t *testing.T) {
cfg := &RateLimitConfig{
Requests: 100,
Per: time.Second,
Burst: 100,
}
rl := NewRateLimiter(
"test-service", []*RateLimitConfig{cfg},
WithMaxCacheSize(5),
)
// Create 10 different keys.
for i := 0; i < 10; i++ {
req := httptest.NewRequest("GET", "/api/test", nil)
key := fmt.Sprintf("key-%d", i)
rl.Allow(req, key)
}
// Cache should be at max size.
require.Equal(t, 5, rl.Size())
}
// TestRateLimiterPathMatching tests that different path patterns have
// independent limits.
func TestRateLimiterPathMatching(t *testing.T) {
cfgApi := &RateLimitConfig{
PathRegexp: "^/api/.*$",
Requests: 5,
Per: time.Second,
Burst: 5,
}
cfgApi.compiledPathRegexp = regexp.MustCompile(cfgApi.PathRegexp)
cfgAdmin := &RateLimitConfig{
PathRegexp: "^/admin/.*$",
Requests: 2,
Per: time.Second,
Burst: 2,
}
cfgAdmin.compiledPathRegexp = regexp.MustCompile(cfgAdmin.PathRegexp)
rl := NewRateLimiter(
"test-service",
[]*RateLimitConfig{cfgApi, cfgAdmin},
)
// API path should allow 5 requests.
for i := 0; i < 5; i++ {
req := httptest.NewRequest("GET", "/api/users", nil)
allowed, _ := rl.Allow(req, "test-key")
require.True(t, allowed)
}
// Admin path should allow 2 requests.
for i := 0; i < 2; i++ {
req := httptest.NewRequest("GET", "/admin/settings", nil)
allowed, _ := rl.Allow(req, "test-key")
require.True(t, allowed)
}
// Next admin request should be denied.
req := httptest.NewRequest("GET", "/admin/settings", nil)
allowed, _ := rl.Allow(req, "test-key")
require.False(t, allowed)
// API should still have capacity (used 5, burst is 5, but we're testing
// a 6th).
req = httptest.NewRequest("GET", "/api/users", nil)
allowed, _ = rl.Allow(req, "test-key")
require.False(t, allowed, "6th API request should be denied")
}
// TestRateLimiterMultipleRulesAllMustPass tests that all matching rules must
// pass for a request to be allowed.
func TestRateLimiterMultipleRulesAllMustPass(t *testing.T) {
// Global rule: 100 req/sec.
cfgGlobal := &RateLimitConfig{
Requests: 100,
Per: time.Second,
Burst: 100,
}
// Specific rule: 2 req/sec for /expensive.
cfgExpensive := &RateLimitConfig{
PathRegexp: "^/expensive$",
Requests: 2,
Per: time.Second,
Burst: 2,
}
cfgExpensive.compiledPathRegexp = regexp.MustCompile(cfgExpensive.PathRegexp)
rl := NewRateLimiter(
"test-service",
[]*RateLimitConfig{cfgGlobal, cfgExpensive},
)
// Expensive should be limited by the stricter rule.
for i := 0; i < 2; i++ {
req := httptest.NewRequest("GET", "/expensive", nil)
allowed, _ := rl.Allow(req, "test-key")
require.True(t, allowed)
}
req := httptest.NewRequest("GET", "/expensive", nil)
allowed, _ := rl.Allow(req, "test-key")
require.False(t, allowed, "should be denied by /expensive rule")
}
// TestRateLimiterPerKeyIsolation tests that different keys have independent
// rate limits.
func TestRateLimiterPerKeyIsolation(t *testing.T) {
cfg := &RateLimitConfig{
Requests: 2,
Per: time.Second,
Burst: 2,
}
rl := NewRateLimiter("test-service", []*RateLimitConfig{cfg})
// User 1 uses their quota.
for i := 0; i < 2; i++ {
req := httptest.NewRequest("GET", "/api/test", nil)
allowed, _ := rl.Allow(req, "user-1")
require.True(t, allowed)
}
// User 1 is now denied.
req := httptest.NewRequest("GET", "/api/test", nil)
allowed, _ := rl.Allow(req, "user-1")
require.False(t, allowed)
// User 2 should still have full quota.
for i := 0; i < 2; i++ {
req := httptest.NewRequest("GET", "/api/test", nil)
allowed, _ := rl.Allow(req, "user-2")
require.True(t, allowed)
}
}
// TestExtractRateLimitKeyIP tests IP-based key extraction for unauthenticated
// requests.
func TestExtractRateLimitKeyIP(t *testing.T) {
req := httptest.NewRequest("GET", "/api/test", nil)
ip := net.ParseIP("192.168.1.100")
// Unauthenticated request should use masked IP (/24 for IPv4).
key := ExtractRateLimitKey(req, ip, false)
require.Equal(t, "ip:192.168.1.0", key)
}
// TestExtractRateLimitKeyIPv6 tests IPv6 key extraction.
func TestExtractRateLimitKeyIPv6(t *testing.T) {
req := httptest.NewRequest("GET", "/api/test", nil)
ip := net.ParseIP("2001:db8:1234:5678::1")
// IPv6 should be masked to /48.
key := ExtractRateLimitKey(req, ip, false)
require.Equal(t, "ip:2001:db8:1234::", key)
}
// TestExtractRateLimitKeyUnauthenticatedIgnoresL402 tests that unauthenticated
// requests fall back to IP even if L402 header is present. This prevents DoS
// attacks where garbage L402 tokens flood the cache.
func TestExtractRateLimitKeyUnauthenticatedIgnoresL402(t *testing.T) {
req := httptest.NewRequest("GET", "/api/test", nil)
// Add a garbage L402 header that would be present before authentication.
req.Header.Set("Authorization", "L402 garbage:token")
ip := net.ParseIP("192.168.1.100")
// Even with L402 header present, unauthenticated=false should use
// masked IP.
key := ExtractRateLimitKey(req, ip, false)
require.Equal(t, "ip:192.168.1.0", key)
}
// TestRateLimitConfigRate tests the Rate() calculation.
func TestRateLimitConfigRate(t *testing.T) {
tests := []struct {
name string
requests int
per time.Duration
wantRate float64
}{
{
name: "10 per second",
requests: 10,
per: time.Second,
wantRate: 10.0,
},
{
name: "60 per minute",
requests: 60,
per: time.Minute,
wantRate: 1.0,
},
{
name: "1 per hour",
requests: 1,
per: time.Hour,
wantRate: 1.0 / 3600.0,
},
{
name: "zero per",
requests: 10,
per: 0,
wantRate: 0,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cfg := &RateLimitConfig{
Requests: tt.requests,
Per: tt.per,
}
require.InDelta(t, tt.wantRate, cfg.Rate(), 0.0001)
})
}
}
// TestRateLimitConfigEffectiveBurst tests the EffectiveBurst() calculation.
func TestRateLimitConfigEffectiveBurst(t *testing.T) {
tests := []struct {
name string
requests int
burst int
wantBurst int
}{
{
name: "explicit burst",
requests: 10,
burst: 20,
wantBurst: 20,
},
{
name: "default to requests",
requests: 10,
burst: 0,
wantBurst: 10,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cfg := &RateLimitConfig{
Requests: tt.requests,
Burst: tt.burst,
}
require.Equal(t, tt.wantBurst, cfg.EffectiveBurst())
})
}
}
// TestRateLimitConfigMatches tests the Matches() method.
func TestRateLimitConfigMatches(t *testing.T) {
tests := []struct {
name string
pathRegexp string
path string
want bool
}{
{
name: "no pattern matches all",
pathRegexp: "",
path: "/anything",
want: true,
},
{
name: "pattern matches",
pathRegexp: "^/api/.*$",
path: "/api/users",
want: true,
},
{
name: "pattern does not match",
pathRegexp: "^/api/.*$",
path: "/admin/users",
want: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cfg := &RateLimitConfig{
PathRegexp: tt.pathRegexp,
}
if tt.pathRegexp != "" {
cfg.compiledPathRegexp = regexp.MustCompile(
tt.pathRegexp,
)
}
require.Equal(t, tt.want, cfg.Matches(tt.path))
})
}
}
// TestSendRateLimitResponseHTTP tests HTTP rate limit response.
func TestSendRateLimitResponseHTTP(t *testing.T) {
w := httptest.NewRecorder()
req := httptest.NewRequest("GET", "/api/test", nil)
sendRateLimitResponse(w, req, 5*time.Second)
require.Equal(t, http.StatusTooManyRequests, w.Code)
require.Equal(t, "5", w.Header().Get("Retry-After"))
require.Contains(t, w.Body.String(), "rate limit exceeded")
}
// TestSendRateLimitResponseHTTPSubSecond tests that sub-second delays are
// rounded up to 1 second.
func TestSendRateLimitResponseHTTPSubSecond(t *testing.T) {
w := httptest.NewRecorder()
req := httptest.NewRequest("GET", "/api/test", nil)
sendRateLimitResponse(w, req, 500*time.Millisecond)
require.Equal(t, http.StatusTooManyRequests, w.Code)
require.Equal(t, "1", w.Header().Get("Retry-After"))
}
// TestSendRateLimitResponseHTTPRoundUp tests that fractional seconds are
// rounded up, not down. This ensures clients don't retry before the limit
// actually resets.
func TestSendRateLimitResponseHTTPRoundUp(t *testing.T) {
w := httptest.NewRecorder()
req := httptest.NewRequest("GET", "/api/test", nil)
// 1.1 seconds should round up to 2 seconds, not down to 1.
sendRateLimitResponse(w, req, 1100*time.Millisecond)
require.Equal(t, http.StatusTooManyRequests, w.Code)
require.Equal(t, "2", w.Header().Get("Retry-After"))
}
// TestSendRateLimitResponseGRPC tests gRPC rate limit response.
func TestSendRateLimitResponseGRPC(t *testing.T) {
w := httptest.NewRecorder()
req := httptest.NewRequest("POST", "/grpc.Service/Method", nil)
req.Header.Set("Content-Type", "application/grpc")
sendRateLimitResponse(w, req, 5*time.Second)
require.Equal(t, http.StatusOK, w.Code) // gRPC always returns 200.
require.Equal(t, "5", w.Header().Get("Retry-After"))
require.Equal(t, "8", w.Header().Get("Grpc-Status")) // ResourceExhausted.
require.Equal(t, "rate limit exceeded", w.Header().Get("Grpc-Message"))
}
// TestRateLimiterTokenRefill tests that tokens refill over time.
func TestRateLimiterTokenRefill(t *testing.T) {
cfg := &RateLimitConfig{
Requests: 10,
Per: 100 * time.Millisecond, // Fast refill for testing.
Burst: 1,
}
rl := NewRateLimiter("test-service", []*RateLimitConfig{cfg})
// Use the one available token.
req := httptest.NewRequest("GET", "/api/test", nil)
allowed, _ := rl.Allow(req, "test-key")
require.True(t, allowed)
// Immediate second request should be denied.
allowed, _ = rl.Allow(req, "test-key")
require.False(t, allowed)
// Wait for refill.
time.Sleep(15 * time.Millisecond)
// Should have a token now.
allowed, _ = rl.Allow(req, "test-key")
require.True(t, allowed)
}

View File

@@ -110,6 +110,12 @@ type Service struct {
// request, but still try to do the l402 authentication.
AuthSkipInvoiceCreationPaths []string `long:"authskipinvoicecreationpaths" description:"List of regular expressions for paths that will skip invoice creation'"`
// RateLimits is an optional list of rate-limiting rules for this
// service. Each rule specifies a path pattern and rate limit
// parameters. All matching rules are evaluated; if any rule denies
// the request, it is rejected.
RateLimits []*RateLimitConfig `long:"ratelimits" description:"List of rate limiting rules for this service"`
// compiledHostRegexp is the compiled host regex.
compiledHostRegexp *regexp.Regexp
@@ -123,8 +129,9 @@ type Service struct {
// invoice creation paths.
compiledAuthSkipInvoiceCreationPaths []*regexp.Regexp
freebieDB freebie.DB
pricer pricer.Pricer
freebieDB freebie.DB
pricer pricer.Pricer
rateLimiter *RateLimiter
}
// ResourceName returns the string to be used to identify which resource a
@@ -275,6 +282,47 @@ func prepareServices(services []*Service) error {
)
}
// Validate and compile rate limit configurations.
if len(service.RateLimits) > 0 {
for i, rl := range service.RateLimits {
// Validate required fields.
if rl.Requests <= 0 {
return fmt.Errorf("service %s rate "+
"limit %d: requests must be "+
"positive", service.Name, i)
}
if rl.Per <= 0 {
return fmt.Errorf("service %s rate "+
"limit %d: per duration must "+
"be positive", service.Name, i)
}
// Compile path regex if provided.
if rl.PathRegexp != "" {
compiled, err := regexp.Compile(
rl.PathRegexp,
)
if err != nil {
return fmt.Errorf("service %s "+
"rate limit %d: error "+
"compiling path regex: "+
"%w", service.Name, i,
err)
}
rl.compiledPathRegexp = compiled
}
}
// Create the rate limiter for this service.
service.rateLimiter = NewRateLimiter(
service.Name, service.RateLimits,
)
log.Infof("Initialized rate limiter for service %s "+
"with %d rules", service.Name,
len(service.RateLimits))
}
// If dynamic prices are enabled then use the provided
// DynamicPrice options to initialise a gRPC backed
// pricer client.

View File

@@ -195,6 +195,28 @@ services:
authskipinvoicecreationpaths:
- '^/streamingservice.*$'
# Optional per-endpoint rate limits using a token bucket algorithm.
# Rate limiting is applied per L402 token ID (or IP address for
# unauthenticated requests). All matching rules are evaluated; if any
# rule denies the request, it is rejected.
ratelimits:
# Rate limit for general API endpoints.
- pathregexp: '^/looprpc.SwapServer/LoopOutTerms.*$'
# Number of requests allowed per time window. Must be provided and
# positive.
requests: 5
# Time window duration (e.g., 1s, 1m, 1h). Must be provided and
# positive.
per: 1s
# Maximum burst capacity. Must be positive if provided.
burst: 100
# Stricter rate limit for quote endpoints.
- pathregexp: '^/looprpc.SwapServer/LoopOutQuote.*$'
requests: 2
per: 1s
burst: 2
# Options to use for connection to the price serving gRPC server.
dynamicprice:
# Whether or not a gRPC server is available to query price data from. If