diff --git a/go.mod b/go.mod index de2f7fc..c402ba9 100644 --- a/go.mod +++ b/go.mod @@ -8,7 +8,6 @@ require ( github.com/nbd-wtf/go-nostr v0.30.0 github.com/puzpuzpuz/xsync/v3 v3.0.2 github.com/rs/cors v1.7.0 - github.com/sebest/xff v0.0.0-20210106013422-671bd2870b3a ) require ( diff --git a/go.sum b/go.sum index c1a2a14..c4dbf4c 100644 --- a/go.sum +++ b/go.sum @@ -113,8 +113,6 @@ github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJ github.com/mattn/go-sqlite3 v1.14.6/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU= github.com/mattn/go-sqlite3 v1.14.18 h1:JL0eqdCOq6DJVNPSvArO/bIV9/P7fbGrV00LZHc+5aI= github.com/mattn/go-sqlite3 v1.14.18/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg= -github.com/nbd-wtf/go-nostr v0.28.1 h1:XQi/lBsigBXHRm7IDBJE7SR9citCh9srgf8sA5iVW3A= -github.com/nbd-wtf/go-nostr v0.28.1/go.mod h1:OQ8sNLFJnsj17BdqZiLSmjJBIFTfDqckEYC3utS4qoY= github.com/nbd-wtf/go-nostr v0.30.0 h1:rN085pe4IxmSBVht8LChZbWLggonjA8hPIk8l4/+Hjk= github.com/nbd-wtf/go-nostr v0.30.0/go.mod h1:tiKJY6fWYSujbTQb201Y+IQ3l4szqYVt+fsTnsm7FCk= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= @@ -128,8 +126,6 @@ github.com/rs/cors v1.7.0 h1:+88SsELBHx5r+hZ8TCkggzSstaWNbDvThkVK8H6f9ik= github.com/rs/cors v1.7.0/go.mod h1:gFx+x8UowdsKA9AchylcLynDq+nNFfI8FkUZdN/jGCU= github.com/savsgio/gotils v0.0.0-20230208104028-c358bd845dee h1:8Iv5m6xEo1NR1AvpV+7XmhI4r39LGNzwUL4YpMuL5vk= github.com/savsgio/gotils v0.0.0-20230208104028-c358bd845dee/go.mod h1:qwtSXrKuJh/zsFQ12yEE89xfCrGKK63Rr7ctU/uCo4g= -github.com/sebest/xff v0.0.0-20210106013422-671bd2870b3a h1:iLcLb5Fwwz7g/DLK89F+uQBDeAhHhwdzB5fSlVdhGcM= -github.com/sebest/xff v0.0.0-20210106013422-671bd2870b3a/go.mod h1:wozgYq9WEBQBaIJe4YZ0qTSFAMxmcwBhQH0fO0R34Z0= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= diff --git a/handlers.go b/handlers.go index 3659033..42ba130 100644 --- a/handlers.go +++ b/handlers.go @@ -34,6 +34,13 @@ func (rl *Relay) ServeHTTP(w http.ResponseWriter, r *http.Request) { } func (rl *Relay) HandleWebsocket(w http.ResponseWriter, r *http.Request) { + for _, reject := range rl.RejectConnection { + if reject(r) { + w.WriteHeader(418) // I'm a teapot + return + } + } + conn, err := rl.upgrader.Upgrade(w, r, nil) if err != nil { rl.Log.Printf("failed to upgrade websocket: %v\n", err) diff --git a/helpers.go b/helpers.go index a58066d..05c0fea 100644 --- a/helpers.go +++ b/helpers.go @@ -1,6 +1,7 @@ package khatru import ( + "net" "net/http" "strconv" "strings" @@ -34,3 +35,43 @@ func getServiceBaseURL(r *http.Request) string { } return proto + "://" + host } + +var privateMasks = func() []net.IPNet { + privateCIDRs := []string{ + "127.0.0.0/8", + "10.0.0.0/8", + "172.16.0.0/12", + "192.168.0.0/16", + "fc00::/7", + } + masks := make([]net.IPNet, len(privateCIDRs)) + for i, cidr := range privateCIDRs { + _, netw, err := net.ParseCIDR(cidr) + if err != nil { + return nil + } + masks[i] = *netw + } + return masks +}() + +func isPrivate(ip net.IP) bool { + for _, mask := range privateMasks { + if mask.Contains(ip) { + return true + } + } + return false +} + +func GetIPFromRequest(r *http.Request) string { + if xffh := r.Header.Get("X-Forwarded-For"); xffh != "" { + for _, v := range strings.Split(xffh, ",") { + if ip := net.ParseIP(strings.TrimSpace(v)); ip != nil && ip.IsGlobalUnicast() && !isPrivate(ip) { + return ip.String() + } + } + } + ip, _, _ := net.SplitHostPort(r.RemoteAddr) + return ip +} diff --git a/policies/helpers.go b/policies/helpers.go new file mode 100644 index 0000000..b3bfe70 --- /dev/null +++ b/policies/helpers.go @@ -0,0 +1,43 @@ +package policies + +import ( + "sync/atomic" + "time" + + "github.com/puzpuzpuz/xsync/v3" +) + +func startRateLimitSystem[K comparable]( + tokensPerInterval int, + interval time.Duration, + maxTokens int, +) func(key K) (ratelimited bool) { + negativeBuckets := xsync.NewMapOf[K, *atomic.Int32]() + maxTokensInt32 := int32(maxTokens) + + go func() { + for { + time.Sleep(interval) + negativeBuckets.Range(func(key K, bucket *atomic.Int32) bool { + newv := bucket.Add(int32(-tokensPerInterval)) + if newv <= 0 { + negativeBuckets.Delete(key) + } + return true + }) + } + }() + + return func(key K) bool { + nb, _ := negativeBuckets.LoadOrStore(key, &atomic.Int32{}) + + if nb.Load() < maxTokensInt32 { + nb.Add(1) + // rate limit not reached yet + return false + } + + // rate limit reached + return true + } +} diff --git a/policies/ratelimits.go b/policies/ratelimits.go new file mode 100644 index 0000000..6a78e7c --- /dev/null +++ b/policies/ratelimits.go @@ -0,0 +1,42 @@ +package policies + +import ( + "context" + "net/http" + "time" + + "github.com/fiatjaf/khatru" + "github.com/nbd-wtf/go-nostr" +) + +func EventIPRateLimiter(tokensPerInterval int, interval time.Duration, maxTokens int) func(ctx context.Context, _ *nostr.Event) (reject bool, msg string) { + rl := startRateLimitSystem[string](tokensPerInterval, interval, maxTokens) + + return func(ctx context.Context, _ *nostr.Event) (reject bool, msg string) { + return rl(khatru.GetIP(ctx)), "rate-limited: slow down, please" + } +} + +func EventPubKeyRateLimiter(tokensPerInterval int, interval time.Duration, maxTokens int) func(ctx context.Context, _ *nostr.Event) (reject bool, msg string) { + rl := startRateLimitSystem[string](tokensPerInterval, interval, maxTokens) + + return func(ctx context.Context, evt *nostr.Event) (reject bool, msg string) { + return rl(evt.PubKey), "rate-limited: slow down, please" + } +} + +func ConnectionRateLimiter(tokensPerInterval int, interval time.Duration, maxTokens int) func(r *http.Request) bool { + rl := startRateLimitSystem[string](tokensPerInterval, interval, maxTokens) + + return func(r *http.Request) bool { + return rl(khatru.GetIPFromRequest(r)) + } +} + +func FilterIPRateLimiter(tokensPerInterval int, interval time.Duration, maxTokens int) func(ctx context.Context, _ nostr.Filter) (reject bool, msg string) { + rl := startRateLimitSystem[string](tokensPerInterval, interval, maxTokens) + + return func(ctx context.Context, _ nostr.Filter) (reject bool, msg string) { + return rl(khatru.GetIP(ctx)), "rate-limited: there is a bug in the client, no one should be making so many requests" + } +} diff --git a/policies/sane_defaults.go b/policies/sane_defaults.go index 4fd1f2c..6b775ac 100644 --- a/policies/sane_defaults.go +++ b/policies/sane_defaults.go @@ -1,14 +1,24 @@ package policies -import "github.com/fiatjaf/khatru" +import ( + "time" + + "github.com/fiatjaf/khatru" +) func ApplySaneDefaults(relay *khatru.Relay) { relay.RejectEvent = append(relay.RejectEvent, RejectEventsWithBase64Media, + EventIPRateLimiter(2, time.Minute*3, 5), ) relay.RejectFilter = append(relay.RejectFilter, NoEmptyFilters, NoComplexFilters, + FilterIPRateLimiter(20, time.Minute, 100), + ) + + relay.RejectConnection = append(relay.RejectConnection, + ConnectionRateLimiter(1, time.Minute*5, 3), ) } diff --git a/relay.go b/relay.go index 4be8219..2ba192e 100644 --- a/relay.go +++ b/relay.go @@ -45,6 +45,7 @@ type Relay struct { RejectEvent []func(ctx context.Context, event *nostr.Event) (reject bool, msg string) RejectFilter []func(ctx context.Context, filter nostr.Filter) (reject bool, msg string) RejectCountFilter []func(ctx context.Context, filter nostr.Filter) (reject bool, msg string) + RejectConnection []func(r *http.Request) bool OverwriteDeletionOutcome []func(ctx context.Context, target *nostr.Event, deletion *nostr.Event) (acceptDeletion bool, msg string) OverwriteResponseEvent []func(ctx context.Context, event *nostr.Event) OverwriteFilter []func(ctx context.Context, filter *nostr.Filter) diff --git a/utils.go b/utils.go index ae150dd..3e3620d 100644 --- a/utils.go +++ b/utils.go @@ -4,7 +4,6 @@ import ( "context" "github.com/nbd-wtf/go-nostr" - "github.com/sebest/xff" ) const ( @@ -31,7 +30,7 @@ func GetAuthed(ctx context.Context) string { } func GetIP(ctx context.Context) string { - return xff.GetRemoteAddr(GetConnection(ctx).Request) + return GetIPFromRequest(GetConnection(ctx).Request) } func GetSubscriptionID(ctx context.Context) string {