diff --git a/controllers/invoicestream.ctrl.go b/controllers/invoicestream.ctrl.go index bf939f3..9b4d2f0 100644 --- a/controllers/invoicestream.ctrl.go +++ b/controllers/invoicestream.ctrl.go @@ -33,7 +33,8 @@ func (controller *InvoiceStreamController) StreamInvoices(c echo.Context) error return err } invoiceChan := make(chan models.Invoice) - controller.svc.InvoiceSubscribers[userId] = invoiceChan + reqId := c.Response().Header().Get(echo.HeaderXRequestID) + controller.svc.InvoicePubSub.Subscribe(reqId, userId, invoiceChan) ctx := c.Request().Context() upgrader := websocket.Upgrader{} upgrader.CheckOrigin = func(r *http.Request) bool { return true } @@ -80,5 +81,6 @@ SocketLoop: } } } + controller.svc.InvoicePubSub.Unsubscribe(reqId, userId) return nil } diff --git a/lib/service/invoicesubscription.go b/lib/service/invoicesubscription.go index bf1c438..e4ae58b 100644 --- a/lib/service/invoicesubscription.go +++ b/lib/service/invoicesubscription.go @@ -97,9 +97,7 @@ func (svc *LndhubService) ProcessInvoiceUpdate(ctx context.Context, rawInvoice * svc.Logger.Errorf("Failed to commit DB transaction user_id:%v invoice_id:%v %v", invoice.UserID, invoice.ID, err) return err } - if sub, ok := svc.InvoiceSubscribers[invoice.UserID]; ok { - sub <- invoice - } + svc.InvoicePubSub.Publish(invoice.UserID, invoice) return nil } diff --git a/lib/service/pubsub.go b/lib/service/pubsub.go new file mode 100644 index 0000000..e38e10f --- /dev/null +++ b/lib/service/pubsub.go @@ -0,0 +1,48 @@ +package service + +import ( + "sync" + + "github.com/getAlby/lndhub.go/db/models" +) + +type Pubsub struct { + mu sync.RWMutex + subs map[int64]map[string]chan models.Invoice +} + +func NewPubsub() *Pubsub { + ps := &Pubsub{} + ps.subs = make(map[int64]map[string]chan models.Invoice) + return ps +} + +func (ps *Pubsub) Subscribe(id string, topic int64, ch chan models.Invoice) { + ps.mu.Lock() + defer ps.mu.Unlock() + + ps.subs[topic][id] = ch +} + +func (ps *Pubsub) Unsubscribe(id string, topic int64) { + ps.mu.Lock() + defer ps.mu.Unlock() + delete(ps.subs[topic], id) +} + +func (ps *Pubsub) Publish(topic int64, msg models.Invoice) { + ps.mu.RLock() + defer ps.mu.RUnlock() + + for _, ch := range ps.subs[topic] { + ch <- msg + } +} + +func (ps *Pubsub) CloseAll() { + for _, subs := range ps.subs { + for _, ch := range subs { + close(ch) + } + } +} diff --git a/lib/service/service.go b/lib/service/service.go index 03484c0..06b3238 100644 --- a/lib/service/service.go +++ b/lib/service/service.go @@ -17,12 +17,12 @@ import ( const alphaNumBytes = random.Alphanumeric type LndhubService struct { - Config *Config - DB *bun.DB - LndClient lnd.LightningClientWrapper - Logger *lecho.Logger - IdentityPubkey string - InvoiceSubscribers map[int64]chan models.Invoice + Config *Config + DB *bun.DB + LndClient lnd.LightningClientWrapper + Logger *lecho.Logger + IdentityPubkey string + InvoicePubSub *Pubsub } func (svc *LndhubService) GenerateToken(ctx context.Context, login, password, inRefreshToken string) (accessToken, refreshToken string, err error) { diff --git a/main.go b/main.go index 2b945d8..edab439 100644 --- a/main.go +++ b/main.go @@ -15,7 +15,6 @@ import ( "github.com/getAlby/lndhub.go/controllers" "github.com/getAlby/lndhub.go/db" "github.com/getAlby/lndhub.go/db/migrations" - "github.com/getAlby/lndhub.go/db/models" "github.com/getAlby/lndhub.go/lib" "github.com/getAlby/lndhub.go/lib/responses" "github.com/getAlby/lndhub.go/lib/service" @@ -119,12 +118,12 @@ func main() { logger.Infof("Connected to LND: %s - %s", getInfo.Alias, getInfo.IdentityPubkey) svc := &service.LndhubService{ - Config: c, - DB: dbConn, - LndClient: lndClient, - Logger: logger, - IdentityPubkey: getInfo.IdentityPubkey, - InvoiceSubscribers: map[int64]chan models.Invoice{}, + Config: c, + DB: dbConn, + LndClient: lndClient, + Logger: logger, + IdentityPubkey: getInfo.IdentityPubkey, + InvoicePubSub: service.NewPubsub(), } strictRateLimitMiddleware := createRateLimitMiddleware(c.StrictRateLimit, c.BurstRateLimit) @@ -204,9 +203,7 @@ func main() { e.Logger.Fatal(err) } //close all channels - for _, sub := range svc.InvoiceSubscribers { - close(sub) - } + svc.InvoicePubSub.CloseAll() if echoPrometheus != nil { if err := echoPrometheus.Shutdown(ctx); err != nil { e.Logger.Fatal(err)