diff --git a/controllers/addinvoice.ctrl.go b/controllers/addinvoice.ctrl.go index 8a6230b..8eda455 100644 --- a/controllers/addinvoice.ctrl.go +++ b/controllers/addinvoice.ctrl.go @@ -5,7 +5,6 @@ import ( "github.com/getAlby/lndhub.go/lib/responses" "github.com/getAlby/lndhub.go/lib/service" - "github.com/getsentry/sentry-go" "github.com/labstack/echo/v4" "github.com/labstack/gommon/log" ) @@ -62,38 +61,11 @@ func AddInvoice(c echo.Context, svc *service.LndhubService, userID int64) error return c.JSON(http.StatusBadRequest, responses.BadArgumentsError) } - if svc.Config.MaxReceiveAmount > 0 { - if amount > svc.Config.MaxReceiveAmount { - c.Logger().Errorf("Max receive amount exceeded for user_id:%v (amount:%v)", userID, amount) - return c.JSON(http.StatusBadRequest, responses.BadArgumentsError) - } - } - - if svc.Config.MaxAccountBalance > 0 { - currentBalance, err := svc.CurrentUserBalance(c.Request().Context(), userID) - if err != nil { - c.Logger().Errorj( - log.JSON{ - "message": "error fetching balance", - "lndhub_user_id": userID, - "error": err, - }, - ) - return c.JSON(http.StatusBadRequest, responses.BadArgumentsError) - } - if currentBalance+amount > svc.Config.MaxAccountBalance { - c.Logger().Errorf("Max account balance exceeded for user_id:%v (balance:%v + amount:%v)", userID, currentBalance, amount) - return c.JSON(http.StatusBadRequest, responses.BadArgumentsError) - } - } - c.Logger().Infof("Adding invoice: user_id:%v memo:%s value:%v description_hash:%s", userID, body.Memo, amount, body.DescriptionHash) - invoice, err := svc.AddIncomingInvoice(c.Request().Context(), userID, amount, body.Memo, body.DescriptionHash) - if err != nil { - c.Logger().Errorf("Error creating invoice: user_id:%v error: %v", userID, err) - sentry.CaptureException(err) - return c.JSON(http.StatusBadRequest, responses.BadArgumentsError) + invoice, errResp := svc.AddIncomingInvoice(c.Request().Context(), userID, amount, body.Memo, body.DescriptionHash) + if errResp != nil { + return c.JSON(errResp.HttpStatusCode, errResp) } responseBody := AddInvoiceResponseBody{} responseBody.RHash = invoice.RHash diff --git a/controllers_v2/invoice.ctrl.go b/controllers_v2/invoice.ctrl.go index 24a4b32..fff42a9 100644 --- a/controllers_v2/invoice.ctrl.go +++ b/controllers_v2/invoice.ctrl.go @@ -7,7 +7,6 @@ import ( "github.com/getAlby/lndhub.go/common" "github.com/getAlby/lndhub.go/lib/responses" "github.com/getAlby/lndhub.go/lib/service" - "github.com/getsentry/sentry-go" "github.com/labstack/echo/v4" "github.com/labstack/gommon/log" ) @@ -180,11 +179,9 @@ func (controller *InvoiceController) AddInvoice(c echo.Context) error { c.Logger().Infof("Adding invoice: user_id:%v memo:%s value:%v description_hash:%s", userID, body.Description, body.Amount, body.DescriptionHash) - invoice, err := controller.svc.AddIncomingInvoice(c.Request().Context(), userID, body.Amount, body.Description, body.DescriptionHash) - if err != nil { - c.Logger().Errorf("Error creating invoice: user_id:%v error: %v", userID, err) - sentry.CaptureException(err) - return c.JSON(http.StatusBadRequest, responses.BadArgumentsError) + invoice, errResp := controller.svc.AddIncomingInvoice(c.Request().Context(), userID, body.Amount, body.Description, body.DescriptionHash) + if errResp != nil { + return c.JSON(errResp.HttpStatusCode, errResp) } responseBody := AddInvoiceResponseBody{ PaymentHash: invoice.RHash, diff --git a/integration_tests/internal_payment_test.go b/integration_tests/internal_payment_test.go index ec68810..0fa4ea6 100644 --- a/integration_tests/internal_payment_test.go +++ b/integration_tests/internal_payment_test.go @@ -133,7 +133,7 @@ func (suite *PaymentTestSuite) TestPaymentFeeReserve() { //reset fee reserve so it's not used in other tests suite.service.Config.FeeReserve = false } -func (suite *PaymentTestSuite) TestVolumeExceeded() { +func (suite *PaymentTestSuite) TestIncomingExceededChecks() { //this will cause the payment to fail as the account was already funded //with 1000 sats suite.service.Config.MaxVolume = 999 @@ -150,7 +150,7 @@ func (suite *PaymentTestSuite) TestVolumeExceeded() { //try to make external payment //which should fail //create external invoice - externalSatRequested := 1000 + externalSatRequested := 500 externalInvoice := lnrpc.Invoice{ Memo: "integration tests: external pay from user", Value: int64(externalSatRequested), @@ -190,9 +190,53 @@ func (suite *PaymentTestSuite) TestVolumeExceeded() { suite.echo.ServeHTTP(rec, req) assert.Equal(suite.T(), http.StatusOK, rec.Code) - //change the config back + suite.service.Config.MaxReceiveAmount = 21 + rec = httptest.NewRecorder() + assert.NoError(suite.T(), json.NewEncoder(&buf).Encode(&ExpectedAddInvoiceRequestBody{ + Amount: aliceFundingSats, + Memo: "memo", + })) + req = httptest.NewRequest(http.MethodPost, "/addinvoice", &buf) + req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON) + req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", suite.aliceToken)) + suite.echo.ServeHTTP(rec, req) + //should fail because max receive amount check + assert.Equal(suite.T(), http.StatusBadRequest, rec.Code) + resp = &responses.ErrorResponse{} + err = json.NewDecoder(rec.Body).Decode(resp) + assert.NoError(suite.T(), err) + assert.Equal(suite.T(), responses.ReceiveExceededError.Message, resp.Message) + + // remove volume and receive config and check if it works + suite.service.Config.MaxVolume = 0 suite.service.Config.MaxVolumePeriod = 0 - suite.service.Config.MaxVolume = 1e6 + suite.service.Config.MaxReceiveAmount = 0 + invoiceResponse = suite.createAddInvoiceReq(aliceFundingSats, "integration test internal payment alice", suite.aliceToken) + err = suite.mlnd.mockPaidInvoice(invoiceResponse, 0, false, nil) + assert.NoError(suite.T(), err) + + // add max account + suite.service.Config.MaxAccountBalance = 500 + assert.NoError(suite.T(), json.NewEncoder(&buf).Encode(&ExpectedAddInvoiceRequestBody{ + Amount: aliceFundingSats, + Memo: "memo", + })) + req = httptest.NewRequest(http.MethodPost, "/addinvoice", &buf) + req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON) + req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", suite.aliceToken)) + suite.echo.ServeHTTP(rec, req) + //should fail because max balance check + assert.Equal(suite.T(), http.StatusBadRequest, rec.Code) + resp = &responses.ErrorResponse{} + err = json.NewDecoder(rec.Body).Decode(resp) + assert.NoError(suite.T(), err) + assert.Equal(suite.T(), responses.BalanceExceededError.Message, resp.Message) + + //change the config back and add sats, it should work now + suite.service.Config.MaxAccountBalance = 0 + invoiceResponse = suite.createAddInvoiceReq(aliceFundingSats, "integration test internal payment alice", suite.aliceToken) + err = suite.mlnd.mockPaidInvoice(invoiceResponse, 0, false, nil) + assert.NoError(suite.T(), err) } func (suite *PaymentTestSuite) TestInternalPayment() { aliceFundingSats := 1000 diff --git a/integration_tests/invoice_test.go b/integration_tests/invoice_test.go index c426165..083671c 100644 --- a/integration_tests/invoice_test.go +++ b/integration_tests/invoice_test.go @@ -92,8 +92,8 @@ func (suite *InvoiceTestSuite) TestPreimageEntropy() { user, _ := suite.service.FindUserByLogin(context.Background(), suite.aliceLogin.Login) preimageChars := map[byte]int{} for i := 0; i < 1000; i++ { - inv, err := suite.service.AddIncomingInvoice(context.Background(), user.ID, 10, "test entropy", "") - assert.NoError(suite.T(), err) + inv, errResp := suite.service.AddIncomingInvoice(context.Background(), user.ID, 10, "test entropy", "") + assert.Nil(suite.T(), errResp) primgBytes, _ := hex.DecodeString(inv.Preimage) for _, char := range primgBytes { preimageChars[char] += 1 diff --git a/lib/responses/errors.go b/lib/responses/errors.go index d8126ea..15cc8e7 100644 --- a/lib/responses/errors.go +++ b/lib/responses/errors.go @@ -64,6 +64,20 @@ var NotEnoughBalanceError = ErrorResponse{ HttpStatusCode: 400, } +var ReceiveExceededError = ErrorResponse{ + Error: true, + Code: 2, + Message: "max receive amount exceeded. please contact support for further assistance.", + HttpStatusCode: 400, +} + +var BalanceExceededError = ErrorResponse{ + Error: true, + Code: 2, + Message: "max account balance exceeded. please contact support for further assistance.", + HttpStatusCode: 400, +} + var TooMuchVolumeError = ErrorResponse{ Error: true, Code: 2, diff --git a/lib/service/invoices.go b/lib/service/invoices.go index 3a8cd6b..7ef884a 100644 --- a/lib/service/invoices.go +++ b/lib/service/invoices.go @@ -13,8 +13,10 @@ import ( "github.com/getAlby/lndhub.go/common" "github.com/getAlby/lndhub.go/db/models" + "github.com/getAlby/lndhub.go/lib/responses" "github.com/getAlby/lndhub.go/lnd" "github.com/getsentry/sentry-go" + "github.com/labstack/gommon/log" "github.com/lightningnetwork/lnd/lnrpc" "github.com/uptrace/bun" "github.com/uptrace/bun/schema" @@ -475,10 +477,48 @@ func (svc *LndhubService) AddOutgoingInvoice(ctx context.Context, userID int64, return &invoice, nil } -func (svc *LndhubService) AddIncomingInvoice(ctx context.Context, userID int64, amount int64, memo, descriptionHashStr string) (*models.Invoice, error) { +func (svc *LndhubService) AddIncomingInvoice(ctx context.Context, userID int64, amount int64, memo, descriptionHashStr string) (*models.Invoice, *responses.ErrorResponse) { + + if svc.Config.MaxReceiveAmount > 0 { + if amount > svc.Config.MaxReceiveAmount { + svc.Logger.Errorf("Max receive amount exceeded for user_id %d", userID) + return nil, &responses.ReceiveExceededError + } + } + + if svc.Config.MaxAccountBalance > 0 { + currentBalance, err := svc.CurrentUserBalance(ctx, userID) + if err != nil { + svc.Logger.Errorj( + log.JSON{ + "message": "error fetching balance", + "lndhub_user_id": userID, + "error": err, + }, + ) + return nil, &responses.GeneralServerError + } + if currentBalance+amount > svc.Config.MaxAccountBalance { + svc.Logger.Errorf("Max account balance exceeded for user_id %d", userID) + return nil, &responses.BalanceExceededError + } + } + + if svc.Config.MaxVolume > 0 { + volume, err := svc.GetVolumeOverPeriod(ctx, userID, time.Duration(svc.Config.MaxVolumePeriod*int64(time.Second))) + if err != nil { + return nil, &responses.GeneralServerError + } + if volume > svc.Config.MaxVolume { + svc.Logger.Errorf("Transaction volume exceeded for user_id %d", userID) + sentry.CaptureMessage(fmt.Sprintf("transaction volume exceeded for user %d", userID)) + return nil, &responses.TooMuchVolumeError + } + } + preimage, err := makePreimageHex() if err != nil { - return nil, err + return nil, &responses.GeneralServerError } expiry := time.Hour * 24 // invoice expires in 24h // Initialize new DB invoice @@ -495,12 +535,12 @@ func (svc *LndhubService) AddIncomingInvoice(ctx context.Context, userID int64, // Save invoice - we save the invoice early to have a record in case the LN call fails _, err = svc.DB.NewInsert().Model(&invoice).Exec(ctx) if err != nil { - return nil, err + return nil, &responses.GeneralServerError } descriptionHash, err := hex.DecodeString(descriptionHashStr) if err != nil { - return nil, err + return nil, &responses.GeneralServerError } // Initialize lnrpc invoice lnInvoice := lnrpc.Invoice{ @@ -513,7 +553,8 @@ func (svc *LndhubService) AddIncomingInvoice(ctx context.Context, userID int64, // Call LND lnInvoiceResult, err := svc.LndClient.AddInvoice(ctx, &lnInvoice) if err != nil { - return nil, err + svc.Logger.Errorf("Error creating invoice: user_id:%v error: %v", userID, err) + return nil, &responses.GeneralServerError } // Update the DB invoice with the data from the LND gRPC call @@ -526,7 +567,7 @@ func (svc *LndhubService) AddIncomingInvoice(ctx context.Context, userID int64, _, err = svc.DB.NewUpdate().Model(&invoice).WherePK().Exec(ctx) if err != nil { - return nil, err + return nil, &responses.GeneralServerError } return &invoice, nil