diff --git a/README.md b/README.md index 8ed056a..497a8a2 100644 --- a/README.md +++ b/README.md @@ -50,3 +50,69 @@ services and APIs. compare with `sample-conf.yaml`. * Start aperture without any command line parameters (`./aperture`), all configuration is done in the `~/.aperture/aperture.yaml` file. + +## Rate Limiting + +Aperture supports optional per-endpoint rate limiting using a token bucket +algorithm. Rate limits are configured per service and applied based on the +client's L402 token ID for authenticated requests, or IP address for +unauthenticated requests. + +### Features + +* **Token bucket algorithm**: Allows controlled bursting while maintaining a + steady-state request rate. +* **Per-client isolation**: Each L402 token ID or IP address has independent + rate limit buckets. +* **Path-based rules**: Different endpoints can have different rate limits using + regular expressions. +* **Multiple rules**: All matching rules are evaluated; if any rule denies the + request, it is rejected. This allows layering global and endpoint-specific + limits. +* **Protocol-aware responses**: Returns HTTP 429 with `Retry-After` header for + REST requests, and gRPC `ResourceExhausted` status for gRPC requests. + +### Configuration + +Rate limits are configured in the `ratelimits` section of each service: + +```yaml +services: + - name: "myservice" + hostregexp: "api.example.com" + address: "127.0.0.1:8080" + protocol: https + + ratelimits: + # Global rate limit for all endpoints + - requests: 100 # Requests allowed per time window + per: 1s # Time window duration (1s, 1m, 1h, etc.) + burst: 100 # Max burst capacity (defaults to 'requests') + + # Stricter limit for expensive endpoints + - pathregexp: '^/api/v1/expensive.*$' + requests: 5 + per: 1m + burst: 5 +``` + +This example configures two rate limit rules using a token bucket algorithm. Each +client gets a "bucket" of tokens that refills at the `requests/per` rate, up to the +`burst` capacity. A request consumes one token; if no tokens are available, the +request is rejected. This allows clients to make quick bursts of requests (up to +`burst`) while enforcing a steady-state rate limit over time. + +1. **Global limit**: All endpoints are limited to 100 requests per second per client, + with a burst capacity of 100. +2. **Endpoint-specific limit**: Paths matching `/api/v1/expensive.*` have a stricter + limit of 5 requests per minute with a burst of 5. Since both rules are evaluated, + requests to expensive endpoints must satisfy both limits. + +### Configuration Options + +| Option | Description | Required | +|--------|-------------|----------| +| `pathregexp` | Regular expression to match request paths. If omitted, matches all paths. | No | +| `requests` | Number of requests allowed per time window. | Yes | +| `per` | Time window duration (e.g., `1s`, `1m`, `1h`). | Yes | +| `burst` | Maximum burst size. Defaults to `requests` if not set. | No | diff --git a/freebie/mem_store.go b/freebie/mem_store.go index 4a97b98..848404a 100644 --- a/freebie/mem_store.go +++ b/freebie/mem_store.go @@ -3,10 +3,8 @@ package freebie import ( "net" "net/http" -) -var ( - defaultIPMask = net.IPv4Mask(0xff, 0xff, 0xff, 0x00) + "github.com/lightninglabs/aperture/netutil" ) type Count uint16 @@ -17,7 +15,7 @@ type memStore struct { } func (m *memStore) getKey(ip net.IP) string { - return ip.Mask(defaultIPMask).String() + return netutil.MaskIP(ip).String() } func (m *memStore) currentCount(ip net.IP) Count { @@ -38,10 +36,10 @@ func (m *memStore) TallyFreebie(r *http.Request, ip net.IP) (bool, error) { return true, nil } -// NewMemIPMaskStore creates a new in-memory freebie store that masks the last -// byte of an IP address to keep track of free requests. The last byte of the -// address is discarded for the mapping to reduce risk of abuse by users that -// have a whole range of IPs at their disposal. +// NewMemIPMaskStore creates a new in-memory freebie store that masks IP +// addresses to keep track of free requests. IPv4 addresses are masked to /24 +// and IPv6 addresses to /48. This reduces risk of abuse by users that have a +// whole range of IPs at their disposal. func NewMemIPMaskStore(numFreebies Count) DB { return &memStore{ numFreebies: numFreebies, diff --git a/netutil/ip.go b/netutil/ip.go new file mode 100644 index 0000000..d2ede7f --- /dev/null +++ b/netutil/ip.go @@ -0,0 +1,29 @@ +package netutil + +import "net" + +var ( + // ipv4Mask24 masks IPv4 addresses to /24 (last octet zeroed). + // This groups clients on the same subnet together. + ipv4Mask24 = net.CIDRMask(24, 32) + + // ipv6Mask48 masks IPv6 addresses to /48. + // Residential connections typically receive /48 to /64 allocations, + // so /48 provides reasonable grouping for rate limiting purposes. + ipv6Mask48 = net.CIDRMask(48, 128) +) + +// MaskIP returns a masked version of the IP address for grouping purposes. +// IPv4 addresses are masked to /24 (zeroing the last octet). +// IPv6 addresses are masked to /48. +// +// This is useful for rate limiting and freebie tracking where we want to +// group requests from the same network segment rather than individual IPs, +// reducing abuse potential from users with multiple addresses. +func MaskIP(ip net.IP) net.IP { + if ip4 := ip.To4(); ip4 != nil { + return ip4.Mask(ipv4Mask24) + } + + return ip.Mask(ipv6Mask48) +} diff --git a/netutil/ip_test.go b/netutil/ip_test.go new file mode 100644 index 0000000..6c714d0 --- /dev/null +++ b/netutil/ip_test.go @@ -0,0 +1,117 @@ +package netutil + +import ( + "net" + "testing" + + "github.com/stretchr/testify/require" +) + +// TestMaskIP verifies that MaskIP correctly applies /24 masks to IPv4 and /48 +// masks to IPv6 addresses. +func TestMaskIP(t *testing.T) { + tests := []struct { + name string + input string + expected string + }{ + { + name: "IPv4 masks last octet", + input: "192.168.1.123", + expected: "192.168.1.0", + }, + { + name: "IPv4 already masked", + input: "10.0.0.0", + expected: "10.0.0.0", + }, + { + name: "IPv4 different last octet same result", + input: "192.168.1.255", + expected: "192.168.1.0", + }, + { + name: "IPv6 masks to /48", + input: "2001:db8:1234:5678:9abc:def0:1234:5678", + expected: "2001:db8:1234::", + }, + { + name: "IPv6 already masked", + input: "2001:db8:abcd::", + expected: "2001:db8:abcd::", + }, + { + name: "IPv6 loopback", + input: "::1", + expected: "::", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + ip := net.ParseIP(tc.input) + require.NotNil(t, ip, "failed to parse input IP") + + result := MaskIP(ip) + require.Equal(t, tc.expected, result.String()) + }) + } +} + +// TestMaskIP_SameSubnetGroupsTogether verifies that IPv4 addresses in the same +// /24 subnet produce identical masked results. +func TestMaskIP_SameSubnetGroupsTogether(t *testing.T) { + // Verify that IPs in the same /24 subnet produce the same masked result. + ips := []string{ + "192.168.1.1", + "192.168.1.100", + "192.168.1.255", + } + + results := make([]string, 0, len(ips)) + for _, ipStr := range ips { + ip := net.ParseIP(ipStr) + results = append(results, MaskIP(ip).String()) + } + + // All should be the same. + for i := 1; i < len(results); i++ { + require.Equal(t, results[0], results[i], + "IPs in same /24 should have same masked result") + } +} + +// TestMaskIP_DifferentSubnetsDiffer verifies that IPv4 addresses in different +// /24 subnets produce distinct masked results. +func TestMaskIP_DifferentSubnetsDiffer(t *testing.T) { + ip1 := net.ParseIP("192.168.1.100") + ip2 := net.ParseIP("192.168.2.100") + + result1 := MaskIP(ip1).String() + result2 := MaskIP(ip2).String() + + require.NotEqual(t, result1, result2, + "IPs in different /24 subnets should have different masked results") +} + +// TestMaskIP_IPv6SamePrefix48GroupsTogether verifies that IPv6 addresses +// sharing the same /48 prefix produce identical masked results. +func TestMaskIP_IPv6SamePrefix48GroupsTogether(t *testing.T) { + // IPs in the same /48 should produce the same masked result. + ips := []string{ + "2001:db8:1234:0001::", + "2001:db8:1234:ffff::", + "2001:db8:1234:abcd:1234:5678:9abc:def0", + } + + results := make([]string, 0, len(ips)) + for _, ipStr := range ips { + ip := net.ParseIP(ipStr) + results = append(results, MaskIP(ip).String()) + } + + for i := 1; i < len(results); i++ { + require.Equal(t, results[0], results[i], + "IPs in same /48 should have same masked result") + } +} diff --git a/proxy/proxy.go b/proxy/proxy.go index 7dfb5d4..ff47038 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -4,12 +4,14 @@ import ( "crypto/tls" "crypto/x509" "fmt" + "math" "net" "net/http" "net/http/httputil" "os" "strconv" "strings" + "time" "github.com/lightninglabs/aperture/auth" "github.com/lightninglabs/aperture/l402" @@ -172,6 +174,26 @@ func (p *Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) { // Determine auth level required to access service and dispatch request // accordingly. authLevel := target.AuthRequired(r) + + // checkRateLimit is a helper that checks rate limits after determining + // the authentication status. This ensures we only use L402 token IDs + // for authenticated requests, preventing DoS via garbage tokens. + checkRateLimit := func(authenticated bool) bool { + if target.rateLimiter == nil { + return true + } + key := ExtractRateLimitKey(r, remoteIP, authenticated) + allowed, retryAfter := target.rateLimiter.Allow(r, key) + if !allowed { + prefixLog.Infof("Rate limit exceeded for key %s, "+ + "retry after %v", key, retryAfter) + addCorsHeaders(w.Header()) + sendRateLimitResponse(w, r, retryAfter) + } + + return allowed + } + skipInvoiceCreation := target.SkipInvoiceCreation(r) switch { case authLevel.IsOn(): @@ -215,6 +237,11 @@ func (p *Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } + // User is authenticated, apply rate limit with L402 token ID. + if !checkRateLimit(true) { + return + } + case authLevel.IsFreebie(): // We only need to respect the freebie counter if the user // is not authenticated at all. @@ -267,6 +294,21 @@ func (p *Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) { ) return } + + // Unauthenticated freebie user, rate limit by IP. + if !checkRateLimit(false) { + return + } + } else if !checkRateLimit(true) { + // Authenticated user on freebie path, rate limit by + // L402 token. + return + } + + default: + // Auth is off, rate limit by IP for unauthenticated access. + if !checkRateLimit(false) { + return } } @@ -486,6 +528,36 @@ func sendDirectResponse(w http.ResponseWriter, r *http.Request, } } +// sendRateLimitResponse sends a rate limit exceeded response to the client. +// For HTTP clients, it returns 429 Too Many Requests with Retry-After header. +// For gRPC clients, it returns a ResourceExhausted status. +func sendRateLimitResponse(w http.ResponseWriter, r *http.Request, + retryAfter time.Duration) { + + // Round up to ensure clients don't retry before the limit resets. + retrySeconds := int(math.Ceil(retryAfter.Seconds())) + if retrySeconds < 1 { + retrySeconds = 1 + } + + // Set Retry-After header for both HTTP and gRPC. + w.Header().Set("Retry-After", strconv.Itoa(retrySeconds)) + + // Check if this is a gRPC request. + if strings.HasPrefix(r.Header.Get(hdrContentType), hdrTypeGrpc) { + w.Header().Set( + hdrGrpcStatus, + strconv.Itoa(int(codes.ResourceExhausted)), + ) + w.Header().Set(hdrGrpcMessage, "rate limit exceeded") + + // gRPC requires 200 OK even for errors. + w.WriteHeader(http.StatusOK) + } else { + http.Error(w, "rate limit exceeded", http.StatusTooManyRequests) + } +} + type trailerFixingTransport struct { next http.RoundTripper } diff --git a/proxy/ratelimit_config.go b/proxy/ratelimit_config.go new file mode 100644 index 0000000..953b0db --- /dev/null +++ b/proxy/ratelimit_config.go @@ -0,0 +1,56 @@ +package proxy + +import ( + "regexp" + "time" +) + +// RateLimitConfig defines a rate limiting rule for a specific path pattern. +type RateLimitConfig struct { + // PathRegexp is a regular expression that matches request paths + // to which this rate limit applies. If empty, matches all paths. + PathRegexp string `long:"pathregexp" description:"Regular expression to match the path of the URL against for rate limiting"` + + // Requests is the number of requests allowed per time window (Per). + Requests int `long:"requests" description:"Number of requests allowed per time window"` + + // Per is the time window duration (e.g., 1s, 1m, 1h). Defaults to 1s. + Per time.Duration `long:"per" description:"Time window for rate limiting (e.g., 1s, 1m, 1h)"` + + // Burst is the maximum number of requests that can be made in a burst, + // exceeding the steady-state rate. Defaults to Requests if not set. + Burst int `long:"burst" description:"Maximum burst size (defaults to Requests if not set)"` + + // compiledPathRegexp is the compiled version of PathRegexp. + compiledPathRegexp *regexp.Regexp +} + +// Rate returns the rate.Limit value (requests per second) for this +// configuration. +func (r *RateLimitConfig) Rate() float64 { + if r.Per == 0 { + return 0 + } + + return float64(r.Requests) / r.Per.Seconds() +} + +// EffectiveBurst returns the burst value, defaulting to Requests if Burst +// is 0. +func (r *RateLimitConfig) EffectiveBurst() int { + if r.Burst == 0 { + return r.Requests + } + + return r.Burst +} + +// Matches returns true if the given path matches this rate limit's path +// pattern. +func (r *RateLimitConfig) Matches(path string) bool { + if r.compiledPathRegexp == nil { + return true // No pattern means match all + } + + return r.compiledPathRegexp.MatchString(path) +} diff --git a/proxy/ratelimit_metrics.go b/proxy/ratelimit_metrics.go new file mode 100644 index 0000000..2708eff --- /dev/null +++ b/proxy/ratelimit_metrics.go @@ -0,0 +1,52 @@ +package proxy + +import ( + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promauto" +) + +var ( + // rateLimitAllowed counts requests that passed rate limiting. + rateLimitAllowed = promauto.NewCounterVec( + prometheus.CounterOpts{ + Namespace: "aperture", + Subsystem: "ratelimit", + Name: "allowed_total", + Help: "Total number of requests allowed by rate limiter", + }, + []string{"service", "path_pattern"}, + ) + + // rateLimitDenied counts requests denied by rate limiting. + rateLimitDenied = promauto.NewCounterVec( + prometheus.CounterOpts{ + Namespace: "aperture", + Subsystem: "ratelimit", + Name: "denied_total", + Help: "Total number of requests denied by rate limiter", + }, + []string{"service", "path_pattern"}, + ) + + // rateLimitCacheSize tracks the current size of the rate limiter cache. + rateLimitCacheSize = promauto.NewGaugeVec( + prometheus.GaugeOpts{ + Namespace: "aperture", + Subsystem: "ratelimit", + Name: "cache_size", + Help: "Current number of entries in the rate limiter cache", + }, + []string{"service"}, + ) + + // rateLimitEvictions counts LRU cache evictions. + rateLimitEvictions = promauto.NewCounterVec( + prometheus.CounterOpts{ + Namespace: "aperture", + Subsystem: "ratelimit", + Name: "evictions_total", + Help: "Total number of rate limiter cache evictions", + }, + []string{"service"}, + ) +) diff --git a/proxy/ratelimiter.go b/proxy/ratelimiter.go new file mode 100644 index 0000000..ad837f9 --- /dev/null +++ b/proxy/ratelimiter.go @@ -0,0 +1,252 @@ +package proxy + +import ( + "bytes" + "net" + "net/http" + "sync" + "time" + + "github.com/lightninglabs/aperture/l402" + "github.com/lightninglabs/aperture/netutil" + "github.com/lightninglabs/neutrino/cache/lru" + "golang.org/x/time/rate" +) + +const ( + // DefaultMaxCacheSize is the default maximum number of rate limiter + // entries to keep in the LRU cache. + DefaultMaxCacheSize = 10_000 +) + +// limiterKey is a composite key for the rate limiter cache. Using a struct +// instead of a concatenated string saves memory because the pathPattern field +// can reference the same underlying string across multiple keys. +type limiterKey struct { + // clientKey identifies the client (e.g., "ip:1.2.3.4" or "token:abc"). + clientKey string + // pathPattern is the rate limit rule's PathRegexp (pointer to config's + // string, not a copy). + pathPattern string +} + +// limiterEntry holds a rate.Limiter. Implements cache.Value interface. +type limiterEntry struct { + limiter *rate.Limiter +} + +// Size implements cache.Value. Returns 1 so the LRU cache counts entries +// rather than bytes. +func (e *limiterEntry) Size() (uint64, error) { + return 1, nil +} + +// RateLimiter manages per-key rate limiters with LRU eviction. +type RateLimiter struct { + // cacheMu protects the LRU cache which is not concurrency-safe. + cacheMu sync.Mutex + + // configs is the list of rate limit configurations for this limiter. + configs []*RateLimitConfig + + // cache is the LRU cache of rate limiter entries. + cache *lru.Cache[limiterKey, *limiterEntry] + + // maxSize is the maximum number of entries in the cache. + maxSize int + + // serviceName is used for metrics labels. + serviceName string +} + +// RateLimiterOption is a functional option for configuring a RateLimiter. +type RateLimiterOption func(*RateLimiter) + +// WithMaxCacheSize sets the maximum cache size. +func WithMaxCacheSize(size int) RateLimiterOption { + return func(rl *RateLimiter) { + rl.maxSize = size + } +} + +// NewRateLimiter creates a new RateLimiter with the given configurations. +func NewRateLimiter(serviceName string, configs []*RateLimitConfig, + opts ...RateLimiterOption) *RateLimiter { + + rl := &RateLimiter{ + configs: configs, + maxSize: DefaultMaxCacheSize, + serviceName: serviceName, + } + + for _, opt := range opts { + opt(rl) + } + + // Initialize the LRU cache with the configured max size. + rl.cache = lru.NewCache[limiterKey, *limiterEntry](uint64(rl.maxSize)) + + return rl +} + +// Allow checks if a request should be allowed based on all matching rate +// limits. Returns (allowed, retryAfter) where retryAfter is the suggested +// duration to wait if denied. +func (rl *RateLimiter) Allow(r *http.Request, key string) (bool, + time.Duration) { + + path := r.URL.Path + + // Collect all matching configs and their reservations. We need to check + // all rules before consuming any tokens, so that if any rule denies we + // can cancel all reservations. + type ruleReservation struct { + cfg *RateLimitConfig + reservation *rate.Reservation + } + reservations := make([]ruleReservation, 0, len(rl.configs)) + + for _, cfg := range rl.configs { + if !cfg.Matches(path) { + continue + } + + // Create composite key: client key + path pattern for + // independent limiting per rule. Using a struct instead of + // string concatenation saves memory since pathPattern + // references the config's string. + cacheKey := limiterKey{ + clientKey: key, + pathPattern: cfg.PathRegexp, + } + + limiter := rl.getOrCreateLimiter(cacheKey, cfg) + reservation := limiter.Reserve() + + reservations = append(reservations, ruleReservation{ + cfg: cfg, + reservation: reservation, + }) + } + + // If no rules matched, allow the request. + if len(reservations) == 0 { + return true, 0 + } + + // Check if all reservations can proceed immediately. If any rule + // denies, we must cancel ALL reservations to avoid consuming tokens + // unfairly. + var maxWait time.Duration + allAllowed := true + + for _, rr := range reservations { + if !rr.reservation.OK() { + // Rate is zero or infinity. + allAllowed = false + maxWait = time.Second + + break + } + + delay := rr.reservation.Delay() + if delay > 0 { + allAllowed = false + if delay > maxWait { + maxWait = delay + } + } + } + + // If any rule denied, cancel all reservations and return denied. + if !allAllowed { + for _, rr := range reservations { + rr.reservation.Cancel() + rateLimitDenied.WithLabelValues( + rl.serviceName, rr.cfg.PathRegexp, + ).Inc() + } + + return false, maxWait + } + + // All rules allowed - tokens are consumed, record metrics. + for _, rr := range reservations { + rateLimitAllowed.WithLabelValues( + rl.serviceName, rr.cfg.PathRegexp, + ).Inc() + } + + return true, 0 +} + +// getOrCreateLimiter retrieves an existing limiter or creates a new one. +func (rl *RateLimiter) getOrCreateLimiter(key limiterKey, + cfg *RateLimitConfig) *rate.Limiter { + + rl.cacheMu.Lock() + defer rl.cacheMu.Unlock() + + // Try to get existing entry from cache (also updates LRU order). + if entry, err := rl.cache.Get(key); err == nil { + return entry.limiter + } + + // Create a new limiter. + limiter := rate.NewLimiter( + rate.Limit(cfg.Rate()), cfg.EffectiveBurst(), + ) + + entry := &limiterEntry{ + limiter: limiter, + } + + // Put handles eviction automatically when cache is full. + evicted, _ := rl.cache.Put(key, entry) + if evicted { + rateLimitEvictions.WithLabelValues(rl.serviceName).Inc() + } + + rateLimitCacheSize.WithLabelValues(rl.serviceName).Set( + float64(rl.cache.Len()), + ) + + return limiter +} + +// Size returns the current number of entries in the cache. +func (rl *RateLimiter) Size() int { + rl.cacheMu.Lock() + defer rl.cacheMu.Unlock() + + return rl.cache.Len() +} + +// ExtractRateLimitKey extracts the rate-limiting key from a request. +// For authenticated requests, it uses the L402 token ID. For unauthenticated +// requests, it falls back to the client IP address. +// +// IMPORTANT: The authenticated parameter should only be true if the L402 token +// has been validated by the authenticator. Using unvalidated L402 tokens as +// keys is a DoS vector since attackers can flood the cache with garbage tokens. +func ExtractRateLimitKey(r *http.Request, remoteIP net.IP, + authenticated bool) string { + + // Only use L402 token ID if the request has been authenticated. + // This prevents DoS attacks where garbage L402 tokens flood the cache. + if authenticated { + mac, _, err := l402.FromHeader(&r.Header) + if err == nil && mac != nil { + identifier, err := l402.DecodeIdentifier( + bytes.NewBuffer(mac.Id()), + ) + if err == nil { + return "token:" + identifier.TokenID.String() + } + } + } + + // Fall back to IP address for unauthenticated requests. + // Mask the IP to group clients from the same network segment. + return "ip:" + netutil.MaskIP(remoteIP).String() +} diff --git a/proxy/ratelimiter_test.go b/proxy/ratelimiter_test.go new file mode 100644 index 0000000..a445fe7 --- /dev/null +++ b/proxy/ratelimiter_test.go @@ -0,0 +1,436 @@ +package proxy + +import ( + "fmt" + "net" + "net/http" + "net/http/httptest" + "regexp" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +// TestRateLimiterBasic tests basic rate limiting functionality. +func TestRateLimiterBasic(t *testing.T) { + cfg := &RateLimitConfig{ + PathRegexp: "^/api/.*$", + Requests: 10, + Per: time.Second, + Burst: 10, + } + cfg.compiledPathRegexp = regexp.MustCompile(cfg.PathRegexp) + + rl := NewRateLimiter("test-service", []*RateLimitConfig{cfg}) + + // First 10 requests should be allowed. + for i := 0; i < 10; i++ { + req := httptest.NewRequest("GET", "/api/test", nil) + allowed, _ := rl.Allow(req, "test-key") + require.True(t, allowed, "request %d should be allowed", i) + } + + // 11th request should be denied. + req := httptest.NewRequest("GET", "/api/test", nil) + allowed, retryAfter := rl.Allow(req, "test-key") + require.False(t, allowed, "11th request should be denied") + require.Greater(t, retryAfter, time.Duration(0)) +} + +// TestRateLimiterNoMatchingRules tests that requests pass when no rules match. +func TestRateLimiterNoMatchingRules(t *testing.T) { + cfg := &RateLimitConfig{ + PathRegexp: "^/api/.*$", + Requests: 1, + Per: time.Hour, + Burst: 1, + } + cfg.compiledPathRegexp = regexp.MustCompile(cfg.PathRegexp) + + rl := NewRateLimiter("test-service", []*RateLimitConfig{cfg}) + + // Request to non-matching path should always be allowed. + for i := 0; i < 100; i++ { + req := httptest.NewRequest("GET", "/other/path", nil) + allowed, _ := rl.Allow(req, "test-key") + require.True(t, allowed, "non-matching request should be allowed") + } +} + +// TestRateLimiterLRUEviction tests that the LRU cache evicts old entries. +func TestRateLimiterLRUEviction(t *testing.T) { + cfg := &RateLimitConfig{ + Requests: 100, + Per: time.Second, + Burst: 100, + } + + rl := NewRateLimiter( + "test-service", []*RateLimitConfig{cfg}, + WithMaxCacheSize(5), + ) + + // Create 10 different keys. + for i := 0; i < 10; i++ { + req := httptest.NewRequest("GET", "/api/test", nil) + key := fmt.Sprintf("key-%d", i) + rl.Allow(req, key) + } + + // Cache should be at max size. + require.Equal(t, 5, rl.Size()) +} + +// TestRateLimiterPathMatching tests that different path patterns have +// independent limits. +func TestRateLimiterPathMatching(t *testing.T) { + cfgApi := &RateLimitConfig{ + PathRegexp: "^/api/.*$", + Requests: 5, + Per: time.Second, + Burst: 5, + } + cfgApi.compiledPathRegexp = regexp.MustCompile(cfgApi.PathRegexp) + + cfgAdmin := &RateLimitConfig{ + PathRegexp: "^/admin/.*$", + Requests: 2, + Per: time.Second, + Burst: 2, + } + cfgAdmin.compiledPathRegexp = regexp.MustCompile(cfgAdmin.PathRegexp) + + rl := NewRateLimiter( + "test-service", + []*RateLimitConfig{cfgApi, cfgAdmin}, + ) + + // API path should allow 5 requests. + for i := 0; i < 5; i++ { + req := httptest.NewRequest("GET", "/api/users", nil) + allowed, _ := rl.Allow(req, "test-key") + require.True(t, allowed) + } + + // Admin path should allow 2 requests. + for i := 0; i < 2; i++ { + req := httptest.NewRequest("GET", "/admin/settings", nil) + allowed, _ := rl.Allow(req, "test-key") + require.True(t, allowed) + } + + // Next admin request should be denied. + req := httptest.NewRequest("GET", "/admin/settings", nil) + allowed, _ := rl.Allow(req, "test-key") + require.False(t, allowed) + + // API should still have capacity (used 5, burst is 5, but we're testing + // a 6th). + req = httptest.NewRequest("GET", "/api/users", nil) + allowed, _ = rl.Allow(req, "test-key") + require.False(t, allowed, "6th API request should be denied") +} + +// TestRateLimiterMultipleRulesAllMustPass tests that all matching rules must +// pass for a request to be allowed. +func TestRateLimiterMultipleRulesAllMustPass(t *testing.T) { + // Global rule: 100 req/sec. + cfgGlobal := &RateLimitConfig{ + Requests: 100, + Per: time.Second, + Burst: 100, + } + + // Specific rule: 2 req/sec for /expensive. + cfgExpensive := &RateLimitConfig{ + PathRegexp: "^/expensive$", + Requests: 2, + Per: time.Second, + Burst: 2, + } + cfgExpensive.compiledPathRegexp = regexp.MustCompile(cfgExpensive.PathRegexp) + + rl := NewRateLimiter( + "test-service", + []*RateLimitConfig{cfgGlobal, cfgExpensive}, + ) + + // Expensive should be limited by the stricter rule. + for i := 0; i < 2; i++ { + req := httptest.NewRequest("GET", "/expensive", nil) + allowed, _ := rl.Allow(req, "test-key") + require.True(t, allowed) + } + + req := httptest.NewRequest("GET", "/expensive", nil) + allowed, _ := rl.Allow(req, "test-key") + require.False(t, allowed, "should be denied by /expensive rule") +} + +// TestRateLimiterPerKeyIsolation tests that different keys have independent +// rate limits. +func TestRateLimiterPerKeyIsolation(t *testing.T) { + cfg := &RateLimitConfig{ + Requests: 2, + Per: time.Second, + Burst: 2, + } + + rl := NewRateLimiter("test-service", []*RateLimitConfig{cfg}) + + // User 1 uses their quota. + for i := 0; i < 2; i++ { + req := httptest.NewRequest("GET", "/api/test", nil) + allowed, _ := rl.Allow(req, "user-1") + require.True(t, allowed) + } + + // User 1 is now denied. + req := httptest.NewRequest("GET", "/api/test", nil) + allowed, _ := rl.Allow(req, "user-1") + require.False(t, allowed) + + // User 2 should still have full quota. + for i := 0; i < 2; i++ { + req := httptest.NewRequest("GET", "/api/test", nil) + allowed, _ := rl.Allow(req, "user-2") + require.True(t, allowed) + } +} + +// TestExtractRateLimitKeyIP tests IP-based key extraction for unauthenticated +// requests. +func TestExtractRateLimitKeyIP(t *testing.T) { + req := httptest.NewRequest("GET", "/api/test", nil) + ip := net.ParseIP("192.168.1.100") + + // Unauthenticated request should use masked IP (/24 for IPv4). + key := ExtractRateLimitKey(req, ip, false) + require.Equal(t, "ip:192.168.1.0", key) +} + +// TestExtractRateLimitKeyIPv6 tests IPv6 key extraction. +func TestExtractRateLimitKeyIPv6(t *testing.T) { + req := httptest.NewRequest("GET", "/api/test", nil) + ip := net.ParseIP("2001:db8:1234:5678::1") + + // IPv6 should be masked to /48. + key := ExtractRateLimitKey(req, ip, false) + require.Equal(t, "ip:2001:db8:1234::", key) +} + +// TestExtractRateLimitKeyUnauthenticatedIgnoresL402 tests that unauthenticated +// requests fall back to IP even if L402 header is present. This prevents DoS +// attacks where garbage L402 tokens flood the cache. +func TestExtractRateLimitKeyUnauthenticatedIgnoresL402(t *testing.T) { + req := httptest.NewRequest("GET", "/api/test", nil) + // Add a garbage L402 header that would be present before authentication. + req.Header.Set("Authorization", "L402 garbage:token") + ip := net.ParseIP("192.168.1.100") + + // Even with L402 header present, unauthenticated=false should use + // masked IP. + key := ExtractRateLimitKey(req, ip, false) + require.Equal(t, "ip:192.168.1.0", key) +} + +// TestRateLimitConfigRate tests the Rate() calculation. +func TestRateLimitConfigRate(t *testing.T) { + tests := []struct { + name string + requests int + per time.Duration + wantRate float64 + }{ + { + name: "10 per second", + requests: 10, + per: time.Second, + wantRate: 10.0, + }, + { + name: "60 per minute", + requests: 60, + per: time.Minute, + wantRate: 1.0, + }, + { + name: "1 per hour", + requests: 1, + per: time.Hour, + wantRate: 1.0 / 3600.0, + }, + { + name: "zero per", + requests: 10, + per: 0, + wantRate: 0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cfg := &RateLimitConfig{ + Requests: tt.requests, + Per: tt.per, + } + require.InDelta(t, tt.wantRate, cfg.Rate(), 0.0001) + }) + } +} + +// TestRateLimitConfigEffectiveBurst tests the EffectiveBurst() calculation. +func TestRateLimitConfigEffectiveBurst(t *testing.T) { + tests := []struct { + name string + requests int + burst int + wantBurst int + }{ + { + name: "explicit burst", + requests: 10, + burst: 20, + wantBurst: 20, + }, + { + name: "default to requests", + requests: 10, + burst: 0, + wantBurst: 10, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cfg := &RateLimitConfig{ + Requests: tt.requests, + Burst: tt.burst, + } + require.Equal(t, tt.wantBurst, cfg.EffectiveBurst()) + }) + } +} + +// TestRateLimitConfigMatches tests the Matches() method. +func TestRateLimitConfigMatches(t *testing.T) { + tests := []struct { + name string + pathRegexp string + path string + want bool + }{ + { + name: "no pattern matches all", + pathRegexp: "", + path: "/anything", + want: true, + }, + { + name: "pattern matches", + pathRegexp: "^/api/.*$", + path: "/api/users", + want: true, + }, + { + name: "pattern does not match", + pathRegexp: "^/api/.*$", + path: "/admin/users", + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cfg := &RateLimitConfig{ + PathRegexp: tt.pathRegexp, + } + if tt.pathRegexp != "" { + cfg.compiledPathRegexp = regexp.MustCompile( + tt.pathRegexp, + ) + } + require.Equal(t, tt.want, cfg.Matches(tt.path)) + }) + } +} + +// TestSendRateLimitResponseHTTP tests HTTP rate limit response. +func TestSendRateLimitResponseHTTP(t *testing.T) { + w := httptest.NewRecorder() + req := httptest.NewRequest("GET", "/api/test", nil) + + sendRateLimitResponse(w, req, 5*time.Second) + + require.Equal(t, http.StatusTooManyRequests, w.Code) + require.Equal(t, "5", w.Header().Get("Retry-After")) + require.Contains(t, w.Body.String(), "rate limit exceeded") +} + +// TestSendRateLimitResponseHTTPSubSecond tests that sub-second delays are +// rounded up to 1 second. +func TestSendRateLimitResponseHTTPSubSecond(t *testing.T) { + w := httptest.NewRecorder() + req := httptest.NewRequest("GET", "/api/test", nil) + + sendRateLimitResponse(w, req, 500*time.Millisecond) + + require.Equal(t, http.StatusTooManyRequests, w.Code) + require.Equal(t, "1", w.Header().Get("Retry-After")) +} + +// TestSendRateLimitResponseHTTPRoundUp tests that fractional seconds are +// rounded up, not down. This ensures clients don't retry before the limit +// actually resets. +func TestSendRateLimitResponseHTTPRoundUp(t *testing.T) { + w := httptest.NewRecorder() + req := httptest.NewRequest("GET", "/api/test", nil) + + // 1.1 seconds should round up to 2 seconds, not down to 1. + sendRateLimitResponse(w, req, 1100*time.Millisecond) + + require.Equal(t, http.StatusTooManyRequests, w.Code) + require.Equal(t, "2", w.Header().Get("Retry-After")) +} + +// TestSendRateLimitResponseGRPC tests gRPC rate limit response. +func TestSendRateLimitResponseGRPC(t *testing.T) { + w := httptest.NewRecorder() + req := httptest.NewRequest("POST", "/grpc.Service/Method", nil) + req.Header.Set("Content-Type", "application/grpc") + + sendRateLimitResponse(w, req, 5*time.Second) + + require.Equal(t, http.StatusOK, w.Code) // gRPC always returns 200. + require.Equal(t, "5", w.Header().Get("Retry-After")) + require.Equal(t, "8", w.Header().Get("Grpc-Status")) // ResourceExhausted. + require.Equal(t, "rate limit exceeded", w.Header().Get("Grpc-Message")) +} + +// TestRateLimiterTokenRefill tests that tokens refill over time. +func TestRateLimiterTokenRefill(t *testing.T) { + cfg := &RateLimitConfig{ + Requests: 10, + Per: 100 * time.Millisecond, // Fast refill for testing. + Burst: 1, + } + + rl := NewRateLimiter("test-service", []*RateLimitConfig{cfg}) + + // Use the one available token. + req := httptest.NewRequest("GET", "/api/test", nil) + allowed, _ := rl.Allow(req, "test-key") + require.True(t, allowed) + + // Immediate second request should be denied. + allowed, _ = rl.Allow(req, "test-key") + require.False(t, allowed) + + // Wait for refill. + time.Sleep(15 * time.Millisecond) + + // Should have a token now. + allowed, _ = rl.Allow(req, "test-key") + require.True(t, allowed) +} diff --git a/proxy/service.go b/proxy/service.go index 8f58d03..d2a952d 100644 --- a/proxy/service.go +++ b/proxy/service.go @@ -110,6 +110,12 @@ type Service struct { // request, but still try to do the l402 authentication. AuthSkipInvoiceCreationPaths []string `long:"authskipinvoicecreationpaths" description:"List of regular expressions for paths that will skip invoice creation'"` + // RateLimits is an optional list of rate-limiting rules for this + // service. Each rule specifies a path pattern and rate limit + // parameters. All matching rules are evaluated; if any rule denies + // the request, it is rejected. + RateLimits []*RateLimitConfig `long:"ratelimits" description:"List of rate limiting rules for this service"` + // compiledHostRegexp is the compiled host regex. compiledHostRegexp *regexp.Regexp @@ -123,8 +129,9 @@ type Service struct { // invoice creation paths. compiledAuthSkipInvoiceCreationPaths []*regexp.Regexp - freebieDB freebie.DB - pricer pricer.Pricer + freebieDB freebie.DB + pricer pricer.Pricer + rateLimiter *RateLimiter } // ResourceName returns the string to be used to identify which resource a @@ -275,6 +282,47 @@ func prepareServices(services []*Service) error { ) } + // Validate and compile rate limit configurations. + if len(service.RateLimits) > 0 { + for i, rl := range service.RateLimits { + // Validate required fields. + if rl.Requests <= 0 { + return fmt.Errorf("service %s rate "+ + "limit %d: requests must be "+ + "positive", service.Name, i) + } + if rl.Per <= 0 { + return fmt.Errorf("service %s rate "+ + "limit %d: per duration must "+ + "be positive", service.Name, i) + } + + // Compile path regex if provided. + if rl.PathRegexp != "" { + compiled, err := regexp.Compile( + rl.PathRegexp, + ) + if err != nil { + return fmt.Errorf("service %s "+ + "rate limit %d: error "+ + "compiling path regex: "+ + "%w", service.Name, i, + err) + } + rl.compiledPathRegexp = compiled + } + } + + // Create the rate limiter for this service. + service.rateLimiter = NewRateLimiter( + service.Name, service.RateLimits, + ) + + log.Infof("Initialized rate limiter for service %s "+ + "with %d rules", service.Name, + len(service.RateLimits)) + } + // If dynamic prices are enabled then use the provided // DynamicPrice options to initialise a gRPC backed // pricer client. diff --git a/sample-conf.yaml b/sample-conf.yaml index aaa247d..f3f4ac2 100644 --- a/sample-conf.yaml +++ b/sample-conf.yaml @@ -195,6 +195,28 @@ services: authskipinvoicecreationpaths: - '^/streamingservice.*$' + # Optional per-endpoint rate limits using a token bucket algorithm. + # Rate limiting is applied per L402 token ID (or IP address for + # unauthenticated requests). All matching rules are evaluated; if any + # rule denies the request, it is rejected. + ratelimits: + # Rate limit for general API endpoints. + - pathregexp: '^/looprpc.SwapServer/LoopOutTerms.*$' + # Number of requests allowed per time window. Must be provided and + # positive. + requests: 5 + # Time window duration (e.g., 1s, 1m, 1h). Must be provided and + # positive. + per: 1s + # Maximum burst capacity. Must be positive if provided. + burst: 100 + + # Stricter rate limit for quote endpoints. + - pathregexp: '^/looprpc.SwapServer/LoopOutQuote.*$' + requests: 2 + per: 1s + burst: 2 + # Options to use for connection to the price serving gRPC server. dynamicprice: # Whether or not a gRPC server is available to query price data from. If