Prevent getting cheated by broadcasting forfeit transactions (#123)

* broadcast forfeit transaction in case the user is trying the cheat the ASP

* fix connector input + --cheat flag in CLI

* WIP

* cleaning and fixes

* add TODO

* sweeper.go: mark round swept if vtxo are redeemed

* fixes after reviews

* revert "--cheat" flag in client

* revert redeem.go

* optimization

* update account.go according to ocean ListUtxos new spec

* WaitForSync implementation

* ocean-wallet/service.go: remove go rountine while writing to notification channel
This commit is contained in:
Louis Singer
2024-03-04 13:58:36 +01:00
committed by GitHub
parent 6d0d03e316
commit 066e8eeabb
21 changed files with 537 additions and 120 deletions

View File

@@ -5,6 +5,7 @@ import (
"context" "context"
"encoding/hex" "encoding/hex"
"fmt" "fmt"
"sync"
"time" "time"
"github.com/ark-network/ark/common" "github.com/ark-network/ark/common"
@@ -286,7 +287,7 @@ func (s *service) startFinalization() {
return return
} }
unsignedPoolTx, tree, err := s.builder.BuildPoolTx(s.pubkey, payments, s.minRelayFee) unsignedPoolTx, tree, connectorAddress, err := s.builder.BuildPoolTx(s.pubkey, payments, s.minRelayFee)
if err != nil { if err != nil {
changes = round.Fail(fmt.Errorf("failed to create pool tx: %s", err)) changes = round.Fail(fmt.Errorf("failed to create pool tx: %s", err))
log.WithError(err).Warn("failed to create pool tx") log.WithError(err).Warn("failed to create pool tx")
@@ -295,7 +296,7 @@ func (s *service) startFinalization() {
log.Debugf("pool tx created for round %s", round.Id) log.Debugf("pool tx created for round %s", round.Id)
connectors, forfeitTxs, err := s.builder.BuildForfeitTxs(s.pubkey, unsignedPoolTx, payments) connectors, forfeitTxs, err := s.builder.BuildForfeitTxs(s.pubkey, unsignedPoolTx, payments, s.minRelayFee)
if err != nil { if err != nil {
changes = round.Fail(fmt.Errorf("failed to create connectors and forfeit txs: %s", err)) changes = round.Fail(fmt.Errorf("failed to create connectors and forfeit txs: %s", err))
log.WithError(err).Warn("failed to create connectors and forfeit txs") log.WithError(err).Warn("failed to create connectors and forfeit txs")
@@ -304,7 +305,7 @@ func (s *service) startFinalization() {
log.Debugf("forfeit transactions created for round %s", round.Id) log.Debugf("forfeit transactions created for round %s", round.Id)
events, err := round.StartFinalization(connectors, tree, unsignedPoolTx) events, err := round.StartFinalization(connectorAddress, connectors, tree, unsignedPoolTx)
if err != nil { if err != nil {
changes = round.Fail(fmt.Errorf("failed to start finalization: %s", err)) changes = round.Fail(fmt.Errorf("failed to start finalization: %s", err))
log.WithError(err).Warn("failed to start finalization") log.WithError(err).Warn("failed to start finalization")
@@ -369,26 +370,163 @@ func (s *service) finalizeRound() {
func (s *service) listenToRedemptions() { func (s *service) listenToRedemptions() {
ctx := context.Background() ctx := context.Background()
chVtxos := s.scanner.GetNotificationChannel(ctx) chVtxos := s.scanner.GetNotificationChannel(ctx)
mutx := &sync.Mutex{}
for vtxoKeys := range chVtxos { for vtxoKeys := range chVtxos {
if len(vtxoKeys) > 0 { go func(vtxoKeys []domain.VtxoKey) {
for { vtxosRepo := s.repoManager.Vtxos()
// TODO: make sure that the vtxos haven't been already spent, otherwise roundRepo := s.repoManager.Rounds()
// broadcast the corresponding forfeit tx and connector to prevent
// getting cheated. for _, v := range vtxoKeys {
vtxos, err := s.repoManager.Vtxos().RedeemVtxos(ctx, vtxoKeys) vtxos, err := vtxosRepo.GetVtxos(ctx, []domain.VtxoKey{v})
if err != nil { if err != nil {
log.WithError(err).Warn("failed to redeem vtxos, retrying...") log.WithError(err).Warn("failed to retrieve vtxos, skipping...")
time.Sleep(100 * time.Millisecond)
continue continue
} }
if len(vtxos) > 0 {
log.Debugf("redeemed %d vtxos", len(vtxos)) vtxo := vtxos[0]
if vtxo.Redeemed {
continue
} }
break
if _, err := s.repoManager.Vtxos().RedeemVtxos(ctx, []domain.VtxoKey{vtxo.VtxoKey}); err != nil {
log.WithError(err).Warn("failed to redeem vtxos, retrying...")
continue
}
log.Debugf("vtxo %s redeemed", vtxo.Txid)
if !vtxo.Spent {
continue
}
log.Debugf("fraud detected on vtxo %s", vtxo.Txid)
round, err := roundRepo.GetRoundWithTxid(ctx, vtxo.SpentBy)
if err != nil {
log.WithError(err).Warn("failed to retrieve round")
continue
}
mutx.Lock()
defer mutx.Unlock()
connectorTxid, connectorVout, err := s.getNextConnector(ctx, *round)
if err != nil {
log.WithError(err).Warn("failed to retrieve next connector")
continue
}
forfeitTx, err := findForfeitTx(round.ForfeitTxs, connectorTxid, connectorVout, vtxo.Txid)
if err != nil {
log.WithError(err).Warn("failed to retrieve forfeit tx")
continue
}
signedForfeitTx, err := s.wallet.SignConnectorInput(ctx, forfeitTx, []int{0}, false)
if err != nil {
log.WithError(err).Warn("failed to sign connector input in forfeit tx")
continue
}
signedForfeitTx, err = s.wallet.SignPsetWithKey(ctx, signedForfeitTx, []int{1})
if err != nil {
log.WithError(err).Warn("failed to sign vtxo input in forfeit tx")
continue
}
forfeitTxHex, err := finalizeAndExtractForfeit(signedForfeitTx)
if err != nil {
log.WithError(err).Warn("failed to finalize forfeit tx")
continue
}
forfeitTxid, err := s.wallet.BroadcastTransaction(ctx, forfeitTxHex)
if err != nil {
log.WithError(err).Warn("failed to broadcast forfeit tx")
continue
}
log.Debugf("broadcasted forfeit tx %s", forfeitTxid)
}
}(vtxoKeys)
}
}
func (s *service) getNextConnector(
ctx context.Context,
round domain.Round,
) (string, uint32, error) {
connectorTx, err := psetv2.NewPsetFromBase64(round.Connectors[0])
if err != nil {
return "", 0, err
}
prevout := connectorTx.Inputs[0].WitnessUtxo
if prevout == nil {
return "", 0, fmt.Errorf("connector prevout not found")
}
utxos, err := s.wallet.ListConnectorUtxos(ctx, round.ConnectorAddress)
if err != nil {
return "", 0, err
}
// if we do not find any utxos, we make sure to wait for the connector outpoint to be confirmed then we retry
if len(utxos) <= 0 {
if err := s.wallet.WaitForSync(ctx, round.Txid); err != nil {
return "", 0, err
}
utxos, err = s.wallet.ListConnectorUtxos(ctx, round.ConnectorAddress)
if err != nil {
return "", 0, err
}
}
// search for an already existing connector
for _, u := range utxos {
if u.GetValue() == 450 {
return u.GetTxid(), u.GetIndex(), nil
}
}
for _, u := range utxos {
if u.GetValue() > 450 {
for _, b64 := range round.Connectors {
pset, err := psetv2.NewPsetFromBase64(b64)
if err != nil {
return "", 0, err
}
for _, i := range pset.Inputs {
if chainhash.Hash(i.PreviousTxid).String() == u.GetTxid() && i.PreviousTxIndex == u.GetIndex() {
// sign & broadcast the connector tx
signedConnectorTx, err := s.wallet.SignConnectorInput(ctx, b64, []int{0}, true)
if err != nil {
return "", 0, err
}
connectorTxid, err := s.wallet.BroadcastTransaction(ctx, signedConnectorTx)
if err != nil {
return "", 0, err
}
log.Debugf("broadcasted connector tx %s", connectorTxid)
// wait for the connector tx to be in the mempool
if err := s.wallet.WaitForSync(ctx, connectorTxid); err != nil {
return "", 0, err
}
return connectorTxid, 0, nil
} }
} }
} }
} }
}
return "", 0, fmt.Errorf("no connector utxos found")
}
func (s *service) updateVtxoSet(round *domain.Round) { func (s *service) updateVtxoSet(round *domain.Round) {
// Update the vtxo set only after a round is finalized. // Update the vtxo set only after a round is finalized.
@@ -401,7 +539,7 @@ func (s *service) updateVtxoSet(round *domain.Round) {
spentVtxos := getSpentVtxos(round.Payments) spentVtxos := getSpentVtxos(round.Payments)
if len(spentVtxos) > 0 { if len(spentVtxos) > 0 {
for { for {
if err := repo.SpendVtxos(ctx, spentVtxos); err != nil { if err := repo.SpendVtxos(ctx, spentVtxos, round.Txid); err != nil {
log.WithError(err).Warn("failed to add new vtxos, retrying soon") log.WithError(err).Warn("failed to add new vtxos, retrying soon")
time.Sleep(100 * time.Millisecond) time.Sleep(100 * time.Millisecond)
continue continue
@@ -535,12 +673,28 @@ func (s *service) stopWatchingVtxos(vtxos []domain.Vtxo) {
} }
func (s *service) restoreWatchingVtxos() error { func (s *service) restoreWatchingVtxos() error {
vtxos, err := s.repoManager.Vtxos().GetSpendableVtxos( sweepableRounds, err := s.repoManager.Rounds().GetSweepableRounds(context.Background())
context.Background(), "",
)
if err != nil { if err != nil {
return err return err
} }
vtxos := make([]domain.Vtxo, 0)
for _, round := range sweepableRounds {
fromRound, err := s.repoManager.Vtxos().GetVtxosForRound(
context.Background(), round.Txid,
)
if err != nil {
log.WithError(err).Warnf("failed to retrieve vtxos for round %s", round.Txid)
continue
}
for _, v := range fromRound {
if !v.Swept && !v.Redeemed {
vtxos = append(vtxos, v)
}
}
}
if len(vtxos) <= 0 { if len(vtxos) <= 0 {
return nil return nil
} }
@@ -617,3 +771,45 @@ func getPaymentsFromOnboarding(
payment := domain.NewPaymentUnsafe(nil, receivers) payment := domain.NewPaymentUnsafe(nil, receivers)
return []domain.Payment{*payment} return []domain.Payment{*payment}
} }
func finalizeAndExtractForfeit(b64 string) (string, error) {
p, err := psetv2.NewPsetFromBase64(b64)
if err != nil {
return "", err
}
// finalize connector input
if err := psetv2.FinalizeAll(p); err != nil {
return "", err
}
// extract the forfeit tx
extracted, err := psetv2.Extract(p)
if err != nil {
return "", err
}
return extracted.ToHex()
}
func findForfeitTx(
forfeits []string, connectorTxid string, connectorVout uint32, vtxoTxid string,
) (string, error) {
for _, forfeit := range forfeits {
forfeitTx, err := psetv2.NewPsetFromBase64(forfeit)
if err != nil {
return "", err
}
connector := forfeitTx.Inputs[0]
vtxoInput := forfeitTx.Inputs[1]
if chainhash.Hash(connector.PreviousTxid).String() == connectorTxid &&
connector.PreviousTxIndex == connectorVout &&
chainhash.Hash(vtxoInput.PreviousTxid).String() == vtxoTxid {
return forfeit, nil
}
}
return "", fmt.Errorf("forfeit tx not found")
}

View File

@@ -205,9 +205,10 @@ func (s *sweeper) createTask(
} }
} }
vtxosRepository := s.repoManager.Vtxos()
if len(sweepInputs) > 0 { if len(sweepInputs) > 0 {
// build the sweep transaction with all the expired non-swept shared outputs // build the sweep transaction with all the expired non-swept shared outputs
sweepTx, err := s.builder.BuildSweepTx(s.wallet, sweepInputs) sweepTx, err := s.builder.BuildSweepTx(sweepInputs)
if err != nil { if err != nil {
log.WithError(err).Error("error while building sweep tx") log.WithError(err).Error("error while building sweep tx")
return return
@@ -231,7 +232,6 @@ func (s *sweeper) createTask(
} }
if len(txid) > 0 { if len(txid) > 0 {
log.Debugln("sweep tx broadcasted:", txid) log.Debugln("sweep tx broadcasted:", txid)
vtxosRepository := s.repoManager.Vtxos()
// mark the vtxos as swept // mark the vtxos as swept
if err := vtxosRepository.SweepVtxos(ctx, vtxoKeys); err != nil { if err := vtxosRepository.SweepVtxos(ctx, vtxoKeys); err != nil {
@@ -240,6 +240,8 @@ func (s *sweeper) createTask(
} }
log.Debugf("%d vtxos swept", len(vtxoKeys)) log.Debugf("%d vtxos swept", len(vtxoKeys))
}
}
roundVtxos, err := vtxosRepository.GetVtxosForRound(ctx, roundTxid) roundVtxos, err := vtxosRepository.GetVtxosForRound(ctx, roundTxid)
if err != nil { if err != nil {
@@ -249,7 +251,7 @@ func (s *sweeper) createTask(
allSwept := true allSwept := true
for _, vtxo := range roundVtxos { for _, vtxo := range roundVtxos {
allSwept = allSwept && vtxo.Swept allSwept = allSwept && (vtxo.Swept || vtxo.Redeemed)
if !allSwept { if !allSwept {
break break
} }
@@ -264,6 +266,7 @@ func (s *sweeper) createTask(
return return
} }
log.Debugf("round %s fully swept", roundTxid)
round.Sweep() round.Sweep()
if err := roundRepo.AddOrUpdateRound(ctx, *round); err != nil { if err := roundRepo.AddOrUpdateRound(ctx, *round); err != nil {
@@ -273,8 +276,6 @@ func (s *sweeper) createTask(
} }
} }
} }
}
}
// onchainOutputs iterates over all the nodes' outputs in the congestion tree and checks their onchain state // onchainOutputs iterates over all the nodes' outputs in the congestion tree and checks their onchain state
// returns the sweepable outputs as ports.SweepInput mapped by their expiration time // returns the sweepable outputs as ports.SweepInput mapped by their expiration time

View File

@@ -201,6 +201,7 @@ func (m *forfeitTxsMap) sign(txs []string) error {
} }
if sig.Verify(preimage, pubkey) { if sig.Verify(preimage, pubkey) {
m.forfeitTxs[txid].tx = tx
m.forfeitTxs[txid].signed = true m.forfeitTxs[txid].signed = true
} else { } else {
return fmt.Errorf("invalid signature") return fmt.Errorf("invalid signature")

View File

@@ -21,6 +21,7 @@ type RoundFinalizationStarted struct {
Id string Id string
CongestionTree tree.CongestionTree CongestionTree tree.CongestionTree
Connectors []string Connectors []string
ConnectorAddress string
UnsignedForfeitTxs []string UnsignedForfeitTxs []string
PoolTx string PoolTx string
} }

View File

@@ -131,6 +131,7 @@ type Vtxo struct {
VtxoKey VtxoKey
Receiver Receiver
PoolTx string PoolTx string
SpentBy string
Spent bool Spent bool
Redeemed bool Redeemed bool
Swept bool Swept bool

View File

@@ -44,9 +44,10 @@ type Round struct {
ForfeitTxs []string ForfeitTxs []string
CongestionTree tree.CongestionTree CongestionTree tree.CongestionTree
Connectors []string Connectors []string
ConnectorAddress string
DustAmount uint64 DustAmount uint64
Version uint Version uint
Swept bool // true if all the vtxos are vtxo.Swept Swept bool // true if all the vtxos are vtxo.Swept or vtxo.Redeemed
changes []RoundEvent changes []RoundEvent
} }
@@ -118,6 +119,7 @@ func (r *Round) On(event RoundEvent, replayed bool) {
r.Stage.Code = FinalizationStage r.Stage.Code = FinalizationStage
r.CongestionTree = e.CongestionTree r.CongestionTree = e.CongestionTree
r.Connectors = append([]string{}, e.Connectors...) r.Connectors = append([]string{}, e.Connectors...)
r.ConnectorAddress = e.ConnectorAddress
r.UnsignedTx = e.PoolTx r.UnsignedTx = e.PoolTx
case RoundFinalized: case RoundFinalized:
r.Stage.Ended = true r.Stage.Ended = true
@@ -178,7 +180,7 @@ func (r *Round) RegisterPayments(payments []Payment) ([]RoundEvent, error) {
return []RoundEvent{event}, nil return []RoundEvent{event}, nil
} }
func (r *Round) StartFinalization(connectors []string, congestionTree tree.CongestionTree, poolTx string) ([]RoundEvent, error) { func (r *Round) StartFinalization(connectorAddress string, connectors []string, congestionTree tree.CongestionTree, poolTx string) ([]RoundEvent, error) {
if len(connectors) <= 0 { if len(connectors) <= 0 {
return nil, fmt.Errorf("missing list of connectors") return nil, fmt.Errorf("missing list of connectors")
} }
@@ -199,6 +201,7 @@ func (r *Round) StartFinalization(connectors []string, congestionTree tree.Conge
Id: r.Id, Id: r.Id,
CongestionTree: congestionTree, CongestionTree: congestionTree,
Connectors: connectors, Connectors: connectors,
ConnectorAddress: connectorAddress,
PoolTx: poolTx, PoolTx: poolTx,
} }
r.raise(event) r.raise(event)

View File

@@ -19,7 +19,7 @@ type RoundRepository interface {
type VtxoRepository interface { type VtxoRepository interface {
AddVtxos(ctx context.Context, vtxos []Vtxo) error AddVtxos(ctx context.Context, vtxos []Vtxo) error
SpendVtxos(ctx context.Context, vtxos []VtxoKey) error SpendVtxos(ctx context.Context, vtxos []VtxoKey, txid string) error
RedeemVtxos(ctx context.Context, vtxos []VtxoKey) ([]Vtxo, error) RedeemVtxos(ctx context.Context, vtxos []VtxoKey) ([]Vtxo, error)
GetVtxos(ctx context.Context, vtxos []VtxoKey) ([]Vtxo, error) GetVtxos(ctx context.Context, vtxos []VtxoKey) ([]Vtxo, error)
GetVtxosForRound(ctx context.Context, txid string) ([]Vtxo, error) GetVtxosForRound(ctx context.Context, txid string) ([]Vtxo, error)

View File

@@ -296,7 +296,7 @@ func testStartFinalization(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
require.NotEmpty(t, events) require.NotEmpty(t, events)
events, err = round.StartFinalization(connectors, congestionTree, poolTx) events, err = round.StartFinalization("", connectors, congestionTree, poolTx)
require.NoError(t, err) require.NoError(t, err)
require.Len(t, events, 1) require.Len(t, events, 1)
require.True(t, round.IsStarted()) require.True(t, round.IsStarted())
@@ -418,7 +418,8 @@ func testStartFinalization(t *testing.T) {
} }
for _, f := range fixtures { for _, f := range fixtures {
events, err := f.round.StartFinalization(f.connectors, f.tree, f.poolTx) // TODO fix this
events, err := f.round.StartFinalization("", f.connectors, f.tree, f.poolTx)
require.EqualError(t, err, f.expectedErr) require.EqualError(t, err, f.expectedErr)
require.Empty(t, events) require.Empty(t, events)
} }
@@ -438,7 +439,7 @@ func testEndFinalization(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
require.NotEmpty(t, events) require.NotEmpty(t, events)
events, err = round.StartFinalization(connectors, congestionTree, poolTx) events, err = round.StartFinalization("", connectors, congestionTree, poolTx)
require.NoError(t, err) require.NoError(t, err)
require.NotEmpty(t, events) require.NotEmpty(t, events)

View File

@@ -14,15 +14,8 @@ type SweepInput struct {
} }
type TxBuilder interface { type TxBuilder interface {
BuildPoolTx( BuildPoolTx(aspPubkey *secp256k1.PublicKey, payments []domain.Payment, minRelayFee uint64) (poolTx string, congestionTree tree.CongestionTree, connectorAddress string, err error)
aspPubkey *secp256k1.PublicKey, payments []domain.Payment, minRelayFee uint64, BuildForfeitTxs(aspPubkey *secp256k1.PublicKey, poolTx string, payments []domain.Payment, minRelayFee uint64) (connectors []string, forfeitTxs []string, err error)
) (poolTx string, congestionTree tree.CongestionTree, err error) BuildSweepTx(inputs []SweepInput) (signedSweepTx string, err error)
BuildForfeitTxs(
aspPubkey *secp256k1.PublicKey, poolTx string, payments []domain.Payment,
) (connectors []string, forfeitTxs []string, err error)
BuildSweepTx(
wallet WalletService,
inputs []SweepInput,
) (signedSweepTx string, err error)
GetVtxoScript(userPubkey, aspPubkey *secp256k1.PublicKey) ([]byte, error) GetVtxoScript(userPubkey, aspPubkey *secp256k1.PublicKey) ([]byte, error)
} }

View File

@@ -10,6 +10,7 @@ type WalletService interface {
BlockchainScanner BlockchainScanner
Status(ctx context.Context) (WalletStatus, error) Status(ctx context.Context) (WalletStatus, error)
GetPubkey(ctx context.Context) (*secp256k1.PublicKey, error) GetPubkey(ctx context.Context) (*secp256k1.PublicKey, error)
DeriveConnectorAddress(ctx context.Context) (string, error)
DeriveAddresses(ctx context.Context, num int) ([]string, error) DeriveAddresses(ctx context.Context, num int) ([]string, error)
SignPset( SignPset(
ctx context.Context, pset string, extractRawTx bool, ctx context.Context, pset string, extractRawTx bool,
@@ -18,7 +19,10 @@ type WalletService interface {
BroadcastTransaction(ctx context.Context, txHex string) (string, error) BroadcastTransaction(ctx context.Context, txHex string) (string, error)
SignPsetWithKey(ctx context.Context, pset string, inputIndexes []int) (string, error) // inputIndexes == nil means sign all inputs SignPsetWithKey(ctx context.Context, pset string, inputIndexes []int) (string, error) // inputIndexes == nil means sign all inputs
IsTransactionConfirmed(ctx context.Context, txid string) (isConfirmed bool, blocktime int64, err error) IsTransactionConfirmed(ctx context.Context, txid string) (isConfirmed bool, blocktime int64, err error)
WaitForSync(ctx context.Context, txid string) error
EstimateFees(ctx context.Context, pset string) (uint64, error) EstimateFees(ctx context.Context, pset string) (uint64, error)
SignConnectorInput(ctx context.Context, pset string, inputIndexes []int, extract bool) (string, error)
ListConnectorUtxos(ctx context.Context, connectorAddress string) ([]TxInput, error)
Close() Close()
} }

View File

@@ -53,10 +53,10 @@ func (r *vtxoRepository) AddVtxos(
} }
func (r *vtxoRepository) SpendVtxos( func (r *vtxoRepository) SpendVtxos(
ctx context.Context, vtxoKeys []domain.VtxoKey, ctx context.Context, vtxoKeys []domain.VtxoKey, spentBy string,
) error { ) error {
for _, vtxoKey := range vtxoKeys { for _, vtxoKey := range vtxoKeys {
if err := r.spendVtxo(ctx, vtxoKey); err != nil { if err := r.spendVtxo(ctx, vtxoKey, spentBy); err != nil {
return err return err
} }
} }
@@ -96,7 +96,7 @@ func (r *vtxoRepository) GetVtxos(
func (r *vtxoRepository) GetVtxosForRound( func (r *vtxoRepository) GetVtxosForRound(
ctx context.Context, txid string, ctx context.Context, txid string,
) ([]domain.Vtxo, error) { ) ([]domain.Vtxo, error) {
query := badgerhold.Where("Txid").Eq(txid) query := badgerhold.Where("PoolTx").Eq(txid)
return r.findVtxos(ctx, query) return r.findVtxos(ctx, query)
} }
@@ -161,7 +161,7 @@ func (r *vtxoRepository) getVtxo(
return &vtxo, nil return &vtxo, nil
} }
func (r *vtxoRepository) spendVtxo(ctx context.Context, vtxoKey domain.VtxoKey) error { func (r *vtxoRepository) spendVtxo(ctx context.Context, vtxoKey domain.VtxoKey, spendBy string) error {
vtxo, err := r.getVtxo(ctx, vtxoKey) vtxo, err := r.getVtxo(ctx, vtxoKey)
if err != nil { if err != nil {
if strings.Contains(err.Error(), "not found") { if strings.Contains(err.Error(), "not found") {
@@ -174,6 +174,7 @@ func (r *vtxoRepository) spendVtxo(ctx context.Context, vtxoKey domain.VtxoKey)
} }
vtxo.Spent = true vtxo.Spent = true
vtxo.SpentBy = spendBy
if ctx.Value("tx") != nil { if ctx.Value("tx") != nil {
tx := ctx.Value("tx").(*badger.Txn) tx := ctx.Value("tx").(*badger.Txn)
err = r.store.TxUpdate(tx, vtxoKey.Hash(), *vtxo) err = r.store.TxUpdate(tx, vtxoKey.Hash(), *vtxo)

View File

@@ -362,7 +362,7 @@ func testVtxoRepository(t *testing.T, svc ports.RepoManager) {
require.NoError(t, err) require.NoError(t, err)
require.Exactly(t, userVtxos, spendableVtxos) require.Exactly(t, userVtxos, spendableVtxos)
err = svc.Vtxos().SpendVtxos(ctx, vtxoKeys[:1]) err = svc.Vtxos().SpendVtxos(ctx, vtxoKeys[:1], txid)
require.NoError(t, err) require.NoError(t, err)
spentVtxos, err := svc.Vtxos().GetVtxos(ctx, vtxoKeys[:1]) spentVtxos, err := svc.Vtxos().GetVtxos(ctx, vtxoKeys[:1])

View File

@@ -4,14 +4,52 @@ import (
"context" "context"
pb "github.com/ark-network/ark/api-spec/protobuf/gen/ocean/v1" pb "github.com/ark-network/ark/api-spec/protobuf/gen/ocean/v1"
"github.com/ark-network/ark/internal/core/ports"
"github.com/vulpemventures/go-elements/address" "github.com/vulpemventures/go-elements/address"
) )
func (s *service) DeriveAddresses( func (s *service) DeriveAddresses(
ctx context.Context, numOfAddresses int, ctx context.Context, numOfAddresses int,
) ([]string, error) {
return s.deriveAddresses(ctx, numOfAddresses, arkAccount)
}
func (s *service) DeriveConnectorAddress(ctx context.Context) (string, error) {
addresses, err := s.deriveAddresses(ctx, 1, connectorAccount)
if err != nil {
return "", err
}
return addresses[0], nil
}
func (s *service) ListConnectorUtxos(
ctx context.Context, connectorAddress string,
) ([]ports.TxInput, error) {
res, err := s.accountClient.ListUtxos(ctx, &pb.ListUtxosRequest{
AccountName: connectorAccount,
Addresses: []string{connectorAddress},
})
if err != nil {
return nil, err
}
utxos := make([]ports.TxInput, 0)
for _, utxo := range res.GetSpendableUtxos().GetUtxos() {
utxos = append(utxos, utxo)
}
for _, utxo := range res.GetLockedUtxos().GetUtxos() {
utxos = append(utxos, utxo)
}
return utxos, nil
}
func (s *service) deriveAddresses(
ctx context.Context, numOfAddresses int, account string,
) ([]string, error) { ) ([]string, error) {
res, err := s.accountClient.DeriveAddresses(ctx, &pb.DeriveAddressesRequest{ res, err := s.accountClient.DeriveAddresses(ctx, &pb.DeriveAddressesRequest{
AccountName: accountLabel, AccountName: account,
NumOfAddresses: uint64(numOfAddresses), NumOfAddresses: uint64(numOfAddresses),
}) })
if err != nil { if err != nil {

View File

@@ -61,16 +61,35 @@ func NewService(addr string) (ports.WalletService, error) {
return nil, err return nil, err
} }
found := false mainAccountFound, connectorAccountFound := false, false
for _, account := range info.GetAccounts() { for _, account := range info.GetAccounts() {
if account.GetLabel() == accountLabel { if account.GetLabel() == arkAccount {
found = true mainAccountFound = true
continue
}
if account.GetLabel() == connectorAccount {
connectorAccountFound = true
continue
}
if mainAccountFound && connectorAccountFound {
break break
} }
} }
if !found { if !mainAccountFound {
if _, err := accountClient.CreateAccountBIP44(ctx, &pb.CreateAccountBIP44Request{ if _, err := accountClient.CreateAccountBIP44(ctx, &pb.CreateAccountBIP44Request{
Label: accountLabel, Label: arkAccount,
Unconfidential: true,
}); err != nil {
return nil, err
}
}
if !connectorAccountFound {
if _, err := accountClient.CreateAccountBIP44(ctx, &pb.CreateAccountBIP44Request{
Label: connectorAccount,
Unconfidential: true, Unconfidential: true,
}); err != nil { }); err != nil {
return nil, err return nil, err
@@ -114,9 +133,7 @@ func (s *service) listenToNotificaitons() {
} }
vtxos := toVtxos(msg.GetUtxos()) vtxos := toVtxos(msg.GetUtxos())
if len(vtxos) > 0 { if len(vtxos) > 0 {
go func() {
s.chVtxos <- vtxos s.chVtxos <- vtxos
}()
} }
} }
} }

View File

@@ -59,8 +59,9 @@ func (s *service) SignPset(
} }
func (s *service) SelectUtxos(ctx context.Context, asset string, amount uint64) ([]ports.TxInput, uint64, error) { func (s *service) SelectUtxos(ctx context.Context, asset string, amount uint64) ([]ports.TxInput, uint64, error) {
// TODO: select coins from the connector account IF the round is swept
res, err := s.txClient.SelectUtxos(ctx, &pb.SelectUtxosRequest{ res, err := s.txClient.SelectUtxos(ctx, &pb.SelectUtxosRequest{
AccountName: accountLabel, AccountName: arkAccount,
TargetAsset: asset, TargetAsset: asset,
TargetAmount: amount, TargetAmount: amount,
}) })
@@ -81,22 +82,22 @@ func (s *service) SelectUtxos(ctx context.Context, asset string, amount uint64)
return inputs, res.GetChange(), nil return inputs, res.GetChange(), nil
} }
func (s *service) GetTransaction( func (s *service) getTransaction(
ctx context.Context, txid string, ctx context.Context, txid string,
) (string, int64, error) { ) (string, bool, int64, error) {
res, err := s.txClient.GetTransaction(ctx, &pb.GetTransactionRequest{ res, err := s.txClient.GetTransaction(ctx, &pb.GetTransactionRequest{
Txid: txid, Txid: txid,
}) })
if err != nil { if err != nil {
return "", 0, err return "", false, 0, err
} }
if res.GetBlockDetails().GetTimestamp() > 0 { if res.GetBlockDetails().GetTimestamp() > 0 {
return res.GetTxHex(), res.BlockDetails.GetTimestamp(), nil return res.GetTxHex(), true, res.BlockDetails.GetTimestamp(), nil
} }
// if not confirmed, we return now + 30 secs to estimate the next blocktime // if not confirmed, we return now + 1 min to estimate the next blocktime
return res.GetTxHex(), time.Now().Unix() + 30, nil return res.GetTxHex(), false, time.Now().Add(time.Minute).Unix(), nil
} }
func (s *service) BroadcastTransaction( func (s *service) BroadcastTransaction(
@@ -120,15 +121,29 @@ func (s *service) BroadcastTransaction(
func (s *service) IsTransactionConfirmed( func (s *service) IsTransactionConfirmed(
ctx context.Context, txid string, ctx context.Context, txid string,
) (bool, int64, error) { ) (bool, int64, error) {
_, blocktime, err := s.GetTransaction(ctx, txid) _, isConfirmed, blocktime, err := s.getTransaction(ctx, txid)
if err != nil { if err != nil {
if strings.Contains(strings.ToLower(err.Error()), "missing transaction") { if strings.Contains(strings.ToLower(err.Error()), "missing transaction") {
return false, 0, nil return isConfirmed, 0, nil
} }
return false, 0, err return false, 0, err
} }
return true, blocktime, nil return isConfirmed, blocktime, nil
}
func (s *service) WaitForSync(ctx context.Context, txid string) error {
for {
time.Sleep(5 * time.Second)
_, _, _, err := s.getTransaction(ctx, txid)
if err != nil {
if strings.Contains(strings.ToLower(err.Error()), "missing transaction") {
continue
}
return err
}
break
}
return nil
} }
func (s *service) SignPsetWithKey(ctx context.Context, b64 string, indexes []int) (string, error) { func (s *service) SignPsetWithKey(ctx context.Context, b64 string, indexes []int) (string, error) {
@@ -206,6 +221,38 @@ func (s *service) SignPsetWithKey(ctx context.Context, b64 string, indexes []int
return signedPset.GetSignedTx(), nil return signedPset.GetSignedTx(), nil
} }
func (s *service) SignConnectorInput(ctx context.Context, pset string, inputIndexes []int, extract bool) (string, error) {
decodedTx, err := psetv2.NewPsetFromBase64(pset)
if err != nil {
return "", err
}
utxos := make([]*pb.Input, 0, len(decodedTx.Inputs))
for i := range inputIndexes {
if i >= len(decodedTx.Inputs) {
return "", fmt.Errorf("input index %d out of range", i)
}
input := decodedTx.Inputs[i]
utxos = append(utxos, &pb.Input{
Txid: chainhash.Hash(input.PreviousTxid).String(),
Index: input.PreviousTxIndex,
})
}
_, err = s.txClient.LockUtxos(ctx, &pb.LockUtxosRequest{
AccountName: connectorAccount,
Utxos: utxos,
})
if err != nil {
return "", err
}
return s.SignPset(ctx, pset, extract)
}
func (s *service) EstimateFees( func (s *service) EstimateFees(
ctx context.Context, pset string, ctx context.Context, pset string,
) (uint64, error) { ) (uint64, error) {

View File

@@ -11,7 +11,10 @@ import (
"github.com/vulpemventures/go-bip32" "github.com/vulpemventures/go-bip32"
) )
const accountLabel = "ark" const (
arkAccount = "ark"
connectorAccount = "ark-connector"
)
var derivationPath = []uint32{0, 0} var derivationPath = []uint32{0, 0}
@@ -63,7 +66,7 @@ func (s *service) findAccount(ctx context.Context, label string) (*pb.AccountInf
} }
func (s *service) getPubkey(ctx context.Context) (*secp256k1.PublicKey, *bip32.Key, error) { func (s *service) getPubkey(ctx context.Context) (*secp256k1.PublicKey, *bip32.Key, error) {
account, err := s.findAccount(ctx, accountLabel) account, err := s.findAccount(ctx, arkAccount)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }

View File

@@ -11,6 +11,7 @@ import (
"github.com/decred/dcrd/dcrec/secp256k1/v4" "github.com/decred/dcrd/dcrec/secp256k1/v4"
"github.com/vulpemventures/go-elements/address" "github.com/vulpemventures/go-elements/address"
"github.com/vulpemventures/go-elements/network" "github.com/vulpemventures/go-elements/network"
"github.com/vulpemventures/go-elements/payment"
"github.com/vulpemventures/go-elements/psetv2" "github.com/vulpemventures/go-elements/psetv2"
"github.com/vulpemventures/go-elements/taproot" "github.com/vulpemventures/go-elements/taproot"
) )
@@ -41,9 +42,9 @@ func (b *txBuilder) GetVtxoScript(userPubkey, aspPubkey *secp256k1.PublicKey) ([
return outputScript, nil return outputScript, nil
} }
func (b *txBuilder) BuildSweepTx(wallet ports.WalletService, inputs []ports.SweepInput) (signedSweepTx string, err error) { func (b *txBuilder) BuildSweepTx(inputs []ports.SweepInput) (signedSweepTx string, err error) {
sweepPset, err := sweepTransaction( sweepPset, err := sweepTransaction(
wallet, b.wallet,
inputs, inputs,
b.net.AssetID, b.net.AssetID,
) )
@@ -57,7 +58,7 @@ func (b *txBuilder) BuildSweepTx(wallet ports.WalletService, inputs []ports.Swee
} }
ctx := context.Background() ctx := context.Background()
signedSweepPsetB64, err := wallet.SignPsetWithKey(ctx, sweepPsetBase64, nil) signedSweepPsetB64, err := b.wallet.SignPsetWithKey(ctx, sweepPsetBase64, nil)
if err != nil { if err != nil {
return "", err return "", err
} }
@@ -80,9 +81,14 @@ func (b *txBuilder) BuildSweepTx(wallet ports.WalletService, inputs []ports.Swee
} }
func (b *txBuilder) BuildForfeitTxs( func (b *txBuilder) BuildForfeitTxs(
aspPubkey *secp256k1.PublicKey, poolTx string, payments []domain.Payment, aspPubkey *secp256k1.PublicKey, poolTx string, payments []domain.Payment, minRelayFee uint64,
) (connectors []string, forfeitTxs []string, err error) { ) (connectors []string, forfeitTxs []string, err error) {
connectorTxs, err := b.createConnectors(poolTx, payments, aspPubkey) connectorAddress, err := b.getConnectorAddress(poolTx)
if err != nil {
return nil, nil, err
}
connectorTxs, err := b.createConnectors(poolTx, payments, connectorAddress, minRelayFee)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
@@ -101,7 +107,7 @@ func (b *txBuilder) BuildForfeitTxs(
func (b *txBuilder) BuildPoolTx( func (b *txBuilder) BuildPoolTx(
aspPubkey *secp256k1.PublicKey, payments []domain.Payment, minRelayFee uint64, aspPubkey *secp256k1.PublicKey, payments []domain.Payment, minRelayFee uint64,
) (poolTx string, congestionTree tree.CongestionTree, err error) { ) (poolTx string, congestionTree tree.CongestionTree, connectorAddress string, err error) {
// The creation of the tree and the pool tx are tightly coupled: // The creation of the tree and the pool tx are tightly coupled:
// - building the tree requires knowing the shared outpoint (txid:vout) // - building the tree requires knowing the shared outpoint (txid:vout)
// - building the pool tx requires knowing the shared output script and amount // - building the pool tx requires knowing the shared output script and amount
@@ -120,8 +126,13 @@ func (b *txBuilder) BuildPoolTx(
return return
} }
connectorAddress, err = b.wallet.DeriveConnectorAddress(context.Background())
if err != nil {
return
}
ptx, err := b.createPoolTx( ptx, err := b.createPoolTx(
sharedOutputAmount, sharedOutputScript, payments, aspPubkey, sharedOutputAmount, sharedOutputScript, payments, aspPubkey, connectorAddress, minRelayFee,
) )
if err != nil { if err != nil {
return return
@@ -190,15 +201,24 @@ func (b *txBuilder) getLeafScriptAndTree(
func (b *txBuilder) createPoolTx( func (b *txBuilder) createPoolTx(
sharedOutputAmount uint64, sharedOutputScript []byte, sharedOutputAmount uint64, sharedOutputScript []byte,
payments []domain.Payment, aspPubKey *secp256k1.PublicKey, payments []domain.Payment, aspPubKey *secp256k1.PublicKey, connectorAddress string, minRelayFee uint64,
) (*psetv2.Pset, error) { ) (*psetv2.Pset, error) {
aspScript, err := p2wpkhScript(aspPubKey, b.net) aspScript, err := p2wpkhScript(aspPubKey, b.net)
if err != nil { if err != nil {
return nil, err return nil, err
} }
connectorScript, err := address.ToOutputScript(connectorAddress)
if err != nil {
return nil, err
}
receivers := getOnchainReceivers(payments) receivers := getOnchainReceivers(payments)
connectorsAmount := connectorAmount * countSpentVtxos(payments) nbOfInputs := countSpentVtxos(payments)
connectorsAmount := (connectorAmount + minRelayFee) * nbOfInputs
if nbOfInputs > 1 {
connectorsAmount -= minRelayFee
}
targetAmount := sharedOutputAmount + connectorsAmount targetAmount := sharedOutputAmount + connectorsAmount
outputs := []psetv2.OutputArgs{ outputs := []psetv2.OutputArgs{
@@ -210,7 +230,7 @@ func (b *txBuilder) createPoolTx(
{ {
Asset: b.net.AssetID, Asset: b.net.AssetID,
Amount: connectorsAmount, Amount: connectorsAmount,
Script: aspScript, Script: connectorScript,
}, },
} }
@@ -354,11 +374,11 @@ func (b *txBuilder) createPoolTx(
} }
func (b *txBuilder) createConnectors( func (b *txBuilder) createConnectors(
poolTx string, payments []domain.Payment, aspPubkey *secp256k1.PublicKey, poolTx string, payments []domain.Payment, connectorAddress string, minRelayFee uint64,
) ([]*psetv2.Pset, error) { ) ([]*psetv2.Pset, error) {
txid, _ := getTxid(poolTx) txid, _ := getTxid(poolTx)
aspScript, err := p2wpkhScript(aspPubkey, b.net) aspScript, err := address.ToOutputScript(connectorAddress)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -378,7 +398,7 @@ func (b *txBuilder) createConnectors(
if numberOfConnectors == 1 { if numberOfConnectors == 1 {
outputs := []psetv2.OutputArgs{connectorOutput} outputs := []psetv2.OutputArgs{connectorOutput}
connectorTx, err := craftConnectorTx(previousInput, outputs) connectorTx, err := craftConnectorTx(previousInput, aspScript, outputs, minRelayFee)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -386,12 +406,16 @@ func (b *txBuilder) createConnectors(
return []*psetv2.Pset{connectorTx}, nil return []*psetv2.Pset{connectorTx}, nil
} }
totalConnectorAmount := connectorAmount * numberOfConnectors totalConnectorAmount := (connectorAmount + minRelayFee) * numberOfConnectors
if numberOfConnectors > 1 {
totalConnectorAmount -= minRelayFee
}
connectors := make([]*psetv2.Pset, 0, numberOfConnectors-1) connectors := make([]*psetv2.Pset, 0, numberOfConnectors-1)
for i := uint64(0); i < numberOfConnectors-1; i++ { for i := uint64(0); i < numberOfConnectors-1; i++ {
outputs := []psetv2.OutputArgs{connectorOutput} outputs := []psetv2.OutputArgs{connectorOutput}
totalConnectorAmount -= connectorAmount totalConnectorAmount -= connectorAmount
totalConnectorAmount -= minRelayFee
if totalConnectorAmount > 0 { if totalConnectorAmount > 0 {
outputs = append(outputs, psetv2.OutputArgs{ outputs = append(outputs, psetv2.OutputArgs{
Asset: b.net.AssetID, Asset: b.net.AssetID,
@@ -399,7 +423,7 @@ func (b *txBuilder) createConnectors(
Amount: totalConnectorAmount, Amount: totalConnectorAmount,
}) })
} }
connectorTx, err := craftConnectorTx(previousInput, outputs) connectorTx, err := craftConnectorTx(previousInput, aspScript, outputs, minRelayFee)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -473,3 +497,23 @@ func (b *txBuilder) createForfeitTxs(
} }
return forfeitTxs, nil return forfeitTxs, nil
} }
func (b *txBuilder) getConnectorAddress(poolTx string) (string, error) {
pset, err := psetv2.NewPsetFromBase64(poolTx)
if err != nil {
return "", err
}
if len(pset.Outputs) < 1 {
return "", fmt.Errorf("connector output not found in pool tx")
}
connectorOutput := pset.Outputs[1]
pay, err := payment.FromScript(connectorOutput.Script, b.net, nil)
if err != nil {
return "", err
}
return pay.WitnessPubKeyHash()
}

View File

@@ -21,6 +21,7 @@ import (
const ( const (
testingKey = "0218d5ca8b58797b7dbd65c075dd7ba7784b3f38ab71b1a5a8e3f94ba0257654a6" testingKey = "0218d5ca8b58797b7dbd65c075dd7ba7784b3f38ab71b1a5a8e3f94ba0257654a6"
connectorAddress = "tex1qekd5u0qj8jl07vy60830xy7n9qtmcx9u3s0cqc"
minRelayFee = uint64(30) minRelayFee = uint64(30)
roundLifetime = int64(1209344) roundLifetime = int64(1209344)
unilateralExitDelay = int64(512) unilateralExitDelay = int64(512)
@@ -37,6 +38,8 @@ func TestMain(m *testing.M) {
Return(uint64(100), nil) Return(uint64(100), nil)
wallet.On("SelectUtxos", mock.Anything, mock.Anything, mock.Anything). wallet.On("SelectUtxos", mock.Anything, mock.Anything, mock.Anything).
Return(randomInput, uint64(0), nil) Return(randomInput, uint64(0), nil)
wallet.On("DeriveConnectorAddress", mock.Anything).
Return(connectorAddress, nil)
pubkeyBytes, _ := hex.DecodeString(testingKey) pubkeyBytes, _ := hex.DecodeString(testingKey)
pubkey, _ = secp256k1.ParsePubKey(pubkeyBytes) pubkey, _ = secp256k1.ParsePubKey(pubkeyBytes)
@@ -56,12 +59,13 @@ func TestBuildPoolTx(t *testing.T) {
if len(fixtures.Valid) > 0 { if len(fixtures.Valid) > 0 {
t.Run("valid", func(t *testing.T) { t.Run("valid", func(t *testing.T) {
for _, f := range fixtures.Valid { for _, f := range fixtures.Valid {
poolTx, congestionTree, err := builder.BuildPoolTx( poolTx, congestionTree, connAddr, err := builder.BuildPoolTx(
pubkey, f.Payments, minRelayFee, pubkey, f.Payments, minRelayFee,
) )
require.NoError(t, err) require.NoError(t, err)
require.NotEmpty(t, poolTx) require.NotEmpty(t, poolTx)
require.NotEmpty(t, congestionTree) require.NotEmpty(t, congestionTree)
require.Equal(t, connectorAddress, connAddr)
require.Equal(t, f.ExpectedNumOfNodes, congestionTree.NumberOfNodes()) require.Equal(t, f.ExpectedNumOfNodes, congestionTree.NumberOfNodes())
require.Len(t, congestionTree.Leaves(), f.ExpectedNumOfLeaves) require.Len(t, congestionTree.Leaves(), f.ExpectedNumOfLeaves)
@@ -76,11 +80,12 @@ func TestBuildPoolTx(t *testing.T) {
if len(fixtures.Invalid) > 0 { if len(fixtures.Invalid) > 0 {
t.Run("invalid", func(t *testing.T) { t.Run("invalid", func(t *testing.T) {
for _, f := range fixtures.Invalid { for _, f := range fixtures.Invalid {
poolTx, congestionTree, err := builder.BuildPoolTx( poolTx, congestionTree, connAddr, err := builder.BuildPoolTx(
pubkey, f.Payments, minRelayFee, pubkey, f.Payments, minRelayFee,
) )
require.EqualError(t, err, f.ExpectedErr) require.EqualError(t, err, f.ExpectedErr)
require.Empty(t, poolTx) require.Empty(t, poolTx)
require.Empty(t, connAddr)
require.Empty(t, congestionTree) require.Empty(t, congestionTree)
} }
}) })
@@ -100,7 +105,7 @@ func TestBuildForfeitTxs(t *testing.T) {
t.Run("valid", func(t *testing.T) { t.Run("valid", func(t *testing.T) {
for _, f := range fixtures.Valid { for _, f := range fixtures.Valid {
connectors, forfeitTxs, err := builder.BuildForfeitTxs( connectors, forfeitTxs, err := builder.BuildForfeitTxs(
pubkey, f.PoolTx, f.Payments, pubkey, f.PoolTx, f.Payments, minRelayFee,
) )
require.NoError(t, err) require.NoError(t, err)
require.Len(t, connectors, f.ExpectedNumOfConnectors) require.Len(t, connectors, f.ExpectedNumOfConnectors)
@@ -114,7 +119,7 @@ func TestBuildForfeitTxs(t *testing.T) {
require.NotNil(t, tx) require.NotNil(t, tx)
require.Len(t, tx.Inputs, 1) require.Len(t, tx.Inputs, 1)
require.Len(t, tx.Outputs, 2) require.Len(t, tx.Outputs, 3)
inputTxid := chainhash.Hash(tx.Inputs[0].PreviousTxid).String() inputTxid := chainhash.Hash(tx.Inputs[0].PreviousTxid).String()
require.Equal(t, expectedInputTxid, inputTxid) require.Equal(t, expectedInputTxid, inputTxid)
@@ -138,7 +143,7 @@ func TestBuildForfeitTxs(t *testing.T) {
t.Run("invalid", func(t *testing.T) { t.Run("invalid", func(t *testing.T) {
for _, f := range fixtures.Invalid { for _, f := range fixtures.Invalid {
connectors, forfeitTxs, err := builder.BuildForfeitTxs( connectors, forfeitTxs, err := builder.BuildForfeitTxs(
pubkey, f.PoolTx, f.Payments, pubkey, f.PoolTx, f.Payments, minRelayFee,
) )
require.EqualError(t, err, f.ExpectedErr) require.EqualError(t, err, f.ExpectedErr)
require.Empty(t, connectors) require.Empty(t, connectors)

View File

@@ -1,12 +1,13 @@
package txbuilder package txbuilder
import ( import (
"github.com/vulpemventures/go-elements/elementsutil"
"github.com/vulpemventures/go-elements/psetv2" "github.com/vulpemventures/go-elements/psetv2"
"github.com/vulpemventures/go-elements/transaction" "github.com/vulpemventures/go-elements/transaction"
) )
func craftConnectorTx( func craftConnectorTx(
input psetv2.InputArgs, outputs []psetv2.OutputArgs, input psetv2.InputArgs, inputScript []byte, outputs []psetv2.OutputArgs, feeAmount uint64,
) (*psetv2.Pset, error) { ) (*psetv2.Pset, error) {
ptx, _ := psetv2.New(nil, nil, nil) ptx, _ := psetv2.New(nil, nil, nil)
updater, _ := psetv2.NewUpdater(ptx) updater, _ := psetv2.NewUpdater(ptx)
@@ -17,9 +18,34 @@ func craftConnectorTx(
return nil, err return nil, err
} }
// TODO: add prevout. var asset []byte
amount := feeAmount
for _, output := range outputs {
amount += output.Amount
if asset == nil {
var err error
asset, err = elementsutil.AssetHashToBytes(output.Asset)
if err != nil {
return nil, err
}
}
}
if err := updater.AddOutputs(outputs); err != nil { value, err := elementsutil.ValueToBytes(amount)
if err != nil {
return nil, err
}
if err := updater.AddInWitnessUtxo(0, transaction.NewTxOutput(asset, value, inputScript)); err != nil {
return nil, err
}
feeOutput := psetv2.OutputArgs{
Asset: outputs[0].Asset,
Amount: feeAmount,
}
if err := updater.AddOutputs(append(outputs, feeOutput)); err != nil {
return nil, err return nil, err
} }

View File

@@ -38,11 +38,6 @@ func craftForfeitTxs(
} }
vtxoAmount, _ := elementsutil.ValueToBytes(vtxo.Amount) vtxoAmount, _ := elementsutil.ValueToBytes(vtxo.Amount)
vtxoPrevout := &transaction.TxOutput{
Asset: connectorPrevout.Asset,
Value: vtxoAmount,
Script: vtxoScript,
}
if err := updater.AddInputs([]psetv2.InputArgs{connectorInput, vtxoInput}); err != nil { if err := updater.AddInputs([]psetv2.InputArgs{connectorInput, vtxoInput}); err != nil {
return nil, err return nil, err
@@ -56,6 +51,8 @@ func craftForfeitTxs(
return nil, err return nil, err
} }
vtxoPrevout := transaction.NewTxOutput(connectorPrevout.Asset, vtxoAmount, vtxoScript)
if err = updater.AddInWitnessUtxo(1, vtxoPrevout); err != nil { if err = updater.AddInWitnessUtxo(1, vtxoPrevout); err != nil {
return nil, err return nil, err
} }

View File

@@ -40,6 +40,17 @@ func (m *mockedWallet) DeriveAddresses(ctx context.Context, num int) ([]string,
return res, args.Error(1) return res, args.Error(1)
} }
// DeriveConnectorAddress implements ports.WalletService.
func (m *mockedWallet) DeriveConnectorAddress(ctx context.Context) (string, error) {
args := m.Called(ctx)
var res string
if a := args.Get(0); a != nil {
res = a.(string)
}
return res, args.Error(1)
}
// GetPubkey implements ports.WalletService. // GetPubkey implements ports.WalletService.
func (m *mockedWallet) GetPubkey(ctx context.Context) (*secp256k1.PublicKey, error) { func (m *mockedWallet) GetPubkey(ctx context.Context) (*secp256k1.PublicKey, error) {
args := m.Called(ctx) args := m.Called(ctx)
@@ -123,6 +134,17 @@ func (m *mockedWallet) SignPsetWithKey(ctx context.Context, pset string, inputIn
return res, args.Error(1) return res, args.Error(1)
} }
func (m *mockedWallet) SignConnectorInput(ctx context.Context, pset string, inputIndexes []int, extract bool) (string, error) {
args := m.Called(ctx, pset, inputIndexes, extract)
var res string
if a := args.Get(0); a != nil {
res = a.(string)
}
return res, args.Error(1)
}
func (m *mockedWallet) WatchScripts( func (m *mockedWallet) WatchScripts(
ctx context.Context, scripts []string, ctx context.Context, scripts []string,
) error { ) error {
@@ -147,6 +169,22 @@ func (m *mockedWallet) GetNotificationChannel(ctx context.Context) chan []domain
return res return res
} }
func (m *mockedWallet) ListConnectorUtxos(ctx context.Context, addr string) ([]ports.TxInput, error) {
args := m.Called(ctx, addr)
var res []ports.TxInput
if a := args.Get(0); a != nil {
res = a.([]ports.TxInput)
}
return res, args.Error(1)
}
func (m *mockedWallet) WaitForSync(ctx context.Context, txid string) error {
args := m.Called(ctx, txid)
return args.Error(0)
}
type mockedInput struct { type mockedInput struct {
mock.Mock mock.Mock
} }