diff --git a/integration_tests/lnd_mock.go b/integration_tests/lnd_mock.go index b2cf99d..86c51ff 100644 --- a/integration_tests/lnd_mock.go +++ b/integration_tests/lnd_mock.go @@ -15,6 +15,7 @@ import ( "github.com/getAlby/lndhub.go/lnd" "github.com/labstack/gommon/random" "github.com/lightningnetwork/lnd/lnrpc" + "github.com/lightningnetwork/lnd/lnrpc/routerrpc" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/zpay32" "google.golang.org/grpc" @@ -58,6 +59,9 @@ func (mockSub *MockSubscribeInvoices) Recv() (*lnrpc.Invoice, error) { inv := <-mockSub.invoiceChan return inv, nil } +func (mlnd *MockLND) SubscribePayment(ctx context.Context, req *routerrpc.TrackPaymentRequest, options ...grpc.CallOption) (lnd.SubscribePaymentWrapper, error) { + return nil, nil +} func (mlnd *MockLND) ListChannels(ctx context.Context, req *lnrpc.ListChannelsRequest, options ...grpc.CallOption) (*lnrpc.ListChannelsResponse, error) { return &lnrpc.ListChannelsResponse{ diff --git a/integration_tests/subscription_start_test.go b/integration_tests/subscription_start_test.go index c900f8c..25d3e6e 100644 --- a/integration_tests/subscription_start_test.go +++ b/integration_tests/subscription_start_test.go @@ -17,6 +17,7 @@ import ( "github.com/go-playground/validator/v10" "github.com/labstack/echo/v4" "github.com/lightningnetwork/lnd/lnrpc" + "github.com/lightningnetwork/lnd/lnrpc/routerrpc" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/suite" "github.com/uptrace/bun" @@ -128,6 +129,9 @@ func (mock *lndSubscriptionStartMockClient) SubscribeInvoices(ctx context.Contex func (mock *lndSubscriptionStartMockClient) Recv() (*lnrpc.Invoice, error) { select {} } +func (mock *lndSubscriptionStartMockClient) SubscribePayment(ctx context.Context, req *routerrpc.TrackPaymentRequest, options ...grpc.CallOption) (lnd.SubscribePaymentWrapper, error) { + return nil, nil +} func (mock *lndSubscriptionStartMockClient) GetInfo(ctx context.Context, req *lnrpc.GetInfoRequest, options ...grpc.CallOption) (*lnrpc.GetInfoResponse, error) { panic("not implemented") // TODO: Implement diff --git a/lib/service/checkpayments.go b/lib/service/checkpayments.go index 1832427..32aa4e2 100644 --- a/lib/service/checkpayments.go +++ b/lib/service/checkpayments.go @@ -3,10 +3,10 @@ package service import ( "context" "encoding/hex" - "fmt" "github.com/getAlby/lndhub.go/db/models" "github.com/lightningnetwork/lnd/lnrpc" + "github.com/lightningnetwork/lnd/lnrpc/routerrpc" ) func (svc *LndhubService) CheckAllPendingOutgoingPayments(ctx context.Context) (err error) { @@ -19,53 +19,63 @@ func (svc *LndhubService) CheckAllPendingOutgoingPayments(ctx context.Context) ( svc.Logger.Infof("Found %d pending payments", len(pendingPayments)) //call trackoutgoingpaymentstatus for each one for _, inv := range pendingPayments { - err = svc.TrackOutgoingPaymentstatus(ctx, &inv) - if err != nil { - svc.Logger.Errorf("Error tracking payment %v: %s", inv, err.Error()) - } + //spawn goroutines + go svc.TrackOutgoingPaymentstatus(ctx, &inv) } return nil } -func (svc *LndhubService) TrackOutgoingPaymentstatus(ctx context.Context, invoice *models.Invoice) error { +//Should be called in a goroutine as the tracking can potentially take a long time +func (svc *LndhubService) TrackOutgoingPaymentstatus(ctx context.Context, invoice *models.Invoice) { //fetch the tx entry for the invoice entry := models.TransactionEntry{} err := svc.DB.NewSelect().Model(&entry).Where("invoice_id = ?", invoice.ID).Limit(1).Scan(ctx) if err != nil { - return err + svc.Logger.Errorf("Error tracking payment %v: %s", invoice, err.Error()) + return + } if entry.UserID != invoice.UserID { - return fmt.Errorf("User ID's don't match : entry %v, invoice %v", entry, invoice) + svc.Logger.Errorf("User ID's don't match : entry %v, invoice %v", entry, invoice) + return } //ask lnd using TrackPaymentV2 by hash of payment rawHash, err := hex.DecodeString(invoice.RHash) if err != nil { - return err + svc.Logger.Errorf("Error tracking payment %v: %s", invoice, err.Error()) + return } - payment, err := svc.LndClient.TrackPayment(ctx, rawHash) + paymentTracker, err := svc.LndClient.SubscribePayment(ctx, &routerrpc.TrackPaymentRequest{ + PaymentHash: rawHash, + NoInflightUpdates: true, + }) if err != nil { - return err + svc.Logger.Errorf("Error tracking payment %v: %s", invoice, err.Error()) + return } //call HandleFailedPayment or HandleSuccesfulPayment - if payment.Status == lnrpc.Payment_FAILED { - svc.Logger.Infof("Failed payment detected: %v", payment) - //todo handle failed payment - //return svc.HandleFailedPayment(ctx, invoice, entry, fmt.Errorf(payment.FailureReason.String())) - return nil + for { + payment, err := paymentTracker.Recv() + if err != nil { + svc.Logger.Errorf("Error tracking payment %v: %s", invoice, err.Error()) + return + } + if payment.Status == lnrpc.Payment_FAILED { + svc.Logger.Infof("Failed payment detected: %v", payment) + //todo handle failed payment + //return svc.HandleFailedPayment(ctx, invoice, entry, fmt.Errorf(payment.FailureReason.String())) + return + } + if payment.Status == lnrpc.Payment_SUCCEEDED { + invoice.Fee = payment.FeeSat + invoice.Preimage = payment.PaymentPreimage + svc.Logger.Infof("Completed payment detected: %v", payment) + //todo handle succesful payment + //return svc.HandleSuccessfulPayment(ctx, invoice, entry) + return + } + //Since we shouldn't get in-flight updates we shouldn't get here + svc.Logger.Warnf("Got an unexpected in-flight update %v", payment) } - if payment.Status == lnrpc.Payment_SUCCEEDED { - invoice.Fee = payment.FeeSat - invoice.Preimage = payment.PaymentPreimage - svc.Logger.Infof("Completed payment detected: %v", payment) - //todo handle succesful payment - //return svc.HandleSuccessfulPayment(ctx, invoice, entry) - return nil - } - if payment.Status == lnrpc.Payment_IN_FLIGHT { - //todo handle inflight payment - svc.Logger.Infof("In-flight payment detected: %v", payment) - return nil - } - return nil } diff --git a/lnd/interface.go b/lnd/interface.go index f4b3da1..e7de050 100644 --- a/lnd/interface.go +++ b/lnd/interface.go @@ -4,6 +4,7 @@ import ( "context" "github.com/lightningnetwork/lnd/lnrpc" + "github.com/lightningnetwork/lnd/lnrpc/routerrpc" "google.golang.org/grpc" ) @@ -12,11 +13,14 @@ type LightningClientWrapper interface { SendPaymentSync(ctx context.Context, req *lnrpc.SendRequest, options ...grpc.CallOption) (*lnrpc.SendResponse, error) AddInvoice(ctx context.Context, req *lnrpc.Invoice, options ...grpc.CallOption) (*lnrpc.AddInvoiceResponse, error) SubscribeInvoices(ctx context.Context, req *lnrpc.InvoiceSubscription, options ...grpc.CallOption) (SubscribeInvoicesWrapper, error) + SubscribePayment(ctx context.Context, req *routerrpc.TrackPaymentRequest, options ...grpc.CallOption) (SubscribePaymentWrapper, error) GetInfo(ctx context.Context, req *lnrpc.GetInfoRequest, options ...grpc.CallOption) (*lnrpc.GetInfoResponse, error) DecodeBolt11(ctx context.Context, bolt11 string, options ...grpc.CallOption) (*lnrpc.PayReq, error) - TrackPayment(ctx context.Context, hash []byte, options ...grpc.CallOption) (*lnrpc.Payment, error) } type SubscribeInvoicesWrapper interface { Recv() (*lnrpc.Invoice, error) } +type SubscribePaymentWrapper interface { + Recv() (*lnrpc.Payment, error) +} diff --git a/lnd/lnd.go b/lnd/lnd.go index 2d4502e..ec044ba 100644 --- a/lnd/lnd.go +++ b/lnd/lnd.go @@ -126,6 +126,10 @@ func (wrapper *LNDWrapper) DecodeBolt11(ctx context.Context, bolt11 string, opti }) } +func (wrapper *LNDWrapper) SubscribePayment(ctx context.Context, req *routerrpc.TrackPaymentRequest, options ...grpc.CallOption) (SubscribePaymentWrapper, error) { + return wrapper.routerClient.TrackPaymentV2(ctx, req, options...) +} + func (wrapper *LNDWrapper) TrackPayment(ctx context.Context, hash []byte, options ...grpc.CallOption) (*lnrpc.Payment, error) { client, err := wrapper.routerClient.TrackPaymentV2(ctx, &routerrpc.TrackPaymentRequest{ PaymentHash: []byte(hash), diff --git a/main.go b/main.go index 9189512..4032e67 100644 --- a/main.go +++ b/main.go @@ -166,12 +166,11 @@ func main() { go svc.InvoiceUpdateSubscription(context.Background()) // Check the status of all pending outgoing payments - go func() { - err = svc.CheckAllPendingOutgoingPayments(context.Background()) - if err != nil { - svc.Logger.Error(err) - } - }() + // A goroutine will be spawned for each one + err = svc.CheckAllPendingOutgoingPayments(context.Background()) + if err != nil { + svc.Logger.Error(err) + } //Start webhook subscription if svc.Config.WebhookUrl != "" {