diff --git a/cmd/server/main.go b/cmd/server/main.go index fa5a61a..9d6c9e3 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -153,8 +153,17 @@ func main() { logMw := transport.CreateLoggingMiddleware(logger) // strict rate limit for requests for sending payments strictRateLimitMiddleware := transport.CreateRateLimitMiddleware(c.StrictRateLimit, c.BurstRateLimit) - secured := e.Group("", tokens.Middleware(c.JWTSecret), logMw) - securedWithStrictRateLimit := e.Group("", tokens.Middleware(c.JWTSecret), strictRateLimitMiddleware, logMw) + + limits := &lnd.Limits{ + MaxSendVolume: c.MaxSendVolume, + MaxSendAmount: c.MaxSendAmount, + MaxReceiveVolume: c.MaxReceiveVolume, + MaxReceiveAmount: c.MaxReceiveAmount, + MaxAccountBalance: c.MaxAccountBalance, + } + + secured := e.Group("", tokens.Middleware(c.JWTSecret, limits), logMw) + securedWithStrictRateLimit := e.Group("", tokens.Middleware(c.JWTSecret, limits), strictRateLimitMiddleware, logMw) transport.RegisterLegacyEndpoints(svc, e, secured, securedWithStrictRateLimit, strictRateLimitMiddleware, tokens.AdminTokenMiddleware(c.AdminToken), logMw) transport.RegisterV2Endpoints(svc, e, secured, securedWithStrictRateLimit, strictRateLimitMiddleware, tokens.AdminTokenMiddleware(c.AdminToken), logMw) diff --git a/controllers/addinvoice.ctrl.go b/controllers/addinvoice.ctrl.go index 90ccbdd..6da8de6 100644 --- a/controllers/addinvoice.ctrl.go +++ b/controllers/addinvoice.ctrl.go @@ -37,6 +37,7 @@ func (controller *AddInvoiceController) AddInvoice(c echo.Context) error { func AddInvoice(c echo.Context, svc *service.LndhubService, userID int64) error { var body AddInvoiceRequestBody + limits := svc.GetLimitsFromContext(c) if err := c.Bind(&body); err != nil { c.Logger().Errorf("Failed to load addinvoice request body: %v", err) @@ -61,7 +62,7 @@ func AddInvoice(c echo.Context, svc *service.LndhubService, userID int64) error return c.JSON(http.StatusBadRequest, responses.BadArgumentsError) } - resp, err := svc.CheckIncomingPaymentAllowed(c.Request().Context(), amount, userID) + resp, err := svc.CheckIncomingPaymentAllowed(c.Request().Context(), amount, userID, limits) if err != nil { return c.JSON(http.StatusInternalServerError, responses.GeneralServerError) } diff --git a/controllers/keysend.ctrl.go b/controllers/keysend.ctrl.go index 0eda406..dff2887 100644 --- a/controllers/keysend.ctrl.go +++ b/controllers/keysend.ctrl.go @@ -46,6 +46,7 @@ type KeySendResponseBody struct { func (controller *KeySendController) KeySend(c echo.Context) error { userID := c.Get("UserID").(int64) + limits := controller.svc.GetLimitsFromContext(c) reqBody := KeySendRequestBody{} if err := c.Bind(&reqBody); err != nil { c.Logger().Errorf("Failed to load keysend request body: %v", err) @@ -75,7 +76,7 @@ func (controller *KeySendController) KeySend(c echo.Context) error { }) } - resp, err := controller.svc.CheckOutgoingPaymentAllowed(c.Request().Context(), lnPayReq, userID) + resp, err := controller.svc.CheckOutgoingPaymentAllowed(c.Request().Context(), lnPayReq, userID, limits) if err != nil { return c.JSON(http.StatusInternalServerError, responses.GeneralServerError) } diff --git a/controllers/payinvoice.ctrl.go b/controllers/payinvoice.ctrl.go index 04dd397..f6c2634 100644 --- a/controllers/payinvoice.ctrl.go +++ b/controllers/payinvoice.ctrl.go @@ -43,6 +43,7 @@ type PayInvoiceResponseBody struct { func (controller *PayInvoiceController) PayInvoice(c echo.Context) error { userID := c.Get("UserID").(int64) + limits := controller.svc.GetLimitsFromContext(c) reqBody := PayInvoiceRequestBody{} if err := c.Bind(&reqBody); err != nil { c.Logger().Errorf("Failed to load payinvoice request body: user_id:%v error: %v", userID, err) @@ -90,7 +91,7 @@ func (controller *PayInvoiceController) PayInvoice(c echo.Context) error { lnPayReq.PayReq.NumSatoshis = amt } - resp, err := controller.svc.CheckOutgoingPaymentAllowed(c.Request().Context(), lnPayReq, userID) + resp, err := controller.svc.CheckOutgoingPaymentAllowed(c.Request().Context(), lnPayReq, userID, limits) if err != nil { return c.JSON(http.StatusInternalServerError, responses.GeneralServerError) } diff --git a/controllers_v2/invoice.ctrl.go b/controllers_v2/invoice.ctrl.go index d38a353..973a364 100644 --- a/controllers_v2/invoice.ctrl.go +++ b/controllers_v2/invoice.ctrl.go @@ -165,6 +165,8 @@ type AddInvoiceResponseBody struct { // @Security OAuth2Password func (controller *InvoiceController) AddInvoice(c echo.Context) error { userID := c.Get("UserID").(int64) + limits := controller.svc.GetLimitsFromContext(c) + var body AddInvoiceRequestBody if err := c.Bind(&body); err != nil { @@ -177,7 +179,7 @@ func (controller *InvoiceController) AddInvoice(c echo.Context) error { return c.JSON(http.StatusBadRequest, responses.BadArgumentsError) } - resp, err := controller.svc.CheckIncomingPaymentAllowed(c.Request().Context(), body.Amount, userID) + resp, err := controller.svc.CheckIncomingPaymentAllowed(c.Request().Context(), body.Amount, userID, limits) if err != nil { return c.JSON(http.StatusInternalServerError, responses.GeneralServerError) } diff --git a/controllers_v2/keysend.ctrl.go b/controllers_v2/keysend.ctrl.go index 1c9e28c..2db1ec0 100644 --- a/controllers_v2/keysend.ctrl.go +++ b/controllers_v2/keysend.ctrl.go @@ -71,6 +71,7 @@ type KeySendResponseBody struct { // @Security OAuth2Password func (controller *KeySendController) KeySend(c echo.Context) error { userID := c.Get("UserID").(int64) + limits := controller.svc.GetLimitsFromContext(c) reqBody := KeySendRequestBody{} if err := c.Bind(&reqBody); err != nil { c.Logger().Errorf("Failed to load keysend request body: %v", err) @@ -81,7 +82,7 @@ func (controller *KeySendController) KeySend(c echo.Context) error { c.Logger().Errorf("Invalid keysend request body: %v", err) return c.JSON(http.StatusBadRequest, responses.BadArgumentsError) } - errResp := controller.checkKeysendPaymentAllowed(context.Background(), reqBody.Amount, userID) + errResp := controller.checkKeysendPaymentAllowed(context.Background(), reqBody.Amount, userID, limits) if errResp != nil { c.Logger().Errorf("Failed to send keysend: %s", errResp.Message) return c.JSON(errResp.HttpStatusCode, errResp) @@ -108,6 +109,7 @@ func (controller *KeySendController) KeySend(c echo.Context) error { // @Security OAuth2Password func (controller *KeySendController) MultiKeySend(c echo.Context) error { userID := c.Get("UserID").(int64) + limits := controller.svc.GetLimitsFromContext(c) reqBody := MultiKeySendRequestBody{} if err := c.Bind(&reqBody); err != nil { c.Logger().Errorf("Failed to load keysend request body: %v", err) @@ -127,7 +129,7 @@ func (controller *KeySendController) MultiKeySend(c echo.Context) error { for _, keysend := range reqBody.Keysends { totalAmount += keysend.Amount } - errResp := controller.checkKeysendPaymentAllowed(context.Background(), totalAmount, userID) + errResp := controller.checkKeysendPaymentAllowed(context.Background(), totalAmount, userID, limits) if errResp != nil { c.Logger().Errorf("Failed to make keysend split payments: %s", errResp.Message) return c.JSON(errResp.HttpStatusCode, errResp) @@ -162,14 +164,14 @@ func (controller *KeySendController) MultiKeySend(c echo.Context) error { return c.JSON(status, result) } -func (controller *KeySendController) checkKeysendPaymentAllowed(ctx context.Context, amount, userID int64) (resp *responses.ErrorResponse) { +func (controller *KeySendController) checkKeysendPaymentAllowed(ctx context.Context, amount, userID int64, limits *lnd.Limits) (resp *responses.ErrorResponse) { syntheticPayReq := &lnd.LNPayReq{ PayReq: &lnrpc.PayReq{ NumSatoshis: amount, }, Keysend: true, } - resp, err := controller.svc.CheckOutgoingPaymentAllowed(ctx, syntheticPayReq, userID) + resp, err := controller.svc.CheckOutgoingPaymentAllowed(ctx, syntheticPayReq, userID, limits) if err != nil { return &responses.GeneralServerError } diff --git a/controllers_v2/payinvoice.ctrl.go b/controllers_v2/payinvoice.ctrl.go index b137507..3bbc6bc 100644 --- a/controllers_v2/payinvoice.ctrl.go +++ b/controllers_v2/payinvoice.ctrl.go @@ -52,6 +52,7 @@ type PayInvoiceResponseBody struct { // @Security OAuth2Password func (controller *PayInvoiceController) PayInvoice(c echo.Context) error { userID := c.Get("UserID").(int64) + limits := controller.svc.GetLimitsFromContext(c) reqBody := PayInvoiceRequestBody{} if err := c.Bind(&reqBody); err != nil { c.Logger().Errorf("Failed to load payinvoice request body: user_id:%v error: %v", userID, err) @@ -98,7 +99,7 @@ func (controller *PayInvoiceController) PayInvoice(c echo.Context) error { } lnPayReq.PayReq.NumSatoshis = amt } - resp, err := controller.svc.CheckOutgoingPaymentAllowed(c.Request().Context(), lnPayReq, userID) + resp, err := controller.svc.CheckOutgoingPaymentAllowed(c.Request().Context(), lnPayReq, userID, limits) if err != nil { return c.JSON(http.StatusBadRequest, responses.GeneralServerError) } diff --git a/integration_tests/checkpayment_test.go b/integration_tests/checkpayment_test.go index 86b9da9..67a379f 100644 --- a/integration_tests/checkpayment_test.go +++ b/integration_tests/checkpayment_test.go @@ -15,6 +15,7 @@ import ( "github.com/getAlby/lndhub.go/lib/responses" "github.com/getAlby/lndhub.go/lib/service" "github.com/getAlby/lndhub.go/lib/tokens" + "github.com/getAlby/lndhub.go/lnd" "github.com/go-playground/validator/v10" "github.com/labstack/echo/v4" "github.com/stretchr/testify/assert" @@ -56,7 +57,7 @@ func (suite *CheckPaymentTestSuite) SetupSuite() { assert.Equal(suite.T(), 1, len(userTokens)) suite.userLogin = users[0] suite.userToken = userTokens[0] - suite.echo.Use(tokens.Middleware([]byte(suite.service.Config.JWTSecret))) + suite.echo.Use(tokens.Middleware([]byte(suite.service.Config.JWTSecret), &lnd.Limits{})) suite.echo.POST("/addinvoice", controllers.NewAddInvoiceController(suite.service).AddInvoice) suite.echo.POST("/payinvoice", controllers.NewPayInvoiceController(suite.service).PayInvoice) suite.echo.GET("/checkpayment/:payment_hash", controllers.NewCheckPaymentController(suite.service).CheckPayment) diff --git a/integration_tests/getinfo_test.go b/integration_tests/getinfo_test.go index b6d0719..2d57b0b 100644 --- a/integration_tests/getinfo_test.go +++ b/integration_tests/getinfo_test.go @@ -13,6 +13,7 @@ import ( "github.com/getAlby/lndhub.go/lib/responses" "github.com/getAlby/lndhub.go/lib/service" "github.com/getAlby/lndhub.go/lib/tokens" + "github.com/getAlby/lndhub.go/lnd" "github.com/go-playground/validator/v10" "github.com/labstack/echo/v4" "github.com/lightningnetwork/lnd/lnrpc" @@ -47,7 +48,7 @@ func (suite *GetInfoTestSuite) SetupSuite() { assert.Equal(suite.T(), 1, len(userTokens)) suite.userLogin = users[0] suite.userToken = userTokens[0] - suite.echo.Use(tokens.Middleware([]byte(suite.service.Config.JWTSecret))) + suite.echo.Use(tokens.Middleware([]byte(suite.service.Config.JWTSecret), &lnd.Limits{})) suite.echo.GET("/getinfo", controllers.NewGetInfoController(svc).GetInfo) } diff --git a/integration_tests/gettxs_test.go b/integration_tests/gettxs_test.go index a1a6f05..206158a 100644 --- a/integration_tests/gettxs_test.go +++ b/integration_tests/gettxs_test.go @@ -15,6 +15,7 @@ import ( "github.com/getAlby/lndhub.go/lib/responses" "github.com/getAlby/lndhub.go/lib/service" "github.com/getAlby/lndhub.go/lib/tokens" + "github.com/getAlby/lndhub.go/lnd" "github.com/go-playground/validator/v10" "github.com/labstack/echo/v4" "github.com/stretchr/testify/assert" @@ -52,7 +53,7 @@ func (suite *GetTxTestSuite) SetupSuite() { e.HTTPErrorHandler = responses.HTTPErrorHandler e.Validator = &lib.CustomValidator{Validator: validator.New()} suite.echo = e - suite.echo.Use(tokens.Middleware([]byte(suite.Service.Config.JWTSecret))) + suite.echo.Use(tokens.Middleware([]byte(suite.Service.Config.JWTSecret), &lnd.Limits{})) suite.echo.GET("/gettxs", controllers.NewGetTXSController(suite.Service).GetTXS) suite.echo.GET("/getuserinvoices", controllers.NewGetTXSController(svc).GetUserInvoices) suite.echo.POST("/addinvoice", controllers.NewAddInvoiceController(suite.Service).AddInvoice) diff --git a/integration_tests/hodl_invoice_test.go b/integration_tests/hodl_invoice_test.go index 3e58e28..8a990c2 100644 --- a/integration_tests/hodl_invoice_test.go +++ b/integration_tests/hodl_invoice_test.go @@ -14,6 +14,7 @@ import ( "github.com/getAlby/lndhub.go/lib/responses" "github.com/getAlby/lndhub.go/lib/service" "github.com/getAlby/lndhub.go/lib/tokens" + "github.com/getAlby/lndhub.go/lnd" "github.com/go-playground/validator/v10" "github.com/labstack/echo/v4" "github.com/lightningnetwork/lnd/lnrpc" @@ -72,7 +73,7 @@ func (suite *HodlInvoiceSuite) SetupSuite() { suite.userLogin = users[0] suite.userToken = userTokens[0] suite.userToken2 = userTokens[1] - suite.echo.Use(tokens.Middleware([]byte(suite.service.Config.JWTSecret))) + suite.echo.Use(tokens.Middleware([]byte(suite.service.Config.JWTSecret), &lnd.Limits{})) suite.echo.GET("/balance", controllers.NewBalanceController(suite.service).Balance) suite.echo.POST("/addinvoice", controllers.NewAddInvoiceController(suite.service).AddInvoice) suite.echo.POST("/payinvoice", controllers.NewPayInvoiceController(suite.service).PayInvoice) diff --git a/integration_tests/incoming_payment_test.go b/integration_tests/incoming_payment_test.go index aabf5d1..fe6c17f 100644 --- a/integration_tests/incoming_payment_test.go +++ b/integration_tests/incoming_payment_test.go @@ -17,6 +17,7 @@ import ( "github.com/getAlby/lndhub.go/lib/responses" "github.com/getAlby/lndhub.go/lib/service" "github.com/getAlby/lndhub.go/lib/tokens" + "github.com/getAlby/lndhub.go/lnd" "github.com/go-playground/validator/v10" "github.com/labstack/echo/v4" "github.com/labstack/gommon/random" @@ -71,7 +72,7 @@ func (suite *IncomingPaymentTestSuite) TestIncomingPayment() { req := httptest.NewRequest(http.MethodGet, "/balance", &buf) req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", suite.userToken)) rec := httptest.NewRecorder() - suite.echo.Use(tokens.Middleware([]byte(suite.service.Config.JWTSecret))) + suite.echo.Use(tokens.Middleware([]byte(suite.service.Config.JWTSecret), &lnd.Limits{})) suite.echo.GET("/balance", controllers.NewBalanceController(suite.service).Balance) suite.echo.POST("/addinvoice", controllers.NewAddInvoiceController(suite.service).AddInvoice) suite.echo.ServeHTTP(rec, req) @@ -102,7 +103,7 @@ func (suite *IncomingPaymentTestSuite) TestIncomingPaymentZeroAmt() { req := httptest.NewRequest(http.MethodGet, "/balance", &buf) req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", suite.userToken)) rec := httptest.NewRecorder() - suite.echo.Use(tokens.Middleware([]byte(suite.service.Config.JWTSecret))) + suite.echo.Use(tokens.Middleware([]byte(suite.service.Config.JWTSecret), &lnd.Limits{})) suite.echo.GET("/balance", controllers.NewBalanceController(suite.service).Balance) suite.echo.POST("/addinvoice", controllers.NewAddInvoiceController(suite.service).AddInvoice) suite.echo.ServeHTTP(rec, req) @@ -144,7 +145,7 @@ func (suite *IncomingPaymentTestSuite) TestIncomingPaymentKeysend() { req := httptest.NewRequest(http.MethodGet, "/balance", &buf) req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", suite.userToken)) rec := httptest.NewRecorder() - suite.echo.Use(tokens.Middleware([]byte(suite.service.Config.JWTSecret))) + suite.echo.Use(tokens.Middleware([]byte(suite.service.Config.JWTSecret), &lnd.Limits{})) suite.echo.GET("/balance", controllers.NewBalanceController(suite.service).Balance) suite.echo.GET("/getuserinvoices", controllers.NewGetTXSController(suite.service).GetUserInvoices) suite.echo.ServeHTTP(rec, req) diff --git a/integration_tests/internal_payment_test.go b/integration_tests/internal_payment_test.go index dcc9c7f..8a95158 100644 --- a/integration_tests/internal_payment_test.go +++ b/integration_tests/internal_payment_test.go @@ -17,6 +17,7 @@ import ( "github.com/getAlby/lndhub.go/lib/responses" "github.com/getAlby/lndhub.go/lib/service" "github.com/getAlby/lndhub.go/lib/tokens" + "github.com/getAlby/lndhub.go/lnd" "github.com/go-playground/validator/v10" "github.com/labstack/echo/v4" "github.com/lightningnetwork/lnd/lnrpc" @@ -69,7 +70,8 @@ func (suite *PaymentTestSuite) SetupSuite() { suite.aliceToken = userTokens[0] suite.bobLogin = users[1] suite.bobToken = userTokens[1] - suite.echo.Use(tokens.Middleware([]byte(suite.service.Config.JWTSecret))) + + suite.echo.Use(tokens.Middleware([]byte(suite.service.Config.JWTSecret), &lnd.Limits{})) suite.echo.GET("/balance", controllers.NewBalanceController(suite.service).Balance) suite.echo.POST("/addinvoice", controllers.NewAddInvoiceController(suite.service).AddInvoice) suite.echo.POST("/payinvoice", controllers.NewPayInvoiceController(suite.service).PayInvoice) diff --git a/integration_tests/invoice_test.go b/integration_tests/invoice_test.go index 083671c..c93d6fa 100644 --- a/integration_tests/invoice_test.go +++ b/integration_tests/invoice_test.go @@ -18,6 +18,7 @@ import ( "github.com/getAlby/lndhub.go/lib/responses" "github.com/getAlby/lndhub.go/lib/service" "github.com/getAlby/lndhub.go/lib/tokens" + "github.com/getAlby/lndhub.go/lnd" "github.com/go-playground/validator/v10" "github.com/labstack/echo/v4" "github.com/stretchr/testify/assert" @@ -51,7 +52,7 @@ func (suite *InvoiceTestSuite) SetupSuite() { suite.aliceLogin = users[0] suite.aliceToken = userTokens[0] suite.echo.POST("/invoice/:user_login", controllers.NewInvoiceController(svc).Invoice) - suite.echo.POST("/v2/invoices", v2controllers.NewInvoiceController(svc).AddInvoice, tokens.Middleware([]byte(suite.service.Config.JWTSecret))) + suite.echo.POST("/v2/invoices", v2controllers.NewInvoiceController(svc).AddInvoice, tokens.Middleware([]byte(suite.service.Config.JWTSecret), &lnd.Limits{})) } func (suite *InvoiceTestSuite) TearDownTest() { diff --git a/integration_tests/keysend_failure_test.go b/integration_tests/keysend_failure_test.go index 25659c7..880db57 100644 --- a/integration_tests/keysend_failure_test.go +++ b/integration_tests/keysend_failure_test.go @@ -13,6 +13,7 @@ import ( "github.com/getAlby/lndhub.go/lib/responses" "github.com/getAlby/lndhub.go/lib/service" "github.com/getAlby/lndhub.go/lib/tokens" + "github.com/getAlby/lndhub.go/lnd" "github.com/go-playground/validator/v10" "github.com/labstack/echo/v4" "github.com/stretchr/testify/assert" @@ -72,7 +73,7 @@ func (suite *KeySendFailureTestSuite) SetupSuite() { assert.Equal(suite.T(), 1, len(userTokens)) suite.aliceLogin = users[0] suite.aliceToken = userTokens[0] - suite.echo.Use(tokens.Middleware([]byte(suite.service.Config.JWTSecret))) + suite.echo.Use(tokens.Middleware([]byte(suite.service.Config.JWTSecret), &lnd.Limits{})) suite.echo.POST("/addinvoice", controllers.NewAddInvoiceController(suite.service).AddInvoice) suite.echo.POST("/keysend", controllers.NewKeySendController(suite.service).KeySend) } diff --git a/integration_tests/keysend_test.go b/integration_tests/keysend_test.go index 7ce7955..9409d64 100644 --- a/integration_tests/keysend_test.go +++ b/integration_tests/keysend_test.go @@ -17,6 +17,7 @@ import ( "github.com/getAlby/lndhub.go/lib/responses" "github.com/getAlby/lndhub.go/lib/service" "github.com/getAlby/lndhub.go/lib/tokens" + "github.com/getAlby/lndhub.go/lnd" "github.com/go-playground/validator/v10" "github.com/labstack/echo/v4" "github.com/stretchr/testify/assert" @@ -60,7 +61,7 @@ func (suite *KeySendTestSuite) SetupSuite() { assert.Equal(suite.T(), 1, len(userTokens)) suite.aliceLogin = users[0] suite.aliceToken = userTokens[0] - suite.echo.Use(tokens.Middleware([]byte(suite.service.Config.JWTSecret))) + suite.echo.Use(tokens.Middleware([]byte(suite.service.Config.JWTSecret), &lnd.Limits{})) suite.echo.GET("/balance", controllers.NewBalanceController(suite.service).Balance) suite.echo.POST("/addinvoice", controllers.NewAddInvoiceController(suite.service).AddInvoice) suite.echo.POST("/payinvoice", controllers.NewPayInvoiceController(suite.service).PayInvoice) diff --git a/integration_tests/payment_failure_async_test.go b/integration_tests/payment_failure_async_test.go index edc8a1f..818e6be 100644 --- a/integration_tests/payment_failure_async_test.go +++ b/integration_tests/payment_failure_async_test.go @@ -13,6 +13,7 @@ import ( "github.com/getAlby/lndhub.go/lib/responses" "github.com/getAlby/lndhub.go/lib/service" "github.com/getAlby/lndhub.go/lib/tokens" + "github.com/getAlby/lndhub.go/lnd" "github.com/go-playground/validator/v10" "github.com/labstack/echo/v4" "github.com/lightningnetwork/lnd/lnrpc" @@ -69,7 +70,7 @@ func (suite *PaymentTestAsyncErrorsSuite) SetupSuite() { assert.Equal(suite.T(), 1, len(userTokens)) suite.userLogin = users[0] suite.userToken = userTokens[0] - suite.echo.Use(tokens.Middleware([]byte(suite.service.Config.JWTSecret))) + suite.echo.Use(tokens.Middleware([]byte(suite.service.Config.JWTSecret), &lnd.Limits{})) suite.echo.GET("/balance", controllers.NewBalanceController(suite.service).Balance) suite.echo.POST("/addinvoice", controllers.NewAddInvoiceController(suite.service).AddInvoice) suite.echo.POST("/payinvoice", controllers.NewPayInvoiceController(suite.service).PayInvoice) diff --git a/integration_tests/payment_failure_test.go b/integration_tests/payment_failure_test.go index eff222a..6d38ec6 100644 --- a/integration_tests/payment_failure_test.go +++ b/integration_tests/payment_failure_test.go @@ -17,6 +17,7 @@ import ( "github.com/getAlby/lndhub.go/lib/responses" "github.com/getAlby/lndhub.go/lib/service" "github.com/getAlby/lndhub.go/lib/tokens" + "github.com/getAlby/lndhub.go/lnd" "github.com/go-playground/validator/v10" "github.com/labstack/echo/v4" "github.com/lightningnetwork/lnd/lnrpc" @@ -71,7 +72,7 @@ func (suite *PaymentTestErrorsSuite) SetupSuite() { assert.Equal(suite.T(), 1, len(userTokens)) suite.userLogin = users[0] suite.userToken = userTokens[0] - suite.echo.Use(tokens.Middleware([]byte(suite.service.Config.JWTSecret))) + suite.echo.Use(tokens.Middleware([]byte(suite.service.Config.JWTSecret), &lnd.Limits{})) suite.echo.GET("/balance", controllers.NewBalanceController(suite.service).Balance) suite.echo.POST("/addinvoice", controllers.NewAddInvoiceController(suite.service).AddInvoice) suite.echo.POST("/payinvoice", controllers.NewPayInvoiceController(suite.service).PayInvoice) diff --git a/integration_tests/rabbitmq_test.go b/integration_tests/rabbitmq_test.go index f14d1dc..85f6b73 100644 --- a/integration_tests/rabbitmq_test.go +++ b/integration_tests/rabbitmq_test.go @@ -16,6 +16,7 @@ import ( "github.com/getAlby/lndhub.go/lib/responses" "github.com/getAlby/lndhub.go/lib/service" "github.com/getAlby/lndhub.go/lib/tokens" + "github.com/getAlby/lndhub.go/lnd" "github.com/go-playground/validator/v10" "github.com/labstack/echo/v4" "github.com/lightningnetwork/lnd/lnrpc" @@ -64,7 +65,7 @@ func (suite *RabbitMQTestSuite) SetupSuite() { e.Validator = &lib.CustomValidator{Validator: validator.New()} suite.echo = e - suite.echo.Use(tokens.Middleware(suite.svc.Config.JWTSecret)) + suite.echo.Use(tokens.Middleware([]byte(suite.svc.Config.JWTSecret), &lnd.Limits{})) suite.echo.POST("/addinvoice", controllers.NewAddInvoiceController(suite.svc).AddInvoice) suite.echo.POST("/payinvoice", controllers.NewPayInvoiceController(suite.svc).PayInvoice) go func() { diff --git a/integration_tests/subscription_start_test.go b/integration_tests/subscription_start_test.go index aea1ee4..3f2c84c 100644 --- a/integration_tests/subscription_start_test.go +++ b/integration_tests/subscription_start_test.go @@ -56,7 +56,7 @@ func (suite *SubscriptionStartTestSuite) SetupSuite() { suite.echo = e suite.userLogin = users[0] suite.userToken = userTokens[0] - suite.echo.Use(tokens.Middleware([]byte(suite.service.Config.JWTSecret))) + suite.echo.Use(tokens.Middleware([]byte(suite.service.Config.JWTSecret), &lnd.Limits{})) suite.echo.POST("/addinvoice", controllers.NewAddInvoiceController(suite.service).AddInvoice) } diff --git a/integration_tests/webhook_test.go b/integration_tests/webhook_test.go index b8d7e21..b18dfed 100644 --- a/integration_tests/webhook_test.go +++ b/integration_tests/webhook_test.go @@ -15,6 +15,7 @@ import ( "github.com/getAlby/lndhub.go/lib/responses" "github.com/getAlby/lndhub.go/lib/service" "github.com/getAlby/lndhub.go/lib/tokens" + "github.com/getAlby/lndhub.go/lnd" "github.com/go-playground/validator/v10" "github.com/labstack/echo/v4" "github.com/stretchr/testify/assert" @@ -73,7 +74,7 @@ func (suite *WebHookTestSuite) SetupSuite() { suite.echo = e suite.userLogin = users[0] suite.userToken = userTokens[0] - suite.echo.Use(tokens.Middleware([]byte(suite.service.Config.JWTSecret))) + suite.echo.Use(tokens.Middleware([]byte(suite.service.Config.JWTSecret), &lnd.Limits{})) suite.echo.POST("/addinvoice", controllers.NewAddInvoiceController(suite.service).AddInvoice) } func (suite *WebHookTestSuite) TestWebHook() { diff --git a/lib/service/user.go b/lib/service/user.go index 107408c..b0739ab 100644 --- a/lib/service/user.go +++ b/lib/service/user.go @@ -13,6 +13,7 @@ import ( "github.com/getAlby/lndhub.go/lib/security" "github.com/getAlby/lndhub.go/lnd" "github.com/getsentry/sentry-go" + "github.com/labstack/echo/v4" "github.com/labstack/gommon/log" "github.com/uptrace/bun" passwordvalidator "github.com/wagslane/go-password-validator" @@ -125,9 +126,9 @@ func (svc *LndhubService) FindUserByLogin(ctx context.Context, login string) (*m return &user, nil } -func (svc *LndhubService) CheckOutgoingPaymentAllowed(ctx context.Context, lnpayReq *lnd.LNPayReq, userId int64) (result *responses.ErrorResponse, err error) { - if svc.Config.MaxSendAmount > 0 { - if lnpayReq.PayReq.NumSatoshis > svc.Config.MaxSendAmount { +func (svc *LndhubService) CheckOutgoingPaymentAllowed(ctx context.Context, lnpayReq *lnd.LNPayReq, userId int64, limits *lnd.Limits) (result *responses.ErrorResponse, err error) { + if limits.MaxSendAmount > 0 { + if lnpayReq.PayReq.NumSatoshis > limits.MaxSendAmount { svc.Logger.Errorf("Max send amount exceeded for user_id %v (amount:%v)", userId, lnpayReq.PayReq.NumSatoshis) return &responses.SendExceededError, nil } @@ -153,18 +154,18 @@ func (svc *LndhubService) CheckOutgoingPaymentAllowed(ctx context.Context, lnpay return &responses.NotEnoughBalanceError, nil } - return svc.CheckVolumeAllowed(ctx, userId, common.InvoiceTypeOutgoing) + return svc.CheckVolumeAllowed(ctx, userId, limits.MaxSendVolume, common.InvoiceTypeOutgoing) } -func (svc *LndhubService) CheckIncomingPaymentAllowed(ctx context.Context, amount, userId int64) (result *responses.ErrorResponse, err error) { - if svc.Config.MaxReceiveAmount > 0 { - if amount > svc.Config.MaxReceiveAmount { +func (svc *LndhubService) CheckIncomingPaymentAllowed(ctx context.Context, amount, userId int64, limits *lnd.Limits) (result *responses.ErrorResponse, err error) { + if limits.MaxReceiveAmount > 0 { + if amount > limits.MaxReceiveAmount { svc.Logger.Errorf("Max receive amount exceeded for user_id %d", userId) return &responses.ReceiveExceededError, nil } } - if svc.Config.MaxAccountBalance > 0 { + if limits.MaxAccountBalance > 0 { currentBalance, err := svc.CurrentUserBalance(ctx, userId) if err != nil { svc.Logger.Errorj( @@ -176,22 +177,16 @@ func (svc *LndhubService) CheckIncomingPaymentAllowed(ctx context.Context, amoun ) return nil, err } - if currentBalance+amount > svc.Config.MaxAccountBalance { + if currentBalance+amount > limits.MaxAccountBalance { svc.Logger.Errorf("Max account balance exceeded for user_id %d", userId) return &responses.BalanceExceededError, nil } } - return svc.CheckVolumeAllowed(ctx, userId, common.InvoiceTypeIncoming) + return svc.CheckVolumeAllowed(ctx, userId, limits.MaxReceiveVolume, common.InvoiceTypeIncoming) } -func (svc *LndhubService) CheckVolumeAllowed(ctx context.Context, userId int64, invoiceType string) (result *responses.ErrorResponse, err error) { - var maxVolume int64 - if invoiceType == common.InvoiceTypeIncoming { - maxVolume = svc.Config.MaxReceiveVolume - } else { - maxVolume = svc.Config.MaxSendVolume - } +func (svc *LndhubService) CheckVolumeAllowed(ctx context.Context, userId, maxVolume int64, invoiceType string) (result *responses.ErrorResponse, err error) { if maxVolume > 0 { volume, err := svc.GetVolumeOverPeriod(ctx, userId, invoiceType, time.Duration(svc.Config.MaxVolumePeriod*int64(time.Second))) if err != nil { @@ -279,3 +274,24 @@ func (svc *LndhubService) GetVolumeOverPeriod(ctx context.Context, userId int64, } return result, nil } + +func (svc *LndhubService) GetLimitsFromContext(c echo.Context) (limits *lnd.Limits) { + limits = &lnd.Limits{} + if val, ok := c.Get("MaxSendVolume").(int64); ok { + limits.MaxSendVolume = val + } + if val, ok := c.Get("MaxSendAmount").(int64); ok { + limits.MaxSendAmount = val + } + if val, ok := c.Get("MaxReceiveVolume").(int64); ok { + limits.MaxReceiveVolume = val + } + if val, ok := c.Get("MaxReceiveAmount").(int64); ok { + limits.MaxReceiveAmount = val + } + if val, ok := c.Get("MaxAccountBalance").(int64); ok { + limits.MaxAccountBalance = val + } + + return limits +} diff --git a/lib/tokens/jwt.go b/lib/tokens/jwt.go index 53cf490..144d1f3 100644 --- a/lib/tokens/jwt.go +++ b/lib/tokens/jwt.go @@ -7,6 +7,7 @@ import ( "time" "github.com/getAlby/lndhub.go/db/models" + "github.com/getAlby/lndhub.go/lnd" "github.com/getsentry/sentry-go" sentryecho "github.com/getsentry/sentry-go/echo" "github.com/golang-jwt/jwt" @@ -15,12 +16,17 @@ import ( ) type jwtCustomClaims struct { - ID int64 `json:"id"` - IsRefresh bool `json:"isRefresh"` + ID int64 `json:"id"` + IsRefresh bool `json:"isRefresh"` + MaxSendVolume int64 `json:"maxSendVolume"` + MaxSendAmount int64 `json:"maxSendAmount"` + MaxReceiveVolume int64 `json:"maxReceiveVolume"` + MaxReceiveAmount int64 `json:"maxReceiveAmount"` + MaxAccountBalance int64 `json:"maxAccountBalance"` jwt.StandardClaims } -func Middleware(secret []byte) echo.MiddlewareFunc { +func Middleware(secret []byte, limits *lnd.Limits) echo.MiddlewareFunc { config := middleware.DefaultJWTConfig config.Claims = &jwtCustomClaims{} @@ -38,6 +44,31 @@ func Middleware(secret []byte) echo.MiddlewareFunc { token := c.Get("UserJwt").(*jwt.Token) claims := token.Claims.(*jwtCustomClaims) c.Set("UserID", claims.ID) + if limits.MaxSendVolume == 0 { + c.Set("MaxSendVolume", claims.MaxSendAmount) + } else { + c.Set("MaxSendVolume", limits.MaxSendVolume) + } + if limits.MaxSendAmount == 0 { + c.Set("MaxSendAmount", claims.MaxSendAmount) + } else { + c.Set("MaxSendAmount", limits.MaxSendAmount) + } + if limits.MaxReceiveVolume == 0 { + c.Set("MaxReceiveVolume", claims.MaxReceiveVolume) + } else { + c.Set("MaxReceiveVolume", limits.MaxReceiveVolume) + } + if limits.MaxReceiveAmount == 0 { + c.Set("MaxReceiveAmount", claims.MaxReceiveAmount) + } else { + c.Set("MaxReceiveAmount", limits.MaxReceiveAmount) + } + if limits.MaxAccountBalance == 0 { + c.Set("MaxAccountBalance", claims.MaxAccountBalance) + } else { + c.Set("MaxAccountBalance", limits.MaxAccountBalance) + } // pass UserID to sentry for exception notifications if hub := sentryecho.GetHubFromContext(c); hub != nil { hub.Scope().SetUser(sentry.User{ID: strconv.FormatInt(claims.ID, 10)}) diff --git a/lnd/config.go b/lnd/config.go index 41cabce..509b1a7 100644 --- a/lnd/config.go +++ b/lnd/config.go @@ -22,6 +22,14 @@ type Config struct { LNDClusterPubkeys string `envconfig:"LND_CLUSTER_PUBKEYS"` //comma-seperated list of public keys of the cluster } +type Limits struct { + MaxSendVolume int64 + MaxSendAmount int64 + MaxReceiveVolume int64 + MaxReceiveAmount int64 + MaxAccountBalance int64 +} + func LoadConfig() (c *Config, err error) { c = &Config{} err = envconfig.Process("", c)