Prevent getting cheated by broadcasting forfeit transactions (#123)

* broadcast forfeit transaction in case the user is trying the cheat the ASP

* fix connector input + --cheat flag in CLI

* WIP

* cleaning and fixes

* add TODO

* sweeper.go: mark round swept if vtxo are redeemed

* fixes after reviews

* revert "--cheat" flag in client

* revert redeem.go

* optimization

* update account.go according to ocean ListUtxos new spec

* WaitForSync implementation

* ocean-wallet/service.go: remove go rountine while writing to notification channel
This commit is contained in:
Louis Singer
2024-03-04 13:58:36 +01:00
committed by GitHub
parent 6d0d03e316
commit 066e8eeabb
21 changed files with 537 additions and 120 deletions

View File

@@ -5,6 +5,7 @@ import (
"context"
"encoding/hex"
"fmt"
"sync"
"time"
"github.com/ark-network/ark/common"
@@ -286,7 +287,7 @@ func (s *service) startFinalization() {
return
}
unsignedPoolTx, tree, err := s.builder.BuildPoolTx(s.pubkey, payments, s.minRelayFee)
unsignedPoolTx, tree, connectorAddress, err := s.builder.BuildPoolTx(s.pubkey, payments, s.minRelayFee)
if err != nil {
changes = round.Fail(fmt.Errorf("failed to create pool tx: %s", err))
log.WithError(err).Warn("failed to create pool tx")
@@ -295,7 +296,7 @@ func (s *service) startFinalization() {
log.Debugf("pool tx created for round %s", round.Id)
connectors, forfeitTxs, err := s.builder.BuildForfeitTxs(s.pubkey, unsignedPoolTx, payments)
connectors, forfeitTxs, err := s.builder.BuildForfeitTxs(s.pubkey, unsignedPoolTx, payments, s.minRelayFee)
if err != nil {
changes = round.Fail(fmt.Errorf("failed to create connectors and forfeit txs: %s", err))
log.WithError(err).Warn("failed to create connectors and forfeit txs")
@@ -304,7 +305,7 @@ func (s *service) startFinalization() {
log.Debugf("forfeit transactions created for round %s", round.Id)
events, err := round.StartFinalization(connectors, tree, unsignedPoolTx)
events, err := round.StartFinalization(connectorAddress, connectors, tree, unsignedPoolTx)
if err != nil {
changes = round.Fail(fmt.Errorf("failed to start finalization: %s", err))
log.WithError(err).Warn("failed to start finalization")
@@ -369,25 +370,162 @@ func (s *service) finalizeRound() {
func (s *service) listenToRedemptions() {
ctx := context.Background()
chVtxos := s.scanner.GetNotificationChannel(ctx)
mutx := &sync.Mutex{}
for vtxoKeys := range chVtxos {
if len(vtxoKeys) > 0 {
for {
// TODO: make sure that the vtxos haven't been already spent, otherwise
// broadcast the corresponding forfeit tx and connector to prevent
// getting cheated.
vtxos, err := s.repoManager.Vtxos().RedeemVtxos(ctx, vtxoKeys)
go func(vtxoKeys []domain.VtxoKey) {
vtxosRepo := s.repoManager.Vtxos()
roundRepo := s.repoManager.Rounds()
for _, v := range vtxoKeys {
vtxos, err := vtxosRepo.GetVtxos(ctx, []domain.VtxoKey{v})
if err != nil {
log.WithError(err).Warn("failed to redeem vtxos, retrying...")
time.Sleep(100 * time.Millisecond)
log.WithError(err).Warn("failed to retrieve vtxos, skipping...")
continue
}
if len(vtxos) > 0 {
log.Debugf("redeemed %d vtxos", len(vtxos))
vtxo := vtxos[0]
if vtxo.Redeemed {
continue
}
if _, err := s.repoManager.Vtxos().RedeemVtxos(ctx, []domain.VtxoKey{vtxo.VtxoKey}); err != nil {
log.WithError(err).Warn("failed to redeem vtxos, retrying...")
continue
}
log.Debugf("vtxo %s redeemed", vtxo.Txid)
if !vtxo.Spent {
continue
}
log.Debugf("fraud detected on vtxo %s", vtxo.Txid)
round, err := roundRepo.GetRoundWithTxid(ctx, vtxo.SpentBy)
if err != nil {
log.WithError(err).Warn("failed to retrieve round")
continue
}
mutx.Lock()
defer mutx.Unlock()
connectorTxid, connectorVout, err := s.getNextConnector(ctx, *round)
if err != nil {
log.WithError(err).Warn("failed to retrieve next connector")
continue
}
forfeitTx, err := findForfeitTx(round.ForfeitTxs, connectorTxid, connectorVout, vtxo.Txid)
if err != nil {
log.WithError(err).Warn("failed to retrieve forfeit tx")
continue
}
signedForfeitTx, err := s.wallet.SignConnectorInput(ctx, forfeitTx, []int{0}, false)
if err != nil {
log.WithError(err).Warn("failed to sign connector input in forfeit tx")
continue
}
signedForfeitTx, err = s.wallet.SignPsetWithKey(ctx, signedForfeitTx, []int{1})
if err != nil {
log.WithError(err).Warn("failed to sign vtxo input in forfeit tx")
continue
}
forfeitTxHex, err := finalizeAndExtractForfeit(signedForfeitTx)
if err != nil {
log.WithError(err).Warn("failed to finalize forfeit tx")
continue
}
forfeitTxid, err := s.wallet.BroadcastTransaction(ctx, forfeitTxHex)
if err != nil {
log.WithError(err).Warn("failed to broadcast forfeit tx")
continue
}
log.Debugf("broadcasted forfeit tx %s", forfeitTxid)
}
}(vtxoKeys)
}
}
func (s *service) getNextConnector(
ctx context.Context,
round domain.Round,
) (string, uint32, error) {
connectorTx, err := psetv2.NewPsetFromBase64(round.Connectors[0])
if err != nil {
return "", 0, err
}
prevout := connectorTx.Inputs[0].WitnessUtxo
if prevout == nil {
return "", 0, fmt.Errorf("connector prevout not found")
}
utxos, err := s.wallet.ListConnectorUtxos(ctx, round.ConnectorAddress)
if err != nil {
return "", 0, err
}
// if we do not find any utxos, we make sure to wait for the connector outpoint to be confirmed then we retry
if len(utxos) <= 0 {
if err := s.wallet.WaitForSync(ctx, round.Txid); err != nil {
return "", 0, err
}
utxos, err = s.wallet.ListConnectorUtxos(ctx, round.ConnectorAddress)
if err != nil {
return "", 0, err
}
}
// search for an already existing connector
for _, u := range utxos {
if u.GetValue() == 450 {
return u.GetTxid(), u.GetIndex(), nil
}
}
for _, u := range utxos {
if u.GetValue() > 450 {
for _, b64 := range round.Connectors {
pset, err := psetv2.NewPsetFromBase64(b64)
if err != nil {
return "", 0, err
}
for _, i := range pset.Inputs {
if chainhash.Hash(i.PreviousTxid).String() == u.GetTxid() && i.PreviousTxIndex == u.GetIndex() {
// sign & broadcast the connector tx
signedConnectorTx, err := s.wallet.SignConnectorInput(ctx, b64, []int{0}, true)
if err != nil {
return "", 0, err
}
connectorTxid, err := s.wallet.BroadcastTransaction(ctx, signedConnectorTx)
if err != nil {
return "", 0, err
}
log.Debugf("broadcasted connector tx %s", connectorTxid)
// wait for the connector tx to be in the mempool
if err := s.wallet.WaitForSync(ctx, connectorTxid); err != nil {
return "", 0, err
}
return connectorTxid, 0, nil
}
}
break
}
}
}
return "", 0, fmt.Errorf("no connector utxos found")
}
func (s *service) updateVtxoSet(round *domain.Round) {
@@ -401,7 +539,7 @@ func (s *service) updateVtxoSet(round *domain.Round) {
spentVtxos := getSpentVtxos(round.Payments)
if len(spentVtxos) > 0 {
for {
if err := repo.SpendVtxos(ctx, spentVtxos); err != nil {
if err := repo.SpendVtxos(ctx, spentVtxos, round.Txid); err != nil {
log.WithError(err).Warn("failed to add new vtxos, retrying soon")
time.Sleep(100 * time.Millisecond)
continue
@@ -535,12 +673,28 @@ func (s *service) stopWatchingVtxos(vtxos []domain.Vtxo) {
}
func (s *service) restoreWatchingVtxos() error {
vtxos, err := s.repoManager.Vtxos().GetSpendableVtxos(
context.Background(), "",
)
sweepableRounds, err := s.repoManager.Rounds().GetSweepableRounds(context.Background())
if err != nil {
return err
}
vtxos := make([]domain.Vtxo, 0)
for _, round := range sweepableRounds {
fromRound, err := s.repoManager.Vtxos().GetVtxosForRound(
context.Background(), round.Txid,
)
if err != nil {
log.WithError(err).Warnf("failed to retrieve vtxos for round %s", round.Txid)
continue
}
for _, v := range fromRound {
if !v.Swept && !v.Redeemed {
vtxos = append(vtxos, v)
}
}
}
if len(vtxos) <= 0 {
return nil
}
@@ -617,3 +771,45 @@ func getPaymentsFromOnboarding(
payment := domain.NewPaymentUnsafe(nil, receivers)
return []domain.Payment{*payment}
}
func finalizeAndExtractForfeit(b64 string) (string, error) {
p, err := psetv2.NewPsetFromBase64(b64)
if err != nil {
return "", err
}
// finalize connector input
if err := psetv2.FinalizeAll(p); err != nil {
return "", err
}
// extract the forfeit tx
extracted, err := psetv2.Extract(p)
if err != nil {
return "", err
}
return extracted.ToHex()
}
func findForfeitTx(
forfeits []string, connectorTxid string, connectorVout uint32, vtxoTxid string,
) (string, error) {
for _, forfeit := range forfeits {
forfeitTx, err := psetv2.NewPsetFromBase64(forfeit)
if err != nil {
return "", err
}
connector := forfeitTx.Inputs[0]
vtxoInput := forfeitTx.Inputs[1]
if chainhash.Hash(connector.PreviousTxid).String() == connectorTxid &&
connector.PreviousTxIndex == connectorVout &&
chainhash.Hash(vtxoInput.PreviousTxid).String() == vtxoTxid {
return forfeit, nil
}
}
return "", fmt.Errorf("forfeit tx not found")
}