diff --git a/handlers.go b/handlers.go index 885f799..510479d 100644 --- a/handlers.go +++ b/handlers.go @@ -33,8 +33,6 @@ func (rl *Relay) ServeHTTP(w http.ResponseWriter, r *http.Request) { } func (rl *Relay) HandleWebsocket(w http.ResponseWriter, r *http.Request) { - connectionContext := r.Context() - conn, err := rl.upgrader.Upgrade(w, r, nil) if err != nil { rl.Log.Printf("failed to upgrade websocket: %v\n", err) @@ -54,12 +52,17 @@ func (rl *Relay) HandleWebsocket(w http.ResponseWriter, r *http.Request) { Authed: make(chan struct{}), } - connectionContext = context.WithValue(connectionContext, WS_KEY, ws) - // reader go func() { + ctx, cancel := context.WithCancel( + context.WithValue( + context.Background(), + WS_KEY, ws, + ), + ) defer func() { ticker.Stop() + cancel() if _, ok := rl.clients.Load(conn); ok { conn.Close() rl.clients.Delete(conn) @@ -75,7 +78,7 @@ func (rl *Relay) HandleWebsocket(w http.ResponseWriter, r *http.Request) { }) for _, onconnect := range rl.OnConnect { - onconnect(connectionContext) + onconnect(ctx) } for { @@ -99,14 +102,6 @@ func (rl *Relay) HandleWebsocket(w http.ResponseWriter, r *http.Request) { } go func(message []byte) { - ctx := context.WithValue( - context.WithValue( - context.Background(), - AUTH_CONTEXT_KEY, connectionContext.Value(AUTH_CONTEXT_KEY), - ), - WS_KEY, ws, - ) - envelope := nostr.ParseMessage(message) if envelope == nil { // stop silently @@ -163,8 +158,10 @@ func (rl *Relay) HandleWebsocket(w http.ResponseWriter, r *http.Request) { eose := sync.WaitGroup{} eose.Add(len(env.Filters)) + reqCtx, cancel := context.WithCancelCause(ctx) + for _, filter := range env.Filters { - err := rl.handleRequest(ctx, env.SubscriptionID, &eose, ws, filter) + err := rl.handleRequest(reqCtx, env.SubscriptionID, &eose, ws, filter) if err != nil { reason := nostr.NormalizeOKMessage(err.Error(), "blocked") if isAuthRequired(reason) { @@ -180,7 +177,7 @@ func (rl *Relay) HandleWebsocket(w http.ResponseWriter, r *http.Request) { ws.WriteJSON(nostr.EOSEEnvelope(env.SubscriptionID)) }() - setListener(env.SubscriptionID, ws, env.Filters) + setListener(env.SubscriptionID, ws, env.Filters, cancel) case *nostr.CloseEnvelope: removeListenerId(ws, string(*env)) case *nostr.AuthEnvelope: @@ -188,7 +185,7 @@ func (rl *Relay) HandleWebsocket(w http.ResponseWriter, r *http.Request) { if pubkey, ok := nip42.ValidateAuthEvent(&env.Event, ws.Challenge, wsBaseUrl); ok { ws.AuthedPublicKey = pubkey close(ws.Authed) - connectionContext = context.WithValue(ctx, AUTH_CONTEXT_KEY, pubkey) + ctx = context.WithValue(ctx, AUTH_CONTEXT_KEY, pubkey) ws.WriteJSON(nostr.OKEnvelope{EventID: env.Event.ID, OK: true}) } else { ws.WriteJSON(nostr.OKEnvelope{EventID: env.Event.ID, OK: false, Reason: "error: failed to authenticate"}) diff --git a/listener.go b/listener.go index 0c51862..0ccb5ca 100644 --- a/listener.go +++ b/listener.go @@ -1,12 +1,16 @@ package khatru import ( + "context" + "fmt" + "github.com/nbd-wtf/go-nostr" "github.com/puzpuzpuz/xsync/v2" ) type Listener struct { filters nostr.Filters + cancel context.CancelCauseFunc } var listeners = xsync.NewTypedMapOf[*WebSocket, *xsync.MapOf[string, *Listener]](pointerHasher[WebSocket]) @@ -43,24 +47,28 @@ func GetListeningFilters() nostr.Filters { return respfilters } -func setListener(id string, ws *WebSocket, filters nostr.Filters) { +func setListener(id string, ws *WebSocket, filters nostr.Filters, cancel context.CancelCauseFunc) { subs, _ := listeners.LoadOrCompute(ws, func() *xsync.MapOf[string, *Listener] { return xsync.NewMapOf[*Listener]() }) - subs.Store(id, &Listener{filters: filters}) + subs.Store(id, &Listener{filters: filters, cancel: cancel}) } -// Remove a specific subscription id from listeners for a given ws client +// remove a specific subscription id from listeners for a given ws client +// and cancel its specific context func removeListenerId(ws *WebSocket, id string) { if subs, ok := listeners.Load(ws); ok { - subs.Delete(id) + if listener, ok := subs.LoadAndDelete(id); ok { + listener.cancel(fmt.Errorf("subscription closed by client")) + } if subs.Size() == 0 { listeners.Delete(ws) } } } -// Remove WebSocket conn from listeners +// remove WebSocket conn from listeners +// (no need to cancel contexts as they are all inherited from the main connection context) func removeListener(ws *WebSocket) { listeners.Delete(ws) }