Implements SQLite repositories (#180)

* add sqlite db

* add .vscode to gitignore

* add vtxo repo

* add sqlite repos implementations

* add sqlite in db/service

* update go.mod

* fix sqlite

* move sqlite tests to service_test.go + fixes

* integration tests using sqlite + properly close statements

* implement GetRoundsIds

* add "tx" table to store forfeits, connectors and congestion trees

* add db max conn = 1

* upsert VTXO + fix onboarding

* remove json tags

* Fixes

* Fix

* fix lint

* fix config.go

* Fix rm config & open db only once

* Update makefile

---------

Co-authored-by: altafan <18440657+altafan@users.noreply.github.com>
This commit is contained in:
Louis Singer
2024-06-19 18:16:31 +02:00
committed by GitHub
parent 584131764f
commit bb208ec995
18 changed files with 1469 additions and 119 deletions

View File

@@ -18,8 +18,12 @@ import (
const minAllowedSequence = 512
var (
supportedEventDbs = supportedType{
"badger": {},
}
supportedDbs = supportedType{
"badger": {},
"sqlite": {},
}
supportedSchedulers = supportedType{
"gocron": {},
@@ -34,7 +38,9 @@ var (
type Config struct {
DbType string
EventDbType string
DbDir string
EventDbDir string
RoundInterval int64
Network common.Network
SchedulerType string
@@ -55,6 +61,9 @@ type Config struct {
}
func (c *Config) Validate() error {
if !supportedEventDbs.supports(c.EventDbType) {
return fmt.Errorf("event db type not supported, please select one of: %s", supportedEventDbs)
}
if !supportedDbs.supports(c.DbType) {
return fmt.Errorf("db type not supported, please select one of: %s", supportedDbs)
}
@@ -143,21 +152,33 @@ func (c *Config) AdminService() application.AdminService {
func (c *Config) repoManager() error {
var svc ports.RepoManager
var err error
var eventStoreConfig []interface{}
var dataStoreConfig []interface{}
logger := log.New()
switch c.EventDbType {
case "badger":
eventStoreConfig = []interface{}{c.EventDbDir, logger}
default:
return fmt.Errorf("unknown event db type")
}
switch c.DbType {
case "badger":
logger := log.New()
svc, err = db.NewService(db.ServiceConfig{
EventStoreType: c.DbType,
RoundStoreType: c.DbType,
VtxoStoreType: c.DbType,
EventStoreConfig: []interface{}{c.DbDir, logger},
RoundStoreConfig: []interface{}{c.DbDir, logger},
VtxoStoreConfig: []interface{}{c.DbDir, logger},
})
dataStoreConfig = []interface{}{c.DbDir, logger}
case "sqlite":
dataStoreConfig = []interface{}{c.DbDir}
default:
return fmt.Errorf("unknown db type")
}
svc, err = db.NewService(db.ServiceConfig{
EventStoreType: c.EventDbType,
DataStoreType: c.DbType,
EventStoreConfig: eventStoreConfig,
DataStoreConfig: dataStoreConfig,
})
if err != nil {
return err
}

View File

@@ -14,6 +14,7 @@ type Config struct {
WalletAddr string
RoundInterval int64
Port uint32
EventDbType string
DbType string
DbDir string
SchedulerType string
@@ -34,6 +35,7 @@ var (
WalletAddr = "WALLET_ADDR"
RoundInterval = "ROUND_INTERVAL"
Port = "PORT"
EventDbType = "EVENT_DB_TYPE"
DbType = "DB_TYPE"
SchedulerType = "SCHEDULER_TYPE"
TxBuilderType = "TX_BUILDER_TYPE"
@@ -51,7 +53,8 @@ var (
defaultRoundInterval = 5
defaultPort = 6000
defaultWalletAddr = "localhost:18000"
defaultDbType = "badger"
defaultDbType = "sqlite"
defaultEventDbType = "badger"
defaultSchedulerType = "gocron"
defaultTxBuilderType = "covenant"
defaultBlockchainScannerType = "ocean"
@@ -80,6 +83,7 @@ func LoadConfig() (*Config, error) {
viper.SetDefault(RoundInterval, defaultRoundInterval)
viper.SetDefault(RoundLifetime, defaultRoundLifetime)
viper.SetDefault(SchedulerType, defaultSchedulerType)
viper.SetDefault(EventDbType, defaultEventDbType)
viper.SetDefault(TxBuilderType, defaultTxBuilderType)
viper.SetDefault(UnilateralExitDelay, defaultUnilateralExitDelay)
viper.SetDefault(BlockchainScannerType, defaultBlockchainScannerType)
@@ -99,6 +103,7 @@ func LoadConfig() (*Config, error) {
WalletAddr: viper.GetString(WalletAddr),
RoundInterval: viper.GetInt64(RoundInterval),
Port: viper.GetUint32(Port),
EventDbType: viper.GetString(EventDbType),
DbType: viper.GetString(DbType),
SchedulerType: viper.GetString(SchedulerType),
TxBuilderType: viper.GetString(TxBuilderType),

View File

@@ -267,10 +267,13 @@ func (s *service) Onboard(
log.Debugf("broadcasted boarding tx %s", txid)
s.onboardingCh <- onboarding{
tx: boardingTx,
congestionTree: congestionTree,
userPubkey: userPubkey,
sharedOutputScript := hex.EncodeToString(extracted.Outputs[0].Script)
if _, ok := s.trustedOnboardingScripts[sharedOutputScript]; !ok {
s.onboardingCh <- onboarding{
tx: boardingTx,
congestionTree: congestionTree,
userPubkey: userPubkey,
}
}
return nil
@@ -594,7 +597,7 @@ func (s *service) listenToScannerNotifications() {
continue
}
if _, err := s.repoManager.Vtxos().RedeemVtxos(ctx, []domain.VtxoKey{vtxo.VtxoKey}); err != nil {
if err := s.repoManager.Vtxos().RedeemVtxos(ctx, []domain.VtxoKey{vtxo.VtxoKey}); err != nil {
log.WithError(err).Warn("failed to redeem vtxos, retrying...")
continue
}

View File

@@ -7,6 +7,8 @@ import (
type RoundEventRepository interface {
Save(ctx context.Context, id string, events ...RoundEvent) (*Round, error)
Load(ctx context.Context, id string) (*Round, error)
RegisterEventsHandler(func(*Round))
Close()
}
type RoundRepository interface {
@@ -17,16 +19,18 @@ type RoundRepository interface {
GetSweepableRounds(ctx context.Context) ([]Round, error)
GetRoundsIds(ctx context.Context, startedAfter int64, startedBefore int64) ([]string, error)
GetSweptRounds(ctx context.Context) ([]Round, error)
Close()
}
type VtxoRepository interface {
AddVtxos(ctx context.Context, vtxos []Vtxo) error
SpendVtxos(ctx context.Context, vtxos []VtxoKey, txid string) error
RedeemVtxos(ctx context.Context, vtxos []VtxoKey) ([]Vtxo, error)
RedeemVtxos(ctx context.Context, vtxos []VtxoKey) error
GetVtxos(ctx context.Context, vtxos []VtxoKey) ([]Vtxo, error)
GetVtxosForRound(ctx context.Context, txid string) ([]Vtxo, error)
SweepVtxos(ctx context.Context, vtxos []VtxoKey) error
GetAllVtxos(ctx context.Context, pubkey string) ([]Vtxo, []Vtxo, error)
GetAllSweepableVtxos(ctx context.Context) ([]Vtxo, error)
UpdateExpireAt(ctx context.Context, vtxos []VtxoKey, expireAt int64) error
Close()
}

View File

@@ -7,7 +7,6 @@ import (
"sync"
"github.com/ark-network/ark/internal/core/domain"
dbtypes "github.com/ark-network/ark/internal/infrastructure/db/types"
"github.com/dgraph-io/badger/v4"
"github.com/timshannon/badgerhold/v4"
)
@@ -25,7 +24,7 @@ type eventRepository struct {
handler func(round *domain.Round)
}
func NewRoundEventRepository(config ...interface{}) (dbtypes.EventStore, error) {
func NewRoundEventRepository(config ...interface{}) (domain.RoundEventRepository, error) {
if len(config) != 2 {
return nil, fmt.Errorf("invalid config")
}

View File

@@ -6,7 +6,6 @@ import (
"path/filepath"
"github.com/ark-network/ark/internal/core/domain"
dbtypes "github.com/ark-network/ark/internal/infrastructure/db/types"
"github.com/dgraph-io/badger/v4"
"github.com/timshannon/badgerhold/v4"
)
@@ -17,7 +16,7 @@ type roundRepository struct {
store *badgerhold.Store
}
func NewRoundRepository(config ...interface{}) (dbtypes.RoundStore, error) {
func NewRoundRepository(config ...interface{}) (domain.RoundRepository, error) {
if len(config) != 2 {
return nil, fmt.Errorf("invalid config")
}

View File

@@ -7,7 +7,6 @@ import (
"strings"
"github.com/ark-network/ark/internal/core/domain"
dbtypes "github.com/ark-network/ark/internal/infrastructure/db/types"
"github.com/dgraph-io/badger/v4"
"github.com/timshannon/badgerhold/v4"
)
@@ -18,7 +17,7 @@ type vtxoRepository struct {
store *badgerhold.Store
}
func NewVtxoRepository(config ...interface{}) (dbtypes.VtxoStore, error) {
func NewVtxoRepository(config ...interface{}) (domain.VtxoRepository, error) {
if len(config) != 2 {
return nil, fmt.Errorf("invalid config")
}
@@ -65,18 +64,14 @@ func (r *vtxoRepository) SpendVtxos(
func (r *vtxoRepository) RedeemVtxos(
ctx context.Context, vtxoKeys []domain.VtxoKey,
) ([]domain.Vtxo, error) {
vtxos := make([]domain.Vtxo, 0, len(vtxoKeys))
) error {
for _, vtxoKey := range vtxoKeys {
vtxo, err := r.redeemVtxo(ctx, vtxoKey)
_, err := r.redeemVtxo(ctx, vtxoKey)
if err != nil {
return nil, err
}
if vtxo != nil {
vtxos = append(vtxos, *vtxo)
return err
}
}
return vtxos, nil
return nil
}
func (r *vtxoRepository) GetVtxos(
@@ -248,7 +243,7 @@ func (r *vtxoRepository) redeemVtxo(ctx context.Context, vtxoKey domain.VtxoKey)
}
func (r *vtxoRepository) findVtxos(ctx context.Context, query *badgerhold.Query) ([]domain.Vtxo, error) {
var vtxos []domain.Vtxo
vtxos := make([]domain.Vtxo, 0)
var err error
if ctx.Value("tx") != nil {

View File

@@ -2,39 +2,44 @@ package db
import (
"fmt"
"path/filepath"
"github.com/ark-network/ark/internal/core/domain"
"github.com/ark-network/ark/internal/core/ports"
badgerdb "github.com/ark-network/ark/internal/infrastructure/db/badger"
dbtypes "github.com/ark-network/ark/internal/infrastructure/db/types"
sqlitedb "github.com/ark-network/ark/internal/infrastructure/db/sqlite"
)
var (
eventStoreTypes = map[string]func(...interface{}) (dbtypes.EventStore, error){
eventStoreTypes = map[string]func(...interface{}) (domain.RoundEventRepository, error){
"badger": badgerdb.NewRoundEventRepository,
}
roundStoreTypes = map[string]func(...interface{}) (dbtypes.RoundStore, error){
roundStoreTypes = map[string]func(...interface{}) (domain.RoundRepository, error){
"badger": badgerdb.NewRoundRepository,
"sqlite": sqlitedb.NewRoundRepository,
}
vtxoStoreTypes = map[string]func(...interface{}) (dbtypes.VtxoStore, error){
vtxoStoreTypes = map[string]func(...interface{}) (domain.VtxoRepository, error){
"badger": badgerdb.NewVtxoRepository,
"sqlite": sqlitedb.NewVtxoRepository,
}
)
const (
sqliteDbFile = "sqlite.db"
)
type ServiceConfig struct {
EventStoreType string
RoundStoreType string
VtxoStoreType string
DataStoreType string
EventStoreConfig []interface{}
RoundStoreConfig []interface{}
VtxoStoreConfig []interface{}
DataStoreConfig []interface{}
}
type service struct {
eventStore dbtypes.EventStore
roundStore dbtypes.RoundStore
vtxoStore dbtypes.VtxoStore
eventStore domain.RoundEventRepository
roundStore domain.RoundRepository
vtxoStore domain.VtxoRepository
}
func NewService(config ServiceConfig) (ports.RepoManager, error) {
@@ -42,26 +47,62 @@ func NewService(config ServiceConfig) (ports.RepoManager, error) {
if !ok {
return nil, fmt.Errorf("event store type not supported")
}
roundStoreFactory, ok := roundStoreTypes[config.RoundStoreType]
roundStoreFactory, ok := roundStoreTypes[config.DataStoreType]
if !ok {
return nil, fmt.Errorf("round store type not supported")
}
vtxoStoreFactory, ok := vtxoStoreTypes[config.VtxoStoreType]
vtxoStoreFactory, ok := vtxoStoreTypes[config.DataStoreType]
if !ok {
return nil, fmt.Errorf("vtxo store type not supported")
}
eventStore, err := eventStoreFactory(config.EventStoreConfig...)
if err != nil {
return nil, fmt.Errorf("failed to open event store: %s", err)
var eventStore domain.RoundEventRepository
var roundStore domain.RoundRepository
var vtxoStore domain.VtxoRepository
var err error
switch config.EventStoreType {
case "badger":
eventStore, err = eventStoreFactory(config.EventStoreConfig...)
if err != nil {
return nil, fmt.Errorf("failed to open event store: %s", err)
}
default:
return nil, fmt.Errorf("unknown event store db type")
}
roundStore, err := roundStoreFactory(config.RoundStoreConfig...)
if err != nil {
return nil, fmt.Errorf("failed to open round store: %s", err)
}
vtxoStore, err := vtxoStoreFactory(config.VtxoStoreConfig...)
if err != nil {
return nil, fmt.Errorf("failed to open vtxo store: %s", err)
switch config.DataStoreType {
case "badger":
roundStore, err = roundStoreFactory(config.DataStoreConfig...)
if err != nil {
return nil, fmt.Errorf("failed to open round store: %s", err)
}
vtxoStore, err = vtxoStoreFactory(config.DataStoreConfig...)
if err != nil {
return nil, fmt.Errorf("failed to open vtxo store: %s", err)
}
case "sqlite":
if len(config.DataStoreConfig) != 1 {
return nil, fmt.Errorf("invalid data store config")
}
baseDir, ok := config.DataStoreConfig[0].(string)
if !ok {
return nil, fmt.Errorf("invalid base directory")
}
db, err := sqlitedb.OpenDb(filepath.Join(baseDir, sqliteDbFile))
if err != nil {
return nil, err
}
roundStore, err = roundStoreFactory(db)
if err != nil {
return nil, fmt.Errorf("failed to open round store: %s", err)
}
vtxoStore, err = vtxoStoreFactory(db)
if err != nil {
return nil, fmt.Errorf("failed to open vtxo store: %s", err)
}
}
return &service{eventStore, roundStore, vtxoStore}, nil

View File

@@ -2,7 +2,11 @@ package db_test
import (
"context"
"crypto/rand"
"encoding/hex"
"os"
"reflect"
"sort"
"testing"
"time"
@@ -18,7 +22,6 @@ import (
const (
emptyPtx = "cHNldP8BAgQCAAAAAQQBAAEFAQABBgEDAfsEAgAAAAA="
emptyTx = "0200000000000000000000"
txid = "00000000000000000000000000000000000000000000000000000000000000000"
pubkey1 = "0300000000000000000000000000000000000000000000000000000000000000001"
pubkey2 = "0200000000000000000000000000000000000000000000000000000000000000002"
)
@@ -26,48 +29,54 @@ const (
var congestionTree = [][]tree.Node{
{
{
Txid: txid,
Txid: randomString(32),
Tx: emptyPtx,
ParentTxid: txid,
ParentTxid: randomString(32),
},
},
{
{
Txid: txid,
Txid: randomString(32),
Tx: emptyPtx,
ParentTxid: txid,
ParentTxid: randomString(32),
},
{
Txid: txid,
Txid: randomString(32),
Tx: emptyPtx,
ParentTxid: txid,
ParentTxid: randomString(32),
},
},
{
{
Txid: txid,
Txid: randomString(32),
Tx: emptyPtx,
ParentTxid: txid,
ParentTxid: randomString(32),
},
{
Txid: txid,
Txid: randomString(32),
Tx: emptyPtx,
ParentTxid: txid,
ParentTxid: randomString(32),
},
{
Txid: txid,
Txid: randomString(32),
Tx: emptyPtx,
ParentTxid: txid,
ParentTxid: randomString(32),
},
{
Txid: txid,
Txid: randomString(32),
Tx: emptyPtx,
ParentTxid: txid,
ParentTxid: randomString(32),
},
},
}
func TestMain(m *testing.M) {
m.Run()
_ = os.Remove("test.db")
}
func TestService(t *testing.T) {
dbDir := t.TempDir()
tests := []struct {
name string
config db.ServiceConfig
@@ -76,11 +85,18 @@ func TestService(t *testing.T) {
name: "repo_manager_with_badger_stores",
config: db.ServiceConfig{
EventStoreType: "badger",
RoundStoreType: "badger",
VtxoStoreType: "badger",
DataStoreType: "badger",
EventStoreConfig: []interface{}{"", nil},
RoundStoreConfig: []interface{}{"", nil},
VtxoStoreConfig: []interface{}{"", nil},
DataStoreConfig: []interface{}{"", nil},
},
},
{
name: "repo_manager_with_sqlite_stores",
config: db.ServiceConfig{
EventStoreType: "badger",
DataStoreType: "sqlite",
EventStoreConfig: []interface{}{"", nil},
DataStoreConfig: []interface{}{dbDir},
},
},
}
@@ -160,7 +176,7 @@ func testRoundEventRepository(t *testing.T, svc ports.RepoManager) {
},
domain.RoundFinalized{
Id: "7578231e-428d-45ae-aaa4-e62c77ad5cec",
Txid: txid,
Txid: randomString(32),
ForfeitTxs: []string{emptyPtx, emptyPtx, emptyPtx, emptyPtx},
Timestamp: 1701190300,
},
@@ -229,14 +245,52 @@ func testRoundRepository(t *testing.T, svc ports.RepoManager) {
Id: roundId,
Payments: []domain.Payment{
{
Id: uuid.New().String(),
Inputs: []domain.Vtxo{{}},
Receivers: []domain.Receiver{{}},
Id: uuid.New().String(),
Inputs: []domain.Vtxo{
{
VtxoKey: domain.VtxoKey{
Txid: randomString(32),
VOut: 0,
},
PoolTx: randomString(32),
ExpireAt: 7980322,
Receiver: domain.Receiver{
Pubkey: randomString(36),
Amount: 300,
},
},
},
Receivers: []domain.Receiver{{
Pubkey: randomString(36),
Amount: 300,
}},
},
{
Id: uuid.New().String(),
Inputs: []domain.Vtxo{{}},
Receivers: []domain.Receiver{{}, {}, {}},
Id: uuid.New().String(),
Inputs: []domain.Vtxo{
{
VtxoKey: domain.VtxoKey{
Txid: randomString(32),
VOut: 0,
},
PoolTx: randomString(32),
ExpireAt: 7980322,
Receiver: domain.Receiver{
Pubkey: randomString(36),
Amount: 600,
},
},
},
Receivers: []domain.Receiver{
{
Pubkey: randomString(36),
Amount: 400,
},
{
Pubkey: randomString(34),
Amount: 200,
},
},
},
},
},
@@ -249,6 +303,10 @@ func testRoundRepository(t *testing.T, svc ports.RepoManager) {
}
events = append(events, newEvents...)
updatedRound := domain.NewRoundFromEvents(events)
for _, pay := range updatedRound.Payments {
err = svc.Vtxos().AddVtxos(ctx, pay.Inputs)
require.NoError(t, err)
}
err = svc.Rounds().AddOrUpdateRound(ctx, *updatedRound)
require.NoError(t, err)
@@ -263,6 +321,7 @@ func testRoundRepository(t *testing.T, svc ports.RepoManager) {
require.NotNil(t, currentRound)
require.Condition(t, roundsMatch(*updatedRound, *roundById))
txid := randomString(32)
newEvents = []domain.RoundEvent{
domain.RoundFinalized{
Id: roundId,
@@ -300,7 +359,7 @@ func testVtxoRepository(t *testing.T, svc ports.RepoManager) {
userVtxos := []domain.Vtxo{
{
VtxoKey: domain.VtxoKey{
Txid: txid,
Txid: randomString(32),
VOut: 0,
},
Receiver: domain.Receiver{
@@ -310,7 +369,7 @@ func testVtxoRepository(t *testing.T, svc ports.RepoManager) {
},
{
VtxoKey: domain.VtxoKey{
Txid: txid,
Txid: randomString(32),
VOut: 1,
},
Receiver: domain.Receiver{
@@ -321,7 +380,7 @@ func testVtxoRepository(t *testing.T, svc ports.RepoManager) {
}
newVtxos := append(userVtxos, domain.Vtxo{
VtxoKey: domain.VtxoKey{
Txid: txid,
Txid: randomString(32),
VOut: 1,
},
Receiver: domain.Receiver{
@@ -346,8 +405,8 @@ func testVtxoRepository(t *testing.T, svc ports.RepoManager) {
spendableVtxos, spentVtxos, err = svc.Vtxos().GetAllVtxos(ctx, "")
require.NoError(t, err)
require.Empty(t, spendableVtxos)
require.Empty(t, spentVtxos)
numberOfVtxos := len(spendableVtxos) + len(spentVtxos)
err = svc.Vtxos().AddVtxos(ctx, newVtxos)
require.NoError(t, err)
@@ -358,15 +417,21 @@ func testVtxoRepository(t *testing.T, svc ports.RepoManager) {
spendableVtxos, spentVtxos, err = svc.Vtxos().GetAllVtxos(ctx, pubkey1)
require.NoError(t, err)
require.Exactly(t, vtxos, spendableVtxos)
sortedVtxos := sortVtxos(userVtxos)
sort.Sort(sortedVtxos)
sortedSpendableVtxos := sortVtxos(spendableVtxos)
sort.Sort(sortedSpendableVtxos)
require.Exactly(t, sortedSpendableVtxos, sortedVtxos)
require.Empty(t, spentVtxos)
spendableVtxos, spentVtxos, err = svc.Vtxos().GetAllVtxos(ctx, "")
require.NoError(t, err)
require.Exactly(t, userVtxos, spendableVtxos)
require.Empty(t, spentVtxos)
require.Len(t, append(spendableVtxos, spentVtxos...), numberOfVtxos+len(newVtxos))
err = svc.Vtxos().SpendVtxos(ctx, vtxoKeys[:1], txid)
err = svc.Vtxos().SpendVtxos(ctx, vtxoKeys[:1], randomString(32))
require.NoError(t, err)
spentVtxos, err = svc.Vtxos().GetVtxos(ctx, vtxoKeys[:1])
@@ -397,27 +462,92 @@ func roundsMatch(expected, got domain.Round) assert.Comparison {
if expected.Stage != got.Stage {
return false
}
if !reflect.DeepEqual(expected.Payments, got.Payments) {
return false
for k, v := range expected.Payments {
gotValue, ok := got.Payments[k]
if !ok {
return false
}
expectedVtxos := sortVtxos(v.Inputs)
gotVtxos := sortVtxos(gotValue.Inputs)
sort.Sort(expectedVtxos)
sort.Sort(gotVtxos)
expectedReceivers := sortReceivers(v.Receivers)
gotReceivers := sortReceivers(gotValue.Receivers)
sort.Sort(expectedReceivers)
sort.Sort(gotReceivers)
if !reflect.DeepEqual(expectedReceivers, gotReceivers) {
return false
}
if !reflect.DeepEqual(expectedVtxos, gotVtxos) {
return false
}
}
if expected.Txid != got.Txid {
return false
}
if expected.UnsignedTx != got.UnsignedTx {
return false
}
if !reflect.DeepEqual(expected.ForfeitTxs, got.ForfeitTxs) {
return false
if len(expected.ForfeitTxs) > 0 {
expectedForfeits := sortStrings(expected.ForfeitTxs)
gotForfeits := sortStrings(got.ForfeitTxs)
sort.Sort(expectedForfeits)
sort.Sort(gotForfeits)
if !reflect.DeepEqual(expectedForfeits, gotForfeits) {
return false
}
}
if !reflect.DeepEqual(expected.CongestionTree, got.CongestionTree) {
return false
}
if !reflect.DeepEqual(expected.Connectors, got.Connectors) {
return false
if len(expected.Connectors) > 0 {
expectedConnectors := sortStrings(expected.Connectors)
gotConnectors := sortStrings(got.Connectors)
sort.Sort(expectedConnectors)
sort.Sort(gotConnectors)
if !reflect.DeepEqual(expectedConnectors, gotConnectors) {
return false
}
}
if expected.Version != got.Version {
return false
}
return true
return expected.Version == got.Version
}
}
func randomString(len int) string {
buf := make([]byte, len)
// nolint
rand.Read(buf)
return hex.EncodeToString(buf)
}
type sortVtxos []domain.Vtxo
func (a sortVtxos) Len() int { return len(a) }
func (a sortVtxos) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
func (a sortVtxos) Less(i, j int) bool { return a[i].Txid < a[j].Txid }
type sortReceivers []domain.Receiver
func (a sortReceivers) Len() int { return len(a) }
func (a sortReceivers) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
func (a sortReceivers) Less(i, j int) bool { return a[i].Pubkey < a[j].Pubkey }
type sortStrings []string
func (a sortStrings) Len() int { return len(a) }
func (a sortStrings) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
func (a sortStrings) Less(i, j int) bool { return a[i] < a[j] }

View File

@@ -0,0 +1,691 @@
package sqlitedb
import (
"context"
"database/sql"
"errors"
"fmt"
"github.com/ark-network/ark/common/tree"
"github.com/ark-network/ark/internal/core/domain"
)
const (
createReceiverTable = `
CREATE TABLE IF NOT EXISTS receiver (
payment_id TEXT NOT NULL,
pubkey TEXT NOT NULL,
amount INTEGER NOT NULL,
onchain_address TEXT NOT NULL,
FOREIGN KEY (payment_id) REFERENCES payment(id)
PRIMARY KEY (payment_id, pubkey)
);
`
createPaymentTable = `
CREATE TABLE IF NOT EXISTS payment (
id TEXT PRIMARY KEY,
round_id TEXT NOT NULL,
FOREIGN KEY (round_id) REFERENCES round(id)
);
`
createRoundTable = `
CREATE TABLE IF NOT EXISTS round (
id TEXT PRIMARY KEY,
starting_timestamp INTEGER NOT NULL,
ending_timestamp INTEGER NOT NULL,
ended BOOLEAN NOT NULL,
failed BOOLEAN NOT NULL,
stage_code INTEGER NOT NULL,
txid TEXT NOT NULL,
unsigned_tx TEXT NOT NULL,
connector_address TEXT NOT NULL,
dust_amount INTEGER NOT NULL,
version INTEGER NOT NULL,
swept BOOLEAN NOT NULL
);
`
createTransactionTable = `
CREATE TABLE IF NOT EXISTS tx (
id INTEGER PRIMARY KEY AUTOINCREMENT,
tx TEXT NOT NULL,
round_id TEXT NOT NULL,
type TEXT NOT NULL,
position INTEGER NOT NULL,
txid TEXT,
tree_level INTEGER,
parent_txid TEXT,
is_leaf BOOLEAN,
FOREIGN KEY (round_id) REFERENCES round(id)
);
`
upsertTransaction = `
INSERT INTO tx (
tx, round_id, type, position, txid, tree_level, parent_txid, is_leaf
) VALUES (?, ?, ?, ?, ?, ?, ?, ?)
ON CONFLICT(id) DO UPDATE SET
tx = EXCLUDED.tx,
round_id = EXCLUDED.round_id,
type = EXCLUDED.type,
position = EXCLUDED.position,
txid = EXCLUDED.txid,
tree_level = EXCLUDED.tree_level,
parent_txid = EXCLUDED.parent_txid,
is_leaf = EXCLUDED.is_leaf;
`
upsertRound = `
INSERT INTO round (
id,
starting_timestamp,
ending_timestamp,
ended, failed,
stage_code,
txid,
unsigned_tx,
connector_address,
dust_amount,
version,
swept
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
ON CONFLICT(id) DO UPDATE SET
starting_timestamp = EXCLUDED.starting_timestamp,
ending_timestamp = EXCLUDED.ending_timestamp,
ended = EXCLUDED.ended,
failed = EXCLUDED.failed,
stage_code = EXCLUDED.stage_code,
txid = EXCLUDED.txid,
unsigned_tx = EXCLUDED.unsigned_tx,
connector_address = EXCLUDED.connector_address,
dust_amount = EXCLUDED.dust_amount,
version = EXCLUDED.version,
swept = EXCLUDED.swept;
`
upsertPayment = `
INSERT INTO payment (id, round_id) VALUES (?, ?)
ON CONFLICT(id) DO UPDATE SET round_id = EXCLUDED.round_id;
`
upsertReceiver = `
INSERT INTO receiver (payment_id, pubkey, amount, onchain_address) VALUES (?, ?, ?, ?)
ON CONFLICT(payment_id, pubkey) DO UPDATE SET
amount = EXCLUDED.amount,
onchain_address = EXCLUDED.onchain_address,
pubkey = EXCLUDED.pubkey;
`
updateVtxoPaymentId = `
UPDATE vtxo SET payment_id = ? WHERE txid = ? AND vout = ?
`
selectRound = `
SELECT round.id, round.starting_timestamp, round.ending_timestamp, round.ended, round.failed, round.stage_code, round.txid,
round.unsigned_tx, round.connector_address, round.dust_amount, round.version, round.swept, payment.id, receiver.payment_id,
receiver.pubkey, receiver.amount, receiver.onchain_address, vtxo.txid, vtxo.vout, vtxo.pubkey, vtxo.amount,
vtxo.pool_tx, vtxo.spent_by, vtxo.spent, vtxo.redeemed, vtxo.swept, vtxo.expire_at, vtxo.payment_id,
tx.tx, tx.type, tx.position, tx.txid,
tx.tree_level, tx.parent_txid, tx.is_leaf
FROM round
LEFT OUTER JOIN payment ON round.id=payment.round_id
LEFT OUTER JOIN tx ON round.id=tx.round_id
LEFT OUTER JOIN receiver ON payment.id=receiver.payment_id
LEFT OUTER JOIN vtxo ON payment.id=vtxo.payment_id
`
selectCurrentRound = selectRound + " WHERE round.ended = false AND round.failed = false;"
selectRoundWithId = selectRound + " WHERE round.id = ?;"
selectRoundWithTxId = selectRound + " WHERE round.txid = ?;"
selectSweepableRounds = selectRound + " WHERE round.swept = false AND round.ended = true AND round.failed = false;"
selectSweptRounds = selectRound + " WHERE round.swept = true AND round.failed = false AND round.ended = true;"
selectRoundIdsInRange = `
SELECT id FROM round WHERE starting_timestamp > ? AND starting_timestamp < ?;
`
selectRoundIds = `
SELECT id FROM round;
`
)
type receiverRow struct {
paymentId *string
pubkey *string
amount *uint64
onchainAddress *string
}
type paymentRow struct {
id *string
}
type transactionRow struct {
tx *string
txType *string
position *int
txid *string
treeLevel *int
parentTxid *string
isLeaf *bool
}
type roundRow struct {
id *string
startingTimestamp *int64
endingTimestamp *int64
ended *bool
failed *bool
stageCode *domain.RoundStage
txid *string
unsignedTx *string
connectorAddress *string
dustAmount *uint64
version *uint
swept *bool
}
type roundRepository struct {
db *sql.DB
}
func NewRoundRepository(config ...interface{}) (domain.RoundRepository, error) {
if len(config) != 1 {
return nil, fmt.Errorf("invalid config")
}
db, ok := config[0].(*sql.DB)
if !ok {
return nil, fmt.Errorf("cannot open round repository: invalid config")
}
return newRoundRepository(db)
}
func newRoundRepository(db *sql.DB) (*roundRepository, error) {
if _, err := db.Exec(createRoundTable); err != nil {
return nil, err
}
if _, err := db.Exec(createPaymentTable); err != nil {
return nil, err
}
if _, err := db.Exec(createReceiverTable); err != nil {
return nil, err
}
if _, err := db.Exec(createTransactionTable); err != nil {
return nil, err
}
return &roundRepository{db}, nil
}
func (r *roundRepository) Close() {
_ = r.db.Close()
}
func (r *roundRepository) GetRoundsIds(ctx context.Context, startedAfter int64, startedBefore int64) ([]string, error) {
var rows *sql.Rows
if startedAfter == 0 && startedBefore == 0 {
stmt, err := r.db.Prepare(selectRoundIds)
if err != nil {
return nil, err
}
defer stmt.Close()
rows, err = stmt.Query()
if err != nil {
return nil, err
}
} else {
stmt, err := r.db.Prepare(selectRoundIdsInRange)
if err != nil {
return nil, err
}
defer stmt.Close()
rows, err = stmt.Query(startedAfter, startedBefore)
if err != nil {
return nil, err
}
}
defer rows.Close()
ids := make([]string, 0)
for rows.Next() {
var id string
if err := rows.Scan(&id); err != nil {
return nil, err
}
ids = append(ids, id)
}
return ids, nil
}
func (r *roundRepository) AddOrUpdateRound(ctx context.Context, round domain.Round) error {
tx, err := r.db.Begin()
if err != nil {
return err
}
stmt, err := tx.Prepare(upsertRound)
if err != nil {
return err
}
defer stmt.Close()
// insert round row
_, err = stmt.Exec(
round.Id,
round.StartingTimestamp,
round.EndingTimestamp,
round.Stage.Ended,
round.Stage.Failed,
round.Stage.Code,
round.Txid,
round.UnsignedTx,
round.ConnectorAddress,
round.DustAmount,
round.Version,
round.Swept,
)
if err != nil {
return err
}
// insert transactions rows
if len(round.ForfeitTxs) > 0 || len(round.Connectors) > 0 || len(round.CongestionTree) > 0 {
stmt, err = tx.Prepare(upsertTransaction)
if err != nil {
return err
}
defer stmt.Close()
for pos, tx := range round.ForfeitTxs {
_, err := stmt.Exec(tx, round.Id, "forfeit", pos, nil, nil, nil, nil)
if err != nil {
return err
}
}
for pos, tx := range round.Connectors {
_, err := stmt.Exec(tx, round.Id, "connector", pos, nil, nil, nil, nil)
if err != nil {
return err
}
}
for level, levelTxs := range round.CongestionTree {
for pos, tx := range levelTxs {
_, err := stmt.Exec(tx.Tx, round.Id, "tree", pos, tx.Txid, level, tx.ParentTxid, tx.Leaf)
if err != nil {
return err
}
}
}
}
// insert payments rows
if len(round.Payments) > 0 {
stmtUpsertPayment, err := tx.Prepare(upsertPayment)
if err != nil {
return err
}
defer stmtUpsertPayment.Close()
for _, payment := range round.Payments {
_, err = stmtUpsertPayment.Exec(payment.Id, round.Id)
if err != nil {
return err
}
stmtUpsertReceiver, err := tx.Prepare(upsertReceiver)
if err != nil {
return err
}
defer stmtUpsertReceiver.Close()
for _, receiver := range payment.Receivers {
_, err := stmtUpsertReceiver.Exec(payment.Id, receiver.Pubkey, receiver.Amount, receiver.OnchainAddress)
if err != nil {
return err
}
}
stmtUpdatePaymentId, err := tx.Prepare(updateVtxoPaymentId)
if err != nil {
return err
}
defer stmtUpdatePaymentId.Close()
for _, input := range payment.Inputs {
_, err := stmtUpdatePaymentId.Exec(payment.Id, input.Txid, input.VOut)
if err != nil {
return err
}
}
}
}
return tx.Commit()
}
func (r *roundRepository) GetCurrentRound(ctx context.Context) (*domain.Round, error) {
stmt, err := r.db.Prepare(selectCurrentRound)
if err != nil {
return nil, err
}
defer stmt.Close()
rows, err := stmt.Query()
if err != nil {
return nil, err
}
rounds, err := readRoundRows(rows)
if err != nil {
return nil, err
}
if len(rounds) == 0 {
return nil, errors.New("no current round")
}
return rounds[0], nil
}
func (r *roundRepository) GetRoundWithId(ctx context.Context, id string) (*domain.Round, error) {
stmt, err := r.db.Prepare(selectRoundWithId)
if err != nil {
return nil, err
}
defer stmt.Close()
rows, err := stmt.Query(id)
if err != nil {
return nil, err
}
rounds, err := readRoundRows(rows)
if err != nil {
return nil, err
}
if len(rounds) > 0 {
return rounds[0], nil
}
return nil, errors.New("round not found")
}
func (r *roundRepository) GetRoundWithTxid(ctx context.Context, txid string) (*domain.Round, error) {
stmt, err := r.db.Prepare(selectRoundWithTxId)
if err != nil {
return nil, err
}
defer stmt.Close()
rows, err := stmt.Query(txid)
if err != nil {
return nil, err
}
rounds, err := readRoundRows(rows)
if err != nil {
return nil, err
}
if len(rounds) > 0 {
return rounds[0], nil
}
return nil, errors.New("round not found")
}
func (r *roundRepository) GetSweepableRounds(ctx context.Context) ([]domain.Round, error) {
stmt, err := r.db.Prepare(selectSweepableRounds)
if err != nil {
return nil, err
}
defer stmt.Close()
rows, err := stmt.Query()
if err != nil {
return nil, err
}
rounds, err := readRoundRows(rows)
if err != nil {
return nil, err
}
res := make([]domain.Round, 0)
for _, round := range rounds {
res = append(res, *round)
}
return res, nil
}
func (r *roundRepository) GetSweptRounds(ctx context.Context) ([]domain.Round, error) {
stmt, err := r.db.Prepare(selectSweptRounds)
if err != nil {
return nil, err
}
defer stmt.Close()
rows, err := stmt.Query()
if err != nil {
return nil, err
}
rounds, err := readRoundRows(rows)
if err != nil {
return nil, err
}
res := make([]domain.Round, 0)
for _, round := range rounds {
res = append(res, *round)
}
return res, nil
}
func rowToReceiver(row receiverRow) domain.Receiver {
return domain.Receiver{
Pubkey: *row.pubkey,
Amount: *row.amount,
OnchainAddress: *row.onchainAddress,
}
}
func readRoundRows(rows *sql.Rows) ([]*domain.Round, error) {
defer rows.Close()
rounds := make(map[string]*domain.Round)
for rows.Next() {
var roundRow roundRow
var paymentRow paymentRow
var receiverRow receiverRow
var vtxoRow vtxoRow
var transactionRow transactionRow
if err := rows.Scan(
&roundRow.id,
&roundRow.startingTimestamp,
&roundRow.endingTimestamp,
&roundRow.ended,
&roundRow.failed,
&roundRow.stageCode,
&roundRow.txid,
&roundRow.unsignedTx,
&roundRow.connectorAddress,
&roundRow.dustAmount,
&roundRow.version,
&roundRow.swept,
&paymentRow.id,
&receiverRow.paymentId,
&receiverRow.pubkey,
&receiverRow.amount,
&receiverRow.onchainAddress,
&vtxoRow.txid,
&vtxoRow.vout,
&vtxoRow.pubkey,
&vtxoRow.amount,
&vtxoRow.poolTx,
&vtxoRow.spentBy,
&vtxoRow.spent,
&vtxoRow.redeemed,
&vtxoRow.swept,
&vtxoRow.expireAt,
&vtxoRow.paymentID,
&transactionRow.tx,
&transactionRow.txType,
&transactionRow.position,
&transactionRow.txid,
&transactionRow.treeLevel,
&transactionRow.parentTxid,
&transactionRow.isLeaf,
); err != nil {
return nil, err
}
var round *domain.Round
var ok bool
if roundRow.id == nil {
continue
}
round, ok = rounds[*roundRow.id]
if !ok {
round = &domain.Round{
Id: *roundRow.id,
StartingTimestamp: *roundRow.startingTimestamp,
EndingTimestamp: *roundRow.endingTimestamp,
Stage: domain.Stage{
Ended: *roundRow.ended,
Failed: *roundRow.failed,
Code: *roundRow.stageCode,
},
Txid: *roundRow.txid,
UnsignedTx: *roundRow.unsignedTx,
ConnectorAddress: *roundRow.connectorAddress,
DustAmount: *roundRow.dustAmount,
Version: *roundRow.version,
Swept: *roundRow.swept,
Payments: make(map[string]domain.Payment),
}
}
if paymentRow.id != nil {
payment, ok := round.Payments[*paymentRow.id]
if !ok {
payment = domain.Payment{
Id: *paymentRow.id,
Inputs: make([]domain.Vtxo, 0),
Receivers: make([]domain.Receiver, 0),
}
round.Payments[*paymentRow.id] = payment
}
if vtxoRow.paymentID != nil {
payment, ok = round.Payments[*vtxoRow.paymentID]
if !ok {
payment = domain.Payment{
Id: *vtxoRow.paymentID,
Inputs: make([]domain.Vtxo, 0),
Receivers: make([]domain.Receiver, 0),
}
}
vtxo := rowToVtxo(vtxoRow)
found := false
for _, v := range payment.Inputs {
if vtxo.Txid == v.Txid && vtxo.VOut == v.VOut {
found = true
break
}
}
if !found {
payment.Inputs = append(payment.Inputs, rowToVtxo(vtxoRow))
round.Payments[*vtxoRow.paymentID] = payment
}
}
if receiverRow.paymentId != nil {
payment, ok = round.Payments[*receiverRow.paymentId]
if !ok {
payment = domain.Payment{
Id: *receiverRow.paymentId,
Inputs: make([]domain.Vtxo, 0),
Receivers: make([]domain.Receiver, 0),
}
}
rcv := rowToReceiver(receiverRow)
found := false
for _, rcv := range payment.Receivers {
if rcv.Pubkey == *receiverRow.pubkey && rcv.Amount == *receiverRow.amount {
found = true
break
}
}
if !found {
payment.Receivers = append(payment.Receivers, rcv)
round.Payments[*receiverRow.paymentId] = payment
}
}
}
if transactionRow.tx != nil {
position := *transactionRow.position
switch *transactionRow.txType {
case "forfeit":
round.ForfeitTxs = extendArray(round.ForfeitTxs, position)
round.ForfeitTxs[position] = *transactionRow.tx
case "connector":
round.Connectors = extendArray(round.Connectors, position)
round.Connectors[position] = *transactionRow.tx
case "tree":
level := *transactionRow.treeLevel
round.CongestionTree = extendArray(round.CongestionTree, level)
round.CongestionTree[level] = extendArray(round.CongestionTree[level], position)
if round.CongestionTree[level][position] == (tree.Node{}) {
round.CongestionTree[level][position] = tree.Node{
Tx: *transactionRow.tx,
Txid: *transactionRow.txid,
ParentTxid: *transactionRow.parentTxid,
Leaf: *transactionRow.isLeaf,
}
}
}
}
rounds[*roundRow.id] = round
}
var result []*domain.Round
for _, round := range rounds {
result = append(result, round)
}
return result, nil
}

View File

@@ -0,0 +1,45 @@
package sqlitedb
import (
"database/sql"
"fmt"
"os"
"path/filepath"
_ "modernc.org/sqlite"
)
const (
driverName = "sqlite"
)
func OpenDb(dbPath string) (*sql.DB, error) {
dir := filepath.Dir(dbPath)
if _, err := os.Stat(dir); os.IsNotExist(err) {
err = os.MkdirAll(dir, 0755)
if err != nil {
return nil, fmt.Errorf("failed to create directory: %v", err)
}
}
db, err := sql.Open(driverName, dbPath)
if err != nil {
return nil, fmt.Errorf("failed to open db: %w", err)
}
db.SetMaxOpenConns(1) // prevent concurrent writes
return db, nil
}
func extendArray[T any](arr []T, position int) []T {
if arr == nil {
return make([]T, position+1)
}
if len(arr) <= position {
return append(arr, make([]T, position-len(arr)+1)...)
}
return arr
}

View File

@@ -0,0 +1,377 @@
package sqlitedb
import (
"context"
"database/sql"
"fmt"
"github.com/ark-network/ark/internal/core/domain"
)
const (
createVtxoTable = `
CREATE TABLE IF NOT EXISTS vtxo (
txid TEXT NOT NULL PRIMARY KEY,
vout INTEGER NOT NULL,
pubkey TEXT NOT NULL,
amount INTEGER NOT NULL,
pool_tx TEXT NOT NULL,
spent_by TEXT NOT NULL,
spent BOOLEAN NOT NULL,
redeemed BOOLEAN NOT NULL,
swept BOOLEAN NOT NULL,
expire_at INTEGER NOT NULL,
payment_id TEXT,
FOREIGN KEY (payment_id) REFERENCES payment(id)
);
`
upsertVtxos = `
INSERT INTO vtxo (txid, vout, pubkey, amount, pool_tx, spent_by, spent, redeemed, swept, expire_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) ON CONFLICT(txid) DO UPDATE SET
vout = excluded.vout,
pubkey = excluded.pubkey,
amount = excluded.amount,
pool_tx = excluded.pool_tx,
spent_by = excluded.spent_by,
spent = excluded.spent,
redeemed = excluded.redeemed,
swept = excluded.swept,
expire_at = excluded.expire_at;
`
selectSweepableVtxos = `
SELECT * FROM vtxo WHERE redeemed = false AND swept = false
`
selectNotRedeemedVtxos = `
SELECT * FROM vtxo WHERE redeemed = false
`
selectNotRedeemedVtxosWithPubkey = `
SELECT * FROM vtxo WHERE redeemed = false AND pubkey = ?
`
selectVtxoByOutpoint = `
SELECT * FROM vtxo WHERE txid = ? AND vout = ?
`
selectVtxosByPoolTxid = `
SELECT * FROM vtxo WHERE pool_tx = ?
`
markVtxoAsRedeemed = `
UPDATE vtxo SET redeemed = true WHERE txid = ? AND vout = ?
`
markVtxoAsSwept = `
UPDATE vtxo SET swept = true WHERE txid = ? AND vout = ?
`
markVtxoAsSpent = `
UPDATE vtxo SET spent = true, spent_by = ? WHERE txid = ? AND vout = ?
`
updateVtxoExpireAt = `
UPDATE vtxo SET expire_at = ? WHERE txid = ? AND vout = ?
`
)
type vtxoRow struct {
txid *string
vout *uint32
pubkey *string
amount *uint64
poolTx *string
spentBy *string
spent *bool
redeemed *bool
swept *bool
expireAt *int64
paymentID *string
}
type vxtoRepository struct {
db *sql.DB
}
func NewVtxoRepository(config ...interface{}) (domain.VtxoRepository, error) {
if len(config) != 1 {
return nil, fmt.Errorf("invalid config")
}
db, ok := config[0].(*sql.DB)
if !ok {
return nil, fmt.Errorf("cannot open vtxo repository: invalid config")
}
return newVtxoRepository(db)
}
func newVtxoRepository(db *sql.DB) (*vxtoRepository, error) {
_, err := db.Exec(createVtxoTable)
if err != nil {
return nil, err
}
return &vxtoRepository{db}, nil
}
func (v *vxtoRepository) Close() {
_ = v.db.Close()
}
func (v *vxtoRepository) AddVtxos(ctx context.Context, vtxos []domain.Vtxo) error {
tx, err := v.db.Begin()
if err != nil {
return err
}
stmt, err := tx.Prepare(upsertVtxos)
if err != nil {
return err
}
defer stmt.Close()
for _, vtxo := range vtxos {
_, err := stmt.Exec(
vtxo.Txid,
vtxo.VOut,
vtxo.Pubkey,
vtxo.Amount,
vtxo.PoolTx,
vtxo.SpentBy,
vtxo.Spent,
vtxo.Redeemed,
vtxo.Swept,
vtxo.ExpireAt,
)
if err != nil {
return err
}
}
return tx.Commit()
}
func (v *vxtoRepository) GetAllSweepableVtxos(ctx context.Context) ([]domain.Vtxo, error) {
rows, err := v.db.Query(selectSweepableVtxos)
if err != nil {
return nil, err
}
return readRows(rows)
}
func (v *vxtoRepository) GetAllVtxos(ctx context.Context, pubkey string) ([]domain.Vtxo, []domain.Vtxo, error) {
withPubkey := len(pubkey) > 0
var rows *sql.Rows
var err error
if withPubkey {
rows, err = v.db.Query(selectNotRedeemedVtxosWithPubkey, pubkey)
} else {
rows, err = v.db.Query(selectNotRedeemedVtxos)
}
if err != nil {
return nil, nil, err
}
vtxos, err := readRows(rows)
if err != nil {
return nil, nil, err
}
unspentVtxos := make([]domain.Vtxo, 0)
spentVtxos := make([]domain.Vtxo, 0)
for _, vtxo := range vtxos {
if vtxo.Spent {
spentVtxos = append(spentVtxos, vtxo)
} else {
unspentVtxos = append(unspentVtxos, vtxo)
}
}
return unspentVtxos, spentVtxos, nil
}
func (v *vxtoRepository) GetVtxos(ctx context.Context, outpoints []domain.VtxoKey) ([]domain.Vtxo, error) {
stmt, err := v.db.Prepare(selectVtxoByOutpoint)
if err != nil {
return nil, err
}
defer stmt.Close()
vtxos := make([]domain.Vtxo, 0, len(outpoints))
for _, outpoint := range outpoints {
rows, err := stmt.Query(outpoint.Txid, outpoint.VOut)
if err != nil {
return nil, err
}
result, err := readRows(rows)
if err != nil {
return nil, err
}
if len(result) == 0 {
return nil, fmt.Errorf("vtxo not found")
}
vtxos = append(vtxos, result[0])
}
return vtxos, nil
}
func (v *vxtoRepository) GetVtxosForRound(ctx context.Context, txid string) ([]domain.Vtxo, error) {
rows, err := v.db.Query(selectVtxosByPoolTxid, txid)
if err != nil {
return nil, err
}
return readRows(rows)
}
func (v *vxtoRepository) RedeemVtxos(ctx context.Context, vtxos []domain.VtxoKey) error {
tx, err := v.db.Begin()
if err != nil {
return err
}
stmt, err := tx.Prepare(markVtxoAsRedeemed)
if err != nil {
return err
}
defer stmt.Close()
for _, vtxo := range vtxos {
_, err := stmt.Exec(vtxo.Txid, vtxo.VOut)
if err != nil {
return err
}
}
return tx.Commit()
}
func (v *vxtoRepository) SpendVtxos(ctx context.Context, vtxos []domain.VtxoKey, txid string) error {
tx, err := v.db.Begin()
if err != nil {
return err
}
stmt, err := tx.Prepare(markVtxoAsSpent)
if err != nil {
return err
}
defer stmt.Close()
for _, vtxo := range vtxos {
_, err := stmt.Exec(txid, vtxo.Txid, vtxo.VOut)
if err != nil {
return err
}
}
return tx.Commit()
}
func (v *vxtoRepository) SweepVtxos(ctx context.Context, vtxos []domain.VtxoKey) error {
tx, err := v.db.Begin()
if err != nil {
return err
}
stmt, err := tx.Prepare(markVtxoAsSwept)
if err != nil {
return err
}
defer stmt.Close()
for _, vtxo := range vtxos {
_, err := stmt.Exec(vtxo.Txid, vtxo.VOut)
if err != nil {
return err
}
}
return tx.Commit()
}
func (v *vxtoRepository) UpdateExpireAt(ctx context.Context, vtxos []domain.VtxoKey, expireAt int64) error {
tx, err := v.db.Begin()
if err != nil {
return err
}
stmt, err := tx.Prepare(updateVtxoExpireAt)
if err != nil {
return err
}
defer stmt.Close()
for _, vtxo := range vtxos {
_, err := stmt.Exec(expireAt, vtxo.Txid, vtxo.VOut)
if err != nil {
return err
}
}
return tx.Commit()
}
func rowToVtxo(row vtxoRow) domain.Vtxo {
return domain.Vtxo{
VtxoKey: domain.VtxoKey{
Txid: *row.txid,
VOut: *row.vout,
},
Receiver: domain.Receiver{
Pubkey: *row.pubkey,
Amount: *row.amount,
},
PoolTx: *row.poolTx,
SpentBy: *row.spentBy,
Spent: *row.spent,
Redeemed: *row.redeemed,
Swept: *row.swept,
ExpireAt: *row.expireAt,
}
}
func readRows(rows *sql.Rows) ([]domain.Vtxo, error) {
defer rows.Close()
vtxos := make([]domain.Vtxo, 0)
for rows.Next() {
var row vtxoRow
if err := rows.Scan(
&row.txid,
&row.vout,
&row.pubkey,
&row.amount,
&row.poolTx,
&row.spentBy,
&row.spent,
&row.redeemed,
&row.swept,
&row.expireAt,
&row.paymentID,
); err != nil {
return nil, err
}
vtxos = append(vtxos, rowToVtxo(row))
}
return vtxos, nil
}

View File

@@ -1,19 +0,0 @@
package dbtypes
import "github.com/ark-network/ark/internal/core/domain"
type EventStore interface {
domain.RoundEventRepository
RegisterEventsHandler(func(*domain.Round))
Close()
}
type RoundStore interface {
domain.RoundRepository
Close()
}
type VtxoStore interface {
domain.VtxoRepository
Close()
}