diff --git a/go.mod b/go.mod index 6590f71..4bc805a 100644 --- a/go.mod +++ b/go.mod @@ -11,8 +11,10 @@ require ( github.com/google/uuid v1.3.1 github.com/sirupsen/logrus v1.9.3 github.com/spf13/viper v1.17.0 + github.com/stretchr/testify v1.8.4 github.com/urfave/cli/v2 v2.25.7 github.com/vulpemventures/go-elements v0.4.7 + golang.org/x/crypto v0.14.0 golang.org/x/term v0.13.0 google.golang.org/grpc v1.59.0 google.golang.org/protobuf v1.31.0 @@ -22,9 +24,10 @@ require ( github.com/btcsuite/btcd v0.23.1 // indirect github.com/btcsuite/btcd/btcutil v1.1.3 // indirect github.com/btcsuite/btcd/btcutil/psbt v1.1.4 // indirect - github.com/btcsuite/btcd/chaincfg/chainhash v1.0.1 // indirect + github.com/btcsuite/btcd/chaincfg/chainhash v1.0.1 github.com/btcsuite/btclog v0.0.0-20170628155309-84c8d2346e9f // indirect github.com/cpuguy83/go-md2man/v2 v2.0.2 // indirect + github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect github.com/decred/dcrd/crypto/blake256 v1.0.1 // indirect github.com/fsnotify/fsnotify v1.6.0 // indirect github.com/golang/protobuf v1.5.3 // indirect @@ -33,6 +36,7 @@ require ( github.com/magiconair/properties v1.8.7 // indirect github.com/mitchellh/mapstructure v1.5.0 // indirect github.com/pelletier/go-toml/v2 v2.1.0 // indirect + github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect github.com/russross/blackfriday/v2 v2.1.0 // indirect github.com/sagikazarmark/locafero v0.3.0 // indirect github.com/sagikazarmark/slog-shim v0.1.0 // indirect @@ -45,7 +49,6 @@ require ( github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673 // indirect go.uber.org/atomic v1.9.0 // indirect go.uber.org/multierr v1.9.0 // indirect - golang.org/x/crypto v0.14.0 // indirect golang.org/x/exp v0.0.0-20230905200255-921286631fa9 // indirect golang.org/x/net v0.17.0 // indirect golang.org/x/sys v0.13.0 // indirect diff --git a/internal/infrastructure/tx-builder/dummy/builder.go b/internal/infrastructure/tx-builder/dummy/builder.go new file mode 100644 index 0000000..8f4f720 --- /dev/null +++ b/internal/infrastructure/tx-builder/dummy/builder.go @@ -0,0 +1,223 @@ +package txbuilder + +import ( + "context" + "encoding/hex" + + "github.com/ark-network/ark/common" + "github.com/ark-network/ark/internal/core/domain" + "github.com/ark-network/ark/internal/core/ports" + "github.com/decred/dcrd/dcrec/secp256k1/v4" + "github.com/vulpemventures/go-elements/network" + "github.com/vulpemventures/go-elements/psetv2" +) + +const ( + connectorAmount = 450 +) + +type txBuilder struct { + net *network.Network + aspPublicKey *secp256k1.PublicKey +} + +func toElementsNetwork(net common.Network) *network.Network { + switch net { + case common.MainNet: + return &network.Liquid + case common.TestNet: + return &network.Testnet + default: + return nil + } +} + +func NewTxBuilder(aspPublicKey *secp256k1.PublicKey, net common.Network) ports.TxBuilder { + return &txBuilder{ + aspPublicKey: aspPublicKey, + net: toElementsNetwork(net), + } +} + +// BuildCongestionTree implements ports.TxBuilder. +func (b *txBuilder) BuildCongestionTree(poolTx string, payments []domain.Payment) (congestionTree []string, err error) { + poolTxID, err := getTxID(poolTx) + if err != nil { + return nil, err + } + + receivers := receiversFromPayments(payments) + + return buildCongestionTree( + newOutputScriptFactory(b.aspPublicKey, b.net), + b.net, + poolTxID, + receivers, + ) +} + +// BuildForfeitTxs implements ports.TxBuilder. +func (b *txBuilder) BuildForfeitTxs(poolTx string, payments []domain.Payment) (connectors []string, forfeitTxs []string, err error) { + poolTxID, err := getTxID(poolTx) + if err != nil { + return nil, nil, err + } + + aspScript, err := p2wpkhScript(b.aspPublicKey, b.net) + if err != nil { + return nil, nil, err + } + + numberOfConnectors := numberOfVTXOs(payments) + + connectors, err = createConnectors( + poolTxID, + 1, + psetv2.OutputArgs{ + Asset: b.net.AssetID, + Amount: connectorAmount, + Script: aspScript, + }, + aspScript, + numberOfConnectors, + ) + if err != nil { + return nil, nil, err + } + + connectorsAsInputs, err := connectorsToInputArgs(connectors) + if err != nil { + return nil, nil, err + } + + forfeitTxs = make([]string, 0) + for _, payment := range payments { + for _, vtxo := range payment.Inputs { + for _, connector := range connectorsAsInputs { + forfeitTx, err := createForfeitTx( + connector, + psetv2.InputArgs{ + Txid: vtxo.Txid, + TxIndex: vtxo.VOut, + }, + vtxo.Amount, + aspScript, + b.net, + ) + if err != nil { + return nil, nil, err + } + + forfeitTxs = append(forfeitTxs, forfeitTx) + } + } + } + + return connectors, forfeitTxs, nil +} + +// BuildPoolTx implements ports.TxBuilder. +func (b *txBuilder) BuildPoolTx(wallet ports.WalletService, payments []domain.Payment) (poolTx string, err error) { + aspScriptBytes, err := p2wpkhScript(b.aspPublicKey, b.net) + if err != nil { + return "", err + } + + aspScript := hex.EncodeToString(aspScriptBytes) + + receivers := receiversFromPayments(payments) + sharedOutputAmount := sumReceivers(receivers) + + numberOfConnectors := numberOfVTXOs(payments) + connectorOutputAmount := connectorAmount * numberOfConnectors + + ctx := context.Background() + + return wallet.Transaction().Transfer(ctx, []ports.TxOutput{ + newOutput(aspScript, sharedOutputAmount), + newOutput(aspScript, connectorOutputAmount), + }) +} + +func connectorsToInputArgs(connectors []string) ([]psetv2.InputArgs, error) { + inputs := make([]psetv2.InputArgs, 0, len(connectors)+1) + for i, psetb64 := range connectors { + txID, err := getTxID(psetb64) + if err != nil { + return nil, err + } + + input := psetv2.InputArgs{ + Txid: txID, + TxIndex: 0, + } + inputs = append(inputs, input) + + if i == len(connectors)-1 { + input := psetv2.InputArgs{ + Txid: txID, + TxIndex: 1, + } + inputs = append(inputs, input) + } + } + return inputs, nil +} + +func getTxID(psetBase64 string) (string, error) { + pset, err := psetv2.NewPsetFromBase64(psetBase64) + if err != nil { + return "", err + } + + utx, err := pset.UnsignedTx() + if err != nil { + return "", err + } + + return utx.TxHash().String(), nil +} + +func numberOfVTXOs(payments []domain.Payment) uint64 { + var sum uint64 + for _, payment := range payments { + sum += uint64(len(payment.Inputs)) + } + return sum +} + +func receiversFromPayments(payments []domain.Payment) []domain.Receiver { + receivers := make([]domain.Receiver, 0) + for _, payment := range payments { + receivers = append(receivers, payment.Receivers...) + } + return receivers +} + +func sumReceivers(receivers []domain.Receiver) uint64 { + var sum uint64 + for _, r := range receivers { + sum += r.Amount + } + return sum +} + +type output struct { + script string + amount uint64 +} + +func newOutput(script string, amount uint64) ports.TxOutput { + return &output{ + script: script, + amount: amount, + } +} + +func (o *output) GetAmount() uint64 { + return o.amount +} + +func (o *output) GetScript() string { + return o.script +} diff --git a/internal/infrastructure/tx-builder/dummy/builder_test.go b/internal/infrastructure/tx-builder/dummy/builder_test.go new file mode 100644 index 0000000..60cc8e5 --- /dev/null +++ b/internal/infrastructure/tx-builder/dummy/builder_test.go @@ -0,0 +1,307 @@ +package txbuilder_test + +import ( + "testing" + + "github.com/ark-network/ark/common" + "github.com/ark-network/ark/internal/core/domain" + "github.com/ark-network/ark/internal/core/ports" + txbuilder "github.com/ark-network/ark/internal/infrastructure/tx-builder/dummy" + "github.com/btcsuite/btcd/chaincfg/chainhash" + "github.com/stretchr/testify/require" + "github.com/vulpemventures/go-elements/address" + "github.com/vulpemventures/go-elements/network" + "github.com/vulpemventures/go-elements/payment" + "github.com/vulpemventures/go-elements/psetv2" +) + +const ( + testingKey = "apub1qgvdtj5ttpuhkldavhq8thtm5auyk0ec4dcmrfdgu0u5hgp9we22v3hrs4x" +) + +func createTestTxBuilder() (ports.TxBuilder, error) { + _, key, err := common.DecodePubKey(testingKey) + if err != nil { + return nil, err + } + + return txbuilder.NewTxBuilder(key, common.MainNet), nil +} + +func createTestPoolTx(sharedOutputAmount, numberOfInputs uint64) (string, error) { + _, key, err := common.DecodePubKey(testingKey) + if err != nil { + return "", err + } + + payment := payment.FromPublicKey(key, &network.Regtest, nil) + addr, err := payment.WitnessPubKeyHash() + if err != nil { + return "", err + } + + script, err := address.ToOutputScript(addr) + if err != nil { + return "", err + } + + pset, err := psetv2.New(nil, nil, nil) + if err != nil { + return "", err + } + + updater, err := psetv2.NewUpdater(pset) + if err != nil { + return "", err + } + + err = updater.AddInputs([]psetv2.InputArgs{ + { + Txid: "2f8f5733734fd44d581976bd3c1aee098bd606402df2ce02ce908287f1d5ede4", + TxIndex: 0, + }, + }) + if err != nil { + return "", err + } + + connectorsAmount := numberOfInputs * (450 + 500) + + err = updater.AddOutputs([]psetv2.OutputArgs{ + { + Asset: network.Regtest.AssetID, + Amount: sharedOutputAmount, + Script: script, + }, + { + Asset: network.Regtest.AssetID, + Amount: connectorsAmount, + Script: script, + }, + { + Asset: network.Regtest.AssetID, + Amount: 500, + }, + }) + if err != nil { + return "", err + } + + return pset.ToBase64() +} + +func TestBuildCongestionTree(t *testing.T) { + builder, err := createTestTxBuilder() + require.NoError(t, err) + + poolTx, err := createTestPoolTx(1000, (450+500)*1) + require.NoError(t, err) + + poolPset, err := psetv2.NewPsetFromBase64(poolTx) + require.NoError(t, err) + + poolTxUnsigned, err := poolPset.UnsignedTx() + require.NoError(t, err) + + poolTxID := poolTxUnsigned.TxHash().String() + + fixtures := []struct { + payments []domain.Payment + expectedNodesNum int // 2*len(receivers)-1 + }{ + { + payments: []domain.Payment{ + { + Id: "0", + Inputs: []domain.Vtxo{ + { + VtxoKey: domain.VtxoKey{ + Txid: "fd68e3c5796cc7db0a8036d486d5f625b6b2f2c014810ac020e1ac23e82c59d6", + VOut: 0, + }, + Receiver: domain.Receiver{ + Pubkey: "apub1qgvdtj5ttpuhkldavhq8thtm5auyk0ec4dcmrfdgu0u5hgp9we22v3hrs4x", + Amount: 600, + }, + }, + }, + Receivers: []domain.Receiver{ + { + Pubkey: "apub1qgvdtj5ttpuhkldavhq8thtm5auyk0ec4dcmrfdgu0u5hgp9we22v3hrs4x", + Amount: 600, + }, + { + Pubkey: "apub1qgvdtj5ttpuhkldavhq8thtm5auyk0ec4dcmrfdgu0u5hgp9we22v3hrs4x", + Amount: 400, + }, + }, + }, + }, + expectedNodesNum: 3, + }, + } + + for _, f := range fixtures { + tree, err := builder.BuildCongestionTree(poolTx, f.payments) + require.NoError(t, err) + require.Len(t, tree, f.expectedNodesNum) + + // decode all psbt + psets := make([]*psetv2.Pset, 0, f.expectedNodesNum) + + for _, pset := range tree { + pset, err := psetv2.NewPsetFromBase64(pset) + require.NoError(t, err) + require.NotNil(t, pset) + + psets = append(psets, pset) + } + + require.Len(t, psets[0].Inputs, 1) + require.Len(t, psets[0].Outputs, 2) + + // first tx input should be the pool tx shared output + inputTxID0, err := chainhash.NewHash(psets[0].Inputs[0].PreviousTxid) + require.NoError(t, err) + require.Equal(t, poolTxID, inputTxID0.String()) + require.Equal(t, uint32(0), psets[0].Inputs[0].PreviousTxIndex) + + unsignedTx0, err := psets[0].UnsignedTx() + require.NoError(t, err) + + txID0 := unsignedTx0.TxHash().String() + + // first tx input should be the first tx0 output + require.Len(t, psets[1].Inputs, 1) + require.Len(t, psets[1].Outputs, 1) + inputTxID1, err := chainhash.NewHash(psets[1].Inputs[0].PreviousTxid) + require.NoError(t, err) + require.Equal(t, txID0, inputTxID1.String()) + require.Equal(t, uint32(0), psets[1].Inputs[0].PreviousTxIndex) + // check the output amount (should be 600, the first receiver amount) + require.Equal(t, uint64(600), psets[1].Outputs[0].Value) + + // second tx input should be the second tx0 output + require.Len(t, psets[2].Inputs, 1) + require.Len(t, psets[2].Outputs, 1) + + inputTxID2, err := chainhash.NewHash(psets[2].Inputs[0].PreviousTxid) + require.NoError(t, err) + require.Equal(t, txID0, inputTxID2.String()) + require.Equal(t, uint32(1), psets[2].Inputs[0].PreviousTxIndex) + // check the output amount (should be 400, the second receiver amount) + require.Equal(t, uint64(400), psets[2].Outputs[0].Value) + } +} + +func TestBuildForfeitTxs(t *testing.T) { + builder, err := createTestTxBuilder() + require.NoError(t, err) + + poolTx, err := createTestPoolTx(1000, 450*2) + require.NoError(t, err) + + poolPset, err := psetv2.NewPsetFromBase64(poolTx) + require.NoError(t, err) + + poolTxUnsigned, err := poolPset.UnsignedTx() + require.NoError(t, err) + + poolTxID := poolTxUnsigned.TxHash().String() + + fixtures := []struct { + payments []domain.Payment + expectedNumOfForfeitTxs int + expectedNumOfConnectors int + }{ + { + payments: []domain.Payment{ + { + Id: "0", + Inputs: []domain.Vtxo{ + { + VtxoKey: domain.VtxoKey{ + Txid: "fd68e3c5796cc7db0a8036d486d5f625b6b2f2c014810ac020e1ac23e82c59d6", + VOut: 0, + }, + Receiver: domain.Receiver{ + Pubkey: "apub1qgvdtj5ttpuhkldavhq8thtm5auyk0ec4dcmrfdgu0u5hgp9we22v3hrs4x", + Amount: 600, + }, + }, + { + VtxoKey: domain.VtxoKey{ + Txid: "fd68e3c5796cc7db0a8036d486d5f625b6b2f2c014810ac020e1ac23e82c59d6", + VOut: 1, + }, + Receiver: domain.Receiver{ + Pubkey: "apub1qgvdtj5ttpuhkldavhq8thtm5auyk0ec4dcmrfdgu0u5hgp9we22v3hrs4x", + Amount: 400, + }, + }, + }, + Receivers: []domain.Receiver{ + { + Pubkey: "apub1qgvdtj5ttpuhkldavhq8thtm5auyk0ec4dcmrfdgu0u5hgp9we22v3hrs4x", + Amount: 600, + }, + { + Pubkey: "apub1qgvdtj5ttpuhkldavhq8thtm5auyk0ec4dcmrfdgu0u5hgp9we22v3hrs4x", + Amount: 400, + }, + }, + }, + }, + expectedNumOfForfeitTxs: 4, + expectedNumOfConnectors: 1, + }, + } + + for _, f := range fixtures { + connectors, forfeitTxs, err := builder.BuildForfeitTxs(poolTx, f.payments) + require.NoError(t, err) + + require.Len(t, connectors, f.expectedNumOfConnectors) + require.Len(t, forfeitTxs, f.expectedNumOfForfeitTxs) + + // decode and check connectors + connectorsPsets := make([]*psetv2.Pset, 0, f.expectedNumOfConnectors) + for _, pset := range connectors { + p, err := psetv2.NewPsetFromBase64(pset) + require.NoError(t, err) + connectorsPsets = append(connectorsPsets, p) + } + + for i, pset := range connectorsPsets { + require.Len(t, pset.Inputs, 1) + require.Len(t, pset.Outputs, 2) + + expectedInputTxid := poolTxID + expectedInputVout := uint32(1) + if i > 0 { + tx, err := connectorsPsets[i-1].UnsignedTx() + require.NoError(t, err) + require.NotNil(t, tx) + expectedInputTxid = tx.TxHash().String() + } + + inputTxid := chainhash.Hash(pset.Inputs[0].PreviousTxid).String() + require.Equal(t, expectedInputTxid, inputTxid) + require.Equal(t, expectedInputVout, pset.Inputs[0].PreviousTxIndex) + } + + // decode and check forfeit txs + forfeitTxsPsets := make([]*psetv2.Pset, 0, f.expectedNumOfForfeitTxs) + for _, pset := range forfeitTxs { + p, err := psetv2.NewPsetFromBase64(pset) + require.NoError(t, err) + forfeitTxsPsets = append(forfeitTxsPsets, p) + } + + // each forfeit tx should have 2 inputs and 2 outputs + for _, pset := range forfeitTxsPsets { + require.Len(t, pset.Inputs, 2) + require.Len(t, pset.Outputs, 1) + } + } +} diff --git a/internal/infrastructure/tx-builder/dummy/connectors.go b/internal/infrastructure/tx-builder/dummy/connectors.go new file mode 100644 index 0000000..f56d1b8 --- /dev/null +++ b/internal/infrastructure/tx-builder/dummy/connectors.go @@ -0,0 +1,75 @@ +package txbuilder + +import ( + "github.com/vulpemventures/go-elements/psetv2" +) + +func createConnectors( + poolTxID string, + connectorOutputIndex uint32, + connectorOutput psetv2.OutputArgs, + changeScript []byte, + numberOfConnectors uint64, +) (connectorsPsets []string, err error) { + previousInput := psetv2.InputArgs{ + Txid: poolTxID, + TxIndex: connectorOutputIndex, + } + + // compute the initial amount of the connectors output in pool transaction + remainingAmount := connectorAmount * numberOfConnectors + + connectorsPset := make([]string, 0, numberOfConnectors-1) + for i := uint64(0); i < numberOfConnectors-1; i++ { + // create a new pset + pset, err := psetv2.New(nil, nil, nil) + if err != nil { + return nil, err + } + + updater, err := psetv2.NewUpdater(pset) + if err != nil { + return nil, err + } + + err = updater.AddInputs([]psetv2.InputArgs{previousInput}) + if err != nil { + return nil, err + } + + err = updater.AddOutputs([]psetv2.OutputArgs{connectorOutput}) + if err != nil { + return nil, err + } + + changeAmount := remainingAmount - connectorOutput.Amount + if changeAmount > 0 { + changeOutput := psetv2.OutputArgs{ + Asset: connectorOutput.Asset, + Amount: changeAmount, + Script: changeScript, + } + err = updater.AddOutputs([]psetv2.OutputArgs{changeOutput}) + if err != nil { + return nil, err + } + tx, _ := pset.UnsignedTx() + txid := tx.TxHash().String() + + // make the change the next previousInput + previousInput = psetv2.InputArgs{ + Txid: txid, + TxIndex: 1, + } + } + + base64, err := pset.ToBase64() + if err != nil { + return nil, err + } + + connectorsPset = append(connectorsPset, base64) + } + + return connectorsPset, nil +} diff --git a/internal/infrastructure/tx-builder/dummy/forfeit.go b/internal/infrastructure/tx-builder/dummy/forfeit.go new file mode 100644 index 0000000..6db4749 --- /dev/null +++ b/internal/infrastructure/tx-builder/dummy/forfeit.go @@ -0,0 +1,42 @@ +package txbuilder + +import ( + "github.com/vulpemventures/go-elements/network" + "github.com/vulpemventures/go-elements/psetv2" +) + +func createForfeitTx( + connectorInput psetv2.InputArgs, + vtxoInput psetv2.InputArgs, + vtxoAmount uint64, + aspScript []byte, + net *network.Network, +) (forfeitTx string, err error) { + pset, err := psetv2.New(nil, nil, nil) + if err != nil { + return "", err + } + + updater, err := psetv2.NewUpdater(pset) + if err != nil { + return "", err + } + + err = updater.AddInputs([]psetv2.InputArgs{connectorInput, vtxoInput}) + if err != nil { + return "", err + } + + err = updater.AddOutputs([]psetv2.OutputArgs{ + { + Asset: net.AssetID, + Amount: vtxoAmount, + Script: aspScript, + }, + }) + if err != nil { + return "", err + } + + return pset.ToBase64() +} diff --git a/internal/infrastructure/tx-builder/dummy/tree.go b/internal/infrastructure/tx-builder/dummy/tree.go new file mode 100644 index 0000000..b1cc430 --- /dev/null +++ b/internal/infrastructure/tx-builder/dummy/tree.go @@ -0,0 +1,270 @@ +package txbuilder + +import ( + "github.com/ark-network/ark/common" + "github.com/ark-network/ark/internal/core/domain" + "github.com/decred/dcrd/dcrec/secp256k1/v4" + "github.com/vulpemventures/go-elements/address" + "github.com/vulpemventures/go-elements/network" + "github.com/vulpemventures/go-elements/payment" + "github.com/vulpemventures/go-elements/psetv2" +) + +const ( + sharedOutputIndex = 0 +) + +type outputScriptFactory func(leaves []domain.Receiver) ([]byte, error) + +func p2wpkhScript(publicKey *secp256k1.PublicKey, net *network.Network) ([]byte, error) { + payment := payment.FromPublicKey(publicKey, net, nil) + addr, err := payment.WitnessPubKeyHash() + if err != nil { + return nil, err + } + + return address.ToOutputScript(addr) +} + +// newOtputScriptFactory returns an output script factory func that lock funds using the ASP public key only on all branches psbt. The leaves are instead locked by the leaf public key. +func newOutputScriptFactory(aspPublicKey *secp256k1.PublicKey, net *network.Network) outputScriptFactory { + return func(leaves []domain.Receiver) ([]byte, error) { + aspScript, err := p2wpkhScript(aspPublicKey, net) + if err != nil { + return nil, err + } + + switch len(leaves) { + case 0: + return nil, nil + case 1: // it's a leaf + _, key, err := common.DecodePubKey(leaves[0].Pubkey) + if err != nil { + return nil, err + } + + return p2wpkhScript(key, net) + default: // it's a branch, lock funds with ASP public key + return aspScript, nil + } + } +} + +// congestionTree builder iteratively creates a binary tree of Pset from a set of receivers +// it also expect createOutputScript func managing the output script creation and the network to use (mainly for L-BTC asset id) +func buildCongestionTree( + createOutputScript outputScriptFactory, + net *network.Network, + poolTxID string, + receivers []domain.Receiver, +) (congestionTree []string, err error) { + var nodes []*node + + for _, r := range receivers { + nodes = append(nodes, newLeaf(createOutputScript, net, r)) + } + + for len(nodes) > 1 { + nodes, err = createTreeLevel(nodes) + if err != nil { + return nil, err + } + } + + var tree []string + + psets, err := nodes[0].psets(psetv2.InputArgs{ + Txid: poolTxID, + TxIndex: sharedOutputIndex, + }) + if err != nil { + return nil, err + } + + for _, pset := range psets { + psetB64, err := pset.ToBase64() + if err != nil { + return nil, err + } + tree = append(tree, psetB64) + } + + return tree, nil +} + +func createTreeLevel(nodes []*node) ([]*node, error) { + if len(nodes)%2 != 0 { + last := nodes[len(nodes)-1] + pairs, err := createTreeLevel(nodes[:len(nodes)-1]) + if err != nil { + return nil, err + } + + return append(pairs, last), nil + } + + pairs := make([]*node, 0, len(nodes)/2) + for i := 0; i < len(nodes); i += 2 { + pairs = append(pairs, newBranch(nodes[i], nodes[i+1])) + } + return pairs, nil +} + +// internal struct to build a binary tree of Pset +type node struct { + receivers []domain.Receiver + left *node + right *node + createOutputScript outputScriptFactory + network *network.Network +} + +// create a node from a single receiver +func newLeaf( + createOutputScript outputScriptFactory, + network *network.Network, + receiver domain.Receiver, +) *node { + return &node{ + receivers: []domain.Receiver{receiver}, + createOutputScript: createOutputScript, + network: network, + left: nil, + right: nil, + } +} + +// aggregate two nodes into a branch node +func newBranch( + left *node, + right *node, +) *node { + return &node{ + receivers: append(left.receivers, right.receivers...), + createOutputScript: left.createOutputScript, + network: left.network, + left: left, + right: right, + } +} + +// is it the final node of the tree +func (n *node) isLeaf() bool { + return len(n.receivers) == 1 +} + +// compute the output amount of a node +func (n *node) amount() uint64 { + var amount uint64 + for _, r := range n.receivers { + amount += r.Amount + } + return amount +} + +// compute the output script of a node +func (n *node) script() ([]byte, error) { + return n.createOutputScript(n.receivers) +} + +// use script & amount to create OutputArgs +func (n *node) output() (*psetv2.OutputArgs, error) { + script, err := n.script() + if err != nil { + return nil, err + } + + return &psetv2.OutputArgs{ + Asset: n.network.AssetID, + Amount: n.amount(), + Script: script, + }, nil +} + +// create the node Pset from the previous node Pset represented by input arg +// if node is a branch, it adds two outputs to the Pset, one for the left branch and one for the right branch +// if node is a leaf, it only adds one output to the Pset (the node output) +func (n *node) pset(input psetv2.InputArgs) (*psetv2.Pset, error) { + pset, err := psetv2.New(nil, nil, nil) + if err != nil { + return nil, err + } + + updater, err := psetv2.NewUpdater(pset) + if err != nil { + return nil, err + } + + err = updater.AddInputs([]psetv2.InputArgs{input}) + if err != nil { + return nil, err + } + + if n.isLeaf() { + output, err := n.output() + if err != nil { + return nil, err + } + + err = updater.AddOutputs([]psetv2.OutputArgs{*output}) + if err != nil { + return nil, err + } + return pset, nil + } + + outputLeft, err := n.left.output() + if err != nil { + return nil, err + } + + outputRight, err := n.right.output() + if err != nil { + return nil, err + } + + err = updater.AddOutputs([]psetv2.OutputArgs{*outputLeft, *outputRight}) + if err != nil { + return nil, err + } + + return pset, nil +} + +// create the node pset and all the psets of its children recursively, updating the input arg at each step +// the function stops when it reaches a leaf node +func (n *node) psets(input psetv2.InputArgs) ([]*psetv2.Pset, error) { + pset, err := n.pset(input) + if err != nil { + return nil, err + } + + if n.isLeaf() { + return []*psetv2.Pset{pset}, nil + } + + unsignedTx, err := pset.UnsignedTx() + if err != nil { + return nil, err + } + + txID := unsignedTx.TxHash().String() + + psetsLeft, err := n.left.psets(psetv2.InputArgs{ + Txid: txID, + TxIndex: 0, + }) + if err != nil { + return nil, err + } + + psetsRight, err := n.right.psets(psetv2.InputArgs{ + Txid: txID, + TxIndex: 1, + }) + if err != nil { + return nil, err + } + + return append([]*psetv2.Pset{pset}, append(psetsLeft, psetsRight...)...), nil +}