mirror of
https://github.com/aljazceru/ark.git
synced 2025-12-18 12:44:19 +01:00
Rename folders (#97)
* Rename arkd folder & drop cli * Rename ark cli folder & update docs * Update readme * Fix * scripts: add build-all * Add target to build cli for all platforms * Update build scripts --------- Co-authored-by: tiero <3596602+tiero@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
0d8c7bffb2
commit
dc00d60585
601
server/internal/core/application/service.go
Normal file
601
server/internal/core/application/service.go
Normal file
@@ -0,0 +1,601 @@
|
||||
package application
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/ark-network/ark/common"
|
||||
"github.com/ark-network/ark/internal/core/domain"
|
||||
"github.com/ark-network/ark/internal/core/ports"
|
||||
"github.com/btcsuite/btcd/chaincfg/chainhash"
|
||||
"github.com/decred/dcrd/dcrec/secp256k1/v4"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/vulpemventures/go-elements/network"
|
||||
"github.com/vulpemventures/go-elements/psetv2"
|
||||
)
|
||||
|
||||
var (
|
||||
paymentsThreshold = int64(128)
|
||||
dustAmount = uint64(450)
|
||||
faucetVtxo = domain.VtxoKey{
|
||||
Txid: "0000000000000000000000000000000000000000000000000000000000000000",
|
||||
VOut: 0,
|
||||
}
|
||||
)
|
||||
|
||||
type Service interface {
|
||||
Start() error
|
||||
Stop()
|
||||
SpendVtxos(ctx context.Context, inputs []domain.VtxoKey) (string, error)
|
||||
ClaimVtxos(ctx context.Context, creds string, receivers []domain.Receiver) error
|
||||
SignVtxos(ctx context.Context, forfeitTxs []string) error
|
||||
FaucetVtxos(ctx context.Context, pubkey *secp256k1.PublicKey) error
|
||||
GetRoundByTxid(ctx context.Context, poolTxid string) (*domain.Round, error)
|
||||
GetEventsChannel(ctx context.Context) <-chan domain.RoundEvent
|
||||
UpdatePaymentStatus(ctx context.Context, id string) error
|
||||
ListVtxos(ctx context.Context, pubkey *secp256k1.PublicKey) ([]domain.Vtxo, error)
|
||||
GetPubkey(ctx context.Context) (string, error)
|
||||
}
|
||||
|
||||
type service struct {
|
||||
network common.Network
|
||||
onchainNework network.Network
|
||||
pubkey *secp256k1.PublicKey
|
||||
roundLifetime int64
|
||||
roundInterval int64
|
||||
minRelayFee uint64
|
||||
|
||||
wallet ports.WalletService
|
||||
repoManager ports.RepoManager
|
||||
builder ports.TxBuilder
|
||||
scanner ports.BlockchainScanner
|
||||
sweeper *sweeper
|
||||
|
||||
paymentRequests *paymentsMap
|
||||
forfeitTxs *forfeitTxsMap
|
||||
|
||||
eventsCh chan domain.RoundEvent
|
||||
}
|
||||
|
||||
func NewService(
|
||||
network common.Network, onchainNetwork network.Network,
|
||||
roundInterval, roundLifetime int64, minRelayFee uint64,
|
||||
walletSvc ports.WalletService, repoManager ports.RepoManager,
|
||||
builder ports.TxBuilder, scanner ports.BlockchainScanner,
|
||||
scheduler ports.SchedulerService,
|
||||
) (Service, error) {
|
||||
eventsCh := make(chan domain.RoundEvent)
|
||||
paymentRequests := newPaymentsMap(nil)
|
||||
|
||||
genesisHash, _ := chainhash.NewHashFromStr(onchainNetwork.GenesisBlockHash)
|
||||
forfeitTxs := newForfeitTxsMap(genesisHash)
|
||||
pubkey, err := walletSvc.GetPubkey(context.Background())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to fetch pubkey: %s", err)
|
||||
}
|
||||
|
||||
sweeper := newSweeper(walletSvc, repoManager, builder, scheduler)
|
||||
|
||||
svc := &service{
|
||||
network, onchainNetwork, pubkey,
|
||||
roundLifetime, roundInterval, minRelayFee,
|
||||
walletSvc, repoManager, builder, scanner, sweeper,
|
||||
paymentRequests, forfeitTxs, eventsCh,
|
||||
}
|
||||
repoManager.RegisterEventsHandler(
|
||||
func(round *domain.Round) {
|
||||
svc.updateProjectionStore(round)
|
||||
svc.propagateEvents(round)
|
||||
},
|
||||
)
|
||||
|
||||
if err := svc.restoreWatchingVtxos(); err != nil {
|
||||
return nil, fmt.Errorf("failed to restore watching vtxos: %s", err)
|
||||
}
|
||||
go svc.listenToRedemptions()
|
||||
return svc, nil
|
||||
}
|
||||
|
||||
func (s *service) Start() error {
|
||||
log.Debug("starting sweeper service")
|
||||
if err := s.sweeper.start(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Debug("starting app service")
|
||||
go s.start()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *service) Stop() {
|
||||
s.sweeper.stop()
|
||||
// nolint
|
||||
vtxos, _ := s.repoManager.Vtxos().GetSpendableVtxos(
|
||||
context.Background(), "",
|
||||
)
|
||||
if len(vtxos) > 0 {
|
||||
s.stopWatchingVtxos(vtxos)
|
||||
}
|
||||
|
||||
s.wallet.Close()
|
||||
log.Debug("closed connection to wallet")
|
||||
s.repoManager.Close()
|
||||
log.Debug("closed connection to db")
|
||||
}
|
||||
|
||||
func (s *service) SpendVtxos(ctx context.Context, inputs []domain.VtxoKey) (string, error) {
|
||||
vtxos, err := s.repoManager.Vtxos().GetVtxos(ctx, inputs)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
for _, v := range vtxos {
|
||||
if v.Spent {
|
||||
return "", fmt.Errorf("input %s:%d already spent", v.Txid, v.VOut)
|
||||
}
|
||||
}
|
||||
|
||||
payment, err := domain.NewPayment(vtxos)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if err := s.paymentRequests.push(*payment); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return payment.Id, nil
|
||||
}
|
||||
|
||||
func (s *service) ClaimVtxos(ctx context.Context, creds string, receivers []domain.Receiver) error {
|
||||
// Check credentials
|
||||
payment, ok := s.paymentRequests.view(creds)
|
||||
if !ok {
|
||||
return fmt.Errorf("invalid credentials")
|
||||
}
|
||||
|
||||
if err := payment.AddReceivers(receivers); err != nil {
|
||||
return err
|
||||
}
|
||||
return s.paymentRequests.update(*payment)
|
||||
}
|
||||
|
||||
func (s *service) UpdatePaymentStatus(_ context.Context, id string) error {
|
||||
return s.paymentRequests.updatePingTimestamp(id)
|
||||
}
|
||||
|
||||
func (s *service) FaucetVtxos(ctx context.Context, userPubkey *secp256k1.PublicKey) error {
|
||||
pubkey := hex.EncodeToString(userPubkey.SerializeCompressed())
|
||||
|
||||
payment, err := domain.NewPayment([]domain.Vtxo{
|
||||
{
|
||||
VtxoKey: faucetVtxo,
|
||||
Receiver: domain.Receiver{
|
||||
Pubkey: pubkey,
|
||||
Amount: 10000,
|
||||
},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := payment.AddReceivers([]domain.Receiver{
|
||||
{Pubkey: pubkey, Amount: 1000},
|
||||
{Pubkey: pubkey, Amount: 1000},
|
||||
{Pubkey: pubkey, Amount: 1000},
|
||||
{Pubkey: pubkey, Amount: 1000},
|
||||
{Pubkey: pubkey, Amount: 1000},
|
||||
{Pubkey: pubkey, Amount: 1000},
|
||||
{Pubkey: pubkey, Amount: 1000},
|
||||
{Pubkey: pubkey, Amount: 1000},
|
||||
{Pubkey: pubkey, Amount: 1000},
|
||||
{Pubkey: pubkey, Amount: 1000},
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := s.paymentRequests.push(*payment); err != nil {
|
||||
return err
|
||||
}
|
||||
return s.paymentRequests.updatePingTimestamp(payment.Id)
|
||||
}
|
||||
|
||||
func (s *service) SignVtxos(ctx context.Context, forfeitTxs []string) error {
|
||||
return s.forfeitTxs.sign(forfeitTxs)
|
||||
}
|
||||
|
||||
func (s *service) ListVtxos(ctx context.Context, pubkey *secp256k1.PublicKey) ([]domain.Vtxo, error) {
|
||||
pk := hex.EncodeToString(pubkey.SerializeCompressed())
|
||||
return s.repoManager.Vtxos().GetSpendableVtxos(ctx, pk)
|
||||
}
|
||||
|
||||
func (s *service) GetEventsChannel(ctx context.Context) <-chan domain.RoundEvent {
|
||||
return s.eventsCh
|
||||
}
|
||||
|
||||
func (s *service) GetRoundByTxid(ctx context.Context, poolTxid string) (*domain.Round, error) {
|
||||
return s.repoManager.Rounds().GetRoundWithTxid(ctx, poolTxid)
|
||||
}
|
||||
|
||||
func (s *service) GetPubkey(ctx context.Context) (string, error) {
|
||||
pubkey, err := common.EncodePubKey(s.network.PubKey, s.pubkey)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return pubkey, nil
|
||||
}
|
||||
|
||||
func (s *service) start() {
|
||||
s.startRound()
|
||||
}
|
||||
|
||||
func (s *service) startRound() {
|
||||
round := domain.NewRound(dustAmount)
|
||||
changes, _ := round.StartRegistration()
|
||||
if err := s.repoManager.Events().Save(
|
||||
context.Background(), round.Id, changes...,
|
||||
); err != nil {
|
||||
log.WithError(err).Warn("failed to store new round events")
|
||||
return
|
||||
}
|
||||
|
||||
defer func() {
|
||||
time.Sleep(time.Duration(s.roundInterval/2) * time.Second)
|
||||
s.startFinalization()
|
||||
}()
|
||||
|
||||
log.Debugf("started registration stage for new round: %s", round.Id)
|
||||
}
|
||||
|
||||
func (s *service) startFinalization() {
|
||||
ctx := context.Background()
|
||||
round, err := s.repoManager.Rounds().GetCurrentRound(ctx)
|
||||
if err != nil {
|
||||
log.WithError(err).Warn("failed to retrieve current round")
|
||||
return
|
||||
}
|
||||
|
||||
var changes []domain.RoundEvent
|
||||
defer func() {
|
||||
if len(changes) > 0 {
|
||||
if err := s.repoManager.Events().Save(ctx, round.Id, changes...); err != nil {
|
||||
log.WithError(err).Warn("failed to store new round events")
|
||||
}
|
||||
}
|
||||
|
||||
if round.IsFailed() {
|
||||
s.startRound()
|
||||
return
|
||||
}
|
||||
time.Sleep(time.Duration((s.roundInterval/2)-1) * time.Second)
|
||||
s.finalizeRound()
|
||||
}()
|
||||
|
||||
if round.IsFailed() {
|
||||
return
|
||||
}
|
||||
|
||||
// TODO: understand how many payments must be popped from the queue and actually registered for the round
|
||||
num := s.paymentRequests.len()
|
||||
if num == 0 {
|
||||
err := fmt.Errorf("no payments registered")
|
||||
changes = round.Fail(fmt.Errorf("round aborted: %s", err))
|
||||
log.WithError(err).Debugf("round %s aborted", round.Id)
|
||||
return
|
||||
}
|
||||
if num > paymentsThreshold {
|
||||
num = paymentsThreshold
|
||||
}
|
||||
payments := s.paymentRequests.pop(num)
|
||||
changes, err = round.RegisterPayments(payments)
|
||||
if err != nil {
|
||||
changes = round.Fail(fmt.Errorf("failed to register payments: %s", err))
|
||||
log.WithError(err).Warn("failed to register payments")
|
||||
return
|
||||
}
|
||||
|
||||
unsignedPoolTx, tree, err := s.builder.BuildPoolTx(s.pubkey, payments, s.minRelayFee)
|
||||
if err != nil {
|
||||
changes = round.Fail(fmt.Errorf("failed to create pool tx: %s", err))
|
||||
log.WithError(err).Warn("failed to create pool tx")
|
||||
return
|
||||
}
|
||||
|
||||
log.Debugf("pool tx created for round %s", round.Id)
|
||||
|
||||
connectors, forfeitTxs, err := s.builder.BuildForfeitTxs(s.pubkey, unsignedPoolTx, payments)
|
||||
if err != nil {
|
||||
changes = round.Fail(fmt.Errorf("failed to create connectors and forfeit txs: %s", err))
|
||||
log.WithError(err).Warn("failed to create connectors and forfeit txs")
|
||||
return
|
||||
}
|
||||
|
||||
log.Debugf("forfeit transactions created for round %s", round.Id)
|
||||
|
||||
events, err := round.StartFinalization(connectors, tree, unsignedPoolTx)
|
||||
if err != nil {
|
||||
changes = round.Fail(fmt.Errorf("failed to start finalization: %s", err))
|
||||
log.WithError(err).Warn("failed to start finalization")
|
||||
return
|
||||
}
|
||||
changes = append(changes, events...)
|
||||
|
||||
s.forfeitTxs.push(forfeitTxs)
|
||||
|
||||
log.Debugf("started finalization stage for round: %s", round.Id)
|
||||
}
|
||||
|
||||
func (s *service) finalizeRound() {
|
||||
defer s.startRound()
|
||||
|
||||
ctx := context.Background()
|
||||
round, err := s.repoManager.Rounds().GetCurrentRound(ctx)
|
||||
if err != nil {
|
||||
log.WithError(err).Warn("failed to retrieve current round")
|
||||
return
|
||||
}
|
||||
if round.IsFailed() {
|
||||
return
|
||||
}
|
||||
|
||||
var changes []domain.RoundEvent
|
||||
defer func() {
|
||||
if err := s.repoManager.Events().Save(ctx, round.Id, changes...); err != nil {
|
||||
log.WithError(err).Warn("failed to store new round events")
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
forfeitTxs, leftUnsigned := s.forfeitTxs.pop()
|
||||
if len(leftUnsigned) > 0 {
|
||||
err := fmt.Errorf("%d forfeit txs left to sign", len(leftUnsigned))
|
||||
changes = round.Fail(fmt.Errorf("failed to finalize round: %s", err))
|
||||
log.WithError(err).Warn("failed to finalize round")
|
||||
return
|
||||
}
|
||||
|
||||
log.Debugf("signing round transaction %s\n", round.Id)
|
||||
signedPoolTx, err := s.wallet.SignPset(ctx, round.UnsignedTx, true)
|
||||
if err != nil {
|
||||
changes = round.Fail(fmt.Errorf("failed to sign round tx: %s", err))
|
||||
log.WithError(err).Warn("failed to sign round tx")
|
||||
return
|
||||
}
|
||||
|
||||
txid, err := s.wallet.BroadcastTransaction(ctx, signedPoolTx)
|
||||
if err != nil {
|
||||
changes = round.Fail(fmt.Errorf("failed to broadcast pool tx: %s", err))
|
||||
log.WithError(err).Warn("failed to broadcast pool tx")
|
||||
return
|
||||
}
|
||||
|
||||
now := time.Now().Unix()
|
||||
expirationTimestamp := now + s.roundLifetime + 30 // add 30 secs to be sure that the tx is confirmed
|
||||
|
||||
if err := s.sweeper.schedule(expirationTimestamp, txid, round.CongestionTree); err != nil {
|
||||
changes = round.Fail(fmt.Errorf("failed to schedule sweep tx: %s", err))
|
||||
log.WithError(err).Warn("failed to schedule sweep tx")
|
||||
return
|
||||
}
|
||||
|
||||
changes, _ = round.EndFinalization(forfeitTxs, txid)
|
||||
|
||||
log.Debugf("finalized round %s with pool tx %s", round.Id, round.Txid)
|
||||
}
|
||||
|
||||
func (s *service) listenToRedemptions() {
|
||||
ctx := context.Background()
|
||||
chVtxos := s.scanner.GetNotificationChannel(ctx)
|
||||
for vtxoKeys := range chVtxos {
|
||||
if len(vtxoKeys) > 0 {
|
||||
for {
|
||||
// TODO: make sure that the vtxos haven't been already spent, otherwise
|
||||
// broadcast the corresponding forfeit tx and connector to prevent
|
||||
// getting cheated.
|
||||
vtxos, err := s.repoManager.Vtxos().RedeemVtxos(ctx, vtxoKeys)
|
||||
if err != nil {
|
||||
log.WithError(err).Warn("failed to redeem vtxos, retrying...")
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
continue
|
||||
}
|
||||
if len(vtxos) > 0 {
|
||||
log.Debugf("redeemed %d vtxos", len(vtxos))
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *service) updateProjectionStore(round *domain.Round) {
|
||||
ctx := context.Background()
|
||||
lastChange := round.Events()[len(round.Events())-1]
|
||||
// Update the vtxo set only after a round is finalized.
|
||||
if _, ok := lastChange.(domain.RoundFinalized); ok {
|
||||
repo := s.repoManager.Vtxos()
|
||||
spentVtxos := getSpentVtxos(round.Payments)
|
||||
if len(spentVtxos) > 0 {
|
||||
for {
|
||||
if err := repo.SpendVtxos(ctx, spentVtxos); err != nil {
|
||||
log.WithError(err).Warn("failed to add new vtxos, retrying soon")
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
continue
|
||||
}
|
||||
log.Debugf("spent %d vtxos", len(spentVtxos))
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
newVtxos := s.getNewVtxos(round)
|
||||
for {
|
||||
if err := repo.AddVtxos(ctx, newVtxos); err != nil {
|
||||
log.WithError(err).Warn("failed to add new vtxos, retrying soon")
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
continue
|
||||
}
|
||||
log.Debugf("added %d new vtxos", len(newVtxos))
|
||||
break
|
||||
}
|
||||
|
||||
go func() {
|
||||
for {
|
||||
if err := s.startWatchingVtxos(newVtxos); err != nil {
|
||||
log.WithError(err).Warn(
|
||||
"failed to start watching vtxos, retrying in a moment...",
|
||||
)
|
||||
continue
|
||||
}
|
||||
log.Debugf("started watching %d vtxos", len(newVtxos))
|
||||
return
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// Always update the status of the round.
|
||||
for {
|
||||
if err := s.repoManager.Rounds().AddOrUpdateRound(ctx, *round); err != nil {
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
continue
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
func (s *service) propagateEvents(round *domain.Round) {
|
||||
lastEvent := round.Events()[len(round.Events())-1]
|
||||
switch e := lastEvent.(type) {
|
||||
case domain.RoundFinalizationStarted:
|
||||
forfeitTxs := s.forfeitTxs.view()
|
||||
s.eventsCh <- domain.RoundFinalizationStarted{
|
||||
Id: e.Id,
|
||||
CongestionTree: e.CongestionTree,
|
||||
Connectors: e.Connectors,
|
||||
PoolTx: e.PoolTx,
|
||||
UnsignedForfeitTxs: forfeitTxs,
|
||||
}
|
||||
case domain.RoundFinalized, domain.RoundFailed:
|
||||
s.eventsCh <- e
|
||||
}
|
||||
}
|
||||
|
||||
func (s *service) getNewVtxos(round *domain.Round) []domain.Vtxo {
|
||||
leaves := round.CongestionTree.Leaves()
|
||||
vtxos := make([]domain.Vtxo, 0)
|
||||
for _, node := range leaves {
|
||||
tx, _ := psetv2.NewPsetFromBase64(node.Tx)
|
||||
for i, out := range tx.Outputs {
|
||||
for _, p := range round.Payments {
|
||||
var pubkey string
|
||||
found := false
|
||||
for _, r := range p.Receivers {
|
||||
if r.IsOnchain() {
|
||||
continue
|
||||
}
|
||||
|
||||
buf, _ := hex.DecodeString(r.Pubkey)
|
||||
pk, _ := secp256k1.ParsePubKey(buf)
|
||||
script, _ := s.builder.GetVtxoScript(pk, s.pubkey)
|
||||
if bytes.Equal(script, out.Script) {
|
||||
found = true
|
||||
pubkey = r.Pubkey
|
||||
break
|
||||
}
|
||||
}
|
||||
if found {
|
||||
vtxos = append(vtxos, domain.Vtxo{
|
||||
VtxoKey: domain.VtxoKey{Txid: node.Txid, VOut: uint32(i)},
|
||||
Receiver: domain.Receiver{Pubkey: pubkey, Amount: out.Value},
|
||||
PoolTx: round.Txid,
|
||||
})
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return vtxos
|
||||
}
|
||||
|
||||
func (s *service) startWatchingVtxos(vtxos []domain.Vtxo) error {
|
||||
scripts, err := s.extractVtxosScripts(vtxos)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return s.scanner.WatchScripts(context.Background(), scripts)
|
||||
}
|
||||
|
||||
func (s *service) stopWatchingVtxos(vtxos []domain.Vtxo) {
|
||||
scripts, err := s.extractVtxosScripts(vtxos)
|
||||
if err != nil {
|
||||
log.WithError(err).Warn("failed to extract scripts from vtxos")
|
||||
return
|
||||
}
|
||||
|
||||
for {
|
||||
if err := s.scanner.UnwatchScripts(context.Background(), scripts); err != nil {
|
||||
log.WithError(err).Warn("failed to stop watching vtxos, retrying in a moment...")
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
continue
|
||||
}
|
||||
log.Debugf("stopped watching %d vtxos", len(vtxos))
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
func (s *service) restoreWatchingVtxos() error {
|
||||
vtxos, err := s.repoManager.Vtxos().GetSpendableVtxos(
|
||||
context.Background(), "",
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(vtxos) <= 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := s.startWatchingVtxos(vtxos); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Debugf("restored watching %d vtxos", len(vtxos))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *service) extractVtxosScripts(vtxos []domain.Vtxo) ([]string, error) {
|
||||
indexedScripts := make(map[string]struct{})
|
||||
for _, vtxo := range vtxos {
|
||||
buf, err := hex.DecodeString(vtxo.Pubkey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
userPubkey, err := secp256k1.ParsePubKey(buf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
script, err := s.builder.GetVtxoScript(userPubkey, s.pubkey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
indexedScripts[hex.EncodeToString(script)] = struct{}{}
|
||||
}
|
||||
scripts := make([]string, 0, len(indexedScripts))
|
||||
for script := range indexedScripts {
|
||||
scripts = append(scripts, script)
|
||||
}
|
||||
return scripts, nil
|
||||
}
|
||||
|
||||
func getSpentVtxos(payments map[string]domain.Payment) []domain.VtxoKey {
|
||||
vtxos := make([]domain.VtxoKey, 0)
|
||||
for _, p := range payments {
|
||||
for _, vtxo := range p.Inputs {
|
||||
if vtxo.VtxoKey == faucetVtxo {
|
||||
continue
|
||||
}
|
||||
vtxos = append(vtxos, vtxo.VtxoKey)
|
||||
}
|
||||
}
|
||||
return vtxos
|
||||
}
|
||||
579
server/internal/core/application/sweeper.go
Normal file
579
server/internal/core/application/sweeper.go
Normal file
@@ -0,0 +1,579 @@
|
||||
package application
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/ark-network/ark/common/tree"
|
||||
"github.com/ark-network/ark/internal/core/domain"
|
||||
"github.com/ark-network/ark/internal/core/ports"
|
||||
"github.com/btcsuite/btcd/chaincfg/chainhash"
|
||||
"github.com/decred/dcrd/dcrec/secp256k1/v4"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/vulpemventures/go-elements/psetv2"
|
||||
)
|
||||
|
||||
// sweeper is an unexported service running while the main application service is started
|
||||
// it is responsible for sweeping onchain shared outputs that expired
|
||||
// it also handles delaying the sweep events in case some parts of the tree are broadcasted
|
||||
// when a round is finalized, the main application service schedules a sweep event on the newly created congestion tree
|
||||
type sweeper struct {
|
||||
wallet ports.WalletService
|
||||
repoManager ports.RepoManager
|
||||
builder ports.TxBuilder
|
||||
scheduler ports.SchedulerService
|
||||
|
||||
// cache of scheduled tasks, avoid scheduling the same sweep event multiple times
|
||||
scheduledTasks map[string]struct{}
|
||||
}
|
||||
|
||||
func newSweeper(
|
||||
wallet ports.WalletService,
|
||||
repoManager ports.RepoManager,
|
||||
builder ports.TxBuilder,
|
||||
scheduler ports.SchedulerService,
|
||||
) *sweeper {
|
||||
return &sweeper{
|
||||
wallet,
|
||||
repoManager,
|
||||
builder,
|
||||
scheduler,
|
||||
make(map[string]struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *sweeper) start() error {
|
||||
s.scheduler.Start()
|
||||
|
||||
allRounds, err := s.repoManager.Rounds().GetSweepableRounds(context.Background())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, round := range allRounds {
|
||||
task := s.createTask(round.Txid, round.CongestionTree)
|
||||
task()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *sweeper) stop() {
|
||||
s.scheduler.Stop()
|
||||
}
|
||||
|
||||
// removeTask update the cached map of scheduled tasks
|
||||
func (s *sweeper) removeTask(treeRootTxid string) {
|
||||
delete(s.scheduledTasks, treeRootTxid)
|
||||
}
|
||||
|
||||
// schedule set up a task to be executed once at the given timestamp
|
||||
func (s *sweeper) schedule(
|
||||
expirationTimestamp int64, roundTxid string, congestionTree tree.CongestionTree,
|
||||
) error {
|
||||
root, err := congestionTree.Root()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if _, scheduled := s.scheduledTasks[root.Txid]; scheduled {
|
||||
return nil
|
||||
}
|
||||
|
||||
task := s.createTask(roundTxid, congestionTree)
|
||||
fancyTime := time.Unix(expirationTimestamp, 0).Format("2006-01-02 15:04:05")
|
||||
log.Debugf("scheduled sweep task at %s", fancyTime)
|
||||
if err := s.scheduler.ScheduleTaskOnce(expirationTimestamp, task); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.scheduledTasks[root.Txid] = struct{}{}
|
||||
return nil
|
||||
}
|
||||
|
||||
// createTask returns a function passed as handler in the scheduler
|
||||
// it tries to craft a sweep tx containing the onchain outputs of the given congestion tree
|
||||
// if some parts of the tree have been broadcasted in the meantine, it will schedule the next taskes for the remaining parts of the tree
|
||||
func (s *sweeper) createTask(
|
||||
roundTxid string, congestionTree tree.CongestionTree,
|
||||
) func() {
|
||||
return func() {
|
||||
ctx := context.Background()
|
||||
root, err := congestionTree.Root()
|
||||
if err != nil {
|
||||
log.WithError(err).Error("error while getting root node")
|
||||
return
|
||||
}
|
||||
|
||||
s.removeTask(root.Txid)
|
||||
log.Debugf("sweeper: %s", root.Txid)
|
||||
|
||||
sweepInputs := make([]ports.SweepInput, 0)
|
||||
vtxoKeys := make([]domain.VtxoKey, 0) // vtxos associated to the sweep inputs
|
||||
|
||||
// inspect the congestion tree to find onchain shared outputs
|
||||
sharedOutputs, err := s.findSweepableOutputs(ctx, congestionTree)
|
||||
if err != nil {
|
||||
log.WithError(err).Error("error while inspecting congestion tree")
|
||||
return
|
||||
}
|
||||
|
||||
for expiredAt, inputs := range sharedOutputs {
|
||||
// if the shared outputs are not expired, schedule a sweep task for it
|
||||
if time.Unix(expiredAt, 0).After(time.Now()) {
|
||||
subtrees, err := computeSubTrees(congestionTree, inputs)
|
||||
if err != nil {
|
||||
log.WithError(err).Error("error while computing subtrees")
|
||||
continue
|
||||
}
|
||||
|
||||
for _, subTree := range subtrees {
|
||||
// mitigate the risk to get BIP68 non-final errors by scheduling the task 30 seconds after the expiration time
|
||||
if err := s.schedule(int64(expiredAt), roundTxid, subTree); err != nil {
|
||||
log.WithError(err).Error("error while scheduling sweep task")
|
||||
continue
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// iterate over the expired shared outputs
|
||||
for _, input := range inputs {
|
||||
// sweepableVtxos related to the sweep input
|
||||
sweepableVtxos := make([]domain.VtxoKey, 0)
|
||||
|
||||
// check if input is the vtxo itself
|
||||
vtxos, _ := s.repoManager.Vtxos().GetVtxos(
|
||||
ctx,
|
||||
[]domain.VtxoKey{
|
||||
{
|
||||
Txid: input.InputArgs.Txid,
|
||||
VOut: input.InputArgs.TxIndex,
|
||||
},
|
||||
},
|
||||
)
|
||||
if len(vtxos) > 0 {
|
||||
if !vtxos[0].Swept && !vtxos[0].Redeemed {
|
||||
sweepableVtxos = append(sweepableVtxos, vtxos[0].VtxoKey)
|
||||
}
|
||||
} else {
|
||||
// if it's not a vtxo, find all the vtxos leaves reachable from that input
|
||||
vtxosLeaves, err := congestionTree.FindLeaves(input.InputArgs.Txid, input.InputArgs.TxIndex)
|
||||
if err != nil {
|
||||
log.WithError(err).Error("error while finding vtxos leaves")
|
||||
continue
|
||||
}
|
||||
|
||||
for _, leaf := range vtxosLeaves {
|
||||
pset, err := psetv2.NewPsetFromBase64(leaf.Tx)
|
||||
if err != nil {
|
||||
log.Error(fmt.Errorf("error while decoding pset: %w", err))
|
||||
continue
|
||||
}
|
||||
|
||||
vtxo, err := extractVtxoOutpoint(pset)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
continue
|
||||
}
|
||||
|
||||
sweepableVtxos = append(sweepableVtxos, *vtxo)
|
||||
}
|
||||
|
||||
if len(sweepableVtxos) <= 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
firstVtxo, err := s.repoManager.Vtxos().GetVtxos(ctx, sweepableVtxos[:1])
|
||||
if err != nil {
|
||||
log.Error(fmt.Errorf("error while getting vtxo: %w", err))
|
||||
sweepInputs = append(sweepInputs, input) // add the input anyway in order to try to sweep it
|
||||
continue
|
||||
}
|
||||
|
||||
if firstVtxo[0].Swept || firstVtxo[0].Redeemed {
|
||||
// we assume that if the first vtxo is swept or redeemed, the shared output has been spent
|
||||
// skip, the output is already swept or spent by a unilateral redeem
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
if len(sweepableVtxos) > 0 {
|
||||
vtxoKeys = append(vtxoKeys, sweepableVtxos...)
|
||||
sweepInputs = append(sweepInputs, input)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(sweepInputs) > 0 {
|
||||
// build the sweep transaction with all the expired non-swept shared outputs
|
||||
sweepTx, err := s.builder.BuildSweepTx(s.wallet, sweepInputs)
|
||||
if err != nil {
|
||||
log.WithError(err).Error("error while building sweep tx")
|
||||
return
|
||||
}
|
||||
|
||||
err = nil
|
||||
txid := ""
|
||||
// retry until the tx is broadcasted or the error is not BIP68 final
|
||||
for len(txid) == 0 && (err == nil || err == fmt.Errorf("non-BIP68-final")) {
|
||||
if err != nil {
|
||||
log.Debugln("sweep tx not BIP68 final, retrying in 5 seconds")
|
||||
time.Sleep(5 * time.Second)
|
||||
}
|
||||
|
||||
txid, err = s.wallet.BroadcastTransaction(ctx, sweepTx)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
log.WithError(err).Error("error while broadcasting sweep tx")
|
||||
return
|
||||
}
|
||||
if len(txid) > 0 {
|
||||
log.Debugln("sweep tx broadcasted:", txid)
|
||||
vtxosRepository := s.repoManager.Vtxos()
|
||||
|
||||
// mark the vtxos as swept
|
||||
if err := vtxosRepository.SweepVtxos(ctx, vtxoKeys); err != nil {
|
||||
log.Error(fmt.Errorf("error while deleting vtxos: %w", err))
|
||||
return
|
||||
}
|
||||
|
||||
log.Debugf("%d vtxos swept", len(vtxoKeys))
|
||||
|
||||
roundVtxos, err := vtxosRepository.GetVtxosForRound(ctx, roundTxid)
|
||||
if err != nil {
|
||||
log.WithError(err).Error("error while getting vtxos for round")
|
||||
return
|
||||
}
|
||||
|
||||
allSwept := true
|
||||
for _, vtxo := range roundVtxos {
|
||||
allSwept = allSwept && vtxo.Swept
|
||||
if !allSwept {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if allSwept {
|
||||
// update the round
|
||||
roundRepo := s.repoManager.Rounds()
|
||||
round, err := roundRepo.GetRoundWithTxid(ctx, roundTxid)
|
||||
if err != nil {
|
||||
log.WithError(err).Error("error while getting round")
|
||||
return
|
||||
}
|
||||
|
||||
round.Sweep()
|
||||
|
||||
if err := roundRepo.AddOrUpdateRound(ctx, *round); err != nil {
|
||||
log.WithError(err).Error("error while marking round as swept")
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// onchainOutputs iterates over all the nodes' outputs in the congestion tree and checks their onchain state
|
||||
// returns the sweepable outputs as ports.SweepInput mapped by their expiration time
|
||||
func (s *sweeper) findSweepableOutputs(
|
||||
ctx context.Context,
|
||||
congestionTree tree.CongestionTree,
|
||||
) (map[int64][]ports.SweepInput, error) {
|
||||
sweepableOutputs := make(map[int64][]ports.SweepInput)
|
||||
blocktimeCache := make(map[string]int64) // txid -> blocktime
|
||||
nodesToCheck := congestionTree[0] // init with the root
|
||||
|
||||
for len(nodesToCheck) > 0 {
|
||||
newNodesToCheck := make([]tree.Node, 0)
|
||||
|
||||
for _, node := range nodesToCheck {
|
||||
isPublished, blocktime, err := s.wallet.IsTransactionPublished(ctx, node.Txid)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var expirationTime int64
|
||||
var sweepInputs []ports.SweepInput
|
||||
|
||||
if !isPublished {
|
||||
if _, ok := blocktimeCache[node.ParentTxid]; !ok {
|
||||
isPublished, blocktime, err := s.wallet.IsTransactionPublished(ctx, node.ParentTxid)
|
||||
if !isPublished || err != nil {
|
||||
return nil, fmt.Errorf("tx %s not found", node.Txid)
|
||||
}
|
||||
|
||||
blocktimeCache[node.ParentTxid] = blocktime
|
||||
}
|
||||
|
||||
expirationTime, sweepInputs, err = s.nodeToSweepInputs(blocktimeCache[node.ParentTxid], node)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
// cache the blocktime for future use
|
||||
blocktimeCache[node.Txid] = int64(blocktime)
|
||||
|
||||
// if the tx is onchain, it means that the input is spent
|
||||
// add the children to the nodes in order to check them during the next iteration
|
||||
// We will return the error below, but are we going to schedule the tasks for the "children roots"?
|
||||
if !node.Leaf {
|
||||
children := congestionTree.Children(node.Txid)
|
||||
newNodesToCheck = append(newNodesToCheck, children...)
|
||||
continue
|
||||
}
|
||||
|
||||
// if the node is a leaf, the vtxos outputs should added as onchain outputs if they are not swept yet
|
||||
vtxoExpiration, sweepInput, err := s.leafToSweepInput(ctx, blocktime, node)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if sweepInput != nil {
|
||||
expirationTime = vtxoExpiration
|
||||
sweepInputs = []ports.SweepInput{*sweepInput}
|
||||
}
|
||||
}
|
||||
|
||||
if _, ok := sweepableOutputs[expirationTime]; !ok {
|
||||
sweepableOutputs[expirationTime] = make([]ports.SweepInput, 0)
|
||||
}
|
||||
sweepableOutputs[expirationTime] = append(sweepableOutputs[expirationTime], sweepInputs...)
|
||||
}
|
||||
|
||||
nodesToCheck = newNodesToCheck
|
||||
}
|
||||
|
||||
return sweepableOutputs, nil
|
||||
}
|
||||
|
||||
func (s *sweeper) leafToSweepInput(ctx context.Context, txBlocktime int64, node tree.Node) (int64, *ports.SweepInput, error) {
|
||||
pset, err := psetv2.NewPsetFromBase64(node.Tx)
|
||||
if err != nil {
|
||||
return -1, nil, err
|
||||
}
|
||||
|
||||
vtxo, err := extractVtxoOutpoint(pset)
|
||||
if err != nil {
|
||||
return -1, nil, err
|
||||
}
|
||||
|
||||
fromRepo, err := s.repoManager.Vtxos().GetVtxos(ctx, []domain.VtxoKey{*vtxo})
|
||||
if err != nil {
|
||||
return -1, nil, err
|
||||
}
|
||||
|
||||
if len(fromRepo) == 0 {
|
||||
return -1, nil, fmt.Errorf("vtxo not found")
|
||||
}
|
||||
|
||||
if fromRepo[0].Swept {
|
||||
return -1, nil, nil
|
||||
}
|
||||
|
||||
if fromRepo[0].Redeemed {
|
||||
return -1, nil, nil
|
||||
}
|
||||
|
||||
// if the vtxo is not swept or redeemed, add it to the onchain outputs
|
||||
pubKeyBytes, err := hex.DecodeString(fromRepo[0].Pubkey)
|
||||
if err != nil {
|
||||
return -1, nil, err
|
||||
}
|
||||
|
||||
pubKey, err := secp256k1.ParsePubKey(pubKeyBytes)
|
||||
if err != nil {
|
||||
return -1, nil, err
|
||||
}
|
||||
|
||||
sweepLeaf, lifetime, err := s.builder.GetLeafSweepClosure(node, pubKey)
|
||||
if err != nil {
|
||||
return -1, nil, err
|
||||
}
|
||||
|
||||
sweepInput := ports.SweepInput{
|
||||
InputArgs: psetv2.InputArgs{
|
||||
Txid: vtxo.Txid,
|
||||
TxIndex: vtxo.VOut,
|
||||
},
|
||||
SweepLeaf: *sweepLeaf,
|
||||
Amount: fromRepo[0].Amount,
|
||||
}
|
||||
|
||||
return txBlocktime + lifetime, &sweepInput, nil
|
||||
}
|
||||
|
||||
func (s *sweeper) nodeToSweepInputs(parentBlocktime int64, node tree.Node) (int64, []ports.SweepInput, error) {
|
||||
pset, err := psetv2.NewPsetFromBase64(node.Tx)
|
||||
if err != nil {
|
||||
return -1, nil, err
|
||||
}
|
||||
|
||||
if len(pset.Inputs) != 1 {
|
||||
return -1, nil, fmt.Errorf("invalid node pset, expect 1 input, got %d", len(pset.Inputs))
|
||||
}
|
||||
|
||||
// if the tx is not onchain, it means that the input is an existing shared output
|
||||
input := pset.Inputs[0]
|
||||
txid := chainhash.Hash(input.PreviousTxid).String()
|
||||
index := input.PreviousTxIndex
|
||||
|
||||
sweepLeaf, lifetime, err := extractSweepLeaf(input)
|
||||
if err != nil {
|
||||
return -1, nil, err
|
||||
}
|
||||
|
||||
expirationTime := parentBlocktime + lifetime
|
||||
|
||||
amount := uint64(0)
|
||||
for _, out := range pset.Outputs {
|
||||
amount += out.Value
|
||||
}
|
||||
|
||||
sweepInputs := []ports.SweepInput{
|
||||
{
|
||||
InputArgs: psetv2.InputArgs{
|
||||
Txid: txid,
|
||||
TxIndex: index,
|
||||
},
|
||||
SweepLeaf: *sweepLeaf,
|
||||
Amount: amount,
|
||||
},
|
||||
}
|
||||
|
||||
return expirationTime, sweepInputs, nil
|
||||
}
|
||||
|
||||
func computeSubTrees(congestionTree tree.CongestionTree, inputs []ports.SweepInput) ([]tree.CongestionTree, error) {
|
||||
subTrees := make(map[string]tree.CongestionTree, 0)
|
||||
|
||||
// for each sweepable input, create a sub congestion tree
|
||||
// it allows to skip the part of the tree that has been broadcasted in the next task
|
||||
for _, input := range inputs {
|
||||
subTree, err := computeSubTree(congestionTree, input.InputArgs.Txid)
|
||||
if err != nil {
|
||||
log.WithError(err).Error("error while finding sub tree")
|
||||
continue
|
||||
}
|
||||
|
||||
root, err := subTree.Root()
|
||||
if err != nil {
|
||||
log.WithError(err).Error("error while getting root node")
|
||||
continue
|
||||
}
|
||||
|
||||
subTrees[root.Txid] = subTree
|
||||
}
|
||||
|
||||
// filter out the sub trees, remove the ones that are included in others
|
||||
filteredSubTrees := make([]tree.CongestionTree, 0)
|
||||
for i, subTree := range subTrees {
|
||||
notIncludedInOtherTrees := true
|
||||
|
||||
for j, otherSubTree := range subTrees {
|
||||
if i == j {
|
||||
continue
|
||||
}
|
||||
contains, err := containsTree(otherSubTree, subTree)
|
||||
if err != nil {
|
||||
log.WithError(err).Error("error while checking if a tree contains another")
|
||||
continue
|
||||
}
|
||||
|
||||
if contains {
|
||||
notIncludedInOtherTrees = false
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if notIncludedInOtherTrees {
|
||||
filteredSubTrees = append(filteredSubTrees, subTree)
|
||||
}
|
||||
}
|
||||
|
||||
return filteredSubTrees, nil
|
||||
}
|
||||
|
||||
func computeSubTree(congestionTree tree.CongestionTree, newRoot string) (tree.CongestionTree, error) {
|
||||
for _, level := range congestionTree {
|
||||
for _, node := range level {
|
||||
if node.Txid == newRoot || node.ParentTxid == newRoot {
|
||||
newTree := make(tree.CongestionTree, 0)
|
||||
newTree = append(newTree, []tree.Node{node})
|
||||
|
||||
children := congestionTree.Children(node.Txid)
|
||||
for len(children) > 0 {
|
||||
newTree = append(newTree, children)
|
||||
newChildren := make([]tree.Node, 0)
|
||||
for _, child := range children {
|
||||
newChildren = append(newChildren, congestionTree.Children(child.Txid)...)
|
||||
}
|
||||
children = newChildren
|
||||
}
|
||||
|
||||
return newTree, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("failed to create subtree, new root not found")
|
||||
}
|
||||
|
||||
func containsTree(tr0 tree.CongestionTree, tr1 tree.CongestionTree) (bool, error) {
|
||||
tr1Root, err := tr1.Root()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
for _, level := range tr0 {
|
||||
for _, node := range level {
|
||||
if node.Txid == tr1Root.Txid {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// given a congestion tree input, searches and returns the sweep leaf and its lifetime in seconds
|
||||
func extractSweepLeaf(input psetv2.Input) (sweepLeaf *psetv2.TapLeafScript, lifetime int64, err error) {
|
||||
for _, leaf := range input.TapLeafScript {
|
||||
isSweep, _, seconds, err := tree.DecodeSweepScript(leaf.Script)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
if isSweep {
|
||||
lifetime = int64(seconds)
|
||||
sweepLeaf = &leaf
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if sweepLeaf == nil {
|
||||
return nil, 0, fmt.Errorf("sweep leaf not found")
|
||||
}
|
||||
|
||||
return sweepLeaf, lifetime, nil
|
||||
}
|
||||
|
||||
// assuming the pset is a leaf in the congestion tree, returns the vtxos outputs
|
||||
func extractVtxoOutpoint(pset *psetv2.Pset) (*domain.VtxoKey, error) {
|
||||
if len(pset.Outputs) != 2 {
|
||||
return nil, fmt.Errorf("invalid leaf pset, expect 2 outputs, got %d", len(pset.Outputs))
|
||||
}
|
||||
|
||||
utx, err := pset.UnsignedTx()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &domain.VtxoKey{
|
||||
Txid: utx.TxHash().String(),
|
||||
VOut: 0,
|
||||
}, nil
|
||||
}
|
||||
254
server/internal/core/application/utils.go
Normal file
254
server/internal/core/application/utils.go
Normal file
@@ -0,0 +1,254 @@
|
||||
package application
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"sort"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/ark-network/ark/common"
|
||||
"github.com/ark-network/ark/internal/core/domain"
|
||||
"github.com/btcsuite/btcd/btcec/v2/schnorr"
|
||||
"github.com/btcsuite/btcd/chaincfg/chainhash"
|
||||
"github.com/vulpemventures/go-elements/psetv2"
|
||||
)
|
||||
|
||||
type timedPayment struct {
|
||||
domain.Payment
|
||||
timestamp time.Time
|
||||
pingTimestamp time.Time
|
||||
}
|
||||
|
||||
type paymentsMap struct {
|
||||
lock *sync.RWMutex
|
||||
payments map[string]*timedPayment
|
||||
}
|
||||
|
||||
func newPaymentsMap(payments []domain.Payment) *paymentsMap {
|
||||
paymentsById := make(map[string]*timedPayment)
|
||||
for _, p := range payments {
|
||||
paymentsById[p.Id] = &timedPayment{p, time.Now(), time.Time{}}
|
||||
}
|
||||
lock := &sync.RWMutex{}
|
||||
return &paymentsMap{lock, paymentsById}
|
||||
}
|
||||
|
||||
func (m *paymentsMap) len() int64 {
|
||||
m.lock.RLock()
|
||||
defer m.lock.RUnlock()
|
||||
|
||||
count := int64(0)
|
||||
for _, p := range m.payments {
|
||||
if len(p.Receivers) > 0 {
|
||||
count++
|
||||
}
|
||||
}
|
||||
return count
|
||||
}
|
||||
|
||||
func (m *paymentsMap) push(payment domain.Payment) error {
|
||||
m.lock.Lock()
|
||||
defer m.lock.Unlock()
|
||||
|
||||
if _, ok := m.payments[payment.Id]; ok {
|
||||
return fmt.Errorf("duplicated inputs")
|
||||
}
|
||||
|
||||
m.payments[payment.Id] = &timedPayment{payment, time.Now(), time.Time{}}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *paymentsMap) pop(num int64) []domain.Payment {
|
||||
m.lock.Lock()
|
||||
defer m.lock.Unlock()
|
||||
|
||||
paymentsByTime := make([]timedPayment, 0, len(m.payments))
|
||||
for _, p := range m.payments {
|
||||
// Skip payments without registered receivers.
|
||||
if len(p.Receivers) <= 0 {
|
||||
continue
|
||||
}
|
||||
// Skip payments for which users didn't notify to be online in the last minute.
|
||||
if p.pingTimestamp.IsZero() || time.Since(p.pingTimestamp).Minutes() > 1 {
|
||||
continue
|
||||
}
|
||||
paymentsByTime = append(paymentsByTime, *p)
|
||||
}
|
||||
sort.SliceStable(paymentsByTime, func(i, j int) bool {
|
||||
return paymentsByTime[i].timestamp.Before(paymentsByTime[j].timestamp)
|
||||
})
|
||||
|
||||
if num < 0 || num > int64(len(paymentsByTime)) {
|
||||
num = int64(len(paymentsByTime))
|
||||
}
|
||||
|
||||
payments := make([]domain.Payment, 0, num)
|
||||
for _, p := range paymentsByTime[:num] {
|
||||
payments = append(payments, p.Payment)
|
||||
delete(m.payments, p.Id)
|
||||
}
|
||||
return payments
|
||||
}
|
||||
|
||||
func (m *paymentsMap) update(payment domain.Payment) error {
|
||||
m.lock.Lock()
|
||||
defer m.lock.Unlock()
|
||||
|
||||
p, ok := m.payments[payment.Id]
|
||||
if !ok {
|
||||
return fmt.Errorf("payment %s not found", payment.Id)
|
||||
}
|
||||
|
||||
p.Payment = payment
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *paymentsMap) updatePingTimestamp(id string) error {
|
||||
m.lock.Lock()
|
||||
defer m.lock.Unlock()
|
||||
|
||||
payment, ok := m.payments[id]
|
||||
if !ok {
|
||||
return fmt.Errorf("payment %s not found", id)
|
||||
}
|
||||
|
||||
payment.pingTimestamp = time.Now()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *paymentsMap) view(id string) (*domain.Payment, bool) {
|
||||
m.lock.RLock()
|
||||
defer m.lock.RUnlock()
|
||||
|
||||
payment, ok := m.payments[id]
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
return &domain.Payment{
|
||||
Id: payment.Id,
|
||||
Inputs: payment.Inputs,
|
||||
Receivers: payment.Receivers,
|
||||
}, true
|
||||
}
|
||||
|
||||
type signedTx struct {
|
||||
tx string
|
||||
signed bool
|
||||
}
|
||||
|
||||
type forfeitTxsMap struct {
|
||||
lock *sync.RWMutex
|
||||
forfeitTxs map[string]*signedTx
|
||||
genesisBlockHash *chainhash.Hash
|
||||
}
|
||||
|
||||
func newForfeitTxsMap(genesisBlockHash *chainhash.Hash) *forfeitTxsMap {
|
||||
return &forfeitTxsMap{&sync.RWMutex{}, make(map[string]*signedTx), genesisBlockHash}
|
||||
}
|
||||
|
||||
func (m *forfeitTxsMap) push(txs []string) {
|
||||
m.lock.Lock()
|
||||
defer m.lock.Unlock()
|
||||
|
||||
faucetTxID, _ := hex.DecodeString(faucetVtxo.Txid)
|
||||
|
||||
for _, tx := range txs {
|
||||
ptx, _ := psetv2.NewPsetFromBase64(tx)
|
||||
utx, _ := ptx.UnsignedTx()
|
||||
|
||||
signed := false
|
||||
|
||||
// find the faucet vtxos, and mark them as signed
|
||||
for _, input := range ptx.Inputs {
|
||||
if bytes.Equal(input.PreviousTxid, faucetTxID) {
|
||||
signed = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
m.forfeitTxs[utx.TxHash().String()] = &signedTx{tx, signed}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *forfeitTxsMap) sign(txs []string) error {
|
||||
m.lock.Lock()
|
||||
defer m.lock.Unlock()
|
||||
|
||||
for _, tx := range txs {
|
||||
ptx, _ := psetv2.NewPsetFromBase64(tx)
|
||||
utx, _ := ptx.UnsignedTx()
|
||||
txid := utx.TxHash().String()
|
||||
|
||||
if _, ok := m.forfeitTxs[txid]; ok {
|
||||
for index, input := range ptx.Inputs {
|
||||
if len(input.TapScriptSig) > 0 {
|
||||
for _, tapScriptSig := range input.TapScriptSig {
|
||||
leafHash, err := chainhash.NewHash(tapScriptSig.LeafHash)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
preimage, err := common.TaprootPreimage(
|
||||
m.genesisBlockHash,
|
||||
ptx,
|
||||
index,
|
||||
leafHash,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
sig, err := schnorr.ParseSignature(tapScriptSig.Signature)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
pubkey, err := schnorr.ParsePubKey(tapScriptSig.PubKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if sig.Verify(preimage, pubkey) {
|
||||
m.forfeitTxs[txid].signed = true
|
||||
} else {
|
||||
return fmt.Errorf("invalid signature")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *forfeitTxsMap) pop() (signed, unsigned []string) {
|
||||
m.lock.Lock()
|
||||
defer m.lock.Unlock()
|
||||
|
||||
for _, t := range m.forfeitTxs {
|
||||
if t.signed {
|
||||
signed = append(signed, t.tx)
|
||||
} else {
|
||||
unsigned = append(unsigned, t.tx)
|
||||
}
|
||||
}
|
||||
|
||||
m.forfeitTxs = make(map[string]*signedTx)
|
||||
return signed, unsigned
|
||||
}
|
||||
|
||||
func (m *forfeitTxsMap) view() []string {
|
||||
m.lock.RLock()
|
||||
defer m.lock.RUnlock()
|
||||
|
||||
txs := make([]string, 0, len(m.forfeitTxs))
|
||||
for _, tx := range m.forfeitTxs {
|
||||
txs = append(txs, tx.tx)
|
||||
}
|
||||
return txs
|
||||
}
|
||||
Reference in New Issue
Block a user