diff --git a/lib/service/checkpayments.go b/lib/service/checkpayments.go index 10e1e00..7805e36 100644 --- a/lib/service/checkpayments.go +++ b/lib/service/checkpayments.go @@ -12,15 +12,17 @@ import ( "github.com/lightningnetwork/lnd/lnrpc/routerrpc" ) +func (svc *LndhubService) GetAllPendingPayments(ctx context.Context) ([]models.Invoice, error) { + payments := []models.Invoice{} + err := svc.DB.NewSelect().Model(&payments).Where("state = 'initialized'").Where("type = 'outgoing'").Where("r_hash != ''").Where("created_at >= (now() - interval '2 weeks') ").Scan(ctx) + return payments, err +} func (svc *LndhubService) CheckAllPendingOutgoingPayments(ctx context.Context) (err error) { - //check database for all pending payments - pendingPayments := []models.Invoice{} - //since this part is synchronously executed before the main server starts, we should not get into race conditions - //only fetch invoices from the last 2 weeks which should be a safe timeframe for hodl invoices to avoid refetching old invoices again and again - err = svc.DB.NewSelect().Model(&pendingPayments).Where("state = 'initialized'").Where("type = 'outgoing'").Where("r_hash != ''").Where("created_at >= (now() - interval '2 weeks') ").Scan(ctx) - if err != nil { - return err - } + pendingPayments, err := svc.GetAllPendingPayments(ctx) + if err != nil { + return err + } + svc.Logger.Infof("Found %d pending payments", len(pendingPayments)) //call trackoutgoingpaymentstatus for each one var wg sync.WaitGroup @@ -39,6 +41,13 @@ func (svc *LndhubService) CheckAllPendingOutgoingPayments(ctx context.Context) ( return nil } +func (svc *LndhubService) GetTransactionEntryByInvoiceId(ctx context.Context, id int64) (models.TransactionEntry, error) { + entry := models.TransactionEntry{} + + err := svc.DB.NewSelect().Model(&entry).Where("invoice_id = ?", id).Limit(1).Scan(ctx) + return entry, err +} + // Should be called in a goroutine as the tracking can potentially take a long time func (svc *LndhubService) TrackOutgoingPaymentstatus(ctx context.Context, invoice *models.Invoice) { //ask lnd using TrackPaymentV2 by hash of payment @@ -55,8 +64,7 @@ func (svc *LndhubService) TrackOutgoingPaymentstatus(ctx context.Context, invoic svc.Logger.Errorf("Error tracking payment %s: %s", invoice.RHash, err.Error()) return } - //fetch the tx entry for the invoice - entry := models.TransactionEntry{} + //call HandleFailedPayment or HandleSuccesfulPayment for { payment, err := paymentTracker.Recv() @@ -64,7 +72,7 @@ func (svc *LndhubService) TrackOutgoingPaymentstatus(ctx context.Context, invoic svc.Logger.Errorf("Error tracking payment with hash %s: %s", invoice.RHash, err.Error()) return } - err = svc.DB.NewSelect().Model(&entry).Where("invoice_id = ?", invoice.ID).Limit(1).Scan(ctx) + entry, err := svc.GetTransactionEntryByInvoiceId(ctx, invoice.ID) if err != nil { svc.Logger.Errorf("Error tracking payment %s: %s", invoice.RHash, err.Error()) return diff --git a/lib/service/config.go b/lib/service/config.go index 26115d5..0acf0ec 100644 --- a/lib/service/config.go +++ b/lib/service/config.go @@ -46,6 +46,7 @@ type Config struct { RabbitMQLndInvoiceExchange string `envconfig:"RABBITMQ_LND_INVOICE_EXCHANGE" default:"lnd_invoice"` RabbitMQInvoiceConsumerQueueName string `envconfig:"RABBITMQ_INVOICE_CONSUMER_QUEUE_NAME" default:"lnd_invoice_consumer"` SubscriptionConsumerType string `envconfig:"SUBSCRIPTION_CONSUMER_TYPE" default:"grpc"` + FinalizePendingPaymentsWith string `envconfig:"FINALIZE_PAYMENTS_WITH" default:"native"` Branding BrandingConfig } diff --git a/main.go b/main.go index 448a660..51f8de7 100644 --- a/main.go +++ b/main.go @@ -156,6 +156,15 @@ func main() { } logger.Infof("Connected to LND: %s - %s", getInfo.Alias, getInfo.IdentityPubkey) + svc := &service.LndhubService{ + Config: c, + DB: dbConn, + LndClient: lndClient, + Logger: logger, + IdentityPubkey: getInfo.IdentityPubkey, + InvoicePubSub: service.NewPubsub(), + } + // If no RABBITMQ_URI was provided we will not attempt to create a client // No rabbitmq features will be available in this case. var rabbitmqClient rabbitmq.Client @@ -165,6 +174,7 @@ func main() { rabbitmq.WithLndInvoiceExchange(c.RabbitMQLndInvoiceExchange), rabbitmq.WithLndHubInvoiceExchange(c.RabbitMQLndhubInvoiceExchange), rabbitmq.WithLndInvoiceConsumerQueueName(c.RabbitMQInvoiceConsumerQueueName), + rabbitmq.WithLndHubService(svc), ) if err != nil { logger.Fatal(err) @@ -174,15 +184,7 @@ func main() { defer rabbitmqClient.Close() } - svc := &service.LndhubService{ - Config: c, - DB: dbConn, - LndClient: lndClient, - RabbitMQClient: rabbitmqClient, - Logger: logger, - IdentityPubkey: getInfo.IdentityPubkey, - InvoicePubSub: service.NewPubsub(), - } + svc.RabbitMQClient = rabbitmqClient logMw := createLoggingMiddleware(logger) // strict rate limit for requests for sending payments @@ -230,10 +232,20 @@ func main() { // A goroutine will be spawned for each one backgroundWg.Add(1) go func() { - err = svc.CheckAllPendingOutgoingPayments(backGroundCtx) - if err != nil { - svc.Logger.Error(err) + switch svc.Config.FinalizePendingPaymentsWith { + case "rabbitmq": + err = svc.RabbitMQClient.FinalizeInitializedPayments(backGroundCtx) + if err != nil { + svc.Logger.Error(err) + } + + default: + err = svc.CheckAllPendingOutgoingPayments(backGroundCtx) + if err != nil { + svc.Logger.Error(err) + } } + svc.Logger.Info("Pending payment check routines done") backgroundWg.Done() }() diff --git a/rabbitmq/rabbitmq.go b/rabbitmq/rabbitmq.go index 91dec61..c830376 100644 --- a/rabbitmq/rabbitmq.go +++ b/rabbitmq/rabbitmq.go @@ -4,10 +4,12 @@ import ( "bytes" "context" "encoding/json" + "errors" "fmt" "io" "os" "sync" + "time" "github.com/getAlby/lndhub.go/db/models" "github.com/getsentry/sentry-go" @@ -39,12 +41,21 @@ type ( type Client interface { SubscribeToLndInvoices(context.Context, IncomingInvoiceHandler) error StartPublishInvoices(context.Context, SubscribeToInvoicesFunc, EncodeOutgoingInvoiceFunc) error + FinalizeInitializedPayments(context.Context) error // Close will close all connections to rabbitmq Close() error } +type LndHubService interface { + HandleFailedPayment(context.Context, *models.Invoice, models.TransactionEntry, error) error + HandleSuccessfulPayment(context.Context, *models.Invoice, models.TransactionEntry) error + GetAllPendingPayments(context.Context) ([]models.Invoice, error) + GetTransactionEntryByInvoiceId(context.Context, int64) (models.TransactionEntry, error) +} + type DefaultClient struct { - conn *amqp.Connection + conn *amqp.Connection + lndHubService LndHubService // It is recommended that, when possible, publishers and consumers // use separate connections so that consumers are isolated from potential @@ -55,7 +66,9 @@ type DefaultClient struct { logger *lecho.Logger lndInvoiceConsumerQueueName string + lndPaymentConsumerQueueName string lndInvoiceExchange string + lndPaymentExchange string lndHubInvoiceExchange string } @@ -85,6 +98,12 @@ func WithLogger(logger *lecho.Logger) ClientOption { } } +func WithLndHubService(svc LndHubService) ClientOption { + return func(client *DefaultClient) { + client.lndHubService = svc + } +} + // Dial sets up a connection to rabbitmq with two channels that are ready to produce and consume func Dial(uri string, options ...ClientOption) (Client, error) { conn, err := amqp.Dial(uri) @@ -128,6 +147,171 @@ func Dial(uri string, options ...ClientOption) (Client, error) { func (client *DefaultClient) Close() error { return client.conn.Close() } +func (client *DefaultClient) FinalizeInitializedPayments(ctx context.Context) error { + // Sanity check + if client.lndHubService == nil { + return errors.New("no LndHubService provided to rabbitmqClient") + } + + err := client.publishChannel.ExchangeDeclare( + client.lndPaymentExchange, + // topic is a type of exchange that allows routing messages to different queue's bases on a routing key + "topic", + // Durable and Non-Auto-Deleted exchanges will survive server restarts and remain + // declared when there are no remaining bindings. + true, + false, + // Non-Internal exchange's accept direct publishing + false, + // Nowait: We set this to false as we want to wait for a server response + // to check whether the exchange was created succesfully + false, + nil, + ) + if err != nil { + return err + } + + queue, err := client.consumeChannel.QueueDeclare( + client.lndPaymentConsumerQueueName, + // Durable and Non-Auto-Deleted queues will survive server restarts and remain + // declared when there are no remaining bindings. + true, + false, + // None-Exclusive means other consumers can consume from this queue. + // Messages from queues are spread out and load balanced between consumers. + // So multiple lndhub.go instances will spread the load of invoices between them + false, + // Nowait: We set this to false as we want to wait for a server response + // to check whether the queue was created successfully + false, + nil, + ) + if err != nil { + return err + } + + err = client.consumeChannel.QueueBind( + queue.Name, + "payment.outgoing.#", + client.lndPaymentExchange, + // Nowait: We set this to false as we want to wait for a server response + // to check whether the queue was created successfully + false, + nil, + ) + if err != nil { + return err + } + + deliveryChan, err := client.consumeChannel.Consume( + queue.Name, + "", + false, + false, + false, + false, + nil, + ) + if err != nil { + return err + } + + getInvoicesTable := func(ctx context.Context) (map[string]models.Invoice, error) { + invoicesByHash := map[string]models.Invoice{} + pendingInvoices, err := client.lndHubService.GetAllPendingPayments(ctx) + + if err != nil { + return invoicesByHash, err + } + + for _, invoice := range pendingInvoices { + invoicesByHash[invoice.RHash] = invoice + } + return invoicesByHash, nil + } + + pendingInvoices, err := getInvoicesTable(ctx) + if err != nil { + return err + } + + for { + select { + case <-ctx.Done(): + return context.Canceled + case delivery, ok := <-deliveryChan: + // Shortcircuit if no pending invoices are left + if len(pendingInvoices) == 0 { + break + } + + if !ok { + return err + } + + payment := lnrpc.Payment{} + + err := json.Unmarshal(delivery.Body, &payment) + if err != nil { + delivery.Nack(false, false) + + continue + } + + ticker := time.NewTicker(time.Hour) + defer ticker.Stop() + + // Check if paymentHash corresponds to one of the pending invoices + if invoice, ok := pendingInvoices[payment.PaymentHash]; ok == true { + switch payment.Status { + case lnrpc.Payment_SUCCEEDED: + t, err := client.lndHubService.GetTransactionEntryByInvoiceId(ctx, invoice.ID) + if err != nil { + delivery.Nack(false, true) + + continue + } + + if err = client.lndHubService.HandleSuccessfulPayment(ctx, &invoice, t); err != nil { + delivery.Nack(false, true) + + continue + } + + case lnrpc.Payment_FAILED: + t, err := client.lndHubService.GetTransactionEntryByInvoiceId(ctx, invoice.ID) + if err != nil { + delivery.Nack(false, true) + + continue + } + + if err = client.lndHubService.HandleFailedPayment(ctx, &invoice, t, fmt.Errorf(payment.FailureReason.String())); err != nil { + delivery.Nack(false, true) + + continue + } + } + } + + // Refresh the pending invoice table after each tick + select { + case <-ticker.C: + invoices, err := getInvoicesTable(ctx) + pendingInvoices = invoices + if err != nil { + return err + } + } + + delivery.Ack(false) + } + + return nil + } +} + func (client *DefaultClient) SubscribeToLndInvoices(ctx context.Context, handler IncomingInvoiceHandler) error { err := client.publishChannel.ExchangeDeclare( client.lndInvoiceExchange,