diff --git a/integration_tests/internal_payment_test.go b/integration_tests/internal_payment_test.go index 950a1dc..dcc9c7f 100644 --- a/integration_tests/internal_payment_test.go +++ b/integration_tests/internal_payment_test.go @@ -69,7 +69,6 @@ func (suite *PaymentTestSuite) SetupSuite() { suite.aliceToken = userTokens[0] suite.bobLogin = users[1] suite.bobToken = userTokens[1] - suite.echo.Use(tokens.Middleware([]byte(suite.service.Config.JWTSecret))) suite.echo.GET("/balance", controllers.NewBalanceController(suite.service).Balance) suite.echo.POST("/addinvoice", controllers.NewAddInvoiceController(suite.service).AddInvoice) diff --git a/lib/service/user.go b/lib/service/user.go index 8f25cca..196a9e5 100644 --- a/lib/service/user.go +++ b/lib/service/user.go @@ -134,6 +134,25 @@ func (svc *LndhubService) CheckOutgoingPaymentAllowed(ctx context.Context, lnpay } } + if limits.MaxSendVolume > 0 { + volume, err := svc.GetVolumeOverPeriod(ctx, userId, common.InvoiceTypeOutgoing, time.Duration(svc.Config.MaxVolumePeriod*int64(time.Second))) + if err != nil { + svc.Logger.Errorj( + log.JSON{ + "message": "error fetching volume", + "error": err, + "lndhub_user_id": userId, + }, + ) + return nil, err + } + if volume > limits.MaxSendVolume { + svc.Logger.Errorf("Transaction volume exceeded for user_id %d", userId) + sentry.CaptureMessage(fmt.Sprintf("transaction volume exceeded for user %d", userId)) + return &responses.TooMuchVolumeError, nil + } + } + currentBalance, err := svc.CurrentUserBalance(ctx, userId) if err != nil { svc.Logger.Errorj( @@ -154,7 +173,7 @@ func (svc *LndhubService) CheckOutgoingPaymentAllowed(ctx context.Context, lnpay return &responses.NotEnoughBalanceError, nil } - return svc.CheckVolumeAllowed(ctx, userId, limits.MaxSendVolume, common.InvoiceTypeOutgoing) + return nil, nil } func (svc *LndhubService) CheckIncomingPaymentAllowed(ctx context.Context, amount, userId int64, limits *models.Limits) (result *responses.ErrorResponse, err error) { @@ -165,6 +184,25 @@ func (svc *LndhubService) CheckIncomingPaymentAllowed(ctx context.Context, amoun } } + if limits.MaxReceiveVolume > 0 { + volume, err := svc.GetVolumeOverPeriod(ctx, userId, common.InvoiceTypeIncoming, time.Duration(svc.Config.MaxVolumePeriod*int64(time.Second))) + if err != nil { + svc.Logger.Errorj( + log.JSON{ + "message": "error fetching volume", + "error": err, + "lndhub_user_id": userId, + }, + ) + return nil, err + } + if volume > limits.MaxReceiveVolume { + svc.Logger.Errorf("Transaction volume exceeded for user_id %d", userId) + sentry.CaptureMessage(fmt.Sprintf("transaction volume exceeded for user %d", userId)) + return &responses.TooMuchVolumeError, nil + } + } + if limits.MaxAccountBalance > 0 { currentBalance, err := svc.CurrentUserBalance(ctx, userId) if err != nil { @@ -183,32 +221,9 @@ func (svc *LndhubService) CheckIncomingPaymentAllowed(ctx context.Context, amoun } } - return svc.CheckVolumeAllowed(ctx, userId, limits.MaxReceiveVolume, common.InvoiceTypeIncoming) -} - -func (svc *LndhubService) CheckVolumeAllowed(ctx context.Context, userId, maxVolume int64, invoiceType string) (result *responses.ErrorResponse, err error) { - if maxVolume > 0 { - volume, err := svc.GetVolumeOverPeriod(ctx, userId, invoiceType, time.Duration(svc.Config.MaxVolumePeriod*int64(time.Second))) - if err != nil { - svc.Logger.Errorj( - log.JSON{ - "message": "error fetching volume", - "error": err, - "lndhub_user_id": userId, - }, - ) - return nil, err - } - if volume > maxVolume { - svc.Logger.Errorf("Transaction volume exceeded for user_id %d", userId) - sentry.CaptureMessage(fmt.Sprintf("transaction volume exceeded for user %d", userId)) - return &responses.TooMuchVolumeError, nil - } - } return nil, nil } - func (svc *LndhubService) CalcFeeLimit(destination string, amount int64) int64 { if svc.LndClient.IsIdentityPubkey(destination) { return 0