mirror of
https://github.com/aljazceru/ark.git
synced 2025-12-18 12:44:19 +01:00
[covenantless] Fix coin selection to build round tx (#336)
* rework createPoolTx * change address getter * rename BuildPoolTx --> BuildRoundTx
This commit is contained in:
@@ -485,7 +485,7 @@ func (s *covenantService) startFinalization() {
|
||||
return
|
||||
}
|
||||
|
||||
unsignedPoolTx, tree, connectorAddress, err := s.builder.BuildPoolTx(s.pubkey, payments, boardingInputs, sweptRounds)
|
||||
unsignedPoolTx, tree, connectorAddress, err := s.builder.BuildRoundTx(s.pubkey, payments, boardingInputs, sweptRounds)
|
||||
if err != nil {
|
||||
round.Fail(fmt.Errorf("failed to create pool tx: %s", err))
|
||||
log.WithError(err).Warn("failed to create pool tx")
|
||||
|
||||
@@ -796,7 +796,7 @@ func (s *covenantlessService) startFinalization() {
|
||||
|
||||
cosigners = append(cosigners, ephemeralKey.PubKey())
|
||||
|
||||
unsignedRoundTx, tree, connectorAddress, err := s.builder.BuildPoolTx(s.pubkey, payments, boardingInputs, sweptRounds, cosigners...)
|
||||
unsignedRoundTx, tree, connectorAddress, err := s.builder.BuildRoundTx(s.pubkey, payments, boardingInputs, sweptRounds, cosigners...)
|
||||
if err != nil {
|
||||
round.Fail(fmt.Errorf("failed to create pool tx: %s", err))
|
||||
log.WithError(err).Warn("failed to create pool tx")
|
||||
|
||||
@@ -28,10 +28,10 @@ type BoardingInput struct {
|
||||
}
|
||||
|
||||
type TxBuilder interface {
|
||||
BuildPoolTx(
|
||||
BuildRoundTx(
|
||||
aspPubkey *secp256k1.PublicKey, payments []domain.Payment, boardingInputs []BoardingInput, sweptRounds []domain.Round,
|
||||
cosigners ...*secp256k1.PublicKey,
|
||||
) (poolTx string, congestionTree tree.CongestionTree, connectorAddress string, err error)
|
||||
) (roundTx string, congestionTree tree.CongestionTree, connectorAddress string, err error)
|
||||
BuildForfeitTxs(poolTx string, payments []domain.Payment, minRelayFeeRate chainfee.SatPerKVByte) (connectors []string, forfeitTxs []string, err error)
|
||||
BuildSweepTx(inputs []SweepInput) (signedSweepTx string, err error)
|
||||
GetSweepInput(parentblocktime int64, node tree.Node) (expirationtime int64, sweepInput SweepInput, err error)
|
||||
|
||||
@@ -118,7 +118,7 @@ func (b *txBuilder) BuildForfeitTxs(
|
||||
return connectors, forfeitTxs, nil
|
||||
}
|
||||
|
||||
func (b *txBuilder) BuildPoolTx(
|
||||
func (b *txBuilder) BuildRoundTx(
|
||||
aspPubkey *secp256k1.PublicKey,
|
||||
payments []domain.Payment,
|
||||
boardingInputs []ports.BoardingInput,
|
||||
|
||||
@@ -67,7 +67,7 @@ func TestBuildPoolTx(t *testing.T) {
|
||||
if len(fixtures.Valid) > 0 {
|
||||
t.Run("valid", func(t *testing.T) {
|
||||
for _, f := range fixtures.Valid {
|
||||
poolTx, congestionTree, connAddr, err := builder.BuildPoolTx(
|
||||
poolTx, congestionTree, connAddr, err := builder.BuildRoundTx(
|
||||
pubkey, f.Payments, []ports.BoardingInput{}, []domain.Round{},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
@@ -88,7 +88,7 @@ func TestBuildPoolTx(t *testing.T) {
|
||||
if len(fixtures.Invalid) > 0 {
|
||||
t.Run("invalid", func(t *testing.T) {
|
||||
for _, f := range fixtures.Invalid {
|
||||
poolTx, congestionTree, connAddr, err := builder.BuildPoolTx(
|
||||
poolTx, congestionTree, connAddr, err := builder.BuildRoundTx(
|
||||
pubkey, f.Payments, []ports.BoardingInput{}, []domain.Round{},
|
||||
)
|
||||
require.EqualError(t, err, f.ExpectedErr)
|
||||
|
||||
@@ -257,13 +257,13 @@ func (b *txBuilder) BuildForfeitTxs(
|
||||
return connectors, forfeitTxs, nil
|
||||
}
|
||||
|
||||
func (b *txBuilder) BuildPoolTx(
|
||||
func (b *txBuilder) BuildRoundTx(
|
||||
aspPubkey *secp256k1.PublicKey,
|
||||
payments []domain.Payment,
|
||||
boardingInputs []ports.BoardingInput,
|
||||
sweptRounds []domain.Round,
|
||||
cosigners ...*secp256k1.PublicKey,
|
||||
) (poolTx string, congestionTree tree.CongestionTree, connectorAddress string, err error) {
|
||||
) (roundTx string, congestionTree tree.CongestionTree, connectorAddress string, err error) {
|
||||
var sharedOutputScript []byte
|
||||
var sharedOutputAmount int64
|
||||
|
||||
@@ -295,14 +295,14 @@ func (b *txBuilder) BuildPoolTx(
|
||||
return
|
||||
}
|
||||
|
||||
ptx, err := b.createPoolTx(
|
||||
ptx, err := b.createRoundTx(
|
||||
sharedOutputAmount, sharedOutputScript, payments, boardingInputs, connectorAddress, sweptRounds,
|
||||
)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
poolTx, err = ptx.B64Encode()
|
||||
roundTx, err = ptx.B64Encode()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@@ -614,7 +614,7 @@ func (b *txBuilder) BuildAsyncPaymentTransactions(
|
||||
}
|
||||
|
||||
// TODO use lnd CoinSelect to craft the pool tx
|
||||
func (b *txBuilder) createPoolTx(
|
||||
func (b *txBuilder) createRoundTx(
|
||||
sharedOutputAmount int64,
|
||||
sharedOutputScript []byte,
|
||||
payments []domain.Payment,
|
||||
@@ -699,30 +699,40 @@ func (b *txBuilder) createPoolTx(
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var dust uint64
|
||||
var cacheChangeScript []byte
|
||||
// avoid derivation of several change addresses
|
||||
getChange := func() ([]byte, error) {
|
||||
if len(cacheChangeScript) > 0 {
|
||||
return cacheChangeScript, nil
|
||||
}
|
||||
|
||||
changeAddresses, err := b.wallet.DeriveAddresses(ctx, 1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
changeAddress, err := btcutil.DecodeAddress(changeAddresses[0], b.onchainNetwork())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return txscript.PayToAddrScript(changeAddress)
|
||||
}
|
||||
|
||||
exceedingValue := uint64(0)
|
||||
if change > 0 {
|
||||
if change < dustLimit {
|
||||
dust = change
|
||||
if change <= dustLimit {
|
||||
exceedingValue = change
|
||||
change = 0
|
||||
} else {
|
||||
address, err := b.wallet.DeriveAddresses(ctx, 1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
addr, err := btcutil.DecodeAddress(address[0], b.onchainNetwork())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
aspScript, err := txscript.PayToAddrScript(addr)
|
||||
changeScript, err := getChange()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
outputs = append(outputs, &wire.TxOut{
|
||||
Value: int64(change),
|
||||
PkScript: aspScript,
|
||||
PkScript: changeScript,
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -832,155 +842,30 @@ func (b *txBuilder) createPoolTx(
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if dust > feeAmount {
|
||||
feeAmount = dust
|
||||
} else {
|
||||
feeAmount += dust
|
||||
}
|
||||
for feeAmount > exceedingValue {
|
||||
feesToPay := feeAmount - exceedingValue
|
||||
|
||||
if dust == 0 {
|
||||
if feeAmount == change {
|
||||
// fees = change, remove change output
|
||||
ptx.UnsignedTx.TxOut = ptx.UnsignedTx.TxOut[:len(ptx.UnsignedTx.TxOut)-1]
|
||||
ptx.Outputs = ptx.Outputs[:len(ptx.Outputs)-1]
|
||||
} else if feeAmount < change {
|
||||
// change covers the fees, reduce change amount
|
||||
ptx.UnsignedTx.TxOut[len(ptx.Outputs)-1].Value = int64(change - feeAmount)
|
||||
} else {
|
||||
// change is not enough to cover fees, re-select utxos
|
||||
if change > 0 {
|
||||
// remove change output if present
|
||||
// change is able to cover the remaining fees
|
||||
if change > feesToPay {
|
||||
newChange := change - (feeAmount - exceedingValue)
|
||||
// new change amount is less than dust limit, let's remove it
|
||||
if newChange <= dustLimit {
|
||||
ptx.UnsignedTx.TxOut = ptx.UnsignedTx.TxOut[:len(ptx.UnsignedTx.TxOut)-1]
|
||||
ptx.Outputs = ptx.Outputs[:len(ptx.Outputs)-1]
|
||||
}
|
||||
newUtxos, change, err := b.selectUtxos(ctx, sweptRounds, feeAmount-change)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
dust = 0
|
||||
if change > 0 {
|
||||
if change < dustLimit {
|
||||
dust = change
|
||||
change = 0
|
||||
} else {
|
||||
address, err := b.wallet.DeriveAddresses(ctx, 1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
addr, err := btcutil.DecodeAddress(address[0], b.onchainNetwork())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
aspScript, err := txscript.PayToAddrScript(addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ptx.UnsignedTx.AddTxOut(&wire.TxOut{
|
||||
Value: int64(change),
|
||||
PkScript: aspScript,
|
||||
})
|
||||
ptx.Outputs = append(ptx.Outputs, psbt.POutput{})
|
||||
}
|
||||
}
|
||||
|
||||
for _, utxo := range newUtxos {
|
||||
txhash, err := chainhash.NewHashFromStr(utxo.GetTxid())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
outpoint := &wire.OutPoint{
|
||||
Hash: *txhash,
|
||||
Index: utxo.GetIndex(),
|
||||
}
|
||||
|
||||
ptx.UnsignedTx.AddTxIn(wire.NewTxIn(outpoint, nil, nil))
|
||||
ptx.Inputs = append(ptx.Inputs, psbt.PInput{})
|
||||
|
||||
scriptBytes, err := hex.DecodeString(utxo.GetScript())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := updater.AddInWitnessUtxo(
|
||||
&wire.TxOut{
|
||||
Value: int64(utxo.GetValue()),
|
||||
PkScript: scriptBytes,
|
||||
},
|
||||
len(ptx.UnsignedTx.TxIn)-1,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
b64, err = ptx.B64Encode()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
feeAmount, err = b.wallet.EstimateFees(ctx, b64)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if dust > feeAmount {
|
||||
feeAmount = dust
|
||||
} else {
|
||||
feeAmount += dust
|
||||
ptx.UnsignedTx.TxOut[len(ptx.Outputs)-1].Value = int64(newChange)
|
||||
}
|
||||
|
||||
if dust == 0 {
|
||||
if feeAmount == change {
|
||||
// fees = change, remove change output
|
||||
ptx.UnsignedTx.TxOut = ptx.UnsignedTx.TxOut[:len(ptx.UnsignedTx.TxOut)-1]
|
||||
ptx.Outputs = ptx.Outputs[:len(ptx.Outputs)-1]
|
||||
} else if feeAmount < change {
|
||||
// change covers the fees, reduce change amount
|
||||
ptx.UnsignedTx.TxOut[len(ptx.Outputs)-1].Value = int64(change - feeAmount)
|
||||
} else {
|
||||
return nil, fmt.Errorf("change is not enough to cover fees")
|
||||
}
|
||||
}
|
||||
break
|
||||
}
|
||||
} else if feeAmount-dust > 0 {
|
||||
newUtxos, change, err := b.selectUtxos(ctx, sweptRounds, feeAmount-dust)
|
||||
|
||||
// change is not enough to cover the remaining fees, let's re-select utxos
|
||||
newUtxos, newChange, err := b.wallet.SelectUtxos(ctx, "", feeAmount-exceedingValue)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
dust = 0
|
||||
if change > 0 {
|
||||
if change < dustLimit {
|
||||
dust = change
|
||||
change = 0
|
||||
} else {
|
||||
address, err := b.wallet.DeriveAddresses(ctx, 1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
addr, err := btcutil.DecodeAddress(address[0], b.onchainNetwork())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
aspScript, err := txscript.PayToAddrScript(addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ptx.UnsignedTx.AddTxOut(&wire.TxOut{
|
||||
Value: int64(change),
|
||||
PkScript: aspScript,
|
||||
})
|
||||
ptx.Outputs = append(ptx.Outputs, psbt.POutput{})
|
||||
}
|
||||
}
|
||||
|
||||
// add new inputs
|
||||
for _, utxo := range newUtxos {
|
||||
txhash, err := chainhash.NewHashFromStr(utxo.GetTxid())
|
||||
if err != nil {
|
||||
@@ -1011,34 +896,37 @@ func (b *txBuilder) createPoolTx(
|
||||
}
|
||||
}
|
||||
|
||||
// add new change output if necessary
|
||||
if newChange > 0 {
|
||||
if newChange <= dustLimit {
|
||||
newChange = 0
|
||||
exceedingValue += newChange
|
||||
} else {
|
||||
changeScript, err := getChange()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ptx.UnsignedTx.AddTxOut(&wire.TxOut{
|
||||
Value: int64(newChange),
|
||||
PkScript: changeScript,
|
||||
})
|
||||
ptx.Outputs = append(ptx.Outputs, psbt.POutput{})
|
||||
}
|
||||
}
|
||||
|
||||
b64, err = ptx.B64Encode()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
feeAmount, err = b.wallet.EstimateFees(ctx, b64)
|
||||
newFeeAmount, err := b.wallet.EstimateFees(ctx, b64)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if dust > feeAmount {
|
||||
feeAmount = dust
|
||||
} else {
|
||||
feeAmount += dust
|
||||
}
|
||||
|
||||
if dust == 0 {
|
||||
if feeAmount == change {
|
||||
// fees = change, remove change output
|
||||
ptx.UnsignedTx.TxOut = ptx.UnsignedTx.TxOut[:len(ptx.UnsignedTx.TxOut)-1]
|
||||
ptx.Outputs = ptx.Outputs[:len(ptx.Outputs)-1]
|
||||
} else if feeAmount < change {
|
||||
// change covers the fees, reduce change amount
|
||||
ptx.UnsignedTx.TxOut[len(ptx.Outputs)-1].Value = int64(change - feeAmount)
|
||||
} else {
|
||||
return nil, fmt.Errorf("change is not enough to cover fees")
|
||||
}
|
||||
}
|
||||
feeAmount = newFeeAmount
|
||||
change = newChange
|
||||
}
|
||||
|
||||
// remove input taproot leaf script
|
||||
|
||||
@@ -77,7 +77,7 @@ func TestBuildPoolTx(t *testing.T) {
|
||||
cosigners = append(cosigners, randKey.PubKey())
|
||||
}
|
||||
|
||||
poolTx, congestionTree, connAddr, err := builder.BuildPoolTx(
|
||||
poolTx, congestionTree, connAddr, err := builder.BuildRoundTx(
|
||||
pubkey, f.Payments, []ports.BoardingInput{}, []domain.Round{}, cosigners...,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
@@ -98,7 +98,7 @@ func TestBuildPoolTx(t *testing.T) {
|
||||
if len(fixtures.Invalid) > 0 {
|
||||
t.Run("invalid", func(t *testing.T) {
|
||||
for _, f := range fixtures.Invalid {
|
||||
poolTx, congestionTree, connAddr, err := builder.BuildPoolTx(
|
||||
poolTx, congestionTree, connAddr, err := builder.BuildRoundTx(
|
||||
pubkey, f.Payments, []ports.BoardingInput{}, []domain.Round{},
|
||||
)
|
||||
require.EqualError(t, err, f.ExpectedErr)
|
||||
|
||||
Reference in New Issue
Block a user