From ab6ae36eb50b6070d83dc85ec96c2d6194884f4a Mon Sep 17 00:00:00 2001 From: Louis Singer <41042567+louisinger@users.noreply.github.com> Date: Fri, 27 Sep 2024 16:09:37 +0200 Subject: [PATCH] [covenantless] Fix coin selection to build round tx (#336) * rework createPoolTx * change address getter * rename BuildPoolTx --> BuildRoundTx --- server/internal/core/application/covenant.go | 2 +- .../internal/core/application/covenantless.go | 2 +- server/internal/core/ports/tx_builder.go | 4 +- .../tx-builder/covenant/builder.go | 2 +- .../tx-builder/covenant/builder_test.go | 4 +- .../tx-builder/covenantless/builder.go | 242 +++++------------- .../tx-builder/covenantless/builder_test.go | 4 +- 7 files changed, 74 insertions(+), 186 deletions(-) diff --git a/server/internal/core/application/covenant.go b/server/internal/core/application/covenant.go index 8f53808..7033ffe 100644 --- a/server/internal/core/application/covenant.go +++ b/server/internal/core/application/covenant.go @@ -485,7 +485,7 @@ func (s *covenantService) startFinalization() { return } - unsignedPoolTx, tree, connectorAddress, err := s.builder.BuildPoolTx(s.pubkey, payments, boardingInputs, sweptRounds) + unsignedPoolTx, tree, connectorAddress, err := s.builder.BuildRoundTx(s.pubkey, payments, boardingInputs, sweptRounds) if err != nil { round.Fail(fmt.Errorf("failed to create pool tx: %s", err)) log.WithError(err).Warn("failed to create pool tx") diff --git a/server/internal/core/application/covenantless.go b/server/internal/core/application/covenantless.go index dcb6286..801db28 100644 --- a/server/internal/core/application/covenantless.go +++ b/server/internal/core/application/covenantless.go @@ -796,7 +796,7 @@ func (s *covenantlessService) startFinalization() { cosigners = append(cosigners, ephemeralKey.PubKey()) - unsignedRoundTx, tree, connectorAddress, err := s.builder.BuildPoolTx(s.pubkey, payments, boardingInputs, sweptRounds, cosigners...) + unsignedRoundTx, tree, connectorAddress, err := s.builder.BuildRoundTx(s.pubkey, payments, boardingInputs, sweptRounds, cosigners...) if err != nil { round.Fail(fmt.Errorf("failed to create pool tx: %s", err)) log.WithError(err).Warn("failed to create pool tx") diff --git a/server/internal/core/ports/tx_builder.go b/server/internal/core/ports/tx_builder.go index ff0803b..46a610e 100644 --- a/server/internal/core/ports/tx_builder.go +++ b/server/internal/core/ports/tx_builder.go @@ -28,10 +28,10 @@ type BoardingInput struct { } type TxBuilder interface { - BuildPoolTx( + BuildRoundTx( aspPubkey *secp256k1.PublicKey, payments []domain.Payment, boardingInputs []BoardingInput, sweptRounds []domain.Round, cosigners ...*secp256k1.PublicKey, - ) (poolTx string, congestionTree tree.CongestionTree, connectorAddress string, err error) + ) (roundTx string, congestionTree tree.CongestionTree, connectorAddress string, err error) BuildForfeitTxs(poolTx string, payments []domain.Payment, minRelayFeeRate chainfee.SatPerKVByte) (connectors []string, forfeitTxs []string, err error) BuildSweepTx(inputs []SweepInput) (signedSweepTx string, err error) GetSweepInput(parentblocktime int64, node tree.Node) (expirationtime int64, sweepInput SweepInput, err error) diff --git a/server/internal/infrastructure/tx-builder/covenant/builder.go b/server/internal/infrastructure/tx-builder/covenant/builder.go index 8c932ea..db0c872 100644 --- a/server/internal/infrastructure/tx-builder/covenant/builder.go +++ b/server/internal/infrastructure/tx-builder/covenant/builder.go @@ -118,7 +118,7 @@ func (b *txBuilder) BuildForfeitTxs( return connectors, forfeitTxs, nil } -func (b *txBuilder) BuildPoolTx( +func (b *txBuilder) BuildRoundTx( aspPubkey *secp256k1.PublicKey, payments []domain.Payment, boardingInputs []ports.BoardingInput, diff --git a/server/internal/infrastructure/tx-builder/covenant/builder_test.go b/server/internal/infrastructure/tx-builder/covenant/builder_test.go index 80f5c1b..30aed55 100644 --- a/server/internal/infrastructure/tx-builder/covenant/builder_test.go +++ b/server/internal/infrastructure/tx-builder/covenant/builder_test.go @@ -67,7 +67,7 @@ func TestBuildPoolTx(t *testing.T) { if len(fixtures.Valid) > 0 { t.Run("valid", func(t *testing.T) { for _, f := range fixtures.Valid { - poolTx, congestionTree, connAddr, err := builder.BuildPoolTx( + poolTx, congestionTree, connAddr, err := builder.BuildRoundTx( pubkey, f.Payments, []ports.BoardingInput{}, []domain.Round{}, ) require.NoError(t, err) @@ -88,7 +88,7 @@ func TestBuildPoolTx(t *testing.T) { if len(fixtures.Invalid) > 0 { t.Run("invalid", func(t *testing.T) { for _, f := range fixtures.Invalid { - poolTx, congestionTree, connAddr, err := builder.BuildPoolTx( + poolTx, congestionTree, connAddr, err := builder.BuildRoundTx( pubkey, f.Payments, []ports.BoardingInput{}, []domain.Round{}, ) require.EqualError(t, err, f.ExpectedErr) diff --git a/server/internal/infrastructure/tx-builder/covenantless/builder.go b/server/internal/infrastructure/tx-builder/covenantless/builder.go index ebbd9ca..99c2849 100644 --- a/server/internal/infrastructure/tx-builder/covenantless/builder.go +++ b/server/internal/infrastructure/tx-builder/covenantless/builder.go @@ -257,13 +257,13 @@ func (b *txBuilder) BuildForfeitTxs( return connectors, forfeitTxs, nil } -func (b *txBuilder) BuildPoolTx( +func (b *txBuilder) BuildRoundTx( aspPubkey *secp256k1.PublicKey, payments []domain.Payment, boardingInputs []ports.BoardingInput, sweptRounds []domain.Round, cosigners ...*secp256k1.PublicKey, -) (poolTx string, congestionTree tree.CongestionTree, connectorAddress string, err error) { +) (roundTx string, congestionTree tree.CongestionTree, connectorAddress string, err error) { var sharedOutputScript []byte var sharedOutputAmount int64 @@ -295,14 +295,14 @@ func (b *txBuilder) BuildPoolTx( return } - ptx, err := b.createPoolTx( + ptx, err := b.createRoundTx( sharedOutputAmount, sharedOutputScript, payments, boardingInputs, connectorAddress, sweptRounds, ) if err != nil { return } - poolTx, err = ptx.B64Encode() + roundTx, err = ptx.B64Encode() if err != nil { return } @@ -614,7 +614,7 @@ func (b *txBuilder) BuildAsyncPaymentTransactions( } // TODO use lnd CoinSelect to craft the pool tx -func (b *txBuilder) createPoolTx( +func (b *txBuilder) createRoundTx( sharedOutputAmount int64, sharedOutputScript []byte, payments []domain.Payment, @@ -699,30 +699,40 @@ func (b *txBuilder) createPoolTx( return nil, err } - var dust uint64 + var cacheChangeScript []byte + // avoid derivation of several change addresses + getChange := func() ([]byte, error) { + if len(cacheChangeScript) > 0 { + return cacheChangeScript, nil + } + + changeAddresses, err := b.wallet.DeriveAddresses(ctx, 1) + if err != nil { + return nil, err + } + + changeAddress, err := btcutil.DecodeAddress(changeAddresses[0], b.onchainNetwork()) + if err != nil { + return nil, err + } + + return txscript.PayToAddrScript(changeAddress) + } + + exceedingValue := uint64(0) if change > 0 { - if change < dustLimit { - dust = change + if change <= dustLimit { + exceedingValue = change change = 0 } else { - address, err := b.wallet.DeriveAddresses(ctx, 1) - if err != nil { - return nil, err - } - - addr, err := btcutil.DecodeAddress(address[0], b.onchainNetwork()) - if err != nil { - return nil, err - } - - aspScript, err := txscript.PayToAddrScript(addr) + changeScript, err := getChange() if err != nil { return nil, err } outputs = append(outputs, &wire.TxOut{ Value: int64(change), - PkScript: aspScript, + PkScript: changeScript, }) } } @@ -832,155 +842,30 @@ func (b *txBuilder) createPoolTx( return nil, err } - if dust > feeAmount { - feeAmount = dust - } else { - feeAmount += dust - } + for feeAmount > exceedingValue { + feesToPay := feeAmount - exceedingValue - if dust == 0 { - if feeAmount == change { - // fees = change, remove change output - ptx.UnsignedTx.TxOut = ptx.UnsignedTx.TxOut[:len(ptx.UnsignedTx.TxOut)-1] - ptx.Outputs = ptx.Outputs[:len(ptx.Outputs)-1] - } else if feeAmount < change { - // change covers the fees, reduce change amount - ptx.UnsignedTx.TxOut[len(ptx.Outputs)-1].Value = int64(change - feeAmount) - } else { - // change is not enough to cover fees, re-select utxos - if change > 0 { - // remove change output if present + // change is able to cover the remaining fees + if change > feesToPay { + newChange := change - (feeAmount - exceedingValue) + // new change amount is less than dust limit, let's remove it + if newChange <= dustLimit { ptx.UnsignedTx.TxOut = ptx.UnsignedTx.TxOut[:len(ptx.UnsignedTx.TxOut)-1] ptx.Outputs = ptx.Outputs[:len(ptx.Outputs)-1] - } - newUtxos, change, err := b.selectUtxos(ctx, sweptRounds, feeAmount-change) - if err != nil { - return nil, err - } - - dust = 0 - if change > 0 { - if change < dustLimit { - dust = change - change = 0 - } else { - address, err := b.wallet.DeriveAddresses(ctx, 1) - if err != nil { - return nil, err - } - - addr, err := btcutil.DecodeAddress(address[0], b.onchainNetwork()) - if err != nil { - return nil, err - } - - aspScript, err := txscript.PayToAddrScript(addr) - if err != nil { - return nil, err - } - - ptx.UnsignedTx.AddTxOut(&wire.TxOut{ - Value: int64(change), - PkScript: aspScript, - }) - ptx.Outputs = append(ptx.Outputs, psbt.POutput{}) - } - } - - for _, utxo := range newUtxos { - txhash, err := chainhash.NewHashFromStr(utxo.GetTxid()) - if err != nil { - return nil, err - } - - outpoint := &wire.OutPoint{ - Hash: *txhash, - Index: utxo.GetIndex(), - } - - ptx.UnsignedTx.AddTxIn(wire.NewTxIn(outpoint, nil, nil)) - ptx.Inputs = append(ptx.Inputs, psbt.PInput{}) - - scriptBytes, err := hex.DecodeString(utxo.GetScript()) - if err != nil { - return nil, err - } - - if err := updater.AddInWitnessUtxo( - &wire.TxOut{ - Value: int64(utxo.GetValue()), - PkScript: scriptBytes, - }, - len(ptx.UnsignedTx.TxIn)-1, - ); err != nil { - return nil, err - } - } - - b64, err = ptx.B64Encode() - if err != nil { - return nil, err - } - - feeAmount, err = b.wallet.EstimateFees(ctx, b64) - if err != nil { - return nil, err - } - - if dust > feeAmount { - feeAmount = dust } else { - feeAmount += dust + ptx.UnsignedTx.TxOut[len(ptx.Outputs)-1].Value = int64(newChange) } - if dust == 0 { - if feeAmount == change { - // fees = change, remove change output - ptx.UnsignedTx.TxOut = ptx.UnsignedTx.TxOut[:len(ptx.UnsignedTx.TxOut)-1] - ptx.Outputs = ptx.Outputs[:len(ptx.Outputs)-1] - } else if feeAmount < change { - // change covers the fees, reduce change amount - ptx.UnsignedTx.TxOut[len(ptx.Outputs)-1].Value = int64(change - feeAmount) - } else { - return nil, fmt.Errorf("change is not enough to cover fees") - } - } + break } - } else if feeAmount-dust > 0 { - newUtxos, change, err := b.selectUtxos(ctx, sweptRounds, feeAmount-dust) + + // change is not enough to cover the remaining fees, let's re-select utxos + newUtxos, newChange, err := b.wallet.SelectUtxos(ctx, "", feeAmount-exceedingValue) if err != nil { return nil, err } - dust = 0 - if change > 0 { - if change < dustLimit { - dust = change - change = 0 - } else { - address, err := b.wallet.DeriveAddresses(ctx, 1) - if err != nil { - return nil, err - } - - addr, err := btcutil.DecodeAddress(address[0], b.onchainNetwork()) - if err != nil { - return nil, err - } - - aspScript, err := txscript.PayToAddrScript(addr) - if err != nil { - return nil, err - } - - ptx.UnsignedTx.AddTxOut(&wire.TxOut{ - Value: int64(change), - PkScript: aspScript, - }) - ptx.Outputs = append(ptx.Outputs, psbt.POutput{}) - } - } - + // add new inputs for _, utxo := range newUtxos { txhash, err := chainhash.NewHashFromStr(utxo.GetTxid()) if err != nil { @@ -1011,34 +896,37 @@ func (b *txBuilder) createPoolTx( } } + // add new change output if necessary + if newChange > 0 { + if newChange <= dustLimit { + newChange = 0 + exceedingValue += newChange + } else { + changeScript, err := getChange() + if err != nil { + return nil, err + } + + ptx.UnsignedTx.AddTxOut(&wire.TxOut{ + Value: int64(newChange), + PkScript: changeScript, + }) + ptx.Outputs = append(ptx.Outputs, psbt.POutput{}) + } + } + b64, err = ptx.B64Encode() if err != nil { return nil, err } - feeAmount, err = b.wallet.EstimateFees(ctx, b64) + newFeeAmount, err := b.wallet.EstimateFees(ctx, b64) if err != nil { return nil, err } - if dust > feeAmount { - feeAmount = dust - } else { - feeAmount += dust - } - - if dust == 0 { - if feeAmount == change { - // fees = change, remove change output - ptx.UnsignedTx.TxOut = ptx.UnsignedTx.TxOut[:len(ptx.UnsignedTx.TxOut)-1] - ptx.Outputs = ptx.Outputs[:len(ptx.Outputs)-1] - } else if feeAmount < change { - // change covers the fees, reduce change amount - ptx.UnsignedTx.TxOut[len(ptx.Outputs)-1].Value = int64(change - feeAmount) - } else { - return nil, fmt.Errorf("change is not enough to cover fees") - } - } + feeAmount = newFeeAmount + change = newChange } // remove input taproot leaf script diff --git a/server/internal/infrastructure/tx-builder/covenantless/builder_test.go b/server/internal/infrastructure/tx-builder/covenantless/builder_test.go index 762d5d8..7411507 100644 --- a/server/internal/infrastructure/tx-builder/covenantless/builder_test.go +++ b/server/internal/infrastructure/tx-builder/covenantless/builder_test.go @@ -77,7 +77,7 @@ func TestBuildPoolTx(t *testing.T) { cosigners = append(cosigners, randKey.PubKey()) } - poolTx, congestionTree, connAddr, err := builder.BuildPoolTx( + poolTx, congestionTree, connAddr, err := builder.BuildRoundTx( pubkey, f.Payments, []ports.BoardingInput{}, []domain.Round{}, cosigners..., ) require.NoError(t, err) @@ -98,7 +98,7 @@ func TestBuildPoolTx(t *testing.T) { if len(fixtures.Invalid) > 0 { t.Run("invalid", func(t *testing.T) { for _, f := range fixtures.Invalid { - poolTx, congestionTree, connAddr, err := builder.BuildPoolTx( + poolTx, congestionTree, connAddr, err := builder.BuildRoundTx( pubkey, f.Payments, []ports.BoardingInput{}, []domain.Round{}, ) require.EqualError(t, err, f.ExpectedErr)