diff --git a/controllers/addinvoice.ctrl.go b/controllers/addinvoice.ctrl.go index c1867d0..0347767 100644 --- a/controllers/addinvoice.ctrl.go +++ b/controllers/addinvoice.ctrl.go @@ -47,7 +47,7 @@ func (controller *AddInvoiceController) AddInvoice(c echo.Context) error { c.Logger().Infof("Adding invoice: user_id=%v memo=%s value=%v description_hash=%s", userID, body.Memo, body.Amt, body.DescriptionHash) - invoice, err := controller.svc.AddInvoice(userID, body.Amt, body.Memo, body.DescriptionHash) + invoice, err := controller.svc.AddIncomingInvoice(userID, body.Amt, body.Memo, body.DescriptionHash) if err != nil { c.Logger().Errorf("Error creating invoice: %v", err) // TODO: sentry notification diff --git a/controllers/payinvoice.ctrl.go b/controllers/payinvoice.ctrl.go index adf9eb9..3ce9628 100644 --- a/controllers/payinvoice.ctrl.go +++ b/controllers/payinvoice.ctrl.go @@ -1,6 +1,8 @@ package controllers import ( + "context" + "fmt" "net/http" "github.com/getAlby/lndhub.go/lib/service" @@ -18,23 +20,71 @@ func NewPayInvoiceController(svc *service.LndhubService) *PayInvoiceController { // PayInvoice : Pay invoice Controller func (controller *PayInvoiceController) PayInvoice(c echo.Context) error { - userId := c.Get("UserID").(int64) + userID := c.Get("UserID").(int64) var reqBody struct { Invoice string `json:"invoice" validate:"required"` Amount int `json:"amount" validate:"omitempty,gte=0"` } if err := c.Bind(&reqBody); err != nil { + c.Logger().Errorf("Failed to load payinvoice request body: %v", err) return c.JSON(http.StatusBadRequest, echo.Map{ - "message": "failed to bind json", + "error": true, + "code": 8, + "message": "Bad arguments", }) } if err := c.Validate(&reqBody); err != nil { + c.Logger().Errorf("Invalid payinvoice request body: %v", err) return c.JSON(http.StatusBadRequest, echo.Map{ - "message": "invalid request", + "error": true, + "code": 8, + "message": "Bad arguments", }) } - //TODO json response - return controller.svc.Payinvoice(userId, reqBody.Invoice) + + paymentRequest := reqBody.Invoice + decodedPaymentRequest, err := controller.svc.DecodePaymentRequest(paymentRequest) + if err != nil { + return c.JSON(http.StatusBadRequest, echo.Map{ + "error": true, + "code": 8, + "message": "Bad arguments", + }) + } + c.Logger().Info("%v", decodedPaymentRequest) + + invoice, err := controller.svc.AddOutgoingInvoice(userID, paymentRequest, *decodedPaymentRequest) + if err != nil { + c.Logger().Errorf("Error creating invoice: %v", err) + // TODO: sentry notification + return c.JSON(http.StatusInternalServerError, nil) + } + + currentBalance, err := controller.svc.CurrentUserBalance(context.TODO(), userID) + if err != nil { + return err + } + + if currentBalance < invoice.Amount { + c.Logger().Errorf("User does not have enough balance invoice_id=%v user_id=%v balance=%v amount=%v", invoice.ID, userID, currentBalance, invoice.Amount) + + return c.JSON(http.StatusBadRequest, echo.Map{ + "error": true, + "code": 2, + "message": fmt.Sprintf("not enough balance (%v). Make sure you have at least 1%% reserved for potential fees", currentBalance), + }) + } + + entry, err := controller.svc.PayInvoice(invoice) + if err != nil { + c.Logger().Errorf("Failed: %v", err) + return c.JSON(http.StatusBadRequest, echo.Map{ + "error": true, + "code": 10, + "message": fmt.Sprintf("Payment failed. Does the receiver have enough inbound capacity? (%v)", err), + }) + } + return c.JSON(http.StatusOK, &entry) } diff --git a/db/models/invoice.go b/db/models/invoice.go index 90f7fbd..afc61a4 100644 --- a/db/models/invoice.go +++ b/db/models/invoice.go @@ -9,17 +9,18 @@ import ( // Invoice : Invoice Model type Invoice struct { - ID uint `json:"id" bun:",pk,autoincrement"` - Type string `json:"type"` - UserID int64 `json:"user_id"` + ID int64 `json:"id" bun:",pk,autoincrement"` + Type string `json:"type" validate:"required"` + UserID int64 `json:"user_id" validate:"required"` User *User `bun:"rel:belongs-to,join:user_id=id"` - Amount int64 `json:"amount"` + Amount int64 `json:"amount" validate:"gte=0"` Memo string `json:"memo"` - DescriptionHash string `json:"description_hash"` + DescriptionHash string `json:"description_hash" bun:",nullzero"` PaymentRequest string `json:"payment_request"` RHash string `json:"r_hash"` - State string `json:"state"` - AddIndex uint64 `json:"add_index"` + Preimage string `json:"preimage" bun:",nullzero"` + State string `json:"state" bun:",default:'initialized'"` + AddIndex uint64 `json:"add_index" bun:",nullzero"` CreatedAt time.Time `bun:",nullzero,notnull,default:current_timestamp"` UpdatedAt bun.NullTime `json:"updated_at"` SettledAt bun.NullTime `json:"settled_at"` diff --git a/lib/service/invoices.go b/lib/service/invoices.go index a9a456d..aa6e6d6 100644 --- a/lib/service/invoices.go +++ b/lib/service/invoices.go @@ -2,12 +2,19 @@ package service import ( "context" + "database/sql" "encoding/hex" + "errors" "math/rand" + "strings" + "time" + "github.com/btcsuite/btcd/chaincfg" "github.com/getAlby/lndhub.go/db/models" "github.com/labstack/gommon/random" "github.com/lightningnetwork/lnd/lnrpc" + "github.com/lightningnetwork/lnd/zpay32" + "github.com/uptrace/bun/schema" ) func (svc *LndhubService) FindInvoiceByPaymentHash(userId int64, rHash string) (*models.Invoice, error) { @@ -20,28 +27,120 @@ func (svc *LndhubService) FindInvoiceByPaymentHash(userId int64, rHash string) ( return &invoice, nil } -func (svc *LndhubService) Payinvoice(userId int64, invoice string) error { +func (svc *LndhubService) PayInvoice(invoice *models.Invoice) (*models.TransactionEntry, error) { + userId := invoice.UserID + + // Get the user's current and outgoing account for the transaction entry debitAccount, err := svc.AccountFor(context.TODO(), "current", userId) if err != nil { - return err + return nil, err } creditAccount, err := svc.AccountFor(context.TODO(), "outgoing", userId) if err != nil { - return err + return nil, err } entry := models.TransactionEntry{ UserID: userId, + InvoiceID: invoice.ID, CreditAccountID: creditAccount.ID, DebitAccountID: debitAccount.ID, - Amount: 1000, + Amount: invoice.Amount, } - _, err = svc.DB.NewInsert().Model(&entry).Exec(context.TODO()) - return err + // Start a DB transaction + // We rollback anything on error (only the invoice that was passed in to the PayInvoice calls stays in the DB) + tx, err := svc.DB.BeginTx(context.TODO(), &sql.TxOptions{}) + if err != nil { + return &entry, err + } + _, err = tx.NewInsert().Model(&entry).Exec(context.TODO()) + if err != nil { + tx.Rollback() + return &entry, err + } + + // TODO: set fee limit + feeLimit := lnrpc.FeeLimit{ + Limit: &lnrpc.FeeLimit_Percent{ + Percent: 2, + }, + } + + // Prepare the LNRPC call + sendPaymentRequest := lnrpc.SendRequest{ + PaymentRequest: invoice.PaymentRequest, + Amt: invoice.Amount, + FeeLimit: &feeLimit, + } + lndClient := *svc.LndClient + + // Execute the payment + sendPaymentResult, err := lndClient.SendPaymentSync(context.TODO(), &sendPaymentRequest) + if err != nil { + tx.Rollback() + return &entry, err + } + + // If there was a payment error we rollback and return an error + if sendPaymentResult.GetPaymentError() != "" || sendPaymentResult.GetPaymentPreimage() == nil { + tx.Rollback() + return &entry, errors.New(sendPaymentResult.GetPaymentError()) + } + + // The payment was successful. + // We store the preimage and mark the invoice as settled + preimage := sendPaymentResult.GetPaymentPreimage() + invoice.Preimage = hex.EncodeToString(preimage[:]) + invoice.State = "settled" + invoice.SettledAt = schema.NullTime{Time: time.Now()} + + _, err = svc.DB.NewUpdate().Model(invoice).WherePK().Exec(context.TODO()) + if err != nil { + tx.Rollback() + return &entry, err + } + + // Commit the DB transaction. Done, everything worked + err = tx.Commit() + if err != nil { + return &entry, err + } + + return &entry, err } -func (svc *LndhubService) AddInvoice(userID int64, amount int64, memo, descriptionHash string) (*models.Invoice, error) { +func (svc *LndhubService) AddOutgoingInvoice(userID int64, paymentRequest string, decodedInvoice zpay32.Invoice) (*models.Invoice, error) { + // Initialize new DB invoice + invoice := models.Invoice{ + Type: "outgoing", + UserID: userID, + Memo: *decodedInvoice.Description, + PaymentRequest: paymentRequest, + State: "initialized", + } + if decodedInvoice.DescriptionHash != nil { + dh := *decodedInvoice.DescriptionHash + invoice.DescriptionHash = hex.EncodeToString(dh[:]) + } + if decodedInvoice.PaymentHash != nil { + ph := *decodedInvoice.PaymentHash + invoice.RHash = hex.EncodeToString(ph[:]) + } + if decodedInvoice.MilliSat != nil { + msat := decodedInvoice.MilliSat + invoice.Amount = int64(msat.ToSatoshis()) + } + + // Save invoice + _, err := svc.DB.NewInsert().Model(&invoice).Exec(context.TODO()) + if err != nil { + return nil, err + } + return &invoice, nil +} + +func (svc *LndhubService) AddIncomingInvoice(userID int64, amount int64, memo, descriptionHash string) (*models.Invoice, error) { // Initialize new DB invoice invoice := models.Invoice{ Type: "incoming", @@ -86,6 +185,10 @@ func (svc *LndhubService) AddInvoice(userID int64, amount int64, memo, descripti return &invoice, nil } +func (svc *LndhubService) DecodePaymentRequest(bolt11 string) (*zpay32.Invoice, error) { + return zpay32.Decode(bolt11, ChainFromCurrency(bolt11[2:])) +} + const hexBytes = random.Hex func makePreimageHex() []byte { @@ -95,3 +198,15 @@ func makePreimageHex() []byte { } return b } + +func ChainFromCurrency(currency string) *chaincfg.Params { + if strings.HasPrefix(currency, "bcrt") { + return &chaincfg.RegressionNetParams + } else if strings.HasPrefix(currency, "tb") { + return &chaincfg.TestNet3Params + } else if strings.HasPrefix(currency, "sb") { + return &chaincfg.SimNetParams + } else { + return &chaincfg.MainNetParams + } +}