mirror of
https://github.com/aljazceru/khatru.git
synced 2026-02-05 21:24:22 +01:00
fun with connection contexts and context cancelations.
This commit is contained in:
29
handlers.go
29
handlers.go
@@ -33,8 +33,6 @@ func (rl *Relay) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
func (rl *Relay) HandleWebsocket(w http.ResponseWriter, r *http.Request) {
|
||||
connectionContext := r.Context()
|
||||
|
||||
conn, err := rl.upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
rl.Log.Printf("failed to upgrade websocket: %v\n", err)
|
||||
@@ -54,12 +52,17 @@ func (rl *Relay) HandleWebsocket(w http.ResponseWriter, r *http.Request) {
|
||||
Authed: make(chan struct{}),
|
||||
}
|
||||
|
||||
connectionContext = context.WithValue(connectionContext, WS_KEY, ws)
|
||||
|
||||
// reader
|
||||
go func() {
|
||||
ctx, cancel := context.WithCancel(
|
||||
context.WithValue(
|
||||
context.Background(),
|
||||
WS_KEY, ws,
|
||||
),
|
||||
)
|
||||
defer func() {
|
||||
ticker.Stop()
|
||||
cancel()
|
||||
if _, ok := rl.clients.Load(conn); ok {
|
||||
conn.Close()
|
||||
rl.clients.Delete(conn)
|
||||
@@ -75,7 +78,7 @@ func (rl *Relay) HandleWebsocket(w http.ResponseWriter, r *http.Request) {
|
||||
})
|
||||
|
||||
for _, onconnect := range rl.OnConnect {
|
||||
onconnect(connectionContext)
|
||||
onconnect(ctx)
|
||||
}
|
||||
|
||||
for {
|
||||
@@ -99,14 +102,6 @@ func (rl *Relay) HandleWebsocket(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
go func(message []byte) {
|
||||
ctx := context.WithValue(
|
||||
context.WithValue(
|
||||
context.Background(),
|
||||
AUTH_CONTEXT_KEY, connectionContext.Value(AUTH_CONTEXT_KEY),
|
||||
),
|
||||
WS_KEY, ws,
|
||||
)
|
||||
|
||||
envelope := nostr.ParseMessage(message)
|
||||
if envelope == nil {
|
||||
// stop silently
|
||||
@@ -163,8 +158,10 @@ func (rl *Relay) HandleWebsocket(w http.ResponseWriter, r *http.Request) {
|
||||
eose := sync.WaitGroup{}
|
||||
eose.Add(len(env.Filters))
|
||||
|
||||
reqCtx, cancel := context.WithCancelCause(ctx)
|
||||
|
||||
for _, filter := range env.Filters {
|
||||
err := rl.handleRequest(ctx, env.SubscriptionID, &eose, ws, filter)
|
||||
err := rl.handleRequest(reqCtx, env.SubscriptionID, &eose, ws, filter)
|
||||
if err != nil {
|
||||
reason := nostr.NormalizeOKMessage(err.Error(), "blocked")
|
||||
if isAuthRequired(reason) {
|
||||
@@ -180,7 +177,7 @@ func (rl *Relay) HandleWebsocket(w http.ResponseWriter, r *http.Request) {
|
||||
ws.WriteJSON(nostr.EOSEEnvelope(env.SubscriptionID))
|
||||
}()
|
||||
|
||||
setListener(env.SubscriptionID, ws, env.Filters)
|
||||
setListener(env.SubscriptionID, ws, env.Filters, cancel)
|
||||
case *nostr.CloseEnvelope:
|
||||
removeListenerId(ws, string(*env))
|
||||
case *nostr.AuthEnvelope:
|
||||
@@ -188,7 +185,7 @@ func (rl *Relay) HandleWebsocket(w http.ResponseWriter, r *http.Request) {
|
||||
if pubkey, ok := nip42.ValidateAuthEvent(&env.Event, ws.Challenge, wsBaseUrl); ok {
|
||||
ws.AuthedPublicKey = pubkey
|
||||
close(ws.Authed)
|
||||
connectionContext = context.WithValue(ctx, AUTH_CONTEXT_KEY, pubkey)
|
||||
ctx = context.WithValue(ctx, AUTH_CONTEXT_KEY, pubkey)
|
||||
ws.WriteJSON(nostr.OKEnvelope{EventID: env.Event.ID, OK: true})
|
||||
} else {
|
||||
ws.WriteJSON(nostr.OKEnvelope{EventID: env.Event.ID, OK: false, Reason: "error: failed to authenticate"})
|
||||
|
||||
18
listener.go
18
listener.go
@@ -1,12 +1,16 @@
|
||||
package khatru
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/nbd-wtf/go-nostr"
|
||||
"github.com/puzpuzpuz/xsync/v2"
|
||||
)
|
||||
|
||||
type Listener struct {
|
||||
filters nostr.Filters
|
||||
cancel context.CancelCauseFunc
|
||||
}
|
||||
|
||||
var listeners = xsync.NewTypedMapOf[*WebSocket, *xsync.MapOf[string, *Listener]](pointerHasher[WebSocket])
|
||||
@@ -43,24 +47,28 @@ func GetListeningFilters() nostr.Filters {
|
||||
return respfilters
|
||||
}
|
||||
|
||||
func setListener(id string, ws *WebSocket, filters nostr.Filters) {
|
||||
func setListener(id string, ws *WebSocket, filters nostr.Filters, cancel context.CancelCauseFunc) {
|
||||
subs, _ := listeners.LoadOrCompute(ws, func() *xsync.MapOf[string, *Listener] {
|
||||
return xsync.NewMapOf[*Listener]()
|
||||
})
|
||||
subs.Store(id, &Listener{filters: filters})
|
||||
subs.Store(id, &Listener{filters: filters, cancel: cancel})
|
||||
}
|
||||
|
||||
// Remove a specific subscription id from listeners for a given ws client
|
||||
// remove a specific subscription id from listeners for a given ws client
|
||||
// and cancel its specific context
|
||||
func removeListenerId(ws *WebSocket, id string) {
|
||||
if subs, ok := listeners.Load(ws); ok {
|
||||
subs.Delete(id)
|
||||
if listener, ok := subs.LoadAndDelete(id); ok {
|
||||
listener.cancel(fmt.Errorf("subscription closed by client"))
|
||||
}
|
||||
if subs.Size() == 0 {
|
||||
listeners.Delete(ws)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Remove WebSocket conn from listeners
|
||||
// remove WebSocket conn from listeners
|
||||
// (no need to cancel contexts as they are all inherited from the main connection context)
|
||||
func removeListener(ws *WebSocket) {
|
||||
listeners.Delete(ws)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user