diff --git a/controllers/invoicestream.ctrl.go b/controllers/invoicestream.ctrl.go index 2ec80bc..a6630e7 100644 --- a/controllers/invoicestream.ctrl.go +++ b/controllers/invoicestream.ctrl.go @@ -33,28 +33,13 @@ func (controller *InvoiceStreamController) StreamInvoices(c echo.Context) error return err } invoiceChan := make(chan models.Invoice) - subId := controller.svc.InvoicePubSub.Subscribe(userId, invoiceChan) - upgrader := websocket.Upgrader{} - upgrader.CheckOrigin = func(r *http.Request) bool { return true } ticker := time.NewTicker(30 * time.Second) - ws, err := upgrader.Upgrade(c.Response(), c.Request(), nil) + ws, done, err := createWebsocketUpgrader(c) if err != nil { - controller.svc.InvoicePubSub.Unsubscribe(subId, userId) return err } - defer ws.Close() - - //start listening for close messages - done := make(chan struct{}) - go func() { - defer close(done) - for { - _, _, err := ws.ReadMessage() - if err != nil { - return - } - } - }() + //start subscription + subId := controller.svc.InvoicePubSub.Subscribe(userId, invoiceChan) //start with keepalive message err = ws.WriteJSON(&InvoiceEventWrapper{Type: "keepalive"}) @@ -94,5 +79,30 @@ SocketLoop: } } } - return controller.svc.InvoicePubSub.Unsubscribe(subId, userId) + controller.svc.InvoicePubSub.Unsubscribe(subId, userId) + return nil +} + +//open the websocket and start listening for close messages in a goroutine +func createWebsocketUpgrader(c echo.Context) (conn *websocket.Conn, done chan struct{}, err error) { + upgrader := websocket.Upgrader{} + upgrader.CheckOrigin = func(r *http.Request) bool { return true } + ws, err := upgrader.Upgrade(c.Response(), c.Request(), nil) + if err != nil { + return nil, nil, err + } + defer ws.Close() + + //start listening for close messages + done = make(chan struct{}) + go func() { + defer close(done) + for { + _, _, err := ws.ReadMessage() + if err != nil { + return + } + } + }() + return ws, done, nil } diff --git a/lib/service/pubsub.go b/lib/service/pubsub.go index f71b853..47b0486 100644 --- a/lib/service/pubsub.go +++ b/lib/service/pubsub.go @@ -29,18 +29,27 @@ func (ps *Pubsub) Subscribe(topic int64, ch chan models.Invoice) (subId string) return subId } -func (ps *Pubsub) Unsubscribe(id string, topic int64) error { +func (ps *Pubsub) Unsubscribe(id string, topic int64) { ps.mu.Lock() defer ps.mu.Unlock() + if ps.subs[topic] == nil { + return + } + if ps.subs[topic][id] == nil { + return + } close(ps.subs[topic][id]) delete(ps.subs[topic], id) - return nil } func (ps *Pubsub) Publish(topic int64, msg models.Invoice) { ps.mu.RLock() defer ps.mu.RUnlock() + if ps.subs[topic] == nil { + return + } + for _, ch := range ps.subs[topic] { ch <- msg }