diff --git a/controllers/addinvoice.ctrl.go b/controllers/addinvoice.ctrl.go index 00ba960..576f430 100644 --- a/controllers/addinvoice.ctrl.go +++ b/controllers/addinvoice.ctrl.go @@ -37,7 +37,6 @@ func (controller *AddInvoiceController) AddInvoice(c echo.Context) error { func AddInvoice(c echo.Context, svc *service.LndhubService, userID int64) error { var body AddInvoiceRequestBody - limits := svc.GetLimits(c) if err := c.Bind(&body); err != nil { c.Logger().Errorf("Failed to load addinvoice request body: %v", err) @@ -62,7 +61,7 @@ func AddInvoice(c echo.Context, svc *service.LndhubService, userID int64) error return c.JSON(http.StatusBadRequest, responses.BadArgumentsError) } - resp, err := svc.CheckIncomingPaymentAllowed(c.Request().Context(), amount, userID, limits) + resp, err := svc.CheckIncomingPaymentAllowed(c, amount, userID) if err != nil { return c.JSON(http.StatusInternalServerError, responses.GeneralServerError) } diff --git a/controllers/keysend.ctrl.go b/controllers/keysend.ctrl.go index 4ee3d70..421d5bc 100644 --- a/controllers/keysend.ctrl.go +++ b/controllers/keysend.ctrl.go @@ -46,7 +46,6 @@ type KeySendResponseBody struct { func (controller *KeySendController) KeySend(c echo.Context) error { userID := c.Get("UserID").(int64) - limits := controller.svc.GetLimits(c) reqBody := KeySendRequestBody{} if err := c.Bind(&reqBody); err != nil { c.Logger().Errorf("Failed to load keysend request body: %v", err) @@ -76,7 +75,7 @@ func (controller *KeySendController) KeySend(c echo.Context) error { }) } - resp, err := controller.svc.CheckOutgoingPaymentAllowed(c.Request().Context(), lnPayReq, userID, limits) + resp, err := controller.svc.CheckOutgoingPaymentAllowed(c, lnPayReq, userID) if err != nil { return c.JSON(http.StatusInternalServerError, responses.GeneralServerError) } diff --git a/controllers/payinvoice.ctrl.go b/controllers/payinvoice.ctrl.go index 7f8e511..e0f5311 100644 --- a/controllers/payinvoice.ctrl.go +++ b/controllers/payinvoice.ctrl.go @@ -43,7 +43,6 @@ type PayInvoiceResponseBody struct { func (controller *PayInvoiceController) PayInvoice(c echo.Context) error { userID := c.Get("UserID").(int64) - limits := controller.svc.GetLimits(c) reqBody := PayInvoiceRequestBody{} if err := c.Bind(&reqBody); err != nil { c.Logger().Errorf("Failed to load payinvoice request body: user_id:%v error: %v", userID, err) @@ -91,7 +90,7 @@ func (controller *PayInvoiceController) PayInvoice(c echo.Context) error { lnPayReq.PayReq.NumSatoshis = amt } - resp, err := controller.svc.CheckOutgoingPaymentAllowed(c.Request().Context(), lnPayReq, userID, limits) + resp, err := controller.svc.CheckOutgoingPaymentAllowed(c, lnPayReq, userID) if err != nil { return c.JSON(http.StatusInternalServerError, responses.GeneralServerError) } diff --git a/controllers_v2/invoice.ctrl.go b/controllers_v2/invoice.ctrl.go index af20331..ed8537b 100644 --- a/controllers_v2/invoice.ctrl.go +++ b/controllers_v2/invoice.ctrl.go @@ -165,7 +165,6 @@ type AddInvoiceResponseBody struct { // @Security OAuth2Password func (controller *InvoiceController) AddInvoice(c echo.Context) error { userID := c.Get("UserID").(int64) - limits := controller.svc.GetLimits(c) var body AddInvoiceRequestBody @@ -179,7 +178,7 @@ func (controller *InvoiceController) AddInvoice(c echo.Context) error { return c.JSON(http.StatusBadRequest, responses.BadArgumentsError) } - resp, err := controller.svc.CheckIncomingPaymentAllowed(c.Request().Context(), body.Amount, userID, limits) + resp, err := controller.svc.CheckIncomingPaymentAllowed(c, body.Amount, userID) if err != nil { return c.JSON(http.StatusInternalServerError, responses.GeneralServerError) } diff --git a/controllers_v2/keysend.ctrl.go b/controllers_v2/keysend.ctrl.go index 38e9a2e..884b906 100644 --- a/controllers_v2/keysend.ctrl.go +++ b/controllers_v2/keysend.ctrl.go @@ -8,7 +8,6 @@ import ( "strconv" "github.com/getAlby/lndhub.go/common" - "github.com/getAlby/lndhub.go/db/models" "github.com/getAlby/lndhub.go/lib/responses" "github.com/getAlby/lndhub.go/lib/service" "github.com/getAlby/lndhub.go/lnd" @@ -72,7 +71,6 @@ type KeySendResponseBody struct { // @Security OAuth2Password func (controller *KeySendController) KeySend(c echo.Context) error { userID := c.Get("UserID").(int64) - limits := controller.svc.GetLimits(c) reqBody := KeySendRequestBody{} if err := c.Bind(&reqBody); err != nil { c.Logger().Errorf("Failed to load keysend request body: %v", err) @@ -83,7 +81,7 @@ func (controller *KeySendController) KeySend(c echo.Context) error { c.Logger().Errorf("Invalid keysend request body: %v", err) return c.JSON(http.StatusBadRequest, responses.BadArgumentsError) } - errResp := controller.checkKeysendPaymentAllowed(context.Background(), reqBody.Amount, userID, limits) + errResp := controller.checkKeysendPaymentAllowed(c, reqBody.Amount, userID) if errResp != nil { c.Logger().Errorf("Failed to send keysend: %s", errResp.Message) return c.JSON(errResp.HttpStatusCode, errResp) @@ -110,7 +108,6 @@ func (controller *KeySendController) KeySend(c echo.Context) error { // @Security OAuth2Password func (controller *KeySendController) MultiKeySend(c echo.Context) error { userID := c.Get("UserID").(int64) - limits := controller.svc.GetLimits(c) reqBody := MultiKeySendRequestBody{} if err := c.Bind(&reqBody); err != nil { c.Logger().Errorf("Failed to load keysend request body: %v", err) @@ -130,7 +127,7 @@ func (controller *KeySendController) MultiKeySend(c echo.Context) error { for _, keysend := range reqBody.Keysends { totalAmount += keysend.Amount } - errResp := controller.checkKeysendPaymentAllowed(context.Background(), totalAmount, userID, limits) + errResp := controller.checkKeysendPaymentAllowed(c, totalAmount, userID) if errResp != nil { c.Logger().Errorf("Failed to make keysend split payments: %s", errResp.Message) return c.JSON(errResp.HttpStatusCode, errResp) @@ -165,14 +162,14 @@ func (controller *KeySendController) MultiKeySend(c echo.Context) error { return c.JSON(status, result) } -func (controller *KeySendController) checkKeysendPaymentAllowed(ctx context.Context, amount, userID int64, limits *models.Limits) (resp *responses.ErrorResponse) { +func (controller *KeySendController) checkKeysendPaymentAllowed(c echo.Context, amount, userID int64) (resp *responses.ErrorResponse) { syntheticPayReq := &lnd.LNPayReq{ PayReq: &lnrpc.PayReq{ NumSatoshis: amount, }, Keysend: true, } - resp, err := controller.svc.CheckOutgoingPaymentAllowed(ctx, syntheticPayReq, userID, limits) + resp, err := controller.svc.CheckOutgoingPaymentAllowed(c, syntheticPayReq, userID) if err != nil { return &responses.GeneralServerError } diff --git a/controllers_v2/payinvoice.ctrl.go b/controllers_v2/payinvoice.ctrl.go index 9d0647d..bf16737 100644 --- a/controllers_v2/payinvoice.ctrl.go +++ b/controllers_v2/payinvoice.ctrl.go @@ -52,7 +52,6 @@ type PayInvoiceResponseBody struct { // @Security OAuth2Password func (controller *PayInvoiceController) PayInvoice(c echo.Context) error { userID := c.Get("UserID").(int64) - limits := controller.svc.GetLimits(c) reqBody := PayInvoiceRequestBody{} if err := c.Bind(&reqBody); err != nil { c.Logger().Errorf("Failed to load payinvoice request body: user_id:%v error: %v", userID, err) @@ -99,7 +98,7 @@ func (controller *PayInvoiceController) PayInvoice(c echo.Context) error { } lnPayReq.PayReq.NumSatoshis = amt } - resp, err := controller.svc.CheckOutgoingPaymentAllowed(c.Request().Context(), lnPayReq, userID, limits) + resp, err := controller.svc.CheckOutgoingPaymentAllowed(c, lnPayReq, userID) if err != nil { return c.JSON(http.StatusBadRequest, responses.GeneralServerError) } diff --git a/db/models/user.go b/db/models/user.go index a9f55af..98939f0 100644 --- a/db/models/user.go +++ b/db/models/user.go @@ -21,14 +21,6 @@ type User struct { Deactivated bool } -type Limits struct { - MaxSendVolume int64 - MaxSendAmount int64 - MaxReceiveVolume int64 - MaxReceiveAmount int64 - MaxAccountBalance int64 -} - func (u *User) BeforeAppendModel(ctx context.Context, query bun.Query) error { switch query.(type) { case *bun.UpdateQuery: diff --git a/lib/service/user.go b/lib/service/user.go index 196a9e5..edf0bbe 100644 --- a/lib/service/user.go +++ b/lib/service/user.go @@ -126,7 +126,8 @@ func (svc *LndhubService) FindUserByLogin(ctx context.Context, login string) (*m return &user, nil } -func (svc *LndhubService) CheckOutgoingPaymentAllowed(ctx context.Context, lnpayReq *lnd.LNPayReq, userId int64, limits *models.Limits) (result *responses.ErrorResponse, err error) { +func (svc *LndhubService) CheckOutgoingPaymentAllowed(c echo.Context, lnpayReq *lnd.LNPayReq, userId int64) (result *responses.ErrorResponse, err error) { + limits := svc.GetLimits(c) if limits.MaxSendAmount > 0 { if lnpayReq.PayReq.NumSatoshis > limits.MaxSendAmount { svc.Logger.Errorf("Max send amount exceeded for user_id %v (amount:%v)", userId, lnpayReq.PayReq.NumSatoshis) @@ -135,7 +136,7 @@ func (svc *LndhubService) CheckOutgoingPaymentAllowed(ctx context.Context, lnpay } if limits.MaxSendVolume > 0 { - volume, err := svc.GetVolumeOverPeriod(ctx, userId, common.InvoiceTypeOutgoing, time.Duration(svc.Config.MaxVolumePeriod*int64(time.Second))) + volume, err := svc.GetVolumeOverPeriod(c.Request().Context(), userId, common.InvoiceTypeOutgoing, time.Duration(svc.Config.MaxVolumePeriod*int64(time.Second))) if err != nil { svc.Logger.Errorj( log.JSON{ @@ -153,7 +154,7 @@ func (svc *LndhubService) CheckOutgoingPaymentAllowed(ctx context.Context, lnpay } } - currentBalance, err := svc.CurrentUserBalance(ctx, userId) + currentBalance, err := svc.CurrentUserBalance(c.Request().Context(), userId) if err != nil { svc.Logger.Errorj( log.JSON{ @@ -176,7 +177,8 @@ func (svc *LndhubService) CheckOutgoingPaymentAllowed(ctx context.Context, lnpay return nil, nil } -func (svc *LndhubService) CheckIncomingPaymentAllowed(ctx context.Context, amount, userId int64, limits *models.Limits) (result *responses.ErrorResponse, err error) { +func (svc *LndhubService) CheckIncomingPaymentAllowed(c echo.Context, amount, userId int64) (result *responses.ErrorResponse, err error) { + limits := svc.GetLimits(c) if limits.MaxReceiveAmount > 0 { if amount > limits.MaxReceiveAmount { svc.Logger.Errorf("Max receive amount exceeded for user_id %d", userId) @@ -185,7 +187,7 @@ func (svc *LndhubService) CheckIncomingPaymentAllowed(ctx context.Context, amoun } if limits.MaxReceiveVolume > 0 { - volume, err := svc.GetVolumeOverPeriod(ctx, userId, common.InvoiceTypeIncoming, time.Duration(svc.Config.MaxVolumePeriod*int64(time.Second))) + volume, err := svc.GetVolumeOverPeriod(c.Request().Context(), userId, common.InvoiceTypeIncoming, time.Duration(svc.Config.MaxVolumePeriod*int64(time.Second))) if err != nil { svc.Logger.Errorj( log.JSON{ @@ -204,7 +206,7 @@ func (svc *LndhubService) CheckIncomingPaymentAllowed(ctx context.Context, amoun } if limits.MaxAccountBalance > 0 { - currentBalance, err := svc.CurrentUserBalance(ctx, userId) + currentBalance, err := svc.CurrentUserBalance(c.Request().Context(), userId) if err != nil { svc.Logger.Errorj( log.JSON{ @@ -290,8 +292,8 @@ func (svc *LndhubService) GetVolumeOverPeriod(ctx context.Context, userId int64, return result, nil } -func (svc *LndhubService) GetLimits(c echo.Context) (limits *models.Limits) { - limits = &models.Limits{ +func (svc *LndhubService) GetLimits(c echo.Context) (limits *lnd.Limits) { + limits = &lnd.Limits{ MaxSendVolume: svc.Config.MaxSendVolume, MaxSendAmount: svc.Config.MaxSendAmount, MaxReceiveVolume: svc.Config.MaxReceiveVolume, diff --git a/lnd/config.go b/lnd/config.go index 41cabce..509b1a7 100644 --- a/lnd/config.go +++ b/lnd/config.go @@ -22,6 +22,14 @@ type Config struct { LNDClusterPubkeys string `envconfig:"LND_CLUSTER_PUBKEYS"` //comma-seperated list of public keys of the cluster } +type Limits struct { + MaxSendVolume int64 + MaxSendAmount int64 + MaxReceiveVolume int64 + MaxReceiveAmount int64 + MaxAccountBalance int64 +} + func LoadConfig() (c *Config, err error) { c = &Config{} err = envconfig.Process("", c)