diff --git a/controllers/invoicestream.ctrl.go b/controllers/invoicestream.ctrl.go index 9b4d2f0..42be2ef 100644 --- a/controllers/invoicestream.ctrl.go +++ b/controllers/invoicestream.ctrl.go @@ -35,7 +35,6 @@ func (controller *InvoiceStreamController) StreamInvoices(c echo.Context) error invoiceChan := make(chan models.Invoice) reqId := c.Response().Header().Get(echo.HeaderXRequestID) controller.svc.InvoicePubSub.Subscribe(reqId, userId, invoiceChan) - ctx := c.Request().Context() upgrader := websocket.Upgrader{} upgrader.CheckOrigin = func(r *http.Request) bool { return true } ticker := time.NewTicker(30 * time.Second) @@ -44,6 +43,19 @@ func (controller *InvoiceStreamController) StreamInvoices(c echo.Context) error return err } defer ws.Close() + + //start listening for close messages + done := make(chan struct{}) + go func() { + defer close(done) + for { + _, _, err := ws.ReadMessage() + if err != nil { + return + } + } + }() + //start with keepalive message err = ws.WriteJSON(&InvoiceEventWrapper{Type: "keepalive"}) if err != nil { @@ -53,7 +65,7 @@ func (controller *InvoiceStreamController) StreamInvoices(c echo.Context) error SocketLoop: for { select { - case <-ctx.Done(): + case <-done: break SocketLoop case <-ticker.C: err := ws.WriteJSON(&InvoiceEventWrapper{Type: "keepalive"}) @@ -81,6 +93,5 @@ SocketLoop: } } } - controller.svc.InvoicePubSub.Unsubscribe(reqId, userId) - return nil + return controller.svc.InvoicePubSub.Unsubscribe(reqId, userId) } diff --git a/lib/service/pubsub.go b/lib/service/pubsub.go index ac43857..b5faad8 100644 --- a/lib/service/pubsub.go +++ b/lib/service/pubsub.go @@ -26,10 +26,12 @@ func (ps *Pubsub) Subscribe(id string, topic int64, ch chan models.Invoice) { ps.subs[topic][id] = ch } -func (ps *Pubsub) Unsubscribe(id string, topic int64) { +func (ps *Pubsub) Unsubscribe(id string, topic int64) error { ps.mu.Lock() defer ps.mu.Unlock() + close(ps.subs[topic][id]) delete(ps.subs[topic], id) + return nil } func (ps *Pubsub) Publish(topic int64, msg models.Invoice) { @@ -40,11 +42,3 @@ func (ps *Pubsub) Publish(topic int64, msg models.Invoice) { ch <- msg } } - -func (ps *Pubsub) CloseAll() { - for _, subs := range ps.subs { - for _, ch := range subs { - close(ch) - } - } -} diff --git a/main.go b/main.go index edab439..ed2487b 100644 --- a/main.go +++ b/main.go @@ -202,8 +202,6 @@ func main() { if err := e.Shutdown(ctx); err != nil { e.Logger.Fatal(err) } - //close all channels - svc.InvoicePubSub.CloseAll() if echoPrometheus != nil { if err := echoPrometheus.Shutdown(ctx); err != nil { e.Logger.Fatal(err)