diff --git a/listener_fuzz_test.go b/listener_fuzz_test.go new file mode 100644 index 0000000..683d4bc --- /dev/null +++ b/listener_fuzz_test.go @@ -0,0 +1,129 @@ +package khatru + +import ( + "math/rand" + "testing" + + "github.com/nbd-wtf/go-nostr" + "github.com/stretchr/testify/require" +) + +func FuzzRandomListenerClientRemoving(f *testing.F) { + f.Add(uint(20), uint(20), uint(1)) + f.Fuzz(func(t *testing.T, utw uint, ubs uint, ualf uint) { + totalWebsockets := int(utw) + baseSubs := int(ubs) + addListenerFreq := int(ualf) + 1 + + rl := NewRelay() + + f := nostr.Filter{Kinds: []int{1}} + cancel := func(cause error) {} + + websockets := make([]*WebSocket, 0, totalWebsockets*baseSubs) + + l := 0 + + for i := 0; i < totalWebsockets; i++ { + ws := &WebSocket{} + websockets = append(websockets, ws) + rl.clients[ws] = nil + } + + s := 0 + for j := 0; j < baseSubs; j++ { + for i := 0; i < totalWebsockets; i++ { + ws := websockets[i] + w := idFromSeqUpper(i) + + if s%addListenerFreq == 0 { + l++ + rl.addListener(ws, w+":"+idFromSeqLower(j), rl, f, cancel) + } + + s++ + } + } + + require.Len(t, rl.clients, totalWebsockets) + require.Len(t, rl.listeners, l) + + for ws := range rl.clients { + rl.removeClientAndListeners(ws) + } + + require.Len(t, rl.clients, 0) + require.Len(t, rl.listeners, 0) + }) +} + +func FuzzRandomListenerIdRemoving(f *testing.F) { + f.Add(uint(20), uint(20), uint(1), uint(4)) + f.Fuzz(func(t *testing.T, utw uint, ubs uint, ualf uint, ualef uint) { + totalWebsockets := int(utw) + baseSubs := int(ubs) + addListenerFreq := int(ualf) + 1 + addExtraListenerFreq := int(ualef) + 1 + + if totalWebsockets > 1024 || baseSubs > 1024 { + return + } + + rl := NewRelay() + + f := nostr.Filter{Kinds: []int{1}} + cancel := func(cause error) {} + websockets := make([]*WebSocket, 0, totalWebsockets) + + type wsid struct { + ws *WebSocket + id string + } + + subs := make([]wsid, 0, totalWebsockets*baseSubs) + extra := 0 + + for i := 0; i < totalWebsockets; i++ { + ws := &WebSocket{} + websockets = append(websockets, ws) + rl.clients[ws] = nil + } + + s := 0 + for j := 0; j < baseSubs; j++ { + for i := 0; i < totalWebsockets; i++ { + ws := websockets[i] + w := idFromSeqUpper(i) + + if s%addListenerFreq == 0 { + id := w + ":" + idFromSeqLower(j) + rl.addListener(ws, id, rl, f, cancel) + subs = append(subs, wsid{ws, id}) + + if s%addExtraListenerFreq == 0 { + rl.addListener(ws, id, rl, f, cancel) + extra++ + } + } + + s++ + } + } + + require.Len(t, rl.clients, totalWebsockets) + require.Len(t, rl.listeners, len(subs)+extra) + + rand.Shuffle(len(subs), func(i, j int) { + subs[i], subs[j] = subs[j], subs[i] + }) + for _, wsidToRemove := range subs { + rl.removeListenerId(wsidToRemove.ws, wsidToRemove.id) + } + + require.Len(t, rl.listeners, 0) + require.Len(t, rl.clients, totalWebsockets) + for _, specs := range rl.clients { + require.Len(t, specs, 0) + } + }) +} diff --git a/listener_test.go b/listener_test.go index 6507d51..1cf0c79 100644 --- a/listener_test.go +++ b/listener_test.go @@ -2,12 +2,27 @@ package khatru import ( "math/rand" + "strings" "testing" "github.com/nbd-wtf/go-nostr" "github.com/stretchr/testify/require" ) +func idFromSeqUpper(seq int) string { return idFromSeq(seq, 65, 90) } +func idFromSeqLower(seq int) string { return idFromSeq(seq, 97, 122) } +func idFromSeq(seq int, min, max int) string { + maxSeq := max - min + 1 + nLetters := seq/maxSeq + 1 + result := strings.Builder{} + result.Grow(nLetters) + for l := 0; l < nLetters; l++ { + letter := rune(seq%maxSeq + min) + result.WriteRune(letter) + } + return result.String() +} + func TestListenerSetupAndRemoveOnce(t *testing.T) { rl := NewRelay() @@ -425,11 +440,11 @@ func TestRandomListenerClientRemoving(t *testing.T) { for j := 0; j < 20; j++ { for i := 0; i < 20; i++ { ws := websockets[i] - w := string(rune(i + 65)) + w := idFromSeqUpper(i) if rand.Intn(2) < 1 { l++ - rl.addListener(ws, w+":"+string(rune(j+97)), rl, f, cancel) + rl.addListener(ws, w+":"+idFromSeqLower(j), rl, f, cancel) } } } @@ -470,10 +485,10 @@ func TestRandomListenerIdRemoving(t *testing.T) { for j := 0; j < 20; j++ { for i := 0; i < 20; i++ { ws := websockets[i] - w := string(rune(i + 65)) + w := idFromSeqUpper(i) if rand.Intn(2) < 1 { - id := w + ":" + string(rune(j+97)) + id := w + ":" + idFromSeqLower(j) rl.addListener(ws, id, rl, f, cancel) subs = append(subs, wsid{ws, id})