Add integration tests for sweeping rounds (#339)

* add "block" scheduler type + sweep integration test

* increase timeout in integrationtests

* remove config logs

* rename scheduler package name

* rename package

* rename packages
This commit is contained in:
Louis Singer
2024-10-05 16:12:46 +02:00
committed by GitHub
parent 7606b4cd00
commit 0d39bb6b9f
37 changed files with 477 additions and 279 deletions

View File

@@ -1,8 +1,10 @@
package common package common
import ( import (
"encoding/hex"
"fmt" "fmt"
"github.com/btcsuite/btcd/blockchain"
"github.com/btcsuite/btcd/txscript"
) )
const ( const (
@@ -18,42 +20,33 @@ func closerToModulo512(x uint) uint {
return x - (x % 512) return x - (x % 512)
} }
func BIP68EncodeAsNumber(seconds uint) (uint32, error) { func BIP68Sequence(locktime uint) (uint32, error) {
seconds = closerToModulo512(seconds) isSeconds := locktime >= 512
if seconds > SECONDS_MAX { if isSeconds {
return 0, fmt.Errorf("seconds too large, max is %d", SECONDS_MAX) locktime = closerToModulo512(locktime)
} if locktime > SECONDS_MAX {
if seconds%SECONDS_MOD != 0 { return 0, fmt.Errorf("seconds too large, max is %d", SECONDS_MAX)
return 0, fmt.Errorf("seconds must be a multiple of %d", SECONDS_MOD) }
if locktime%SECONDS_MOD != 0 {
return 0, fmt.Errorf("seconds must be a multiple of %d", SECONDS_MOD)
}
} }
asNumber := SEQUENCE_LOCKTIME_TYPE_FLAG | (seconds >> SEQUENCE_LOCKTIME_GRANULARITY) return blockchain.LockTimeToSequence(isSeconds, uint32(locktime)), nil
return uint32(asNumber), nil
} }
// BIP68Encode returns the encoded sequence locktime for the given number of seconds. func BIP68DecodeSequence(sequence []byte) (uint, error) {
func BIP68Encode(seconds uint) ([]byte, error) { scriptNumber, err := txscript.MakeScriptNum(sequence, true, len(sequence))
asNumber, err := BIP68EncodeAsNumber(seconds)
if err != nil { if err != nil {
return nil, err return 0, err
} }
hexString := fmt.Sprintf("%x", asNumber)
reversed, err := hex.DecodeString(hexString)
if err != nil {
return nil, err
}
for i, j := 0, len(reversed)-1; i < j; i, j = i+1, j-1 {
reversed[i], reversed[j] = reversed[j], reversed[i]
}
return reversed, nil
}
func BIP68Decode(sequence []byte) (uint, error) { if scriptNumber >= txscript.OP_1 && scriptNumber <= txscript.OP_16 {
var asNumber int64 scriptNumber = scriptNumber - (txscript.OP_1 - 1)
for i := len(sequence) - 1; i >= 0; i-- {
asNumber = asNumber<<8 | int64(sequence[i])
} }
asNumber := int64(scriptNumber)
if asNumber&SEQUENCE_LOCKTIME_DISABLE_FLAG != 0 { if asNumber&SEQUENCE_LOCKTIME_DISABLE_FLAG != 0 {
return 0, fmt.Errorf("sequence is disabled") return 0, fmt.Errorf("sequence is disabled")
} }
@@ -61,5 +54,6 @@ func BIP68Decode(sequence []byte) (uint, error) {
seconds := asNumber & SEQUENCE_LOCKTIME_MASK << SEQUENCE_LOCKTIME_GRANULARITY seconds := asNumber & SEQUENCE_LOCKTIME_MASK << SEQUENCE_LOCKTIME_GRANULARITY
return uint(seconds), nil return uint(seconds), nil
} }
return 0, fmt.Errorf("sequence is encoded as block number")
return uint(asNumber), nil
} }

View File

@@ -1,43 +0,0 @@
package common_test
import (
"encoding/json"
"os"
"testing"
sdk "github.com/ark-network/ark/common"
"github.com/stretchr/testify/require"
)
func TestBIP68(t *testing.T) {
data, err := os.ReadFile("fixtures/bip68.json")
require.NoError(t, err)
var testCases []struct {
Input uint `json:"seconds"`
Expected int64 `json:"sequence"`
Desc string `json:"description"`
}
err = json.Unmarshal(data, &testCases)
require.NoError(t, err)
require.NotEmpty(t, testCases)
for _, tc := range testCases {
t.Run(tc.Desc, func(t *testing.T) {
actual, err := sdk.BIP68Encode(tc.Input)
require.NoError(t, err)
var asNumber int64
for i := len(actual) - 1; i >= 0; i-- {
asNumber = asNumber<<8 | int64(actual[i])
}
require.Equal(t, tc.Expected, asNumber)
decoded, err := sdk.BIP68Decode(actual)
require.NoError(t, err)
require.Equal(t, tc.Input, decoded)
})
}
}

View File

@@ -109,9 +109,12 @@ func (d *CSVSigClosure) Decode(script []byte) (bool, error) {
return false, nil return false, nil
} }
sequence := script[1:csvIndex] sequence := script[:csvIndex]
if len(sequence) > 1 {
sequence = sequence[1:]
}
seconds, err := common.BIP68Decode(sequence) seconds, err := common.BIP68DecodeSequence(sequence)
if err != nil { if err != nil {
return false, err return false, err
} }
@@ -162,15 +165,18 @@ func decodeChecksigScript(script []byte) (bool, *secp256k1.PublicKey, error) {
// checkSequenceVerifyScript without checksig // checkSequenceVerifyScript without checksig
func encodeCsvScript(seconds uint) ([]byte, error) { func encodeCsvScript(seconds uint) ([]byte, error) {
sequence, err := common.BIP68Encode(seconds) sequence, err := common.BIP68Sequence(seconds)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return txscript.NewScriptBuilder().AddData(sequence).AddOps([]byte{ return txscript.NewScriptBuilder().
txscript.OP_CHECKSEQUENCEVERIFY, AddInt64(int64(sequence)).
txscript.OP_DROP, AddOps([]byte{
}).Script() txscript.OP_CHECKSEQUENCEVERIFY,
txscript.OP_DROP,
}).
Script()
} }
// checkSequenceVerifyScript + checksig // checkSequenceVerifyScript + checksig

View File

@@ -0,0 +1,30 @@
package bitcointree_test
import (
"testing"
"github.com/ark-network/ark/common/bitcointree"
"github.com/decred/dcrd/dcrec/secp256k1/v4"
"github.com/stretchr/testify/require"
)
func TestRoundTripCSV(t *testing.T) {
seckey, err := secp256k1.GeneratePrivateKey()
require.NoError(t, err)
csvSig := &bitcointree.CSVSigClosure{
Pubkey: seckey.PubKey(),
Seconds: 1024,
}
leaf, err := csvSig.Leaf()
require.NoError(t, err)
var cl bitcointree.CSVSigClosure
valid, err := cl.Decode(leaf.Script)
require.NoError(t, err)
require.True(t, valid)
require.Equal(t, csvSig.Seconds, cl.Seconds)
}

View File

@@ -130,15 +130,18 @@ func (e *Older) Parse(policy string) error {
} }
func (e *Older) Script(bool) (string, error) { func (e *Older) Script(bool) (string, error) {
sequence, err := common.BIP68Encode(e.Timeout) sequence, err := common.BIP68Sequence(e.Timeout)
if err != nil { if err != nil {
return "", err return "", err
} }
script, err := txscript.NewScriptBuilder().AddData(sequence).AddOps([]byte{ script, err := txscript.NewScriptBuilder().
txscript.OP_CHECKSEQUENCEVERIFY, AddInt64(int64(sequence)).
txscript.OP_DROP, AddOps([]byte{
}).Script() txscript.OP_CHECKSEQUENCEVERIFY,
txscript.OP_DROP,
}).
Script()
if err != nil { if err != nil {
return "", err return "", err
} }

View File

@@ -1,67 +0,0 @@
[
{
"description": "0x00400000 (00000000010000000000000000000000)",
"seconds": 0,
"sequence": 4194304
},
{
"description": "0x00400001 (00000000010000000000000000000001)",
"seconds": 512,
"sequence": 4194305
},
{
"description": "0x00400002 (00000000010000000000000000000010)",
"seconds": 1024,
"sequence": 4194306
},
{
"description": "0x00400003 (00000000010000000000000000000011)",
"seconds": 1536,
"sequence": 4194307
},
{
"description": "0x00400004 (00000000010000000000000000000100)",
"seconds": 2048,
"sequence": 4194308
},
{
"description": "0x00400005 (00000000010000000000000000000101)",
"seconds": 2560,
"sequence": 4194309
},
{
"description": "0x00400006 (00000000010000000000000000000110)",
"seconds": 3072,
"sequence": 4194310
},
{
"description": "0x00400007 (00000000010000000000000000000111)",
"seconds": 3584,
"sequence": 4194311
},
{
"description": "0x00400008 (00000000010000000000000000001000)",
"seconds": 4096,
"sequence": 4194312
},
{
"description": "0x00400009 (00000000010000000000000000001001)",
"seconds": 4608,
"sequence": 4194313
},
{
"description": "0x0040000a (00000000010000000000000000001010)",
"seconds": 5120,
"sequence": 4194314
},
{
"description": "0x0040000b (00000000010000000000000000001011)",
"seconds": 5632,
"sequence": 4194315
},
{
"description": "0x0040000c (00000000010000000000000000001100)",
"seconds": 6144,
"sequence": 4194316
}
]

View File

@@ -130,9 +130,12 @@ func (d *CSVSigClosure) Decode(script []byte) (bool, error) {
return false, nil return false, nil
} }
sequence := script[1:csvIndex] sequence := script[:csvIndex]
if len(sequence) > 1 {
sequence = sequence[1:]
}
seconds, err := common.BIP68Decode(sequence) seconds, err := common.BIP68DecodeSequence(sequence)
if err != nil { if err != nil {
return false, err return false, err
} }
@@ -369,15 +372,18 @@ func decodeChecksigScript(script []byte) (bool, *secp256k1.PublicKey, error) {
// checkSequenceVerifyScript without checksig // checkSequenceVerifyScript without checksig
func encodeCsvScript(seconds uint) ([]byte, error) { func encodeCsvScript(seconds uint) ([]byte, error) {
sequence, err := common.BIP68Encode(seconds) sequence, err := common.BIP68Sequence(seconds)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return txscript.NewScriptBuilder().AddData(sequence).AddOps([]byte{ return txscript.NewScriptBuilder().
txscript.OP_CHECKSEQUENCEVERIFY, AddInt64(int64(sequence)).
txscript.OP_DROP, AddOps([]byte{
}).Script() txscript.OP_CHECKSEQUENCEVERIFY,
txscript.OP_DROP,
}).
Script()
} }
// checkSequenceVerifyScript + checksig // checkSequenceVerifyScript + checksig

View File

@@ -9,12 +9,13 @@ services:
- ARK_ROUND_INTERVAL=10 - ARK_ROUND_INTERVAL=10
- ARK_NETWORK=regtest - ARK_NETWORK=regtest
- ARK_LOG_LEVEL=5 - ARK_LOG_LEVEL=5
- ARK_ROUND_LIFETIME=512 - ARK_ROUND_LIFETIME=20
- ARK_TX_BUILDER_TYPE=covenantless - ARK_TX_BUILDER_TYPE=covenantless
- ARK_ESPLORA_URL=http://chopsticks:3000 - ARK_ESPLORA_URL=http://chopsticks:3000
- ARK_BITCOIND_RPC_USER=admin1 - ARK_BITCOIND_RPC_USER=admin1
- ARK_BITCOIND_RPC_PASS=123 - ARK_BITCOIND_RPC_PASS=123
- ARK_BITCOIND_RPC_HOST=bitcoin:18443 - ARK_BITCOIND_RPC_HOST=bitcoin:18443
- ARK_SCHEDULER_TYPE=block
- ARK_NO_TLS=true - ARK_NO_TLS=true
- ARK_NO_MACAROONS=true - ARK_NO_MACAROONS=true
- ARK_DATADIR=/app/data - ARK_DATADIR=/app/data

View File

@@ -31,7 +31,9 @@ services:
- ARK_ROUND_INTERVAL=10 - ARK_ROUND_INTERVAL=10
- ARK_NETWORK=liquidregtest - ARK_NETWORK=liquidregtest
- ARK_LOG_LEVEL=5 - ARK_LOG_LEVEL=5
- ARK_ROUND_LIFETIME=512 - ARK_ESPLORA_URL=http://chopsticks-liquid:3000
- ARK_ROUND_LIFETIME=20
- ARK_SCHEDULER_TYPE=block
- ARK_DB_TYPE=sqlite - ARK_DB_TYPE=sqlite
- ARK_TX_BUILDER_TYPE=covenant - ARK_TX_BUILDER_TYPE=covenant
- ARK_PORT=6060 - ARK_PORT=6060

View File

@@ -35,7 +35,7 @@ type Utxo struct {
} }
func (u *Utxo) Sequence() (uint32, error) { func (u *Utxo) Sequence() (uint32, error) {
return common.BIP68EncodeAsNumber(u.Delay) return common.BIP68Sequence(u.Delay)
} }
func newUtxo(explorerUtxo ExplorerUtxo, delay uint) Utxo { func newUtxo(explorerUtxo ExplorerUtxo, delay uint) Utxo {

View File

@@ -23,8 +23,8 @@ help:
## intergrationtest: runs integration tests ## intergrationtest: runs integration tests
integrationtest: integrationtest:
@echo "Running integration tests..." @echo "Running integration tests..."
@go test -v -count 1 -timeout 300s github.com/ark-network/ark/server/test/e2e/covenant @go test -v -count 1 -timeout 400s github.com/ark-network/ark/server/test/e2e/covenant
@go test -v -count 1 -timeout 300s github.com/ark-network/ark/server/test/e2e/covenantless @go test -v -count 1 -timeout 400s github.com/ark-network/ark/server/test/e2e/covenantless
## lint: lint codebase ## lint: lint codebase
lint: lint:

View File

@@ -8,7 +8,8 @@ import (
"github.com/ark-network/ark/server/internal/core/application" "github.com/ark-network/ark/server/internal/core/application"
"github.com/ark-network/ark/server/internal/core/ports" "github.com/ark-network/ark/server/internal/core/ports"
"github.com/ark-network/ark/server/internal/infrastructure/db" "github.com/ark-network/ark/server/internal/infrastructure/db"
scheduler "github.com/ark-network/ark/server/internal/infrastructure/scheduler/gocron" blockscheduler "github.com/ark-network/ark/server/internal/infrastructure/scheduler/block"
timescheduler "github.com/ark-network/ark/server/internal/infrastructure/scheduler/gocron"
txbuilder "github.com/ark-network/ark/server/internal/infrastructure/tx-builder/covenant" txbuilder "github.com/ark-network/ark/server/internal/infrastructure/tx-builder/covenant"
cltxbuilder "github.com/ark-network/ark/server/internal/infrastructure/tx-builder/covenantless" cltxbuilder "github.com/ark-network/ark/server/internal/infrastructure/tx-builder/covenantless"
fileunlocker "github.com/ark-network/ark/server/internal/infrastructure/unlocker/file" fileunlocker "github.com/ark-network/ark/server/internal/infrastructure/unlocker/file"
@@ -29,6 +30,7 @@ var (
} }
supportedSchedulers = supportedType{ supportedSchedulers = supportedType{
"gocron": {}, "gocron": {},
"block": {},
} }
supportedTxBuilders = supportedType{ supportedTxBuilders = supportedType{
"covenant": {}, "covenant": {},
@@ -115,11 +117,23 @@ func (c *Config) Validate() error {
if len(c.WalletAddr) <= 0 { if len(c.WalletAddr) <= 0 {
return fmt.Errorf("missing onchain wallet address") return fmt.Errorf("missing onchain wallet address")
} }
// round life time must be a multiple of 512
if c.RoundLifetime < minAllowedSequence { if c.RoundLifetime < minAllowedSequence {
return fmt.Errorf( if c.SchedulerType != "block" {
"invalid round lifetime, must be a at least %d", minAllowedSequence, return fmt.Errorf("scheduler type must be block if round lifetime is expressed in blocks")
) }
} else {
if c.SchedulerType != "gocron" {
return fmt.Errorf("scheduler type must be gocron if round lifetime is expressed in seconds")
}
// round life time must be a multiple of 512 if expressed in seconds
if c.RoundLifetime%minAllowedSequence != 0 {
c.RoundLifetime -= c.RoundLifetime % minAllowedSequence
log.Infof(
"round lifetime must be a multiple of %d, rounded to %d",
minAllowedSequence, c.RoundLifetime,
)
}
} }
if c.UnilateralExitDelay < minAllowedSequence { if c.UnilateralExitDelay < minAllowedSequence {
@@ -134,14 +148,6 @@ func (c *Config) Validate() error {
) )
} }
if c.RoundLifetime%minAllowedSequence != 0 {
c.RoundLifetime -= c.RoundLifetime % minAllowedSequence
log.Infof(
"round lifetime must be a multiple of %d, rounded to %d",
minAllowedSequence, c.RoundLifetime,
)
}
if c.UnilateralExitDelay%minAllowedSequence != 0 { if c.UnilateralExitDelay%minAllowedSequence != 0 {
c.UnilateralExitDelay -= c.UnilateralExitDelay % minAllowedSequence c.UnilateralExitDelay -= c.UnilateralExitDelay % minAllowedSequence
log.Infof( log.Infof(
@@ -328,7 +334,9 @@ func (c *Config) schedulerService() error {
var err error var err error
switch c.SchedulerType { switch c.SchedulerType {
case "gocron": case "gocron":
svc = scheduler.NewScheduler() svc = timescheduler.NewScheduler()
case "block":
svc, err = blockscheduler.NewScheduler(c.EsploraURL)
default: default:
err = fmt.Errorf("unknown scheduler type") err = fmt.Errorf("unknown scheduler type")
} }
@@ -367,7 +375,12 @@ func (c *Config) appService() error {
} }
func (c *Config) adminService() error { func (c *Config) adminService() error {
c.adminSvc = application.NewAdminService(c.wallet, c.repo, c.txBuilder) unit := ports.UnixTime
if c.RoundLifetime < minAllowedSequence {
unit = ports.BlockHeight
}
c.adminSvc = application.NewAdminService(c.wallet, c.repo, c.txBuilder, unit)
return nil return nil
} }

View File

@@ -50,16 +50,18 @@ type AdminService interface {
} }
type adminService struct { type adminService struct {
walletSvc ports.WalletService walletSvc ports.WalletService
repoManager ports.RepoManager repoManager ports.RepoManager
txBuilder ports.TxBuilder txBuilder ports.TxBuilder
sweeperTimeUnit ports.TimeUnit
} }
func NewAdminService(walletSvc ports.WalletService, repoManager ports.RepoManager, txBuilder ports.TxBuilder) AdminService { func NewAdminService(walletSvc ports.WalletService, repoManager ports.RepoManager, txBuilder ports.TxBuilder, timeUnit ports.TimeUnit) AdminService {
return &adminService{ return &adminService{
walletSvc: walletSvc, walletSvc: walletSvc,
repoManager: repoManager, repoManager: repoManager,
txBuilder: txBuilder, txBuilder: txBuilder,
sweeperTimeUnit: timeUnit,
} }
} }
@@ -130,7 +132,7 @@ func (a *adminService) GetScheduledSweeps(ctx context.Context) ([]ScheduledSweep
for _, round := range sweepableRounds { for _, round := range sweepableRounds {
sweepable, err := findSweepableOutputs( sweepable, err := findSweepableOutputs(
ctx, a.walletSvc, a.txBuilder, round.CongestionTree, ctx, a.walletSvc, a.txBuilder, a.sweeperTimeUnit, round.CongestionTree,
) )
if err != nil { if err != nil {
return nil, err return nil, err

View File

@@ -177,7 +177,7 @@ func (s *covenantService) SpendVtxos(ctx context.Context, inputs []ports.Input)
return "", fmt.Errorf("failed to parse tx %s: %s", input.Txid, err) return "", fmt.Errorf("failed to parse tx %s: %s", input.Txid, err)
} }
confirmed, blocktime, err := s.wallet.IsTransactionConfirmed(ctx, input.Txid) confirmed, _, blocktime, err := s.wallet.IsTransactionConfirmed(ctx, input.Txid)
if err != nil { if err != nil {
return "", fmt.Errorf("failed to check tx %s: %s", input.Txid, err) return "", fmt.Errorf("failed to check tx %s: %s", input.Txid, err)
} }
@@ -910,12 +910,10 @@ func (s *covenantService) scheduleSweepVtxosForRound(round *domain.Round) {
return return
} }
expirationTimestamp := time.Now().Add( expirationTime := s.sweeper.scheduler.AddNow(s.roundLifetime)
time.Duration(s.roundLifetime+30) * time.Second,
)
if err := s.sweeper.schedule( if err := s.sweeper.schedule(
expirationTimestamp.Unix(), round.Txid, round.CongestionTree, expirationTime, round.Txid, round.CongestionTree,
); err != nil { ); err != nil {
log.WithError(err).Warn("failed to schedule sweep tx") log.WithError(err).Warn("failed to schedule sweep tx")
} }

View File

@@ -421,7 +421,7 @@ func (s *covenantlessService) SpendVtxos(ctx context.Context, inputs []ports.Inp
return "", fmt.Errorf("failed to deserialize tx %s: %s", input.Txid, err) return "", fmt.Errorf("failed to deserialize tx %s: %s", input.Txid, err)
} }
confirmed, blocktime, err := s.wallet.IsTransactionConfirmed(ctx, input.Txid) confirmed, _, blocktime, err := s.wallet.IsTransactionConfirmed(ctx, input.Txid)
if err != nil { if err != nil {
return "", fmt.Errorf("failed to check tx %s: %s", input.Txid, err) return "", fmt.Errorf("failed to check tx %s: %s", input.Txid, err)
} }
@@ -1316,13 +1316,9 @@ func (s *covenantlessService) scheduleSweepVtxosForRound(round *domain.Round) {
return return
} }
expirationTimestamp := time.Now().Add( expirationTimestamp := s.sweeper.scheduler.AddNow(s.roundLifetime)
time.Duration(s.roundLifetime+30) * time.Second,
)
if err := s.sweeper.schedule( if err := s.sweeper.schedule(expirationTimestamp, round.Txid, round.CongestionTree); err != nil {
expirationTimestamp.Unix(), round.Txid, round.CongestionTree,
); err != nil {
log.WithError(err).Warn("failed to schedule sweep tx") log.WithError(err).Warn("failed to schedule sweep tx")
} }
} }

View File

@@ -3,6 +3,7 @@ package application
import ( import (
"context" "context"
"fmt" "fmt"
"sync"
"time" "time"
"github.com/ark-network/ark/common/tree" "github.com/ark-network/ark/common/tree"
@@ -22,6 +23,7 @@ type sweeper struct {
scheduler ports.SchedulerService scheduler ports.SchedulerService
// cache of scheduled tasks, avoid scheduling the same sweep event multiple times // cache of scheduled tasks, avoid scheduling the same sweep event multiple times
locker sync.Locker
scheduledTasks map[string]struct{} scheduledTasks map[string]struct{}
} }
@@ -36,6 +38,7 @@ func newSweeper(
repoManager, repoManager,
builder, builder,
scheduler, scheduler,
&sync.Mutex{},
make(map[string]struct{}), make(map[string]struct{}),
} }
} }
@@ -62,6 +65,8 @@ func (s *sweeper) stop() {
// removeTask update the cached map of scheduled tasks // removeTask update the cached map of scheduled tasks
func (s *sweeper) removeTask(treeRootTxid string) { func (s *sweeper) removeTask(treeRootTxid string) {
s.locker.Lock()
defer s.locker.Unlock()
delete(s.scheduledTasks, treeRootTxid) delete(s.scheduledTasks, treeRootTxid)
} }
@@ -84,13 +89,22 @@ func (s *sweeper) schedule(
} }
task := s.createTask(roundTxid, congestionTree) task := s.createTask(roundTxid, congestionTree)
fancyTime := time.Unix(expirationTimestamp, 0).Format("2006-01-02 15:04:05")
var fancyTime string
if s.scheduler.Unit() == ports.UnixTime {
fancyTime = time.Unix(expirationTimestamp, 0).Format("2006-01-02 15:04:05")
} else {
fancyTime = fmt.Sprintf("block %d", expirationTimestamp)
}
log.Debugf("scheduled sweep for round %s at %s", roundTxid, fancyTime) log.Debugf("scheduled sweep for round %s at %s", roundTxid, fancyTime)
if err := s.scheduler.ScheduleTaskOnce(expirationTimestamp, task); err != nil { if err := s.scheduler.ScheduleTaskOnce(expirationTimestamp, task); err != nil {
return err return err
} }
s.locker.Lock()
s.scheduledTasks[root.Txid] = struct{}{} s.scheduledTasks[root.Txid] = struct{}{}
s.locker.Unlock()
if err := s.updateVtxoExpirationTime(congestionTree, expirationTimestamp); err != nil { if err := s.updateVtxoExpirationTime(congestionTree, expirationTimestamp); err != nil {
log.WithError(err).Error("error while updating vtxo expiration time") log.WithError(err).Error("error while updating vtxo expiration time")
@@ -120,7 +134,7 @@ func (s *sweeper) createTask(
vtxoKeys := make([]domain.VtxoKey, 0) // vtxos associated to the sweep inputs vtxoKeys := make([]domain.VtxoKey, 0) // vtxos associated to the sweep inputs
// inspect the congestion tree to find onchain shared outputs // inspect the congestion tree to find onchain shared outputs
sharedOutputs, err := findSweepableOutputs(ctx, s.wallet, s.builder, congestionTree) sharedOutputs, err := findSweepableOutputs(ctx, s.wallet, s.builder, s.scheduler.Unit(), congestionTree)
if err != nil { if err != nil {
log.WithError(err).Error("error while inspecting congestion tree") log.WithError(err).Error("error while inspecting congestion tree")
return return
@@ -128,7 +142,7 @@ func (s *sweeper) createTask(
for expiredAt, inputs := range sharedOutputs { for expiredAt, inputs := range sharedOutputs {
// if the shared outputs are not expired, schedule a sweep task for it // if the shared outputs are not expired, schedule a sweep task for it
if time.Unix(expiredAt, 0).After(time.Now()) { if s.scheduler.AfterNow(expiredAt) {
subtrees, err := computeSubTrees(congestionTree, inputs) subtrees, err := computeSubTrees(congestionTree, inputs)
if err != nil { if err != nil {
log.WithError(err).Error("error while computing subtrees") log.WithError(err).Error("error while computing subtrees")
@@ -136,8 +150,7 @@ func (s *sweeper) createTask(
} }
for _, subTree := range subtrees { for _, subTree := range subtrees {
// mitigate the risk to get BIP68 non-final errors by scheduling the task 30 seconds after the expiration time if err := s.schedule(expiredAt, roundTxid, subTree); err != nil {
if err := s.schedule(int64(expiredAt), roundTxid, subTree); err != nil {
log.WithError(err).Error("error while scheduling sweep task") log.WithError(err).Error("error while scheduling sweep task")
continue continue
} }

View File

@@ -259,17 +259,18 @@ func findSweepableOutputs(
ctx context.Context, ctx context.Context,
walletSvc ports.WalletService, walletSvc ports.WalletService,
txbuilder ports.TxBuilder, txbuilder ports.TxBuilder,
schedulerUnit ports.TimeUnit,
congestionTree tree.CongestionTree, congestionTree tree.CongestionTree,
) (map[int64][]ports.SweepInput, error) { ) (map[int64][]ports.SweepInput, error) {
sweepableOutputs := make(map[int64][]ports.SweepInput) sweepableOutputs := make(map[int64][]ports.SweepInput)
blocktimeCache := make(map[string]int64) // txid -> blocktime blocktimeCache := make(map[string]int64) // txid -> blocktime / blockheight
nodesToCheck := congestionTree[0] // init with the root nodesToCheck := congestionTree[0] // init with the root
for len(nodesToCheck) > 0 { for len(nodesToCheck) > 0 {
newNodesToCheck := make([]tree.Node, 0) newNodesToCheck := make([]tree.Node, 0)
for _, node := range nodesToCheck { for _, node := range nodesToCheck {
isConfirmed, blocktime, err := walletSvc.IsTransactionConfirmed(ctx, node.Txid) isConfirmed, height, blocktime, err := walletSvc.IsTransactionConfirmed(ctx, node.Txid)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -279,21 +280,31 @@ func findSweepableOutputs(
if !isConfirmed { if !isConfirmed {
if _, ok := blocktimeCache[node.ParentTxid]; !ok { if _, ok := blocktimeCache[node.ParentTxid]; !ok {
isConfirmed, blocktime, err := walletSvc.IsTransactionConfirmed(ctx, node.ParentTxid) isConfirmed, height, blocktime, err := walletSvc.IsTransactionConfirmed(ctx, node.ParentTxid)
if !isConfirmed || err != nil { if !isConfirmed || err != nil {
return nil, fmt.Errorf("tx %s not found", node.ParentTxid) return nil, fmt.Errorf("tx %s not found", node.ParentTxid)
} }
blocktimeCache[node.ParentTxid] = blocktime if schedulerUnit == ports.BlockHeight {
blocktimeCache[node.ParentTxid] = height
} else {
blocktimeCache[node.ParentTxid] = blocktime
}
} }
expirationTime, sweepInput, err = txbuilder.GetSweepInput(blocktimeCache[node.ParentTxid], node) var lifetime int64
lifetime, sweepInput, err = txbuilder.GetSweepInput(node)
if err != nil { if err != nil {
return nil, err return nil, err
} }
expirationTime = blocktimeCache[node.ParentTxid] + lifetime
} else { } else {
// cache the blocktime for future use // cache the blocktime for future use
blocktimeCache[node.Txid] = int64(blocktime) if schedulerUnit == ports.BlockHeight {
blocktimeCache[node.Txid] = height
} else {
blocktimeCache[node.Txid] = blocktime
}
// if the tx is onchain, it means that the input is spent // if the tx is onchain, it means that the input is spent
// add the children to the nodes in order to check them during the next iteration // add the children to the nodes in order to check them during the next iteration

View File

@@ -14,5 +14,5 @@ type BlockchainScanner interface {
WatchScripts(ctx context.Context, scripts []string) error WatchScripts(ctx context.Context, scripts []string) error
UnwatchScripts(ctx context.Context, scripts []string) error UnwatchScripts(ctx context.Context, scripts []string) error
GetNotificationChannel(ctx context.Context) <-chan map[string][]VtxoWithValue GetNotificationChannel(ctx context.Context) <-chan map[string][]VtxoWithValue
IsTransactionConfirmed(ctx context.Context, txid string) (isConfirmed bool, blocktime int64, err error) IsTransactionConfirmed(ctx context.Context, txid string) (isConfirmed bool, blocknumber int64, blocktime int64, err error)
} }

View File

@@ -1,9 +1,18 @@
package ports package ports
type TimeUnit int
const (
UnixTime TimeUnit = iota
BlockHeight
)
type SchedulerService interface { type SchedulerService interface {
Start() Start()
Stop() Stop()
ScheduleTask(interval int64, immediate bool, task func()) error Unit() TimeUnit
ScheduleTaskOnce(delay int64, task func()) error AddNow(lifetime int64) int64
AfterNow(expiry int64) bool
ScheduleTaskOnce(at int64, task func()) error
} }

View File

@@ -34,7 +34,7 @@ type TxBuilder interface {
) (roundTx string, congestionTree tree.CongestionTree, connectorAddress string, err error) ) (roundTx string, congestionTree tree.CongestionTree, connectorAddress string, err error)
BuildForfeitTxs(poolTx string, payments []domain.Payment, minRelayFeeRate chainfee.SatPerKVByte) (connectors []string, forfeitTxs []string, err error) BuildForfeitTxs(poolTx string, payments []domain.Payment, minRelayFeeRate chainfee.SatPerKVByte) (connectors []string, forfeitTxs []string, err error)
BuildSweepTx(inputs []SweepInput) (signedSweepTx string, err error) BuildSweepTx(inputs []SweepInput) (signedSweepTx string, err error)
GetSweepInput(parentblocktime int64, node tree.Node) (expirationtime int64, sweepInput SweepInput, err error) GetSweepInput(node tree.Node) (lifetime int64, sweepInput SweepInput, err error)
FinalizeAndExtract(tx string) (txhex string, err error) FinalizeAndExtract(tx string) (txhex string, err error)
VerifyTapscriptPartialSigs(tx string) (valid bool, txid string, err error) VerifyTapscriptPartialSigs(tx string) (valid bool, txid string, err error)
// FindLeaves returns all the leaves txs that are reachable from the given outpoint // FindLeaves returns all the leaves txs that are reachable from the given outpoint

View File

@@ -117,7 +117,7 @@ func (r *vtxoRepository) GetAllVtxos(
spentVtxos := make([]domain.Vtxo, 0, len(vtxos)) spentVtxos := make([]domain.Vtxo, 0, len(vtxos))
unspentVtxos := make([]domain.Vtxo, 0, len(vtxos)) unspentVtxos := make([]domain.Vtxo, 0, len(vtxos))
for _, vtxo := range vtxos { for _, vtxo := range vtxos {
if vtxo.Spent { if vtxo.Spent || vtxo.Swept {
spentVtxos = append(spentVtxos, vtxo) spentVtxos = append(spentVtxos, vtxo)
} else { } else {
unspentVtxos = append(unspentVtxos, vtxo) unspentVtxos = append(unspentVtxos, vtxo)

View File

@@ -113,7 +113,7 @@ func (v *vxtoRepository) GetAllVtxos(ctx context.Context, pubkey string) ([]doma
spentVtxos := make([]domain.Vtxo, 0) spentVtxos := make([]domain.Vtxo, 0)
for _, vtxo := range vtxos { for _, vtxo := range vtxos {
if vtxo.Spent { if vtxo.Spent || vtxo.Swept {
spentVtxos = append(spentVtxos, vtxo) spentVtxos = append(spentVtxos, vtxo)
} else { } else {
unspentVtxos = append(unspentVtxos, vtxo) unspentVtxos = append(unspentVtxos, vtxo)

View File

@@ -0,0 +1,147 @@
package blockscheduler
import (
"fmt"
"net/http"
"net/url"
"sync"
"time"
"github.com/ark-network/ark/server/internal/core/ports"
"github.com/sirupsen/logrus"
)
const tipHeightEndpoit = "/blocks/tip/height"
type service struct {
tipURL string
lock sync.Locker
taskes map[int64][]func()
stopCh chan struct{}
}
func NewScheduler(esploraURL string) (ports.SchedulerService, error) {
if len(esploraURL) == 0 {
return nil, fmt.Errorf("esplora URL is required")
}
tipURL, err := url.JoinPath(esploraURL, tipHeightEndpoit)
if err != nil {
return nil, err
}
return &service{
tipURL,
&sync.Mutex{},
make(map[int64][]func()),
make(chan struct{}),
}, nil
}
func (s *service) Start() {
go func() {
for {
select {
case <-s.stopCh:
return
default:
time.Sleep(10 * time.Second)
taskes, err := s.popTaskes()
if err != nil {
fmt.Println("error fetching tasks:", err)
continue
}
logrus.Debugf("fetched %d tasks", len(taskes))
for _, task := range taskes {
go task()
}
}
}
}()
}
func (s *service) Stop() {
s.stopCh <- struct{}{}
close(s.stopCh)
}
func (s *service) Unit() ports.TimeUnit {
return ports.BlockHeight
}
func (s *service) AddNow(lifetime int64) int64 {
tip, err := s.fetchTipHeight()
if err != nil {
return 0
}
return tip + lifetime
}
func (s *service) AfterNow(expiry int64) bool {
tip, err := s.fetchTipHeight()
if err != nil {
return false
}
return expiry > tip
}
func (s *service) ScheduleTaskOnce(at int64, task func()) error {
s.lock.Lock()
defer s.lock.Unlock()
if _, ok := s.taskes[at]; !ok {
s.taskes[at] = make([]func(), 0)
}
s.taskes[at] = append(s.taskes[at], task)
return nil
}
func (s *service) popTaskes() ([]func(), error) {
s.lock.Lock()
defer s.lock.Unlock()
tip, err := s.fetchTipHeight()
if err != nil {
return nil, err
}
taskes := make([]func(), 0)
for height, tasks := range s.taskes {
if height > tip {
continue
}
taskes = append(taskes, tasks...)
delete(s.taskes, height)
}
return taskes, nil
}
func (s *service) fetchTipHeight() (int64, error) {
resp, err := http.Get(s.tipURL)
if err != nil {
return 0, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return 0, fmt.Errorf("unexpected status code: %d", resp.StatusCode)
}
var tip int64
if _, err := fmt.Fscanf(resp.Body, "%d", &tip); err != nil {
return 0, err
}
logrus.Debugf("fetching tip height from %s, got %d", s.tipURL, tip)
return tip, nil
}

View File

@@ -1,4 +1,4 @@
package scheduler package timescheduler
import ( import (
"fmt" "fmt"
@@ -17,6 +17,18 @@ func NewScheduler() ports.SchedulerService {
return &service{svc} return &service{svc}
} }
func (s *service) Unit() ports.TimeUnit {
return ports.UnixTime
}
func (s *service) AddNow(lifetime int64) int64 {
return time.Now().Add(time.Duration(lifetime) * time.Second).Unix()
}
func (s *service) AfterNow(expiry int64) bool {
return time.Unix(expiry, 0).After(time.Now())
}
func (s *service) Start() { func (s *service) Start() {
s.scheduler.StartAsync() s.scheduler.StartAsync()
} }
@@ -25,15 +37,6 @@ func (s *service) Stop() {
s.scheduler.Stop() s.scheduler.Stop()
} }
func (s *service) ScheduleTask(interval int64, immediate bool, task func()) error {
if immediate {
_, err := s.scheduler.Every(int(interval)).Seconds().Do(task)
return err
}
_, err := s.scheduler.Every(int(interval)).Seconds().WaitForSchedule().Do(task)
return err
}
func (s *service) ScheduleTaskOnce(at int64, task func()) error { func (s *service) ScheduleTaskOnce(at int64, task func()) error {
delay := at - time.Now().Unix() delay := at - time.Now().Unix()
if delay < 0 { if delay < 0 {

View File

@@ -195,7 +195,7 @@ func (b *txBuilder) BuildRoundTx(
return return
} }
func (b *txBuilder) GetSweepInput(parentblocktime int64, node tree.Node) (expirationtime int64, sweepInput ports.SweepInput, err error) { func (b *txBuilder) GetSweepInput(node tree.Node) (lifetime int64, sweepInput ports.SweepInput, err error) {
pset, err := psetv2.NewPsetFromBase64(node.Tx) pset, err := psetv2.NewPsetFromBase64(node.Tx)
if err != nil { if err != nil {
return -1, nil, err return -1, nil, err
@@ -215,8 +215,6 @@ func (b *txBuilder) GetSweepInput(parentblocktime int64, node tree.Node) (expira
return -1, nil, err return -1, nil, err
} }
expirationTime := parentblocktime + lifetime
txhex, err := b.wallet.GetTransaction(context.Background(), txid) txhex, err := b.wallet.GetTransaction(context.Background(), txid)
if err != nil { if err != nil {
return -1, nil, err return -1, nil, err
@@ -241,7 +239,7 @@ func (b *txBuilder) GetSweepInput(parentblocktime int64, node tree.Node) (expira
amount: inputValue, amount: inputValue,
} }
return expirationTime, sweepInput, nil return lifetime, sweepInput, nil
} }
func (b *txBuilder) VerifyTapscriptPartialSigs(tx string) (bool, string, error) { func (b *txBuilder) VerifyTapscriptPartialSigs(tx string) (bool, string, error) {

View File

@@ -151,7 +151,7 @@ func (m *mockedWallet) GetDustAmount(ctx context.Context) (uint64, error) {
return res, args.Error(1) return res, args.Error(1)
} }
func (m *mockedWallet) IsTransactionConfirmed(ctx context.Context, txid string) (bool, int64, error) { func (m *mockedWallet) IsTransactionConfirmed(ctx context.Context, txid string) (bool, int64, int64, error) {
args := m.Called(ctx, txid) args := m.Called(ctx, txid)
var res bool var res bool
@@ -159,12 +159,17 @@ func (m *mockedWallet) IsTransactionConfirmed(ctx context.Context, txid string)
res = a.(bool) res = a.(bool)
} }
var height int64
if h := args.Get(1); h != nil {
height = h.(int64)
}
var blocktime int64 var blocktime int64
if b := args.Get(1); b != nil { if b := args.Get(1); b != nil {
blocktime = b.(int64) blocktime = b.(int64)
} }
return res, blocktime, args.Error(2) return res, height, blocktime, args.Error(2)
} }
func (m *mockedWallet) SignTransactionTapscript(ctx context.Context, pset string, inputIndexes []int) (string, error) { func (m *mockedWallet) SignTransactionTapscript(ctx context.Context, pset string, inputIndexes []int) (string, error) {

View File

@@ -90,7 +90,7 @@ func sweepTransaction(
return nil, err return nil, err
} }
sequence, err := common.BIP68EncodeAsNumber(sweepClosure.Seconds) sequence, err := common.BIP68Sequence(sweepClosure.Seconds)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -324,7 +324,7 @@ func (b *txBuilder) BuildRoundTx(
return return
} }
func (b *txBuilder) GetSweepInput(parentblocktime int64, node tree.Node) (expirationtime int64, sweepInput ports.SweepInput, err error) { func (b *txBuilder) GetSweepInput(node tree.Node) (lifetime int64, sweepInput ports.SweepInput, err error) {
partialTx, err := psbt.NewFromRawBytes(strings.NewReader(node.Tx), true) partialTx, err := psbt.NewFromRawBytes(strings.NewReader(node.Tx), true)
if err != nil { if err != nil {
return -1, nil, err return -1, nil, err
@@ -343,8 +343,6 @@ func (b *txBuilder) GetSweepInput(parentblocktime int64, node tree.Node) (expira
return -1, nil, err return -1, nil, err
} }
expirationTime := parentblocktime + lifetime
txhex, err := b.wallet.GetTransaction(context.Background(), txid.String()) txhex, err := b.wallet.GetTransaction(context.Background(), txid.String())
if err != nil { if err != nil {
return -1, nil, err return -1, nil, err
@@ -365,7 +363,7 @@ func (b *txBuilder) GetSweepInput(parentblocktime int64, node tree.Node) (expira
amount: tx.TxOut[index].Value, amount: tx.TxOut[index].Value,
} }
return expirationTime, sweepInput, nil return lifetime, sweepInput, nil
} }
func (b *txBuilder) FindLeaves(congestionTree tree.CongestionTree, fromtxid string, vout uint32) ([]tree.Node, error) { func (b *txBuilder) FindLeaves(congestionTree tree.CongestionTree, fromtxid string, vout uint32) ([]tree.Node, error) {
@@ -1220,6 +1218,8 @@ func extractSweepLeaf(input psbt.PInput) (sweepLeaf *psbt.TaprootTapLeafScript,
if err != nil { if err != nil {
return nil, nil, 0, err return nil, nil, 0, err
} }
fmt.Println("closure", valid)
if valid && closure.Seconds > 0 { if valid && closure.Seconds > 0 {
sweepLeaf = leaf sweepLeaf = leaf
lifetime = int64(closure.Seconds) lifetime = int64(closure.Seconds)

View File

@@ -172,7 +172,7 @@ func (m *mockedWallet) GetDustAmount(ctx context.Context) (uint64, error) {
return res, args.Error(1) return res, args.Error(1)
} }
func (m *mockedWallet) IsTransactionConfirmed(ctx context.Context, txid string) (bool, int64, error) { func (m *mockedWallet) IsTransactionConfirmed(ctx context.Context, txid string) (bool, int64, int64, error) {
args := m.Called(ctx, txid) args := m.Called(ctx, txid)
var res bool var res bool
@@ -180,12 +180,17 @@ func (m *mockedWallet) IsTransactionConfirmed(ctx context.Context, txid string)
res = a.(bool) res = a.(bool)
} }
var height int64
if h := args.Get(1); h != nil {
height = h.(int64)
}
var blocktime int64 var blocktime int64
if b := args.Get(1); b != nil { if b := args.Get(1); b != nil {
blocktime = b.(int64) blocktime = b.(int64)
} }
return res, blocktime, args.Error(2) return res, height, blocktime, args.Error(2)
} }
func (m *mockedWallet) SignTransactionTapscript(ctx context.Context, pset string, inputIndexes []int) (string, error) { func (m *mockedWallet) SignTransactionTapscript(ctx context.Context, pset string, inputIndexes []int) (string, error) {

View File

@@ -37,7 +37,7 @@ func sweepTransaction(
return nil, fmt.Errorf("invalid csv script") return nil, fmt.Errorf("invalid csv script")
} }
sequence, err := common.BIP68EncodeAsNumber(sweepClosure.Seconds) sequence, err := common.BIP68Sequence(sweepClosure.Seconds)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -33,21 +33,31 @@ func (b *bitcoindRPCClient) broadcast(txhex string) error {
return nil return nil
} }
func (b *bitcoindRPCClient) getTxStatus(txid string) (isConfirmed bool, blocktime int64, err error) { func (b *bitcoindRPCClient) getTxStatus(txid string) (isConfirmed bool, height, blocktime int64, err error) {
txhash, err := chainhash.NewHashFromStr(txid) txhash, err := chainhash.NewHashFromStr(txid)
if err != nil { if err != nil {
return false, 0, err return false, 0, 0, err
} }
tx, err := b.chainClient.GetRawTransactionVerbose(txhash) tx, err := b.chainClient.GetRawTransactionVerbose(txhash)
if err != nil { if err != nil {
if strings.Contains(err.Error(), "No such mempool or blockchain transaction") { if strings.Contains(err.Error(), "No such mempool or blockchain transaction") {
return false, 0, nil return false, 0, 0, nil
} }
return false, 0, err return false, 0, 0, err
} }
return tx.Confirmations > 0, tx.Blocktime, nil blockHash, err := chainhash.NewHashFromStr(tx.BlockHash)
if err != nil {
return false, 0, 0, err
}
blockHeight, err := b.chainClient.GetBlockHeight(blockHash)
if err != nil {
return false, 0, 0, err
}
return tx.Confirmations > 0, int64(blockHeight), tx.Blocktime, nil
} }
func (b *bitcoindRPCClient) getTx(txid string) (*wire.MsgTx, error) { func (b *bitcoindRPCClient) getTx(txid string) (*wire.MsgTx, error) {

View File

@@ -20,8 +20,9 @@ type esploraClient struct {
type esploraTx struct { type esploraTx struct {
Status struct { Status struct {
Confirmed bool `json:"confirmed"` Confirmed bool `json:"confirmed"`
BlockTime int64 `json:"block_time"` BlockTime int64 `json:"block_time"`
BlockNumber int64 `json:"block_height"`
} `json:"status"` } `json:"status"`
} }
@@ -79,30 +80,30 @@ func (f *esploraClient) getTx(txid string) (*wire.MsgTx, error) {
return &tx, nil return &tx, nil
} }
func (f *esploraClient) getTxStatus(txid string) (isConfirmed bool, blocktime int64, err error) { func (f *esploraClient) getTxStatus(txid string) (isConfirmed bool, blocknumber, blocktime int64, err error) {
endpoint, err := url.JoinPath(f.url, "tx", txid) endpoint, err := url.JoinPath(f.url, "tx", txid)
if err != nil { if err != nil {
return false, 0, err return false, 0, 0, err
} }
resp, err := http.DefaultClient.Get(endpoint) resp, err := http.DefaultClient.Get(endpoint)
if err != nil { if err != nil {
return false, 0, err return false, 0, 0, err
} }
defer resp.Body.Close() defer resp.Body.Close()
if resp.StatusCode != http.StatusOK { if resp.StatusCode != http.StatusOK {
return false, 0, err return false, 0, 0, err
} }
var response esploraTx var response esploraTx
if err := json.NewDecoder(resp.Body).Decode(&response); err != nil { if err := json.NewDecoder(resp.Body).Decode(&response); err != nil {
return false, 0, err return false, 0, 0, err
} }
return response.Status.Confirmed, response.Status.BlockTime, nil return response.Status.Confirmed, response.Status.BlockNumber, response.Status.BlockTime, nil
} }
// GetFeeMap returns a map of sat/vbyte fees for different confirmation targets // GetFeeMap returns a map of sat/vbyte fees for different confirmation targets

View File

@@ -92,7 +92,7 @@ var (
// add additional chain API not supported by the chain.Interface type // add additional chain API not supported by the chain.Interface type
type extraChainAPI interface { type extraChainAPI interface {
getTx(txid string) (*wire.MsgTx, error) getTx(txid string) (*wire.MsgTx, error)
getTxStatus(txid string) (isConfirmed bool, blocktime int64, err error) getTxStatus(txid string) (isConfirmed bool, blockHeight, blocktime int64, err error)
broadcast(txHex string) error broadcast(txHex string) error
} }
@@ -957,7 +957,7 @@ func (s *service) GetNotificationChannel(
func (s *service) IsTransactionConfirmed( func (s *service) IsTransactionConfirmed(
ctx context.Context, txid string, ctx context.Context, txid string,
) (isConfirmed bool, blocktime int64, err error) { ) (isConfirmed bool, blocknumber int64, blocktime int64, err error) {
return s.extraAPI.getTxStatus(txid) return s.extraAPI.getTxStatus(txid)
} }

View File

@@ -158,21 +158,21 @@ func (s *service) BroadcastTransaction(
func (s *service) IsTransactionConfirmed( func (s *service) IsTransactionConfirmed(
ctx context.Context, txid string, ctx context.Context, txid string,
) (bool, int64, error) { ) (bool, int64, int64, error) {
_, isConfirmed, blocktime, err := s.getTransaction(ctx, txid) _, isConfirmed, blockheight, blocktime, err := s.getTransaction(ctx, txid)
if err != nil { if err != nil {
if strings.Contains(strings.ToLower(err.Error()), "missing transaction") { if strings.Contains(strings.ToLower(err.Error()), "missing transaction") {
return isConfirmed, 0, nil return isConfirmed, 0, 0, nil
} }
return false, 0, err return false, 0, 0, err
} }
return isConfirmed, blocktime, nil return isConfirmed, blockheight, blocktime, nil
} }
func (s *service) WaitForSync(ctx context.Context, txid string) error { func (s *service) WaitForSync(ctx context.Context, txid string) error {
for { for {
time.Sleep(5 * time.Second) time.Sleep(5 * time.Second)
_, _, _, err := s.getTransaction(ctx, txid) _, _, _, _, err := s.getTransaction(ctx, txid)
if err != nil { if err != nil {
if strings.Contains(strings.ToLower(err.Error()), "missing transaction") { if strings.Contains(strings.ToLower(err.Error()), "missing transaction") {
continue continue
@@ -351,7 +351,7 @@ func (s *service) EstimateFees(
} }
func (s *service) GetTransaction(ctx context.Context, txid string) (string, error) { func (s *service) GetTransaction(ctx context.Context, txid string) (string, error) {
txHex, _, _, err := s.getTransaction(ctx, txid) txHex, _, _, _, err := s.getTransaction(ctx, txid)
if err != nil { if err != nil {
return "", err return "", err
} }
@@ -361,18 +361,18 @@ func (s *service) GetTransaction(ctx context.Context, txid string) (string, erro
func (s *service) getTransaction( func (s *service) getTransaction(
ctx context.Context, txid string, ctx context.Context, txid string,
) (string, bool, int64, error) { ) (string, bool, int64, int64, error) {
res, err := s.txClient.GetTransaction(ctx, &pb.GetTransactionRequest{ res, err := s.txClient.GetTransaction(ctx, &pb.GetTransactionRequest{
Txid: txid, Txid: txid,
}) })
if err != nil { if err != nil {
return "", false, 0, err return "", false, 0, 0, err
} }
if res.GetBlockDetails().GetTimestamp() > 0 { if res.GetBlockDetails().GetTimestamp() > 0 {
return res.GetTxHex(), true, res.BlockDetails.GetTimestamp(), nil return res.GetTxHex(), true, int64(res.GetBlockDetails().GetHeight()), res.BlockDetails.GetTimestamp(), nil
} }
// if not confirmed, we return now + 1 min to estimate the next blocktime // if not confirmed, we return now + 1 min to estimate the next blocktime
return res.GetTxHex(), false, time.Now().Add(time.Minute).Unix(), nil return res.GetTxHex(), false, 0, time.Now().Add(time.Minute).Unix(), nil
} }

View File

@@ -202,6 +202,36 @@ func TestReactToSpentVtxosRedemption(t *testing.T) {
require.Empty(t, balance.OnchainBalance.LockedAmount) require.Empty(t, balance.OnchainBalance.LockedAmount)
} }
func TestSweep(t *testing.T) {
var receive utils.ArkReceive
receiveStr, err := runArkCommand("receive")
require.NoError(t, err)
err = json.Unmarshal([]byte(receiveStr), &receive)
require.NoError(t, err)
_, err = utils.RunCommand("nigiri", "faucet", "--liquid", receive.Boarding)
require.NoError(t, err)
time.Sleep(5 * time.Second)
_, err = runArkCommand("claim", "--password", utils.Password)
require.NoError(t, err)
time.Sleep(3 * time.Second)
_, err = utils.RunCommand("nigiri", "rpc", "--liquid", "generatetoaddress", "100", "el1qqwk722tghgkgmh3r2ph4d2apwj0dy9xnzlenzklx8jg3z299fpaw56trre9gpk6wmw0u4qycajqeva3t7lzp7wnacvwxha59r")
require.NoError(t, err)
time.Sleep(40 * time.Second)
var balance utils.ArkBalance
balanceStr, err := runArkCommand("balance")
require.NoError(t, err)
require.NoError(t, json.Unmarshal([]byte(balanceStr), &balance))
require.Zero(t, balance.Offchain.Total) // all funds should be swept
}
func runArkCommand(arg ...string) (string, error) { func runArkCommand(arg ...string) (string, error) {
args := append([]string{"exec", "-t", "arkd", "ark"}, arg...) args := append([]string{"exec", "-t", "arkd", "ark"}, arg...)
return utils.RunCommand("docker", args...) return utils.RunCommand("docker", args...)
@@ -276,20 +306,12 @@ func setupAspWallet() error {
return fmt.Errorf("failed to parse response: %s", err) return fmt.Errorf("failed to parse response: %s", err)
} }
if _, err := utils.RunCommand("nigiri", "faucet", "--liquid", addr.Address); err != nil { const numberOfFaucet = 6
return fmt.Errorf("failed to fund wallet: %s", err)
} for i := 0; i < numberOfFaucet; i++ {
if _, err := utils.RunCommand("nigiri", "faucet", "--liquid", addr.Address); err != nil { if _, err := utils.RunCommand("nigiri", "faucet", "--liquid", addr.Address); err != nil {
return fmt.Errorf("failed to fund wallet: %s", err) return fmt.Errorf("failed to fund wallet: %s with address %s", err, addr.Address)
} }
if _, err := utils.RunCommand("nigiri", "faucet", "--liquid", addr.Address); err != nil {
return fmt.Errorf("failed to fund wallet: %s", err)
}
if _, err := utils.RunCommand("nigiri", "faucet", "--liquid", addr.Address); err != nil {
return fmt.Errorf("failed to fund wallet: %s", err)
}
if _, err := utils.RunCommand("nigiri", "faucet", "--liquid", addr.Address); err != nil {
return fmt.Errorf("failed to fund wallet: %s", err)
} }
time.Sleep(5 * time.Second) time.Sleep(5 * time.Second)

View File

@@ -362,6 +362,36 @@ func TestAliceSeveralPaymentsToBob(t *testing.T) {
} }
func TestSweep(t *testing.T) {
var receive utils.ArkReceive
receiveStr, err := runClarkCommand("receive")
require.NoError(t, err)
err = json.Unmarshal([]byte(receiveStr), &receive)
require.NoError(t, err)
_, err = utils.RunCommand("nigiri", "faucet", receive.Boarding)
require.NoError(t, err)
time.Sleep(5 * time.Second)
_, err = runClarkCommand("claim", "--password", utils.Password)
require.NoError(t, err)
time.Sleep(3 * time.Second)
_, err = utils.RunCommand("nigiri", "rpc", "generatetoaddress", "100", "bcrt1qe8eelqalnch946nzhefd5ajhgl2afjw5aegc59")
require.NoError(t, err)
time.Sleep(40 * time.Second)
var balance utils.ArkBalance
balanceStr, err := runClarkCommand("balance")
require.NoError(t, err)
require.NoError(t, json.Unmarshal([]byte(balanceStr), &balance))
require.Zero(t, balance.Offchain.Total) // all funds should be swept
}
func runClarkCommand(arg ...string) (string, error) { func runClarkCommand(arg ...string) (string, error) {
args := append([]string{"exec", "-t", "clarkd", "ark"}, arg...) args := append([]string{"exec", "-t", "clarkd", "ark"}, arg...)
return utils.RunCommand("docker", args...) return utils.RunCommand("docker", args...)

View File

@@ -35,6 +35,9 @@ func GenerateBlock() error {
if _, err := RunCommand("nigiri", "rpc", "--liquid", "generatetoaddress", "1", "el1qqwk722tghgkgmh3r2ph4d2apwj0dy9xnzlenzklx8jg3z299fpaw56trre9gpk6wmw0u4qycajqeva3t7lzp7wnacvwxha59r"); err != nil { if _, err := RunCommand("nigiri", "rpc", "--liquid", "generatetoaddress", "1", "el1qqwk722tghgkgmh3r2ph4d2apwj0dy9xnzlenzklx8jg3z299fpaw56trre9gpk6wmw0u4qycajqeva3t7lzp7wnacvwxha59r"); err != nil {
return err return err
} }
if _, err := RunCommand("nigiri", "rpc", "generatetoaddress", "1", "bcrt1qe8eelqalnch946nzhefd5ajhgl2afjw5aegc59"); err != nil {
return err
}
time.Sleep(6 * time.Second) time.Sleep(6 * time.Second)
return nil return nil