diff --git a/main.go b/main.go index aeda4c0..4731195 100644 --- a/main.go +++ b/main.go @@ -160,7 +160,12 @@ func main() { // No rabbitmq features will be available in this case. var rabbitmqClient rabbitmq.Client if c.RabbitMQUri != "" { - rabbitmqClient, err = rabbitmq.Dial(c.RabbitMQUri, + amqpClient, err := rabbitmq.DialAMQP(c.RabbitMQUri) + if err != nil { + logger.Fatal(err) + } + + rabbitmqClient, err = rabbitmq.NewClient(amqpClient, rabbitmq.WithLogger(logger), rabbitmq.WithLndInvoiceExchange(c.RabbitMQLndInvoiceExchange), rabbitmq.WithLndHubInvoiceExchange(c.RabbitMQLndhubInvoiceExchange), diff --git a/rabbitmq/amqp.go b/rabbitmq/amqp.go index 4911273..68b2fe9 100644 --- a/rabbitmq/amqp.go +++ b/rabbitmq/amqp.go @@ -13,7 +13,7 @@ type AMQPClient interface { Close() error } -type DefaultAMQPCLient struct { +type defaultAMQPCLient struct { conn *amqp.Connection // It is recommended that, when possible, publishers and consumers @@ -23,9 +23,9 @@ type DefaultAMQPCLient struct { publishChannel *amqp.Channel } -func (c *DefaultAMQPCLient) Close() error { return c.conn.Close() } +func (c *defaultAMQPCLient) Close() error { return c.conn.Close() } -func (c *DefaultAMQPCLient) ExchangeDeclare(name, kind string, durable, autoDelete, internal, noWait bool, args amqp.Table) error { +func (c *defaultAMQPCLient) ExchangeDeclare(name, kind string, durable, autoDelete, internal, noWait bool, args amqp.Table) error { // TODO: Seperate management channel? Or provide way to select channel? ch, err := c.conn.Channel() if err != nil { @@ -38,7 +38,7 @@ func (c *DefaultAMQPCLient) ExchangeDeclare(name, kind string, durable, autoDele } -type listenOptions struct { +type ListenOptions struct { Durable bool AutoDelete bool Internal bool @@ -47,52 +47,52 @@ type listenOptions struct { AutoAck bool } -type AMQPListenOptions = func(opts listenOptions) listenOptions +type AMQPListenOptions = func(opts ListenOptions) ListenOptions func WithDurable(durable bool) AMQPListenOptions { - return func(opts listenOptions) listenOptions { + return func(opts ListenOptions) ListenOptions { opts.Durable = durable return opts } } func WithAutoDelete(autoDelete bool) AMQPListenOptions { - return func(opts listenOptions) listenOptions { + return func(opts ListenOptions) ListenOptions { opts.AutoDelete = autoDelete return opts } } func WithInternal(internal bool) AMQPListenOptions { - return func(opts listenOptions) listenOptions { + return func(opts ListenOptions) ListenOptions { opts.Internal = internal return opts } } func WithWait(wait bool) AMQPListenOptions { - return func(opts listenOptions) listenOptions { + return func(opts ListenOptions) ListenOptions { opts.Wait = wait return opts } } func WithExclusive(exclusive bool) AMQPListenOptions { - return func(opts listenOptions) listenOptions { + return func(opts ListenOptions) ListenOptions { opts.Exclusive = exclusive return opts } } func WithAutoAck(autoAck bool) AMQPListenOptions { - return func(opts listenOptions) listenOptions { + return func(opts ListenOptions) ListenOptions { opts.AutoAck = autoAck return opts } } -func (c *DefaultAMQPCLient) Listen(ctx context.Context, exchange string, routingKey string, queueName string, options ...AMQPListenOptions) (<-chan amqp.Delivery, error) { - opts := listenOptions{ +func (c *defaultAMQPCLient) Listen(ctx context.Context, exchange string, routingKey string, queueName string, options ...AMQPListenOptions) (<-chan amqp.Delivery, error) { + opts := ListenOptions{ Durable: true, AutoDelete: false, Internal: false, @@ -167,11 +167,11 @@ func (c *DefaultAMQPCLient) Listen(ctx context.Context, exchange string, routing ) } -func (c *DefaultAMQPCLient) PublishWithContext(ctx context.Context, exchange string, key string, mandatory bool, immediate bool, msg amqp.Publishing) error { +func (c *defaultAMQPCLient) PublishWithContext(ctx context.Context, exchange string, key string, mandatory bool, immediate bool, msg amqp.Publishing) error { return c.publishChannel.PublishWithContext(ctx, exchange, key, mandatory, immediate, msg) } -func Dial(uri string) (AMQPClient, error) { +func DialAMQP(uri string) (AMQPClient, error) { conn, err := amqp.Dial(uri) if err != nil { return nil, err @@ -187,7 +187,7 @@ func Dial(uri string) (AMQPClient, error) { return nil, err } - return &DefaultAMQPCLient{ + return &defaultAMQPCLient{ conn, consumeChannel, publishChannel, diff --git a/rabbitmq/rabbitmqmocks/rabbitmq.go b/rabbitmq/mock_rabbitmq/rabbitmq.go similarity index 51% rename from rabbitmq/rabbitmqmocks/rabbitmq.go rename to rabbitmq/mock_rabbitmq/rabbitmq.go index 7a6a35a..8c6661f 100644 --- a/rabbitmq/rabbitmqmocks/rabbitmq.go +++ b/rabbitmq/mock_rabbitmq/rabbitmq.go @@ -1,15 +1,17 @@ // Code generated by MockGen. DO NOT EDIT. -// Source: github.com/getAlby/lndhub.go/rabbitmq (interfaces: LndHubService) +// Source: github.com/getAlby/lndhub.go/rabbitmq (interfaces: LndHubService,AMQPClient) -// Package rabbitmqmocks is a generated GoMock package. -package rabbitmqmocks +// Package mock_rabbitmq is a generated GoMock package. +package mock_rabbitmq import ( context "context" reflect "reflect" models "github.com/getAlby/lndhub.go/db/models" + rabbitmq "github.com/getAlby/lndhub.go/rabbitmq" gomock "github.com/golang/mock/gomock" + amqp091 "github.com/rabbitmq/amqp091-go" ) // MockLndHubService is a mock of LndHubService interface. @@ -92,3 +94,88 @@ func (mr *MockLndHubServiceMockRecorder) HandleSuccessfulPayment(arg0, arg1, arg mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandleSuccessfulPayment", reflect.TypeOf((*MockLndHubService)(nil).HandleSuccessfulPayment), arg0, arg1, arg2) } + +// MockAMQPClient is a mock of AMQPClient interface. +type MockAMQPClient struct { + ctrl *gomock.Controller + recorder *MockAMQPClientMockRecorder +} + +// MockAMQPClientMockRecorder is the mock recorder for MockAMQPClient. +type MockAMQPClientMockRecorder struct { + mock *MockAMQPClient +} + +// NewMockAMQPClient creates a new mock instance. +func NewMockAMQPClient(ctrl *gomock.Controller) *MockAMQPClient { + mock := &MockAMQPClient{ctrl: ctrl} + mock.recorder = &MockAMQPClientMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockAMQPClient) EXPECT() *MockAMQPClientMockRecorder { + return m.recorder +} + +// Close mocks base method. +func (m *MockAMQPClient) Close() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Close") + ret0, _ := ret[0].(error) + return ret0 +} + +// Close indicates an expected call of Close. +func (mr *MockAMQPClientMockRecorder) Close() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockAMQPClient)(nil).Close)) +} + +// ExchangeDeclare mocks base method. +func (m *MockAMQPClient) ExchangeDeclare(arg0, arg1 string, arg2, arg3, arg4, arg5 bool, arg6 amqp091.Table) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ExchangeDeclare", arg0, arg1, arg2, arg3, arg4, arg5, arg6) + ret0, _ := ret[0].(error) + return ret0 +} + +// ExchangeDeclare indicates an expected call of ExchangeDeclare. +func (mr *MockAMQPClientMockRecorder) ExchangeDeclare(arg0, arg1, arg2, arg3, arg4, arg5, arg6 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ExchangeDeclare", reflect.TypeOf((*MockAMQPClient)(nil).ExchangeDeclare), arg0, arg1, arg2, arg3, arg4, arg5, arg6) +} + +// Listen mocks base method. +func (m *MockAMQPClient) Listen(arg0 context.Context, arg1, arg2, arg3 string, arg4 ...func(rabbitmq.ListenOptions) rabbitmq.ListenOptions) (<-chan amqp091.Delivery, error) { + m.ctrl.T.Helper() + varargs := []interface{}{arg0, arg1, arg2, arg3} + for _, a := range arg4 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Listen", varargs...) + ret0, _ := ret[0].(<-chan amqp091.Delivery) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Listen indicates an expected call of Listen. +func (mr *MockAMQPClientMockRecorder) Listen(arg0, arg1, arg2, arg3 interface{}, arg4 ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{arg0, arg1, arg2, arg3}, arg4...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Listen", reflect.TypeOf((*MockAMQPClient)(nil).Listen), varargs...) +} + +// PublishWithContext mocks base method. +func (m *MockAMQPClient) PublishWithContext(arg0 context.Context, arg1, arg2 string, arg3, arg4 bool, arg5 amqp091.Publishing) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "PublishWithContext", arg0, arg1, arg2, arg3, arg4, arg5) + ret0, _ := ret[0].(error) + return ret0 +} + +// PublishWithContext indicates an expected call of PublishWithContext. +func (mr *MockAMQPClientMockRecorder) PublishWithContext(arg0, arg1, arg2, arg3, arg4, arg5 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PublishWithContext", reflect.TypeOf((*MockAMQPClient)(nil).PublishWithContext), arg0, arg1, arg2, arg3, arg4, arg5) +} diff --git a/rabbitmq/rabbitmq.go b/rabbitmq/rabbitmq.go index cb1dffc..ef7884e 100644 --- a/rabbitmq/rabbitmq.go +++ b/rabbitmq/rabbitmq.go @@ -45,6 +45,14 @@ type Client interface { Close() error } +type ClientConfig struct { + lndInvoiceConsumerQueueName string + lndPaymentConsumerQueueName string + lndInvoiceExchange string + lndPaymentExchange string + lndHubInvoiceExchange string +} + type LndHubService interface { HandleFailedPayment(context.Context, *models.Invoice, models.TransactionEntry, error) error HandleSuccessfulPayment(context.Context, *models.Invoice, models.TransactionEntry) error @@ -53,45 +61,41 @@ type LndHubService interface { } type DefaultClient struct { - amqpClient AMQPClient - logger *lecho.Logger + amqpClient AMQPClient + logger *lecho.Logger - lndInvoiceConsumerQueueName string - lndPaymentConsumerQueueName string - lndInvoiceExchange string - lndPaymentExchange string - lndHubInvoiceExchange string + config ClientConfig } type ClientOption = func(client *DefaultClient) func WithLndInvoiceExchange(exchange string) ClientOption { return func(client *DefaultClient) { - client.lndInvoiceExchange = exchange + client.config.lndInvoiceExchange = exchange } } func WithLndHubInvoiceExchange(exchange string) ClientOption { return func(client *DefaultClient) { - client.lndHubInvoiceExchange = exchange + client.config.lndHubInvoiceExchange = exchange } } func WithLndInvoiceConsumerQueueName(name string) ClientOption { return func(client *DefaultClient) { - client.lndInvoiceConsumerQueueName = name + client.config.lndInvoiceConsumerQueueName = name } } func WithLndPaymentConsumerQueueName(name string) ClientOption { return func(client *DefaultClient) { - client.lndPaymentConsumerQueueName = name + client.config.lndPaymentConsumerQueueName = name } } func WithLndPaymentExchange(exchange string) ClientOption { return func(client *DefaultClient) { - client.lndPaymentExchange = exchange + client.config.lndPaymentExchange = exchange } } @@ -104,7 +108,7 @@ func WithLogger(logger *lecho.Logger) ClientOption { // Dial sets up a connection to rabbitmq with two channels that are ready to produce and consume func NewClient(amqpClient AMQPClient, options ...ClientOption) (Client, error) { client := &DefaultClient{ - amqpClient: amqpClient, + amqpClient: amqpClient, logger: lecho.New( os.Stdout, @@ -112,11 +116,13 @@ func NewClient(amqpClient AMQPClient, options ...ClientOption) (Client, error) { lecho.WithTimestamp(), ), - lndInvoiceConsumerQueueName: "lnd_invoice_consumer", - lndPaymentConsumerQueueName: "lnd_payment_consumer", - lndInvoiceExchange: "lnd_invoice", - lndPaymentExchange: "lnd_payment", - lndHubInvoiceExchange: "lndhub_invoice", + config: ClientConfig{ + lndInvoiceConsumerQueueName: "lnd_invoice_consumer", + lndPaymentConsumerQueueName: "lnd_payment_consumer", + lndInvoiceExchange: "lnd_invoice", + lndPaymentExchange: "lnd_payment", + lndHubInvoiceExchange: "lndhub_invoice", + }, } for _, opt := range options { @@ -130,11 +136,11 @@ func (client *DefaultClient) Close() error { return client.amqpClient.Close() } func (client *DefaultClient) FinalizeInitializedPayments(ctx context.Context, svc LndHubService) error { deliveryChan, err := client.amqpClient.Listen( - ctx, - client.lndPaymentExchange, - "payment.outgoing.*", - client.lndPaymentConsumerQueueName, - ) + ctx, + client.config.lndPaymentExchange, + "payment.outgoing.*", + client.config.lndPaymentConsumerQueueName, + ) if err != nil { return err } @@ -157,30 +163,37 @@ func (client *DefaultClient) FinalizeInitializedPayments(ctx context.Context, sv if err != nil { return err } + client.logger.Infof("Payment finalizer: Found %d pending invoices", len(pendingInvoices)) ticker := time.NewTicker(time.Hour) defer ticker.Stop() client.logger.Info("Starting payment finalizer rabbitmq consumer") + for { + // Shortcircuit if no pending invoices are left + if len(pendingInvoices) == 0 { + client.logger.Info("Payment finalizer: Resolved all pending payments, exiting payment finalizer routine") + + return nil + } + select { case <-ctx.Done(): return context.Canceled + case <-ticker.C: invoices, err := getInvoicesTable(ctx) - pendingInvoices = invoices - client.logger.Infof("Payment finalizer: Found %d pending invoices", len(pendingInvoices)) if err != nil { return err } - case delivery, ok := <-deliveryChan: - // Shortcircuit if no pending invoices are left - if len(pendingInvoices) == 0 { - client.logger.Info("Payment finalizer: Resolved all pending payments, exiting payment finalizer routine") - return nil - } + pendingInvoices = invoices + + client.logger.Infof("Payment finalizer: Found %d pending invoices", len(pendingInvoices)) + + case delivery, ok := <-deliveryChan: if !ok { return err } @@ -215,6 +228,7 @@ func (client *DefaultClient) FinalizeInitializedPayments(ctx context.Context, sv continue } + client.logger.Infof("Payment finalizer: updated successful payment with hash: %s", payment.PaymentHash) delete(pendingInvoices, payment.PaymentHash) @@ -225,17 +239,18 @@ func (client *DefaultClient) FinalizeInitializedPayments(ctx context.Context, sv continue } + client.logger.Infof("Payment finalizer: updated failed payment with hash: %s", payment.PaymentHash) delete(pendingInvoices, payment.PaymentHash) } } delivery.Ack(false) } - } + } } func (client *DefaultClient) SubscribeToLndInvoices(ctx context.Context, handler IncomingInvoiceHandler) error { - deliveryChan, err := client.amqpClient.Listen(ctx, client.lndInvoiceExchange, "invoice.incoming.settled", client.lndInvoiceConsumerQueueName) + deliveryChan, err := client.amqpClient.Listen(ctx, client.config.lndInvoiceExchange, "invoice.incoming.settled", client.config.lndInvoiceConsumerQueueName) if err != nil { return err } @@ -291,7 +306,7 @@ func (client *DefaultClient) SubscribeToLndInvoices(ctx context.Context, handler func (client *DefaultClient) StartPublishInvoices(ctx context.Context, invoicesSubscribeFunc SubscribeToInvoicesFunc, payloadFunc EncodeOutgoingInvoiceFunc) error { err := client.amqpClient.ExchangeDeclare( - client.lndHubInvoiceExchange, + client.config.lndHubInvoiceExchange, // topic is a type of exchange that allows routing messages to different queue's bases on a routing key "topic", // Durable and Non-Auto-Deleted exchanges will survive server restarts and remain @@ -346,7 +361,7 @@ func (client *DefaultClient) publishToLndhubExchange(ctx context.Context, invoic key := fmt.Sprintf("invoice.%s.%s", invoice.Type, invoice.State) err = client.amqpClient.PublishWithContext(ctx, - client.lndHubInvoiceExchange, + client.config.lndHubInvoiceExchange, key, false, false, diff --git a/rabbitmq/rabbitmq_test.go b/rabbitmq/rabbitmq_test.go index 092d502..fbcf14f 100644 --- a/rabbitmq/rabbitmq_test.go +++ b/rabbitmq/rabbitmq_test.go @@ -1,11 +1,106 @@ package rabbitmq_test import ( + "context" + "encoding/json" + "sync" "testing" + "time" + + "github.com/getAlby/lndhub.go/db/models" + "github.com/getAlby/lndhub.go/rabbitmq" + "github.com/getAlby/lndhub.go/rabbitmq/mock_rabbitmq" + "github.com/golang/mock/gomock" + "github.com/lightningnetwork/lnd/lnrpc" + amqp "github.com/rabbitmq/amqp091-go" + "github.com/stretchr/testify/assert" ) -//go:generate mockgen -destination=./rabbitmqmocks/rabbitmq.go -package rabbitmqmocks github.com/getAlby/lndhub.go/rabbitmq LndHubService +//go:generate mockgen -destination=./mock_rabbitmq/rabbitmq.go github.com/getAlby/lndhub.go/rabbitmq LndHubService,AMQPClient func TestFinalizedInitializedPayments(t *testing.T) { t.Parallel() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + lndHubService := mock_rabbitmq.NewMockLndHubService(ctrl) + amqpClient := mock_rabbitmq.NewMockAMQPClient(ctrl) + + client, err := rabbitmq.NewClient(amqpClient) + assert.NoError(t, err) + + ch := make(chan amqp.Delivery, 1) + amqpClient.EXPECT(). + Listen(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + MaxTimes(1). + Return(ch, nil) + + hash := "69e5f0f0590be75e30f671d56afe1d55" + + invoices := []models.Invoice{ + { + ID: 0, + RHash: hash, + }, + } + + lndHubService.EXPECT(). + GetAllPendingPayments(gomock.Any()). + MaxTimes(1). + Return(invoices, nil) + + lndHubService.EXPECT(). + HandleFailedPayment(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + AnyTimes(). + Return(nil) + + lndHubService.EXPECT(). + HandleSuccessfulPayment(gomock.Any(), gomock.Any(), gomock.Any()). + AnyTimes(). + Return(nil) + + lndHubService.EXPECT(). + GetTransactionEntryByInvoiceId(gomock.Any(), gomock.Eq(invoices[0].ID)). + AnyTimes(). + Return(models.TransactionEntry{InvoiceID: invoices[0].ID}, nil) + + ctx := context.Background() + b, err := json.Marshal(&lnrpc.Payment{PaymentHash: hash, Status: lnrpc.Payment_SUCCEEDED}) + if err != nil { + t.Error(err) + } + + ch <- amqp.Delivery{Body: b} + + wg := sync.WaitGroup{} + + wg.Add(1) + go func() { + err = client.FinalizeInitializedPayments(ctx, lndHubService) + + assert.NoError(t, err) + wg.Done() + }() + + waitTimeout(&wg, time.Second * 3, t) +} + +// waitTimeout waits for the waitgroup for the specified max timeout. +// Returns true if waiting timed out. +func waitTimeout(wg *sync.WaitGroup, timeout time.Duration, t *testing.T) bool { + c := make(chan struct{}) + go func() { + defer close(c) + wg.Wait() + }() + + select { + case <-c: + return false // completed normally + + case <-time.After(timeout): + t.Errorf("Waiting on waitgroup timed out during test") + + return true // timed out + } }