Add integration tests for sweeping rounds (#339)

* add "block" scheduler type + sweep integration test

* increase timeout in integrationtests

* remove config logs

* rename scheduler package name

* rename package

* rename packages
This commit is contained in:
Louis Singer
2024-10-05 16:12:46 +02:00
committed by GitHub
parent 7606b4cd00
commit 0d39bb6b9f
37 changed files with 477 additions and 279 deletions

View File

@@ -1,8 +1,10 @@
package common
import (
"encoding/hex"
"fmt"
"github.com/btcsuite/btcd/blockchain"
"github.com/btcsuite/btcd/txscript"
)
const (
@@ -18,42 +20,33 @@ func closerToModulo512(x uint) uint {
return x - (x % 512)
}
func BIP68EncodeAsNumber(seconds uint) (uint32, error) {
seconds = closerToModulo512(seconds)
if seconds > SECONDS_MAX {
return 0, fmt.Errorf("seconds too large, max is %d", SECONDS_MAX)
}
if seconds%SECONDS_MOD != 0 {
return 0, fmt.Errorf("seconds must be a multiple of %d", SECONDS_MOD)
func BIP68Sequence(locktime uint) (uint32, error) {
isSeconds := locktime >= 512
if isSeconds {
locktime = closerToModulo512(locktime)
if locktime > SECONDS_MAX {
return 0, fmt.Errorf("seconds too large, max is %d", SECONDS_MAX)
}
if locktime%SECONDS_MOD != 0 {
return 0, fmt.Errorf("seconds must be a multiple of %d", SECONDS_MOD)
}
}
asNumber := SEQUENCE_LOCKTIME_TYPE_FLAG | (seconds >> SEQUENCE_LOCKTIME_GRANULARITY)
return uint32(asNumber), nil
return blockchain.LockTimeToSequence(isSeconds, uint32(locktime)), nil
}
// BIP68Encode returns the encoded sequence locktime for the given number of seconds.
func BIP68Encode(seconds uint) ([]byte, error) {
asNumber, err := BIP68EncodeAsNumber(seconds)
func BIP68DecodeSequence(sequence []byte) (uint, error) {
scriptNumber, err := txscript.MakeScriptNum(sequence, true, len(sequence))
if err != nil {
return nil, err
return 0, err
}
hexString := fmt.Sprintf("%x", asNumber)
reversed, err := hex.DecodeString(hexString)
if err != nil {
return nil, err
}
for i, j := 0, len(reversed)-1; i < j; i, j = i+1, j-1 {
reversed[i], reversed[j] = reversed[j], reversed[i]
}
return reversed, nil
}
func BIP68Decode(sequence []byte) (uint, error) {
var asNumber int64
for i := len(sequence) - 1; i >= 0; i-- {
asNumber = asNumber<<8 | int64(sequence[i])
if scriptNumber >= txscript.OP_1 && scriptNumber <= txscript.OP_16 {
scriptNumber = scriptNumber - (txscript.OP_1 - 1)
}
asNumber := int64(scriptNumber)
if asNumber&SEQUENCE_LOCKTIME_DISABLE_FLAG != 0 {
return 0, fmt.Errorf("sequence is disabled")
}
@@ -61,5 +54,6 @@ func BIP68Decode(sequence []byte) (uint, error) {
seconds := asNumber & SEQUENCE_LOCKTIME_MASK << SEQUENCE_LOCKTIME_GRANULARITY
return uint(seconds), nil
}
return 0, fmt.Errorf("sequence is encoded as block number")
return uint(asNumber), nil
}

View File

@@ -1,43 +0,0 @@
package common_test
import (
"encoding/json"
"os"
"testing"
sdk "github.com/ark-network/ark/common"
"github.com/stretchr/testify/require"
)
func TestBIP68(t *testing.T) {
data, err := os.ReadFile("fixtures/bip68.json")
require.NoError(t, err)
var testCases []struct {
Input uint `json:"seconds"`
Expected int64 `json:"sequence"`
Desc string `json:"description"`
}
err = json.Unmarshal(data, &testCases)
require.NoError(t, err)
require.NotEmpty(t, testCases)
for _, tc := range testCases {
t.Run(tc.Desc, func(t *testing.T) {
actual, err := sdk.BIP68Encode(tc.Input)
require.NoError(t, err)
var asNumber int64
for i := len(actual) - 1; i >= 0; i-- {
asNumber = asNumber<<8 | int64(actual[i])
}
require.Equal(t, tc.Expected, asNumber)
decoded, err := sdk.BIP68Decode(actual)
require.NoError(t, err)
require.Equal(t, tc.Input, decoded)
})
}
}

View File

@@ -109,9 +109,12 @@ func (d *CSVSigClosure) Decode(script []byte) (bool, error) {
return false, nil
}
sequence := script[1:csvIndex]
sequence := script[:csvIndex]
if len(sequence) > 1 {
sequence = sequence[1:]
}
seconds, err := common.BIP68Decode(sequence)
seconds, err := common.BIP68DecodeSequence(sequence)
if err != nil {
return false, err
}
@@ -162,15 +165,18 @@ func decodeChecksigScript(script []byte) (bool, *secp256k1.PublicKey, error) {
// checkSequenceVerifyScript without checksig
func encodeCsvScript(seconds uint) ([]byte, error) {
sequence, err := common.BIP68Encode(seconds)
sequence, err := common.BIP68Sequence(seconds)
if err != nil {
return nil, err
}
return txscript.NewScriptBuilder().AddData(sequence).AddOps([]byte{
txscript.OP_CHECKSEQUENCEVERIFY,
txscript.OP_DROP,
}).Script()
return txscript.NewScriptBuilder().
AddInt64(int64(sequence)).
AddOps([]byte{
txscript.OP_CHECKSEQUENCEVERIFY,
txscript.OP_DROP,
}).
Script()
}
// checkSequenceVerifyScript + checksig

View File

@@ -0,0 +1,30 @@
package bitcointree_test
import (
"testing"
"github.com/ark-network/ark/common/bitcointree"
"github.com/decred/dcrd/dcrec/secp256k1/v4"
"github.com/stretchr/testify/require"
)
func TestRoundTripCSV(t *testing.T) {
seckey, err := secp256k1.GeneratePrivateKey()
require.NoError(t, err)
csvSig := &bitcointree.CSVSigClosure{
Pubkey: seckey.PubKey(),
Seconds: 1024,
}
leaf, err := csvSig.Leaf()
require.NoError(t, err)
var cl bitcointree.CSVSigClosure
valid, err := cl.Decode(leaf.Script)
require.NoError(t, err)
require.True(t, valid)
require.Equal(t, csvSig.Seconds, cl.Seconds)
}

View File

@@ -130,15 +130,18 @@ func (e *Older) Parse(policy string) error {
}
func (e *Older) Script(bool) (string, error) {
sequence, err := common.BIP68Encode(e.Timeout)
sequence, err := common.BIP68Sequence(e.Timeout)
if err != nil {
return "", err
}
script, err := txscript.NewScriptBuilder().AddData(sequence).AddOps([]byte{
txscript.OP_CHECKSEQUENCEVERIFY,
txscript.OP_DROP,
}).Script()
script, err := txscript.NewScriptBuilder().
AddInt64(int64(sequence)).
AddOps([]byte{
txscript.OP_CHECKSEQUENCEVERIFY,
txscript.OP_DROP,
}).
Script()
if err != nil {
return "", err
}

View File

@@ -1,67 +0,0 @@
[
{
"description": "0x00400000 (00000000010000000000000000000000)",
"seconds": 0,
"sequence": 4194304
},
{
"description": "0x00400001 (00000000010000000000000000000001)",
"seconds": 512,
"sequence": 4194305
},
{
"description": "0x00400002 (00000000010000000000000000000010)",
"seconds": 1024,
"sequence": 4194306
},
{
"description": "0x00400003 (00000000010000000000000000000011)",
"seconds": 1536,
"sequence": 4194307
},
{
"description": "0x00400004 (00000000010000000000000000000100)",
"seconds": 2048,
"sequence": 4194308
},
{
"description": "0x00400005 (00000000010000000000000000000101)",
"seconds": 2560,
"sequence": 4194309
},
{
"description": "0x00400006 (00000000010000000000000000000110)",
"seconds": 3072,
"sequence": 4194310
},
{
"description": "0x00400007 (00000000010000000000000000000111)",
"seconds": 3584,
"sequence": 4194311
},
{
"description": "0x00400008 (00000000010000000000000000001000)",
"seconds": 4096,
"sequence": 4194312
},
{
"description": "0x00400009 (00000000010000000000000000001001)",
"seconds": 4608,
"sequence": 4194313
},
{
"description": "0x0040000a (00000000010000000000000000001010)",
"seconds": 5120,
"sequence": 4194314
},
{
"description": "0x0040000b (00000000010000000000000000001011)",
"seconds": 5632,
"sequence": 4194315
},
{
"description": "0x0040000c (00000000010000000000000000001100)",
"seconds": 6144,
"sequence": 4194316
}
]

View File

@@ -130,9 +130,12 @@ func (d *CSVSigClosure) Decode(script []byte) (bool, error) {
return false, nil
}
sequence := script[1:csvIndex]
sequence := script[:csvIndex]
if len(sequence) > 1 {
sequence = sequence[1:]
}
seconds, err := common.BIP68Decode(sequence)
seconds, err := common.BIP68DecodeSequence(sequence)
if err != nil {
return false, err
}
@@ -369,15 +372,18 @@ func decodeChecksigScript(script []byte) (bool, *secp256k1.PublicKey, error) {
// checkSequenceVerifyScript without checksig
func encodeCsvScript(seconds uint) ([]byte, error) {
sequence, err := common.BIP68Encode(seconds)
sequence, err := common.BIP68Sequence(seconds)
if err != nil {
return nil, err
}
return txscript.NewScriptBuilder().AddData(sequence).AddOps([]byte{
txscript.OP_CHECKSEQUENCEVERIFY,
txscript.OP_DROP,
}).Script()
return txscript.NewScriptBuilder().
AddInt64(int64(sequence)).
AddOps([]byte{
txscript.OP_CHECKSEQUENCEVERIFY,
txscript.OP_DROP,
}).
Script()
}
// checkSequenceVerifyScript + checksig

View File

@@ -9,12 +9,13 @@ services:
- ARK_ROUND_INTERVAL=10
- ARK_NETWORK=regtest
- ARK_LOG_LEVEL=5
- ARK_ROUND_LIFETIME=512
- ARK_ROUND_LIFETIME=20
- ARK_TX_BUILDER_TYPE=covenantless
- ARK_ESPLORA_URL=http://chopsticks:3000
- ARK_BITCOIND_RPC_USER=admin1
- ARK_BITCOIND_RPC_PASS=123
- ARK_BITCOIND_RPC_HOST=bitcoin:18443
- ARK_SCHEDULER_TYPE=block
- ARK_NO_TLS=true
- ARK_NO_MACAROONS=true
- ARK_DATADIR=/app/data

View File

@@ -31,7 +31,9 @@ services:
- ARK_ROUND_INTERVAL=10
- ARK_NETWORK=liquidregtest
- ARK_LOG_LEVEL=5
- ARK_ROUND_LIFETIME=512
- ARK_ESPLORA_URL=http://chopsticks-liquid:3000
- ARK_ROUND_LIFETIME=20
- ARK_SCHEDULER_TYPE=block
- ARK_DB_TYPE=sqlite
- ARK_TX_BUILDER_TYPE=covenant
- ARK_PORT=6060

View File

@@ -35,7 +35,7 @@ type Utxo struct {
}
func (u *Utxo) Sequence() (uint32, error) {
return common.BIP68EncodeAsNumber(u.Delay)
return common.BIP68Sequence(u.Delay)
}
func newUtxo(explorerUtxo ExplorerUtxo, delay uint) Utxo {

View File

@@ -23,8 +23,8 @@ help:
## intergrationtest: runs integration tests
integrationtest:
@echo "Running integration tests..."
@go test -v -count 1 -timeout 300s github.com/ark-network/ark/server/test/e2e/covenant
@go test -v -count 1 -timeout 300s github.com/ark-network/ark/server/test/e2e/covenantless
@go test -v -count 1 -timeout 400s github.com/ark-network/ark/server/test/e2e/covenant
@go test -v -count 1 -timeout 400s github.com/ark-network/ark/server/test/e2e/covenantless
## lint: lint codebase
lint:

View File

@@ -8,7 +8,8 @@ import (
"github.com/ark-network/ark/server/internal/core/application"
"github.com/ark-network/ark/server/internal/core/ports"
"github.com/ark-network/ark/server/internal/infrastructure/db"
scheduler "github.com/ark-network/ark/server/internal/infrastructure/scheduler/gocron"
blockscheduler "github.com/ark-network/ark/server/internal/infrastructure/scheduler/block"
timescheduler "github.com/ark-network/ark/server/internal/infrastructure/scheduler/gocron"
txbuilder "github.com/ark-network/ark/server/internal/infrastructure/tx-builder/covenant"
cltxbuilder "github.com/ark-network/ark/server/internal/infrastructure/tx-builder/covenantless"
fileunlocker "github.com/ark-network/ark/server/internal/infrastructure/unlocker/file"
@@ -29,6 +30,7 @@ var (
}
supportedSchedulers = supportedType{
"gocron": {},
"block": {},
}
supportedTxBuilders = supportedType{
"covenant": {},
@@ -115,11 +117,23 @@ func (c *Config) Validate() error {
if len(c.WalletAddr) <= 0 {
return fmt.Errorf("missing onchain wallet address")
}
// round life time must be a multiple of 512
if c.RoundLifetime < minAllowedSequence {
return fmt.Errorf(
"invalid round lifetime, must be a at least %d", minAllowedSequence,
)
if c.SchedulerType != "block" {
return fmt.Errorf("scheduler type must be block if round lifetime is expressed in blocks")
}
} else {
if c.SchedulerType != "gocron" {
return fmt.Errorf("scheduler type must be gocron if round lifetime is expressed in seconds")
}
// round life time must be a multiple of 512 if expressed in seconds
if c.RoundLifetime%minAllowedSequence != 0 {
c.RoundLifetime -= c.RoundLifetime % minAllowedSequence
log.Infof(
"round lifetime must be a multiple of %d, rounded to %d",
minAllowedSequence, c.RoundLifetime,
)
}
}
if c.UnilateralExitDelay < minAllowedSequence {
@@ -134,14 +148,6 @@ func (c *Config) Validate() error {
)
}
if c.RoundLifetime%minAllowedSequence != 0 {
c.RoundLifetime -= c.RoundLifetime % minAllowedSequence
log.Infof(
"round lifetime must be a multiple of %d, rounded to %d",
minAllowedSequence, c.RoundLifetime,
)
}
if c.UnilateralExitDelay%minAllowedSequence != 0 {
c.UnilateralExitDelay -= c.UnilateralExitDelay % minAllowedSequence
log.Infof(
@@ -328,7 +334,9 @@ func (c *Config) schedulerService() error {
var err error
switch c.SchedulerType {
case "gocron":
svc = scheduler.NewScheduler()
svc = timescheduler.NewScheduler()
case "block":
svc, err = blockscheduler.NewScheduler(c.EsploraURL)
default:
err = fmt.Errorf("unknown scheduler type")
}
@@ -367,7 +375,12 @@ func (c *Config) appService() error {
}
func (c *Config) adminService() error {
c.adminSvc = application.NewAdminService(c.wallet, c.repo, c.txBuilder)
unit := ports.UnixTime
if c.RoundLifetime < minAllowedSequence {
unit = ports.BlockHeight
}
c.adminSvc = application.NewAdminService(c.wallet, c.repo, c.txBuilder, unit)
return nil
}

View File

@@ -50,16 +50,18 @@ type AdminService interface {
}
type adminService struct {
walletSvc ports.WalletService
repoManager ports.RepoManager
txBuilder ports.TxBuilder
walletSvc ports.WalletService
repoManager ports.RepoManager
txBuilder ports.TxBuilder
sweeperTimeUnit ports.TimeUnit
}
func NewAdminService(walletSvc ports.WalletService, repoManager ports.RepoManager, txBuilder ports.TxBuilder) AdminService {
func NewAdminService(walletSvc ports.WalletService, repoManager ports.RepoManager, txBuilder ports.TxBuilder, timeUnit ports.TimeUnit) AdminService {
return &adminService{
walletSvc: walletSvc,
repoManager: repoManager,
txBuilder: txBuilder,
walletSvc: walletSvc,
repoManager: repoManager,
txBuilder: txBuilder,
sweeperTimeUnit: timeUnit,
}
}
@@ -130,7 +132,7 @@ func (a *adminService) GetScheduledSweeps(ctx context.Context) ([]ScheduledSweep
for _, round := range sweepableRounds {
sweepable, err := findSweepableOutputs(
ctx, a.walletSvc, a.txBuilder, round.CongestionTree,
ctx, a.walletSvc, a.txBuilder, a.sweeperTimeUnit, round.CongestionTree,
)
if err != nil {
return nil, err

View File

@@ -177,7 +177,7 @@ func (s *covenantService) SpendVtxos(ctx context.Context, inputs []ports.Input)
return "", fmt.Errorf("failed to parse tx %s: %s", input.Txid, err)
}
confirmed, blocktime, err := s.wallet.IsTransactionConfirmed(ctx, input.Txid)
confirmed, _, blocktime, err := s.wallet.IsTransactionConfirmed(ctx, input.Txid)
if err != nil {
return "", fmt.Errorf("failed to check tx %s: %s", input.Txid, err)
}
@@ -910,12 +910,10 @@ func (s *covenantService) scheduleSweepVtxosForRound(round *domain.Round) {
return
}
expirationTimestamp := time.Now().Add(
time.Duration(s.roundLifetime+30) * time.Second,
)
expirationTime := s.sweeper.scheduler.AddNow(s.roundLifetime)
if err := s.sweeper.schedule(
expirationTimestamp.Unix(), round.Txid, round.CongestionTree,
expirationTime, round.Txid, round.CongestionTree,
); err != nil {
log.WithError(err).Warn("failed to schedule sweep tx")
}

View File

@@ -421,7 +421,7 @@ func (s *covenantlessService) SpendVtxos(ctx context.Context, inputs []ports.Inp
return "", fmt.Errorf("failed to deserialize tx %s: %s", input.Txid, err)
}
confirmed, blocktime, err := s.wallet.IsTransactionConfirmed(ctx, input.Txid)
confirmed, _, blocktime, err := s.wallet.IsTransactionConfirmed(ctx, input.Txid)
if err != nil {
return "", fmt.Errorf("failed to check tx %s: %s", input.Txid, err)
}
@@ -1316,13 +1316,9 @@ func (s *covenantlessService) scheduleSweepVtxosForRound(round *domain.Round) {
return
}
expirationTimestamp := time.Now().Add(
time.Duration(s.roundLifetime+30) * time.Second,
)
expirationTimestamp := s.sweeper.scheduler.AddNow(s.roundLifetime)
if err := s.sweeper.schedule(
expirationTimestamp.Unix(), round.Txid, round.CongestionTree,
); err != nil {
if err := s.sweeper.schedule(expirationTimestamp, round.Txid, round.CongestionTree); err != nil {
log.WithError(err).Warn("failed to schedule sweep tx")
}
}

View File

@@ -3,6 +3,7 @@ package application
import (
"context"
"fmt"
"sync"
"time"
"github.com/ark-network/ark/common/tree"
@@ -22,6 +23,7 @@ type sweeper struct {
scheduler ports.SchedulerService
// cache of scheduled tasks, avoid scheduling the same sweep event multiple times
locker sync.Locker
scheduledTasks map[string]struct{}
}
@@ -36,6 +38,7 @@ func newSweeper(
repoManager,
builder,
scheduler,
&sync.Mutex{},
make(map[string]struct{}),
}
}
@@ -62,6 +65,8 @@ func (s *sweeper) stop() {
// removeTask update the cached map of scheduled tasks
func (s *sweeper) removeTask(treeRootTxid string) {
s.locker.Lock()
defer s.locker.Unlock()
delete(s.scheduledTasks, treeRootTxid)
}
@@ -84,13 +89,22 @@ func (s *sweeper) schedule(
}
task := s.createTask(roundTxid, congestionTree)
fancyTime := time.Unix(expirationTimestamp, 0).Format("2006-01-02 15:04:05")
var fancyTime string
if s.scheduler.Unit() == ports.UnixTime {
fancyTime = time.Unix(expirationTimestamp, 0).Format("2006-01-02 15:04:05")
} else {
fancyTime = fmt.Sprintf("block %d", expirationTimestamp)
}
log.Debugf("scheduled sweep for round %s at %s", roundTxid, fancyTime)
if err := s.scheduler.ScheduleTaskOnce(expirationTimestamp, task); err != nil {
return err
}
s.locker.Lock()
s.scheduledTasks[root.Txid] = struct{}{}
s.locker.Unlock()
if err := s.updateVtxoExpirationTime(congestionTree, expirationTimestamp); err != nil {
log.WithError(err).Error("error while updating vtxo expiration time")
@@ -120,7 +134,7 @@ func (s *sweeper) createTask(
vtxoKeys := make([]domain.VtxoKey, 0) // vtxos associated to the sweep inputs
// inspect the congestion tree to find onchain shared outputs
sharedOutputs, err := findSweepableOutputs(ctx, s.wallet, s.builder, congestionTree)
sharedOutputs, err := findSweepableOutputs(ctx, s.wallet, s.builder, s.scheduler.Unit(), congestionTree)
if err != nil {
log.WithError(err).Error("error while inspecting congestion tree")
return
@@ -128,7 +142,7 @@ func (s *sweeper) createTask(
for expiredAt, inputs := range sharedOutputs {
// if the shared outputs are not expired, schedule a sweep task for it
if time.Unix(expiredAt, 0).After(time.Now()) {
if s.scheduler.AfterNow(expiredAt) {
subtrees, err := computeSubTrees(congestionTree, inputs)
if err != nil {
log.WithError(err).Error("error while computing subtrees")
@@ -136,8 +150,7 @@ func (s *sweeper) createTask(
}
for _, subTree := range subtrees {
// mitigate the risk to get BIP68 non-final errors by scheduling the task 30 seconds after the expiration time
if err := s.schedule(int64(expiredAt), roundTxid, subTree); err != nil {
if err := s.schedule(expiredAt, roundTxid, subTree); err != nil {
log.WithError(err).Error("error while scheduling sweep task")
continue
}

View File

@@ -259,17 +259,18 @@ func findSweepableOutputs(
ctx context.Context,
walletSvc ports.WalletService,
txbuilder ports.TxBuilder,
schedulerUnit ports.TimeUnit,
congestionTree tree.CongestionTree,
) (map[int64][]ports.SweepInput, error) {
sweepableOutputs := make(map[int64][]ports.SweepInput)
blocktimeCache := make(map[string]int64) // txid -> blocktime
blocktimeCache := make(map[string]int64) // txid -> blocktime / blockheight
nodesToCheck := congestionTree[0] // init with the root
for len(nodesToCheck) > 0 {
newNodesToCheck := make([]tree.Node, 0)
for _, node := range nodesToCheck {
isConfirmed, blocktime, err := walletSvc.IsTransactionConfirmed(ctx, node.Txid)
isConfirmed, height, blocktime, err := walletSvc.IsTransactionConfirmed(ctx, node.Txid)
if err != nil {
return nil, err
}
@@ -279,21 +280,31 @@ func findSweepableOutputs(
if !isConfirmed {
if _, ok := blocktimeCache[node.ParentTxid]; !ok {
isConfirmed, blocktime, err := walletSvc.IsTransactionConfirmed(ctx, node.ParentTxid)
isConfirmed, height, blocktime, err := walletSvc.IsTransactionConfirmed(ctx, node.ParentTxid)
if !isConfirmed || err != nil {
return nil, fmt.Errorf("tx %s not found", node.ParentTxid)
}
blocktimeCache[node.ParentTxid] = blocktime
if schedulerUnit == ports.BlockHeight {
blocktimeCache[node.ParentTxid] = height
} else {
blocktimeCache[node.ParentTxid] = blocktime
}
}
expirationTime, sweepInput, err = txbuilder.GetSweepInput(blocktimeCache[node.ParentTxid], node)
var lifetime int64
lifetime, sweepInput, err = txbuilder.GetSweepInput(node)
if err != nil {
return nil, err
}
expirationTime = blocktimeCache[node.ParentTxid] + lifetime
} else {
// cache the blocktime for future use
blocktimeCache[node.Txid] = int64(blocktime)
if schedulerUnit == ports.BlockHeight {
blocktimeCache[node.Txid] = height
} else {
blocktimeCache[node.Txid] = blocktime
}
// if the tx is onchain, it means that the input is spent
// add the children to the nodes in order to check them during the next iteration

View File

@@ -14,5 +14,5 @@ type BlockchainScanner interface {
WatchScripts(ctx context.Context, scripts []string) error
UnwatchScripts(ctx context.Context, scripts []string) error
GetNotificationChannel(ctx context.Context) <-chan map[string][]VtxoWithValue
IsTransactionConfirmed(ctx context.Context, txid string) (isConfirmed bool, blocktime int64, err error)
IsTransactionConfirmed(ctx context.Context, txid string) (isConfirmed bool, blocknumber int64, blocktime int64, err error)
}

View File

@@ -1,9 +1,18 @@
package ports
type TimeUnit int
const (
UnixTime TimeUnit = iota
BlockHeight
)
type SchedulerService interface {
Start()
Stop()
ScheduleTask(interval int64, immediate bool, task func()) error
ScheduleTaskOnce(delay int64, task func()) error
Unit() TimeUnit
AddNow(lifetime int64) int64
AfterNow(expiry int64) bool
ScheduleTaskOnce(at int64, task func()) error
}

View File

@@ -34,7 +34,7 @@ type TxBuilder interface {
) (roundTx string, congestionTree tree.CongestionTree, connectorAddress string, err error)
BuildForfeitTxs(poolTx string, payments []domain.Payment, minRelayFeeRate chainfee.SatPerKVByte) (connectors []string, forfeitTxs []string, err error)
BuildSweepTx(inputs []SweepInput) (signedSweepTx string, err error)
GetSweepInput(parentblocktime int64, node tree.Node) (expirationtime int64, sweepInput SweepInput, err error)
GetSweepInput(node tree.Node) (lifetime int64, sweepInput SweepInput, err error)
FinalizeAndExtract(tx string) (txhex string, err error)
VerifyTapscriptPartialSigs(tx string) (valid bool, txid string, err error)
// FindLeaves returns all the leaves txs that are reachable from the given outpoint

View File

@@ -117,7 +117,7 @@ func (r *vtxoRepository) GetAllVtxos(
spentVtxos := make([]domain.Vtxo, 0, len(vtxos))
unspentVtxos := make([]domain.Vtxo, 0, len(vtxos))
for _, vtxo := range vtxos {
if vtxo.Spent {
if vtxo.Spent || vtxo.Swept {
spentVtxos = append(spentVtxos, vtxo)
} else {
unspentVtxos = append(unspentVtxos, vtxo)

View File

@@ -113,7 +113,7 @@ func (v *vxtoRepository) GetAllVtxos(ctx context.Context, pubkey string) ([]doma
spentVtxos := make([]domain.Vtxo, 0)
for _, vtxo := range vtxos {
if vtxo.Spent {
if vtxo.Spent || vtxo.Swept {
spentVtxos = append(spentVtxos, vtxo)
} else {
unspentVtxos = append(unspentVtxos, vtxo)

View File

@@ -0,0 +1,147 @@
package blockscheduler
import (
"fmt"
"net/http"
"net/url"
"sync"
"time"
"github.com/ark-network/ark/server/internal/core/ports"
"github.com/sirupsen/logrus"
)
const tipHeightEndpoit = "/blocks/tip/height"
type service struct {
tipURL string
lock sync.Locker
taskes map[int64][]func()
stopCh chan struct{}
}
func NewScheduler(esploraURL string) (ports.SchedulerService, error) {
if len(esploraURL) == 0 {
return nil, fmt.Errorf("esplora URL is required")
}
tipURL, err := url.JoinPath(esploraURL, tipHeightEndpoit)
if err != nil {
return nil, err
}
return &service{
tipURL,
&sync.Mutex{},
make(map[int64][]func()),
make(chan struct{}),
}, nil
}
func (s *service) Start() {
go func() {
for {
select {
case <-s.stopCh:
return
default:
time.Sleep(10 * time.Second)
taskes, err := s.popTaskes()
if err != nil {
fmt.Println("error fetching tasks:", err)
continue
}
logrus.Debugf("fetched %d tasks", len(taskes))
for _, task := range taskes {
go task()
}
}
}
}()
}
func (s *service) Stop() {
s.stopCh <- struct{}{}
close(s.stopCh)
}
func (s *service) Unit() ports.TimeUnit {
return ports.BlockHeight
}
func (s *service) AddNow(lifetime int64) int64 {
tip, err := s.fetchTipHeight()
if err != nil {
return 0
}
return tip + lifetime
}
func (s *service) AfterNow(expiry int64) bool {
tip, err := s.fetchTipHeight()
if err != nil {
return false
}
return expiry > tip
}
func (s *service) ScheduleTaskOnce(at int64, task func()) error {
s.lock.Lock()
defer s.lock.Unlock()
if _, ok := s.taskes[at]; !ok {
s.taskes[at] = make([]func(), 0)
}
s.taskes[at] = append(s.taskes[at], task)
return nil
}
func (s *service) popTaskes() ([]func(), error) {
s.lock.Lock()
defer s.lock.Unlock()
tip, err := s.fetchTipHeight()
if err != nil {
return nil, err
}
taskes := make([]func(), 0)
for height, tasks := range s.taskes {
if height > tip {
continue
}
taskes = append(taskes, tasks...)
delete(s.taskes, height)
}
return taskes, nil
}
func (s *service) fetchTipHeight() (int64, error) {
resp, err := http.Get(s.tipURL)
if err != nil {
return 0, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return 0, fmt.Errorf("unexpected status code: %d", resp.StatusCode)
}
var tip int64
if _, err := fmt.Fscanf(resp.Body, "%d", &tip); err != nil {
return 0, err
}
logrus.Debugf("fetching tip height from %s, got %d", s.tipURL, tip)
return tip, nil
}

View File

@@ -1,4 +1,4 @@
package scheduler
package timescheduler
import (
"fmt"
@@ -17,6 +17,18 @@ func NewScheduler() ports.SchedulerService {
return &service{svc}
}
func (s *service) Unit() ports.TimeUnit {
return ports.UnixTime
}
func (s *service) AddNow(lifetime int64) int64 {
return time.Now().Add(time.Duration(lifetime) * time.Second).Unix()
}
func (s *service) AfterNow(expiry int64) bool {
return time.Unix(expiry, 0).After(time.Now())
}
func (s *service) Start() {
s.scheduler.StartAsync()
}
@@ -25,15 +37,6 @@ func (s *service) Stop() {
s.scheduler.Stop()
}
func (s *service) ScheduleTask(interval int64, immediate bool, task func()) error {
if immediate {
_, err := s.scheduler.Every(int(interval)).Seconds().Do(task)
return err
}
_, err := s.scheduler.Every(int(interval)).Seconds().WaitForSchedule().Do(task)
return err
}
func (s *service) ScheduleTaskOnce(at int64, task func()) error {
delay := at - time.Now().Unix()
if delay < 0 {

View File

@@ -195,7 +195,7 @@ func (b *txBuilder) BuildRoundTx(
return
}
func (b *txBuilder) GetSweepInput(parentblocktime int64, node tree.Node) (expirationtime int64, sweepInput ports.SweepInput, err error) {
func (b *txBuilder) GetSweepInput(node tree.Node) (lifetime int64, sweepInput ports.SweepInput, err error) {
pset, err := psetv2.NewPsetFromBase64(node.Tx)
if err != nil {
return -1, nil, err
@@ -215,8 +215,6 @@ func (b *txBuilder) GetSweepInput(parentblocktime int64, node tree.Node) (expira
return -1, nil, err
}
expirationTime := parentblocktime + lifetime
txhex, err := b.wallet.GetTransaction(context.Background(), txid)
if err != nil {
return -1, nil, err
@@ -241,7 +239,7 @@ func (b *txBuilder) GetSweepInput(parentblocktime int64, node tree.Node) (expira
amount: inputValue,
}
return expirationTime, sweepInput, nil
return lifetime, sweepInput, nil
}
func (b *txBuilder) VerifyTapscriptPartialSigs(tx string) (bool, string, error) {

View File

@@ -151,7 +151,7 @@ func (m *mockedWallet) GetDustAmount(ctx context.Context) (uint64, error) {
return res, args.Error(1)
}
func (m *mockedWallet) IsTransactionConfirmed(ctx context.Context, txid string) (bool, int64, error) {
func (m *mockedWallet) IsTransactionConfirmed(ctx context.Context, txid string) (bool, int64, int64, error) {
args := m.Called(ctx, txid)
var res bool
@@ -159,12 +159,17 @@ func (m *mockedWallet) IsTransactionConfirmed(ctx context.Context, txid string)
res = a.(bool)
}
var height int64
if h := args.Get(1); h != nil {
height = h.(int64)
}
var blocktime int64
if b := args.Get(1); b != nil {
blocktime = b.(int64)
}
return res, blocktime, args.Error(2)
return res, height, blocktime, args.Error(2)
}
func (m *mockedWallet) SignTransactionTapscript(ctx context.Context, pset string, inputIndexes []int) (string, error) {

View File

@@ -90,7 +90,7 @@ func sweepTransaction(
return nil, err
}
sequence, err := common.BIP68EncodeAsNumber(sweepClosure.Seconds)
sequence, err := common.BIP68Sequence(sweepClosure.Seconds)
if err != nil {
return nil, err
}

View File

@@ -324,7 +324,7 @@ func (b *txBuilder) BuildRoundTx(
return
}
func (b *txBuilder) GetSweepInput(parentblocktime int64, node tree.Node) (expirationtime int64, sweepInput ports.SweepInput, err error) {
func (b *txBuilder) GetSweepInput(node tree.Node) (lifetime int64, sweepInput ports.SweepInput, err error) {
partialTx, err := psbt.NewFromRawBytes(strings.NewReader(node.Tx), true)
if err != nil {
return -1, nil, err
@@ -343,8 +343,6 @@ func (b *txBuilder) GetSweepInput(parentblocktime int64, node tree.Node) (expira
return -1, nil, err
}
expirationTime := parentblocktime + lifetime
txhex, err := b.wallet.GetTransaction(context.Background(), txid.String())
if err != nil {
return -1, nil, err
@@ -365,7 +363,7 @@ func (b *txBuilder) GetSweepInput(parentblocktime int64, node tree.Node) (expira
amount: tx.TxOut[index].Value,
}
return expirationTime, sweepInput, nil
return lifetime, sweepInput, nil
}
func (b *txBuilder) FindLeaves(congestionTree tree.CongestionTree, fromtxid string, vout uint32) ([]tree.Node, error) {
@@ -1220,6 +1218,8 @@ func extractSweepLeaf(input psbt.PInput) (sweepLeaf *psbt.TaprootTapLeafScript,
if err != nil {
return nil, nil, 0, err
}
fmt.Println("closure", valid)
if valid && closure.Seconds > 0 {
sweepLeaf = leaf
lifetime = int64(closure.Seconds)

View File

@@ -172,7 +172,7 @@ func (m *mockedWallet) GetDustAmount(ctx context.Context) (uint64, error) {
return res, args.Error(1)
}
func (m *mockedWallet) IsTransactionConfirmed(ctx context.Context, txid string) (bool, int64, error) {
func (m *mockedWallet) IsTransactionConfirmed(ctx context.Context, txid string) (bool, int64, int64, error) {
args := m.Called(ctx, txid)
var res bool
@@ -180,12 +180,17 @@ func (m *mockedWallet) IsTransactionConfirmed(ctx context.Context, txid string)
res = a.(bool)
}
var height int64
if h := args.Get(1); h != nil {
height = h.(int64)
}
var blocktime int64
if b := args.Get(1); b != nil {
blocktime = b.(int64)
}
return res, blocktime, args.Error(2)
return res, height, blocktime, args.Error(2)
}
func (m *mockedWallet) SignTransactionTapscript(ctx context.Context, pset string, inputIndexes []int) (string, error) {

View File

@@ -37,7 +37,7 @@ func sweepTransaction(
return nil, fmt.Errorf("invalid csv script")
}
sequence, err := common.BIP68EncodeAsNumber(sweepClosure.Seconds)
sequence, err := common.BIP68Sequence(sweepClosure.Seconds)
if err != nil {
return nil, err
}

View File

@@ -33,21 +33,31 @@ func (b *bitcoindRPCClient) broadcast(txhex string) error {
return nil
}
func (b *bitcoindRPCClient) getTxStatus(txid string) (isConfirmed bool, blocktime int64, err error) {
func (b *bitcoindRPCClient) getTxStatus(txid string) (isConfirmed bool, height, blocktime int64, err error) {
txhash, err := chainhash.NewHashFromStr(txid)
if err != nil {
return false, 0, err
return false, 0, 0, err
}
tx, err := b.chainClient.GetRawTransactionVerbose(txhash)
if err != nil {
if strings.Contains(err.Error(), "No such mempool or blockchain transaction") {
return false, 0, nil
return false, 0, 0, nil
}
return false, 0, err
return false, 0, 0, err
}
return tx.Confirmations > 0, tx.Blocktime, nil
blockHash, err := chainhash.NewHashFromStr(tx.BlockHash)
if err != nil {
return false, 0, 0, err
}
blockHeight, err := b.chainClient.GetBlockHeight(blockHash)
if err != nil {
return false, 0, 0, err
}
return tx.Confirmations > 0, int64(blockHeight), tx.Blocktime, nil
}
func (b *bitcoindRPCClient) getTx(txid string) (*wire.MsgTx, error) {

View File

@@ -20,8 +20,9 @@ type esploraClient struct {
type esploraTx struct {
Status struct {
Confirmed bool `json:"confirmed"`
BlockTime int64 `json:"block_time"`
Confirmed bool `json:"confirmed"`
BlockTime int64 `json:"block_time"`
BlockNumber int64 `json:"block_height"`
} `json:"status"`
}
@@ -79,30 +80,30 @@ func (f *esploraClient) getTx(txid string) (*wire.MsgTx, error) {
return &tx, nil
}
func (f *esploraClient) getTxStatus(txid string) (isConfirmed bool, blocktime int64, err error) {
func (f *esploraClient) getTxStatus(txid string) (isConfirmed bool, blocknumber, blocktime int64, err error) {
endpoint, err := url.JoinPath(f.url, "tx", txid)
if err != nil {
return false, 0, err
return false, 0, 0, err
}
resp, err := http.DefaultClient.Get(endpoint)
if err != nil {
return false, 0, err
return false, 0, 0, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return false, 0, err
return false, 0, 0, err
}
var response esploraTx
if err := json.NewDecoder(resp.Body).Decode(&response); err != nil {
return false, 0, err
return false, 0, 0, err
}
return response.Status.Confirmed, response.Status.BlockTime, nil
return response.Status.Confirmed, response.Status.BlockNumber, response.Status.BlockTime, nil
}
// GetFeeMap returns a map of sat/vbyte fees for different confirmation targets

View File

@@ -92,7 +92,7 @@ var (
// add additional chain API not supported by the chain.Interface type
type extraChainAPI interface {
getTx(txid string) (*wire.MsgTx, error)
getTxStatus(txid string) (isConfirmed bool, blocktime int64, err error)
getTxStatus(txid string) (isConfirmed bool, blockHeight, blocktime int64, err error)
broadcast(txHex string) error
}
@@ -957,7 +957,7 @@ func (s *service) GetNotificationChannel(
func (s *service) IsTransactionConfirmed(
ctx context.Context, txid string,
) (isConfirmed bool, blocktime int64, err error) {
) (isConfirmed bool, blocknumber int64, blocktime int64, err error) {
return s.extraAPI.getTxStatus(txid)
}

View File

@@ -158,21 +158,21 @@ func (s *service) BroadcastTransaction(
func (s *service) IsTransactionConfirmed(
ctx context.Context, txid string,
) (bool, int64, error) {
_, isConfirmed, blocktime, err := s.getTransaction(ctx, txid)
) (bool, int64, int64, error) {
_, isConfirmed, blockheight, blocktime, err := s.getTransaction(ctx, txid)
if err != nil {
if strings.Contains(strings.ToLower(err.Error()), "missing transaction") {
return isConfirmed, 0, nil
return isConfirmed, 0, 0, nil
}
return false, 0, err
return false, 0, 0, err
}
return isConfirmed, blocktime, nil
return isConfirmed, blockheight, blocktime, nil
}
func (s *service) WaitForSync(ctx context.Context, txid string) error {
for {
time.Sleep(5 * time.Second)
_, _, _, err := s.getTransaction(ctx, txid)
_, _, _, _, err := s.getTransaction(ctx, txid)
if err != nil {
if strings.Contains(strings.ToLower(err.Error()), "missing transaction") {
continue
@@ -351,7 +351,7 @@ func (s *service) EstimateFees(
}
func (s *service) GetTransaction(ctx context.Context, txid string) (string, error) {
txHex, _, _, err := s.getTransaction(ctx, txid)
txHex, _, _, _, err := s.getTransaction(ctx, txid)
if err != nil {
return "", err
}
@@ -361,18 +361,18 @@ func (s *service) GetTransaction(ctx context.Context, txid string) (string, erro
func (s *service) getTransaction(
ctx context.Context, txid string,
) (string, bool, int64, error) {
) (string, bool, int64, int64, error) {
res, err := s.txClient.GetTransaction(ctx, &pb.GetTransactionRequest{
Txid: txid,
})
if err != nil {
return "", false, 0, err
return "", false, 0, 0, err
}
if res.GetBlockDetails().GetTimestamp() > 0 {
return res.GetTxHex(), true, res.BlockDetails.GetTimestamp(), nil
return res.GetTxHex(), true, int64(res.GetBlockDetails().GetHeight()), res.BlockDetails.GetTimestamp(), nil
}
// if not confirmed, we return now + 1 min to estimate the next blocktime
return res.GetTxHex(), false, time.Now().Add(time.Minute).Unix(), nil
return res.GetTxHex(), false, 0, time.Now().Add(time.Minute).Unix(), nil
}

View File

@@ -202,6 +202,36 @@ func TestReactToSpentVtxosRedemption(t *testing.T) {
require.Empty(t, balance.OnchainBalance.LockedAmount)
}
func TestSweep(t *testing.T) {
var receive utils.ArkReceive
receiveStr, err := runArkCommand("receive")
require.NoError(t, err)
err = json.Unmarshal([]byte(receiveStr), &receive)
require.NoError(t, err)
_, err = utils.RunCommand("nigiri", "faucet", "--liquid", receive.Boarding)
require.NoError(t, err)
time.Sleep(5 * time.Second)
_, err = runArkCommand("claim", "--password", utils.Password)
require.NoError(t, err)
time.Sleep(3 * time.Second)
_, err = utils.RunCommand("nigiri", "rpc", "--liquid", "generatetoaddress", "100", "el1qqwk722tghgkgmh3r2ph4d2apwj0dy9xnzlenzklx8jg3z299fpaw56trre9gpk6wmw0u4qycajqeva3t7lzp7wnacvwxha59r")
require.NoError(t, err)
time.Sleep(40 * time.Second)
var balance utils.ArkBalance
balanceStr, err := runArkCommand("balance")
require.NoError(t, err)
require.NoError(t, json.Unmarshal([]byte(balanceStr), &balance))
require.Zero(t, balance.Offchain.Total) // all funds should be swept
}
func runArkCommand(arg ...string) (string, error) {
args := append([]string{"exec", "-t", "arkd", "ark"}, arg...)
return utils.RunCommand("docker", args...)
@@ -276,20 +306,12 @@ func setupAspWallet() error {
return fmt.Errorf("failed to parse response: %s", err)
}
if _, err := utils.RunCommand("nigiri", "faucet", "--liquid", addr.Address); err != nil {
return fmt.Errorf("failed to fund wallet: %s", err)
}
if _, err := utils.RunCommand("nigiri", "faucet", "--liquid", addr.Address); err != nil {
return fmt.Errorf("failed to fund wallet: %s", err)
}
if _, err := utils.RunCommand("nigiri", "faucet", "--liquid", addr.Address); err != nil {
return fmt.Errorf("failed to fund wallet: %s", err)
}
if _, err := utils.RunCommand("nigiri", "faucet", "--liquid", addr.Address); err != nil {
return fmt.Errorf("failed to fund wallet: %s", err)
}
if _, err := utils.RunCommand("nigiri", "faucet", "--liquid", addr.Address); err != nil {
return fmt.Errorf("failed to fund wallet: %s", err)
const numberOfFaucet = 6
for i := 0; i < numberOfFaucet; i++ {
if _, err := utils.RunCommand("nigiri", "faucet", "--liquid", addr.Address); err != nil {
return fmt.Errorf("failed to fund wallet: %s with address %s", err, addr.Address)
}
}
time.Sleep(5 * time.Second)

View File

@@ -362,6 +362,36 @@ func TestAliceSeveralPaymentsToBob(t *testing.T) {
}
func TestSweep(t *testing.T) {
var receive utils.ArkReceive
receiveStr, err := runClarkCommand("receive")
require.NoError(t, err)
err = json.Unmarshal([]byte(receiveStr), &receive)
require.NoError(t, err)
_, err = utils.RunCommand("nigiri", "faucet", receive.Boarding)
require.NoError(t, err)
time.Sleep(5 * time.Second)
_, err = runClarkCommand("claim", "--password", utils.Password)
require.NoError(t, err)
time.Sleep(3 * time.Second)
_, err = utils.RunCommand("nigiri", "rpc", "generatetoaddress", "100", "bcrt1qe8eelqalnch946nzhefd5ajhgl2afjw5aegc59")
require.NoError(t, err)
time.Sleep(40 * time.Second)
var balance utils.ArkBalance
balanceStr, err := runClarkCommand("balance")
require.NoError(t, err)
require.NoError(t, json.Unmarshal([]byte(balanceStr), &balance))
require.Zero(t, balance.Offchain.Total) // all funds should be swept
}
func runClarkCommand(arg ...string) (string, error) {
args := append([]string{"exec", "-t", "clarkd", "ark"}, arg...)
return utils.RunCommand("docker", args...)

View File

@@ -35,6 +35,9 @@ func GenerateBlock() error {
if _, err := RunCommand("nigiri", "rpc", "--liquid", "generatetoaddress", "1", "el1qqwk722tghgkgmh3r2ph4d2apwj0dy9xnzlenzklx8jg3z299fpaw56trre9gpk6wmw0u4qycajqeva3t7lzp7wnacvwxha59r"); err != nil {
return err
}
if _, err := RunCommand("nigiri", "rpc", "generatetoaddress", "1", "bcrt1qe8eelqalnch946nzhefd5ajhgl2afjw5aegc59"); err != nil {
return err
}
time.Sleep(6 * time.Second)
return nil