diff --git a/controllers/invoicestream.ctrl.go b/controllers/invoicestream.ctrl.go index 2ec80bc..e41ad63 100644 --- a/controllers/invoicestream.ctrl.go +++ b/controllers/invoicestream.ctrl.go @@ -33,7 +33,10 @@ func (controller *InvoiceStreamController) StreamInvoices(c echo.Context) error return err } invoiceChan := make(chan models.Invoice) - subId := controller.svc.InvoicePubSub.Subscribe(userId, invoiceChan) + subId, err := controller.svc.InvoicePubSub.Subscribe(userId, invoiceChan) + if err != nil { + return err + } upgrader := websocket.Upgrader{} upgrader.CheckOrigin = func(r *http.Request) bool { return true } ticker := time.NewTicker(30 * time.Second) diff --git a/lib/service/invoices.go b/lib/service/invoices.go index f0ccfd3..37f3890 100644 --- a/lib/service/invoices.go +++ b/lib/service/invoices.go @@ -7,7 +7,6 @@ import ( "errors" "fmt" "math" - "math/rand" "time" "github.com/getAlby/lndhub.go/common" @@ -141,7 +140,10 @@ func createLnRpcSendRequest(invoice *models.Invoice) (*lnrpc.SendRequest, error) }, nil } - preImage := makePreimageHex() + preImage, err := makePreimageHex() + if err != nil { + return nil, err + } pHash := sha256.New() pHash.Write(preImage) // Prepare the LNRPC call @@ -336,7 +338,10 @@ func (svc *LndhubService) AddOutgoingInvoice(ctx context.Context, userID int64, } func (svc *LndhubService) AddIncomingInvoice(ctx context.Context, userID int64, amount int64, memo, descriptionHashStr string) (*models.Invoice, error) { - preimage := makePreimageHex() + preimage, err := makePreimageHex() + if err != nil { + return nil, err + } expiry := time.Hour * 24 // invoice expires in 24h // Initialize new DB invoice invoice := models.Invoice{ @@ -350,7 +355,7 @@ func (svc *LndhubService) AddIncomingInvoice(ctx context.Context, userID int64, } // Save invoice - we save the invoice early to have a record in case the LN call fails - _, err := svc.DB.NewInsert().Model(&invoice).Exec(ctx) + _, err = svc.DB.NewInsert().Model(&invoice).Exec(ctx) if err != nil { return nil, err } @@ -395,10 +400,6 @@ func (svc *LndhubService) DecodePaymentRequest(ctx context.Context, bolt11 strin const hexBytes = random.Hex -func makePreimageHex() []byte { - b := make([]byte, 32) - for i := range b { - b[i] = hexBytes[rand.Intn(len(hexBytes))] - } - return b +func makePreimageHex() ([]byte, error) { + return randBytesFromStr(32, hexBytes) } diff --git a/lib/service/pubsub.go b/lib/service/pubsub.go index f71b853..09443bd 100644 --- a/lib/service/pubsub.go +++ b/lib/service/pubsub.go @@ -17,16 +17,20 @@ func NewPubsub() *Pubsub { return ps } -func (ps *Pubsub) Subscribe(topic int64, ch chan models.Invoice) (subId string) { +func (ps *Pubsub) Subscribe(topic int64, ch chan models.Invoice) (subId string, err error) { ps.mu.Lock() defer ps.mu.Unlock() if ps.subs[topic] == nil { ps.subs[topic] = make(map[string]chan models.Invoice) } //re-use preimage code for a uuid - subId = string(makePreimageHex()) + preImageHex, err := makePreimageHex() + if err != nil { + return "", err + } + subId = string(preImageHex) ps.subs[topic][subId] = ch - return subId + return subId, nil } func (ps *Pubsub) Unsubscribe(id string, topic int64) error { diff --git a/lib/service/user.go b/lib/service/user.go index e8bc4ac..82780ba 100644 --- a/lib/service/user.go +++ b/lib/service/user.go @@ -3,7 +3,6 @@ package service import ( "context" "database/sql" - "math/rand" "github.com/getAlby/lndhub.go/common" "github.com/getAlby/lndhub.go/db/models" @@ -18,11 +17,19 @@ func (svc *LndhubService) CreateUser(ctx context.Context, login string, password // generate user login/password if not provided user.Login = login if login == "" { - user.Login = randStringBytes(20) + randLoginBytes, err := randBytesFromStr(20, alphaNumBytes) + if err != nil { + return nil, err + } + user.Login = string(randLoginBytes) } if password == "" { - password = randStringBytes(20) + randPasswordBytes, err := randBytesFromStr(20, alphaNumBytes) + if err != nil { + return nil, err + } + password = string(randPasswordBytes) } // we only store the hashed password but return the initial plain text password in the HTTP response @@ -112,11 +119,3 @@ func (svc *LndhubService) InvoicesFor(ctx context.Context, userId int64, invoice } return invoices, nil } - -func randStringBytes(n int) string { - b := make([]byte, n) - for i := range b { - b[i] = alphaNumBytes[rand.Intn(len(alphaNumBytes))] - } - return string(b) -} diff --git a/lib/service/util.go b/lib/service/util.go new file mode 100644 index 0000000..b30c7ad --- /dev/null +++ b/lib/service/util.go @@ -0,0 +1,19 @@ +package service + +import ( + "crypto/rand" + "math/big" +) + +func randBytesFromStr(length int, from string) ([]byte, error) { + b := make([]byte, length) + fromLenBigInt := big.NewInt(int64(len(from))) + for i := range b { + r, err := rand.Int(rand.Reader, fromLenBigInt) + if err != nil { + return nil, err + } + b[i] = from[r.Int64()] + } + return b, nil +}