chore: use limits from JWT if available

This commit is contained in:
im-adithya
2023-12-05 15:58:45 +05:30
parent 62d7ffb7fe
commit e681e69fa2
24 changed files with 133 additions and 47 deletions

View File

@@ -153,8 +153,17 @@ func main() {
logMw := transport.CreateLoggingMiddleware(logger)
// strict rate limit for requests for sending payments
strictRateLimitMiddleware := transport.CreateRateLimitMiddleware(c.StrictRateLimit, c.BurstRateLimit)
secured := e.Group("", tokens.Middleware(c.JWTSecret), logMw)
securedWithStrictRateLimit := e.Group("", tokens.Middleware(c.JWTSecret), strictRateLimitMiddleware, logMw)
limits := &lnd.Limits{
MaxSendVolume: c.MaxSendVolume,
MaxSendAmount: c.MaxSendAmount,
MaxReceiveVolume: c.MaxReceiveVolume,
MaxReceiveAmount: c.MaxReceiveAmount,
MaxAccountBalance: c.MaxAccountBalance,
}
secured := e.Group("", tokens.Middleware(c.JWTSecret, limits), logMw)
securedWithStrictRateLimit := e.Group("", tokens.Middleware(c.JWTSecret, limits), strictRateLimitMiddleware, logMw)
transport.RegisterLegacyEndpoints(svc, e, secured, securedWithStrictRateLimit, strictRateLimitMiddleware, tokens.AdminTokenMiddleware(c.AdminToken), logMw)
transport.RegisterV2Endpoints(svc, e, secured, securedWithStrictRateLimit, strictRateLimitMiddleware, tokens.AdminTokenMiddleware(c.AdminToken), logMw)

View File

@@ -37,6 +37,7 @@ func (controller *AddInvoiceController) AddInvoice(c echo.Context) error {
func AddInvoice(c echo.Context, svc *service.LndhubService, userID int64) error {
var body AddInvoiceRequestBody
limits := svc.GetLimitsFromContext(c)
if err := c.Bind(&body); err != nil {
c.Logger().Errorf("Failed to load addinvoice request body: %v", err)
@@ -61,7 +62,7 @@ func AddInvoice(c echo.Context, svc *service.LndhubService, userID int64) error
return c.JSON(http.StatusBadRequest, responses.BadArgumentsError)
}
resp, err := svc.CheckIncomingPaymentAllowed(c.Request().Context(), amount, userID)
resp, err := svc.CheckIncomingPaymentAllowed(c.Request().Context(), amount, userID, limits)
if err != nil {
return c.JSON(http.StatusInternalServerError, responses.GeneralServerError)
}

View File

@@ -46,6 +46,7 @@ type KeySendResponseBody struct {
func (controller *KeySendController) KeySend(c echo.Context) error {
userID := c.Get("UserID").(int64)
limits := controller.svc.GetLimitsFromContext(c)
reqBody := KeySendRequestBody{}
if err := c.Bind(&reqBody); err != nil {
c.Logger().Errorf("Failed to load keysend request body: %v", err)
@@ -75,7 +76,7 @@ func (controller *KeySendController) KeySend(c echo.Context) error {
})
}
resp, err := controller.svc.CheckOutgoingPaymentAllowed(c.Request().Context(), lnPayReq, userID)
resp, err := controller.svc.CheckOutgoingPaymentAllowed(c.Request().Context(), lnPayReq, userID, limits)
if err != nil {
return c.JSON(http.StatusInternalServerError, responses.GeneralServerError)
}

View File

@@ -43,6 +43,7 @@ type PayInvoiceResponseBody struct {
func (controller *PayInvoiceController) PayInvoice(c echo.Context) error {
userID := c.Get("UserID").(int64)
limits := controller.svc.GetLimitsFromContext(c)
reqBody := PayInvoiceRequestBody{}
if err := c.Bind(&reqBody); err != nil {
c.Logger().Errorf("Failed to load payinvoice request body: user_id:%v error: %v", userID, err)
@@ -90,7 +91,7 @@ func (controller *PayInvoiceController) PayInvoice(c echo.Context) error {
lnPayReq.PayReq.NumSatoshis = amt
}
resp, err := controller.svc.CheckOutgoingPaymentAllowed(c.Request().Context(), lnPayReq, userID)
resp, err := controller.svc.CheckOutgoingPaymentAllowed(c.Request().Context(), lnPayReq, userID, limits)
if err != nil {
return c.JSON(http.StatusInternalServerError, responses.GeneralServerError)
}

View File

@@ -165,6 +165,8 @@ type AddInvoiceResponseBody struct {
// @Security OAuth2Password
func (controller *InvoiceController) AddInvoice(c echo.Context) error {
userID := c.Get("UserID").(int64)
limits := controller.svc.GetLimitsFromContext(c)
var body AddInvoiceRequestBody
if err := c.Bind(&body); err != nil {
@@ -177,7 +179,7 @@ func (controller *InvoiceController) AddInvoice(c echo.Context) error {
return c.JSON(http.StatusBadRequest, responses.BadArgumentsError)
}
resp, err := controller.svc.CheckIncomingPaymentAllowed(c.Request().Context(), body.Amount, userID)
resp, err := controller.svc.CheckIncomingPaymentAllowed(c.Request().Context(), body.Amount, userID, limits)
if err != nil {
return c.JSON(http.StatusInternalServerError, responses.GeneralServerError)
}

View File

@@ -71,6 +71,7 @@ type KeySendResponseBody struct {
// @Security OAuth2Password
func (controller *KeySendController) KeySend(c echo.Context) error {
userID := c.Get("UserID").(int64)
limits := controller.svc.GetLimitsFromContext(c)
reqBody := KeySendRequestBody{}
if err := c.Bind(&reqBody); err != nil {
c.Logger().Errorf("Failed to load keysend request body: %v", err)
@@ -81,7 +82,7 @@ func (controller *KeySendController) KeySend(c echo.Context) error {
c.Logger().Errorf("Invalid keysend request body: %v", err)
return c.JSON(http.StatusBadRequest, responses.BadArgumentsError)
}
errResp := controller.checkKeysendPaymentAllowed(context.Background(), reqBody.Amount, userID)
errResp := controller.checkKeysendPaymentAllowed(context.Background(), reqBody.Amount, userID, limits)
if errResp != nil {
c.Logger().Errorf("Failed to send keysend: %s", errResp.Message)
return c.JSON(errResp.HttpStatusCode, errResp)
@@ -108,6 +109,7 @@ func (controller *KeySendController) KeySend(c echo.Context) error {
// @Security OAuth2Password
func (controller *KeySendController) MultiKeySend(c echo.Context) error {
userID := c.Get("UserID").(int64)
limits := controller.svc.GetLimitsFromContext(c)
reqBody := MultiKeySendRequestBody{}
if err := c.Bind(&reqBody); err != nil {
c.Logger().Errorf("Failed to load keysend request body: %v", err)
@@ -127,7 +129,7 @@ func (controller *KeySendController) MultiKeySend(c echo.Context) error {
for _, keysend := range reqBody.Keysends {
totalAmount += keysend.Amount
}
errResp := controller.checkKeysendPaymentAllowed(context.Background(), totalAmount, userID)
errResp := controller.checkKeysendPaymentAllowed(context.Background(), totalAmount, userID, limits)
if errResp != nil {
c.Logger().Errorf("Failed to make keysend split payments: %s", errResp.Message)
return c.JSON(errResp.HttpStatusCode, errResp)
@@ -162,14 +164,14 @@ func (controller *KeySendController) MultiKeySend(c echo.Context) error {
return c.JSON(status, result)
}
func (controller *KeySendController) checkKeysendPaymentAllowed(ctx context.Context, amount, userID int64) (resp *responses.ErrorResponse) {
func (controller *KeySendController) checkKeysendPaymentAllowed(ctx context.Context, amount, userID int64, limits *lnd.Limits) (resp *responses.ErrorResponse) {
syntheticPayReq := &lnd.LNPayReq{
PayReq: &lnrpc.PayReq{
NumSatoshis: amount,
},
Keysend: true,
}
resp, err := controller.svc.CheckOutgoingPaymentAllowed(ctx, syntheticPayReq, userID)
resp, err := controller.svc.CheckOutgoingPaymentAllowed(ctx, syntheticPayReq, userID, limits)
if err != nil {
return &responses.GeneralServerError
}

View File

@@ -52,6 +52,7 @@ type PayInvoiceResponseBody struct {
// @Security OAuth2Password
func (controller *PayInvoiceController) PayInvoice(c echo.Context) error {
userID := c.Get("UserID").(int64)
limits := controller.svc.GetLimitsFromContext(c)
reqBody := PayInvoiceRequestBody{}
if err := c.Bind(&reqBody); err != nil {
c.Logger().Errorf("Failed to load payinvoice request body: user_id:%v error: %v", userID, err)
@@ -98,7 +99,7 @@ func (controller *PayInvoiceController) PayInvoice(c echo.Context) error {
}
lnPayReq.PayReq.NumSatoshis = amt
}
resp, err := controller.svc.CheckOutgoingPaymentAllowed(c.Request().Context(), lnPayReq, userID)
resp, err := controller.svc.CheckOutgoingPaymentAllowed(c.Request().Context(), lnPayReq, userID, limits)
if err != nil {
return c.JSON(http.StatusBadRequest, responses.GeneralServerError)
}

View File

@@ -15,6 +15,7 @@ import (
"github.com/getAlby/lndhub.go/lib/responses"
"github.com/getAlby/lndhub.go/lib/service"
"github.com/getAlby/lndhub.go/lib/tokens"
"github.com/getAlby/lndhub.go/lnd"
"github.com/go-playground/validator/v10"
"github.com/labstack/echo/v4"
"github.com/stretchr/testify/assert"
@@ -56,7 +57,7 @@ func (suite *CheckPaymentTestSuite) SetupSuite() {
assert.Equal(suite.T(), 1, len(userTokens))
suite.userLogin = users[0]
suite.userToken = userTokens[0]
suite.echo.Use(tokens.Middleware([]byte(suite.service.Config.JWTSecret)))
suite.echo.Use(tokens.Middleware([]byte(suite.service.Config.JWTSecret), &lnd.Limits{}))
suite.echo.POST("/addinvoice", controllers.NewAddInvoiceController(suite.service).AddInvoice)
suite.echo.POST("/payinvoice", controllers.NewPayInvoiceController(suite.service).PayInvoice)
suite.echo.GET("/checkpayment/:payment_hash", controllers.NewCheckPaymentController(suite.service).CheckPayment)

View File

@@ -13,6 +13,7 @@ import (
"github.com/getAlby/lndhub.go/lib/responses"
"github.com/getAlby/lndhub.go/lib/service"
"github.com/getAlby/lndhub.go/lib/tokens"
"github.com/getAlby/lndhub.go/lnd"
"github.com/go-playground/validator/v10"
"github.com/labstack/echo/v4"
"github.com/lightningnetwork/lnd/lnrpc"
@@ -47,7 +48,7 @@ func (suite *GetInfoTestSuite) SetupSuite() {
assert.Equal(suite.T(), 1, len(userTokens))
suite.userLogin = users[0]
suite.userToken = userTokens[0]
suite.echo.Use(tokens.Middleware([]byte(suite.service.Config.JWTSecret)))
suite.echo.Use(tokens.Middleware([]byte(suite.service.Config.JWTSecret), &lnd.Limits{}))
suite.echo.GET("/getinfo", controllers.NewGetInfoController(svc).GetInfo)
}

View File

@@ -15,6 +15,7 @@ import (
"github.com/getAlby/lndhub.go/lib/responses"
"github.com/getAlby/lndhub.go/lib/service"
"github.com/getAlby/lndhub.go/lib/tokens"
"github.com/getAlby/lndhub.go/lnd"
"github.com/go-playground/validator/v10"
"github.com/labstack/echo/v4"
"github.com/stretchr/testify/assert"
@@ -52,7 +53,7 @@ func (suite *GetTxTestSuite) SetupSuite() {
e.HTTPErrorHandler = responses.HTTPErrorHandler
e.Validator = &lib.CustomValidator{Validator: validator.New()}
suite.echo = e
suite.echo.Use(tokens.Middleware([]byte(suite.Service.Config.JWTSecret)))
suite.echo.Use(tokens.Middleware([]byte(suite.Service.Config.JWTSecret), &lnd.Limits{}))
suite.echo.GET("/gettxs", controllers.NewGetTXSController(suite.Service).GetTXS)
suite.echo.GET("/getuserinvoices", controllers.NewGetTXSController(svc).GetUserInvoices)
suite.echo.POST("/addinvoice", controllers.NewAddInvoiceController(suite.Service).AddInvoice)

View File

@@ -14,6 +14,7 @@ import (
"github.com/getAlby/lndhub.go/lib/responses"
"github.com/getAlby/lndhub.go/lib/service"
"github.com/getAlby/lndhub.go/lib/tokens"
"github.com/getAlby/lndhub.go/lnd"
"github.com/go-playground/validator/v10"
"github.com/labstack/echo/v4"
"github.com/lightningnetwork/lnd/lnrpc"
@@ -72,7 +73,7 @@ func (suite *HodlInvoiceSuite) SetupSuite() {
suite.userLogin = users[0]
suite.userToken = userTokens[0]
suite.userToken2 = userTokens[1]
suite.echo.Use(tokens.Middleware([]byte(suite.service.Config.JWTSecret)))
suite.echo.Use(tokens.Middleware([]byte(suite.service.Config.JWTSecret), &lnd.Limits{}))
suite.echo.GET("/balance", controllers.NewBalanceController(suite.service).Balance)
suite.echo.POST("/addinvoice", controllers.NewAddInvoiceController(suite.service).AddInvoice)
suite.echo.POST("/payinvoice", controllers.NewPayInvoiceController(suite.service).PayInvoice)

View File

@@ -17,6 +17,7 @@ import (
"github.com/getAlby/lndhub.go/lib/responses"
"github.com/getAlby/lndhub.go/lib/service"
"github.com/getAlby/lndhub.go/lib/tokens"
"github.com/getAlby/lndhub.go/lnd"
"github.com/go-playground/validator/v10"
"github.com/labstack/echo/v4"
"github.com/labstack/gommon/random"
@@ -71,7 +72,7 @@ func (suite *IncomingPaymentTestSuite) TestIncomingPayment() {
req := httptest.NewRequest(http.MethodGet, "/balance", &buf)
req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", suite.userToken))
rec := httptest.NewRecorder()
suite.echo.Use(tokens.Middleware([]byte(suite.service.Config.JWTSecret)))
suite.echo.Use(tokens.Middleware([]byte(suite.service.Config.JWTSecret), &lnd.Limits{}))
suite.echo.GET("/balance", controllers.NewBalanceController(suite.service).Balance)
suite.echo.POST("/addinvoice", controllers.NewAddInvoiceController(suite.service).AddInvoice)
suite.echo.ServeHTTP(rec, req)
@@ -102,7 +103,7 @@ func (suite *IncomingPaymentTestSuite) TestIncomingPaymentZeroAmt() {
req := httptest.NewRequest(http.MethodGet, "/balance", &buf)
req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", suite.userToken))
rec := httptest.NewRecorder()
suite.echo.Use(tokens.Middleware([]byte(suite.service.Config.JWTSecret)))
suite.echo.Use(tokens.Middleware([]byte(suite.service.Config.JWTSecret), &lnd.Limits{}))
suite.echo.GET("/balance", controllers.NewBalanceController(suite.service).Balance)
suite.echo.POST("/addinvoice", controllers.NewAddInvoiceController(suite.service).AddInvoice)
suite.echo.ServeHTTP(rec, req)
@@ -144,7 +145,7 @@ func (suite *IncomingPaymentTestSuite) TestIncomingPaymentKeysend() {
req := httptest.NewRequest(http.MethodGet, "/balance", &buf)
req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", suite.userToken))
rec := httptest.NewRecorder()
suite.echo.Use(tokens.Middleware([]byte(suite.service.Config.JWTSecret)))
suite.echo.Use(tokens.Middleware([]byte(suite.service.Config.JWTSecret), &lnd.Limits{}))
suite.echo.GET("/balance", controllers.NewBalanceController(suite.service).Balance)
suite.echo.GET("/getuserinvoices", controllers.NewGetTXSController(suite.service).GetUserInvoices)
suite.echo.ServeHTTP(rec, req)

View File

@@ -17,6 +17,7 @@ import (
"github.com/getAlby/lndhub.go/lib/responses"
"github.com/getAlby/lndhub.go/lib/service"
"github.com/getAlby/lndhub.go/lib/tokens"
"github.com/getAlby/lndhub.go/lnd"
"github.com/go-playground/validator/v10"
"github.com/labstack/echo/v4"
"github.com/lightningnetwork/lnd/lnrpc"
@@ -69,7 +70,8 @@ func (suite *PaymentTestSuite) SetupSuite() {
suite.aliceToken = userTokens[0]
suite.bobLogin = users[1]
suite.bobToken = userTokens[1]
suite.echo.Use(tokens.Middleware([]byte(suite.service.Config.JWTSecret)))
suite.echo.Use(tokens.Middleware([]byte(suite.service.Config.JWTSecret), &lnd.Limits{}))
suite.echo.GET("/balance", controllers.NewBalanceController(suite.service).Balance)
suite.echo.POST("/addinvoice", controllers.NewAddInvoiceController(suite.service).AddInvoice)
suite.echo.POST("/payinvoice", controllers.NewPayInvoiceController(suite.service).PayInvoice)

View File

@@ -18,6 +18,7 @@ import (
"github.com/getAlby/lndhub.go/lib/responses"
"github.com/getAlby/lndhub.go/lib/service"
"github.com/getAlby/lndhub.go/lib/tokens"
"github.com/getAlby/lndhub.go/lnd"
"github.com/go-playground/validator/v10"
"github.com/labstack/echo/v4"
"github.com/stretchr/testify/assert"
@@ -51,7 +52,7 @@ func (suite *InvoiceTestSuite) SetupSuite() {
suite.aliceLogin = users[0]
suite.aliceToken = userTokens[0]
suite.echo.POST("/invoice/:user_login", controllers.NewInvoiceController(svc).Invoice)
suite.echo.POST("/v2/invoices", v2controllers.NewInvoiceController(svc).AddInvoice, tokens.Middleware([]byte(suite.service.Config.JWTSecret)))
suite.echo.POST("/v2/invoices", v2controllers.NewInvoiceController(svc).AddInvoice, tokens.Middleware([]byte(suite.service.Config.JWTSecret), &lnd.Limits{}))
}
func (suite *InvoiceTestSuite) TearDownTest() {

View File

@@ -13,6 +13,7 @@ import (
"github.com/getAlby/lndhub.go/lib/responses"
"github.com/getAlby/lndhub.go/lib/service"
"github.com/getAlby/lndhub.go/lib/tokens"
"github.com/getAlby/lndhub.go/lnd"
"github.com/go-playground/validator/v10"
"github.com/labstack/echo/v4"
"github.com/stretchr/testify/assert"
@@ -72,7 +73,7 @@ func (suite *KeySendFailureTestSuite) SetupSuite() {
assert.Equal(suite.T(), 1, len(userTokens))
suite.aliceLogin = users[0]
suite.aliceToken = userTokens[0]
suite.echo.Use(tokens.Middleware([]byte(suite.service.Config.JWTSecret)))
suite.echo.Use(tokens.Middleware([]byte(suite.service.Config.JWTSecret), &lnd.Limits{}))
suite.echo.POST("/addinvoice", controllers.NewAddInvoiceController(suite.service).AddInvoice)
suite.echo.POST("/keysend", controllers.NewKeySendController(suite.service).KeySend)
}

View File

@@ -17,6 +17,7 @@ import (
"github.com/getAlby/lndhub.go/lib/responses"
"github.com/getAlby/lndhub.go/lib/service"
"github.com/getAlby/lndhub.go/lib/tokens"
"github.com/getAlby/lndhub.go/lnd"
"github.com/go-playground/validator/v10"
"github.com/labstack/echo/v4"
"github.com/stretchr/testify/assert"
@@ -60,7 +61,7 @@ func (suite *KeySendTestSuite) SetupSuite() {
assert.Equal(suite.T(), 1, len(userTokens))
suite.aliceLogin = users[0]
suite.aliceToken = userTokens[0]
suite.echo.Use(tokens.Middleware([]byte(suite.service.Config.JWTSecret)))
suite.echo.Use(tokens.Middleware([]byte(suite.service.Config.JWTSecret), &lnd.Limits{}))
suite.echo.GET("/balance", controllers.NewBalanceController(suite.service).Balance)
suite.echo.POST("/addinvoice", controllers.NewAddInvoiceController(suite.service).AddInvoice)
suite.echo.POST("/payinvoice", controllers.NewPayInvoiceController(suite.service).PayInvoice)

View File

@@ -13,6 +13,7 @@ import (
"github.com/getAlby/lndhub.go/lib/responses"
"github.com/getAlby/lndhub.go/lib/service"
"github.com/getAlby/lndhub.go/lib/tokens"
"github.com/getAlby/lndhub.go/lnd"
"github.com/go-playground/validator/v10"
"github.com/labstack/echo/v4"
"github.com/lightningnetwork/lnd/lnrpc"
@@ -69,7 +70,7 @@ func (suite *PaymentTestAsyncErrorsSuite) SetupSuite() {
assert.Equal(suite.T(), 1, len(userTokens))
suite.userLogin = users[0]
suite.userToken = userTokens[0]
suite.echo.Use(tokens.Middleware([]byte(suite.service.Config.JWTSecret)))
suite.echo.Use(tokens.Middleware([]byte(suite.service.Config.JWTSecret), &lnd.Limits{}))
suite.echo.GET("/balance", controllers.NewBalanceController(suite.service).Balance)
suite.echo.POST("/addinvoice", controllers.NewAddInvoiceController(suite.service).AddInvoice)
suite.echo.POST("/payinvoice", controllers.NewPayInvoiceController(suite.service).PayInvoice)

View File

@@ -17,6 +17,7 @@ import (
"github.com/getAlby/lndhub.go/lib/responses"
"github.com/getAlby/lndhub.go/lib/service"
"github.com/getAlby/lndhub.go/lib/tokens"
"github.com/getAlby/lndhub.go/lnd"
"github.com/go-playground/validator/v10"
"github.com/labstack/echo/v4"
"github.com/lightningnetwork/lnd/lnrpc"
@@ -71,7 +72,7 @@ func (suite *PaymentTestErrorsSuite) SetupSuite() {
assert.Equal(suite.T(), 1, len(userTokens))
suite.userLogin = users[0]
suite.userToken = userTokens[0]
suite.echo.Use(tokens.Middleware([]byte(suite.service.Config.JWTSecret)))
suite.echo.Use(tokens.Middleware([]byte(suite.service.Config.JWTSecret), &lnd.Limits{}))
suite.echo.GET("/balance", controllers.NewBalanceController(suite.service).Balance)
suite.echo.POST("/addinvoice", controllers.NewAddInvoiceController(suite.service).AddInvoice)
suite.echo.POST("/payinvoice", controllers.NewPayInvoiceController(suite.service).PayInvoice)

View File

@@ -16,6 +16,7 @@ import (
"github.com/getAlby/lndhub.go/lib/responses"
"github.com/getAlby/lndhub.go/lib/service"
"github.com/getAlby/lndhub.go/lib/tokens"
"github.com/getAlby/lndhub.go/lnd"
"github.com/go-playground/validator/v10"
"github.com/labstack/echo/v4"
"github.com/lightningnetwork/lnd/lnrpc"
@@ -64,7 +65,7 @@ func (suite *RabbitMQTestSuite) SetupSuite() {
e.Validator = &lib.CustomValidator{Validator: validator.New()}
suite.echo = e
suite.echo.Use(tokens.Middleware(suite.svc.Config.JWTSecret))
suite.echo.Use(tokens.Middleware([]byte(suite.svc.Config.JWTSecret), &lnd.Limits{}))
suite.echo.POST("/addinvoice", controllers.NewAddInvoiceController(suite.svc).AddInvoice)
suite.echo.POST("/payinvoice", controllers.NewPayInvoiceController(suite.svc).PayInvoice)
go func() {

View File

@@ -56,7 +56,7 @@ func (suite *SubscriptionStartTestSuite) SetupSuite() {
suite.echo = e
suite.userLogin = users[0]
suite.userToken = userTokens[0]
suite.echo.Use(tokens.Middleware([]byte(suite.service.Config.JWTSecret)))
suite.echo.Use(tokens.Middleware([]byte(suite.service.Config.JWTSecret), &lnd.Limits{}))
suite.echo.POST("/addinvoice", controllers.NewAddInvoiceController(suite.service).AddInvoice)
}

View File

@@ -15,6 +15,7 @@ import (
"github.com/getAlby/lndhub.go/lib/responses"
"github.com/getAlby/lndhub.go/lib/service"
"github.com/getAlby/lndhub.go/lib/tokens"
"github.com/getAlby/lndhub.go/lnd"
"github.com/go-playground/validator/v10"
"github.com/labstack/echo/v4"
"github.com/stretchr/testify/assert"
@@ -73,7 +74,7 @@ func (suite *WebHookTestSuite) SetupSuite() {
suite.echo = e
suite.userLogin = users[0]
suite.userToken = userTokens[0]
suite.echo.Use(tokens.Middleware([]byte(suite.service.Config.JWTSecret)))
suite.echo.Use(tokens.Middleware([]byte(suite.service.Config.JWTSecret), &lnd.Limits{}))
suite.echo.POST("/addinvoice", controllers.NewAddInvoiceController(suite.service).AddInvoice)
}
func (suite *WebHookTestSuite) TestWebHook() {

View File

@@ -13,6 +13,7 @@ import (
"github.com/getAlby/lndhub.go/lib/security"
"github.com/getAlby/lndhub.go/lnd"
"github.com/getsentry/sentry-go"
"github.com/labstack/echo/v4"
"github.com/labstack/gommon/log"
"github.com/uptrace/bun"
passwordvalidator "github.com/wagslane/go-password-validator"
@@ -125,9 +126,9 @@ func (svc *LndhubService) FindUserByLogin(ctx context.Context, login string) (*m
return &user, nil
}
func (svc *LndhubService) CheckOutgoingPaymentAllowed(ctx context.Context, lnpayReq *lnd.LNPayReq, userId int64) (result *responses.ErrorResponse, err error) {
if svc.Config.MaxSendAmount > 0 {
if lnpayReq.PayReq.NumSatoshis > svc.Config.MaxSendAmount {
func (svc *LndhubService) CheckOutgoingPaymentAllowed(ctx context.Context, lnpayReq *lnd.LNPayReq, userId int64, limits *lnd.Limits) (result *responses.ErrorResponse, err error) {
if limits.MaxSendAmount > 0 {
if lnpayReq.PayReq.NumSatoshis > limits.MaxSendAmount {
svc.Logger.Errorf("Max send amount exceeded for user_id %v (amount:%v)", userId, lnpayReq.PayReq.NumSatoshis)
return &responses.SendExceededError, nil
}
@@ -153,18 +154,18 @@ func (svc *LndhubService) CheckOutgoingPaymentAllowed(ctx context.Context, lnpay
return &responses.NotEnoughBalanceError, nil
}
return svc.CheckVolumeAllowed(ctx, userId, common.InvoiceTypeOutgoing)
return svc.CheckVolumeAllowed(ctx, userId, limits.MaxSendVolume, common.InvoiceTypeOutgoing)
}
func (svc *LndhubService) CheckIncomingPaymentAllowed(ctx context.Context, amount, userId int64) (result *responses.ErrorResponse, err error) {
if svc.Config.MaxReceiveAmount > 0 {
if amount > svc.Config.MaxReceiveAmount {
func (svc *LndhubService) CheckIncomingPaymentAllowed(ctx context.Context, amount, userId int64, limits *lnd.Limits) (result *responses.ErrorResponse, err error) {
if limits.MaxReceiveAmount > 0 {
if amount > limits.MaxReceiveAmount {
svc.Logger.Errorf("Max receive amount exceeded for user_id %d", userId)
return &responses.ReceiveExceededError, nil
}
}
if svc.Config.MaxAccountBalance > 0 {
if limits.MaxAccountBalance > 0 {
currentBalance, err := svc.CurrentUserBalance(ctx, userId)
if err != nil {
svc.Logger.Errorj(
@@ -176,22 +177,16 @@ func (svc *LndhubService) CheckIncomingPaymentAllowed(ctx context.Context, amoun
)
return nil, err
}
if currentBalance+amount > svc.Config.MaxAccountBalance {
if currentBalance+amount > limits.MaxAccountBalance {
svc.Logger.Errorf("Max account balance exceeded for user_id %d", userId)
return &responses.BalanceExceededError, nil
}
}
return svc.CheckVolumeAllowed(ctx, userId, common.InvoiceTypeIncoming)
return svc.CheckVolumeAllowed(ctx, userId, limits.MaxReceiveVolume, common.InvoiceTypeIncoming)
}
func (svc *LndhubService) CheckVolumeAllowed(ctx context.Context, userId int64, invoiceType string) (result *responses.ErrorResponse, err error) {
var maxVolume int64
if invoiceType == common.InvoiceTypeIncoming {
maxVolume = svc.Config.MaxReceiveVolume
} else {
maxVolume = svc.Config.MaxSendVolume
}
func (svc *LndhubService) CheckVolumeAllowed(ctx context.Context, userId, maxVolume int64, invoiceType string) (result *responses.ErrorResponse, err error) {
if maxVolume > 0 {
volume, err := svc.GetVolumeOverPeriod(ctx, userId, invoiceType, time.Duration(svc.Config.MaxVolumePeriod*int64(time.Second)))
if err != nil {
@@ -279,3 +274,24 @@ func (svc *LndhubService) GetVolumeOverPeriod(ctx context.Context, userId int64,
}
return result, nil
}
func (svc *LndhubService) GetLimitsFromContext(c echo.Context) (limits *lnd.Limits) {
limits = &lnd.Limits{}
if val, ok := c.Get("MaxSendVolume").(int64); ok {
limits.MaxSendVolume = val
}
if val, ok := c.Get("MaxSendAmount").(int64); ok {
limits.MaxSendAmount = val
}
if val, ok := c.Get("MaxReceiveVolume").(int64); ok {
limits.MaxReceiveVolume = val
}
if val, ok := c.Get("MaxReceiveAmount").(int64); ok {
limits.MaxReceiveAmount = val
}
if val, ok := c.Get("MaxAccountBalance").(int64); ok {
limits.MaxAccountBalance = val
}
return limits
}

View File

@@ -7,6 +7,7 @@ import (
"time"
"github.com/getAlby/lndhub.go/db/models"
"github.com/getAlby/lndhub.go/lnd"
"github.com/getsentry/sentry-go"
sentryecho "github.com/getsentry/sentry-go/echo"
"github.com/golang-jwt/jwt"
@@ -17,10 +18,15 @@ import (
type jwtCustomClaims struct {
ID int64 `json:"id"`
IsRefresh bool `json:"isRefresh"`
MaxSendVolume int64 `json:"maxSendVolume"`
MaxSendAmount int64 `json:"maxSendAmount"`
MaxReceiveVolume int64 `json:"maxReceiveVolume"`
MaxReceiveAmount int64 `json:"maxReceiveAmount"`
MaxAccountBalance int64 `json:"maxAccountBalance"`
jwt.StandardClaims
}
func Middleware(secret []byte) echo.MiddlewareFunc {
func Middleware(secret []byte, limits *lnd.Limits) echo.MiddlewareFunc {
config := middleware.DefaultJWTConfig
config.Claims = &jwtCustomClaims{}
@@ -38,6 +44,31 @@ func Middleware(secret []byte) echo.MiddlewareFunc {
token := c.Get("UserJwt").(*jwt.Token)
claims := token.Claims.(*jwtCustomClaims)
c.Set("UserID", claims.ID)
if limits.MaxSendVolume == 0 {
c.Set("MaxSendVolume", claims.MaxSendAmount)
} else {
c.Set("MaxSendVolume", limits.MaxSendVolume)
}
if limits.MaxSendAmount == 0 {
c.Set("MaxSendAmount", claims.MaxSendAmount)
} else {
c.Set("MaxSendAmount", limits.MaxSendAmount)
}
if limits.MaxReceiveVolume == 0 {
c.Set("MaxReceiveVolume", claims.MaxReceiveVolume)
} else {
c.Set("MaxReceiveVolume", limits.MaxReceiveVolume)
}
if limits.MaxReceiveAmount == 0 {
c.Set("MaxReceiveAmount", claims.MaxReceiveAmount)
} else {
c.Set("MaxReceiveAmount", limits.MaxReceiveAmount)
}
if limits.MaxAccountBalance == 0 {
c.Set("MaxAccountBalance", claims.MaxAccountBalance)
} else {
c.Set("MaxAccountBalance", limits.MaxAccountBalance)
}
// pass UserID to sentry for exception notifications
if hub := sentryecho.GetHubFromContext(c); hub != nil {
hub.Scope().SetUser(sentry.User{ID: strconv.FormatInt(claims.ID, 10)})

View File

@@ -22,6 +22,14 @@ type Config struct {
LNDClusterPubkeys string `envconfig:"LND_CLUSTER_PUBKEYS"` //comma-seperated list of public keys of the cluster
}
type Limits struct {
MaxSendVolume int64
MaxSendAmount int64
MaxReceiveVolume int64
MaxReceiveAmount int64
MaxAccountBalance int64
}
func LoadConfig() (c *Config, err error) {
c = &Config{}
err = envconfig.Process("", c)