Files
aperture/proxy/ratelimiter.go
2026-01-23 10:12:13 -05:00

253 lines
6.6 KiB
Go

package proxy
import (
"bytes"
"net"
"net/http"
"sync"
"time"
"github.com/lightninglabs/aperture/l402"
"github.com/lightninglabs/aperture/netutil"
"github.com/lightninglabs/neutrino/cache/lru"
"golang.org/x/time/rate"
)
const (
// DefaultMaxCacheSize is the default maximum number of rate limiter
// entries to keep in the LRU cache.
DefaultMaxCacheSize = 10_000
)
// limiterKey is a composite key for the rate limiter cache. Using a struct
// instead of a concatenated string saves memory because the pathPattern field
// can reference the same underlying string across multiple keys.
type limiterKey struct {
// clientKey identifies the client (e.g., "ip:1.2.3.4" or "token:abc").
clientKey string
// pathPattern is the rate limit rule's PathRegexp (pointer to config's
// string, not a copy).
pathPattern string
}
// limiterEntry holds a rate.Limiter. Implements cache.Value interface.
type limiterEntry struct {
limiter *rate.Limiter
}
// Size implements cache.Value. Returns 1 so the LRU cache counts entries
// rather than bytes.
func (e *limiterEntry) Size() (uint64, error) {
return 1, nil
}
// RateLimiter manages per-key rate limiters with LRU eviction.
type RateLimiter struct {
// cacheMu protects the LRU cache which is not concurrency-safe.
cacheMu sync.Mutex
// configs is the list of rate limit configurations for this limiter.
configs []*RateLimitConfig
// cache is the LRU cache of rate limiter entries.
cache *lru.Cache[limiterKey, *limiterEntry]
// maxSize is the maximum number of entries in the cache.
maxSize int
// serviceName is used for metrics labels.
serviceName string
}
// RateLimiterOption is a functional option for configuring a RateLimiter.
type RateLimiterOption func(*RateLimiter)
// WithMaxCacheSize sets the maximum cache size.
func WithMaxCacheSize(size int) RateLimiterOption {
return func(rl *RateLimiter) {
rl.maxSize = size
}
}
// NewRateLimiter creates a new RateLimiter with the given configurations.
func NewRateLimiter(serviceName string, configs []*RateLimitConfig,
opts ...RateLimiterOption) *RateLimiter {
rl := &RateLimiter{
configs: configs,
maxSize: DefaultMaxCacheSize,
serviceName: serviceName,
}
for _, opt := range opts {
opt(rl)
}
// Initialize the LRU cache with the configured max size.
rl.cache = lru.NewCache[limiterKey, *limiterEntry](uint64(rl.maxSize))
return rl
}
// Allow checks if a request should be allowed based on all matching rate
// limits. Returns (allowed, retryAfter) where retryAfter is the suggested
// duration to wait if denied.
func (rl *RateLimiter) Allow(r *http.Request, key string) (bool,
time.Duration) {
path := r.URL.Path
// Collect all matching configs and their reservations. We need to check
// all rules before consuming any tokens, so that if any rule denies we
// can cancel all reservations.
type ruleReservation struct {
cfg *RateLimitConfig
reservation *rate.Reservation
}
reservations := make([]ruleReservation, 0, len(rl.configs))
for _, cfg := range rl.configs {
if !cfg.Matches(path) {
continue
}
// Create composite key: client key + path pattern for
// independent limiting per rule. Using a struct instead of
// string concatenation saves memory since pathPattern
// references the config's string.
cacheKey := limiterKey{
clientKey: key,
pathPattern: cfg.PathRegexp,
}
limiter := rl.getOrCreateLimiter(cacheKey, cfg)
reservation := limiter.Reserve()
reservations = append(reservations, ruleReservation{
cfg: cfg,
reservation: reservation,
})
}
// If no rules matched, allow the request.
if len(reservations) == 0 {
return true, 0
}
// Check if all reservations can proceed immediately. If any rule
// denies, we must cancel ALL reservations to avoid consuming tokens
// unfairly.
var maxWait time.Duration
allAllowed := true
for _, rr := range reservations {
if !rr.reservation.OK() {
// Rate is zero or infinity.
allAllowed = false
maxWait = time.Second
break
}
delay := rr.reservation.Delay()
if delay > 0 {
allAllowed = false
if delay > maxWait {
maxWait = delay
}
}
}
// If any rule denied, cancel all reservations and return denied.
if !allAllowed {
for _, rr := range reservations {
rr.reservation.Cancel()
rateLimitDenied.WithLabelValues(
rl.serviceName, rr.cfg.PathRegexp,
).Inc()
}
return false, maxWait
}
// All rules allowed - tokens are consumed, record metrics.
for _, rr := range reservations {
rateLimitAllowed.WithLabelValues(
rl.serviceName, rr.cfg.PathRegexp,
).Inc()
}
return true, 0
}
// getOrCreateLimiter retrieves an existing limiter or creates a new one.
func (rl *RateLimiter) getOrCreateLimiter(key limiterKey,
cfg *RateLimitConfig) *rate.Limiter {
rl.cacheMu.Lock()
defer rl.cacheMu.Unlock()
// Try to get existing entry from cache (also updates LRU order).
if entry, err := rl.cache.Get(key); err == nil {
return entry.limiter
}
// Create a new limiter.
limiter := rate.NewLimiter(
rate.Limit(cfg.Rate()), cfg.EffectiveBurst(),
)
entry := &limiterEntry{
limiter: limiter,
}
// Put handles eviction automatically when cache is full.
evicted, _ := rl.cache.Put(key, entry)
if evicted {
rateLimitEvictions.WithLabelValues(rl.serviceName).Inc()
}
rateLimitCacheSize.WithLabelValues(rl.serviceName).Set(
float64(rl.cache.Len()),
)
return limiter
}
// Size returns the current number of entries in the cache.
func (rl *RateLimiter) Size() int {
rl.cacheMu.Lock()
defer rl.cacheMu.Unlock()
return rl.cache.Len()
}
// ExtractRateLimitKey extracts the rate-limiting key from a request.
// For authenticated requests, it uses the L402 token ID. For unauthenticated
// requests, it falls back to the client IP address.
//
// IMPORTANT: The authenticated parameter should only be true if the L402 token
// has been validated by the authenticator. Using unvalidated L402 tokens as
// keys is a DoS vector since attackers can flood the cache with garbage tokens.
func ExtractRateLimitKey(r *http.Request, remoteIP net.IP,
authenticated bool) string {
// Only use L402 token ID if the request has been authenticated.
// This prevents DoS attacks where garbage L402 tokens flood the cache.
if authenticated {
mac, _, err := l402.FromHeader(&r.Header)
if err == nil && mac != nil {
identifier, err := l402.DecodeIdentifier(
bytes.NewBuffer(mac.Id()),
)
if err == nil {
return "token:" + identifier.TokenID.String()
}
}
}
// Fall back to IP address for unauthenticated requests.
// Mask the IP to group clients from the same network segment.
return "ip:" + netutil.MaskIP(remoteIP).String()
}