Implements SQLite repositories (#180)

* add sqlite db

* add .vscode to gitignore

* add vtxo repo

* add sqlite repos implementations

* add sqlite in db/service

* update go.mod

* fix sqlite

* move sqlite tests to service_test.go + fixes

* integration tests using sqlite + properly close statements

* implement GetRoundsIds

* add "tx" table to store forfeits, connectors and congestion trees

* add db max conn = 1

* upsert VTXO + fix onboarding

* remove json tags

* Fixes

* Fix

* fix lint

* fix config.go

* Fix rm config & open db only once

* Update makefile

---------

Co-authored-by: altafan <18440657+altafan@users.noreply.github.com>
This commit is contained in:
Louis Singer
2024-06-19 18:16:31 +02:00
committed by GitHub
parent 584131764f
commit bb208ec995
18 changed files with 1469 additions and 119 deletions

View File

@@ -0,0 +1,691 @@
package sqlitedb
import (
"context"
"database/sql"
"errors"
"fmt"
"github.com/ark-network/ark/common/tree"
"github.com/ark-network/ark/internal/core/domain"
)
const (
createReceiverTable = `
CREATE TABLE IF NOT EXISTS receiver (
payment_id TEXT NOT NULL,
pubkey TEXT NOT NULL,
amount INTEGER NOT NULL,
onchain_address TEXT NOT NULL,
FOREIGN KEY (payment_id) REFERENCES payment(id)
PRIMARY KEY (payment_id, pubkey)
);
`
createPaymentTable = `
CREATE TABLE IF NOT EXISTS payment (
id TEXT PRIMARY KEY,
round_id TEXT NOT NULL,
FOREIGN KEY (round_id) REFERENCES round(id)
);
`
createRoundTable = `
CREATE TABLE IF NOT EXISTS round (
id TEXT PRIMARY KEY,
starting_timestamp INTEGER NOT NULL,
ending_timestamp INTEGER NOT NULL,
ended BOOLEAN NOT NULL,
failed BOOLEAN NOT NULL,
stage_code INTEGER NOT NULL,
txid TEXT NOT NULL,
unsigned_tx TEXT NOT NULL,
connector_address TEXT NOT NULL,
dust_amount INTEGER NOT NULL,
version INTEGER NOT NULL,
swept BOOLEAN NOT NULL
);
`
createTransactionTable = `
CREATE TABLE IF NOT EXISTS tx (
id INTEGER PRIMARY KEY AUTOINCREMENT,
tx TEXT NOT NULL,
round_id TEXT NOT NULL,
type TEXT NOT NULL,
position INTEGER NOT NULL,
txid TEXT,
tree_level INTEGER,
parent_txid TEXT,
is_leaf BOOLEAN,
FOREIGN KEY (round_id) REFERENCES round(id)
);
`
upsertTransaction = `
INSERT INTO tx (
tx, round_id, type, position, txid, tree_level, parent_txid, is_leaf
) VALUES (?, ?, ?, ?, ?, ?, ?, ?)
ON CONFLICT(id) DO UPDATE SET
tx = EXCLUDED.tx,
round_id = EXCLUDED.round_id,
type = EXCLUDED.type,
position = EXCLUDED.position,
txid = EXCLUDED.txid,
tree_level = EXCLUDED.tree_level,
parent_txid = EXCLUDED.parent_txid,
is_leaf = EXCLUDED.is_leaf;
`
upsertRound = `
INSERT INTO round (
id,
starting_timestamp,
ending_timestamp,
ended, failed,
stage_code,
txid,
unsigned_tx,
connector_address,
dust_amount,
version,
swept
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
ON CONFLICT(id) DO UPDATE SET
starting_timestamp = EXCLUDED.starting_timestamp,
ending_timestamp = EXCLUDED.ending_timestamp,
ended = EXCLUDED.ended,
failed = EXCLUDED.failed,
stage_code = EXCLUDED.stage_code,
txid = EXCLUDED.txid,
unsigned_tx = EXCLUDED.unsigned_tx,
connector_address = EXCLUDED.connector_address,
dust_amount = EXCLUDED.dust_amount,
version = EXCLUDED.version,
swept = EXCLUDED.swept;
`
upsertPayment = `
INSERT INTO payment (id, round_id) VALUES (?, ?)
ON CONFLICT(id) DO UPDATE SET round_id = EXCLUDED.round_id;
`
upsertReceiver = `
INSERT INTO receiver (payment_id, pubkey, amount, onchain_address) VALUES (?, ?, ?, ?)
ON CONFLICT(payment_id, pubkey) DO UPDATE SET
amount = EXCLUDED.amount,
onchain_address = EXCLUDED.onchain_address,
pubkey = EXCLUDED.pubkey;
`
updateVtxoPaymentId = `
UPDATE vtxo SET payment_id = ? WHERE txid = ? AND vout = ?
`
selectRound = `
SELECT round.id, round.starting_timestamp, round.ending_timestamp, round.ended, round.failed, round.stage_code, round.txid,
round.unsigned_tx, round.connector_address, round.dust_amount, round.version, round.swept, payment.id, receiver.payment_id,
receiver.pubkey, receiver.amount, receiver.onchain_address, vtxo.txid, vtxo.vout, vtxo.pubkey, vtxo.amount,
vtxo.pool_tx, vtxo.spent_by, vtxo.spent, vtxo.redeemed, vtxo.swept, vtxo.expire_at, vtxo.payment_id,
tx.tx, tx.type, tx.position, tx.txid,
tx.tree_level, tx.parent_txid, tx.is_leaf
FROM round
LEFT OUTER JOIN payment ON round.id=payment.round_id
LEFT OUTER JOIN tx ON round.id=tx.round_id
LEFT OUTER JOIN receiver ON payment.id=receiver.payment_id
LEFT OUTER JOIN vtxo ON payment.id=vtxo.payment_id
`
selectCurrentRound = selectRound + " WHERE round.ended = false AND round.failed = false;"
selectRoundWithId = selectRound + " WHERE round.id = ?;"
selectRoundWithTxId = selectRound + " WHERE round.txid = ?;"
selectSweepableRounds = selectRound + " WHERE round.swept = false AND round.ended = true AND round.failed = false;"
selectSweptRounds = selectRound + " WHERE round.swept = true AND round.failed = false AND round.ended = true;"
selectRoundIdsInRange = `
SELECT id FROM round WHERE starting_timestamp > ? AND starting_timestamp < ?;
`
selectRoundIds = `
SELECT id FROM round;
`
)
type receiverRow struct {
paymentId *string
pubkey *string
amount *uint64
onchainAddress *string
}
type paymentRow struct {
id *string
}
type transactionRow struct {
tx *string
txType *string
position *int
txid *string
treeLevel *int
parentTxid *string
isLeaf *bool
}
type roundRow struct {
id *string
startingTimestamp *int64
endingTimestamp *int64
ended *bool
failed *bool
stageCode *domain.RoundStage
txid *string
unsignedTx *string
connectorAddress *string
dustAmount *uint64
version *uint
swept *bool
}
type roundRepository struct {
db *sql.DB
}
func NewRoundRepository(config ...interface{}) (domain.RoundRepository, error) {
if len(config) != 1 {
return nil, fmt.Errorf("invalid config")
}
db, ok := config[0].(*sql.DB)
if !ok {
return nil, fmt.Errorf("cannot open round repository: invalid config")
}
return newRoundRepository(db)
}
func newRoundRepository(db *sql.DB) (*roundRepository, error) {
if _, err := db.Exec(createRoundTable); err != nil {
return nil, err
}
if _, err := db.Exec(createPaymentTable); err != nil {
return nil, err
}
if _, err := db.Exec(createReceiverTable); err != nil {
return nil, err
}
if _, err := db.Exec(createTransactionTable); err != nil {
return nil, err
}
return &roundRepository{db}, nil
}
func (r *roundRepository) Close() {
_ = r.db.Close()
}
func (r *roundRepository) GetRoundsIds(ctx context.Context, startedAfter int64, startedBefore int64) ([]string, error) {
var rows *sql.Rows
if startedAfter == 0 && startedBefore == 0 {
stmt, err := r.db.Prepare(selectRoundIds)
if err != nil {
return nil, err
}
defer stmt.Close()
rows, err = stmt.Query()
if err != nil {
return nil, err
}
} else {
stmt, err := r.db.Prepare(selectRoundIdsInRange)
if err != nil {
return nil, err
}
defer stmt.Close()
rows, err = stmt.Query(startedAfter, startedBefore)
if err != nil {
return nil, err
}
}
defer rows.Close()
ids := make([]string, 0)
for rows.Next() {
var id string
if err := rows.Scan(&id); err != nil {
return nil, err
}
ids = append(ids, id)
}
return ids, nil
}
func (r *roundRepository) AddOrUpdateRound(ctx context.Context, round domain.Round) error {
tx, err := r.db.Begin()
if err != nil {
return err
}
stmt, err := tx.Prepare(upsertRound)
if err != nil {
return err
}
defer stmt.Close()
// insert round row
_, err = stmt.Exec(
round.Id,
round.StartingTimestamp,
round.EndingTimestamp,
round.Stage.Ended,
round.Stage.Failed,
round.Stage.Code,
round.Txid,
round.UnsignedTx,
round.ConnectorAddress,
round.DustAmount,
round.Version,
round.Swept,
)
if err != nil {
return err
}
// insert transactions rows
if len(round.ForfeitTxs) > 0 || len(round.Connectors) > 0 || len(round.CongestionTree) > 0 {
stmt, err = tx.Prepare(upsertTransaction)
if err != nil {
return err
}
defer stmt.Close()
for pos, tx := range round.ForfeitTxs {
_, err := stmt.Exec(tx, round.Id, "forfeit", pos, nil, nil, nil, nil)
if err != nil {
return err
}
}
for pos, tx := range round.Connectors {
_, err := stmt.Exec(tx, round.Id, "connector", pos, nil, nil, nil, nil)
if err != nil {
return err
}
}
for level, levelTxs := range round.CongestionTree {
for pos, tx := range levelTxs {
_, err := stmt.Exec(tx.Tx, round.Id, "tree", pos, tx.Txid, level, tx.ParentTxid, tx.Leaf)
if err != nil {
return err
}
}
}
}
// insert payments rows
if len(round.Payments) > 0 {
stmtUpsertPayment, err := tx.Prepare(upsertPayment)
if err != nil {
return err
}
defer stmtUpsertPayment.Close()
for _, payment := range round.Payments {
_, err = stmtUpsertPayment.Exec(payment.Id, round.Id)
if err != nil {
return err
}
stmtUpsertReceiver, err := tx.Prepare(upsertReceiver)
if err != nil {
return err
}
defer stmtUpsertReceiver.Close()
for _, receiver := range payment.Receivers {
_, err := stmtUpsertReceiver.Exec(payment.Id, receiver.Pubkey, receiver.Amount, receiver.OnchainAddress)
if err != nil {
return err
}
}
stmtUpdatePaymentId, err := tx.Prepare(updateVtxoPaymentId)
if err != nil {
return err
}
defer stmtUpdatePaymentId.Close()
for _, input := range payment.Inputs {
_, err := stmtUpdatePaymentId.Exec(payment.Id, input.Txid, input.VOut)
if err != nil {
return err
}
}
}
}
return tx.Commit()
}
func (r *roundRepository) GetCurrentRound(ctx context.Context) (*domain.Round, error) {
stmt, err := r.db.Prepare(selectCurrentRound)
if err != nil {
return nil, err
}
defer stmt.Close()
rows, err := stmt.Query()
if err != nil {
return nil, err
}
rounds, err := readRoundRows(rows)
if err != nil {
return nil, err
}
if len(rounds) == 0 {
return nil, errors.New("no current round")
}
return rounds[0], nil
}
func (r *roundRepository) GetRoundWithId(ctx context.Context, id string) (*domain.Round, error) {
stmt, err := r.db.Prepare(selectRoundWithId)
if err != nil {
return nil, err
}
defer stmt.Close()
rows, err := stmt.Query(id)
if err != nil {
return nil, err
}
rounds, err := readRoundRows(rows)
if err != nil {
return nil, err
}
if len(rounds) > 0 {
return rounds[0], nil
}
return nil, errors.New("round not found")
}
func (r *roundRepository) GetRoundWithTxid(ctx context.Context, txid string) (*domain.Round, error) {
stmt, err := r.db.Prepare(selectRoundWithTxId)
if err != nil {
return nil, err
}
defer stmt.Close()
rows, err := stmt.Query(txid)
if err != nil {
return nil, err
}
rounds, err := readRoundRows(rows)
if err != nil {
return nil, err
}
if len(rounds) > 0 {
return rounds[0], nil
}
return nil, errors.New("round not found")
}
func (r *roundRepository) GetSweepableRounds(ctx context.Context) ([]domain.Round, error) {
stmt, err := r.db.Prepare(selectSweepableRounds)
if err != nil {
return nil, err
}
defer stmt.Close()
rows, err := stmt.Query()
if err != nil {
return nil, err
}
rounds, err := readRoundRows(rows)
if err != nil {
return nil, err
}
res := make([]domain.Round, 0)
for _, round := range rounds {
res = append(res, *round)
}
return res, nil
}
func (r *roundRepository) GetSweptRounds(ctx context.Context) ([]domain.Round, error) {
stmt, err := r.db.Prepare(selectSweptRounds)
if err != nil {
return nil, err
}
defer stmt.Close()
rows, err := stmt.Query()
if err != nil {
return nil, err
}
rounds, err := readRoundRows(rows)
if err != nil {
return nil, err
}
res := make([]domain.Round, 0)
for _, round := range rounds {
res = append(res, *round)
}
return res, nil
}
func rowToReceiver(row receiverRow) domain.Receiver {
return domain.Receiver{
Pubkey: *row.pubkey,
Amount: *row.amount,
OnchainAddress: *row.onchainAddress,
}
}
func readRoundRows(rows *sql.Rows) ([]*domain.Round, error) {
defer rows.Close()
rounds := make(map[string]*domain.Round)
for rows.Next() {
var roundRow roundRow
var paymentRow paymentRow
var receiverRow receiverRow
var vtxoRow vtxoRow
var transactionRow transactionRow
if err := rows.Scan(
&roundRow.id,
&roundRow.startingTimestamp,
&roundRow.endingTimestamp,
&roundRow.ended,
&roundRow.failed,
&roundRow.stageCode,
&roundRow.txid,
&roundRow.unsignedTx,
&roundRow.connectorAddress,
&roundRow.dustAmount,
&roundRow.version,
&roundRow.swept,
&paymentRow.id,
&receiverRow.paymentId,
&receiverRow.pubkey,
&receiverRow.amount,
&receiverRow.onchainAddress,
&vtxoRow.txid,
&vtxoRow.vout,
&vtxoRow.pubkey,
&vtxoRow.amount,
&vtxoRow.poolTx,
&vtxoRow.spentBy,
&vtxoRow.spent,
&vtxoRow.redeemed,
&vtxoRow.swept,
&vtxoRow.expireAt,
&vtxoRow.paymentID,
&transactionRow.tx,
&transactionRow.txType,
&transactionRow.position,
&transactionRow.txid,
&transactionRow.treeLevel,
&transactionRow.parentTxid,
&transactionRow.isLeaf,
); err != nil {
return nil, err
}
var round *domain.Round
var ok bool
if roundRow.id == nil {
continue
}
round, ok = rounds[*roundRow.id]
if !ok {
round = &domain.Round{
Id: *roundRow.id,
StartingTimestamp: *roundRow.startingTimestamp,
EndingTimestamp: *roundRow.endingTimestamp,
Stage: domain.Stage{
Ended: *roundRow.ended,
Failed: *roundRow.failed,
Code: *roundRow.stageCode,
},
Txid: *roundRow.txid,
UnsignedTx: *roundRow.unsignedTx,
ConnectorAddress: *roundRow.connectorAddress,
DustAmount: *roundRow.dustAmount,
Version: *roundRow.version,
Swept: *roundRow.swept,
Payments: make(map[string]domain.Payment),
}
}
if paymentRow.id != nil {
payment, ok := round.Payments[*paymentRow.id]
if !ok {
payment = domain.Payment{
Id: *paymentRow.id,
Inputs: make([]domain.Vtxo, 0),
Receivers: make([]domain.Receiver, 0),
}
round.Payments[*paymentRow.id] = payment
}
if vtxoRow.paymentID != nil {
payment, ok = round.Payments[*vtxoRow.paymentID]
if !ok {
payment = domain.Payment{
Id: *vtxoRow.paymentID,
Inputs: make([]domain.Vtxo, 0),
Receivers: make([]domain.Receiver, 0),
}
}
vtxo := rowToVtxo(vtxoRow)
found := false
for _, v := range payment.Inputs {
if vtxo.Txid == v.Txid && vtxo.VOut == v.VOut {
found = true
break
}
}
if !found {
payment.Inputs = append(payment.Inputs, rowToVtxo(vtxoRow))
round.Payments[*vtxoRow.paymentID] = payment
}
}
if receiverRow.paymentId != nil {
payment, ok = round.Payments[*receiverRow.paymentId]
if !ok {
payment = domain.Payment{
Id: *receiverRow.paymentId,
Inputs: make([]domain.Vtxo, 0),
Receivers: make([]domain.Receiver, 0),
}
}
rcv := rowToReceiver(receiverRow)
found := false
for _, rcv := range payment.Receivers {
if rcv.Pubkey == *receiverRow.pubkey && rcv.Amount == *receiverRow.amount {
found = true
break
}
}
if !found {
payment.Receivers = append(payment.Receivers, rcv)
round.Payments[*receiverRow.paymentId] = payment
}
}
}
if transactionRow.tx != nil {
position := *transactionRow.position
switch *transactionRow.txType {
case "forfeit":
round.ForfeitTxs = extendArray(round.ForfeitTxs, position)
round.ForfeitTxs[position] = *transactionRow.tx
case "connector":
round.Connectors = extendArray(round.Connectors, position)
round.Connectors[position] = *transactionRow.tx
case "tree":
level := *transactionRow.treeLevel
round.CongestionTree = extendArray(round.CongestionTree, level)
round.CongestionTree[level] = extendArray(round.CongestionTree[level], position)
if round.CongestionTree[level][position] == (tree.Node{}) {
round.CongestionTree[level][position] = tree.Node{
Tx: *transactionRow.tx,
Txid: *transactionRow.txid,
ParentTxid: *transactionRow.parentTxid,
Leaf: *transactionRow.isLeaf,
}
}
}
}
rounds[*roundRow.id] = round
}
var result []*domain.Round
for _, round := range rounds {
result = append(result, round)
}
return result, nil
}