diff --git a/asp/internal/core/application/service.go b/asp/internal/core/application/service.go index 3534645..07eff7e 100644 --- a/asp/internal/core/application/service.go +++ b/asp/internal/core/application/service.go @@ -227,7 +227,14 @@ func (s *service) startFinalization() { return } + var changes []domain.RoundEvent defer func() { + if len(changes) > 0 { + if err := s.repoManager.Events().Save(ctx, round.Id, changes...); err != nil { + log.WithError(err).Warn("failed to store new round events") + } + } + if round.IsFailed() { s.startRound() return @@ -240,14 +247,6 @@ func (s *service) startFinalization() { return } - var changes []domain.RoundEvent - defer func() { - if err := s.repoManager.Events().Save(ctx, round.Id, changes...); err != nil { - log.WithError(err).Warn("failed to store new round events") - return - } - }() - // TODO: understand how many payments must be popped from the queue and actually registered for the round num := s.paymentRequests.len() if num == 0 { @@ -274,6 +273,8 @@ func (s *service) startFinalization() { return } + log.Debugf("pool tx created for round %s", round.Id) + connectors, forfeitTxs, err := s.builder.BuildForfeitTxs(s.pubkey, signedPoolTx, payments) if err != nil { changes = round.Fail(fmt.Errorf("failed to create connectors and forfeit txs: %s", err)) @@ -281,7 +282,14 @@ func (s *service) startFinalization() { return } - events, _ := round.StartFinalization(connectors, tree, signedPoolTx) + log.Debugf("forfeit transactions created for round %s", round.Id) + + events, err := round.StartFinalization(connectors, tree, signedPoolTx) + if err != nil { + changes = round.Fail(fmt.Errorf("failed to start finalization: %s", err)) + log.WithError(err).Warn("failed to start finalization") + return + } changes = append(changes, events...) s.forfeitTxs.push(forfeitTxs) diff --git a/asp/internal/core/domain/congestion_tree.go b/asp/internal/core/domain/congestion_tree.go deleted file mode 100644 index 0b7fe4a..0000000 --- a/asp/internal/core/domain/congestion_tree.go +++ /dev/null @@ -1,44 +0,0 @@ -package domain - -type Node struct { - Txid string - Tx string - ParentTxid string - Leaf bool -} - -type CongestionTree [][]Node - -func (c CongestionTree) Leaves() []Node { - leaves := c[len(c)-1] - for _, level := range c[:len(c)-1] { - for _, node := range level { - if node.Leaf { - leaves = append(leaves, node) - } - } - } - - return leaves -} - -func (c CongestionTree) Children(nodeTxid string) []Node { - var children []Node - for _, level := range c { - for _, node := range level { - if node.ParentTxid == nodeTxid { - children = append(children, node) - } - } - } - - return children -} - -func (c CongestionTree) NumberOfNodes() int { - var count int - for _, level := range c { - count += len(level) - } - return count -} diff --git a/asp/internal/core/domain/events.go b/asp/internal/core/domain/events.go index 5426301..978fadb 100644 --- a/asp/internal/core/domain/events.go +++ b/asp/internal/core/domain/events.go @@ -1,5 +1,7 @@ package domain +import "github.com/ark-network/ark/common/tree" + type RoundEvent interface { isEvent() } @@ -17,7 +19,7 @@ type RoundStarted struct { type RoundFinalizationStarted struct { Id string - CongestionTree CongestionTree + CongestionTree tree.CongestionTree Connectors []string UnsignedForfeitTxs []string PoolTx string diff --git a/asp/internal/core/domain/round.go b/asp/internal/core/domain/round.go index 2ccd35b..7bcbfc2 100644 --- a/asp/internal/core/domain/round.go +++ b/asp/internal/core/domain/round.go @@ -4,6 +4,7 @@ import ( "fmt" "time" + "github.com/ark-network/ark/common/tree" "github.com/google/uuid" ) @@ -41,7 +42,7 @@ type Round struct { Txid string TxHex string ForfeitTxs []string - CongestionTree CongestionTree + CongestionTree tree.CongestionTree Connectors []string DustAmount uint64 Version uint @@ -143,11 +144,11 @@ func (r *Round) RegisterPayments(payments []Payment) ([]RoundEvent, error) { return []RoundEvent{event}, nil } -func (r *Round) StartFinalization(connectors []string, tree CongestionTree, poolTx string) ([]RoundEvent, error) { +func (r *Round) StartFinalization(connectors []string, congestionTree tree.CongestionTree, poolTx string) ([]RoundEvent, error) { if len(connectors) <= 0 { return nil, fmt.Errorf("missing list of connectors") } - if len(tree) <= 0 { + if len(congestionTree) <= 0 { return nil, fmt.Errorf("missing congestion tree") } if len(poolTx) <= 0 { @@ -162,7 +163,7 @@ func (r *Round) StartFinalization(connectors []string, tree CongestionTree, pool event := RoundFinalizationStarted{ Id: r.Id, - CongestionTree: tree, + CongestionTree: congestionTree, Connectors: connectors, PoolTx: poolTx, } diff --git a/asp/internal/core/domain/round_test.go b/asp/internal/core/domain/round_test.go index 70a953b..e44defd 100644 --- a/asp/internal/core/domain/round_test.go +++ b/asp/internal/core/domain/round_test.go @@ -4,6 +4,7 @@ import ( "fmt" "testing" + "github.com/ark-network/ark/common/tree" "github.com/ark-network/ark/internal/core/domain" "github.com/stretchr/testify/require" ) @@ -72,7 +73,7 @@ var ( emptyTx = "0200000000000000000000" txid = "0000000000000000000000000000000000000000000000000000000000000000" pubkey = "030000000000000000000000000000000000000000000000000000000000000001" - congestionTree = domain.CongestionTree{ + congestionTree = tree.CongestionTree{ { { Txid: txid, @@ -318,7 +319,7 @@ func testStartFinalization(t *testing.T) { fixtures := []struct { round *domain.Round connectors []string - tree domain.CongestionTree + tree tree.CongestionTree poolTx string expectedErr string }{ diff --git a/asp/internal/core/ports/tx_builder.go b/asp/internal/core/ports/tx_builder.go index 96ec5f0..1e16a58 100644 --- a/asp/internal/core/ports/tx_builder.go +++ b/asp/internal/core/ports/tx_builder.go @@ -1,6 +1,7 @@ package ports import ( + "github.com/ark-network/ark/common/tree" "github.com/ark-network/ark/internal/core/domain" "github.com/decred/dcrd/dcrec/secp256k1/v4" ) @@ -8,7 +9,7 @@ import ( type TxBuilder interface { BuildPoolTx( aspPubkey *secp256k1.PublicKey, wallet WalletService, payments []domain.Payment, minRelayFee uint64, - ) (poolTx string, congestionTree domain.CongestionTree, err error) + ) (poolTx string, congestionTree tree.CongestionTree, err error) BuildForfeitTxs( aspPubkey *secp256k1.PublicKey, poolTx string, payments []domain.Payment, ) (connectors []string, forfeitTxs []string, err error) diff --git a/asp/internal/infrastructure/db/service_test.go b/asp/internal/infrastructure/db/service_test.go index 3e41625..c40e027 100644 --- a/asp/internal/infrastructure/db/service_test.go +++ b/asp/internal/infrastructure/db/service_test.go @@ -6,6 +6,7 @@ import ( "testing" "time" + "github.com/ark-network/ark/common/tree" "github.com/ark-network/ark/internal/core/domain" "github.com/ark-network/ark/internal/core/ports" "github.com/ark-network/ark/internal/infrastructure/db" @@ -21,7 +22,7 @@ const ( pubkey = "0300000000000000000000000000000000000000000000000000000000000000001" ) -var congestionTree = [][]domain.Node{ +var congestionTree = [][]tree.Node{ { { Txid: txid, diff --git a/asp/internal/infrastructure/tx-builder/covenant/builder.go b/asp/internal/infrastructure/tx-builder/covenant/builder.go index eb05ad3..5dde242 100644 --- a/asp/internal/infrastructure/tx-builder/covenant/builder.go +++ b/asp/internal/infrastructure/tx-builder/covenant/builder.go @@ -4,7 +4,7 @@ import ( "context" "encoding/hex" - "github.com/ark-network/ark/common" + "github.com/ark-network/ark/common/tree" "github.com/ark-network/ark/internal/core/domain" "github.com/ark-network/ark/internal/core/ports" "github.com/decred/dcrd/dcrec/secp256k1/v4" @@ -165,7 +165,7 @@ func (b *txBuilder) BuildPoolTx( wallet ports.WalletService, payments []domain.Payment, minRelayFee uint64, -) (poolTx string, congestionTree domain.CongestionTree, err error) { +) (poolTx string, congestionTree tree.CongestionTree, err error) { aspScriptBytes, err := p2wpkhScript(aspPubkey, b.net) if err != nil { return @@ -226,12 +226,12 @@ func (b *txBuilder) BuildPoolTx( } func (b *txBuilder) getLeafTaprootTree(userPubkey, aspPubkey *secp256k1.PublicKey) ([]byte, *taproot.IndexedElementsTapScriptTree, error) { - sweepTaprootLeaf, err := sweepTapLeaf(aspPubkey) + sweepTaprootLeaf, err := tree.SweepScript(aspPubkey, expirationTime) if err != nil { return nil, nil, err } - vtxoLeaf, err := common.VtxoScript(userPubkey) + vtxoLeaf, err := tree.VtxoScript(userPubkey) if err != nil { return nil, nil, err } @@ -239,7 +239,7 @@ func (b *txBuilder) getLeafTaprootTree(userPubkey, aspPubkey *secp256k1.PublicKe leafTaprootTree := taproot.AssembleTaprootScriptTree(*vtxoLeaf, *sweepTaprootLeaf) root := leafTaprootTree.RootNode.TapHash() - unspendableKeyBytes, _ := hex.DecodeString(unspendablePoint) + unspendableKeyBytes, _ := hex.DecodeString(tree.UnspendablePoint) unspendableKey, _ := secp256k1.ParsePubKey(unspendableKeyBytes) taprootKey := taproot.ComputeTaprootOutputKey( diff --git a/asp/internal/infrastructure/tx-builder/covenant/builder_test.go b/asp/internal/infrastructure/tx-builder/covenant/builder_test.go index 2c81de4..af2cc69 100644 --- a/asp/internal/infrastructure/tx-builder/covenant/builder_test.go +++ b/asp/internal/infrastructure/tx-builder/covenant/builder_test.go @@ -5,17 +5,16 @@ import ( "testing" "github.com/ark-network/ark/common" + "github.com/ark-network/ark/common/tree" "github.com/ark-network/ark/internal/core/domain" "github.com/ark-network/ark/internal/core/ports" txbuilder "github.com/ark-network/ark/internal/infrastructure/tx-builder/covenant" - "github.com/btcsuite/btcd/btcec/v2/schnorr" "github.com/btcsuite/btcd/chaincfg/chainhash" secp256k1 "github.com/decred/dcrd/dcrec/secp256k1/v4" "github.com/stretchr/testify/require" "github.com/vulpemventures/go-elements/network" "github.com/vulpemventures/go-elements/payment" "github.com/vulpemventures/go-elements/psetv2" - "github.com/vulpemventures/go-elements/taproot" "github.com/vulpemventures/go-elements/transaction" ) @@ -124,9 +123,36 @@ func TestBuildCongestionTree(t *testing.T) { fixtures := []struct { payments []domain.Payment - expectedNodesNum int // 2*len(receivers)-1 + expectedNodesNum int // 2*len(receivers) -1 expectedLeavesNum int }{ + { + payments: []domain.Payment{ + { + Id: "0", + Inputs: []domain.Vtxo{ + { + VtxoKey: domain.VtxoKey{ + Txid: "fd68e3c5796cc7db0a8036d486d5f625b6b2f2c014810ac020e1ac23e82c59d6", + VOut: 0, + }, + Receiver: domain.Receiver{ + Pubkey: "020000000000000000000000000000000000000000000000000000000000000002", + Amount: 1100, + }, + }, + }, + Receivers: []domain.Receiver{ + { + Pubkey: "020000000000000000000000000000000000000000000000000000000000000002", + Amount: 1100, + }, + }, + }, + }, + expectedNodesNum: 1, + expectedLeavesNum: 1, + }, { payments: []domain.Payment{ { @@ -238,6 +264,88 @@ func TestBuildCongestionTree(t *testing.T) { }, expectedNodesNum: 5, expectedLeavesNum: 3, + }, { + payments: []domain.Payment{ + { + Id: "a242cdd8-f3d5-46c0-ae98-94135a2bee3f", + Inputs: []domain.Vtxo{ + { + VtxoKey: domain.VtxoKey{ + Txid: "755c820771284d85ea4bbcc246565b4eddadc44237a7e57a0f9cb78a840d1d41", + VOut: 0, + }, + Receiver: domain.Receiver{ + Pubkey: "02c87e5c1758df5ad42a918ec507b6e8dfcdcebf22f64f58eb4ad5804257d658a5", + Amount: 1000, + }, + }, + { + VtxoKey: domain.VtxoKey{ + Txid: "66a0df86fcdeb84b8877adfe0b2c556dba30305d72ddbd4c49355f6930355357", + VOut: 0, + }, + Receiver: domain.Receiver{ + Pubkey: "02c87e5c1758df5ad42a918ec507b6e8dfcdcebf22f64f58eb4ad5804257d658a5", + Amount: 1000, + }, + }, + { + VtxoKey: domain.VtxoKey{ + Txid: "9913159bc7aa493ca53cbb9cbc88f97ba01137c814009dc7ef520c3fafc67909", + VOut: 1, + }, + Receiver: domain.Receiver{ + Pubkey: "02c87e5c1758df5ad42a918ec507b6e8dfcdcebf22f64f58eb4ad5804257d658a5", + Amount: 500, + }, + }, + { + VtxoKey: domain.VtxoKey{ + Txid: "5e10e77a7cdedc153be5193a4b6055a7802706ded4f2a9efefe86ed2f9a6ae60", + VOut: 0, + }, + Receiver: domain.Receiver{ + Pubkey: "02c87e5c1758df5ad42a918ec507b6e8dfcdcebf22f64f58eb4ad5804257d658a5", + Amount: 1000, + }, + }, + { + VtxoKey: domain.VtxoKey{ + Txid: "5e10e77a7cdedc153be5193a4b6055a7802706ded4f2a9efefe86ed2f9a6ae60", + VOut: 1, + }, + Receiver: domain.Receiver{ + Pubkey: "02c87e5c1758df5ad42a918ec507b6e8dfcdcebf22f64f58eb4ad5804257d658a5", + Amount: 1000, + }, + }, + }, + Receivers: []domain.Receiver{ + { + Pubkey: "02c87e5c1758df5ad42a918ec507b6e8dfcdcebf22f64f58eb4ad5804257d658a5", + Amount: 1000, + }, + { + Pubkey: "02c87e5c1758df5ad42a918ec507b6e8dfcdcebf22f64f58eb4ad5804257d658a5", + Amount: 1000, + }, + { + Pubkey: "02c87e5c1758df5ad42a918ec507b6e8dfcdcebf22f64f58eb4ad5804257d658a5", + Amount: 1000, + }, + { + Pubkey: "02c87e5c1758df5ad42a918ec507b6e8dfcdcebf22f64f58eb4ad5804257d658a5", + Amount: 1000, + }, + { + Pubkey: "02c87e5c1758df5ad42a918ec507b6e8dfcdcebf22f64f58eb4ad5804257d658a5", + Amount: 500, + }, + }, + }, + }, + expectedNodesNum: 4, + expectedLeavesNum: 3, }, } @@ -246,54 +354,19 @@ func TestBuildCongestionTree(t *testing.T) { require.NotNil(t, key) for _, f := range fixtures { - poolTx, tree, err := builder.BuildPoolTx(key, &mockedWalletService{}, f.payments, 30) - + poolTx, congestionTree, err := builder.BuildPoolTx(key, &mockedWalletService{}, f.payments, 30) require.NoError(t, err) - require.Equal(t, f.expectedNodesNum, tree.NumberOfNodes()) - require.Len(t, tree.Leaves(), f.expectedLeavesNum) + require.Equal(t, f.expectedNodesNum, congestionTree.NumberOfNodes()) + require.Len(t, congestionTree.Leaves(), f.expectedLeavesNum) - poolTransaction, err := transaction.NewTxFromHex(poolTx) + // check that the pool tx has the right number of inputs and outputs + err = tree.ValidateCongestionTree( + congestionTree, + poolTx, + key, + 1209344, // 2 weeks - 8 minutes + ) require.NoError(t, err) - - poolTxID := poolTransaction.TxHash().String() - - // check the root - require.Len(t, tree[0], 1) - require.Equal(t, poolTxID, tree[0][0].ParentTxid) - - // check the nodes - for _, level := range tree { - for _, node := range level { - pset, err := psetv2.NewPsetFromBase64(node.Tx) - require.NoError(t, err) - - require.Len(t, pset.Inputs, 1) - require.Len(t, pset.Outputs, 3) - - inputTxID := chainhash.Hash(pset.Inputs[0].PreviousTxid).String() - require.Equal(t, node.ParentTxid, inputTxID) - - children := tree.Children(node.Txid) - if len(children) > 0 { - require.Len(t, children, 2) - - for i, child := range children { - childTx, err := psetv2.NewPsetFromBase64(child.Tx) - require.NoError(t, err) - - for _, leaf := range childTx.Inputs[0].TapLeafScript { - key := leaf.ControlBlock.InternalKey - rootHash := leaf.ControlBlock.RootHash(leaf.Script) - - outputScript := taproot.ComputeTaprootOutputKey(key, rootHash) - previousScriptKey := pset.Outputs[i].Script[2:] - require.Len(t, previousScriptKey, 32) - require.Equal(t, schnorr.SerializePubKey(outputScript), previousScriptKey) - } - } - } - } - } } } diff --git a/asp/internal/infrastructure/tx-builder/covenant/forfeit.go b/asp/internal/infrastructure/tx-builder/covenant/forfeit.go index 3194b55..45160ab 100644 --- a/asp/internal/infrastructure/tx-builder/covenant/forfeit.go +++ b/asp/internal/infrastructure/tx-builder/covenant/forfeit.go @@ -3,6 +3,7 @@ package txbuilder import ( "encoding/hex" + "github.com/ark-network/ark/common/tree" "github.com/btcsuite/btcd/txscript" "github.com/decred/dcrd/dcrec/secp256k1/v4" "github.com/vulpemventures/go-elements/elementsutil" @@ -51,7 +52,7 @@ func createForfeitTx( return "", err } - unspendableKeyBytes, _ := hex.DecodeString(unspendablePoint) + unspendableKeyBytes, _ := hex.DecodeString(tree.UnspendablePoint) unspendableKey, _ := secp256k1.ParsePubKey(unspendableKeyBytes) for _, proof := range vtxoTaprootTree.LeafMerkleProofs { diff --git a/asp/internal/infrastructure/tx-builder/covenant/tree.go b/asp/internal/infrastructure/tx-builder/covenant/tree.go index f8cf770..3aa1949 100644 --- a/asp/internal/infrastructure/tx-builder/covenant/tree.go +++ b/asp/internal/infrastructure/tx-builder/covenant/tree.go @@ -1,11 +1,10 @@ package txbuilder import ( - "encoding/binary" "encoding/hex" "fmt" - "github.com/ark-network/ark/common" + "github.com/ark-network/ark/common/tree" "github.com/ark-network/ark/internal/core/domain" "github.com/btcsuite/btcd/btcec/v2/schnorr" "github.com/btcsuite/btcd/chaincfg/chainhash" @@ -17,112 +16,11 @@ import ( ) const ( - OP_INSPECTOUTPUTSCRIPTPUBKEY = 0xd1 - OP_INSPECTOUTPUTVALUE = 0xcf - OP_PUSHCURRENTINPUTINDEX = 0xcd - unspendablePoint = "0250929b74c1a04954b78b4b6035e97a5e078a5a0f28ec96d547bfee9ace803ac0" - timeDelta = 60 * 60 * 24 * 14 // 14 days in seconds + expirationTime = 60 * 60 * 24 * 14 // 14 days in seconds ) // the private method buildCongestionTree returns a function letting to plug in the pool transaction output as input of the tree's root node -type pluggableCongestionTree func(outpoint psetv2.InputArgs) (domain.CongestionTree, error) - -// withOutput returns an introspection script that checks the script and the amount of the output at the given index -// verify will add an OP_EQUALVERIFY at the end of the script, otherwise it will add an OP_EQUAL -func withOutput(index byte, taprootWitnessProgram []byte, amount uint64, verify bool) []byte { - amountBuffer := make([]byte, 8) - binary.LittleEndian.PutUint64(amountBuffer, amount) - - script := []byte{ - index, - OP_INSPECTOUTPUTSCRIPTPUBKEY, - txscript.OP_1, - txscript.OP_EQUALVERIFY, - txscript.OP_DATA_32, - } - - script = append(script, taprootWitnessProgram...) - script = append(script, []byte{ - txscript.OP_EQUALVERIFY, - }...) - script = append(script, index) - script = append(script, []byte{ - OP_INSPECTOUTPUTVALUE, - txscript.OP_1, - txscript.OP_EQUALVERIFY, - txscript.OP_DATA_8, - }...) - script = append(script, amountBuffer...) - if verify { - script = append(script, []byte{ - txscript.OP_EQUALVERIFY, - }...) - } else { - script = append(script, []byte{ - txscript.OP_EQUAL, - }...) - } - - return script -} - -func checksigScript(pubkey *secp256k1.PublicKey) ([]byte, error) { - key := schnorr.SerializePubKey(pubkey) - return txscript.NewScriptBuilder().AddData(key).AddOp(txscript.OP_CHECKSIG).Script() -} - -// checkSequenceVerifyScript without checksig -func checkSequenceVerifyScript(seconds uint) ([]byte, error) { - sequence, err := common.BIP68Encode(seconds) - if err != nil { - return nil, err - } - - return append(sequence, []byte{ - txscript.OP_CHECKSEQUENCEVERIFY, - txscript.OP_DROP, - }...), nil -} - -// checkSequenceVerifyScript + checksig -func csvChecksigScript(pubkey *secp256k1.PublicKey, seconds uint) ([]byte, error) { - script, err := checksigScript(pubkey) - if err != nil { - return nil, err - } - - csvScript, err := checkSequenceVerifyScript(seconds) - if err != nil { - return nil, err - } - - return append(csvScript, script...), nil -} - -// sweepTapLeaf returns a taproot leaf letting the owner of the key to spend the output after a given timeDelta -func sweepTapLeaf(sweepKey *secp256k1.PublicKey) (*taproot.TapElementsLeaf, error) { - sweepScript, err := csvChecksigScript(sweepKey, timeDelta) - if err != nil { - return nil, err - } - - tapLeaf := taproot.NewBaseTapElementsLeaf(sweepScript) - return &tapLeaf, nil -} - -// forceSplitCoinTapLeaf returns a taproot leaf that enforces a split into two outputs -// each output (left and right) will have the given amount and the given taproot key as witness program -func forceSplitCoinTapLeaf( - leftKey, rightKey *secp256k1.PublicKey, leftAmount, rightAmount uint64, -) taproot.TapElementsLeaf { - nextScriptLeft := withOutput(txscript.OP_0, schnorr.SerializePubKey(leftKey), leftAmount, rightKey != nil) - branchScript := append([]byte{}, nextScriptLeft...) - if rightKey != nil { - nextScriptRight := withOutput(txscript.OP_1, schnorr.SerializePubKey(rightKey), rightAmount, false) - branchScript = append(branchScript, nextScriptRight...) - } - return taproot.NewBaseTapElementsLeaf(branchScript) -} +type pluggableCongestionTree func(outpoint psetv2.InputArgs) (tree.CongestionTree, error) func taprootOutputScript(taprootKey *secp256k1.PublicKey) ([]byte, error) { return txscript.NewScriptBuilder().AddOp(txscript.OP_1).AddData(schnorr.SerializePubKey(taprootKey)).Script() @@ -165,7 +63,7 @@ func buildCongestionTree( receivers []domain.Receiver, feeSatsPerNode uint64, ) (pluggableTree pluggableCongestionTree, sharedOutputScript []byte, sharedOutputAmount uint64, err error) { - unspendableKeyBytes, err := hex.DecodeString(unspendablePoint) + unspendableKeyBytes, err := hex.DecodeString(tree.UnspendablePoint) if err != nil { return nil, nil, 0, err } @@ -203,7 +101,7 @@ func buildCongestionTree( } // compute the shared output script - sweepLeaf, err := sweepTapLeaf(aspPublicKey) + sweepLeaf, err := tree.VtxoScript(aspPublicKey) if err != nil { return nil, nil, 0, err } @@ -217,7 +115,7 @@ func buildCongestionTree( var rightAmount uint64 var rightKey *secp256k1.PublicKey - if len(rootPset.Outputs) > 1 { + if len(rootPset.Outputs) > 2 { rightAmount = rootPset.Outputs[1].Value rightKey, err = schnorr.ParsePubKey(rootPset.Outputs[1].Script[2:]) if err != nil { @@ -225,7 +123,7 @@ func buildCongestionTree( } } - goToTreeScript := forceSplitCoinTapLeaf( + goToTreeScript := tree.BranchScript( leftKey, rightKey, leftOutput.Value, rightAmount, ) @@ -237,7 +135,7 @@ func buildCongestionTree( return nil, nil, 0, err } - return func(outpoint psetv2.InputArgs) (domain.CongestionTree, error) { + return func(outpoint psetv2.InputArgs) (tree.CongestionTree, error) { psets, err := nodes[0].psets(&psetArgs{ input: outpoint, taprootTree: taprootTree, @@ -253,7 +151,7 @@ func buildCongestionTree( } } - tree := make(domain.CongestionTree, maxLevel+1) + congestionTree := make(tree.CongestionTree, maxLevel+1) for _, psetWithLevel := range psets { utx, err := psetWithLevel.pset.UnsignedTx() @@ -270,7 +168,7 @@ func buildCongestionTree( parentTxid := chainhash.Hash(psetWithLevel.pset.Inputs[0].PreviousTxid).String() - tree[psetWithLevel.level] = append(tree[psetWithLevel.level], domain.Node{ + congestionTree[psetWithLevel.level] = append(congestionTree[psetWithLevel.level], tree.Node{ Txid: txid, Tx: psetB64, ParentTxid: parentTxid, @@ -278,7 +176,7 @@ func buildCongestionTree( }) } - return tree, nil + return congestionTree, nil }, outputScript, uint64(rightAmount) + leftOutput.Value + uint64(feeSatsPerNode), nil } @@ -349,7 +247,7 @@ func newBranch( } func (n *node) isLeaf() bool { - return n.left.isEmpty() && (n.right == nil || n.right.isEmpty()) + return (n.left == nil || n.left.isEmpty()) && (n.right == nil || n.right.isEmpty()) } // is it the final node of the tree @@ -398,7 +296,7 @@ func (n *node) taprootKey() (*secp256k1.PublicKey, *taproot.IndexedElementsTapSc return n._taprootKey, n._taprootTree, nil } - sweepTaprootLeaf, err := sweepTapLeaf(n.sweepKey) + sweepTaprootLeaf, err := tree.SweepScript(n.sweepKey, expirationTime) if err != nil { return nil, nil, err } @@ -414,7 +312,7 @@ func (n *node) taprootKey() (*secp256k1.PublicKey, *taproot.IndexedElementsTapSc return nil, nil, err } - vtxoLeaf, err := common.VtxoScript(pubkey) + vtxoLeaf, err := tree.VtxoScript(pubkey) if err != nil { return nil, nil, err } @@ -443,7 +341,7 @@ func (n *node) taprootKey() (*secp256k1.PublicKey, *taproot.IndexedElementsTapSc return nil, nil, err } - branchTaprootLeaf := forceSplitCoinTapLeaf( + branchTaprootLeaf := tree.BranchScript( leftKey, rightKey, n.left.amount(), n.right.amount(), ) @@ -565,10 +463,10 @@ func (n *node) psets(inputArgs *psetArgs, level int) ([]psetWithLevel, error) { } nodeResult := []psetWithLevel{ - {pset, level, n.isLeaf()}, + {pset, level, n.isLeaf() || (n.left.isEmpty() || n.right.isEmpty())}, } - if n.left.isEmpty() && (n.right == nil || n.right.isEmpty()) { + if n.isLeaf() { return nodeResult, nil } @@ -583,37 +481,46 @@ func (n *node) psets(inputArgs *psetArgs, level int) ([]psetWithLevel, error) { txID := unsignedTx.TxHash().String() - _, leftTaprootTree, err := n.left.taprootKey() - if err != nil { - return nil, err + if !n.left.isEmpty() { + _, leftTaprootTree, err := n.left.taprootKey() + if err != nil { + return nil, err + } + + psetsLeft, err := n.left.psets(&psetArgs{ + input: psetv2.InputArgs{ + Txid: txID, + TxIndex: 0, + }, + taprootTree: leftTaprootTree, + }, level+1) + if err != nil { + return nil, err + } + + nodeResult = append(nodeResult, psetsLeft...) } - psetsLeft, err := n.left.psets(&psetArgs{ - input: psetv2.InputArgs{ - Txid: txID, - TxIndex: 0, - }, - taprootTree: leftTaprootTree, - }, level+1) - if err != nil { - return nil, err + if !n.right.isEmpty() { + + _, rightTaprootTree, err := n.right.taprootKey() + if err != nil { + return nil, err + } + + psetsRight, err := n.right.psets(&psetArgs{ + input: psetv2.InputArgs{ + Txid: txID, + TxIndex: 1, + }, + taprootTree: rightTaprootTree, + }, level+1) + if err != nil { + return nil, err + } + + nodeResult = append(nodeResult, psetsRight...) } - _, rightTaprootTree, err := n.right.taprootKey() - if err != nil { - return nil, err - } - - psetsRight, err := n.right.psets(&psetArgs{ - input: psetv2.InputArgs{ - Txid: txID, - TxIndex: 1, - }, - taprootTree: rightTaprootTree, - }, level+1) - if err != nil { - return nil, err - } - - return append(nodeResult, append(psetsLeft, psetsRight...)...), nil + return nodeResult, nil } diff --git a/asp/internal/infrastructure/tx-builder/dummy/builder.go b/asp/internal/infrastructure/tx-builder/dummy/builder.go index b84e5f8..844df9d 100644 --- a/asp/internal/infrastructure/tx-builder/dummy/builder.go +++ b/asp/internal/infrastructure/tx-builder/dummy/builder.go @@ -4,6 +4,7 @@ import ( "context" "encoding/hex" + "github.com/ark-network/ark/common/tree" "github.com/ark-network/ark/internal/core/domain" "github.com/ark-network/ark/internal/core/ports" "github.com/decred/dcrd/dcrec/secp256k1/v4" @@ -92,7 +93,7 @@ func (b *txBuilder) BuildForfeitTxs( func (b *txBuilder) BuildPoolTx( aspPubkey *secp256k1.PublicKey, wallet ports.WalletService, payments []domain.Payment, minRelayFee uint64, -) (poolTx string, congestionTree domain.CongestionTree, err error) { +) (poolTx string, congestionTree tree.CongestionTree, err error) { aspScriptBytes, err := p2wpkhScript(aspPubkey, b.net) if err != nil { return "", nil, err diff --git a/asp/internal/infrastructure/tx-builder/dummy/tree.go b/asp/internal/infrastructure/tx-builder/dummy/tree.go index 0ac3b52..e1a1940 100644 --- a/asp/internal/infrastructure/tx-builder/dummy/tree.go +++ b/asp/internal/infrastructure/tx-builder/dummy/tree.go @@ -3,6 +3,7 @@ package txbuilder import ( "encoding/hex" + "github.com/ark-network/ark/common/tree" "github.com/ark-network/ark/internal/core/domain" "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/decred/dcrd/dcrec/secp256k1/v4" @@ -63,7 +64,7 @@ func buildCongestionTree( net network.Network, poolTxID string, receivers []domain.Receiver, -) (congestionTree domain.CongestionTree, err error) { +) (congestionTree tree.CongestionTree, err error) { var nodes []*node for _, r := range receivers { @@ -92,7 +93,7 @@ func buildCongestionTree( } } - tree := make(domain.CongestionTree, maxLevel+1) + congestionTree = make(tree.CongestionTree, maxLevel+1) for _, psetWithLevel := range psets { utx, err := psetWithLevel.pset.UnsignedTx() @@ -109,7 +110,7 @@ func buildCongestionTree( parentTxid := chainhash.Hash(psetWithLevel.pset.Inputs[0].PreviousTxid).String() - tree[psetWithLevel.level] = append(tree[psetWithLevel.level], domain.Node{ + congestionTree[psetWithLevel.level] = append(congestionTree[psetWithLevel.level], tree.Node{ Txid: txid, Tx: psetB64, ParentTxid: parentTxid, @@ -117,7 +118,7 @@ func buildCongestionTree( }) } - return tree, nil + return congestionTree, nil } func createTreeLevel(nodes []*node) ([]*node, error) { diff --git a/asp/internal/interface/grpc/handlers/arkservice.go b/asp/internal/interface/grpc/handlers/arkservice.go index 2defd59..bf15c76 100644 --- a/asp/internal/interface/grpc/handlers/arkservice.go +++ b/asp/internal/interface/grpc/handlers/arkservice.go @@ -7,6 +7,7 @@ import ( arkv1 "github.com/ark-network/ark/api-spec/protobuf/gen/ark/v1" "github.com/ark-network/ark/common" + "github.com/ark-network/ark/common/tree" "github.com/ark-network/ark/internal/core/application" "github.com/ark-network/ark/internal/core/domain" "github.com/decred/dcrd/dcrec/secp256k1/v4" @@ -280,10 +281,10 @@ func (v vtxoList) toProto(hrp string, aspKey *secp256k1.PublicKey) []*arkv1.Vtxo return list } -// castCongestionTree converts a domain.CongestionTree to a repeated arkv1.TreeLevel -func castCongestionTree(tree domain.CongestionTree) *arkv1.Tree { - levels := make([]*arkv1.TreeLevel, 0, len(tree)) - for _, level := range tree { +// castCongestionTree converts a tree.CongestionTree to a repeated arkv1.TreeLevel +func castCongestionTree(congestionTree tree.CongestionTree) *arkv1.Tree { + levels := make([]*arkv1.TreeLevel, 0, len(congestionTree)) + for _, level := range congestionTree { levelProto := &arkv1.TreeLevel{ Nodes: make([]*arkv1.Node, 0, len(level)), } diff --git a/asp/internal/interface/grpc/handlers/utils.go b/asp/internal/interface/grpc/handlers/utils.go index 15623f1..02e04a6 100644 --- a/asp/internal/interface/grpc/handlers/utils.go +++ b/asp/internal/interface/grpc/handlers/utils.go @@ -18,7 +18,7 @@ func parseTxs(txs []string) ([]string, error) { } for _, tx := range txs { if _, err := psetv2.NewPsetFromBase64(tx); err != nil { - return nil, fmt.Errorf("invalid tx format %s", err) + return nil, fmt.Errorf("invalid tx format") } } return txs, nil diff --git a/common/tree.go b/common/tree.go deleted file mode 100644 index 8c894c8..0000000 --- a/common/tree.go +++ /dev/null @@ -1,23 +0,0 @@ -package common - -import ( - "github.com/btcsuite/btcd/btcec/v2/schnorr" - "github.com/btcsuite/btcd/txscript" - "github.com/decred/dcrd/dcrec/secp256k1/v4" - "github.com/vulpemventures/go-elements/taproot" -) - -func checksigScript(pubkey *secp256k1.PublicKey) ([]byte, error) { - key := schnorr.SerializePubKey(pubkey) - return txscript.NewScriptBuilder().AddData(key).AddOp(txscript.OP_CHECKSIG).Script() -} - -func VtxoScript(pubkey *secp256k1.PublicKey) (*taproot.TapElementsLeaf, error) { - script, err := checksigScript(pubkey) - if err != nil { - return nil, err - } - - tapLeaf := taproot.NewBaseTapElementsLeaf(script) - return &tapLeaf, nil -} diff --git a/common/tree/congestion_tree.go b/common/tree/congestion_tree.go new file mode 100644 index 0000000..3ca6438 --- /dev/null +++ b/common/tree/congestion_tree.go @@ -0,0 +1,97 @@ +package tree + +import "errors" + +// Node is a struct embedding the transaction and the parent txid of a congestion tree node +type Node struct { + Txid string + Tx string + ParentTxid string + Leaf bool +} + +var ( + ErrParentNotFound = errors.New("parent not found") + ErrLeafNotFound = errors.New("leaf not found in congestion tree") +) + +// CongestionTree is reprensented as a matrix of TreeNode struct +// the first level of the matrix is the root of the tree +type CongestionTree [][]Node + +// Leaves returns the leaves of the congestion tree (the vtxos txs) +func (c CongestionTree) Leaves() []Node { + leaves := c[len(c)-1] + for _, level := range c[:len(c)-1] { + for _, node := range level { + if node.Leaf { + leaves = append(leaves, node) + } + } + } + + return leaves +} + +// Children returns all the nodes that have the given node as parent +func (c CongestionTree) Children(nodeTxid string) []Node { + var children []Node + for _, level := range c { + for _, node := range level { + if node.ParentTxid == nodeTxid { + children = append(children, node) + } + } + } + + return children +} + +func (c CongestionTree) NumberOfNodes() int { + var count int + for _, level := range c { + count += len(level) + } + return count +} + +func (c CongestionTree) Branch(vtxoTxid string) ([]Node, error) { + branch := make([]Node, 0) + + leaves := c.Leaves() + // check if the vtxo is a leaf + found := false + for _, leaf := range leaves { + if leaf.Txid == vtxoTxid { + found = true + branch = append(branch, leaf) + break + } + } + if !found { + return nil, ErrLeafNotFound + } + + rootTxid := c[0][0].Txid + + for branch[0].Txid != rootTxid { + parent, err := branch[0].findParent(c) + if err != nil { + return nil, err + } + branch = append([]Node{parent}, branch...) + } + + return branch, nil +} + +func (n Node) findParent(tree CongestionTree) (Node, error) { + for _, level := range tree { + for _, node := range level { + if node.Txid == n.ParentTxid { + return node, nil + } + } + } + return Node{}, ErrParentNotFound +} diff --git a/common/tree/script.go b/common/tree/script.go new file mode 100644 index 0000000..775d4d2 --- /dev/null +++ b/common/tree/script.go @@ -0,0 +1,260 @@ +package tree + +import ( + "bytes" + "encoding/binary" + + "github.com/ark-network/ark/common" + "github.com/btcsuite/btcd/btcec/v2/schnorr" + "github.com/btcsuite/btcd/txscript" + "github.com/decred/dcrd/dcrec/secp256k1/v4" + "github.com/vulpemventures/go-elements/taproot" +) + +const ( + OP_INSPECTOUTPUTSCRIPTPUBKEY = 0xd1 + OP_INSPECTOUTPUTVALUE = 0xcf + OP_PUSHCURRENTINPUTINDEX = 0xcd +) + +// VtxoScript returns a simple checksig script for a given pubkey +func VtxoScript(pubkey *secp256k1.PublicKey) (*taproot.TapElementsLeaf, error) { + script, err := checksigScript(pubkey) + if err != nil { + return nil, err + } + + tapLeaf := taproot.NewBaseTapElementsLeaf(script) + return &tapLeaf, nil +} + +// SweepScript returns a taproot leaf letting the owner of the key to spend the output after a given timeDelta +func SweepScript(sweepKey *secp256k1.PublicKey, seconds uint) (*taproot.TapElementsLeaf, error) { + sweepScript, err := csvChecksigScript(sweepKey, seconds) + if err != nil { + return nil, err + } + + tapLeaf := taproot.NewBaseTapElementsLeaf(sweepScript) + return &tapLeaf, nil +} + +// BranchScript returns a taproot leaf that will split the coin in two outputs +// each output (left and right) will have the given amount and the given taproot key as witness program +func BranchScript( + leftKey, rightKey *secp256k1.PublicKey, leftAmount, rightAmount uint64, +) taproot.TapElementsLeaf { + nextScriptLeft := withOutput(txscript.OP_0, schnorr.SerializePubKey(leftKey), leftAmount, rightKey != nil) + branchScript := append([]byte{}, nextScriptLeft...) + if rightKey != nil { + nextScriptRight := withOutput(txscript.OP_1, schnorr.SerializePubKey(rightKey), rightAmount, false) + branchScript = append(branchScript, nextScriptRight...) + } + return taproot.NewBaseTapElementsLeaf(branchScript) +} + +func decodeBranchScript(script []byte) (valid bool, leftKey, rightKey *secp256k1.PublicKey, leftAmount, rightAmount uint64, err error) { + if len(script) != 52 && len(script) != 104 { + return false, nil, nil, 0, 0, nil + } + + isLeftOnly := len(script) == 52 + + validLeft, leftKey, leftAmount, err := decodeWithOutputScript(script[:52], txscript.OP_0, !isLeftOnly) + if err != nil { + return false, nil, nil, 0, 0, err + } + + if !validLeft { + return false, nil, nil, 0, 0, nil + } + + if isLeftOnly { + return true, leftKey, nil, leftAmount, 0, nil + } + + validRight, rightKey, rightAmount, err := decodeWithOutputScript(script[52:], txscript.OP_1, false) + if err != nil { + return false, nil, nil, 0, 0, err + } + + if !validRight { + return false, nil, nil, 0, 0, nil + } + + rebuilt := BranchScript(leftKey, rightKey, leftAmount, rightAmount) + + if !bytes.Equal(rebuilt.Script, script) { + return false, nil, nil, 0, 0, nil + } + + return true, leftKey, rightKey, leftAmount, rightAmount, nil +} + +func decodeWithOutputScript(script []byte, expectedIndex byte, isVerify bool) (valid bool, pubkey *secp256k1.PublicKey, amount uint64, err error) { + if len(script) != 52 { + return false, nil, 0, nil + } + + if script[0] != expectedIndex { + return false, nil, 0, nil + } + + // 32 bytes for the witness program + pubkey, err = schnorr.ParsePubKey(script[5 : 5+32]) + if err != nil { + return false, nil, 0, err + } + + inspectOutputValueIndex := bytes.IndexByte(script, OP_INSPECTOUTPUTVALUE) + if inspectOutputValueIndex == -1 { + return false, nil, 0, nil + } + + if script[inspectOutputValueIndex-1] != expectedIndex { + return false, nil, 0, nil + } + + // 8 bytes for the amount + amountBytes := script[len(script)-9 : len(script)-1] + amount = binary.LittleEndian.Uint64(amountBytes) + + rebuilt := withOutput(expectedIndex, schnorr.SerializePubKey(pubkey), amount, isVerify) + if !bytes.Equal(rebuilt, script) { + return false, nil, 0, nil + } + + return true, pubkey, amount, nil +} + +func decodeChecksigScript(script []byte) (valid bool, pubkey *secp256k1.PublicKey, err error) { + checksigIndex := bytes.Index(script, []byte{txscript.OP_CHECKSIG}) + if checksigIndex == -1 || checksigIndex == 0 { + return false, nil, nil + } + + key := script[1:checksigIndex] + if len(key) != 32 { + return false, nil, nil + } + + pubkey, err = schnorr.ParsePubKey(key) + if err != nil { + return false, nil, err + } + + rebuilt, err := checksigScript(pubkey) + if err != nil { + return false, nil, err + } + + if !bytes.Equal(rebuilt, script) { + return false, nil, nil + } + + return true, pubkey, nil +} + +func decodeSweepScript(script []byte) (valid bool, aspPubKey *secp256k1.PublicKey, seconds uint, err error) { + csvIndex := bytes.Index(script, []byte{txscript.OP_CHECKSEQUENCEVERIFY, txscript.OP_DROP}) + if csvIndex == -1 || csvIndex == 0 { + return false, nil, 0, nil + } + + sequence := script[:csvIndex] + + seconds, err = common.BIP68Decode(sequence) + if err != nil { + return false, nil, 0, err + } + + checksigScript := script[csvIndex+2:] + valid, aspPubKey, err = decodeChecksigScript(checksigScript) + if err != nil { + return false, nil, 0, err + } + + rebuilt, err := csvChecksigScript(aspPubKey, seconds) + if err != nil { + return false, nil, 0, err + } + + if !bytes.Equal(rebuilt, script) { + return false, nil, 0, nil + } + + return valid, aspPubKey, seconds, nil +} + +// checkSequenceVerifyScript without checksig +func checkSequenceVerifyScript(seconds uint) ([]byte, error) { + sequence, err := common.BIP68Encode(seconds) + if err != nil { + return nil, err + } + + return append(sequence, []byte{ + txscript.OP_CHECKSEQUENCEVERIFY, + txscript.OP_DROP, + }...), nil +} + +// checkSequenceVerifyScript + checksig +func csvChecksigScript(pubkey *secp256k1.PublicKey, seconds uint) ([]byte, error) { + script, err := checksigScript(pubkey) + if err != nil { + return nil, err + } + + csvScript, err := checkSequenceVerifyScript(seconds) + if err != nil { + return nil, err + } + + return append(csvScript, script...), nil +} + +func checksigScript(pubkey *secp256k1.PublicKey) ([]byte, error) { + key := schnorr.SerializePubKey(pubkey) + return txscript.NewScriptBuilder().AddData(key).AddOp(txscript.OP_CHECKSIG).Script() +} + +// withOutput returns an introspection script that checks the script and the amount of the output at the given index +// verify will add an OP_EQUALVERIFY at the end of the script, otherwise it will add an OP_EQUAL +// length = 52 bytes +func withOutput(index byte, taprootWitnessProgram []byte, amount uint64, verify bool) []byte { + amountBuffer := make([]byte, 8) + binary.LittleEndian.PutUint64(amountBuffer, amount) + + script := []byte{ + index, + OP_INSPECTOUTPUTSCRIPTPUBKEY, + txscript.OP_1, + txscript.OP_EQUALVERIFY, + txscript.OP_DATA_32, + } + + script = append(script, taprootWitnessProgram...) + script = append(script, []byte{ + txscript.OP_EQUALVERIFY, + }...) + script = append(script, index) + script = append(script, []byte{ + OP_INSPECTOUTPUTVALUE, + txscript.OP_1, + txscript.OP_EQUALVERIFY, + txscript.OP_DATA_8, + }...) + script = append(script, amountBuffer...) + if verify { + script = append(script, []byte{ + txscript.OP_EQUALVERIFY, + }...) + } else { + script = append(script, []byte{ + txscript.OP_EQUAL, + }...) + } + + return script +} diff --git a/common/tree/validation.go b/common/tree/validation.go new file mode 100644 index 0000000..a9677ea --- /dev/null +++ b/common/tree/validation.go @@ -0,0 +1,306 @@ +package tree + +import ( + "bytes" + "encoding/hex" + "errors" + "fmt" + + "github.com/btcsuite/btcd/btcec/v2/schnorr" + "github.com/btcsuite/btcd/chaincfg/chainhash" + "github.com/decred/dcrd/dcrec/secp256k1/v4" + "github.com/vulpemventures/go-elements/elementsutil" + "github.com/vulpemventures/go-elements/psetv2" + "github.com/vulpemventures/go-elements/taproot" + "github.com/vulpemventures/go-elements/transaction" +) + +var ( + ErrInvalidPoolTransaction = errors.New("invalid pool transaction") + ErrEmptyTree = errors.New("empty congestion tree") + ErrInvalidRootLevel = errors.New("root level must have only one node") + ErrNoLeaves = errors.New("no leaves in the tree") + ErrNodeTransactionEmpty = errors.New("node transaction is empty") + ErrNodeTxidEmpty = errors.New("node txid is empty") + ErrNodeParentTxidEmpty = errors.New("node parent txid is empty") + ErrNodeTxidDifferent = errors.New("node txid differs from node transaction") + ErrNumberOfInputs = errors.New("node transaction should have only one input") + ErrNumberOfOutputs = errors.New("node transaction should have only three or two outputs") + ErrParentTxidInput = errors.New("parent txid should be the input of the node transaction") + ErrNumberOfChildren = errors.New("node branch transaction should have two children") + ErrLeafChildren = errors.New("leaf node should have max 1 child") + ErrInvalidChildTxid = errors.New("invalid child txid") + ErrNumberOfTapscripts = errors.New("input should have two tapscripts leaves") + ErrInternalKey = errors.New("taproot internal key is not unspendable") + ErrInvalidTaprootScript = errors.New("invalid taproot script") + ErrInvalidLeafTaprootScript = errors.New("invalid leaf taproot script") + ErrInvalidAmount = errors.New("children amount is different from parent amount") + ErrInvalidAsset = errors.New("invalid output asset") + ErrInvalidSweepSequence = errors.New("invalid sweep sequence") + ErrInvalidASP = errors.New("invalid ASP") + ErrMissingFeeOutput = errors.New("missing fee output") + ErrInvalidLeftOutput = errors.New("invalid left output") + ErrInvalidRightOutput = errors.New("invalid right output") + ErrMissingSweepTapscript = errors.New("missing sweep tapscript") + ErrMissingBranchTapscript = errors.New("missing branch tapscript") + ErrInvalidLeaf = errors.New("leaf node shouldn't have children") + ErrWrongPoolTxID = errors.New("root input should be the pool tx outpoint") +) + +const ( + UnspendablePoint = "0250929b74c1a04954b78b4b6035e97a5e078a5a0f28ec96d547bfee9ace803ac0" + sharedOutputIndex = 0 +) + +// ValidateCongestionTree checks if the given congestion tree is valid +// poolTxID & poolTxIndex & poolTxAmount are used to validate the root input outpoint +// aspPublicKey & roundLifetimeSeconds are used to validate the sweep tapscript leaves +// besides that, the function validates: +// - the number of nodes +// - the number of leaves +// - children coherence with parent +// - every control block and taproot output scripts +// - input and output amounts +func ValidateCongestionTree( + tree CongestionTree, + poolTxHex string, + aspPublicKey *secp256k1.PublicKey, + roundLifetimeSeconds uint, +) error { + unspendableKeyBytes, _ := hex.DecodeString(UnspendablePoint) + unspendableKey, _ := secp256k1.ParsePubKey(unspendableKeyBytes) + + poolTransaction, err := transaction.NewTxFromHex(poolTxHex) + if err != nil { + return ErrInvalidPoolTransaction + } + + poolTxAmount, err := elementsutil.ValueFromBytes(poolTransaction.Outputs[sharedOutputIndex].Value) + if err != nil { + return ErrInvalidPoolTransaction + } + + poolTxID := poolTransaction.TxHash().String() + + nbNodes := tree.NumberOfNodes() + if nbNodes == 0 { + return ErrEmptyTree + } + + if len(tree[0]) != 1 { + return ErrInvalidRootLevel + } + + // check that root input is connected to the pool tx + rootPsetB64 := tree[0][0].Tx + rootPset, err := psetv2.NewPsetFromBase64(rootPsetB64) + if err != nil { + return fmt.Errorf("invalid root transaction: %w", err) + } + + if len(rootPset.Inputs) != 1 { + return ErrNumberOfInputs + } + + rootInput := rootPset.Inputs[0] + if chainhash.Hash(rootInput.PreviousTxid).String() != poolTxID || rootInput.PreviousTxIndex != sharedOutputIndex { + return ErrWrongPoolTxID + } + + sumRootValue := uint64(0) + for _, output := range rootPset.Outputs { + sumRootValue += output.Value + } + + if sumRootValue != poolTxAmount { + return ErrInvalidAmount + } + + if len(tree.Leaves()) == 0 { + return ErrNoLeaves + } + + // iterates over all the nodes of the tree + for _, level := range tree { + for _, node := range level { + if err := validateNodeTransaction(node, tree, unspendableKey, aspPublicKey, roundLifetimeSeconds); err != nil { + return err + } + } + } + + return nil +} + +func validateNodeTransaction( + node Node, + tree CongestionTree, + expectedInternalKey, + expectedPublicKeyASP *secp256k1.PublicKey, + expectedSequenceSeconds uint, +) error { + if node.Tx == "" { + return ErrNodeTransactionEmpty + } + + if node.Txid == "" { + return ErrNodeTxidEmpty + } + + if node.ParentTxid == "" { + return ErrNodeParentTxidEmpty + } + + decodedPset, err := psetv2.NewPsetFromBase64(node.Tx) + if err != nil { + return fmt.Errorf("invalid node transaction: %w", err) + } + + utx, err := decodedPset.UnsignedTx() + if err != nil { + return fmt.Errorf("invalid node transaction: %w", err) + } + + if utx.TxHash().String() != node.Txid { + return ErrNodeTxidDifferent + } + + if len(decodedPset.Inputs) != 1 { + return ErrNumberOfInputs + } + + input := decodedPset.Inputs[0] + if len(input.TapLeafScript) != 2 { + return ErrNumberOfTapscripts + } + + if chainhash.Hash(decodedPset.Inputs[0].PreviousTxid).String() != node.ParentTxid { + return ErrParentTxidInput + } + + feeOutput := decodedPset.Outputs[len(decodedPset.Outputs)-1] + if len(feeOutput.Script) != 0 { + return ErrMissingFeeOutput + } + + children := tree.Children(node.Txid) + + if node.Leaf && len(children) > 1 { + return ErrLeafChildren + } + + for childIndex, child := range children { + childTx, err := psetv2.NewPsetFromBase64(child.Tx) + if err != nil { + return fmt.Errorf("invalid child transaction: %w", err) + } + + parentOutput := decodedPset.Outputs[childIndex] + previousScriptKey := parentOutput.Script[2:] + if len(previousScriptKey) != 32 { + return ErrInvalidTaprootScript + } + + sweepLeafFound := false + branchLeafFound := false + + for _, tapLeaf := range childTx.Inputs[0].TapLeafScript { + key := tapLeaf.ControlBlock.InternalKey + if !key.IsEqual(expectedInternalKey) { + return ErrInternalKey + } + + rootHash := tapLeaf.ControlBlock.RootHash(tapLeaf.Script) + outputScript := taproot.ComputeTaprootOutputKey(key, rootHash) + + if !bytes.Equal(schnorr.SerializePubKey(outputScript), previousScriptKey) { + return ErrInvalidTaprootScript + } + + isSweepLeaf, aspKey, seconds, err := decodeSweepScript(tapLeaf.Script) + if err != nil { + return fmt.Errorf("invalid sweep script: %w", err) + } + + if isSweepLeaf { + if !aspKey.IsEqual(aspKey) { + return ErrInvalidASP + } + + if seconds != expectedSequenceSeconds { + return ErrInvalidSweepSequence + } + + sweepLeafFound = true + continue + } + + isBranchLeaf, leftKey, rightKey, leftAmount, rightAmount, err := decodeBranchScript(tapLeaf.Script) + if err != nil { + return fmt.Errorf("invalid vtxo script: %w", err) + } + + if isBranchLeaf { + branchLeafFound = true + + // check outputs + nbOuts := len(childTx.Outputs) + if leftKey != nil && rightKey != nil { + if nbOuts != 3 { + return ErrNumberOfOutputs + } + } else { + if nbOuts != 2 { + return ErrNumberOfOutputs + } + } + + leftWitnessProgram := childTx.Outputs[0].Script[2:] + leftOutputAmount := childTx.Outputs[0].Value + + if !bytes.Equal(leftWitnessProgram, schnorr.SerializePubKey(leftKey)) { + return ErrInvalidLeftOutput + } + + if leftAmount != leftOutputAmount { + return ErrInvalidLeftOutput + } + + if rightKey != nil { + rightWitnessProgram := childTx.Outputs[1].Script[2:] + rightOutputAmount := childTx.Outputs[1].Value + + if !bytes.Equal(rightWitnessProgram, schnorr.SerializePubKey(rightKey)) { + return ErrInvalidRightOutput + } + + if rightAmount != rightOutputAmount { + return ErrInvalidRightOutput + } + } + } + } + + if !sweepLeafFound { + return ErrMissingSweepTapscript + } + + if !branchLeafFound { + return ErrMissingBranchTapscript + } + + sumChildAmount := uint64(0) + for _, output := range childTx.Outputs { + sumChildAmount += output.Value + if !bytes.Equal(output.Asset, parentOutput.Asset) { + return ErrInvalidAsset + } + } + + if sumChildAmount != parentOutput.Value { + return ErrInvalidAmount + } + } + + return nil +} diff --git a/noah/common.go b/noah/common.go index 0e41b1a..3dcf108 100644 --- a/noah/common.go +++ b/noah/common.go @@ -13,12 +13,18 @@ import ( arkv1 "github.com/ark-network/ark/api-spec/protobuf/gen/ark/v1" "github.com/ark-network/ark/common" + "github.com/ark-network/ark/common/tree" + "github.com/btcsuite/btcd/btcec/v2/schnorr" "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/decred/dcrd/dcrec/secp256k1/v4" "github.com/urfave/cli/v2" + "github.com/vulpemventures/go-elements/address" + "github.com/vulpemventures/go-elements/elementsutil" "github.com/vulpemventures/go-elements/network" "github.com/vulpemventures/go-elements/payment" "github.com/vulpemventures/go-elements/psetv2" + "github.com/vulpemventures/go-elements/taproot" + "github.com/vulpemventures/go-elements/transaction" "golang.org/x/term" ) @@ -344,6 +350,7 @@ func handleRoundStream( paymentID string, vtxosToSign []vtxo, secKey *secp256k1.PrivateKey, + receivers []*arkv1.Output, ) (poolTxID string, err error) { stream, err := client.GetEventStream(ctx.Context, &arkv1.GetEventStreamRequest{}) if err != nil { @@ -377,6 +384,121 @@ func handleRoundStream( if event.GetRoundFinalization() != nil { // stop pinging as soon as we receive some forfeit txs pingStop() + + poolPartialTx := event.GetRoundFinalization().GetPoolPartialTx() + poolTransaction, err := transaction.NewTxFromHex(poolPartialTx) + if err != nil { + return "", err + } + + congestionTree, err := toCongestionTree(event.GetRoundFinalization().GetCongestionTree()) + if err != nil { + return "", err + } + + aspPublicKey, err := getServiceProviderPublicKey() + if err != nil { + return "", err + } + + // validate the congestion tree + if err := tree.ValidateCongestionTree( + congestionTree, + poolPartialTx, + aspPublicKey, + 1209344, // ~ 2 weeks + ); err != nil { + return "", err + } + + // validate the receivers + sweepLeaf, err := tree.SweepScript(aspPublicKey, 1209344) + if err != nil { + return "", err + } + + for _, receiver := range receivers { + isOnChain, onchainScript, userPubKey, err := decodeReceiverAddress(receiver.Address) + if err != nil { + return "", err + } + + if isOnChain { + // collaborative exit case + // search for the output in the pool tx + found := false + for _, output := range poolTransaction.Outputs { + if bytes.Equal(output.Script, onchainScript) { + outputValue, err := elementsutil.ValueFromBytes(output.Value) + if err != nil { + return "", err + } + + if outputValue != receiver.Amount { + return "", fmt.Errorf("invalid collaborative exit output amount: got %d, want %d", outputValue, receiver.Amount) + } + + found = true + break + } + } + + if !found { + return "", fmt.Errorf("collaborative exit output not found: %s", receiver.Address) + } + + continue + } + + // off-chain send case + // search for the output in congestion tree + found := false + + // compute the receiver output taproot key + vtxoScript, err := tree.VtxoScript(userPubKey) + if err != nil { + return "", err + } + + vtxoTaprootTree := taproot.AssembleTaprootScriptTree(*vtxoScript, *sweepLeaf) + root := vtxoTaprootTree.RootNode.TapHash() + unspendableKeyBytes, _ := hex.DecodeString(tree.UnspendablePoint) + unspendableKey, _ := secp256k1.ParsePubKey(unspendableKeyBytes) + vtxoTaprootKey := schnorr.SerializePubKey(taproot.ComputeTaprootOutputKey(unspendableKey, root[:])) + + leaves := congestionTree.Leaves() + for _, leaf := range leaves { + tx, err := psetv2.NewPsetFromBase64(leaf.Tx) + if err != nil { + return "", err + } + + for _, output := range tx.Outputs { + if len(output.Script) == 0 { + continue + } + if bytes.Equal(output.Script[2:], vtxoTaprootKey) { + if output.Value != receiver.Amount { + continue + } + + found = true + break + } + } + + if found { + break + } + } + + if !found { + return "", fmt.Errorf("off-chain send output not found: %s", receiver.Address) + } + } + + fmt.Println("congestion tree validated") + forfeits := event.GetRoundFinalization().GetForfeitTxs() signedForfeits := make([]string, 0) @@ -459,3 +581,50 @@ func ping(ctx *cli.Context, client arkv1.ArkServiceClient, req *arkv1.PingReques return ticker.Stop } + +func toCongestionTree(treeFromProto *arkv1.Tree) (tree.CongestionTree, error) { + levels := make(tree.CongestionTree, 0, len(treeFromProto.Levels)) + + for _, level := range treeFromProto.Levels { + nodes := make([]tree.Node, 0, len(level.Nodes)) + + for _, node := range level.Nodes { + nodes = append(nodes, tree.Node{ + Txid: node.Txid, + Tx: node.Tx, + ParentTxid: node.ParentTxid, + Leaf: false, + }) + } + + levels = append(levels, nodes) + } + + for j, treeLvl := range levels { + for i, node := range treeLvl { + if len(levels.Children(node.Txid)) < 2 { + levels[j][i].Leaf = true + } + } + } + + return levels, nil +} + +func decodeReceiverAddress(addr string) ( + isOnChainAddress bool, + onchainScript []byte, + userPubKey *secp256k1.PublicKey, + err error, +) { + outputScript, err := address.ToOutputScript(addr) + if err != nil { + _, userPubKey, _, err = common.DecodeAddress(addr) + if err != nil { + return + } + return false, nil, userPubKey, nil + } + + return true, outputScript, nil, nil +} diff --git a/noah/explorer.go b/noah/explorer.go index 74b0190..d76cf8c 100644 --- a/noah/explorer.go +++ b/noah/explorer.go @@ -49,6 +49,7 @@ func (e *explorer) Broadcast(txHex string) (string, error) { if strings.Contains(strings.ToLower(err.Error()), "transaction already in block chain") { return txid, nil } + return "", err } diff --git a/noah/redeem.go b/noah/redeem.go index dc502a3..cc21abf 100644 --- a/noah/redeem.go +++ b/noah/redeem.go @@ -10,6 +10,7 @@ import ( "time" arkv1 "github.com/ark-network/ark/api-spec/protobuf/gen/ark/v1" + "github.com/ark-network/ark/common/tree" "github.com/urfave/cli/v2" "github.com/vulpemventures/go-elements/address" "github.com/vulpemventures/go-elements/psetv2" @@ -156,6 +157,7 @@ func collaborativeRedeem(ctx *cli.Context, addr string, amount uint64) error { registerResponse.GetId(), selectedCoins, secKey, + receivers, ) if err != nil { return err @@ -212,10 +214,20 @@ func unilateralRedeem(ctx *cli.Context, addr string) error { return err } - congestionTrees := make(map[string]*arkv1.Tree, 0) + congestionTrees := make(map[string]tree.CongestionTree, 0) transactionsMap := make(map[string]struct{}, 0) transactions := make([]string, 0) + aspPublicKey, err := getServiceProviderPublicKey() + if err != nil { + return err + } + + sweepLeaf, err := tree.SweepScript(aspPublicKey, 1209344) + if err != nil { + return err + } + for _, vtxo := range vtxos { if _, ok := congestionTrees[vtxo.poolTxid]; !ok { round, err := client.GetRound(ctx.Context, &arkv1.GetRoundRequest{ @@ -225,10 +237,16 @@ func unilateralRedeem(ctx *cli.Context, addr string) error { return err } - congestionTrees[vtxo.poolTxid] = round.GetRound().GetCongestionTree() + treeFromRound := round.GetRound().GetCongestionTree() + congestionTree, err := toCongestionTree(treeFromRound) + if err != nil { + return err + } + + congestionTrees[vtxo.poolTxid] = congestionTree } - redeemBranch, err := newRedeemBranch(ctx, congestionTrees[vtxo.poolTxid], vtxo) + redeemBranch, err := newRedeemBranch(ctx, congestionTrees[vtxo.poolTxid], vtxo, sweepLeaf) if err != nil { return err } @@ -277,7 +295,7 @@ func unilateralRedeem(ctx *cli.Context, addr string) error { } vBytes := utx.VirtualSize() - feeAmount := uint64(math.Ceil(float64(vBytes) * 0.2)) + feeAmount := uint64(math.Ceil(float64(vBytes) * 0.25)) if totalVtxosAmount-feeAmount <= 0 { return fmt.Errorf("not enough VTXOs to pay the fees (%d sats), aborting unilateral exit", feeAmount) @@ -305,7 +323,7 @@ func unilateralRedeem(ctx *cli.Context, addr string) error { for { txid, err := explorer.Broadcast(txHex) if err != nil { - if strings.Contains(err.Error(), "bad-txns-inputs-missingorspent") { + if strings.Contains(strings.ToLower(err.Error()), "bad-txns-inputs-missingorspent") { time.Sleep(1 * time.Second) } else { return err @@ -341,13 +359,21 @@ func unilateralRedeem(ctx *cli.Context, addr string) error { return err } - id, err := explorer.Broadcast(hex) - if err != nil { - return err + for { + id, err := explorer.Broadcast(hex) + if err != nil { + if strings.Contains(strings.ToLower(err.Error()), "bad-txns-inputs-missingorspent") { + time.Sleep(1 * time.Second) + continue + } + return err + } + if id != "" { + fmt.Printf("(final) redeem tx %s\n", id) + break + } } - fmt.Printf("(final) redeem tx %s\n", id) - return nil } diff --git a/noah/send.go b/noah/send.go index 721a933..400c424 100644 --- a/noah/send.go +++ b/noah/send.go @@ -155,6 +155,7 @@ func sendAction(ctx *cli.Context) error { registerResponse.GetId(), selectedCoins, secKey, + receiversOutput, ) if err != nil { return err diff --git a/noah/signer.go b/noah/signer.go index 64ff1c5..bc37da1 100644 --- a/noah/signer.go +++ b/noah/signer.go @@ -4,7 +4,7 @@ import ( "bytes" "fmt" - "github.com/ark-network/ark/common" + "github.com/ark-network/ark/common/tree" "github.com/btcsuite/btcd/btcec/v2/ecdsa" "github.com/btcsuite/btcd/btcec/v2/schnorr" "github.com/btcsuite/btcd/chaincfg/chainhash" @@ -127,7 +127,7 @@ func signPset( pubkey := prvKey.PubKey() - vtxoLeaf, err := common.VtxoScript(pubkey) + vtxoLeaf, err := tree.VtxoScript(pubkey) if err != nil { return err } diff --git a/noah/unilateral_redeem.go b/noah/unilateral_redeem.go index eeed95b..05c7515 100644 --- a/noah/unilateral_redeem.go +++ b/noah/unilateral_redeem.go @@ -4,10 +4,8 @@ import ( "bytes" "fmt" - arkv1 "github.com/ark-network/ark/api-spec/protobuf/gen/ark/v1" - "github.com/ark-network/ark/common" + "github.com/ark-network/ark/common/tree" "github.com/btcsuite/btcd/btcec/v2/schnorr" - "github.com/btcsuite/btcd/txscript" "github.com/decred/dcrd/dcrec/secp256k1/v4" "github.com/urfave/cli/v2" "github.com/vulpemventures/go-elements/psetv2" @@ -30,48 +28,33 @@ type redeemBranch struct { internalKey *secp256k1.PublicKey } -func newRedeemBranch(ctx *cli.Context, tree *arkv1.Tree, vtxo vtxo) (RedeemBranch, error) { - for _, level := range tree.Levels { - for _, node := range level.Nodes { - if node.Txid == vtxo.txid { - nodes, err := findParents([]*arkv1.Node{node}, tree) - if err != nil { - return nil, err - } - - branch := make([]*psetv2.Pset, 0, len(nodes)) - for _, node := range nodes { - pset, err := psetv2.NewPsetFromBase64(node.Tx) - if err != nil { - return nil, err - } - branch = append(branch, pset) - } - - // find sweep tap leaf - sweepTapLeaf, err := findSweepLeafScript(branch[0].Inputs[0].TapLeafScript) - if err != nil { - return nil, err - } - - xOnlyKey := branch[0].Inputs[0].TapInternalKey - internalKey, err := schnorr.ParsePubKey(xOnlyKey) - if err != nil { - return nil, err - } - - return &redeemBranch{ - vtxo: &vtxo, - branch: branch, - sweepTapLeaf: sweepTapLeaf, - internalKey: internalKey, - }, nil - - } - } +func newRedeemBranch(ctx *cli.Context, congestionTree tree.CongestionTree, vtxo vtxo, sweepLeaf *taproot.TapElementsLeaf) (RedeemBranch, error) { + nodes, err := congestionTree.Branch(vtxo.txid) + if err != nil { + return nil, err } - return nil, fmt.Errorf("vtxo not found") + branch := make([]*psetv2.Pset, 0, len(nodes)) + for _, node := range nodes { + pset, err := psetv2.NewPsetFromBase64(node.Tx) + if err != nil { + return nil, err + } + branch = append(branch, pset) + } + + xOnlyKey := branch[0].Inputs[0].TapInternalKey + internalKey, err := schnorr.ParsePubKey(xOnlyKey) + if err != nil { + return nil, err + } + + return &redeemBranch{ + vtxo: &vtxo, + branch: branch, + sweepTapLeaf: sweepLeaf, + internalKey: internalKey, + }, nil } // UpdatePath checks for transactions of the branch onchain and updates the branch accordingly @@ -169,7 +152,7 @@ func (r *redeemBranch) AddVtxoInput(updater *psetv2.Updater) error { } // add taproot tree letting to spend the vtxo - checksigLeaf, err := common.VtxoScript(walletPubkey) + checksigLeaf, err := tree.VtxoScript(walletPubkey) if err != nil { return nil } @@ -193,45 +176,3 @@ func (r *redeemBranch) AddVtxoInput(updater *psetv2.Updater) error { return nil } - -// findParents is a recursive function that finds all the parents of a VTXO in a congestion tree -// it returns the branch of the tree letting to redeem the VTXO (from pool tx to leaf) -func findParents(ls []*arkv1.Node, tree *arkv1.Tree) ([]*arkv1.Node, error) { - if len(ls) == 0 { - return nil, fmt.Errorf("empty list") - } - - for levelIndex, level := range tree.Levels { - for _, node := range level.Nodes { - if node.Txid == ls[0].ParentTxid { - newTree := &arkv1.Tree{ - Levels: tree.Levels[:levelIndex], - } - - newList := append([]*arkv1.Node{node}, ls...) - if len(newTree.Levels) > 0 { - return findParents(newList, newTree) - } - - return newList, nil - } - } - } - return nil, fmt.Errorf("parent not found") -} - -// findSweepLeafScript finds the sweep leaf in a set of tap leaf scripts -func findSweepLeafScript(leaves []psetv2.TapLeafScript) (*taproot.TapElementsLeaf, error) { - for _, leaf := range leaves { - if len(leaf.Script) == 0 { - continue - } - - if bytes.Contains(leaf.Script, []byte{txscript.OP_CHECKSIG}) && bytes.Contains(leaf.Script, []byte{txscript.OP_CHECKSEQUENCEVERIFY}) { - tapLeaf := taproot.NewBaseTapElementsLeaf(leaf.Script) - return &tapLeaf, nil - } - - } - return nil, fmt.Errorf("sweep leaf not found") -}