From 0da4b359d69d5f46a7494fe0f6869796a35210e1 Mon Sep 17 00:00:00 2001 From: Stefan Kostic Date: Tue, 8 Feb 2022 13:35:02 +0100 Subject: [PATCH 1/2] Pass in ctx as service methods first argument --- lib/service/invoices.go | 48 ++++++++++++++++++++--------------------- lib/service/service.go | 4 ++-- lib/service/user.go | 6 +++--- 3 files changed, 29 insertions(+), 29 deletions(-) diff --git a/lib/service/invoices.go b/lib/service/invoices.go index 5f29e34..9292fe3 100644 --- a/lib/service/invoices.go +++ b/lib/service/invoices.go @@ -34,33 +34,33 @@ type SendPaymentResponse struct { Invoice *models.Invoice } -func (svc *LndhubService) FindInvoiceByPaymentHash(userId int64, rHash string) (*models.Invoice, error) { +func (svc *LndhubService) FindInvoiceByPaymentHash(ctx context.Context, userId int64, rHash string) (*models.Invoice, error) { var invoice models.Invoice - err := svc.DB.NewSelect().Model(&invoice).Where("invoice.user_id = ? AND invoice.r_hash = ?", userId, rHash).Limit(1).Scan(context.TODO()) + err := svc.DB.NewSelect().Model(&invoice).Where("invoice.user_id = ? AND invoice.r_hash = ?", userId, rHash).Limit(1).Scan(ctx) if err != nil { return &invoice, err } return &invoice, nil } -func (svc *LndhubService) SendInternalPayment(tx *bun.Tx, invoice *models.Invoice) (SendPaymentResponse, error) { +func (svc *LndhubService) SendInternalPayment(ctx context.Context, tx *bun.Tx, invoice *models.Invoice) (SendPaymentResponse, error) { sendPaymentResponse := SendPaymentResponse{} //SendInternalPayment() // find invoice var incomingInvoice models.Invoice - err := svc.DB.NewSelect().Model(&incomingInvoice).Where("type = ? AND payment_request = ? AND state = ? ", "incoming", invoice.PaymentRequest, "open").Limit(1).Scan(context.TODO()) + err := svc.DB.NewSelect().Model(&incomingInvoice).Where("type = ? AND payment_request = ? AND state = ? ", "incoming", invoice.PaymentRequest, "open").Limit(1).Scan(ctx) if err != nil { // invoice not found or already settled // TODO: logging return sendPaymentResponse, err } // Get the user's current and incoming account for the transaction entry - recipientCreditAccount, err := svc.AccountFor(context.TODO(), "current", incomingInvoice.UserID) + recipientCreditAccount, err := svc.AccountFor(ctx, "current", incomingInvoice.UserID) if err != nil { return sendPaymentResponse, err } - recipientDebitAccount, err := svc.AccountFor(context.TODO(), "incoming", incomingInvoice.UserID) + recipientDebitAccount, err := svc.AccountFor(ctx, "incoming", incomingInvoice.UserID) if err != nil { return sendPaymentResponse, err } @@ -72,7 +72,7 @@ func (svc *LndhubService) SendInternalPayment(tx *bun.Tx, invoice *models.Invoic DebitAccountID: recipientDebitAccount.ID, Amount: invoice.Amount, } - _, err = tx.NewInsert().Model(&recipientEntry).Exec(context.TODO()) + _, err = tx.NewInsert().Model(&recipientEntry).Exec(ctx) if err != nil { return sendPaymentResponse, err } @@ -91,7 +91,7 @@ func (svc *LndhubService) SendInternalPayment(tx *bun.Tx, invoice *models.Invoic incomingInvoice.Internal = true // mark incoming invoice as internal, just for documentation/debugging incomingInvoice.State = "settled" incomingInvoice.SettledAt = schema.NullTime{Time: time.Now()} - _, err = tx.NewUpdate().Model(&incomingInvoice).WherePK().Exec(context.TODO()) + _, err = tx.NewUpdate().Model(&incomingInvoice).WherePK().Exec(ctx) if err != nil { // could not save the invoice of the recipient return sendPaymentResponse, err @@ -100,7 +100,7 @@ func (svc *LndhubService) SendInternalPayment(tx *bun.Tx, invoice *models.Invoic return sendPaymentResponse, nil } -func (svc *LndhubService) SendPaymentSync(tx *bun.Tx, invoice *models.Invoice) (SendPaymentResponse, error) { +func (svc *LndhubService) SendPaymentSync(ctx context.Context, tx *bun.Tx, invoice *models.Invoice) (SendPaymentResponse, error) { sendPaymentResponse := SendPaymentResponse{} // TODO: set dynamic fee limit feeLimit := lnrpc.FeeLimit{ @@ -120,7 +120,7 @@ func (svc *LndhubService) SendPaymentSync(tx *bun.Tx, invoice *models.Invoice) ( } // Execute the payment - sendPaymentResult, err := svc.LndClient.SendPaymentSync(context.TODO(), &sendPaymentRequest) + sendPaymentResult, err := svc.LndClient.SendPaymentSync(ctx, &sendPaymentRequest) if err != nil { return sendPaymentResponse, err } @@ -140,15 +140,15 @@ func (svc *LndhubService) SendPaymentSync(tx *bun.Tx, invoice *models.Invoice) ( return sendPaymentResponse, nil } -func (svc *LndhubService) PayInvoice(invoice *models.Invoice) (*SendPaymentResponse, error) { +func (svc *LndhubService) PayInvoice(ctx context.Context, invoice *models.Invoice) (*SendPaymentResponse, error) { userId := invoice.UserID // Get the user's current and outgoing account for the transaction entry - debitAccount, err := svc.AccountFor(context.TODO(), "current", userId) + debitAccount, err := svc.AccountFor(ctx, "current", userId) if err != nil { return nil, err } - creditAccount, err := svc.AccountFor(context.TODO(), "outgoing", userId) + creditAccount, err := svc.AccountFor(ctx, "outgoing", userId) if err != nil { return nil, err } @@ -163,14 +163,14 @@ func (svc *LndhubService) PayInvoice(invoice *models.Invoice) (*SendPaymentRespo // Start a DB transaction // We rollback anything on error (only the invoice that was passed in to the PayInvoice calls stays in the DB) - tx, err := svc.DB.BeginTx(context.TODO(), &sql.TxOptions{}) + tx, err := svc.DB.BeginTx(ctx, &sql.TxOptions{}) if err != nil { return nil, err } // The DB constraints make sure the user actually has enough balance for the transaction // If the user does not have enough balance this call fails - _, err = tx.NewInsert().Model(&entry).Exec(context.TODO()) + _, err = tx.NewInsert().Model(&entry).Exec(ctx) if err != nil { tx.Rollback() return nil, err @@ -186,13 +186,13 @@ func (svc *LndhubService) PayInvoice(invoice *models.Invoice) (*SendPaymentRespo return nil, err } if svc.IdentityPubkey.IsEqual(destinationPubkey) { - paymentResponse, err = svc.SendInternalPayment(&tx, invoice) + paymentResponse, err = svc.SendInternalPayment(ctx, &tx, invoice) if err != nil { tx.Rollback() return nil, err } } else { - paymentResponse, err = svc.SendPaymentSync(&tx, invoice) + paymentResponse, err = svc.SendPaymentSync(ctx, &tx, invoice) if err != nil { tx.Rollback() return nil, err @@ -206,7 +206,7 @@ func (svc *LndhubService) PayInvoice(invoice *models.Invoice) (*SendPaymentRespo invoice.State = "settled" invoice.SettledAt = schema.NullTime{Time: time.Now()} - _, err = tx.NewUpdate().Model(invoice).WherePK().Exec(context.TODO()) + _, err = tx.NewUpdate().Model(invoice).WherePK().Exec(ctx) if err != nil { tx.Rollback() return nil, err @@ -222,7 +222,7 @@ func (svc *LndhubService) PayInvoice(invoice *models.Invoice) (*SendPaymentRespo return &paymentResponse, err } -func (svc *LndhubService) AddOutgoingInvoice(userID int64, paymentRequest string, decodedInvoice *zpay32.Invoice) (*models.Invoice, error) { +func (svc *LndhubService) AddOutgoingInvoice(ctx context.Context, userID int64, paymentRequest string, decodedInvoice *zpay32.Invoice) (*models.Invoice, error) { // Initialize new DB invoice destinationPubkeyHex := hex.EncodeToString(decodedInvoice.Destination.SerializeCompressed()) expiresAt := decodedInvoice.Timestamp.Add(decodedInvoice.Expiry()) @@ -251,14 +251,14 @@ func (svc *LndhubService) AddOutgoingInvoice(userID int64, paymentRequest string } // Save invoice - _, err := svc.DB.NewInsert().Model(&invoice).Exec(context.TODO()) + _, err := svc.DB.NewInsert().Model(&invoice).Exec(ctx) if err != nil { return nil, err } return &invoice, nil } -func (svc *LndhubService) AddIncomingInvoice(userID int64, amount int64, memo, descriptionHashStr string) (*models.Invoice, error) { +func (svc *LndhubService) AddIncomingInvoice(ctx context.Context, userID int64, amount int64, memo, descriptionHashStr string) (*models.Invoice, error) { preimage := makePreimageHex() expiry := time.Hour * 24 // invoice expires in 24h // Initialize new DB invoice @@ -273,7 +273,7 @@ func (svc *LndhubService) AddIncomingInvoice(userID int64, amount int64, memo, d } // Save invoice - we save the invoice early to have a record in case the LN call fails - _, err := svc.DB.NewInsert().Model(&invoice).Exec(context.TODO()) + _, err := svc.DB.NewInsert().Model(&invoice).Exec(ctx) if err != nil { return nil, err } @@ -291,7 +291,7 @@ func (svc *LndhubService) AddIncomingInvoice(userID int64, amount int64, memo, d Expiry: int64(expiry.Seconds()), } // Call LND - lnInvoiceResult, err := svc.LndClient.AddInvoice(context.TODO(), &lnInvoice) + lnInvoiceResult, err := svc.LndClient.AddInvoice(ctx, &lnInvoice) if err != nil { return nil, err } @@ -304,7 +304,7 @@ func (svc *LndhubService) AddIncomingInvoice(userID int64, amount int64, memo, d invoice.DestinationPubkeyHex = svc.GetIdentPubKeyHex() // Our node pubkey for incoming invoices invoice.State = "open" - _, err = svc.DB.NewUpdate().Model(&invoice).WherePK().Exec(context.TODO()) + _, err = svc.DB.NewUpdate().Model(&invoice).WherePK().Exec(ctx) if err != nil { return nil, err } diff --git a/lib/service/service.go b/lib/service/service.go index 33dd0fe..aac2cb6 100644 --- a/lib/service/service.go +++ b/lib/service/service.go @@ -25,13 +25,13 @@ type LndhubService struct { IdentityPubkey *btcec.PublicKey } -func (svc *LndhubService) GenerateToken(login, password, inRefreshToken string) (accessToken, refreshToken string, err error) { +func (svc *LndhubService) GenerateToken(ctx context.Context, login, password, inRefreshToken string) (accessToken, refreshToken string, err error) { var user models.User switch { case login != "" || password != "": { - if err := svc.DB.NewSelect().Model(&user).Where("login = ?", login).Scan(context.TODO()); err != nil { + if err := svc.DB.NewSelect().Model(&user).Where("login = ?", login).Scan(ctx); err != nil { return "", "", fmt.Errorf("bad auth") } if bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(password)) != nil { diff --git a/lib/service/user.go b/lib/service/user.go index 541246a..99c2c12 100644 --- a/lib/service/user.go +++ b/lib/service/user.go @@ -10,7 +10,7 @@ import ( "github.com/uptrace/bun" ) -func (svc *LndhubService) CreateUser() (user *models.User, err error) { +func (svc *LndhubService) CreateUser(ctx context.Context) (user *models.User, err error) { user = &models.User{} @@ -24,7 +24,7 @@ func (svc *LndhubService) CreateUser() (user *models.User, err error) { // Create user and the user's accounts // We use double-entry bookkeeping so we use 4 accounts: incoming, current, outgoing and fees // Wrapping this in a transaction in case something fails - err = svc.DB.RunInTx(context.TODO(), &sql.TxOptions{}, func(ctx context.Context, tx bun.Tx) error { + err = svc.DB.RunInTx(ctx, &sql.TxOptions{}, func(ctx context.Context, tx bun.Tx) error { if _, err := tx.NewInsert().Model(user).Exec(ctx); err != nil { return err } @@ -59,7 +59,7 @@ func (svc *LndhubService) CurrentUserBalance(ctx context.Context, userId int64) if err != nil { return balance, err } - err = svc.DB.NewSelect().Table("account_ledgers").ColumnExpr("sum(account_ledgers.amount) as balance").Where("account_ledgers.account_id = ?", account.ID).Scan(context.TODO(), &balance) + err = svc.DB.NewSelect().Table("account_ledgers").ColumnExpr("sum(account_ledgers.amount) as balance").Where("account_ledgers.account_id = ?", account.ID).Scan(ctx, &balance) return balance, err } From 2748b7f2f30d6a7cbe91af921e18bae07573efaf Mon Sep 17 00:00:00 2001 From: Stefan Kostic Date: Tue, 8 Feb 2022 13:36:50 +0100 Subject: [PATCH 2/2] Use request context in handlers --- controllers/addinvoice.ctrl.go | 2 +- controllers/auth.ctrl.go | 2 +- controllers/balance.ctrl.go | 3 +-- controllers/checkpayment.ctrl.go | 2 +- controllers/create.ctrl.go | 2 +- controllers/getinfo.ctrl.go | 3 +-- controllers/gettxs.ctrl.go | 5 ++--- controllers/home.ctrl.go | 5 ++--- controllers/payinvoice.ctrl.go | 7 +++---- 9 files changed, 13 insertions(+), 18 deletions(-) diff --git a/controllers/addinvoice.ctrl.go b/controllers/addinvoice.ctrl.go index ff31f81..02d75fd 100644 --- a/controllers/addinvoice.ctrl.go +++ b/controllers/addinvoice.ctrl.go @@ -45,7 +45,7 @@ func (controller *AddInvoiceController) AddInvoice(c echo.Context) error { } c.Logger().Infof("Adding invoice: user_id=%v memo=%s value=%v description_hash=%s", userID, body.Memo, amount, body.DescriptionHash) - invoice, err := controller.svc.AddIncomingInvoice(userID, amount, body.Memo, body.DescriptionHash) + invoice, err := controller.svc.AddIncomingInvoice(c.Request().Context(), userID, amount, body.Memo, body.DescriptionHash) if err != nil { c.Logger().Errorf("Error creating invoice: %v", err) sentry.CaptureException(err) diff --git a/controllers/auth.ctrl.go b/controllers/auth.ctrl.go index c888caf..e280f8b 100644 --- a/controllers/auth.ctrl.go +++ b/controllers/auth.ctrl.go @@ -41,7 +41,7 @@ func (controller *AuthController) Auth(c echo.Context) error { return c.JSON(http.StatusBadRequest, responses.BadArgumentsError) } - accessToken, refreshToken, err := controller.svc.GenerateToken(body.Login, body.Password, body.RefreshToken) + accessToken, refreshToken, err := controller.svc.GenerateToken(c.Request().Context(), body.Login, body.Password, body.RefreshToken) if err != nil { return err } diff --git a/controllers/balance.ctrl.go b/controllers/balance.ctrl.go index 99a1f02..9e524a3 100644 --- a/controllers/balance.ctrl.go +++ b/controllers/balance.ctrl.go @@ -1,7 +1,6 @@ package controllers import ( - "context" "net/http" "github.com/getAlby/lndhub.go/lib/service" @@ -20,7 +19,7 @@ func NewBalanceController(svc *service.LndhubService) *BalanceController { // Balance : Balance Controller func (controller *BalanceController) Balance(c echo.Context) error { userId := c.Get("UserID").(int64) - balance, err := controller.svc.CurrentUserBalance(context.TODO(), userId) + balance, err := controller.svc.CurrentUserBalance(c.Request().Context(), userId) if err != nil { return err } diff --git a/controllers/checkpayment.ctrl.go b/controllers/checkpayment.ctrl.go index c13e89f..d2068a1 100644 --- a/controllers/checkpayment.ctrl.go +++ b/controllers/checkpayment.ctrl.go @@ -22,7 +22,7 @@ func (controller *CheckPaymentController) CheckPayment(c echo.Context) error { userId := c.Get("UserID").(int64) rHash := c.Param("payment_hash") - invoice, err := controller.svc.FindInvoiceByPaymentHash(userId, rHash) + invoice, err := controller.svc.FindInvoiceByPaymentHash(c.Request().Context(), userId, rHash) // Probably we did not find the invoice if err != nil { diff --git a/controllers/create.ctrl.go b/controllers/create.ctrl.go index 618665a..268dd0b 100644 --- a/controllers/create.ctrl.go +++ b/controllers/create.ctrl.go @@ -33,7 +33,7 @@ func (controller *CreateUserController) CreateUser(c echo.Context) error { if err := c.Bind(&body); err != nil { return err } - user, err := controller.svc.CreateUser() + user, err := controller.svc.CreateUser(c.Request().Context()) //todo json response if err != nil { return err diff --git a/controllers/getinfo.ctrl.go b/controllers/getinfo.ctrl.go index d58f9f2..c4a6b33 100644 --- a/controllers/getinfo.ctrl.go +++ b/controllers/getinfo.ctrl.go @@ -1,7 +1,6 @@ package controllers import ( - "context" "net/http" "github.com/getAlby/lndhub.go/lib/service" @@ -21,7 +20,7 @@ func NewGetInfoController(svc *service.LndhubService) *GetInfoController { func (controller *GetInfoController) GetInfo(c echo.Context) error { // TODO: add some caching for this GetInfo call. No need to always hit the node - info, err := controller.svc.GetInfo(context.TODO()) + info, err := controller.svc.GetInfo(c.Request().Context()) if err != nil { return err } diff --git a/controllers/gettxs.ctrl.go b/controllers/gettxs.ctrl.go index ae8052f..b19ea2d 100644 --- a/controllers/gettxs.ctrl.go +++ b/controllers/gettxs.ctrl.go @@ -1,7 +1,6 @@ package controllers import ( - "context" "net/http" "github.com/getAlby/lndhub.go/lib" @@ -22,7 +21,7 @@ func NewGetTXSController(svc *service.LndhubService) *GetTXSController { func (controller *GetTXSController) GetTXS(c echo.Context) error { userId := c.Get("UserID").(int64) - invoices, err := controller.svc.InvoicesFor(context.TODO(), userId, "outgoing") + invoices, err := controller.svc.InvoicesFor(c.Request().Context(), userId, "outgoing") if err != nil { return err } @@ -47,7 +46,7 @@ func (controller *GetTXSController) GetTXS(c echo.Context) error { func (controller *GetTXSController) GetUserInvoices(c echo.Context) error { userId := c.Get("UserID").(int64) - invoices, err := controller.svc.InvoicesFor(context.TODO(), userId, "incoming") + invoices, err := controller.svc.InvoicesFor(c.Request().Context(), userId, "incoming") if err != nil { return err } diff --git a/controllers/home.ctrl.go b/controllers/home.ctrl.go index 2a854cf..96e546a 100644 --- a/controllers/home.ctrl.go +++ b/controllers/home.ctrl.go @@ -2,7 +2,6 @@ package controllers import ( "bytes" - "context" _ "embed" "fmt" "html/template" @@ -72,11 +71,11 @@ func (controller *HomeController) QR(c echo.Context) error { } func (controller *HomeController) Home(c echo.Context) error { - info, err := controller.svc.GetInfo(context.TODO()) + info, err := controller.svc.GetInfo(c.Request().Context()) if err != nil { return err } - channels, err := controller.svc.LndClient.ListChannels(context.TODO(), &lnrpc.ListChannelsRequest{}) + channels, err := controller.svc.LndClient.ListChannels(c.Request().Context(), &lnrpc.ListChannelsRequest{}) if err != nil { return err } diff --git a/controllers/payinvoice.ctrl.go b/controllers/payinvoice.ctrl.go index 1962be8..b8e5c07 100644 --- a/controllers/payinvoice.ctrl.go +++ b/controllers/payinvoice.ctrl.go @@ -1,7 +1,6 @@ package controllers import ( - "context" "fmt" "net/http" @@ -58,12 +57,12 @@ func (controller *PayInvoiceController) PayInvoice(c echo.Context) error { } */ - invoice, err := controller.svc.AddOutgoingInvoice(userID, paymentRequest, decodedPaymentRequest) + invoice, err := controller.svc.AddOutgoingInvoice(c.Request().Context(), userID, paymentRequest, decodedPaymentRequest) if err != nil { return err } - currentBalance, err := controller.svc.CurrentUserBalance(context.TODO(), userID) + currentBalance, err := controller.svc.CurrentUserBalance(c.Request().Context(), userID) if err != nil { return err } @@ -74,7 +73,7 @@ func (controller *PayInvoiceController) PayInvoice(c echo.Context) error { return c.JSON(http.StatusBadRequest, responses.NotEnoughBalanceError) } - sendPaymentResponse, err := controller.svc.PayInvoice(invoice) + sendPaymentResponse, err := controller.svc.PayInvoice(c.Request().Context(), invoice) if err != nil { c.Logger().Errorf("Payment failed: %v", err) sentry.CaptureException(err)