Add payinvoice endpoint

This commit is contained in:
Michael Bumann
2022-01-20 01:57:31 +01:00
parent 1db6f77dd9
commit d4183c100b
4 changed files with 186 additions and 20 deletions

View File

@@ -47,7 +47,7 @@ func (controller *AddInvoiceController) AddInvoice(c echo.Context) error {
c.Logger().Infof("Adding invoice: user_id=%v memo=%s value=%v description_hash=%s", userID, body.Memo, body.Amt, body.DescriptionHash) c.Logger().Infof("Adding invoice: user_id=%v memo=%s value=%v description_hash=%s", userID, body.Memo, body.Amt, body.DescriptionHash)
invoice, err := controller.svc.AddInvoice(userID, body.Amt, body.Memo, body.DescriptionHash) invoice, err := controller.svc.AddIncomingInvoice(userID, body.Amt, body.Memo, body.DescriptionHash)
if err != nil { if err != nil {
c.Logger().Errorf("Error creating invoice: %v", err) c.Logger().Errorf("Error creating invoice: %v", err)
// TODO: sentry notification // TODO: sentry notification

View File

@@ -1,6 +1,8 @@
package controllers package controllers
import ( import (
"context"
"fmt"
"net/http" "net/http"
"github.com/getAlby/lndhub.go/lib/service" "github.com/getAlby/lndhub.go/lib/service"
@@ -18,23 +20,71 @@ func NewPayInvoiceController(svc *service.LndhubService) *PayInvoiceController {
// PayInvoice : Pay invoice Controller // PayInvoice : Pay invoice Controller
func (controller *PayInvoiceController) PayInvoice(c echo.Context) error { func (controller *PayInvoiceController) PayInvoice(c echo.Context) error {
userId := c.Get("UserID").(int64) userID := c.Get("UserID").(int64)
var reqBody struct { var reqBody struct {
Invoice string `json:"invoice" validate:"required"` Invoice string `json:"invoice" validate:"required"`
Amount int `json:"amount" validate:"omitempty,gte=0"` Amount int `json:"amount" validate:"omitempty,gte=0"`
} }
if err := c.Bind(&reqBody); err != nil { if err := c.Bind(&reqBody); err != nil {
c.Logger().Errorf("Failed to load payinvoice request body: %v", err)
return c.JSON(http.StatusBadRequest, echo.Map{ return c.JSON(http.StatusBadRequest, echo.Map{
"message": "failed to bind json", "error": true,
"code": 8,
"message": "Bad arguments",
}) })
} }
if err := c.Validate(&reqBody); err != nil { if err := c.Validate(&reqBody); err != nil {
c.Logger().Errorf("Invalid payinvoice request body: %v", err)
return c.JSON(http.StatusBadRequest, echo.Map{ return c.JSON(http.StatusBadRequest, echo.Map{
"message": "invalid request", "error": true,
"code": 8,
"message": "Bad arguments",
}) })
} }
//TODO json response
return controller.svc.Payinvoice(userId, reqBody.Invoice) paymentRequest := reqBody.Invoice
decodedPaymentRequest, err := controller.svc.DecodePaymentRequest(paymentRequest)
if err != nil {
return c.JSON(http.StatusBadRequest, echo.Map{
"error": true,
"code": 8,
"message": "Bad arguments",
})
}
c.Logger().Info("%v", decodedPaymentRequest)
invoice, err := controller.svc.AddOutgoingInvoice(userID, paymentRequest, *decodedPaymentRequest)
if err != nil {
c.Logger().Errorf("Error creating invoice: %v", err)
// TODO: sentry notification
return c.JSON(http.StatusInternalServerError, nil)
}
currentBalance, err := controller.svc.CurrentUserBalance(context.TODO(), userID)
if err != nil {
return err
}
if currentBalance < invoice.Amount {
c.Logger().Errorf("User does not have enough balance invoice_id=%v user_id=%v balance=%v amount=%v", invoice.ID, userID, currentBalance, invoice.Amount)
return c.JSON(http.StatusBadRequest, echo.Map{
"error": true,
"code": 2,
"message": fmt.Sprintf("not enough balance (%v). Make sure you have at least 1%% reserved for potential fees", currentBalance),
})
}
entry, err := controller.svc.PayInvoice(invoice)
if err != nil {
c.Logger().Errorf("Failed: %v", err)
return c.JSON(http.StatusBadRequest, echo.Map{
"error": true,
"code": 10,
"message": fmt.Sprintf("Payment failed. Does the receiver have enough inbound capacity? (%v)", err),
})
}
return c.JSON(http.StatusOK, &entry)
} }

View File

@@ -9,17 +9,18 @@ import (
// Invoice : Invoice Model // Invoice : Invoice Model
type Invoice struct { type Invoice struct {
ID uint `json:"id" bun:",pk,autoincrement"` ID int64 `json:"id" bun:",pk,autoincrement"`
Type string `json:"type"` Type string `json:"type" validate:"required"`
UserID int64 `json:"user_id"` UserID int64 `json:"user_id" validate:"required"`
User *User `bun:"rel:belongs-to,join:user_id=id"` User *User `bun:"rel:belongs-to,join:user_id=id"`
Amount int64 `json:"amount"` Amount int64 `json:"amount" validate:"gte=0"`
Memo string `json:"memo"` Memo string `json:"memo"`
DescriptionHash string `json:"description_hash"` DescriptionHash string `json:"description_hash" bun:",nullzero"`
PaymentRequest string `json:"payment_request"` PaymentRequest string `json:"payment_request"`
RHash string `json:"r_hash"` RHash string `json:"r_hash"`
State string `json:"state"` Preimage string `json:"preimage" bun:",nullzero"`
AddIndex uint64 `json:"add_index"` State string `json:"state" bun:",default:'initialized'"`
AddIndex uint64 `json:"add_index" bun:",nullzero"`
CreatedAt time.Time `bun:",nullzero,notnull,default:current_timestamp"` CreatedAt time.Time `bun:",nullzero,notnull,default:current_timestamp"`
UpdatedAt bun.NullTime `json:"updated_at"` UpdatedAt bun.NullTime `json:"updated_at"`
SettledAt bun.NullTime `json:"settled_at"` SettledAt bun.NullTime `json:"settled_at"`

View File

@@ -2,12 +2,19 @@ package service
import ( import (
"context" "context"
"database/sql"
"encoding/hex" "encoding/hex"
"errors"
"math/rand" "math/rand"
"strings"
"time"
"github.com/btcsuite/btcd/chaincfg"
"github.com/getAlby/lndhub.go/db/models" "github.com/getAlby/lndhub.go/db/models"
"github.com/labstack/gommon/random" "github.com/labstack/gommon/random"
"github.com/lightningnetwork/lnd/lnrpc" "github.com/lightningnetwork/lnd/lnrpc"
"github.com/lightningnetwork/lnd/zpay32"
"github.com/uptrace/bun/schema"
) )
func (svc *LndhubService) FindInvoiceByPaymentHash(userId int64, rHash string) (*models.Invoice, error) { func (svc *LndhubService) FindInvoiceByPaymentHash(userId int64, rHash string) (*models.Invoice, error) {
@@ -20,28 +27,120 @@ func (svc *LndhubService) FindInvoiceByPaymentHash(userId int64, rHash string) (
return &invoice, nil return &invoice, nil
} }
func (svc *LndhubService) Payinvoice(userId int64, invoice string) error { func (svc *LndhubService) PayInvoice(invoice *models.Invoice) (*models.TransactionEntry, error) {
userId := invoice.UserID
// Get the user's current and outgoing account for the transaction entry
debitAccount, err := svc.AccountFor(context.TODO(), "current", userId) debitAccount, err := svc.AccountFor(context.TODO(), "current", userId)
if err != nil { if err != nil {
return err return nil, err
} }
creditAccount, err := svc.AccountFor(context.TODO(), "outgoing", userId) creditAccount, err := svc.AccountFor(context.TODO(), "outgoing", userId)
if err != nil { if err != nil {
return err return nil, err
} }
entry := models.TransactionEntry{ entry := models.TransactionEntry{
UserID: userId, UserID: userId,
InvoiceID: invoice.ID,
CreditAccountID: creditAccount.ID, CreditAccountID: creditAccount.ID,
DebitAccountID: debitAccount.ID, DebitAccountID: debitAccount.ID,
Amount: 1000, Amount: invoice.Amount,
}
_, err = svc.DB.NewInsert().Model(&entry).Exec(context.TODO())
return err
} }
func (svc *LndhubService) AddInvoice(userID int64, amount int64, memo, descriptionHash string) (*models.Invoice, error) { // Start a DB transaction
// We rollback anything on error (only the invoice that was passed in to the PayInvoice calls stays in the DB)
tx, err := svc.DB.BeginTx(context.TODO(), &sql.TxOptions{})
if err != nil {
return &entry, err
}
_, err = tx.NewInsert().Model(&entry).Exec(context.TODO())
if err != nil {
tx.Rollback()
return &entry, err
}
// TODO: set fee limit
feeLimit := lnrpc.FeeLimit{
Limit: &lnrpc.FeeLimit_Percent{
Percent: 2,
},
}
// Prepare the LNRPC call
sendPaymentRequest := lnrpc.SendRequest{
PaymentRequest: invoice.PaymentRequest,
Amt: invoice.Amount,
FeeLimit: &feeLimit,
}
lndClient := *svc.LndClient
// Execute the payment
sendPaymentResult, err := lndClient.SendPaymentSync(context.TODO(), &sendPaymentRequest)
if err != nil {
tx.Rollback()
return &entry, err
}
// If there was a payment error we rollback and return an error
if sendPaymentResult.GetPaymentError() != "" || sendPaymentResult.GetPaymentPreimage() == nil {
tx.Rollback()
return &entry, errors.New(sendPaymentResult.GetPaymentError())
}
// The payment was successful.
// We store the preimage and mark the invoice as settled
preimage := sendPaymentResult.GetPaymentPreimage()
invoice.Preimage = hex.EncodeToString(preimage[:])
invoice.State = "settled"
invoice.SettledAt = schema.NullTime{Time: time.Now()}
_, err = svc.DB.NewUpdate().Model(invoice).WherePK().Exec(context.TODO())
if err != nil {
tx.Rollback()
return &entry, err
}
// Commit the DB transaction. Done, everything worked
err = tx.Commit()
if err != nil {
return &entry, err
}
return &entry, err
}
func (svc *LndhubService) AddOutgoingInvoice(userID int64, paymentRequest string, decodedInvoice zpay32.Invoice) (*models.Invoice, error) {
// Initialize new DB invoice
invoice := models.Invoice{
Type: "outgoing",
UserID: userID,
Memo: *decodedInvoice.Description,
PaymentRequest: paymentRequest,
State: "initialized",
}
if decodedInvoice.DescriptionHash != nil {
dh := *decodedInvoice.DescriptionHash
invoice.DescriptionHash = hex.EncodeToString(dh[:])
}
if decodedInvoice.PaymentHash != nil {
ph := *decodedInvoice.PaymentHash
invoice.RHash = hex.EncodeToString(ph[:])
}
if decodedInvoice.MilliSat != nil {
msat := decodedInvoice.MilliSat
invoice.Amount = int64(msat.ToSatoshis())
}
// Save invoice
_, err := svc.DB.NewInsert().Model(&invoice).Exec(context.TODO())
if err != nil {
return nil, err
}
return &invoice, nil
}
func (svc *LndhubService) AddIncomingInvoice(userID int64, amount int64, memo, descriptionHash string) (*models.Invoice, error) {
// Initialize new DB invoice // Initialize new DB invoice
invoice := models.Invoice{ invoice := models.Invoice{
Type: "incoming", Type: "incoming",
@@ -86,6 +185,10 @@ func (svc *LndhubService) AddInvoice(userID int64, amount int64, memo, descripti
return &invoice, nil return &invoice, nil
} }
func (svc *LndhubService) DecodePaymentRequest(bolt11 string) (*zpay32.Invoice, error) {
return zpay32.Decode(bolt11, ChainFromCurrency(bolt11[2:]))
}
const hexBytes = random.Hex const hexBytes = random.Hex
func makePreimageHex() []byte { func makePreimageHex() []byte {
@@ -95,3 +198,15 @@ func makePreimageHex() []byte {
} }
return b return b
} }
func ChainFromCurrency(currency string) *chaincfg.Params {
if strings.HasPrefix(currency, "bcrt") {
return &chaincfg.RegressionNetParams
} else if strings.HasPrefix(currency, "tb") {
return &chaincfg.TestNet3Params
} else if strings.HasPrefix(currency, "sb") {
return &chaincfg.SimNetParams
} else {
return &chaincfg.MainNetParams
}
}