diff --git a/controllers/balance.ctrl.go b/controllers/balance.ctrl.go index e4a1676..1fd66ce 100644 --- a/controllers/balance.ctrl.go +++ b/controllers/balance.ctrl.go @@ -6,7 +6,6 @@ import ( "github.com/getAlby/lndhub.go/lib" "github.com/labstack/echo/v4" - "github.com/lightningnetwork/lnd/lnrpc" ) // BalanceController : BalanceController struct @@ -16,16 +15,17 @@ type BalanceController struct{} func (BalanceController) Balance(c echo.Context) error { ctx := c.(*lib.LndhubContext) c.Logger().Warn(ctx.User.ID) - lndClient := *ctx.LndClient - getInfo, err := lndClient.GetInfo(context.TODO(), &lnrpc.GetInfoRequest{}) + + db := ctx.DB + + balance, err := ctx.User.CurrentBalance(context.TODO(), db) if err != nil { - panic(err) + return err } - c.Logger().Infof("Connected to LND: %s - %s", getInfo.Alias, getInfo.IdentityPubkey) return c.JSON(http.StatusOK, echo.Map{ "BTC": echo.Map{ - "AvailableBalance": 1, + "AvailableBalance": balance, }, }) } diff --git a/controllers/create.ctrl.go b/controllers/create.ctrl.go index 88f9b59..cd473df 100644 --- a/controllers/create.ctrl.go +++ b/controllers/create.ctrl.go @@ -2,6 +2,7 @@ package controllers import ( "context" + "database/sql" "math/rand" "net/http" @@ -10,6 +11,7 @@ import ( "github.com/getAlby/lndhub.go/lib/security" "github.com/labstack/echo/v4" "github.com/labstack/gommon/random" + "github.com/uptrace/bun" ) const alphaNumBytes = random.Alphanumeric @@ -20,6 +22,8 @@ type CreateUserController struct{} // CreateUser : Create user Controller func (CreateUserController) CreateUser(c echo.Context) error { ctx := c.(*lib.LndhubContext) + + // optional parameters that we currently do not use type RequestBody struct { PartnerID string `json:"partnerid"` AccountType string `json:"accounttype"` @@ -34,14 +38,35 @@ func (CreateUserController) CreateUser(c echo.Context) error { user := models.User{} + // generate user login/password (TODO: allow the user to choose a login/password?) user.Login = randStringBytes(8) password := randStringBytes(15) + // we only store the hashed password but return the initial plain text password in the HTTP response hashedPassword := security.HashPassword(password) user.Password = hashedPassword - if _, err := db.NewInsert().Model(&user).Exec(context.TODO()); err != nil { + // Create user and the user's accounts + // We use double-entry bookkeeping so we use 4 accounts: incoming, current, outgoing and fees + // Wrapping this in a transaction in case something fails + err := db.RunInTx(context.TODO(), &sql.TxOptions{}, func(ctx context.Context, tx bun.Tx) error { + if _, err := tx.NewInsert().Model(&user).Exec(ctx); err != nil { + return err + } + accountTypes := []string{"incoming", "current", "outgoing", "fees"} + for _, accountType := range accountTypes { + account := models.Account{UserID: user.ID, Type: accountType} + if _, err := db.NewInsert().Model(&account).Exec(ctx); err != nil { + return err + } + } + return nil + }) + + // Was the DB transaction successful? + if err != nil { return err } + var ResponseBody struct { Login string `json:"login"` Password string `json:"password"` diff --git a/controllers/payinvoice.ctrl.go b/controllers/payinvoice.ctrl.go index 4bbdeba..1ad77b5 100644 --- a/controllers/payinvoice.ctrl.go +++ b/controllers/payinvoice.ctrl.go @@ -1,8 +1,11 @@ package controllers import ( + "context" "net/http" + "github.com/getAlby/lndhub.go/db/models" + "github.com/getAlby/lndhub.go/lib" "github.com/labstack/echo/v4" ) @@ -11,9 +14,10 @@ type PayInvoiceController struct{} // PayInvoice : Pay invoice Controller func (PayInvoiceController) PayInvoice(c echo.Context) error { + ctx := c.(*lib.LndhubContext) var reqBody struct { Invoice string `json:"invoice" validate:"required"` - Amount int `json:"amount" validate:"gt=0"` + Amount int `json:"amount" validate:"omitempty,gte=0"` } if err := c.Bind(&reqBody); err != nil { @@ -28,5 +32,25 @@ func (PayInvoiceController) PayInvoice(c echo.Context) error { }) } + db := ctx.DB + debitAccount, err := ctx.User.AccountFor("current", context.TODO(), db) + if err != nil { + return err + } + creditAccount, err := ctx.User.AccountFor("outgoing", context.TODO(), db) + if err != nil { + return err + } + + entry := models.TransactionEntry{ + UserID: ctx.User.ID, + CreditAccountID: creditAccount.ID, + DebitAccountID: debitAccount.ID, + Amount: 1000, + } + if _, err := db.NewInsert().Model(&entry).Exec(context.TODO()); err != nil { + return err + } + return nil } diff --git a/db/migrations/20220119021600_balances.up.sql b/db/migrations/20220119021600_balances.up.sql new file mode 100644 index 0000000..307c566 --- /dev/null +++ b/db/migrations/20220119021600_balances.up.sql @@ -0,0 +1,18 @@ +CREATE VIEW account_ledgers( + account_id, + transaction_entry_id, + amount +) AS + SELECT + transaction_entries.credit_account_id, + transaction_entries.id, + transaction_entries.amount + FROM + transaction_entries + UNION ALL + SELECT + transaction_entries.debit_account_id, + transaction_entries.id, + (0 - transaction_entries.amount) + FROM + transaction_entries; diff --git a/db/migrations/main.go b/db/migrations/main.go index ca4148a..7dfd5e0 100644 --- a/db/migrations/main.go +++ b/db/migrations/main.go @@ -8,9 +8,7 @@ import ( var Migrations = migrate.NewMigrations() -// remove the first of 3 slashes to enable sql migrations -// probably not needed as we are targeting several dialects -///go:embed *.sql +//go:embed *.sql var sqlMigrations embed.FS func init() { diff --git a/db/models/account.go b/db/models/account.go index 3a9a1f5..9214eb3 100644 --- a/db/models/account.go +++ b/db/models/account.go @@ -2,6 +2,8 @@ package models // Account : Account Model type Account struct { - UserID int64 - Type string + ID int64 `bun:",pk,autoincrement"` + UserID int64 `bun:",notnull"` + User *User `bun:"rel:belongs-to,join:user_id=id"` + Type string `bun:",notnull"` } diff --git a/db/models/invoice.go b/db/models/invoice.go index 3d28c73..391fec8 100644 --- a/db/models/invoice.go +++ b/db/models/invoice.go @@ -12,6 +12,7 @@ type Invoice struct { ID uint `json:"id" bun:",pk,autoincrement"` Type string `json:"type"` UserID int64 `json:"user_id"` + User *User `bun:"rel:belongs-to,join:user_id=id"` TransactionEntryID uint `json:"transaction_entry_id"` Amount uint `json:"amount"` Memo string `json:"memo"` diff --git a/db/models/transactionentries.go b/db/models/transactionentries.go deleted file mode 100644 index 419af3e..0000000 --- a/db/models/transactionentries.go +++ /dev/null @@ -1,15 +0,0 @@ -package models - -import ( - "time" -) - -// TransactionEntry : Transaction Entries Model -type TransactionEntry struct { - UserID uint - InvoiceID uint - CreditAccountID uint - DebitAccountID uint - Amount uint64 - CreatedAt time.Time `bun:",nullzero,notnull,default:current_timestamp"` -} diff --git a/db/models/transactionentry.go b/db/models/transactionentry.go new file mode 100644 index 0000000..f429e83 --- /dev/null +++ b/db/models/transactionentry.go @@ -0,0 +1,22 @@ +package models + +import ( + "time" +) + +// TransactionEntry : Transaction Entries Model +type TransactionEntry struct { + ID int64 `bun:",pk,autoincrement"` + UserID int64 `bun:",notnull"` + User *User `bun:"rel:belongs-to,join:user_id=id"` + InvoiceID int64 `bun:",notnull"` + Invoice *Invoice `bun:"rel:belongs-to,join:invoice_id=id"` + ParentID int64 + Parent *TransactionEntry `bun:"rel:belongs-to"` + CreditAccountID int64 `bun:",notnull"` + CreditAccount *Account `bun:"rel:belongs-to,join:credit_account_id=id"` + DebitAccountID int64 `bun:",notnull"` + DebitAccount *Account `bun:"rel:belongs-to,join:debit_account_id=id"` + Amount int64 `bun:",notnull"` + CreatedAt time.Time `bun:",nullzero,notnull,default:current_timestamp"` +} diff --git a/db/models/user.go b/db/models/user.go index 9f06791..cbcec60 100644 --- a/db/models/user.go +++ b/db/models/user.go @@ -16,6 +16,8 @@ type User struct { Password string `bun:",notnull"` CreatedAt time.Time `bun:",nullzero,notnull,default:current_timestamp"` UpdatedAt bun.NullTime + Invoices []*Invoice `bun:"rel:has-many,join:id=user_id"` + Accounts []*Account `bun:"rel:has-many,join:id=user_id"` } func (u *User) BeforeAppendModel(ctx context.Context, query bun.Query) error { @@ -26,4 +28,21 @@ func (u *User) BeforeAppendModel(ctx context.Context, query bun.Query) error { return nil } +func (u *User) AccountFor(accountType string, ctx context.Context, db bun.IDB) (Account, error) { + account := Account{} + err := db.NewSelect().Model(&account).Where("user_id = ? AND type= ?", u.ID, accountType).Limit(1).Scan(ctx) + return account, err +} + +func (u *User) CurrentBalance(ctx context.Context, db bun.IDB) (int64, error) { + var balance int64 + + account, err := u.AccountFor("current", ctx, db) + if err != nil { + return balance, err + } + err = db.NewSelect().Table("account_ledgers").ColumnExpr("sum(account_ledgers.amount) as balance").Where("account_ledgers.account_id = ?", account.ID).Scan(context.TODO(), &balance) + return balance, err +} + var _ bun.BeforeAppendModelHook = (*User)(nil)