Rename folders (#97)

* Rename arkd folder & drop cli

* Rename ark cli folder & update docs

* Update readme

* Fix

* scripts: add build-all

* Add target to build cli for all platforms

* Update build scripts

---------

Co-authored-by: tiero <3596602+tiero@users.noreply.github.com>
This commit is contained in:
Pietralberto Mazza
2024-02-09 19:32:58 +01:00
committed by GitHub
parent 0d8c7bffb2
commit dc00d60585
119 changed files with 154 additions and 449 deletions

View File

@@ -0,0 +1,601 @@
package application
import (
"bytes"
"context"
"encoding/hex"
"fmt"
"time"
"github.com/ark-network/ark/common"
"github.com/ark-network/ark/internal/core/domain"
"github.com/ark-network/ark/internal/core/ports"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/decred/dcrd/dcrec/secp256k1/v4"
log "github.com/sirupsen/logrus"
"github.com/vulpemventures/go-elements/network"
"github.com/vulpemventures/go-elements/psetv2"
)
var (
paymentsThreshold = int64(128)
dustAmount = uint64(450)
faucetVtxo = domain.VtxoKey{
Txid: "0000000000000000000000000000000000000000000000000000000000000000",
VOut: 0,
}
)
type Service interface {
Start() error
Stop()
SpendVtxos(ctx context.Context, inputs []domain.VtxoKey) (string, error)
ClaimVtxos(ctx context.Context, creds string, receivers []domain.Receiver) error
SignVtxos(ctx context.Context, forfeitTxs []string) error
FaucetVtxos(ctx context.Context, pubkey *secp256k1.PublicKey) error
GetRoundByTxid(ctx context.Context, poolTxid string) (*domain.Round, error)
GetEventsChannel(ctx context.Context) <-chan domain.RoundEvent
UpdatePaymentStatus(ctx context.Context, id string) error
ListVtxos(ctx context.Context, pubkey *secp256k1.PublicKey) ([]domain.Vtxo, error)
GetPubkey(ctx context.Context) (string, error)
}
type service struct {
network common.Network
onchainNework network.Network
pubkey *secp256k1.PublicKey
roundLifetime int64
roundInterval int64
minRelayFee uint64
wallet ports.WalletService
repoManager ports.RepoManager
builder ports.TxBuilder
scanner ports.BlockchainScanner
sweeper *sweeper
paymentRequests *paymentsMap
forfeitTxs *forfeitTxsMap
eventsCh chan domain.RoundEvent
}
func NewService(
network common.Network, onchainNetwork network.Network,
roundInterval, roundLifetime int64, minRelayFee uint64,
walletSvc ports.WalletService, repoManager ports.RepoManager,
builder ports.TxBuilder, scanner ports.BlockchainScanner,
scheduler ports.SchedulerService,
) (Service, error) {
eventsCh := make(chan domain.RoundEvent)
paymentRequests := newPaymentsMap(nil)
genesisHash, _ := chainhash.NewHashFromStr(onchainNetwork.GenesisBlockHash)
forfeitTxs := newForfeitTxsMap(genesisHash)
pubkey, err := walletSvc.GetPubkey(context.Background())
if err != nil {
return nil, fmt.Errorf("failed to fetch pubkey: %s", err)
}
sweeper := newSweeper(walletSvc, repoManager, builder, scheduler)
svc := &service{
network, onchainNetwork, pubkey,
roundLifetime, roundInterval, minRelayFee,
walletSvc, repoManager, builder, scanner, sweeper,
paymentRequests, forfeitTxs, eventsCh,
}
repoManager.RegisterEventsHandler(
func(round *domain.Round) {
svc.updateProjectionStore(round)
svc.propagateEvents(round)
},
)
if err := svc.restoreWatchingVtxos(); err != nil {
return nil, fmt.Errorf("failed to restore watching vtxos: %s", err)
}
go svc.listenToRedemptions()
return svc, nil
}
func (s *service) Start() error {
log.Debug("starting sweeper service")
if err := s.sweeper.start(); err != nil {
return err
}
log.Debug("starting app service")
go s.start()
return nil
}
func (s *service) Stop() {
s.sweeper.stop()
// nolint
vtxos, _ := s.repoManager.Vtxos().GetSpendableVtxos(
context.Background(), "",
)
if len(vtxos) > 0 {
s.stopWatchingVtxos(vtxos)
}
s.wallet.Close()
log.Debug("closed connection to wallet")
s.repoManager.Close()
log.Debug("closed connection to db")
}
func (s *service) SpendVtxos(ctx context.Context, inputs []domain.VtxoKey) (string, error) {
vtxos, err := s.repoManager.Vtxos().GetVtxos(ctx, inputs)
if err != nil {
return "", err
}
for _, v := range vtxos {
if v.Spent {
return "", fmt.Errorf("input %s:%d already spent", v.Txid, v.VOut)
}
}
payment, err := domain.NewPayment(vtxos)
if err != nil {
return "", err
}
if err := s.paymentRequests.push(*payment); err != nil {
return "", err
}
return payment.Id, nil
}
func (s *service) ClaimVtxos(ctx context.Context, creds string, receivers []domain.Receiver) error {
// Check credentials
payment, ok := s.paymentRequests.view(creds)
if !ok {
return fmt.Errorf("invalid credentials")
}
if err := payment.AddReceivers(receivers); err != nil {
return err
}
return s.paymentRequests.update(*payment)
}
func (s *service) UpdatePaymentStatus(_ context.Context, id string) error {
return s.paymentRequests.updatePingTimestamp(id)
}
func (s *service) FaucetVtxos(ctx context.Context, userPubkey *secp256k1.PublicKey) error {
pubkey := hex.EncodeToString(userPubkey.SerializeCompressed())
payment, err := domain.NewPayment([]domain.Vtxo{
{
VtxoKey: faucetVtxo,
Receiver: domain.Receiver{
Pubkey: pubkey,
Amount: 10000,
},
},
})
if err != nil {
return err
}
if err := payment.AddReceivers([]domain.Receiver{
{Pubkey: pubkey, Amount: 1000},
{Pubkey: pubkey, Amount: 1000},
{Pubkey: pubkey, Amount: 1000},
{Pubkey: pubkey, Amount: 1000},
{Pubkey: pubkey, Amount: 1000},
{Pubkey: pubkey, Amount: 1000},
{Pubkey: pubkey, Amount: 1000},
{Pubkey: pubkey, Amount: 1000},
{Pubkey: pubkey, Amount: 1000},
{Pubkey: pubkey, Amount: 1000},
}); err != nil {
return err
}
if err := s.paymentRequests.push(*payment); err != nil {
return err
}
return s.paymentRequests.updatePingTimestamp(payment.Id)
}
func (s *service) SignVtxos(ctx context.Context, forfeitTxs []string) error {
return s.forfeitTxs.sign(forfeitTxs)
}
func (s *service) ListVtxos(ctx context.Context, pubkey *secp256k1.PublicKey) ([]domain.Vtxo, error) {
pk := hex.EncodeToString(pubkey.SerializeCompressed())
return s.repoManager.Vtxos().GetSpendableVtxos(ctx, pk)
}
func (s *service) GetEventsChannel(ctx context.Context) <-chan domain.RoundEvent {
return s.eventsCh
}
func (s *service) GetRoundByTxid(ctx context.Context, poolTxid string) (*domain.Round, error) {
return s.repoManager.Rounds().GetRoundWithTxid(ctx, poolTxid)
}
func (s *service) GetPubkey(ctx context.Context) (string, error) {
pubkey, err := common.EncodePubKey(s.network.PubKey, s.pubkey)
if err != nil {
return "", err
}
return pubkey, nil
}
func (s *service) start() {
s.startRound()
}
func (s *service) startRound() {
round := domain.NewRound(dustAmount)
changes, _ := round.StartRegistration()
if err := s.repoManager.Events().Save(
context.Background(), round.Id, changes...,
); err != nil {
log.WithError(err).Warn("failed to store new round events")
return
}
defer func() {
time.Sleep(time.Duration(s.roundInterval/2) * time.Second)
s.startFinalization()
}()
log.Debugf("started registration stage for new round: %s", round.Id)
}
func (s *service) startFinalization() {
ctx := context.Background()
round, err := s.repoManager.Rounds().GetCurrentRound(ctx)
if err != nil {
log.WithError(err).Warn("failed to retrieve current round")
return
}
var changes []domain.RoundEvent
defer func() {
if len(changes) > 0 {
if err := s.repoManager.Events().Save(ctx, round.Id, changes...); err != nil {
log.WithError(err).Warn("failed to store new round events")
}
}
if round.IsFailed() {
s.startRound()
return
}
time.Sleep(time.Duration((s.roundInterval/2)-1) * time.Second)
s.finalizeRound()
}()
if round.IsFailed() {
return
}
// TODO: understand how many payments must be popped from the queue and actually registered for the round
num := s.paymentRequests.len()
if num == 0 {
err := fmt.Errorf("no payments registered")
changes = round.Fail(fmt.Errorf("round aborted: %s", err))
log.WithError(err).Debugf("round %s aborted", round.Id)
return
}
if num > paymentsThreshold {
num = paymentsThreshold
}
payments := s.paymentRequests.pop(num)
changes, err = round.RegisterPayments(payments)
if err != nil {
changes = round.Fail(fmt.Errorf("failed to register payments: %s", err))
log.WithError(err).Warn("failed to register payments")
return
}
unsignedPoolTx, tree, err := s.builder.BuildPoolTx(s.pubkey, payments, s.minRelayFee)
if err != nil {
changes = round.Fail(fmt.Errorf("failed to create pool tx: %s", err))
log.WithError(err).Warn("failed to create pool tx")
return
}
log.Debugf("pool tx created for round %s", round.Id)
connectors, forfeitTxs, err := s.builder.BuildForfeitTxs(s.pubkey, unsignedPoolTx, payments)
if err != nil {
changes = round.Fail(fmt.Errorf("failed to create connectors and forfeit txs: %s", err))
log.WithError(err).Warn("failed to create connectors and forfeit txs")
return
}
log.Debugf("forfeit transactions created for round %s", round.Id)
events, err := round.StartFinalization(connectors, tree, unsignedPoolTx)
if err != nil {
changes = round.Fail(fmt.Errorf("failed to start finalization: %s", err))
log.WithError(err).Warn("failed to start finalization")
return
}
changes = append(changes, events...)
s.forfeitTxs.push(forfeitTxs)
log.Debugf("started finalization stage for round: %s", round.Id)
}
func (s *service) finalizeRound() {
defer s.startRound()
ctx := context.Background()
round, err := s.repoManager.Rounds().GetCurrentRound(ctx)
if err != nil {
log.WithError(err).Warn("failed to retrieve current round")
return
}
if round.IsFailed() {
return
}
var changes []domain.RoundEvent
defer func() {
if err := s.repoManager.Events().Save(ctx, round.Id, changes...); err != nil {
log.WithError(err).Warn("failed to store new round events")
return
}
}()
forfeitTxs, leftUnsigned := s.forfeitTxs.pop()
if len(leftUnsigned) > 0 {
err := fmt.Errorf("%d forfeit txs left to sign", len(leftUnsigned))
changes = round.Fail(fmt.Errorf("failed to finalize round: %s", err))
log.WithError(err).Warn("failed to finalize round")
return
}
log.Debugf("signing round transaction %s\n", round.Id)
signedPoolTx, err := s.wallet.SignPset(ctx, round.UnsignedTx, true)
if err != nil {
changes = round.Fail(fmt.Errorf("failed to sign round tx: %s", err))
log.WithError(err).Warn("failed to sign round tx")
return
}
txid, err := s.wallet.BroadcastTransaction(ctx, signedPoolTx)
if err != nil {
changes = round.Fail(fmt.Errorf("failed to broadcast pool tx: %s", err))
log.WithError(err).Warn("failed to broadcast pool tx")
return
}
now := time.Now().Unix()
expirationTimestamp := now + s.roundLifetime + 30 // add 30 secs to be sure that the tx is confirmed
if err := s.sweeper.schedule(expirationTimestamp, txid, round.CongestionTree); err != nil {
changes = round.Fail(fmt.Errorf("failed to schedule sweep tx: %s", err))
log.WithError(err).Warn("failed to schedule sweep tx")
return
}
changes, _ = round.EndFinalization(forfeitTxs, txid)
log.Debugf("finalized round %s with pool tx %s", round.Id, round.Txid)
}
func (s *service) listenToRedemptions() {
ctx := context.Background()
chVtxos := s.scanner.GetNotificationChannel(ctx)
for vtxoKeys := range chVtxos {
if len(vtxoKeys) > 0 {
for {
// TODO: make sure that the vtxos haven't been already spent, otherwise
// broadcast the corresponding forfeit tx and connector to prevent
// getting cheated.
vtxos, err := s.repoManager.Vtxos().RedeemVtxos(ctx, vtxoKeys)
if err != nil {
log.WithError(err).Warn("failed to redeem vtxos, retrying...")
time.Sleep(100 * time.Millisecond)
continue
}
if len(vtxos) > 0 {
log.Debugf("redeemed %d vtxos", len(vtxos))
}
break
}
}
}
}
func (s *service) updateProjectionStore(round *domain.Round) {
ctx := context.Background()
lastChange := round.Events()[len(round.Events())-1]
// Update the vtxo set only after a round is finalized.
if _, ok := lastChange.(domain.RoundFinalized); ok {
repo := s.repoManager.Vtxos()
spentVtxos := getSpentVtxos(round.Payments)
if len(spentVtxos) > 0 {
for {
if err := repo.SpendVtxos(ctx, spentVtxos); err != nil {
log.WithError(err).Warn("failed to add new vtxos, retrying soon")
time.Sleep(100 * time.Millisecond)
continue
}
log.Debugf("spent %d vtxos", len(spentVtxos))
break
}
}
newVtxos := s.getNewVtxos(round)
for {
if err := repo.AddVtxos(ctx, newVtxos); err != nil {
log.WithError(err).Warn("failed to add new vtxos, retrying soon")
time.Sleep(100 * time.Millisecond)
continue
}
log.Debugf("added %d new vtxos", len(newVtxos))
break
}
go func() {
for {
if err := s.startWatchingVtxos(newVtxos); err != nil {
log.WithError(err).Warn(
"failed to start watching vtxos, retrying in a moment...",
)
continue
}
log.Debugf("started watching %d vtxos", len(newVtxos))
return
}
}()
}
// Always update the status of the round.
for {
if err := s.repoManager.Rounds().AddOrUpdateRound(ctx, *round); err != nil {
time.Sleep(100 * time.Millisecond)
continue
}
break
}
}
func (s *service) propagateEvents(round *domain.Round) {
lastEvent := round.Events()[len(round.Events())-1]
switch e := lastEvent.(type) {
case domain.RoundFinalizationStarted:
forfeitTxs := s.forfeitTxs.view()
s.eventsCh <- domain.RoundFinalizationStarted{
Id: e.Id,
CongestionTree: e.CongestionTree,
Connectors: e.Connectors,
PoolTx: e.PoolTx,
UnsignedForfeitTxs: forfeitTxs,
}
case domain.RoundFinalized, domain.RoundFailed:
s.eventsCh <- e
}
}
func (s *service) getNewVtxos(round *domain.Round) []domain.Vtxo {
leaves := round.CongestionTree.Leaves()
vtxos := make([]domain.Vtxo, 0)
for _, node := range leaves {
tx, _ := psetv2.NewPsetFromBase64(node.Tx)
for i, out := range tx.Outputs {
for _, p := range round.Payments {
var pubkey string
found := false
for _, r := range p.Receivers {
if r.IsOnchain() {
continue
}
buf, _ := hex.DecodeString(r.Pubkey)
pk, _ := secp256k1.ParsePubKey(buf)
script, _ := s.builder.GetVtxoScript(pk, s.pubkey)
if bytes.Equal(script, out.Script) {
found = true
pubkey = r.Pubkey
break
}
}
if found {
vtxos = append(vtxos, domain.Vtxo{
VtxoKey: domain.VtxoKey{Txid: node.Txid, VOut: uint32(i)},
Receiver: domain.Receiver{Pubkey: pubkey, Amount: out.Value},
PoolTx: round.Txid,
})
break
}
}
}
}
return vtxos
}
func (s *service) startWatchingVtxos(vtxos []domain.Vtxo) error {
scripts, err := s.extractVtxosScripts(vtxos)
if err != nil {
return err
}
return s.scanner.WatchScripts(context.Background(), scripts)
}
func (s *service) stopWatchingVtxos(vtxos []domain.Vtxo) {
scripts, err := s.extractVtxosScripts(vtxos)
if err != nil {
log.WithError(err).Warn("failed to extract scripts from vtxos")
return
}
for {
if err := s.scanner.UnwatchScripts(context.Background(), scripts); err != nil {
log.WithError(err).Warn("failed to stop watching vtxos, retrying in a moment...")
time.Sleep(100 * time.Millisecond)
continue
}
log.Debugf("stopped watching %d vtxos", len(vtxos))
break
}
}
func (s *service) restoreWatchingVtxos() error {
vtxos, err := s.repoManager.Vtxos().GetSpendableVtxos(
context.Background(), "",
)
if err != nil {
return err
}
if len(vtxos) <= 0 {
return nil
}
if err := s.startWatchingVtxos(vtxos); err != nil {
return err
}
log.Debugf("restored watching %d vtxos", len(vtxos))
return nil
}
func (s *service) extractVtxosScripts(vtxos []domain.Vtxo) ([]string, error) {
indexedScripts := make(map[string]struct{})
for _, vtxo := range vtxos {
buf, err := hex.DecodeString(vtxo.Pubkey)
if err != nil {
return nil, err
}
userPubkey, err := secp256k1.ParsePubKey(buf)
if err != nil {
return nil, err
}
script, err := s.builder.GetVtxoScript(userPubkey, s.pubkey)
if err != nil {
return nil, err
}
indexedScripts[hex.EncodeToString(script)] = struct{}{}
}
scripts := make([]string, 0, len(indexedScripts))
for script := range indexedScripts {
scripts = append(scripts, script)
}
return scripts, nil
}
func getSpentVtxos(payments map[string]domain.Payment) []domain.VtxoKey {
vtxos := make([]domain.VtxoKey, 0)
for _, p := range payments {
for _, vtxo := range p.Inputs {
if vtxo.VtxoKey == faucetVtxo {
continue
}
vtxos = append(vtxos, vtxo.VtxoKey)
}
}
return vtxos
}

View File

@@ -0,0 +1,579 @@
package application
import (
"context"
"encoding/hex"
"fmt"
"time"
"github.com/ark-network/ark/common/tree"
"github.com/ark-network/ark/internal/core/domain"
"github.com/ark-network/ark/internal/core/ports"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/decred/dcrd/dcrec/secp256k1/v4"
log "github.com/sirupsen/logrus"
"github.com/vulpemventures/go-elements/psetv2"
)
// sweeper is an unexported service running while the main application service is started
// it is responsible for sweeping onchain shared outputs that expired
// it also handles delaying the sweep events in case some parts of the tree are broadcasted
// when a round is finalized, the main application service schedules a sweep event on the newly created congestion tree
type sweeper struct {
wallet ports.WalletService
repoManager ports.RepoManager
builder ports.TxBuilder
scheduler ports.SchedulerService
// cache of scheduled tasks, avoid scheduling the same sweep event multiple times
scheduledTasks map[string]struct{}
}
func newSweeper(
wallet ports.WalletService,
repoManager ports.RepoManager,
builder ports.TxBuilder,
scheduler ports.SchedulerService,
) *sweeper {
return &sweeper{
wallet,
repoManager,
builder,
scheduler,
make(map[string]struct{}),
}
}
func (s *sweeper) start() error {
s.scheduler.Start()
allRounds, err := s.repoManager.Rounds().GetSweepableRounds(context.Background())
if err != nil {
return err
}
for _, round := range allRounds {
task := s.createTask(round.Txid, round.CongestionTree)
task()
}
return nil
}
func (s *sweeper) stop() {
s.scheduler.Stop()
}
// removeTask update the cached map of scheduled tasks
func (s *sweeper) removeTask(treeRootTxid string) {
delete(s.scheduledTasks, treeRootTxid)
}
// schedule set up a task to be executed once at the given timestamp
func (s *sweeper) schedule(
expirationTimestamp int64, roundTxid string, congestionTree tree.CongestionTree,
) error {
root, err := congestionTree.Root()
if err != nil {
return err
}
if _, scheduled := s.scheduledTasks[root.Txid]; scheduled {
return nil
}
task := s.createTask(roundTxid, congestionTree)
fancyTime := time.Unix(expirationTimestamp, 0).Format("2006-01-02 15:04:05")
log.Debugf("scheduled sweep task at %s", fancyTime)
if err := s.scheduler.ScheduleTaskOnce(expirationTimestamp, task); err != nil {
return err
}
s.scheduledTasks[root.Txid] = struct{}{}
return nil
}
// createTask returns a function passed as handler in the scheduler
// it tries to craft a sweep tx containing the onchain outputs of the given congestion tree
// if some parts of the tree have been broadcasted in the meantine, it will schedule the next taskes for the remaining parts of the tree
func (s *sweeper) createTask(
roundTxid string, congestionTree tree.CongestionTree,
) func() {
return func() {
ctx := context.Background()
root, err := congestionTree.Root()
if err != nil {
log.WithError(err).Error("error while getting root node")
return
}
s.removeTask(root.Txid)
log.Debugf("sweeper: %s", root.Txid)
sweepInputs := make([]ports.SweepInput, 0)
vtxoKeys := make([]domain.VtxoKey, 0) // vtxos associated to the sweep inputs
// inspect the congestion tree to find onchain shared outputs
sharedOutputs, err := s.findSweepableOutputs(ctx, congestionTree)
if err != nil {
log.WithError(err).Error("error while inspecting congestion tree")
return
}
for expiredAt, inputs := range sharedOutputs {
// if the shared outputs are not expired, schedule a sweep task for it
if time.Unix(expiredAt, 0).After(time.Now()) {
subtrees, err := computeSubTrees(congestionTree, inputs)
if err != nil {
log.WithError(err).Error("error while computing subtrees")
continue
}
for _, subTree := range subtrees {
// mitigate the risk to get BIP68 non-final errors by scheduling the task 30 seconds after the expiration time
if err := s.schedule(int64(expiredAt), roundTxid, subTree); err != nil {
log.WithError(err).Error("error while scheduling sweep task")
continue
}
}
continue
}
// iterate over the expired shared outputs
for _, input := range inputs {
// sweepableVtxos related to the sweep input
sweepableVtxos := make([]domain.VtxoKey, 0)
// check if input is the vtxo itself
vtxos, _ := s.repoManager.Vtxos().GetVtxos(
ctx,
[]domain.VtxoKey{
{
Txid: input.InputArgs.Txid,
VOut: input.InputArgs.TxIndex,
},
},
)
if len(vtxos) > 0 {
if !vtxos[0].Swept && !vtxos[0].Redeemed {
sweepableVtxos = append(sweepableVtxos, vtxos[0].VtxoKey)
}
} else {
// if it's not a vtxo, find all the vtxos leaves reachable from that input
vtxosLeaves, err := congestionTree.FindLeaves(input.InputArgs.Txid, input.InputArgs.TxIndex)
if err != nil {
log.WithError(err).Error("error while finding vtxos leaves")
continue
}
for _, leaf := range vtxosLeaves {
pset, err := psetv2.NewPsetFromBase64(leaf.Tx)
if err != nil {
log.Error(fmt.Errorf("error while decoding pset: %w", err))
continue
}
vtxo, err := extractVtxoOutpoint(pset)
if err != nil {
log.Error(err)
continue
}
sweepableVtxos = append(sweepableVtxos, *vtxo)
}
if len(sweepableVtxos) <= 0 {
continue
}
firstVtxo, err := s.repoManager.Vtxos().GetVtxos(ctx, sweepableVtxos[:1])
if err != nil {
log.Error(fmt.Errorf("error while getting vtxo: %w", err))
sweepInputs = append(sweepInputs, input) // add the input anyway in order to try to sweep it
continue
}
if firstVtxo[0].Swept || firstVtxo[0].Redeemed {
// we assume that if the first vtxo is swept or redeemed, the shared output has been spent
// skip, the output is already swept or spent by a unilateral redeem
continue
}
}
if len(sweepableVtxos) > 0 {
vtxoKeys = append(vtxoKeys, sweepableVtxos...)
sweepInputs = append(sweepInputs, input)
}
}
}
if len(sweepInputs) > 0 {
// build the sweep transaction with all the expired non-swept shared outputs
sweepTx, err := s.builder.BuildSweepTx(s.wallet, sweepInputs)
if err != nil {
log.WithError(err).Error("error while building sweep tx")
return
}
err = nil
txid := ""
// retry until the tx is broadcasted or the error is not BIP68 final
for len(txid) == 0 && (err == nil || err == fmt.Errorf("non-BIP68-final")) {
if err != nil {
log.Debugln("sweep tx not BIP68 final, retrying in 5 seconds")
time.Sleep(5 * time.Second)
}
txid, err = s.wallet.BroadcastTransaction(ctx, sweepTx)
}
if err != nil {
log.WithError(err).Error("error while broadcasting sweep tx")
return
}
if len(txid) > 0 {
log.Debugln("sweep tx broadcasted:", txid)
vtxosRepository := s.repoManager.Vtxos()
// mark the vtxos as swept
if err := vtxosRepository.SweepVtxos(ctx, vtxoKeys); err != nil {
log.Error(fmt.Errorf("error while deleting vtxos: %w", err))
return
}
log.Debugf("%d vtxos swept", len(vtxoKeys))
roundVtxos, err := vtxosRepository.GetVtxosForRound(ctx, roundTxid)
if err != nil {
log.WithError(err).Error("error while getting vtxos for round")
return
}
allSwept := true
for _, vtxo := range roundVtxos {
allSwept = allSwept && vtxo.Swept
if !allSwept {
break
}
}
if allSwept {
// update the round
roundRepo := s.repoManager.Rounds()
round, err := roundRepo.GetRoundWithTxid(ctx, roundTxid)
if err != nil {
log.WithError(err).Error("error while getting round")
return
}
round.Sweep()
if err := roundRepo.AddOrUpdateRound(ctx, *round); err != nil {
log.WithError(err).Error("error while marking round as swept")
return
}
}
}
}
}
}
// onchainOutputs iterates over all the nodes' outputs in the congestion tree and checks their onchain state
// returns the sweepable outputs as ports.SweepInput mapped by their expiration time
func (s *sweeper) findSweepableOutputs(
ctx context.Context,
congestionTree tree.CongestionTree,
) (map[int64][]ports.SweepInput, error) {
sweepableOutputs := make(map[int64][]ports.SweepInput)
blocktimeCache := make(map[string]int64) // txid -> blocktime
nodesToCheck := congestionTree[0] // init with the root
for len(nodesToCheck) > 0 {
newNodesToCheck := make([]tree.Node, 0)
for _, node := range nodesToCheck {
isPublished, blocktime, err := s.wallet.IsTransactionPublished(ctx, node.Txid)
if err != nil {
return nil, err
}
var expirationTime int64
var sweepInputs []ports.SweepInput
if !isPublished {
if _, ok := blocktimeCache[node.ParentTxid]; !ok {
isPublished, blocktime, err := s.wallet.IsTransactionPublished(ctx, node.ParentTxid)
if !isPublished || err != nil {
return nil, fmt.Errorf("tx %s not found", node.Txid)
}
blocktimeCache[node.ParentTxid] = blocktime
}
expirationTime, sweepInputs, err = s.nodeToSweepInputs(blocktimeCache[node.ParentTxid], node)
if err != nil {
return nil, err
}
} else {
// cache the blocktime for future use
blocktimeCache[node.Txid] = int64(blocktime)
// if the tx is onchain, it means that the input is spent
// add the children to the nodes in order to check them during the next iteration
// We will return the error below, but are we going to schedule the tasks for the "children roots"?
if !node.Leaf {
children := congestionTree.Children(node.Txid)
newNodesToCheck = append(newNodesToCheck, children...)
continue
}
// if the node is a leaf, the vtxos outputs should added as onchain outputs if they are not swept yet
vtxoExpiration, sweepInput, err := s.leafToSweepInput(ctx, blocktime, node)
if err != nil {
return nil, err
}
if sweepInput != nil {
expirationTime = vtxoExpiration
sweepInputs = []ports.SweepInput{*sweepInput}
}
}
if _, ok := sweepableOutputs[expirationTime]; !ok {
sweepableOutputs[expirationTime] = make([]ports.SweepInput, 0)
}
sweepableOutputs[expirationTime] = append(sweepableOutputs[expirationTime], sweepInputs...)
}
nodesToCheck = newNodesToCheck
}
return sweepableOutputs, nil
}
func (s *sweeper) leafToSweepInput(ctx context.Context, txBlocktime int64, node tree.Node) (int64, *ports.SweepInput, error) {
pset, err := psetv2.NewPsetFromBase64(node.Tx)
if err != nil {
return -1, nil, err
}
vtxo, err := extractVtxoOutpoint(pset)
if err != nil {
return -1, nil, err
}
fromRepo, err := s.repoManager.Vtxos().GetVtxos(ctx, []domain.VtxoKey{*vtxo})
if err != nil {
return -1, nil, err
}
if len(fromRepo) == 0 {
return -1, nil, fmt.Errorf("vtxo not found")
}
if fromRepo[0].Swept {
return -1, nil, nil
}
if fromRepo[0].Redeemed {
return -1, nil, nil
}
// if the vtxo is not swept or redeemed, add it to the onchain outputs
pubKeyBytes, err := hex.DecodeString(fromRepo[0].Pubkey)
if err != nil {
return -1, nil, err
}
pubKey, err := secp256k1.ParsePubKey(pubKeyBytes)
if err != nil {
return -1, nil, err
}
sweepLeaf, lifetime, err := s.builder.GetLeafSweepClosure(node, pubKey)
if err != nil {
return -1, nil, err
}
sweepInput := ports.SweepInput{
InputArgs: psetv2.InputArgs{
Txid: vtxo.Txid,
TxIndex: vtxo.VOut,
},
SweepLeaf: *sweepLeaf,
Amount: fromRepo[0].Amount,
}
return txBlocktime + lifetime, &sweepInput, nil
}
func (s *sweeper) nodeToSweepInputs(parentBlocktime int64, node tree.Node) (int64, []ports.SweepInput, error) {
pset, err := psetv2.NewPsetFromBase64(node.Tx)
if err != nil {
return -1, nil, err
}
if len(pset.Inputs) != 1 {
return -1, nil, fmt.Errorf("invalid node pset, expect 1 input, got %d", len(pset.Inputs))
}
// if the tx is not onchain, it means that the input is an existing shared output
input := pset.Inputs[0]
txid := chainhash.Hash(input.PreviousTxid).String()
index := input.PreviousTxIndex
sweepLeaf, lifetime, err := extractSweepLeaf(input)
if err != nil {
return -1, nil, err
}
expirationTime := parentBlocktime + lifetime
amount := uint64(0)
for _, out := range pset.Outputs {
amount += out.Value
}
sweepInputs := []ports.SweepInput{
{
InputArgs: psetv2.InputArgs{
Txid: txid,
TxIndex: index,
},
SweepLeaf: *sweepLeaf,
Amount: amount,
},
}
return expirationTime, sweepInputs, nil
}
func computeSubTrees(congestionTree tree.CongestionTree, inputs []ports.SweepInput) ([]tree.CongestionTree, error) {
subTrees := make(map[string]tree.CongestionTree, 0)
// for each sweepable input, create a sub congestion tree
// it allows to skip the part of the tree that has been broadcasted in the next task
for _, input := range inputs {
subTree, err := computeSubTree(congestionTree, input.InputArgs.Txid)
if err != nil {
log.WithError(err).Error("error while finding sub tree")
continue
}
root, err := subTree.Root()
if err != nil {
log.WithError(err).Error("error while getting root node")
continue
}
subTrees[root.Txid] = subTree
}
// filter out the sub trees, remove the ones that are included in others
filteredSubTrees := make([]tree.CongestionTree, 0)
for i, subTree := range subTrees {
notIncludedInOtherTrees := true
for j, otherSubTree := range subTrees {
if i == j {
continue
}
contains, err := containsTree(otherSubTree, subTree)
if err != nil {
log.WithError(err).Error("error while checking if a tree contains another")
continue
}
if contains {
notIncludedInOtherTrees = false
break
}
}
if notIncludedInOtherTrees {
filteredSubTrees = append(filteredSubTrees, subTree)
}
}
return filteredSubTrees, nil
}
func computeSubTree(congestionTree tree.CongestionTree, newRoot string) (tree.CongestionTree, error) {
for _, level := range congestionTree {
for _, node := range level {
if node.Txid == newRoot || node.ParentTxid == newRoot {
newTree := make(tree.CongestionTree, 0)
newTree = append(newTree, []tree.Node{node})
children := congestionTree.Children(node.Txid)
for len(children) > 0 {
newTree = append(newTree, children)
newChildren := make([]tree.Node, 0)
for _, child := range children {
newChildren = append(newChildren, congestionTree.Children(child.Txid)...)
}
children = newChildren
}
return newTree, nil
}
}
}
return nil, fmt.Errorf("failed to create subtree, new root not found")
}
func containsTree(tr0 tree.CongestionTree, tr1 tree.CongestionTree) (bool, error) {
tr1Root, err := tr1.Root()
if err != nil {
return false, err
}
for _, level := range tr0 {
for _, node := range level {
if node.Txid == tr1Root.Txid {
return true, nil
}
}
}
return false, nil
}
// given a congestion tree input, searches and returns the sweep leaf and its lifetime in seconds
func extractSweepLeaf(input psetv2.Input) (sweepLeaf *psetv2.TapLeafScript, lifetime int64, err error) {
for _, leaf := range input.TapLeafScript {
isSweep, _, seconds, err := tree.DecodeSweepScript(leaf.Script)
if err != nil {
return nil, 0, err
}
if isSweep {
lifetime = int64(seconds)
sweepLeaf = &leaf
break
}
}
if sweepLeaf == nil {
return nil, 0, fmt.Errorf("sweep leaf not found")
}
return sweepLeaf, lifetime, nil
}
// assuming the pset is a leaf in the congestion tree, returns the vtxos outputs
func extractVtxoOutpoint(pset *psetv2.Pset) (*domain.VtxoKey, error) {
if len(pset.Outputs) != 2 {
return nil, fmt.Errorf("invalid leaf pset, expect 2 outputs, got %d", len(pset.Outputs))
}
utx, err := pset.UnsignedTx()
if err != nil {
return nil, err
}
return &domain.VtxoKey{
Txid: utx.TxHash().String(),
VOut: 0,
}, nil
}

View File

@@ -0,0 +1,254 @@
package application
import (
"bytes"
"encoding/hex"
"fmt"
"sort"
"sync"
"time"
"github.com/ark-network/ark/common"
"github.com/ark-network/ark/internal/core/domain"
"github.com/btcsuite/btcd/btcec/v2/schnorr"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/vulpemventures/go-elements/psetv2"
)
type timedPayment struct {
domain.Payment
timestamp time.Time
pingTimestamp time.Time
}
type paymentsMap struct {
lock *sync.RWMutex
payments map[string]*timedPayment
}
func newPaymentsMap(payments []domain.Payment) *paymentsMap {
paymentsById := make(map[string]*timedPayment)
for _, p := range payments {
paymentsById[p.Id] = &timedPayment{p, time.Now(), time.Time{}}
}
lock := &sync.RWMutex{}
return &paymentsMap{lock, paymentsById}
}
func (m *paymentsMap) len() int64 {
m.lock.RLock()
defer m.lock.RUnlock()
count := int64(0)
for _, p := range m.payments {
if len(p.Receivers) > 0 {
count++
}
}
return count
}
func (m *paymentsMap) push(payment domain.Payment) error {
m.lock.Lock()
defer m.lock.Unlock()
if _, ok := m.payments[payment.Id]; ok {
return fmt.Errorf("duplicated inputs")
}
m.payments[payment.Id] = &timedPayment{payment, time.Now(), time.Time{}}
return nil
}
func (m *paymentsMap) pop(num int64) []domain.Payment {
m.lock.Lock()
defer m.lock.Unlock()
paymentsByTime := make([]timedPayment, 0, len(m.payments))
for _, p := range m.payments {
// Skip payments without registered receivers.
if len(p.Receivers) <= 0 {
continue
}
// Skip payments for which users didn't notify to be online in the last minute.
if p.pingTimestamp.IsZero() || time.Since(p.pingTimestamp).Minutes() > 1 {
continue
}
paymentsByTime = append(paymentsByTime, *p)
}
sort.SliceStable(paymentsByTime, func(i, j int) bool {
return paymentsByTime[i].timestamp.Before(paymentsByTime[j].timestamp)
})
if num < 0 || num > int64(len(paymentsByTime)) {
num = int64(len(paymentsByTime))
}
payments := make([]domain.Payment, 0, num)
for _, p := range paymentsByTime[:num] {
payments = append(payments, p.Payment)
delete(m.payments, p.Id)
}
return payments
}
func (m *paymentsMap) update(payment domain.Payment) error {
m.lock.Lock()
defer m.lock.Unlock()
p, ok := m.payments[payment.Id]
if !ok {
return fmt.Errorf("payment %s not found", payment.Id)
}
p.Payment = payment
return nil
}
func (m *paymentsMap) updatePingTimestamp(id string) error {
m.lock.Lock()
defer m.lock.Unlock()
payment, ok := m.payments[id]
if !ok {
return fmt.Errorf("payment %s not found", id)
}
payment.pingTimestamp = time.Now()
return nil
}
func (m *paymentsMap) view(id string) (*domain.Payment, bool) {
m.lock.RLock()
defer m.lock.RUnlock()
payment, ok := m.payments[id]
if !ok {
return nil, false
}
return &domain.Payment{
Id: payment.Id,
Inputs: payment.Inputs,
Receivers: payment.Receivers,
}, true
}
type signedTx struct {
tx string
signed bool
}
type forfeitTxsMap struct {
lock *sync.RWMutex
forfeitTxs map[string]*signedTx
genesisBlockHash *chainhash.Hash
}
func newForfeitTxsMap(genesisBlockHash *chainhash.Hash) *forfeitTxsMap {
return &forfeitTxsMap{&sync.RWMutex{}, make(map[string]*signedTx), genesisBlockHash}
}
func (m *forfeitTxsMap) push(txs []string) {
m.lock.Lock()
defer m.lock.Unlock()
faucetTxID, _ := hex.DecodeString(faucetVtxo.Txid)
for _, tx := range txs {
ptx, _ := psetv2.NewPsetFromBase64(tx)
utx, _ := ptx.UnsignedTx()
signed := false
// find the faucet vtxos, and mark them as signed
for _, input := range ptx.Inputs {
if bytes.Equal(input.PreviousTxid, faucetTxID) {
signed = true
break
}
}
m.forfeitTxs[utx.TxHash().String()] = &signedTx{tx, signed}
}
}
func (m *forfeitTxsMap) sign(txs []string) error {
m.lock.Lock()
defer m.lock.Unlock()
for _, tx := range txs {
ptx, _ := psetv2.NewPsetFromBase64(tx)
utx, _ := ptx.UnsignedTx()
txid := utx.TxHash().String()
if _, ok := m.forfeitTxs[txid]; ok {
for index, input := range ptx.Inputs {
if len(input.TapScriptSig) > 0 {
for _, tapScriptSig := range input.TapScriptSig {
leafHash, err := chainhash.NewHash(tapScriptSig.LeafHash)
if err != nil {
return err
}
preimage, err := common.TaprootPreimage(
m.genesisBlockHash,
ptx,
index,
leafHash,
)
if err != nil {
return err
}
sig, err := schnorr.ParseSignature(tapScriptSig.Signature)
if err != nil {
return err
}
pubkey, err := schnorr.ParsePubKey(tapScriptSig.PubKey)
if err != nil {
return err
}
if sig.Verify(preimage, pubkey) {
m.forfeitTxs[txid].signed = true
} else {
return fmt.Errorf("invalid signature")
}
}
}
}
}
}
return nil
}
func (m *forfeitTxsMap) pop() (signed, unsigned []string) {
m.lock.Lock()
defer m.lock.Unlock()
for _, t := range m.forfeitTxs {
if t.signed {
signed = append(signed, t.tx)
} else {
unsigned = append(unsigned, t.tx)
}
}
m.forfeitTxs = make(map[string]*signedTx)
return signed, unsigned
}
func (m *forfeitTxsMap) view() []string {
m.lock.RLock()
defer m.lock.RUnlock()
txs := make([]string, 0, len(m.forfeitTxs))
for _, tx := range m.forfeitTxs {
txs = append(txs, tx.tx)
}
return txs
}