refactor invoice amount checks

This commit is contained in:
kiwiidb
2023-02-17 14:59:10 +01:00
parent e14c0af5ca
commit c74f39f94d
8 changed files with 46 additions and 38 deletions

View File

@@ -91,7 +91,7 @@ func (controller *KeySendController) KeySend(c echo.Context) error {
minimumBalance := invoice.Amount minimumBalance := invoice.Amount
if controller.svc.Config.FeeReserve { if controller.svc.Config.FeeReserve {
minimumBalance += invoice.CalcFeeLimit(controller.svc.IdentityPubkey) minimumBalance += controller.svc.CalcFeeLimit(invoice.DestinationPubkeyHex, invoice.Amount)
} }
if currentBalance < minimumBalance { if currentBalance < minimumBalance {
c.Logger().Errorf("User does not have enough balance invoice_id:%v user_id:%v balance:%v amount:%v", invoice.ID, userID, currentBalance, invoice.Amount) c.Logger().Errorf("User does not have enough balance invoice_id:%v user_id:%v balance:%v amount:%v", invoice.ID, userID, currentBalance, invoice.Amount)

View File

@@ -80,27 +80,19 @@ func (controller *PayInvoiceController) PayInvoice(c echo.Context) error {
} }
} }
ok, err := controller.svc.BalanceCheck(c.Request().Context(), lnPayReq, userID)
if err != nil {
return err
}
if !ok {
c.Logger().Errorf("User does not have enough balance user_id:%v amount:%v", userID, lnPayReq.PayReq.NumSatoshis)
return c.JSON(http.StatusBadRequest, responses.NotEnoughBalanceError)
}
invoice, err := controller.svc.AddOutgoingInvoice(c.Request().Context(), userID, paymentRequest, lnPayReq) invoice, err := controller.svc.AddOutgoingInvoice(c.Request().Context(), userID, paymentRequest, lnPayReq)
if err != nil { if err != nil {
return err return err
} }
currentBalance, err := controller.svc.CurrentUserBalance(c.Request().Context(), userID)
if err != nil {
controller.svc.DB.NewDelete().Where("id = ?", invoice.ID).Exec(c.Request().Context())
return err
}
minimumBalance := invoice.Amount
if controller.svc.Config.FeeReserve {
minimumBalance += invoice.CalcFeeLimit(controller.svc.IdentityPubkey)
}
if currentBalance < minimumBalance {
c.Logger().Errorf("User does not have enough balance invoice_id:%v user_id:%v balance:%v amount:%v", invoice.ID, userID, currentBalance, invoice.Amount)
controller.svc.DB.NewDelete().Model(&invoice).Where("id = ?", invoice.ID).Exec(c.Request().Context())
return c.JSON(http.StatusBadRequest, responses.NotEnoughBalanceError)
}
sendPaymentResponse, err := controller.svc.PayInvoice(c.Request().Context(), invoice) sendPaymentResponse, err := controller.svc.PayInvoice(c.Request().Context(), invoice)
if err != nil { if err != nil {
c.Logger().Errorf("Payment failed invoice_id:%v user_id:%v error: %v", invoice.ID, userID, err) c.Logger().Errorf("Payment failed invoice_id:%v user_id:%v error: %v", invoice.ID, userID, err)

View File

@@ -182,7 +182,7 @@ func (controller *KeySendController) SingleKeySend(c echo.Context, reqBody *KeyS
minimumBalance := invoice.Amount minimumBalance := invoice.Amount
if controller.svc.Config.FeeReserve { if controller.svc.Config.FeeReserve {
minimumBalance += invoice.CalcFeeLimit(controller.svc.IdentityPubkey) minimumBalance += controller.svc.CalcFeeLimit(invoice.DestinationPubkeyHex, invoice.Amount)
} }
if currentBalance < minimumBalance { if currentBalance < minimumBalance {
c.Logger().Errorf("User does not have enough balance invoice_id:%v user_id:%v balance:%v amount:%v", invoice.ID, userID, currentBalance, invoice.Amount) c.Logger().Errorf("User does not have enough balance invoice_id:%v user_id:%v balance:%v amount:%v", invoice.ID, userID, currentBalance, invoice.Amount)

View File

@@ -95,7 +95,7 @@ func (controller *PayInvoiceController) PayInvoice(c echo.Context) error {
minimumBalance := invoice.Amount minimumBalance := invoice.Amount
if controller.svc.Config.FeeReserve { if controller.svc.Config.FeeReserve {
minimumBalance += invoice.CalcFeeLimit(controller.svc.IdentityPubkey) minimumBalance += controller.svc.CalcFeeLimit(invoice.DestinationPubkeyHex, invoice.Amount)
} }
if currentBalance < minimumBalance { if currentBalance < minimumBalance {
c.Logger().Errorf("User does not have enough balance invoice_id:%v user_id:%v balance:%v amount:%v", invoice.ID, userID, currentBalance, invoice.Amount) c.Logger().Errorf("User does not have enough balance invoice_id:%v user_id:%v balance:%v amount:%v", invoice.ID, userID, currentBalance, invoice.Amount)

View File

@@ -2,7 +2,6 @@ package models
import ( import (
"context" "context"
"math"
"time" "time"
"github.com/uptrace/bun" "github.com/uptrace/bun"
@@ -42,15 +41,4 @@ func (i *Invoice) BeforeAppendModel(ctx context.Context, query bun.Query) error
return nil return nil
} }
func (i *Invoice) CalcFeeLimit(identityPubkey string) int64 {
if i.DestinationPubkeyHex == identityPubkey {
return 0
}
limit := int64(10)
if i.Amount > 1000 {
limit = int64(math.Ceil(float64(i.Amount)*float64(0.01)) + 1)
}
return limit
}
var _ bun.BeforeAppendModelHook = (*Invoice)(nil) var _ bun.BeforeAppendModelHook = (*Invoice)(nil)

View File

@@ -117,7 +117,7 @@ func (svc *LndhubService) SendInternalPayment(ctx context.Context, invoice *mode
func (svc *LndhubService) SendPaymentSync(ctx context.Context, invoice *models.Invoice) (SendPaymentResponse, error) { func (svc *LndhubService) SendPaymentSync(ctx context.Context, invoice *models.Invoice) (SendPaymentResponse, error) {
sendPaymentResponse := SendPaymentResponse{} sendPaymentResponse := SendPaymentResponse{}
sendPaymentRequest, err := createLnRpcSendRequest(invoice) sendPaymentRequest, err := svc.createLnRpcSendRequest(invoice)
if err != nil { if err != nil {
return sendPaymentResponse, err return sendPaymentResponse, err
} }
@@ -143,11 +143,11 @@ func (svc *LndhubService) SendPaymentSync(ctx context.Context, invoice *models.I
return sendPaymentResponse, nil return sendPaymentResponse, nil
} }
func createLnRpcSendRequest(invoice *models.Invoice) (*lnrpc.SendRequest, error) { func (svc *LndhubService) createLnRpcSendRequest(invoice *models.Invoice) (*lnrpc.SendRequest, error) {
feeLimit := lnrpc.FeeLimit{ feeLimit := lnrpc.FeeLimit{
Limit: &lnrpc.FeeLimit_Fixed{ Limit: &lnrpc.FeeLimit_Fixed{
//if we get here, the destination is never ourselves, so we can use a dummy //if we get here, the destination is never ourselves, so we can use a dummy
Fixed: invoice.CalcFeeLimit("dummy"), Fixed: svc.CalcFeeLimit("dummy", invoice.Amount),
}, },
} }

View File

@@ -7,12 +7,14 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
var svc = &LndhubService{}
func TestCalcFeeWithInvoiceLessThan1000(t *testing.T) { func TestCalcFeeWithInvoiceLessThan1000(t *testing.T) {
invoice := &models.Invoice{ invoice := &models.Invoice{
Amount: 500, Amount: 500,
} }
feeLimit := invoice.CalcFeeLimit("dummy") feeLimit := svc.CalcFeeLimit("dummy", invoice.Amount)
expectedFee := int64(10) expectedFee := int64(10)
assert.Equal(t, expectedFee, feeLimit) assert.Equal(t, expectedFee, feeLimit)
} }
@@ -22,7 +24,7 @@ func TestCalcFeeWithInvoiceEqualTo1000(t *testing.T) {
Amount: 500, Amount: 500,
} }
feeLimit := invoice.CalcFeeLimit("dummy") feeLimit := svc.CalcFeeLimit("dummy", invoice.Amount)
expectedFee := int64(10) expectedFee := int64(10)
assert.Equal(t, expectedFee, feeLimit) assert.Equal(t, expectedFee, feeLimit)
} }
@@ -32,7 +34,7 @@ func TestCalcFeeWithInvoiceMoreThan1000(t *testing.T) {
Amount: 1500, Amount: 1500,
} }
feeLimit := invoice.CalcFeeLimit("dummy") feeLimit := svc.CalcFeeLimit("dummy", invoice.Amount)
// 1500 * 0.01 + 1 // 1500 * 0.01 + 1
expectedFee := int64(16) expectedFee := int64(16)
assert.Equal(t, expectedFee, feeLimit) assert.Equal(t, expectedFee, feeLimit)

View File

@@ -4,10 +4,12 @@ import (
"context" "context"
"database/sql" "database/sql"
"fmt" "fmt"
"math"
"github.com/getAlby/lndhub.go/common" "github.com/getAlby/lndhub.go/common"
"github.com/getAlby/lndhub.go/db/models" "github.com/getAlby/lndhub.go/db/models"
"github.com/getAlby/lndhub.go/lib/security" "github.com/getAlby/lndhub.go/lib/security"
"github.com/getAlby/lndhub.go/lnd"
"github.com/uptrace/bun" "github.com/uptrace/bun"
passwordvalidator "github.com/wagslane/go-password-validator" passwordvalidator "github.com/wagslane/go-password-validator"
) )
@@ -91,6 +93,30 @@ func (svc *LndhubService) FindUserByLogin(ctx context.Context, login string) (*m
return &user, nil return &user, nil
} }
func (svc *LndhubService) BalanceCheck(ctx context.Context, lnpayReq *lnd.LNPayReq, userId int64) (ok bool, err error) {
currentBalance, err := svc.CurrentUserBalance(ctx, userId)
if err != nil {
return false, err
}
minimumBalance := lnpayReq.PayReq.NumSatoshis
if svc.Config.FeeReserve {
minimumBalance += svc.CalcFeeLimit(lnpayReq.PayReq.Destination, lnpayReq.PayReq.NumSatoshis)
}
return currentBalance > minimumBalance, nil
}
func (svc *LndhubService) CalcFeeLimit(destination string, amount int64) int64 {
if destination == svc.IdentityPubkey {
return 0
}
limit := int64(10)
if amount > 1000 {
limit = int64(math.Ceil(float64(amount)*float64(0.01)) + 1)
}
return limit
}
func (svc *LndhubService) CurrentUserBalance(ctx context.Context, userId int64) (int64, error) { func (svc *LndhubService) CurrentUserBalance(ctx context.Context, userId int64) (int64, error) {
var balance int64 var balance int64