diff --git a/controllers/payinvoice.ctrl.go b/controllers/payinvoice.ctrl.go index 3ce9628..ce7ee33 100644 --- a/controllers/payinvoice.ctrl.go +++ b/controllers/payinvoice.ctrl.go @@ -79,7 +79,8 @@ func (controller *PayInvoiceController) PayInvoice(c echo.Context) error { entry, err := controller.svc.PayInvoice(invoice) if err != nil { - c.Logger().Errorf("Failed: %v", err) + c.Logger().Errorf("Payment failed: %v", err) + // TODO: sentry notification return c.JSON(http.StatusBadRequest, echo.Map{ "error": true, "code": 10, diff --git a/db/migrations/20220120000700_add_constraints.up.go b/db/migrations/20220120000700_add_constraints.up.go new file mode 100644 index 0000000..3dfcf7b --- /dev/null +++ b/db/migrations/20220120000700_add_constraints.up.go @@ -0,0 +1,59 @@ +package migrations + +import ( + "context" + "fmt" + + "github.com/uptrace/bun" +) + +func init() { + Migrations.MustRegister(func(ctx context.Context, db *bun.DB) error { + + if db.Dialect().Name().String() != "pg" { + fmt.Printf("\033[1;31m%s\033[0m", "You are not using PostgreSQL. DB level checks can not be enabled!\n") + return nil + } + sql := ` + -- make sure transfers happen from one account to another one + alter table transaction_entries + ADD CONSTRAINT check_not_same_account + CHECK (debit_account_id != credit_account_id); + + -- make sure that account balances >= 0 (except for incoming account) + CREATE OR REPLACE FUNCTION check_balance() + RETURNS TRIGGER AS $$ + DECLARE + sum BIGINT; + debit_account_type VARCHAR; + BEGIN + SELECT INTO debit_account_type type + FROM accounts + WHERE id = NEW.debit_account_id; + + SELECT INTO sum SUM(amount) + FROM account_ledgers + WHERE account_ledgers.account_id = NEW.debit_account_id; + + -- the incoming account can have a negative balance + -- all other accounts must have a positive balance + IF sum < 0 AND debit_account_type != 'incoming' + THEN + RAISE EXCEPTION 'invalid balance [user_id:%] [debit_account_id:%] balance [%]', + NEW.user_id, + NEW.debit_account_id, + sum; + END IF; + RETURN NEW; + END; + $$ LANGUAGE plpgsql; + CREATE TRIGGER check_balance + AFTER INSERT OR UPDATE ON transaction_entries + FOR EACH ROW EXECUTE PROCEDURE check_balance(); + ` + if _, err := db.Exec(sql); err != nil { + return err + } + return nil + }, nil) +} diff --git a/lib/service/invoices.go b/lib/service/invoices.go index aa6e6d6..bdc13a8 100644 --- a/lib/service/invoices.go +++ b/lib/service/invoices.go @@ -54,6 +54,9 @@ func (svc *LndhubService) PayInvoice(invoice *models.Invoice) (*models.Transacti if err != nil { return &entry, err } + + // The DB constraints make sure the user actually has enough balance for the transaction + // If the user does not have enough balance this call fails _, err = tx.NewInsert().Model(&entry).Exec(context.TODO()) if err != nil { tx.Rollback()