diff --git a/asp/internal/core/application/service.go b/asp/internal/core/application/service.go index 07eff7e..113c47a 100644 --- a/asp/internal/core/application/service.go +++ b/asp/internal/core/application/service.go @@ -266,7 +266,7 @@ func (s *service) startFinalization() { return } - signedPoolTx, tree, err := s.builder.BuildPoolTx(s.pubkey, s.wallet, payments, s.minRelayFee) + unsignedPoolTx, tree, err := s.builder.BuildPoolTx(s.pubkey, s.wallet, payments, s.minRelayFee) if err != nil { changes = round.Fail(fmt.Errorf("failed to create pool tx: %s", err)) log.WithError(err).Warn("failed to create pool tx") @@ -275,7 +275,7 @@ func (s *service) startFinalization() { log.Debugf("pool tx created for round %s", round.Id) - connectors, forfeitTxs, err := s.builder.BuildForfeitTxs(s.pubkey, signedPoolTx, payments) + connectors, forfeitTxs, err := s.builder.BuildForfeitTxs(s.pubkey, unsignedPoolTx, payments) if err != nil { changes = round.Fail(fmt.Errorf("failed to create connectors and forfeit txs: %s", err)) log.WithError(err).Warn("failed to create connectors and forfeit txs") @@ -284,7 +284,7 @@ func (s *service) startFinalization() { log.Debugf("forfeit transactions created for round %s", round.Id) - events, err := round.StartFinalization(connectors, tree, signedPoolTx) + events, err := round.StartFinalization(connectors, tree, unsignedPoolTx) if err != nil { changes = round.Fail(fmt.Errorf("failed to start finalization: %s", err)) log.WithError(err).Warn("failed to start finalization") @@ -326,7 +326,15 @@ func (s *service) finalizeRound() { return } - txid, err := s.wallet.BroadcastTransaction(ctx, round.TxHex) + log.Debugf("signing round transaction %s\n", round.Id) + signedPoolTx, err := s.wallet.SignPset(ctx, round.UnsignedTx, true) + if err != nil { + changes = round.Fail(fmt.Errorf("failed to sign round tx: %s", err)) + log.WithError(err).Warn("failed to sign round tx") + return + } + + txid, err := s.wallet.BroadcastTransaction(ctx, signedPoolTx) if err != nil { changes = round.Fail(fmt.Errorf("failed to broadcast pool tx: %s", err)) log.WithError(err).Warn("failed to broadcast pool tx") diff --git a/asp/internal/core/domain/round.go b/asp/internal/core/domain/round.go index 7bcbfc2..35577ea 100644 --- a/asp/internal/core/domain/round.go +++ b/asp/internal/core/domain/round.go @@ -40,7 +40,7 @@ type Round struct { Stage Stage Payments map[string]Payment Txid string - TxHex string + UnsignedTx string ForfeitTxs []string CongestionTree tree.CongestionTree Connectors []string @@ -84,7 +84,7 @@ func (r *Round) On(event RoundEvent, replayed bool) { r.Stage.Code = FinalizationStage r.CongestionTree = e.CongestionTree r.Connectors = append([]string{}, e.Connectors...) - r.TxHex = e.PoolTx + r.UnsignedTx = e.PoolTx case RoundFinalized: r.Stage.Ended = true r.Txid = e.Txid diff --git a/asp/internal/core/ports/wallet.go b/asp/internal/core/ports/wallet.go index 9ea1021..c037939 100644 --- a/asp/internal/core/ports/wallet.go +++ b/asp/internal/core/ports/wallet.go @@ -13,8 +13,9 @@ type WalletService interface { SignPset( ctx context.Context, pset string, extractRawTx bool, ) (string, error) - Transfer(ctx context.Context, outs []TxOutput) (string, error) + SelectUtxos(ctx context.Context, asset string, amount uint64) ([]TxInput, uint64, error) BroadcastTransaction(ctx context.Context, txHex string) (string, error) + EstimateFees(ctx context.Context, pset string) (uint64, error) Close() } @@ -28,12 +29,6 @@ type TxInput interface { GetTxid() string GetIndex() uint32 GetScript() string - GetScriptSigSize() int - GetWitnessSize() int -} - -type TxOutput interface { - GetAmount() uint64 GetAsset() string - GetScript() string + GetValue() uint64 } diff --git a/asp/internal/infrastructure/db/service_test.go b/asp/internal/infrastructure/db/service_test.go index c40e027..c3771b8 100644 --- a/asp/internal/infrastructure/db/service_test.go +++ b/asp/internal/infrastructure/db/service_test.go @@ -377,7 +377,7 @@ func roundsMatch(expected, got domain.Round) assert.Comparison { if expected.Txid != got.Txid { return false } - if expected.TxHex != got.TxHex { + if expected.UnsignedTx != got.UnsignedTx { return false } if !reflect.DeepEqual(expected.ForfeitTxs, got.ForfeitTxs) { diff --git a/asp/internal/infrastructure/ocean-wallet/transaction.go b/asp/internal/infrastructure/ocean-wallet/transaction.go index 8253532..a0a8f08 100644 --- a/asp/internal/infrastructure/ocean-wallet/transaction.go +++ b/asp/internal/infrastructure/ocean-wallet/transaction.go @@ -2,14 +2,19 @@ package oceanwallet import ( "context" + "encoding/hex" "fmt" pb "github.com/ark-network/ark/api-spec/protobuf/gen/ocean/v1" "github.com/ark-network/ark/internal/core/ports" + "github.com/btcsuite/btcd/chaincfg/chainhash" + "github.com/vulpemventures/go-elements/elementsutil" "github.com/vulpemventures/go-elements/psetv2" ) -const msatsPerByte = 110 +const ( + zero32 = "0000000000000000000000000000000000000000000000000000000000000000" +) func (s *service) SignPset( ctx context.Context, pset string, extractRawTx bool, @@ -21,29 +26,54 @@ func (s *service) SignPset( return "", err } signedPset := res.GetPset() + if !extractRawTx { return signedPset, nil } - ptx, _ := psetv2.NewPsetFromBase64(signedPset) - if err := psetv2.MaybeFinalizeAll(ptx); err != nil { - return "", fmt.Errorf("failed to finalize signed pset: %s", err) - } - return ptx.ToBase64() -} - -func (s *service) Transfer( - ctx context.Context, outs []ports.TxOutput, -) (string, error) { - res, err := s.txClient.Transfer(ctx, &pb.TransferRequest{ - AccountName: accountLabel, - Receivers: outputList(outs).toProto(), - MillisatsPerByte: msatsPerByte, - }) + ptx, err := psetv2.NewPsetFromBase64(signedPset) if err != nil { return "", err } - return res.GetTxHex(), nil + + if err := psetv2.MaybeFinalizeAll(ptx); err != nil { + return "", fmt.Errorf("failed to finalize signed pset: %s", err) + } + + extractedTx, err := psetv2.Extract(ptx) + if err != nil { + return "", fmt.Errorf("failed to extract signed pset: %s", err) + } + + txHex, err := extractedTx.ToHex() + if err != nil { + return "", fmt.Errorf("failed to convert extracted tx to hex: %s", err) + } + + return txHex, nil +} + +func (s *service) SelectUtxos(ctx context.Context, asset string, amount uint64) ([]ports.TxInput, uint64, error) { + res, err := s.txClient.SelectUtxos(ctx, &pb.SelectUtxosRequest{ + AccountName: accountLabel, + TargetAsset: asset, + TargetAmount: amount, + }) + if err != nil { + return nil, 0, err + } + + inputs := make([]ports.TxInput, 0, len(res.GetUtxos())) + for _, utxo := range res.GetUtxos() { + // check that the utxos are not confidential + if utxo.GetAssetBlinder() != zero32 || utxo.GetValueBlinder() != zero32 { + return nil, 0, fmt.Errorf("utxo is confidential") + } + + inputs = append(inputs, utxo) + } + + return inputs, res.GetChange(), nil } func (s *service) BroadcastTransaction( @@ -60,16 +90,50 @@ func (s *service) BroadcastTransaction( return res.GetTxid(), nil } -type outputList []ports.TxOutput +func (s *service) EstimateFees( + ctx context.Context, pset string, +) (uint64, error) { + tx, err := psetv2.NewPsetFromBase64(pset) + if err != nil { + return 0, err + } -func (l outputList) toProto() []*pb.Output { - list := make([]*pb.Output, 0, len(l)) - for _, out := range l { - list = append(list, &pb.Output{ - Amount: out.GetAmount(), - Script: out.GetScript(), - Asset: out.GetAsset(), + inputs := make([]*pb.Input, 0, len(tx.Inputs)) + outputs := make([]*pb.Output, 0, len(tx.Outputs)) + + for _, in := range tx.Inputs { + if in.WitnessUtxo == nil { + return 0, fmt.Errorf("missing witness utxo, cannot estimate fees") + } + + inputs = append(inputs, &pb.Input{ + Txid: chainhash.Hash(in.PreviousTxid).String(), + Index: in.PreviousTxIndex, + Script: hex.EncodeToString(in.WitnessUtxo.Script), }) } - return list + + for _, out := range tx.Outputs { + outputs = append(outputs, &pb.Output{ + Asset: elementsutil.AssetHashFromBytes( + append([]byte{0x01}, out.Asset...), + ), + Amount: out.Value, + Script: hex.EncodeToString(out.Script), + }) + } + + fee, err := s.txClient.EstimateFees( + ctx, + &pb.EstimateFeesRequest{ + Inputs: inputs, + Outputs: outputs, + }, + ) + if err != nil { + return 0, fmt.Errorf("failed to estimate fees: %s", err) + } + + // we add 5 sats in order to avoid min-relay-fee not met errors + return fee.GetFeeAmount() + 5, nil } diff --git a/asp/internal/infrastructure/tx-builder/covenant/builder.go b/asp/internal/infrastructure/tx-builder/covenant/builder.go index 5dde242..d01673e 100644 --- a/asp/internal/infrastructure/tx-builder/covenant/builder.go +++ b/asp/internal/infrastructure/tx-builder/covenant/builder.go @@ -3,10 +3,12 @@ package txbuilder import ( "context" "encoding/hex" + "fmt" "github.com/ark-network/ark/common/tree" "github.com/ark-network/ark/internal/core/domain" "github.com/ark-network/ark/internal/core/ports" + "github.com/btcsuite/btcd/txscript" "github.com/decred/dcrd/dcrec/secp256k1/v4" "github.com/vulpemventures/go-elements/address" "github.com/vulpemventures/go-elements/elementsutil" @@ -21,6 +23,8 @@ const ( connectorAmount = 450 ) +var emptyNonce = []byte{0x00} + type txBuilder struct { net *network.Network } @@ -44,11 +48,7 @@ func p2wpkhScript(publicKey *secp256k1.PublicKey, net *network.Network) ([]byte, func getTxid(txStr string) (string, error) { pset, err := psetv2.NewPsetFromBase64(txStr) if err != nil { - tx, err := transaction.NewTxFromHex(txStr) - if err != nil { - return "", err - } - return tx.TxHash().String(), nil + return "", err } utx, err := pset.UnsignedTx() @@ -116,7 +116,7 @@ func (b *txBuilder) BuildForfeitTxs( pubkeyBytes, err := hex.DecodeString(vtxo.Pubkey) if err != nil { - return nil, nil, err + return nil, nil, fmt.Errorf("failed to decode pubkey: %s", err) } vtxoPubkey, err := secp256k1.ParsePubKey(pubkeyBytes) @@ -171,8 +171,6 @@ func (b *txBuilder) BuildPoolTx( return } - aspScript := hex.EncodeToString(aspScriptBytes) - offchainReceivers, onchainReceivers := receiversFromPayments(payments) numberOfConnectors := numberOfVTXOs(payments) connectorOutputAmount := connectorAmount * numberOfConnectors @@ -189,38 +187,165 @@ func (b *txBuilder) BuildPoolTx( return } - sharedOutputScriptHex := hex.EncodeToString(sharedOutputScript) - - poolTxOuts := []ports.TxOutput{ - newOutput(sharedOutputScriptHex, sharedOutputAmount, b.net.AssetID), - newOutput(aspScript, connectorOutputAmount, b.net.AssetID), + outputs := []psetv2.OutputArgs{ + { + Asset: b.net.AssetID, + Amount: sharedOutputAmount, + Script: sharedOutputScript, + }, + { + Asset: b.net.AssetID, + Amount: connectorOutputAmount, + Script: aspScriptBytes, + }, } + targetAmount := sharedOutputAmount + connectorOutputAmount + for _, receiver := range onchainReceivers { - buf, _ := address.ToOutputScript(receiver.OnchainAddress) - script := hex.EncodeToString(buf) - poolTxOuts = append(poolTxOuts, newOutput(script, receiver.Amount, b.net.AssetID)) + targetAmount += receiver.Amount + + receiverScript, err := address.ToOutputScript(receiver.OnchainAddress) + if err != nil { + return "", nil, err + } + + outputs = append(outputs, psetv2.OutputArgs{ + Asset: b.net.AssetID, + Amount: receiver.Amount, + Script: receiverScript, + }) } - txHex, err := wallet.Transfer(ctx, poolTxOuts) + utxos, change, err := wallet.SelectUtxos(ctx, b.net.AssetID, targetAmount) if err != nil { return } - tx, err := transaction.NewTxFromHex(txHex) + if change > 0 { + outputs = append(outputs, psetv2.OutputArgs{ + Asset: b.net.AssetID, + Amount: change, + Script: aspScriptBytes, + }) + } + + ptx, err := psetv2.New(toInputArgs(utxos), outputs, nil) + if err != nil { + return + } + + updater, err := psetv2.NewUpdater(ptx) + if err != nil { + return + } + + for i, utxo := range utxos { + witnessUtxo, err := toWitnessUtxo(utxo) + if err != nil { + return "", nil, err + } + + if err := updater.AddInWitnessUtxo(i, witnessUtxo); err != nil { + return "", nil, err + } + + if err := updater.AddInSighashType(i, txscript.SigHashAll); err != nil { + return "", nil, err + } + } + + b64, err := ptx.ToBase64() + if err != nil { + return + } + + feesAmount, err := wallet.EstimateFees(ctx, b64) + if err != nil { + return + } + + if feesAmount == change { + // fees = change, remove change output + updater.Pset.Outputs = ptx.Outputs[:len(ptx.Outputs)-1] + } else if feesAmount < change { + // change covers the fees, reduce change amount + updater.Pset.Outputs[len(ptx.Outputs)-1].Value = change - feesAmount + } else { + // change is not enough to cover fees, re-select utxos + if change > 0 { + // remove change output if present + updater.Pset.Outputs = ptx.Outputs[:len(ptx.Outputs)-1] + } + newUtxos, newChange, err := wallet.SelectUtxos(ctx, b.net.AssetID, feesAmount-change) + if err != nil { + return "", nil, err + } + + if err := updater.AddInputs(toInputArgs(newUtxos)); err != nil { + return "", nil, err + } + + if newChange > 0 { + if err := updater.AddOutputs([]psetv2.OutputArgs{ + { + Asset: b.net.AssetID, + Amount: newChange, + Script: aspScriptBytes, + }, + }); err != nil { + return "", nil, err + } + } + + nbInputs := len(utxos) + + for i, utxo := range newUtxos { + + witnessUtxo, err := toWitnessUtxo(utxo) + if err != nil { + return "", nil, err + } + + if err := updater.AddInWitnessUtxo(i+nbInputs, witnessUtxo); err != nil { + return "", nil, err + } + + if err := updater.AddInSighashType(i+nbInputs, txscript.SigHashAll); err != nil { + return "", nil, err + } + } + + } + + // add fee output + if err := updater.AddOutputs([]psetv2.OutputArgs{ + { + Asset: b.net.AssetID, + Amount: feesAmount, + }, + }); err != nil { + return "", nil, err + } + + utx, err := ptx.UnsignedTx() if err != nil { return } tree, err := makeTree(psetv2.InputArgs{ - Txid: tx.TxHash().String(), + Txid: utx.TxHash().String(), TxIndex: 0, }) if err != nil { return } - poolTx = txHex + poolTx, err = ptx.ToBase64() + if err != nil { + return + } + congestionTree = tree return } @@ -321,28 +446,41 @@ func receiversFromPayments( return } -type output struct { - script string - amount uint64 - asset string -} - -func newOutput(script string, amount uint64, asset string) ports.TxOutput { - return &output{ - script: script, - amount: amount, - asset: asset, +func toInputArgs( + ins []ports.TxInput, +) []psetv2.InputArgs { + inputs := make([]psetv2.InputArgs, 0, len(ins)) + for _, in := range ins { + inputs = append(inputs, psetv2.InputArgs{ + Txid: in.GetTxid(), + TxIndex: in.GetIndex(), + }) } + return inputs } -func (o *output) GetAsset() string { - return o.asset -} +func toWitnessUtxo(in ports.TxInput) (*transaction.TxOutput, error) { + valueBytes, err := elementsutil.ValueToBytes(in.GetValue()) + if err != nil { + return nil, fmt.Errorf("failed to convert value to bytes: %s", err) + } -func (o *output) GetAmount() uint64 { - return o.amount -} + assetBytes, err := elementsutil.AssetHashToBytes(in.GetAsset()) + if err != nil { + return nil, fmt.Errorf("failed to convert asset to bytes: %s", err) + } -func (o *output) GetScript() string { - return o.script + scriptBytes, err := hex.DecodeString(in.GetScript()) + if err != nil { + return nil, fmt.Errorf("failed to decode script: %s", err) + } + + return &transaction.TxOutput{ + Asset: assetBytes, + Value: valueBytes, + Script: scriptBytes, + Nonce: emptyNonce, + RangeProof: nil, + SurjectionProof: nil, + }, err } diff --git a/asp/internal/infrastructure/tx-builder/covenant/builder_test.go b/asp/internal/infrastructure/tx-builder/covenant/builder_test.go index af2cc69..c7cfeda 100644 --- a/asp/internal/infrastructure/tx-builder/covenant/builder_test.go +++ b/asp/internal/infrastructure/tx-builder/covenant/builder_test.go @@ -2,6 +2,8 @@ package txbuilder_test import ( "context" + "crypto/rand" + "encoding/hex" "testing" "github.com/ark-network/ark/common" @@ -13,75 +15,40 @@ import ( secp256k1 "github.com/decred/dcrd/dcrec/secp256k1/v4" "github.com/stretchr/testify/require" "github.com/vulpemventures/go-elements/network" - "github.com/vulpemventures/go-elements/payment" "github.com/vulpemventures/go-elements/psetv2" - "github.com/vulpemventures/go-elements/transaction" ) const ( testingKey = "apub1qgvdtj5ttpuhkldavhq8thtm5auyk0ec4dcmrfdgu0u5hgp9we22v3hrs4x" + fakePoolTx = "cHNldP8BAgQCAAAAAQQBAQEFAQMBBgEDAfsEAgAAAAABDiDk7dXxh4KQzgLO8i1ABtaLCe4aPL12GVhN1E9zM1ePLwEPBAAAAAABEAT/////AAEDCOgDAAAAAAAAAQQWABSNnpy01UJqd99eTg2M1IpdKId11gf8BHBzZXQCICWyUQcOKcoZBDzzPM1zJOLdqwPsxK4LXnfE/A5c9slaB/wEcHNldAgEAAAAAAABAwh4BQAAAAAAAAEEFgAUjZ6ctNVCanffXk4NjNSKXSiHddYH/ARwc2V0AiAlslEHDinKGQQ88zzNcyTi3asD7MSuC153xPwOXPbJWgf8BHBzZXQIBAAAAAAAAQMI9AEAAAAAAAABBAAH/ARwc2V0AiAlslEHDinKGQQ88zzNcyTi3asD7MSuC153xPwOXPbJWgf8BHBzZXQIBAAAAAAA" ) -func createTestPoolTx(sharedOutputAmount, numberOfInputs uint64) (string, error) { - _, key, err := common.DecodePubKey(testingKey) - if err != nil { - return "", err - } +type mockedWalletService struct{} - payment := payment.FromPublicKey(key, &network.Testnet, nil) - script := payment.WitnessScript - - pset, err := psetv2.New(nil, nil, nil) - if err != nil { - return "", err - } - - updater, err := psetv2.NewUpdater(pset) - if err != nil { - return "", err - } - - err = updater.AddInputs([]psetv2.InputArgs{ - { - Txid: "2f8f5733734fd44d581976bd3c1aee098bd606402df2ce02ce908287f1d5ede4", - TxIndex: 0, - }, - }) - if err != nil { - return "", err - } - - connectorsAmount := numberOfInputs*450 + 500 - - err = updater.AddOutputs([]psetv2.OutputArgs{ - { - Asset: network.Regtest.AssetID, - Amount: sharedOutputAmount, - Script: script, - }, - { - Asset: network.Regtest.AssetID, - Amount: connectorsAmount, - Script: script, - }, - { - Asset: network.Regtest.AssetID, - Amount: 500, - }, - }) - if err != nil { - return "", err - } - - utx, err := pset.UnsignedTx() - if err != nil { - return "", err - } - - return utx.ToHex() +type input struct { + txid string + vout uint32 } -type mockedWalletService struct{} +func (i *input) GetTxid() string { + return i.txid +} + +func (i *input) GetIndex() uint32 { + return i.vout +} + +func (i *input) GetScript() string { + return "a914ea9f486e82efb3dd83a69fd96e3f0113757da03c87" +} + +func (i *input) GetAsset() string { + return "5ac9f65c0efcc4775e0baec4ec03abdde22473cd3cf33c0419ca290e0751b225" +} + +func (i *input) GetValue() uint64 { + return 1000 +} // BroadcastTransaction implements ports.WalletService. func (*mockedWalletService) BroadcastTransaction(ctx context.Context, txHex string) (string, error) { @@ -113,9 +80,22 @@ func (*mockedWalletService) Status(ctx context.Context) (ports.WalletStatus, err panic("unimplemented") } -// Transfer implements ports.WalletService. -func (*mockedWalletService) Transfer(ctx context.Context, outs []ports.TxOutput) (string, error) { - return createTestPoolTx(outs[0].GetAmount(), 1) +func (*mockedWalletService) SelectUtxos(ctx context.Context, asset string, amount uint64) ([]ports.TxInput, uint64, error) { + // random txid + bytes := make([]byte, 32) + if _, err := rand.Read(bytes); err != nil { + return nil, 0, err + } + fakeInput := input{ + txid: hex.EncodeToString(bytes), + vout: 0, + } + + return []ports.TxInput{&fakeInput}, 0, nil +} + +func (*mockedWalletService) EstimateFees(ctx context.Context, pset string) (uint64, error) { + return 100, nil } func TestBuildCongestionTree(t *testing.T) { @@ -373,14 +353,13 @@ func TestBuildCongestionTree(t *testing.T) { func TestBuildForfeitTxs(t *testing.T) { builder := txbuilder.NewTxBuilder(network.Liquid) - // TODO: replace with fixture. - poolTxHex, err := createTestPoolTx(1000, 2) + poolTx, err := psetv2.NewPsetFromBase64(fakePoolTx) require.NoError(t, err) - poolTx, err := transaction.NewTxFromHex(poolTxHex) + utx, err := poolTx.UnsignedTx() require.NoError(t, err) - poolTxid := poolTx.TxHash().String() + poolTxid := utx.TxHash().String() fixtures := []struct { payments []domain.Payment @@ -436,7 +415,7 @@ func TestBuildForfeitTxs(t *testing.T) { for _, f := range fixtures { connectors, forfeitTxs, err := builder.BuildForfeitTxs( - key, poolTxHex, f.payments, + key, fakePoolTx, f.payments, ) require.NoError(t, err) require.Len(t, connectors, f.expectedNumOfConnectors) diff --git a/asp/internal/infrastructure/tx-builder/dummy/builder.go b/asp/internal/infrastructure/tx-builder/dummy/builder.go index 844df9d..56f5585 100644 --- a/asp/internal/infrastructure/tx-builder/dummy/builder.go +++ b/asp/internal/infrastructure/tx-builder/dummy/builder.go @@ -2,7 +2,6 @@ package txbuilder import ( "context" - "encoding/hex" "github.com/ark-network/ark/common/tree" "github.com/ark-network/ark/internal/core/domain" @@ -99,42 +98,81 @@ func (b *txBuilder) BuildPoolTx( return "", nil, err } - aspScript := hex.EncodeToString(aspScriptBytes) - offchainReceivers, onchainReceivers := receiversFromPayments(payments) sharedOutputAmount := sumReceivers(offchainReceivers) numberOfConnectors := numberOfVTXOs(payments) connectorOutputAmount := connectorAmount * numberOfConnectors - poolTxOuts := []ports.TxOutput{ - newOutput(aspScript, sharedOutputAmount, b.net.AssetID), - newOutput(aspScript, connectorOutputAmount, b.net.AssetID), - } - for _, receiver := range onchainReceivers { - buf, _ := address.ToOutputScript(receiver.OnchainAddress) - script := hex.EncodeToString(buf) - poolTxOuts = append(poolTxOuts, newOutput(script, receiver.Amount, b.net.AssetID)) - } - ctx := context.Background() - poolTx, err = wallet.Transfer(ctx, poolTxOuts) - if err != nil { - return "", nil, err + outputs := []psetv2.OutputArgs{ + { + Asset: b.net.AssetID, + Amount: sharedOutputAmount, + Script: aspScriptBytes, + }, + { + Asset: b.net.AssetID, + Amount: connectorOutputAmount, + Script: aspScriptBytes, + }, } - poolTxID, err := getTxid(poolTx) + amountToSelect := sharedOutputAmount + connectorOutputAmount + + for _, receiver := range onchainReceivers { + amountToSelect += receiver.Amount + + receiverScript, err := address.ToOutputScript(receiver.OnchainAddress) + if err != nil { + return "", nil, err + } + + outputs = append(outputs, psetv2.OutputArgs{ + Asset: b.net.AssetID, + Amount: receiver.Amount, + Script: receiverScript, + }) + } + + utxos, change, err := wallet.SelectUtxos(ctx, b.net.AssetID, amountToSelect) if err != nil { - return "", nil, err + return + } + + if change > 0 { + outputs = append(outputs, psetv2.OutputArgs{ + Asset: b.net.AssetID, + Amount: change, + Script: aspScriptBytes, + }) + } + + ptx, err := psetv2.New(toInputArgs(utxos), outputs, nil) + if err != nil { + return + } + + utx, err := ptx.UnsignedTx() + if err != nil { + return } congestionTree, err = buildCongestionTree( newOutputScriptFactory(aspPubkey, b.net), b.net, - poolTxID, + utx.TxHash().String(), offchainReceivers, ) + if err != nil { + return + } + + poolTx, err = ptx.ToBase64() + if err != nil { + return + } return poolTx, congestionTree, err } @@ -219,28 +257,15 @@ func sumReceivers(receivers []domain.Receiver) uint64 { return sum } -type output struct { - script string - amount uint64 - asset string -} - -func newOutput(script string, amount uint64, asset string) ports.TxOutput { - return &output{ - script: script, - amount: amount, - asset: asset, +func toInputArgs( + ins []ports.TxInput, +) []psetv2.InputArgs { + inputs := make([]psetv2.InputArgs, 0, len(ins)) + for _, in := range ins { + inputs = append(inputs, psetv2.InputArgs{ + Txid: in.GetTxid(), + TxIndex: in.GetIndex(), + }) } -} - -func (o *output) GetAmount() uint64 { - return o.amount -} - -func (o *output) GetAsset() string { - return o.asset -} - -func (o *output) GetScript() string { - return o.script + return inputs } diff --git a/asp/internal/infrastructure/tx-builder/dummy/builder_test.go b/asp/internal/infrastructure/tx-builder/dummy/builder_test.go index 7c62f47..41ce9d1 100644 --- a/asp/internal/infrastructure/tx-builder/dummy/builder_test.go +++ b/asp/internal/infrastructure/tx-builder/dummy/builder_test.go @@ -2,6 +2,8 @@ package txbuilder_test import ( "context" + "crypto/rand" + "encoding/hex" "testing" "github.com/ark-network/ark/common" @@ -12,66 +14,37 @@ import ( secp256k1 "github.com/decred/dcrd/dcrec/secp256k1/v4" "github.com/stretchr/testify/require" "github.com/vulpemventures/go-elements/network" - "github.com/vulpemventures/go-elements/payment" "github.com/vulpemventures/go-elements/psetv2" ) const ( testingKey = "apub1qgvdtj5ttpuhkldavhq8thtm5auyk0ec4dcmrfdgu0u5hgp9we22v3hrs4x" + fakePoolTx = "cHNldP8BAgQCAAAAAQQBAQEFAQMBBgEDAfsEAgAAAAABDiDk7dXxh4KQzgLO8i1ABtaLCe4aPL12GVhN1E9zM1ePLwEPBAAAAAABEAT/////AAEDCOgDAAAAAAAAAQQWABSNnpy01UJqd99eTg2M1IpdKId11gf8BHBzZXQCICWyUQcOKcoZBDzzPM1zJOLdqwPsxK4LXnfE/A5c9slaB/wEcHNldAgEAAAAAAABAwh4BQAAAAAAAAEEFgAUjZ6ctNVCanffXk4NjNSKXSiHddYH/ARwc2V0AiAlslEHDinKGQQ88zzNcyTi3asD7MSuC153xPwOXPbJWgf8BHBzZXQIBAAAAAAAAQMI9AEAAAAAAAABBAAH/ARwc2V0AiAlslEHDinKGQQ88zzNcyTi3asD7MSuC153xPwOXPbJWgf8BHBzZXQIBAAAAAAA" ) -func createTestPoolTx(sharedOutputAmount, numberOfInputs uint64) (string, error) { - _, key, err := common.DecodePubKey(testingKey) - if err != nil { - return "", err - } +type input struct { + txid string + vout uint32 +} - payment := payment.FromPublicKey(key, &network.Testnet, nil) - script := payment.WitnessScript +func (i *input) GetTxid() string { + return i.txid +} - pset, err := psetv2.New(nil, nil, nil) - if err != nil { - return "", err - } +func (i *input) GetIndex() uint32 { + return i.vout +} - updater, err := psetv2.NewUpdater(pset) - if err != nil { - return "", err - } +func (i *input) GetScript() string { + return "a914ea9f486e82efb3dd83a69fd96e3f0113757da03c87" +} - err = updater.AddInputs([]psetv2.InputArgs{ - { - Txid: "2f8f5733734fd44d581976bd3c1aee098bd606402df2ce02ce908287f1d5ede4", - TxIndex: 0, - }, - }) - if err != nil { - return "", err - } +func (i *input) GetAsset() string { + return "5ac9f65c0efcc4775e0baec4ec03abdde22473cd3cf33c0419ca290e0751b225" +} - connectorsAmount := numberOfInputs * (450 + 500) - - err = updater.AddOutputs([]psetv2.OutputArgs{ - { - Asset: network.Regtest.AssetID, - Amount: sharedOutputAmount, - Script: script, - }, - { - Asset: network.Regtest.AssetID, - Amount: connectorsAmount, - Script: script, - }, - { - Asset: network.Regtest.AssetID, - Amount: 500, - }, - }) - if err != nil { - return "", err - } - - return pset.ToBase64() +func (i *input) GetValue() uint64 { + return 1000 } type mockedWalletService struct{} @@ -106,9 +79,22 @@ func (*mockedWalletService) Status(ctx context.Context) (ports.WalletStatus, err panic("unimplemented") } -// Transfer implements ports.WalletService. -func (*mockedWalletService) Transfer(ctx context.Context, outs []ports.TxOutput) (string, error) { - return createTestPoolTx(1000, (450+500)*1) +func (*mockedWalletService) SelectUtxos(ctx context.Context, asset string, amount uint64) ([]ports.TxInput, uint64, error) { + // random txid + bytes := make([]byte, 32) + if _, err := rand.Read(bytes); err != nil { + return nil, 0, err + } + fakeInput := input{ + txid: hex.EncodeToString(bytes), + vout: 0, + } + + return []ports.TxInput{&fakeInput}, 0, nil +} + +func (*mockedWalletService) EstimateFees(ctx context.Context, pset string) (uint64, error) { + return 100, nil } func TestBuildCongestionTree(t *testing.T) { @@ -290,10 +276,7 @@ func TestBuildCongestionTree(t *testing.T) { func TestBuildForfeitTxs(t *testing.T) { builder := txbuilder.NewTxBuilder(network.Liquid) - poolTx, err := createTestPoolTx(1000, 450*2) - require.NoError(t, err) - - poolPset, err := psetv2.NewPsetFromBase64(poolTx) + poolPset, err := psetv2.NewPsetFromBase64(fakePoolTx) require.NoError(t, err) poolTxUnsigned, err := poolPset.UnsignedTx() @@ -355,7 +338,7 @@ func TestBuildForfeitTxs(t *testing.T) { for _, f := range fixtures { connectors, forfeitTxs, err := builder.BuildForfeitTxs( - key, poolTx, f.payments, + key, fakePoolTx, f.payments, ) require.NoError(t, err) require.Len(t, connectors, f.expectedNumOfConnectors) diff --git a/common/tree/validation.go b/common/tree/validation.go index a9677ea..8376ec0 100644 --- a/common/tree/validation.go +++ b/common/tree/validation.go @@ -9,42 +9,41 @@ import ( "github.com/btcsuite/btcd/btcec/v2/schnorr" "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/decred/dcrd/dcrec/secp256k1/v4" - "github.com/vulpemventures/go-elements/elementsutil" "github.com/vulpemventures/go-elements/psetv2" "github.com/vulpemventures/go-elements/taproot" - "github.com/vulpemventures/go-elements/transaction" ) var ( - ErrInvalidPoolTransaction = errors.New("invalid pool transaction") - ErrEmptyTree = errors.New("empty congestion tree") - ErrInvalidRootLevel = errors.New("root level must have only one node") - ErrNoLeaves = errors.New("no leaves in the tree") - ErrNodeTransactionEmpty = errors.New("node transaction is empty") - ErrNodeTxidEmpty = errors.New("node txid is empty") - ErrNodeParentTxidEmpty = errors.New("node parent txid is empty") - ErrNodeTxidDifferent = errors.New("node txid differs from node transaction") - ErrNumberOfInputs = errors.New("node transaction should have only one input") - ErrNumberOfOutputs = errors.New("node transaction should have only three or two outputs") - ErrParentTxidInput = errors.New("parent txid should be the input of the node transaction") - ErrNumberOfChildren = errors.New("node branch transaction should have two children") - ErrLeafChildren = errors.New("leaf node should have max 1 child") - ErrInvalidChildTxid = errors.New("invalid child txid") - ErrNumberOfTapscripts = errors.New("input should have two tapscripts leaves") - ErrInternalKey = errors.New("taproot internal key is not unspendable") - ErrInvalidTaprootScript = errors.New("invalid taproot script") - ErrInvalidLeafTaprootScript = errors.New("invalid leaf taproot script") - ErrInvalidAmount = errors.New("children amount is different from parent amount") - ErrInvalidAsset = errors.New("invalid output asset") - ErrInvalidSweepSequence = errors.New("invalid sweep sequence") - ErrInvalidASP = errors.New("invalid ASP") - ErrMissingFeeOutput = errors.New("missing fee output") - ErrInvalidLeftOutput = errors.New("invalid left output") - ErrInvalidRightOutput = errors.New("invalid right output") - ErrMissingSweepTapscript = errors.New("missing sweep tapscript") - ErrMissingBranchTapscript = errors.New("missing branch tapscript") - ErrInvalidLeaf = errors.New("leaf node shouldn't have children") - ErrWrongPoolTxID = errors.New("root input should be the pool tx outpoint") + ErrInvalidPoolTransaction = errors.New("invalid pool transaction") + ErrInvalidPoolTransactionOutputs = errors.New("invalid number of outputs in pool transaction") + ErrEmptyTree = errors.New("empty congestion tree") + ErrInvalidRootLevel = errors.New("root level must have only one node") + ErrNoLeaves = errors.New("no leaves in the tree") + ErrNodeTransactionEmpty = errors.New("node transaction is empty") + ErrNodeTxidEmpty = errors.New("node txid is empty") + ErrNodeParentTxidEmpty = errors.New("node parent txid is empty") + ErrNodeTxidDifferent = errors.New("node txid differs from node transaction") + ErrNumberOfInputs = errors.New("node transaction should have only one input") + ErrNumberOfOutputs = errors.New("node transaction should have only three or two outputs") + ErrParentTxidInput = errors.New("parent txid should be the input of the node transaction") + ErrNumberOfChildren = errors.New("node branch transaction should have two children") + ErrLeafChildren = errors.New("leaf node should have max 1 child") + ErrInvalidChildTxid = errors.New("invalid child txid") + ErrNumberOfTapscripts = errors.New("input should have two tapscripts leaves") + ErrInternalKey = errors.New("taproot internal key is not unspendable") + ErrInvalidTaprootScript = errors.New("invalid taproot script") + ErrInvalidLeafTaprootScript = errors.New("invalid leaf taproot script") + ErrInvalidAmount = errors.New("children amount is different from parent amount") + ErrInvalidAsset = errors.New("invalid output asset") + ErrInvalidSweepSequence = errors.New("invalid sweep sequence") + ErrInvalidASP = errors.New("invalid ASP") + ErrMissingFeeOutput = errors.New("missing fee output") + ErrInvalidLeftOutput = errors.New("invalid left output") + ErrInvalidRightOutput = errors.New("invalid right output") + ErrMissingSweepTapscript = errors.New("missing sweep tapscript") + ErrMissingBranchTapscript = errors.New("missing branch tapscript") + ErrInvalidLeaf = errors.New("leaf node shouldn't have children") + ErrWrongPoolTxID = errors.New("root input should be the pool tx outpoint") ) const ( @@ -63,24 +62,30 @@ const ( // - input and output amounts func ValidateCongestionTree( tree CongestionTree, - poolTxHex string, + poolTx string, aspPublicKey *secp256k1.PublicKey, roundLifetimeSeconds uint, ) error { unspendableKeyBytes, _ := hex.DecodeString(UnspendablePoint) unspendableKey, _ := secp256k1.ParsePubKey(unspendableKeyBytes) - poolTransaction, err := transaction.NewTxFromHex(poolTxHex) + poolTransaction, err := psetv2.NewPsetFromBase64(poolTx) if err != nil { return ErrInvalidPoolTransaction } - poolTxAmount, err := elementsutil.ValueFromBytes(poolTransaction.Outputs[sharedOutputIndex].Value) + if len(poolTransaction.Outputs) < sharedOutputIndex+1 { + return ErrInvalidPoolTransactionOutputs + } + + poolTxAmount := poolTransaction.Outputs[sharedOutputIndex].Value + + utx, err := poolTransaction.UnsignedTx() if err != nil { return ErrInvalidPoolTransaction } - poolTxID := poolTransaction.TxHash().String() + poolTxID := utx.TxHash().String() nbNodes := tree.NumberOfNodes() if nbNodes == 0 { diff --git a/noah/common.go b/noah/common.go index 3dcf108..1268b10 100644 --- a/noah/common.go +++ b/noah/common.go @@ -19,12 +19,10 @@ import ( "github.com/decred/dcrd/dcrec/secp256k1/v4" "github.com/urfave/cli/v2" "github.com/vulpemventures/go-elements/address" - "github.com/vulpemventures/go-elements/elementsutil" "github.com/vulpemventures/go-elements/network" "github.com/vulpemventures/go-elements/payment" "github.com/vulpemventures/go-elements/psetv2" "github.com/vulpemventures/go-elements/taproot" - "github.com/vulpemventures/go-elements/transaction" "golang.org/x/term" ) @@ -87,7 +85,7 @@ func privateKeyFromPassword() (*secp256k1.PrivateKey, error) { encryptedPrivateKey, err := hex.DecodeString(encryptedPrivateKeyString) if err != nil { - return nil, err + return nil, fmt.Errorf("invalid encrypted private key: %s", err) } password, err := readPassword() @@ -384,9 +382,10 @@ func handleRoundStream( if event.GetRoundFinalization() != nil { // stop pinging as soon as we receive some forfeit txs pingStop() + fmt.Println("round finalization started") poolPartialTx := event.GetRoundFinalization().GetPoolPartialTx() - poolTransaction, err := transaction.NewTxFromHex(poolPartialTx) + poolTransaction, err := psetv2.NewPsetFromBase64(poolPartialTx) if err != nil { return "", err } @@ -429,13 +428,8 @@ func handleRoundStream( found := false for _, output := range poolTransaction.Outputs { if bytes.Equal(output.Script, onchainScript) { - outputValue, err := elementsutil.ValueFromBytes(output.Value) - if err != nil { - return "", err - } - - if outputValue != receiver.Amount { - return "", fmt.Errorf("invalid collaborative exit output amount: got %d, want %d", outputValue, receiver.Amount) + if output.Value != receiver.Amount { + return "", fmt.Errorf("invalid collaborative exit output amount: got %d, want %d", output.Value, receiver.Amount) } found = true