chore: further refactoring

This commit is contained in:
im-adithya
2023-12-07 16:44:47 +05:30
parent 0385c17d2b
commit b494219071
9 changed files with 27 additions and 33 deletions

View File

@@ -37,7 +37,6 @@ func (controller *AddInvoiceController) AddInvoice(c echo.Context) error {
func AddInvoice(c echo.Context, svc *service.LndhubService, userID int64) error { func AddInvoice(c echo.Context, svc *service.LndhubService, userID int64) error {
var body AddInvoiceRequestBody var body AddInvoiceRequestBody
limits := svc.GetLimits(c)
if err := c.Bind(&body); err != nil { if err := c.Bind(&body); err != nil {
c.Logger().Errorf("Failed to load addinvoice request body: %v", err) c.Logger().Errorf("Failed to load addinvoice request body: %v", err)
@@ -62,7 +61,7 @@ func AddInvoice(c echo.Context, svc *service.LndhubService, userID int64) error
return c.JSON(http.StatusBadRequest, responses.BadArgumentsError) return c.JSON(http.StatusBadRequest, responses.BadArgumentsError)
} }
resp, err := svc.CheckIncomingPaymentAllowed(c.Request().Context(), amount, userID, limits) resp, err := svc.CheckIncomingPaymentAllowed(c, amount, userID)
if err != nil { if err != nil {
return c.JSON(http.StatusInternalServerError, responses.GeneralServerError) return c.JSON(http.StatusInternalServerError, responses.GeneralServerError)
} }

View File

@@ -46,7 +46,6 @@ type KeySendResponseBody struct {
func (controller *KeySendController) KeySend(c echo.Context) error { func (controller *KeySendController) KeySend(c echo.Context) error {
userID := c.Get("UserID").(int64) userID := c.Get("UserID").(int64)
limits := controller.svc.GetLimits(c)
reqBody := KeySendRequestBody{} reqBody := KeySendRequestBody{}
if err := c.Bind(&reqBody); err != nil { if err := c.Bind(&reqBody); err != nil {
c.Logger().Errorf("Failed to load keysend request body: %v", err) c.Logger().Errorf("Failed to load keysend request body: %v", err)
@@ -76,7 +75,7 @@ func (controller *KeySendController) KeySend(c echo.Context) error {
}) })
} }
resp, err := controller.svc.CheckOutgoingPaymentAllowed(c.Request().Context(), lnPayReq, userID, limits) resp, err := controller.svc.CheckOutgoingPaymentAllowed(c, lnPayReq, userID)
if err != nil { if err != nil {
return c.JSON(http.StatusInternalServerError, responses.GeneralServerError) return c.JSON(http.StatusInternalServerError, responses.GeneralServerError)
} }

View File

@@ -43,7 +43,6 @@ type PayInvoiceResponseBody struct {
func (controller *PayInvoiceController) PayInvoice(c echo.Context) error { func (controller *PayInvoiceController) PayInvoice(c echo.Context) error {
userID := c.Get("UserID").(int64) userID := c.Get("UserID").(int64)
limits := controller.svc.GetLimits(c)
reqBody := PayInvoiceRequestBody{} reqBody := PayInvoiceRequestBody{}
if err := c.Bind(&reqBody); err != nil { if err := c.Bind(&reqBody); err != nil {
c.Logger().Errorf("Failed to load payinvoice request body: user_id:%v error: %v", userID, err) c.Logger().Errorf("Failed to load payinvoice request body: user_id:%v error: %v", userID, err)
@@ -91,7 +90,7 @@ func (controller *PayInvoiceController) PayInvoice(c echo.Context) error {
lnPayReq.PayReq.NumSatoshis = amt lnPayReq.PayReq.NumSatoshis = amt
} }
resp, err := controller.svc.CheckOutgoingPaymentAllowed(c.Request().Context(), lnPayReq, userID, limits) resp, err := controller.svc.CheckOutgoingPaymentAllowed(c, lnPayReq, userID)
if err != nil { if err != nil {
return c.JSON(http.StatusInternalServerError, responses.GeneralServerError) return c.JSON(http.StatusInternalServerError, responses.GeneralServerError)
} }

View File

@@ -165,7 +165,6 @@ type AddInvoiceResponseBody struct {
// @Security OAuth2Password // @Security OAuth2Password
func (controller *InvoiceController) AddInvoice(c echo.Context) error { func (controller *InvoiceController) AddInvoice(c echo.Context) error {
userID := c.Get("UserID").(int64) userID := c.Get("UserID").(int64)
limits := controller.svc.GetLimits(c)
var body AddInvoiceRequestBody var body AddInvoiceRequestBody
@@ -179,7 +178,7 @@ func (controller *InvoiceController) AddInvoice(c echo.Context) error {
return c.JSON(http.StatusBadRequest, responses.BadArgumentsError) return c.JSON(http.StatusBadRequest, responses.BadArgumentsError)
} }
resp, err := controller.svc.CheckIncomingPaymentAllowed(c.Request().Context(), body.Amount, userID, limits) resp, err := controller.svc.CheckIncomingPaymentAllowed(c, body.Amount, userID)
if err != nil { if err != nil {
return c.JSON(http.StatusInternalServerError, responses.GeneralServerError) return c.JSON(http.StatusInternalServerError, responses.GeneralServerError)
} }

View File

@@ -8,7 +8,6 @@ import (
"strconv" "strconv"
"github.com/getAlby/lndhub.go/common" "github.com/getAlby/lndhub.go/common"
"github.com/getAlby/lndhub.go/db/models"
"github.com/getAlby/lndhub.go/lib/responses" "github.com/getAlby/lndhub.go/lib/responses"
"github.com/getAlby/lndhub.go/lib/service" "github.com/getAlby/lndhub.go/lib/service"
"github.com/getAlby/lndhub.go/lnd" "github.com/getAlby/lndhub.go/lnd"
@@ -72,7 +71,6 @@ type KeySendResponseBody struct {
// @Security OAuth2Password // @Security OAuth2Password
func (controller *KeySendController) KeySend(c echo.Context) error { func (controller *KeySendController) KeySend(c echo.Context) error {
userID := c.Get("UserID").(int64) userID := c.Get("UserID").(int64)
limits := controller.svc.GetLimits(c)
reqBody := KeySendRequestBody{} reqBody := KeySendRequestBody{}
if err := c.Bind(&reqBody); err != nil { if err := c.Bind(&reqBody); err != nil {
c.Logger().Errorf("Failed to load keysend request body: %v", err) c.Logger().Errorf("Failed to load keysend request body: %v", err)
@@ -83,7 +81,7 @@ func (controller *KeySendController) KeySend(c echo.Context) error {
c.Logger().Errorf("Invalid keysend request body: %v", err) c.Logger().Errorf("Invalid keysend request body: %v", err)
return c.JSON(http.StatusBadRequest, responses.BadArgumentsError) return c.JSON(http.StatusBadRequest, responses.BadArgumentsError)
} }
errResp := controller.checkKeysendPaymentAllowed(context.Background(), reqBody.Amount, userID, limits) errResp := controller.checkKeysendPaymentAllowed(c, reqBody.Amount, userID)
if errResp != nil { if errResp != nil {
c.Logger().Errorf("Failed to send keysend: %s", errResp.Message) c.Logger().Errorf("Failed to send keysend: %s", errResp.Message)
return c.JSON(errResp.HttpStatusCode, errResp) return c.JSON(errResp.HttpStatusCode, errResp)
@@ -110,7 +108,6 @@ func (controller *KeySendController) KeySend(c echo.Context) error {
// @Security OAuth2Password // @Security OAuth2Password
func (controller *KeySendController) MultiKeySend(c echo.Context) error { func (controller *KeySendController) MultiKeySend(c echo.Context) error {
userID := c.Get("UserID").(int64) userID := c.Get("UserID").(int64)
limits := controller.svc.GetLimits(c)
reqBody := MultiKeySendRequestBody{} reqBody := MultiKeySendRequestBody{}
if err := c.Bind(&reqBody); err != nil { if err := c.Bind(&reqBody); err != nil {
c.Logger().Errorf("Failed to load keysend request body: %v", err) c.Logger().Errorf("Failed to load keysend request body: %v", err)
@@ -130,7 +127,7 @@ func (controller *KeySendController) MultiKeySend(c echo.Context) error {
for _, keysend := range reqBody.Keysends { for _, keysend := range reqBody.Keysends {
totalAmount += keysend.Amount totalAmount += keysend.Amount
} }
errResp := controller.checkKeysendPaymentAllowed(context.Background(), totalAmount, userID, limits) errResp := controller.checkKeysendPaymentAllowed(c, totalAmount, userID)
if errResp != nil { if errResp != nil {
c.Logger().Errorf("Failed to make keysend split payments: %s", errResp.Message) c.Logger().Errorf("Failed to make keysend split payments: %s", errResp.Message)
return c.JSON(errResp.HttpStatusCode, errResp) return c.JSON(errResp.HttpStatusCode, errResp)
@@ -165,14 +162,14 @@ func (controller *KeySendController) MultiKeySend(c echo.Context) error {
return c.JSON(status, result) return c.JSON(status, result)
} }
func (controller *KeySendController) checkKeysendPaymentAllowed(ctx context.Context, amount, userID int64, limits *models.Limits) (resp *responses.ErrorResponse) { func (controller *KeySendController) checkKeysendPaymentAllowed(c echo.Context, amount, userID int64) (resp *responses.ErrorResponse) {
syntheticPayReq := &lnd.LNPayReq{ syntheticPayReq := &lnd.LNPayReq{
PayReq: &lnrpc.PayReq{ PayReq: &lnrpc.PayReq{
NumSatoshis: amount, NumSatoshis: amount,
}, },
Keysend: true, Keysend: true,
} }
resp, err := controller.svc.CheckOutgoingPaymentAllowed(ctx, syntheticPayReq, userID, limits) resp, err := controller.svc.CheckOutgoingPaymentAllowed(c, syntheticPayReq, userID)
if err != nil { if err != nil {
return &responses.GeneralServerError return &responses.GeneralServerError
} }

View File

@@ -52,7 +52,6 @@ type PayInvoiceResponseBody struct {
// @Security OAuth2Password // @Security OAuth2Password
func (controller *PayInvoiceController) PayInvoice(c echo.Context) error { func (controller *PayInvoiceController) PayInvoice(c echo.Context) error {
userID := c.Get("UserID").(int64) userID := c.Get("UserID").(int64)
limits := controller.svc.GetLimits(c)
reqBody := PayInvoiceRequestBody{} reqBody := PayInvoiceRequestBody{}
if err := c.Bind(&reqBody); err != nil { if err := c.Bind(&reqBody); err != nil {
c.Logger().Errorf("Failed to load payinvoice request body: user_id:%v error: %v", userID, err) c.Logger().Errorf("Failed to load payinvoice request body: user_id:%v error: %v", userID, err)
@@ -99,7 +98,7 @@ func (controller *PayInvoiceController) PayInvoice(c echo.Context) error {
} }
lnPayReq.PayReq.NumSatoshis = amt lnPayReq.PayReq.NumSatoshis = amt
} }
resp, err := controller.svc.CheckOutgoingPaymentAllowed(c.Request().Context(), lnPayReq, userID, limits) resp, err := controller.svc.CheckOutgoingPaymentAllowed(c, lnPayReq, userID)
if err != nil { if err != nil {
return c.JSON(http.StatusBadRequest, responses.GeneralServerError) return c.JSON(http.StatusBadRequest, responses.GeneralServerError)
} }

View File

@@ -21,14 +21,6 @@ type User struct {
Deactivated bool Deactivated bool
} }
type Limits struct {
MaxSendVolume int64
MaxSendAmount int64
MaxReceiveVolume int64
MaxReceiveAmount int64
MaxAccountBalance int64
}
func (u *User) BeforeAppendModel(ctx context.Context, query bun.Query) error { func (u *User) BeforeAppendModel(ctx context.Context, query bun.Query) error {
switch query.(type) { switch query.(type) {
case *bun.UpdateQuery: case *bun.UpdateQuery:

View File

@@ -126,7 +126,8 @@ func (svc *LndhubService) FindUserByLogin(ctx context.Context, login string) (*m
return &user, nil return &user, nil
} }
func (svc *LndhubService) CheckOutgoingPaymentAllowed(ctx context.Context, lnpayReq *lnd.LNPayReq, userId int64, limits *models.Limits) (result *responses.ErrorResponse, err error) { func (svc *LndhubService) CheckOutgoingPaymentAllowed(c echo.Context, lnpayReq *lnd.LNPayReq, userId int64) (result *responses.ErrorResponse, err error) {
limits := svc.GetLimits(c)
if limits.MaxSendAmount > 0 { if limits.MaxSendAmount > 0 {
if lnpayReq.PayReq.NumSatoshis > limits.MaxSendAmount { if lnpayReq.PayReq.NumSatoshis > limits.MaxSendAmount {
svc.Logger.Errorf("Max send amount exceeded for user_id %v (amount:%v)", userId, lnpayReq.PayReq.NumSatoshis) svc.Logger.Errorf("Max send amount exceeded for user_id %v (amount:%v)", userId, lnpayReq.PayReq.NumSatoshis)
@@ -135,7 +136,7 @@ func (svc *LndhubService) CheckOutgoingPaymentAllowed(ctx context.Context, lnpay
} }
if limits.MaxSendVolume > 0 { if limits.MaxSendVolume > 0 {
volume, err := svc.GetVolumeOverPeriod(ctx, userId, common.InvoiceTypeOutgoing, time.Duration(svc.Config.MaxVolumePeriod*int64(time.Second))) volume, err := svc.GetVolumeOverPeriod(c.Request().Context(), userId, common.InvoiceTypeOutgoing, time.Duration(svc.Config.MaxVolumePeriod*int64(time.Second)))
if err != nil { if err != nil {
svc.Logger.Errorj( svc.Logger.Errorj(
log.JSON{ log.JSON{
@@ -153,7 +154,7 @@ func (svc *LndhubService) CheckOutgoingPaymentAllowed(ctx context.Context, lnpay
} }
} }
currentBalance, err := svc.CurrentUserBalance(ctx, userId) currentBalance, err := svc.CurrentUserBalance(c.Request().Context(), userId)
if err != nil { if err != nil {
svc.Logger.Errorj( svc.Logger.Errorj(
log.JSON{ log.JSON{
@@ -176,7 +177,8 @@ func (svc *LndhubService) CheckOutgoingPaymentAllowed(ctx context.Context, lnpay
return nil, nil return nil, nil
} }
func (svc *LndhubService) CheckIncomingPaymentAllowed(ctx context.Context, amount, userId int64, limits *models.Limits) (result *responses.ErrorResponse, err error) { func (svc *LndhubService) CheckIncomingPaymentAllowed(c echo.Context, amount, userId int64) (result *responses.ErrorResponse, err error) {
limits := svc.GetLimits(c)
if limits.MaxReceiveAmount > 0 { if limits.MaxReceiveAmount > 0 {
if amount > limits.MaxReceiveAmount { if amount > limits.MaxReceiveAmount {
svc.Logger.Errorf("Max receive amount exceeded for user_id %d", userId) svc.Logger.Errorf("Max receive amount exceeded for user_id %d", userId)
@@ -185,7 +187,7 @@ func (svc *LndhubService) CheckIncomingPaymentAllowed(ctx context.Context, amoun
} }
if limits.MaxReceiveVolume > 0 { if limits.MaxReceiveVolume > 0 {
volume, err := svc.GetVolumeOverPeriod(ctx, userId, common.InvoiceTypeIncoming, time.Duration(svc.Config.MaxVolumePeriod*int64(time.Second))) volume, err := svc.GetVolumeOverPeriod(c.Request().Context(), userId, common.InvoiceTypeIncoming, time.Duration(svc.Config.MaxVolumePeriod*int64(time.Second)))
if err != nil { if err != nil {
svc.Logger.Errorj( svc.Logger.Errorj(
log.JSON{ log.JSON{
@@ -204,7 +206,7 @@ func (svc *LndhubService) CheckIncomingPaymentAllowed(ctx context.Context, amoun
} }
if limits.MaxAccountBalance > 0 { if limits.MaxAccountBalance > 0 {
currentBalance, err := svc.CurrentUserBalance(ctx, userId) currentBalance, err := svc.CurrentUserBalance(c.Request().Context(), userId)
if err != nil { if err != nil {
svc.Logger.Errorj( svc.Logger.Errorj(
log.JSON{ log.JSON{
@@ -290,8 +292,8 @@ func (svc *LndhubService) GetVolumeOverPeriod(ctx context.Context, userId int64,
return result, nil return result, nil
} }
func (svc *LndhubService) GetLimits(c echo.Context) (limits *models.Limits) { func (svc *LndhubService) GetLimits(c echo.Context) (limits *lnd.Limits) {
limits = &models.Limits{ limits = &lnd.Limits{
MaxSendVolume: svc.Config.MaxSendVolume, MaxSendVolume: svc.Config.MaxSendVolume,
MaxSendAmount: svc.Config.MaxSendAmount, MaxSendAmount: svc.Config.MaxSendAmount,
MaxReceiveVolume: svc.Config.MaxReceiveVolume, MaxReceiveVolume: svc.Config.MaxReceiveVolume,

View File

@@ -22,6 +22,14 @@ type Config struct {
LNDClusterPubkeys string `envconfig:"LND_CLUSTER_PUBKEYS"` //comma-seperated list of public keys of the cluster LNDClusterPubkeys string `envconfig:"LND_CLUSTER_PUBKEYS"` //comma-seperated list of public keys of the cluster
} }
type Limits struct {
MaxSendVolume int64
MaxSendAmount int64
MaxReceiveVolume int64
MaxReceiveAmount int64
MaxAccountBalance int64
}
func LoadConfig() (c *Config, err error) { func LoadConfig() (c *Config, err error) {
c = &Config{} c = &Config{}
err = envconfig.Process("", c) err = envconfig.Process("", c)