diff --git a/.github/workflows/integration_tests.yaml b/.github/workflows/integration_tests.yaml index ad3235d..3acf1ef 100644 --- a/.github/workflows/integration_tests.yaml +++ b/.github/workflows/integration_tests.yaml @@ -43,6 +43,4 @@ jobs: - name: Run tests run: go test -p 1 -v -covermode=atomic -coverprofile=coverage.out -cover -coverpkg=./... ./... env: - RABBITMQ_URI: amqp://root:password@localhost:5672 - - name: Upload coverage reports to Codecov - uses: codecov/codecov-action@v3 + RABBITMQ_URI: amqp://root:password@localhost:5672 \ No newline at end of file diff --git a/controllers/keysend.ctrl.go b/controllers/keysend.ctrl.go index 1ae22a9..add95a3 100644 --- a/controllers/keysend.ctrl.go +++ b/controllers/keysend.ctrl.go @@ -82,7 +82,7 @@ func (controller *KeySendController) KeySend(c echo.Context) error { }) } - ok, err := controller.svc.BalanceCheck(c.Request().Context(), lnPayReq, userID) + resp, err := controller.svc.CheckPaymentAllowed(c.Request().Context(), lnPayReq, userID) if err != nil { c.Logger().Errorj( log.JSON{ @@ -93,9 +93,9 @@ func (controller *KeySendController) KeySend(c echo.Context) error { ) return c.JSON(http.StatusBadRequest, responses.BadArgumentsError) } - if !ok { + if resp != nil { c.Logger().Errorf("User does not have enough balance user_id:%v amount:%v", userID, lnPayReq.PayReq.NumSatoshis) - return c.JSON(http.StatusBadRequest, responses.NotEnoughBalanceError) + return c.JSON(http.StatusBadRequest, resp) } invoice, err := controller.svc.AddOutgoingInvoice(c.Request().Context(), userID, "", lnPayReq) if err != nil { diff --git a/controllers/payinvoice.ctrl.go b/controllers/payinvoice.ctrl.go index 69bc67e..91bbc69 100644 --- a/controllers/payinvoice.ctrl.go +++ b/controllers/payinvoice.ctrl.go @@ -97,7 +97,7 @@ func (controller *PayInvoiceController) PayInvoice(c echo.Context) error { } } - ok, err := controller.svc.BalanceCheck(c.Request().Context(), lnPayReq, userID) + resp, err := controller.svc.CheckPaymentAllowed(c.Request().Context(), lnPayReq, userID) if err != nil { c.Logger().Errorj( log.JSON{ @@ -108,9 +108,9 @@ func (controller *PayInvoiceController) PayInvoice(c echo.Context) error { ) return c.JSON(http.StatusBadRequest, responses.GeneralServerError) } - if !ok { + if resp != nil { c.Logger().Errorf("User does not have enough balance user_id:%v amount:%v", userID, lnPayReq.PayReq.NumSatoshis) - return c.JSON(http.StatusBadRequest, responses.NotEnoughBalanceError) + return c.JSON(http.StatusBadRequest, resp) } invoice, err := controller.svc.AddOutgoingInvoice(c.Request().Context(), userID, paymentRequest, lnPayReq) diff --git a/controllers_v2/keysend.ctrl.go b/controllers_v2/keysend.ctrl.go index d947b8c..96490f0 100644 --- a/controllers_v2/keysend.ctrl.go +++ b/controllers_v2/keysend.ctrl.go @@ -172,14 +172,14 @@ func (controller *KeySendController) SingleKeySend(ctx context.Context, reqBody HttpStatusCode: 400, } } - ok, err := controller.svc.BalanceCheck(ctx, lnPayReq, userID) + resp, err := controller.svc.CheckPaymentAllowed(c.Request().Context(), lnPayReq, userID) if err != nil { controller.svc.Logger.Error(err) return nil, &responses.GeneralServerError } - if !ok { - controller.svc.Logger.Errorf("User does not have enough balance user_id:%v amount:%v", userID, lnPayReq.PayReq.NumSatoshis) - return nil, &responses.NotEnoughBalanceError + if resp != nil { + c.Logger().Errorf("User does not have enough balance user_id:%v amount:%v", userID, lnPayReq.PayReq.NumSatoshis) + return nil, resp } invoice, err := controller.svc.AddOutgoingInvoice(ctx, userID, "", lnPayReq) if err != nil { diff --git a/controllers_v2/payinvoice.ctrl.go b/controllers_v2/payinvoice.ctrl.go index 1724e81..ac14d64 100644 --- a/controllers_v2/payinvoice.ctrl.go +++ b/controllers_v2/payinvoice.ctrl.go @@ -98,7 +98,7 @@ func (controller *PayInvoiceController) PayInvoice(c echo.Context) error { } lnPayReq.PayReq.NumSatoshis = amt } - ok, err := controller.svc.BalanceCheck(c.Request().Context(), lnPayReq, userID) + resp, err := controller.svc.CheckPaymentAllowed(c.Request().Context(), lnPayReq, userID) if err != nil { c.Logger().Errorj( log.JSON{ @@ -109,9 +109,9 @@ func (controller *PayInvoiceController) PayInvoice(c echo.Context) error { ) return err } - if !ok { + if resp != nil { c.Logger().Errorf("User does not have enough balance user_id:%v amount:%v", userID, lnPayReq.PayReq.NumSatoshis) - return c.JSON(http.StatusInternalServerError, responses.NotEnoughBalanceError) + return c.JSON(http.StatusInternalServerError, resp) } invoice, err := controller.svc.AddOutgoingInvoice(c.Request().Context(), userID, paymentRequest, lnPayReq) if err != nil { diff --git a/db/migrations/20230928130000_add_invoice_settled_index.up.sql b/db/migrations/20230928130000_add_invoice_settled_index.up.sql new file mode 100644 index 0000000..e4b547b --- /dev/null +++ b/db/migrations/20230928130000_add_invoice_settled_index.up.sql @@ -0,0 +1,3 @@ +CREATE INDEX CONCURRENTLY IF NOT EXISTS index_invoices_on_user_id_settled_at + ON invoices(user_id, settled_at) + INCLUDE(amount); diff --git a/integration_tests/internal_payment_test.go b/integration_tests/internal_payment_test.go index 3d117fa..ec68810 100644 --- a/integration_tests/internal_payment_test.go +++ b/integration_tests/internal_payment_test.go @@ -133,6 +133,67 @@ func (suite *PaymentTestSuite) TestPaymentFeeReserve() { //reset fee reserve so it's not used in other tests suite.service.Config.FeeReserve = false } +func (suite *PaymentTestSuite) TestVolumeExceeded() { + //this will cause the payment to fail as the account was already funded + //with 1000 sats + suite.service.Config.MaxVolume = 999 + suite.service.Config.MaxVolumePeriod = 2592000 + aliceFundingSats := 1000 + //fund alice account + invoiceResponse := suite.createAddInvoiceReq(aliceFundingSats, "integration test internal payment alice", suite.aliceToken) + err := suite.mlnd.mockPaidInvoice(invoiceResponse, 0, false, nil) + assert.NoError(suite.T(), err) + + //wait a bit for the payment to be processed + time.Sleep(10 * time.Millisecond) + + //try to make external payment + //which should fail + //create external invoice + externalSatRequested := 1000 + externalInvoice := lnrpc.Invoice{ + Memo: "integration tests: external pay from user", + Value: int64(externalSatRequested), + } + invoice, err := suite.externalLND.AddInvoice(context.Background(), &externalInvoice) + assert.NoError(suite.T(), err) + //pay external invoice + rec := httptest.NewRecorder() + var buf bytes.Buffer + assert.NoError(suite.T(), json.NewEncoder(&buf).Encode(&ExpectedPayInvoiceRequestBody{ + Invoice: invoice.PaymentRequest, + })) + req := httptest.NewRequest(http.MethodPost, "/payinvoice", &buf) + req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON) + req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", suite.aliceToken)) + suite.echo.ServeHTTP(rec, req) + //should fail because max volume check + assert.Equal(suite.T(), http.StatusBadRequest, rec.Code) + resp := &responses.ErrorResponse{} + err = json.NewDecoder(rec.Body).Decode(resp) + assert.NoError(suite.T(), err) + assert.Equal(suite.T(), responses.TooMuchVolumeError.Message, resp.Message) + + //change the period to be 1 second, sleep for 2 seconds, try to make another payment, this should work + suite.service.Config.MaxVolumePeriod = 1 + time.Sleep(2 * time.Second) + rec = httptest.NewRecorder() + externalInvoice = lnrpc.Invoice{ + Memo: "integration tests: external pay from user", + Value: int64(externalSatRequested), + } + invoice, err = suite.externalLND.AddInvoice(context.Background(), &externalInvoice) + assert.NoError(suite.T(), err) + assert.NoError(suite.T(), json.NewEncoder(&buf).Encode(&ExpectedPayInvoiceRequestBody{ + Invoice: invoice.PaymentRequest, + })) + suite.echo.ServeHTTP(rec, req) + assert.Equal(suite.T(), http.StatusOK, rec.Code) + + //change the config back + suite.service.Config.MaxVolumePeriod = 0 + suite.service.Config.MaxVolume = 1e6 +} func (suite *PaymentTestSuite) TestInternalPayment() { aliceFundingSats := 1000 bobSatRequested := 500 diff --git a/integration_tests/keysend_failure_test.go b/integration_tests/keysend_failure_test.go new file mode 100644 index 0000000..25659c7 --- /dev/null +++ b/integration_tests/keysend_failure_test.go @@ -0,0 +1,118 @@ +package integration_tests + +import ( + "context" + "fmt" + "log" + "testing" + "time" + + "github.com/getAlby/lndhub.go/common" + "github.com/getAlby/lndhub.go/controllers" + "github.com/getAlby/lndhub.go/lib" + "github.com/getAlby/lndhub.go/lib/responses" + "github.com/getAlby/lndhub.go/lib/service" + "github.com/getAlby/lndhub.go/lib/tokens" + "github.com/go-playground/validator/v10" + "github.com/labstack/echo/v4" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" +) + +type KeySendFailureTestSuite struct { + TestSuite + service *service.LndhubService + mlnd *MockLND + aliceLogin ExpectedCreateUserResponseBody + aliceToken string + invoiceUpdateSubCancelFn context.CancelFunc + serviceClient *LNDMockWrapperAsync +} + +func (suite *KeySendFailureTestSuite) TearDownTest() { + clearTable(suite.service, "transaction_entries") + clearTable(suite.service, "invoices") +} + +func (suite *KeySendFailureTestSuite) TearDownSuite() { + suite.invoiceUpdateSubCancelFn() +} + +func (suite *KeySendFailureTestSuite) SetupSuite() { + fee := int64(1) + mlnd := newDefaultMockLND() + mlnd.fee = fee + suite.mlnd = mlnd + // inject fake lnd client with failing send payment sync into service + lndClient, err := NewLNDMockWrapperAsync(mlnd) + suite.serviceClient = lndClient + if err != nil { + log.Fatalf("Error setting up test client: %v", err) + } + svc, err := LndHubTestServiceInit(lndClient) + if err != nil { + log.Fatalf("Error initializing test service: %v", err) + } + users, userTokens, err := createUsers(svc, 1) + if err != nil { + log.Fatalf("Error creating test users: %v", err) + } + // Subscribe to LND invoice updates in the background + // store cancel func to be called in tear down suite + ctx, cancel := context.WithCancel(context.Background()) + suite.invoiceUpdateSubCancelFn = cancel + go svc.InvoiceUpdateSubscription(ctx) + suite.service = svc + e := echo.New() + + e.HTTPErrorHandler = responses.HTTPErrorHandler + e.Validator = &lib.CustomValidator{Validator: validator.New()} + suite.echo = e + assert.Equal(suite.T(), 1, len(users)) + assert.Equal(suite.T(), 1, len(userTokens)) + suite.aliceLogin = users[0] + suite.aliceToken = userTokens[0] + suite.echo.Use(tokens.Middleware([]byte(suite.service.Config.JWTSecret))) + suite.echo.POST("/addinvoice", controllers.NewAddInvoiceController(suite.service).AddInvoice) + suite.echo.POST("/keysend", controllers.NewKeySendController(suite.service).KeySend) +} + + +func (suite *KeySendFailureTestSuite) TestKeysendPayment() { + aliceFundingSats := 1000 + externalSatRequested := 500 + //fund alice account + invoiceResponse := suite.createAddInvoiceReq(aliceFundingSats, "integration test external payment alice", suite.aliceToken) + err := suite.mlnd.mockPaidInvoice(invoiceResponse, 0, false, nil) + assert.NoError(suite.T(), err) + + //wait a bit for the callback event to hit + time.Sleep(10 * time.Millisecond) + + go suite.createKeySendReqError(int64(externalSatRequested), "key send test", "123456789012345678901234567890123456789012345678901234567890abcdef", suite.aliceToken) + + suite.serviceClient.FailPayment(SendPaymentMockError) + time.Sleep(2 * time.Second) + + // check that balance was reverted + userId := getUserIdFromToken(suite.aliceToken) + aliceBalance, err := suite.service.CurrentUserBalance(context.Background(), userId) + if err != nil { + fmt.Printf("Error when getting balance %v\n", err.Error()) + } + assert.Equal(suite.T(), int64(aliceFundingSats), aliceBalance) + + invoices, err := suite.service.InvoicesFor(context.Background(), userId, common.InvoiceTypeOutgoing) + if err != nil { + fmt.Printf("Error when getting invoices %v\n", err.Error()) + } + assert.Equal(suite.T(), 1, len(invoices)) + assert.Equal(suite.T(), common.InvoiceStateError, invoices[0].State) + assert.Equal(suite.T(), SendPaymentMockError, invoices[0].ErrorMessage) + assert.NotEmpty(suite.T(), invoices[0].RHash) + assert.NotEmpty(suite.T(), invoices[0].Preimage) +} + +func TestKeySendFailureTestSuite(t *testing.T) { + suite.Run(t, new(KeySendFailureTestSuite)) +} diff --git a/integration_tests/util.go b/integration_tests/util.go index 1d7812d..2fe3e7d 100644 --- a/integration_tests/util.go +++ b/integration_tests/util.go @@ -51,6 +51,7 @@ func LndHubTestServiceInit(lndClientMock lnd.LightningClientWrapper) (svc *servi DatabaseMaxConns: 1, DatabaseMaxIdleConns: 1, DatabaseConnMaxLifetime: 10, + MaxFeeAmount: 1e6, JWTSecret: []byte("SECRET"), JWTAccessTokenExpiry: 3600, JWTRefreshTokenExpiry: 3600, diff --git a/lib/responses/errors.go b/lib/responses/errors.go index 5a293fd..d8126ea 100644 --- a/lib/responses/errors.go +++ b/lib/responses/errors.go @@ -64,6 +64,13 @@ var NotEnoughBalanceError = ErrorResponse{ HttpStatusCode: 400, } +var TooMuchVolumeError = ErrorResponse{ + Error: true, + Code: 2, + Message: "transaction volume too high. please contact support for further assistance.", + HttpStatusCode: 400, +} + var AccountDeactivatedError = ErrorResponse{ Error: true, Code: 1, diff --git a/lib/service/config.go b/lib/service/config.go index 1ceb911..fa53971 100644 --- a/lib/service/config.go +++ b/lib/service/config.go @@ -37,6 +37,8 @@ type Config struct { MaxSendAmount int64 `envconfig:"MAX_SEND_AMOUNT" default:"0"` MaxAccountBalance int64 `envconfig:"MAX_ACCOUNT_BALANCE" default:"0"` MaxFeeAmount int64 `envconfig:"MAX_FEE_AMOUNT" default:"5000"` + MaxVolume int64 `envconfig:"MAX_VOLUME" default:"0"` //0 means the volume check is disabled by default + MaxVolumePeriod int64 `envconfig:"MAX_VOLUME_PERIOD" default:"2592000"` //in seconds, default 1 month RabbitMQUri string `envconfig:"RABBITMQ_URI"` RabbitMQLndhubInvoiceExchange string `envconfig:"RABBITMQ_INVOICE_EXCHANGE" default:"lndhub_invoice"` RabbitMQLndInvoiceExchange string `envconfig:"RABBITMQ_LND_INVOICE_EXCHANGE" default:"lnd_invoice"` diff --git a/lib/service/invoices.go b/lib/service/invoices.go index 634cee2..3a8cd6b 100644 --- a/lib/service/invoices.go +++ b/lib/service/invoices.go @@ -160,23 +160,25 @@ func (svc *LndhubService) createLnRpcSendRequest(invoice *models.Invoice) (*lnrp }, nil } - preImage, err := makePreimageHex() - if err != nil { - return nil, err - } - pHash := sha256.New() - pHash.Write(preImage) // Prepare the LNRPC call //See: https://github.com/hsjoberg/blixt-wallet/blob/9fcc56a7dc25237bc14b85e6490adb9e044c009c/src/lndmobile/index.ts#L251-L270 destBytes, err := hex.DecodeString(invoice.DestinationPubkeyHex) if err != nil { return nil, err } + preImage, err := hex.DecodeString(invoice.Preimage) + if err != nil { + return nil, err + } invoice.DestinationCustomRecords[KEYSEND_CUSTOM_RECORD] = preImage + paymentHash, err := hex.DecodeString(invoice.RHash) + if err != nil { + return nil, err + } return &lnrpc.SendRequest{ Dest: destBytes, Amt: invoice.Amount, - PaymentHash: pHash.Sum(nil), + PaymentHash: paymentHash, FeeLimit: &feeLimit, DestFeatures: []lnrpc.FeatureBit{lnrpc.FeatureBit_TLV_ONION_REQ}, DestCustomRecords: invoice.DestinationCustomRecords, @@ -453,6 +455,18 @@ func (svc *LndhubService) AddOutgoingInvoice(ctx context.Context, userID int64, ExpiresAt: bun.NullTime{Time: time.Unix(lnPayReq.PayReq.Timestamp, 0).Add(time.Duration(lnPayReq.PayReq.Expiry) * time.Second)}, } + if lnPayReq.Keysend { + preImage, err := makePreimageHex() + if err != nil { + return nil, err + } + pHash := sha256.New() + pHash.Write(preImage) + + invoice.RHash = hex.EncodeToString(pHash.Sum(nil)) + invoice.Preimage = hex.EncodeToString(preImage) + } + // Save invoice _, err := svc.DB.NewInsert().Model(&invoice).Exec(ctx) if err != nil { diff --git a/lib/service/invoices_test.go b/lib/service/invoices_test.go index 6ae55d9..5ee1536 100644 --- a/lib/service/invoices_test.go +++ b/lib/service/invoices_test.go @@ -10,6 +10,9 @@ import ( var svc = &LndhubService{ LndClient: &lnd.LNDWrapper{IdentityPubkey: "123pubkey"}, + Config: &Config{ + MaxFeeAmount: 1e6, + }, } func TestCalcFeeWithInvoiceLessThan1000(t *testing.T) { @@ -42,3 +45,14 @@ func TestCalcFeeWithInvoiceMoreThan1000(t *testing.T) { expectedFee := int64(16) assert.Equal(t, expectedFee, feeLimit) } + +func TestCalcFeeWithMaxGlobalFee(t *testing.T) { + invoice := &models.Invoice{ + Amount: 1500, + } + svc.Config.MaxFeeAmount = 1 + + feeLimit := svc.CalcFeeLimit("dummy", invoice.Amount) + expectedFee := svc.Config.MaxFeeAmount + assert.Equal(t, expectedFee, feeLimit) +} diff --git a/lib/service/invoicesubscription.go b/lib/service/invoicesubscription.go index b9d24f8..8e1083d 100644 --- a/lib/service/invoicesubscription.go +++ b/lib/service/invoicesubscription.go @@ -2,7 +2,6 @@ package service import ( "context" - "crypto/sha256" "database/sql" "encoding/hex" "errors" @@ -27,12 +26,6 @@ func (svc *LndhubService) HandleInternalKeysendPayment(ctx context.Context, invo if err != nil { return nil, err } - preImage, err := makePreimageHex() - if err != nil { - return nil, err - } - pHash := sha256.New() - pHash.Write(preImage) expiry := time.Hour * 24 incomingInvoice := models.Invoice{ Type: common.InvoiceTypeIncoming, @@ -43,8 +36,8 @@ func (svc *LndhubService) HandleInternalKeysendPayment(ctx context.Context, invo State: common.InvoiceStateInitialized, ExpiresAt: bun.NullTime{Time: time.Now().Add(expiry)}, Keysend: true, - RHash: hex.EncodeToString(pHash.Sum(nil)), - Preimage: hex.EncodeToString(preImage), + RHash: invoice.RHash, + Preimage: invoice.Preimage, DestinationCustomRecords: invoice.DestinationCustomRecords, DestinationPubkeyHex: svc.LndClient.GetMainPubkey(), AddIndex: invoice.AddIndex, diff --git a/lib/service/user.go b/lib/service/user.go index a1e9a75..26d6d4d 100644 --- a/lib/service/user.go +++ b/lib/service/user.go @@ -5,11 +5,14 @@ import ( "database/sql" "fmt" "math" + "time" "github.com/getAlby/lndhub.go/common" "github.com/getAlby/lndhub.go/db/models" + "github.com/getAlby/lndhub.go/lib/responses" "github.com/getAlby/lndhub.go/lib/security" "github.com/getAlby/lndhub.go/lnd" + "github.com/getsentry/sentry-go" "github.com/uptrace/bun" passwordvalidator "github.com/wagslane/go-password-validator" ) @@ -121,17 +124,31 @@ func (svc *LndhubService) FindUserByLogin(ctx context.Context, login string) (*m return &user, nil } -func (svc *LndhubService) BalanceCheck(ctx context.Context, lnpayReq *lnd.LNPayReq, userId int64) (ok bool, err error) { +func (svc *LndhubService) CheckPaymentAllowed(ctx context.Context, lnpayReq *lnd.LNPayReq, userId int64) (result *responses.ErrorResponse, err error) { currentBalance, err := svc.CurrentUserBalance(ctx, userId) if err != nil { - return false, err + return nil, err } minimumBalance := lnpayReq.PayReq.NumSatoshis if svc.Config.FeeReserve { minimumBalance += svc.CalcFeeLimit(lnpayReq.PayReq.Destination, lnpayReq.PayReq.NumSatoshis) } - return currentBalance >= minimumBalance, nil + if currentBalance < minimumBalance { + return &responses.NotEnoughBalanceError, nil + } + //only check for volume if configured + if svc.Config.MaxVolume > 0 { + volume, err := svc.GetVolumeOverPeriod(ctx, userId, time.Duration(svc.Config.MaxVolumePeriod*int64(time.Second))) + if err != nil { + return nil, err + } + if volume > svc.Config.MaxVolume { + sentry.CaptureMessage(fmt.Sprintf("transaction volume exceeded for user %d", userId)) + return &responses.TooMuchVolumeError, nil + } + } + return nil, nil } func (svc *LndhubService) CalcFeeLimit(destination string, amount int64) int64 { @@ -185,3 +202,16 @@ func (svc *LndhubService) InvoicesFor(ctx context.Context, userId int64, invoice } return invoices, nil } + +func (svc *LndhubService) GetVolumeOverPeriod(ctx context.Context, userId int64, period time.Duration) (result int64, err error) { + + err = svc.DB.NewSelect().Table("invoices"). + ColumnExpr("sum(invoices.amount) as result"). + Where("invoices.user_id = ?", userId). + Where("invoices.settled_at >= ?", time.Now().Add(-1*period)). + Scan(ctx, &result) + if err != nil { + return 0, err + } + return result, nil +}