CompleteAsyncPayment: validate signatures and transactions (#298)

This commit is contained in:
Louis Singer
2024-09-10 19:21:54 +02:00
committed by GitHub
parent 0fb34cb13d
commit 1387c8da7a
8 changed files with 209 additions and 72 deletions

View File

@@ -47,7 +47,7 @@ type covenantlessService struct {
currentRoundLock sync.Mutex currentRoundLock sync.Mutex
currentRound *domain.Round currentRound *domain.Round
treeSigningSessions map[string]*musigSigningSession treeSigningSessions map[string]*musigSigningSession
asyncPaymentsCache map[domain.VtxoKey]struct { asyncPaymentsCache map[string]struct { // redeem txid -> receivers
receivers []domain.Receiver receivers []domain.Receiver
expireAt int64 expireAt int64
} }
@@ -70,7 +70,7 @@ func NewCovenantlessService(
} }
sweeper := newSweeper(walletSvc, repoManager, builder, scheduler) sweeper := newSweeper(walletSvc, repoManager, builder, scheduler)
asyncPaymentsCache := make(map[domain.VtxoKey]struct { asyncPaymentsCache := make(map[string]struct {
receivers []domain.Receiver receivers []domain.Receiver
expireAt int64 expireAt int64
}) })
@@ -142,14 +142,114 @@ func (s *covenantlessService) Stop() {
func (s *covenantlessService) CompleteAsyncPayment( func (s *covenantlessService) CompleteAsyncPayment(
ctx context.Context, redeemTx string, unconditionalForfeitTxs []string, ctx context.Context, redeemTx string, unconditionalForfeitTxs []string,
) error { ) error {
// TODO check that the user signed both transactions
redeemPtx, err := psbt.NewFromRawBytes(strings.NewReader(redeemTx), true) redeemPtx, err := psbt.NewFromRawBytes(strings.NewReader(redeemTx), true)
if err != nil { if err != nil {
return fmt.Errorf("failed to parse redeem tx: %s", err) return fmt.Errorf("failed to parse redeem tx: %s", err)
} }
redeemTxid := redeemPtx.UnsignedTx.TxID() redeemTxid := redeemPtx.UnsignedTx.TxID()
asyncPayData, ok := s.asyncPaymentsCache[redeemTxid]
if !ok {
return fmt.Errorf("async payment not found")
}
txs := append([]string{redeemTx}, unconditionalForfeitTxs...)
vtxoRepo := s.repoManager.Vtxos()
for _, tx := range txs {
ptx, err := psbt.NewFromRawBytes(strings.NewReader(tx), true)
if err != nil {
return fmt.Errorf("failed to parse tx: %s", err)
}
for inputIndex, input := range ptx.Inputs {
if input.WitnessUtxo == nil {
return fmt.Errorf("missing witness utxo")
}
if len(input.TaprootLeafScript) == 0 {
return fmt.Errorf("missing tapscript leaf")
}
if len(input.TaprootScriptSpendSig) == 0 {
return fmt.Errorf("missing tapscript spend sig")
}
vtxoOutpoint := ptx.UnsignedTx.TxIn[inputIndex].PreviousOutPoint
// verify that the vtxo is spendable
vtxo, err := vtxoRepo.GetVtxos(ctx, []domain.VtxoKey{{Txid: vtxoOutpoint.Hash.String(), VOut: vtxoOutpoint.Index}})
if err != nil {
return fmt.Errorf("failed to get vtxo: %s", err)
}
if len(vtxo) == 0 {
return fmt.Errorf("vtxo not found")
}
if vtxo[0].Spent {
return fmt.Errorf("vtxo already spent")
}
if vtxo[0].Redeemed {
return fmt.Errorf("vtxo already redeemed")
}
if vtxo[0].Swept {
return fmt.Errorf("vtxo already swept")
}
// verify that the user signs the tx using the right public key
vtxoPublicKey, err := hex.DecodeString(vtxo[0].Pubkey)
if err != nil {
return fmt.Errorf("failed to decode pubkey: %s", err)
}
pubkey, err := secp256k1.ParsePubKey(vtxoPublicKey)
if err != nil {
return fmt.Errorf("failed to parse pubkey: %s", err)
}
xonlyPubkey := schnorr.SerializePubKey(pubkey)
// find signature belonging to the pubkey
found := false
for _, sig := range input.TaprootScriptSpendSig {
if bytes.Equal(sig.XOnlyPubKey, xonlyPubkey) {
found = true
break
}
}
if !found {
return fmt.Errorf("signature not found for pubkey")
}
// verify witness utxo
pkscript, err := s.builder.GetVtxoScript(pubkey, s.pubkey)
if err != nil {
return fmt.Errorf("failed to get vtxo script: %s", err)
}
if !bytes.Equal(input.WitnessUtxo.PkScript, pkscript) {
return fmt.Errorf("witness utxo script mismatch")
}
if input.WitnessUtxo.Value != int64(vtxo[0].Amount) {
return fmt.Errorf("witness utxo value mismatch")
}
}
// verify the tapscript signatures
if valid, _, err := s.builder.VerifyTapscriptPartialSigs(tx); err != nil || !valid {
return fmt.Errorf("invalid tx signature: %s", err)
}
}
spentVtxos := make([]domain.VtxoKey, 0, len(unconditionalForfeitTxs)) spentVtxos := make([]domain.VtxoKey, 0, len(unconditionalForfeitTxs))
for _, in := range redeemPtx.UnsignedTx.TxIn { for _, in := range redeemPtx.UnsignedTx.TxIn {
spentVtxos = append(spentVtxos, domain.VtxoKey{ spentVtxos = append(spentVtxos, domain.VtxoKey{
@@ -158,11 +258,6 @@ func (s *covenantlessService) CompleteAsyncPayment(
}) })
} }
asyncPayData, ok := s.asyncPaymentsCache[spentVtxos[0]]
if !ok {
return fmt.Errorf("async payment not found")
}
vtxos := make([]domain.Vtxo, 0, len(asyncPayData.receivers)) vtxos := make([]domain.Vtxo, 0, len(asyncPayData.receivers))
for i, receiver := range asyncPayData.receivers { for i, receiver := range asyncPayData.receivers {
vtxos = append(vtxos, domain.Vtxo{ vtxos = append(vtxos, domain.Vtxo{
@@ -189,7 +284,7 @@ func (s *covenantlessService) CompleteAsyncPayment(
} }
log.Infof("spent %d vtxos", len(spentVtxos)) log.Infof("spent %d vtxos", len(spentVtxos))
delete(s.asyncPaymentsCache, spentVtxos[0]) delete(s.asyncPaymentsCache, redeemTxid)
return nil return nil
} }
@@ -218,6 +313,7 @@ func (s *covenantlessService) CreateAsyncPayment(
if vtxo.Swept { if vtxo.Swept {
return "", nil, fmt.Errorf("all vtxos must be swept") return "", nil, fmt.Errorf("all vtxos must be swept")
} }
if vtxo.ExpireAt < expiration { if vtxo.ExpireAt < expiration {
expiration = vtxo.ExpireAt expiration = vtxo.ExpireAt
} }
@@ -230,7 +326,12 @@ func (s *covenantlessService) CreateAsyncPayment(
return "", nil, fmt.Errorf("failed to build async payment txs: %s", err) return "", nil, fmt.Errorf("failed to build async payment txs: %s", err)
} }
s.asyncPaymentsCache[inputs[0]] = struct { redeemTx, err := psbt.NewFromRawBytes(strings.NewReader(res.RedeemTx), true)
if err != nil {
return "", nil, fmt.Errorf("failed to parse redeem tx: %s", err)
}
s.asyncPaymentsCache[redeemTx.UnsignedTx.TxID()] = struct {
receivers []domain.Receiver receivers []domain.Receiver
expireAt int64 expireAt int64
}{ }{

View File

@@ -185,7 +185,7 @@ func (m *forfeitTxsMap) push(txs []string) {
defer m.lock.Unlock() defer m.lock.Unlock()
for _, tx := range txs { for _, tx := range txs {
signed, txid, _ := m.builder.VerifyForfeitTx(tx) signed, txid, _ := m.builder.VerifyTapscriptPartialSigs(tx)
m.forfeitTxs[txid] = &signedTx{tx, signed} m.forfeitTxs[txid] = &signedTx{tx, signed}
} }
} }
@@ -195,7 +195,7 @@ func (m *forfeitTxsMap) sign(txs []string) error {
defer m.lock.Unlock() defer m.lock.Unlock()
for _, tx := range txs { for _, tx := range txs {
valid, txid, err := m.builder.VerifyForfeitTx(tx) valid, txid, err := m.builder.VerifyTapscriptPartialSigs(tx)
if err != nil { if err != nil {
return err return err
} }

View File

@@ -32,7 +32,7 @@ type TxBuilder interface {
BuildSweepTx(inputs []SweepInput) (signedSweepTx string, err error) BuildSweepTx(inputs []SweepInput) (signedSweepTx string, err error)
GetVtxoScript(userPubkey, aspPubkey *secp256k1.PublicKey) ([]byte, error) GetVtxoScript(userPubkey, aspPubkey *secp256k1.PublicKey) ([]byte, error)
GetSweepInput(parentblocktime int64, node tree.Node) (expirationtime int64, sweepInput SweepInput, err error) GetSweepInput(parentblocktime int64, node tree.Node) (expirationtime int64, sweepInput SweepInput, err error)
VerifyForfeitTx(tx string) (valid bool, txid string, err error) VerifyTapscriptPartialSigs(tx string) (valid bool, txid string, err error)
FinalizeAndExtractForfeit(tx string) (txhex string, err error) FinalizeAndExtractForfeit(tx string) (txhex string, err error)
// FindLeaves returns all the leaves txs that are reachable from the given outpoint // FindLeaves returns all the leaves txs that are reachable from the given outpoint
FindLeaves(congestionTree tree.CongestionTree, fromtxid string, vout uint32) (leaves []tree.Node, err error) FindLeaves(congestionTree tree.CongestionTree, fromtxid string, vout uint32) (leaves []tree.Node, err error)

View File

@@ -1,6 +1,7 @@
package txbuilder package txbuilder
import ( import (
"bytes"
"context" "context"
"encoding/hex" "encoding/hex"
"fmt" "fmt"
@@ -249,27 +250,46 @@ func (b *txBuilder) GetSweepInput(parentblocktime int64, node tree.Node) (expira
return expirationTime, sweepInput, nil return expirationTime, sweepInput, nil
} }
func (b *txBuilder) VerifyForfeitTx(tx string) (bool, string, error) { func (b *txBuilder) VerifyTapscriptPartialSigs(tx string) (bool, string, error) {
ptx, _ := psetv2.NewPsetFromBase64(tx) ptx, _ := psetv2.NewPsetFromBase64(tx)
utx, _ := ptx.UnsignedTx() utx, _ := ptx.UnsignedTx()
txid := utx.TxHash().String() txid := utx.TxHash().String()
for index, input := range ptx.Inputs { for index, input := range ptx.Inputs {
for _, tapScriptSig := range input.TapScriptSig { if len(input.TapLeafScript) == 0 {
leafHash, err := chainhash.NewHash(tapScriptSig.LeafHash) continue
}
if input.WitnessUtxo == nil {
return false, txid, fmt.Errorf("missing witness utxo for input %d, cannot verify signature", index)
}
// verify taproot leaf script
tapLeaf := input.TapLeafScript[0]
rootHash := tapLeaf.ControlBlock.RootHash(tapLeaf.Script)
tapKeyFromControlBlock := taproot.ComputeTaprootOutputKey(tree.UnspendableKey(), rootHash[:])
pkscript, err := p2trScript(tapKeyFromControlBlock)
if err != nil { if err != nil {
return false, txid, err return false, txid, err
} }
if !bytes.Equal(pkscript, input.WitnessUtxo.Script) {
return false, txid, fmt.Errorf("invalid control block for input %d", index)
}
leafHash := taproot.NewBaseTapElementsLeaf(tapLeaf.Script).TapHash()
preimage, err := b.getTaprootPreimage( preimage, err := b.getTaprootPreimage(
tx, tx,
index, index,
leafHash, &leafHash,
) )
if err != nil { if err != nil {
return false, txid, err return false, txid, err
} }
for _, tapScriptSig := range input.TapScriptSig {
sig, err := schnorr.ParseSignature(tapScriptSig.Signature) sig, err := schnorr.ParseSignature(tapScriptSig.Signature)
if err != nil { if err != nil {
return false, txid, err return false, txid, err
@@ -280,15 +300,13 @@ func (b *txBuilder) VerifyForfeitTx(tx string) (bool, string, error) {
return false, txid, err return false, txid, err
} }
if sig.Verify(preimage, pubkey) { if !sig.Verify(preimage, pubkey) {
return true, txid, nil return false, txid, fmt.Errorf("invalid signature for tx %s", txid)
} else {
return false, txid, fmt.Errorf("invalid signature")
} }
} }
} }
return false, txid, nil return true, txid, nil
} }
func (b *txBuilder) FinalizeAndExtractForfeit(tx string) (string, error) { func (b *txBuilder) FinalizeAndExtractForfeit(tx string) (string, error) {
@@ -388,7 +406,7 @@ func (b *txBuilder) getLeafScriptAndTree(
unspendableKey := tree.UnspendableKey() unspendableKey := tree.UnspendableKey()
taprootKey := taproot.ComputeTaprootOutputKey(unspendableKey, root[:]) taprootKey := taproot.ComputeTaprootOutputKey(unspendableKey, root[:])
outputScript, err := taprootOutputScript(taprootKey) outputScript, err := p2trScript(taprootKey)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }

View File

@@ -79,7 +79,7 @@ func sweepTransaction(
root := leaf.ControlBlock.RootHash(leaf.Script) root := leaf.ControlBlock.RootHash(leaf.Script)
taprootKey := taproot.ComputeTaprootOutputKey(leaf.ControlBlock.InternalKey, root) taprootKey := taproot.ComputeTaprootOutputKey(leaf.ControlBlock.InternalKey, root)
script, err := taprootOutputScript(taprootKey) script, err := p2trScript(taprootKey)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -136,7 +136,7 @@ func addInputs(
return nil return nil
} }
func taprootOutputScript(taprootKey *secp256k1.PublicKey) ([]byte, error) { func p2trScript(taprootKey *secp256k1.PublicKey) ([]byte, error) {
return txscript.NewScriptBuilder().AddOp(txscript.OP_1).AddData(schnorr.SerializePubKey(taprootKey)).Script() return txscript.NewScriptBuilder().AddOp(txscript.OP_1).AddData(schnorr.SerializePubKey(taprootKey)).Script()
} }

View File

@@ -38,22 +38,51 @@ func NewTxBuilder(
return &txBuilder{wallet, net, roundLifetime, exitDelay, boardingExitDelay} return &txBuilder{wallet, net, roundLifetime, exitDelay, boardingExitDelay}
} }
func (b *txBuilder) VerifyForfeitTx(tx string) (bool, string, error) { func (b *txBuilder) VerifyTapscriptPartialSigs(tx string) (bool, string, error) {
ptx, _ := psbt.NewFromRawBytes(strings.NewReader(tx), true) ptx, _ := psbt.NewFromRawBytes(strings.NewReader(tx), true)
txid := ptx.UnsignedTx.TxHash().String() txid := ptx.UnsignedTx.TxID()
for index, input := range ptx.Inputs { for index, input := range ptx.Inputs {
// TODO (@louisinger): verify control block if len(input.TaprootLeafScript) == 0 {
for _, tapScriptSig := range input.TaprootScriptSpendSig { continue
}
if input.WitnessUtxo == nil {
return false, txid, fmt.Errorf("missing witness utxo for input %d, cannot verify signature", index)
}
// verify taproot leaf script
tapLeaf := input.TaprootLeafScript[0]
if len(tapLeaf.ControlBlock) == 0 {
return false, txid, fmt.Errorf("missing control block for input %d", index)
}
controlBlock, err := txscript.ParseControlBlock(tapLeaf.ControlBlock)
if err != nil {
return false, txid, err
}
rootHash := controlBlock.RootHash(tapLeaf.Script)
tapKeyFromControlBlock := txscript.ComputeTaprootOutputKey(bitcointree.UnspendableKey(), rootHash[:])
pkscript, err := p2trScript(tapKeyFromControlBlock)
if err != nil {
return false, txid, err
}
if !bytes.Equal(pkscript, input.WitnessUtxo.PkScript) {
return false, txid, fmt.Errorf("invalid control block for input %d", index)
}
preimage, err := b.getTaprootPreimage( preimage, err := b.getTaprootPreimage(
tx, tx,
index, index,
input.TaprootLeafScript[0].Script, tapLeaf.Script,
) )
if err != nil { if err != nil {
return false, txid, err return false, txid, err
} }
for _, tapScriptSig := range input.TaprootScriptSpendSig {
sig, err := schnorr.ParseSignature(tapScriptSig.Signature) sig, err := schnorr.ParseSignature(tapScriptSig.Signature)
if err != nil { if err != nil {
return false, txid, err return false, txid, err
@@ -64,15 +93,13 @@ func (b *txBuilder) VerifyForfeitTx(tx string) (bool, string, error) {
return false, txid, err return false, txid, err
} }
if sig.Verify(preimage, pubkey) { if !sig.Verify(preimage, pubkey) {
return true, txid, nil
} else {
return false, txid, fmt.Errorf("invalid signature for tx %s", txid) return false, txid, fmt.Errorf("invalid signature for tx %s", txid)
} }
} }
} }
return false, txid, nil return true, txid, nil
} }
func (b *txBuilder) FinalizeAndExtractForfeit(tx string) (string, error) { func (b *txBuilder) FinalizeAndExtractForfeit(tx string) (string, error) {
@@ -359,8 +386,7 @@ func (b *txBuilder) BuildAsyncPaymentTransactions(
return nil, err return nil, err
} }
// TODO generate a fresh new address to get the forfeit funds aspScript, err := p2trScript(aspPubKey)
aspScript, err := p2trScript(aspPubKey, b.onchainNetwork())
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -494,8 +520,13 @@ func (b *txBuilder) BuildAsyncPaymentTransactions(
}) })
} }
sequences := make([]uint32, len(ins))
for i := range sequences {
sequences[i] = wire.MaxTxInSequenceNum
}
redeemPtx, err := psbt.New( redeemPtx, err := psbt.New(
ins, outs, 2, 0, []uint32{wire.MaxTxInSequenceNum}, ins, outs, 2, 0, sequences,
) )
if err != nil { if err != nil {
return nil, err return nil, err
@@ -1087,8 +1118,7 @@ func (b *txBuilder) minRelayFeeForfeitTx() (uint64, error) {
func (b *txBuilder) createForfeitTxs( func (b *txBuilder) createForfeitTxs(
aspPubkey *secp256k1.PublicKey, payments []domain.Payment, connectors []*psbt.Packet, feeAmount uint64, aspPubkey *secp256k1.PublicKey, payments []domain.Payment, connectors []*psbt.Packet, feeAmount uint64,
) ([]string, error) { ) ([]string, error) {
// TODO generate a fresh new address to receive the forfeited funds aspScript, err := p2trScript(aspPubkey)
aspScript, err := p2trScript(aspPubkey, b.onchainNetwork())
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -4,24 +4,12 @@ import (
"github.com/ark-network/ark/common/bitcointree" "github.com/ark-network/ark/common/bitcointree"
"github.com/ark-network/ark/server/internal/core/domain" "github.com/ark-network/ark/server/internal/core/domain"
"github.com/btcsuite/btcd/btcec/v2/schnorr" "github.com/btcsuite/btcd/btcec/v2/schnorr"
"github.com/btcsuite/btcd/btcutil"
"github.com/btcsuite/btcd/chaincfg"
"github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/txscript"
"github.com/decred/dcrd/dcrec/secp256k1/v4" "github.com/decred/dcrd/dcrec/secp256k1/v4"
) )
func p2trScript(publicKey *secp256k1.PublicKey, net *chaincfg.Params) ([]byte, error) { func p2trScript(taprootKey *secp256k1.PublicKey) ([]byte, error) {
tapKey := txscript.ComputeTaprootKeyNoScript(publicKey) return txscript.NewScriptBuilder().AddOp(txscript.OP_1).AddData(schnorr.SerializePubKey(taprootKey)).Script()
payment, err := btcutil.NewAddressWitnessPubKeyHash(
btcutil.Hash160(tapKey.SerializeCompressed()),
net,
)
if err != nil {
return nil, err
}
return txscript.PayToAddrScript(payment)
} }
func getOnchainReceivers( func getOnchainReceivers(