Send unsigned pool transactions to clients (#85)

* replace Transfer by SelectUtxos

* Wallet.SignPset: handle unset WitnessUtxo

* fix linter

* renaming variables

* add witnessUtxo while creating pool transaction

* add EstimateFees in ports.Wallet

* replace createTestPoolTx by a constant pset
This commit is contained in:
Louis Singer
2024-01-24 16:08:50 +01:00
committed by GitHub
parent 5dba216a98
commit d4ee064245
11 changed files with 482 additions and 291 deletions

View File

@@ -266,7 +266,7 @@ func (s *service) startFinalization() {
return return
} }
signedPoolTx, tree, err := s.builder.BuildPoolTx(s.pubkey, s.wallet, payments, s.minRelayFee) unsignedPoolTx, tree, err := s.builder.BuildPoolTx(s.pubkey, s.wallet, payments, s.minRelayFee)
if err != nil { if err != nil {
changes = round.Fail(fmt.Errorf("failed to create pool tx: %s", err)) changes = round.Fail(fmt.Errorf("failed to create pool tx: %s", err))
log.WithError(err).Warn("failed to create pool tx") log.WithError(err).Warn("failed to create pool tx")
@@ -275,7 +275,7 @@ func (s *service) startFinalization() {
log.Debugf("pool tx created for round %s", round.Id) log.Debugf("pool tx created for round %s", round.Id)
connectors, forfeitTxs, err := s.builder.BuildForfeitTxs(s.pubkey, signedPoolTx, payments) connectors, forfeitTxs, err := s.builder.BuildForfeitTxs(s.pubkey, unsignedPoolTx, payments)
if err != nil { if err != nil {
changes = round.Fail(fmt.Errorf("failed to create connectors and forfeit txs: %s", err)) changes = round.Fail(fmt.Errorf("failed to create connectors and forfeit txs: %s", err))
log.WithError(err).Warn("failed to create connectors and forfeit txs") log.WithError(err).Warn("failed to create connectors and forfeit txs")
@@ -284,7 +284,7 @@ func (s *service) startFinalization() {
log.Debugf("forfeit transactions created for round %s", round.Id) log.Debugf("forfeit transactions created for round %s", round.Id)
events, err := round.StartFinalization(connectors, tree, signedPoolTx) events, err := round.StartFinalization(connectors, tree, unsignedPoolTx)
if err != nil { if err != nil {
changes = round.Fail(fmt.Errorf("failed to start finalization: %s", err)) changes = round.Fail(fmt.Errorf("failed to start finalization: %s", err))
log.WithError(err).Warn("failed to start finalization") log.WithError(err).Warn("failed to start finalization")
@@ -326,7 +326,15 @@ func (s *service) finalizeRound() {
return return
} }
txid, err := s.wallet.BroadcastTransaction(ctx, round.TxHex) log.Debugf("signing round transaction %s\n", round.Id)
signedPoolTx, err := s.wallet.SignPset(ctx, round.UnsignedTx, true)
if err != nil {
changes = round.Fail(fmt.Errorf("failed to sign round tx: %s", err))
log.WithError(err).Warn("failed to sign round tx")
return
}
txid, err := s.wallet.BroadcastTransaction(ctx, signedPoolTx)
if err != nil { if err != nil {
changes = round.Fail(fmt.Errorf("failed to broadcast pool tx: %s", err)) changes = round.Fail(fmt.Errorf("failed to broadcast pool tx: %s", err))
log.WithError(err).Warn("failed to broadcast pool tx") log.WithError(err).Warn("failed to broadcast pool tx")

View File

@@ -40,7 +40,7 @@ type Round struct {
Stage Stage Stage Stage
Payments map[string]Payment Payments map[string]Payment
Txid string Txid string
TxHex string UnsignedTx string
ForfeitTxs []string ForfeitTxs []string
CongestionTree tree.CongestionTree CongestionTree tree.CongestionTree
Connectors []string Connectors []string
@@ -84,7 +84,7 @@ func (r *Round) On(event RoundEvent, replayed bool) {
r.Stage.Code = FinalizationStage r.Stage.Code = FinalizationStage
r.CongestionTree = e.CongestionTree r.CongestionTree = e.CongestionTree
r.Connectors = append([]string{}, e.Connectors...) r.Connectors = append([]string{}, e.Connectors...)
r.TxHex = e.PoolTx r.UnsignedTx = e.PoolTx
case RoundFinalized: case RoundFinalized:
r.Stage.Ended = true r.Stage.Ended = true
r.Txid = e.Txid r.Txid = e.Txid

View File

@@ -13,8 +13,9 @@ type WalletService interface {
SignPset( SignPset(
ctx context.Context, pset string, extractRawTx bool, ctx context.Context, pset string, extractRawTx bool,
) (string, error) ) (string, error)
Transfer(ctx context.Context, outs []TxOutput) (string, error) SelectUtxos(ctx context.Context, asset string, amount uint64) ([]TxInput, uint64, error)
BroadcastTransaction(ctx context.Context, txHex string) (string, error) BroadcastTransaction(ctx context.Context, txHex string) (string, error)
EstimateFees(ctx context.Context, pset string) (uint64, error)
Close() Close()
} }
@@ -28,12 +29,6 @@ type TxInput interface {
GetTxid() string GetTxid() string
GetIndex() uint32 GetIndex() uint32
GetScript() string GetScript() string
GetScriptSigSize() int
GetWitnessSize() int
}
type TxOutput interface {
GetAmount() uint64
GetAsset() string GetAsset() string
GetScript() string GetValue() uint64
} }

View File

@@ -377,7 +377,7 @@ func roundsMatch(expected, got domain.Round) assert.Comparison {
if expected.Txid != got.Txid { if expected.Txid != got.Txid {
return false return false
} }
if expected.TxHex != got.TxHex { if expected.UnsignedTx != got.UnsignedTx {
return false return false
} }
if !reflect.DeepEqual(expected.ForfeitTxs, got.ForfeitTxs) { if !reflect.DeepEqual(expected.ForfeitTxs, got.ForfeitTxs) {

View File

@@ -2,14 +2,19 @@ package oceanwallet
import ( import (
"context" "context"
"encoding/hex"
"fmt" "fmt"
pb "github.com/ark-network/ark/api-spec/protobuf/gen/ocean/v1" pb "github.com/ark-network/ark/api-spec/protobuf/gen/ocean/v1"
"github.com/ark-network/ark/internal/core/ports" "github.com/ark-network/ark/internal/core/ports"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/vulpemventures/go-elements/elementsutil"
"github.com/vulpemventures/go-elements/psetv2" "github.com/vulpemventures/go-elements/psetv2"
) )
const msatsPerByte = 110 const (
zero32 = "0000000000000000000000000000000000000000000000000000000000000000"
)
func (s *service) SignPset( func (s *service) SignPset(
ctx context.Context, pset string, extractRawTx bool, ctx context.Context, pset string, extractRawTx bool,
@@ -21,29 +26,54 @@ func (s *service) SignPset(
return "", err return "", err
} }
signedPset := res.GetPset() signedPset := res.GetPset()
if !extractRawTx { if !extractRawTx {
return signedPset, nil return signedPset, nil
} }
ptx, _ := psetv2.NewPsetFromBase64(signedPset) ptx, err := psetv2.NewPsetFromBase64(signedPset)
if err := psetv2.MaybeFinalizeAll(ptx); err != nil {
return "", fmt.Errorf("failed to finalize signed pset: %s", err)
}
return ptx.ToBase64()
}
func (s *service) Transfer(
ctx context.Context, outs []ports.TxOutput,
) (string, error) {
res, err := s.txClient.Transfer(ctx, &pb.TransferRequest{
AccountName: accountLabel,
Receivers: outputList(outs).toProto(),
MillisatsPerByte: msatsPerByte,
})
if err != nil { if err != nil {
return "", err return "", err
} }
return res.GetTxHex(), nil
if err := psetv2.MaybeFinalizeAll(ptx); err != nil {
return "", fmt.Errorf("failed to finalize signed pset: %s", err)
}
extractedTx, err := psetv2.Extract(ptx)
if err != nil {
return "", fmt.Errorf("failed to extract signed pset: %s", err)
}
txHex, err := extractedTx.ToHex()
if err != nil {
return "", fmt.Errorf("failed to convert extracted tx to hex: %s", err)
}
return txHex, nil
}
func (s *service) SelectUtxos(ctx context.Context, asset string, amount uint64) ([]ports.TxInput, uint64, error) {
res, err := s.txClient.SelectUtxos(ctx, &pb.SelectUtxosRequest{
AccountName: accountLabel,
TargetAsset: asset,
TargetAmount: amount,
})
if err != nil {
return nil, 0, err
}
inputs := make([]ports.TxInput, 0, len(res.GetUtxos()))
for _, utxo := range res.GetUtxos() {
// check that the utxos are not confidential
if utxo.GetAssetBlinder() != zero32 || utxo.GetValueBlinder() != zero32 {
return nil, 0, fmt.Errorf("utxo is confidential")
}
inputs = append(inputs, utxo)
}
return inputs, res.GetChange(), nil
} }
func (s *service) BroadcastTransaction( func (s *service) BroadcastTransaction(
@@ -60,16 +90,50 @@ func (s *service) BroadcastTransaction(
return res.GetTxid(), nil return res.GetTxid(), nil
} }
type outputList []ports.TxOutput func (s *service) EstimateFees(
ctx context.Context, pset string,
) (uint64, error) {
tx, err := psetv2.NewPsetFromBase64(pset)
if err != nil {
return 0, err
}
func (l outputList) toProto() []*pb.Output { inputs := make([]*pb.Input, 0, len(tx.Inputs))
list := make([]*pb.Output, 0, len(l)) outputs := make([]*pb.Output, 0, len(tx.Outputs))
for _, out := range l {
list = append(list, &pb.Output{ for _, in := range tx.Inputs {
Amount: out.GetAmount(), if in.WitnessUtxo == nil {
Script: out.GetScript(), return 0, fmt.Errorf("missing witness utxo, cannot estimate fees")
Asset: out.GetAsset(), }
inputs = append(inputs, &pb.Input{
Txid: chainhash.Hash(in.PreviousTxid).String(),
Index: in.PreviousTxIndex,
Script: hex.EncodeToString(in.WitnessUtxo.Script),
}) })
} }
return list
for _, out := range tx.Outputs {
outputs = append(outputs, &pb.Output{
Asset: elementsutil.AssetHashFromBytes(
append([]byte{0x01}, out.Asset...),
),
Amount: out.Value,
Script: hex.EncodeToString(out.Script),
})
}
fee, err := s.txClient.EstimateFees(
ctx,
&pb.EstimateFeesRequest{
Inputs: inputs,
Outputs: outputs,
},
)
if err != nil {
return 0, fmt.Errorf("failed to estimate fees: %s", err)
}
// we add 5 sats in order to avoid min-relay-fee not met errors
return fee.GetFeeAmount() + 5, nil
} }

View File

@@ -3,10 +3,12 @@ package txbuilder
import ( import (
"context" "context"
"encoding/hex" "encoding/hex"
"fmt"
"github.com/ark-network/ark/common/tree" "github.com/ark-network/ark/common/tree"
"github.com/ark-network/ark/internal/core/domain" "github.com/ark-network/ark/internal/core/domain"
"github.com/ark-network/ark/internal/core/ports" "github.com/ark-network/ark/internal/core/ports"
"github.com/btcsuite/btcd/txscript"
"github.com/decred/dcrd/dcrec/secp256k1/v4" "github.com/decred/dcrd/dcrec/secp256k1/v4"
"github.com/vulpemventures/go-elements/address" "github.com/vulpemventures/go-elements/address"
"github.com/vulpemventures/go-elements/elementsutil" "github.com/vulpemventures/go-elements/elementsutil"
@@ -21,6 +23,8 @@ const (
connectorAmount = 450 connectorAmount = 450
) )
var emptyNonce = []byte{0x00}
type txBuilder struct { type txBuilder struct {
net *network.Network net *network.Network
} }
@@ -44,11 +48,7 @@ func p2wpkhScript(publicKey *secp256k1.PublicKey, net *network.Network) ([]byte,
func getTxid(txStr string) (string, error) { func getTxid(txStr string) (string, error) {
pset, err := psetv2.NewPsetFromBase64(txStr) pset, err := psetv2.NewPsetFromBase64(txStr)
if err != nil { if err != nil {
tx, err := transaction.NewTxFromHex(txStr) return "", err
if err != nil {
return "", err
}
return tx.TxHash().String(), nil
} }
utx, err := pset.UnsignedTx() utx, err := pset.UnsignedTx()
@@ -116,7 +116,7 @@ func (b *txBuilder) BuildForfeitTxs(
pubkeyBytes, err := hex.DecodeString(vtxo.Pubkey) pubkeyBytes, err := hex.DecodeString(vtxo.Pubkey)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, fmt.Errorf("failed to decode pubkey: %s", err)
} }
vtxoPubkey, err := secp256k1.ParsePubKey(pubkeyBytes) vtxoPubkey, err := secp256k1.ParsePubKey(pubkeyBytes)
@@ -171,8 +171,6 @@ func (b *txBuilder) BuildPoolTx(
return return
} }
aspScript := hex.EncodeToString(aspScriptBytes)
offchainReceivers, onchainReceivers := receiversFromPayments(payments) offchainReceivers, onchainReceivers := receiversFromPayments(payments)
numberOfConnectors := numberOfVTXOs(payments) numberOfConnectors := numberOfVTXOs(payments)
connectorOutputAmount := connectorAmount * numberOfConnectors connectorOutputAmount := connectorAmount * numberOfConnectors
@@ -189,38 +187,165 @@ func (b *txBuilder) BuildPoolTx(
return return
} }
sharedOutputScriptHex := hex.EncodeToString(sharedOutputScript) outputs := []psetv2.OutputArgs{
{
poolTxOuts := []ports.TxOutput{ Asset: b.net.AssetID,
newOutput(sharedOutputScriptHex, sharedOutputAmount, b.net.AssetID), Amount: sharedOutputAmount,
newOutput(aspScript, connectorOutputAmount, b.net.AssetID), Script: sharedOutputScript,
},
{
Asset: b.net.AssetID,
Amount: connectorOutputAmount,
Script: aspScriptBytes,
},
} }
targetAmount := sharedOutputAmount + connectorOutputAmount
for _, receiver := range onchainReceivers { for _, receiver := range onchainReceivers {
buf, _ := address.ToOutputScript(receiver.OnchainAddress) targetAmount += receiver.Amount
script := hex.EncodeToString(buf)
poolTxOuts = append(poolTxOuts, newOutput(script, receiver.Amount, b.net.AssetID)) receiverScript, err := address.ToOutputScript(receiver.OnchainAddress)
if err != nil {
return "", nil, err
}
outputs = append(outputs, psetv2.OutputArgs{
Asset: b.net.AssetID,
Amount: receiver.Amount,
Script: receiverScript,
})
} }
txHex, err := wallet.Transfer(ctx, poolTxOuts) utxos, change, err := wallet.SelectUtxos(ctx, b.net.AssetID, targetAmount)
if err != nil { if err != nil {
return return
} }
tx, err := transaction.NewTxFromHex(txHex) if change > 0 {
outputs = append(outputs, psetv2.OutputArgs{
Asset: b.net.AssetID,
Amount: change,
Script: aspScriptBytes,
})
}
ptx, err := psetv2.New(toInputArgs(utxos), outputs, nil)
if err != nil {
return
}
updater, err := psetv2.NewUpdater(ptx)
if err != nil {
return
}
for i, utxo := range utxos {
witnessUtxo, err := toWitnessUtxo(utxo)
if err != nil {
return "", nil, err
}
if err := updater.AddInWitnessUtxo(i, witnessUtxo); err != nil {
return "", nil, err
}
if err := updater.AddInSighashType(i, txscript.SigHashAll); err != nil {
return "", nil, err
}
}
b64, err := ptx.ToBase64()
if err != nil {
return
}
feesAmount, err := wallet.EstimateFees(ctx, b64)
if err != nil {
return
}
if feesAmount == change {
// fees = change, remove change output
updater.Pset.Outputs = ptx.Outputs[:len(ptx.Outputs)-1]
} else if feesAmount < change {
// change covers the fees, reduce change amount
updater.Pset.Outputs[len(ptx.Outputs)-1].Value = change - feesAmount
} else {
// change is not enough to cover fees, re-select utxos
if change > 0 {
// remove change output if present
updater.Pset.Outputs = ptx.Outputs[:len(ptx.Outputs)-1]
}
newUtxos, newChange, err := wallet.SelectUtxos(ctx, b.net.AssetID, feesAmount-change)
if err != nil {
return "", nil, err
}
if err := updater.AddInputs(toInputArgs(newUtxos)); err != nil {
return "", nil, err
}
if newChange > 0 {
if err := updater.AddOutputs([]psetv2.OutputArgs{
{
Asset: b.net.AssetID,
Amount: newChange,
Script: aspScriptBytes,
},
}); err != nil {
return "", nil, err
}
}
nbInputs := len(utxos)
for i, utxo := range newUtxos {
witnessUtxo, err := toWitnessUtxo(utxo)
if err != nil {
return "", nil, err
}
if err := updater.AddInWitnessUtxo(i+nbInputs, witnessUtxo); err != nil {
return "", nil, err
}
if err := updater.AddInSighashType(i+nbInputs, txscript.SigHashAll); err != nil {
return "", nil, err
}
}
}
// add fee output
if err := updater.AddOutputs([]psetv2.OutputArgs{
{
Asset: b.net.AssetID,
Amount: feesAmount,
},
}); err != nil {
return "", nil, err
}
utx, err := ptx.UnsignedTx()
if err != nil { if err != nil {
return return
} }
tree, err := makeTree(psetv2.InputArgs{ tree, err := makeTree(psetv2.InputArgs{
Txid: tx.TxHash().String(), Txid: utx.TxHash().String(),
TxIndex: 0, TxIndex: 0,
}) })
if err != nil { if err != nil {
return return
} }
poolTx = txHex poolTx, err = ptx.ToBase64()
if err != nil {
return
}
congestionTree = tree congestionTree = tree
return return
} }
@@ -321,28 +446,41 @@ func receiversFromPayments(
return return
} }
type output struct { func toInputArgs(
script string ins []ports.TxInput,
amount uint64 ) []psetv2.InputArgs {
asset string inputs := make([]psetv2.InputArgs, 0, len(ins))
} for _, in := range ins {
inputs = append(inputs, psetv2.InputArgs{
func newOutput(script string, amount uint64, asset string) ports.TxOutput { Txid: in.GetTxid(),
return &output{ TxIndex: in.GetIndex(),
script: script, })
amount: amount,
asset: asset,
} }
return inputs
} }
func (o *output) GetAsset() string { func toWitnessUtxo(in ports.TxInput) (*transaction.TxOutput, error) {
return o.asset valueBytes, err := elementsutil.ValueToBytes(in.GetValue())
} if err != nil {
return nil, fmt.Errorf("failed to convert value to bytes: %s", err)
}
func (o *output) GetAmount() uint64 { assetBytes, err := elementsutil.AssetHashToBytes(in.GetAsset())
return o.amount if err != nil {
} return nil, fmt.Errorf("failed to convert asset to bytes: %s", err)
}
func (o *output) GetScript() string { scriptBytes, err := hex.DecodeString(in.GetScript())
return o.script if err != nil {
return nil, fmt.Errorf("failed to decode script: %s", err)
}
return &transaction.TxOutput{
Asset: assetBytes,
Value: valueBytes,
Script: scriptBytes,
Nonce: emptyNonce,
RangeProof: nil,
SurjectionProof: nil,
}, err
} }

View File

@@ -2,6 +2,8 @@ package txbuilder_test
import ( import (
"context" "context"
"crypto/rand"
"encoding/hex"
"testing" "testing"
"github.com/ark-network/ark/common" "github.com/ark-network/ark/common"
@@ -13,75 +15,40 @@ import (
secp256k1 "github.com/decred/dcrd/dcrec/secp256k1/v4" secp256k1 "github.com/decred/dcrd/dcrec/secp256k1/v4"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/vulpemventures/go-elements/network" "github.com/vulpemventures/go-elements/network"
"github.com/vulpemventures/go-elements/payment"
"github.com/vulpemventures/go-elements/psetv2" "github.com/vulpemventures/go-elements/psetv2"
"github.com/vulpemventures/go-elements/transaction"
) )
const ( const (
testingKey = "apub1qgvdtj5ttpuhkldavhq8thtm5auyk0ec4dcmrfdgu0u5hgp9we22v3hrs4x" testingKey = "apub1qgvdtj5ttpuhkldavhq8thtm5auyk0ec4dcmrfdgu0u5hgp9we22v3hrs4x"
fakePoolTx = "cHNldP8BAgQCAAAAAQQBAQEFAQMBBgEDAfsEAgAAAAABDiDk7dXxh4KQzgLO8i1ABtaLCe4aPL12GVhN1E9zM1ePLwEPBAAAAAABEAT/////AAEDCOgDAAAAAAAAAQQWABSNnpy01UJqd99eTg2M1IpdKId11gf8BHBzZXQCICWyUQcOKcoZBDzzPM1zJOLdqwPsxK4LXnfE/A5c9slaB/wEcHNldAgEAAAAAAABAwh4BQAAAAAAAAEEFgAUjZ6ctNVCanffXk4NjNSKXSiHddYH/ARwc2V0AiAlslEHDinKGQQ88zzNcyTi3asD7MSuC153xPwOXPbJWgf8BHBzZXQIBAAAAAAAAQMI9AEAAAAAAAABBAAH/ARwc2V0AiAlslEHDinKGQQ88zzNcyTi3asD7MSuC153xPwOXPbJWgf8BHBzZXQIBAAAAAAA"
) )
func createTestPoolTx(sharedOutputAmount, numberOfInputs uint64) (string, error) { type mockedWalletService struct{}
_, key, err := common.DecodePubKey(testingKey)
if err != nil {
return "", err
}
payment := payment.FromPublicKey(key, &network.Testnet, nil) type input struct {
script := payment.WitnessScript txid string
vout uint32
pset, err := psetv2.New(nil, nil, nil)
if err != nil {
return "", err
}
updater, err := psetv2.NewUpdater(pset)
if err != nil {
return "", err
}
err = updater.AddInputs([]psetv2.InputArgs{
{
Txid: "2f8f5733734fd44d581976bd3c1aee098bd606402df2ce02ce908287f1d5ede4",
TxIndex: 0,
},
})
if err != nil {
return "", err
}
connectorsAmount := numberOfInputs*450 + 500
err = updater.AddOutputs([]psetv2.OutputArgs{
{
Asset: network.Regtest.AssetID,
Amount: sharedOutputAmount,
Script: script,
},
{
Asset: network.Regtest.AssetID,
Amount: connectorsAmount,
Script: script,
},
{
Asset: network.Regtest.AssetID,
Amount: 500,
},
})
if err != nil {
return "", err
}
utx, err := pset.UnsignedTx()
if err != nil {
return "", err
}
return utx.ToHex()
} }
type mockedWalletService struct{} func (i *input) GetTxid() string {
return i.txid
}
func (i *input) GetIndex() uint32 {
return i.vout
}
func (i *input) GetScript() string {
return "a914ea9f486e82efb3dd83a69fd96e3f0113757da03c87"
}
func (i *input) GetAsset() string {
return "5ac9f65c0efcc4775e0baec4ec03abdde22473cd3cf33c0419ca290e0751b225"
}
func (i *input) GetValue() uint64 {
return 1000
}
// BroadcastTransaction implements ports.WalletService. // BroadcastTransaction implements ports.WalletService.
func (*mockedWalletService) BroadcastTransaction(ctx context.Context, txHex string) (string, error) { func (*mockedWalletService) BroadcastTransaction(ctx context.Context, txHex string) (string, error) {
@@ -113,9 +80,22 @@ func (*mockedWalletService) Status(ctx context.Context) (ports.WalletStatus, err
panic("unimplemented") panic("unimplemented")
} }
// Transfer implements ports.WalletService. func (*mockedWalletService) SelectUtxos(ctx context.Context, asset string, amount uint64) ([]ports.TxInput, uint64, error) {
func (*mockedWalletService) Transfer(ctx context.Context, outs []ports.TxOutput) (string, error) { // random txid
return createTestPoolTx(outs[0].GetAmount(), 1) bytes := make([]byte, 32)
if _, err := rand.Read(bytes); err != nil {
return nil, 0, err
}
fakeInput := input{
txid: hex.EncodeToString(bytes),
vout: 0,
}
return []ports.TxInput{&fakeInput}, 0, nil
}
func (*mockedWalletService) EstimateFees(ctx context.Context, pset string) (uint64, error) {
return 100, nil
} }
func TestBuildCongestionTree(t *testing.T) { func TestBuildCongestionTree(t *testing.T) {
@@ -373,14 +353,13 @@ func TestBuildCongestionTree(t *testing.T) {
func TestBuildForfeitTxs(t *testing.T) { func TestBuildForfeitTxs(t *testing.T) {
builder := txbuilder.NewTxBuilder(network.Liquid) builder := txbuilder.NewTxBuilder(network.Liquid)
// TODO: replace with fixture. poolTx, err := psetv2.NewPsetFromBase64(fakePoolTx)
poolTxHex, err := createTestPoolTx(1000, 2)
require.NoError(t, err) require.NoError(t, err)
poolTx, err := transaction.NewTxFromHex(poolTxHex) utx, err := poolTx.UnsignedTx()
require.NoError(t, err) require.NoError(t, err)
poolTxid := poolTx.TxHash().String() poolTxid := utx.TxHash().String()
fixtures := []struct { fixtures := []struct {
payments []domain.Payment payments []domain.Payment
@@ -436,7 +415,7 @@ func TestBuildForfeitTxs(t *testing.T) {
for _, f := range fixtures { for _, f := range fixtures {
connectors, forfeitTxs, err := builder.BuildForfeitTxs( connectors, forfeitTxs, err := builder.BuildForfeitTxs(
key, poolTxHex, f.payments, key, fakePoolTx, f.payments,
) )
require.NoError(t, err) require.NoError(t, err)
require.Len(t, connectors, f.expectedNumOfConnectors) require.Len(t, connectors, f.expectedNumOfConnectors)

View File

@@ -2,7 +2,6 @@ package txbuilder
import ( import (
"context" "context"
"encoding/hex"
"github.com/ark-network/ark/common/tree" "github.com/ark-network/ark/common/tree"
"github.com/ark-network/ark/internal/core/domain" "github.com/ark-network/ark/internal/core/domain"
@@ -99,42 +98,81 @@ func (b *txBuilder) BuildPoolTx(
return "", nil, err return "", nil, err
} }
aspScript := hex.EncodeToString(aspScriptBytes)
offchainReceivers, onchainReceivers := receiversFromPayments(payments) offchainReceivers, onchainReceivers := receiversFromPayments(payments)
sharedOutputAmount := sumReceivers(offchainReceivers) sharedOutputAmount := sumReceivers(offchainReceivers)
numberOfConnectors := numberOfVTXOs(payments) numberOfConnectors := numberOfVTXOs(payments)
connectorOutputAmount := connectorAmount * numberOfConnectors connectorOutputAmount := connectorAmount * numberOfConnectors
poolTxOuts := []ports.TxOutput{
newOutput(aspScript, sharedOutputAmount, b.net.AssetID),
newOutput(aspScript, connectorOutputAmount, b.net.AssetID),
}
for _, receiver := range onchainReceivers {
buf, _ := address.ToOutputScript(receiver.OnchainAddress)
script := hex.EncodeToString(buf)
poolTxOuts = append(poolTxOuts, newOutput(script, receiver.Amount, b.net.AssetID))
}
ctx := context.Background() ctx := context.Background()
poolTx, err = wallet.Transfer(ctx, poolTxOuts) outputs := []psetv2.OutputArgs{
if err != nil { {
return "", nil, err Asset: b.net.AssetID,
Amount: sharedOutputAmount,
Script: aspScriptBytes,
},
{
Asset: b.net.AssetID,
Amount: connectorOutputAmount,
Script: aspScriptBytes,
},
} }
poolTxID, err := getTxid(poolTx) amountToSelect := sharedOutputAmount + connectorOutputAmount
for _, receiver := range onchainReceivers {
amountToSelect += receiver.Amount
receiverScript, err := address.ToOutputScript(receiver.OnchainAddress)
if err != nil {
return "", nil, err
}
outputs = append(outputs, psetv2.OutputArgs{
Asset: b.net.AssetID,
Amount: receiver.Amount,
Script: receiverScript,
})
}
utxos, change, err := wallet.SelectUtxos(ctx, b.net.AssetID, amountToSelect)
if err != nil { if err != nil {
return "", nil, err return
}
if change > 0 {
outputs = append(outputs, psetv2.OutputArgs{
Asset: b.net.AssetID,
Amount: change,
Script: aspScriptBytes,
})
}
ptx, err := psetv2.New(toInputArgs(utxos), outputs, nil)
if err != nil {
return
}
utx, err := ptx.UnsignedTx()
if err != nil {
return
} }
congestionTree, err = buildCongestionTree( congestionTree, err = buildCongestionTree(
newOutputScriptFactory(aspPubkey, b.net), newOutputScriptFactory(aspPubkey, b.net),
b.net, b.net,
poolTxID, utx.TxHash().String(),
offchainReceivers, offchainReceivers,
) )
if err != nil {
return
}
poolTx, err = ptx.ToBase64()
if err != nil {
return
}
return poolTx, congestionTree, err return poolTx, congestionTree, err
} }
@@ -219,28 +257,15 @@ func sumReceivers(receivers []domain.Receiver) uint64 {
return sum return sum
} }
type output struct { func toInputArgs(
script string ins []ports.TxInput,
amount uint64 ) []psetv2.InputArgs {
asset string inputs := make([]psetv2.InputArgs, 0, len(ins))
} for _, in := range ins {
inputs = append(inputs, psetv2.InputArgs{
func newOutput(script string, amount uint64, asset string) ports.TxOutput { Txid: in.GetTxid(),
return &output{ TxIndex: in.GetIndex(),
script: script, })
amount: amount,
asset: asset,
} }
} return inputs
func (o *output) GetAmount() uint64 {
return o.amount
}
func (o *output) GetAsset() string {
return o.asset
}
func (o *output) GetScript() string {
return o.script
} }

View File

@@ -2,6 +2,8 @@ package txbuilder_test
import ( import (
"context" "context"
"crypto/rand"
"encoding/hex"
"testing" "testing"
"github.com/ark-network/ark/common" "github.com/ark-network/ark/common"
@@ -12,66 +14,37 @@ import (
secp256k1 "github.com/decred/dcrd/dcrec/secp256k1/v4" secp256k1 "github.com/decred/dcrd/dcrec/secp256k1/v4"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/vulpemventures/go-elements/network" "github.com/vulpemventures/go-elements/network"
"github.com/vulpemventures/go-elements/payment"
"github.com/vulpemventures/go-elements/psetv2" "github.com/vulpemventures/go-elements/psetv2"
) )
const ( const (
testingKey = "apub1qgvdtj5ttpuhkldavhq8thtm5auyk0ec4dcmrfdgu0u5hgp9we22v3hrs4x" testingKey = "apub1qgvdtj5ttpuhkldavhq8thtm5auyk0ec4dcmrfdgu0u5hgp9we22v3hrs4x"
fakePoolTx = "cHNldP8BAgQCAAAAAQQBAQEFAQMBBgEDAfsEAgAAAAABDiDk7dXxh4KQzgLO8i1ABtaLCe4aPL12GVhN1E9zM1ePLwEPBAAAAAABEAT/////AAEDCOgDAAAAAAAAAQQWABSNnpy01UJqd99eTg2M1IpdKId11gf8BHBzZXQCICWyUQcOKcoZBDzzPM1zJOLdqwPsxK4LXnfE/A5c9slaB/wEcHNldAgEAAAAAAABAwh4BQAAAAAAAAEEFgAUjZ6ctNVCanffXk4NjNSKXSiHddYH/ARwc2V0AiAlslEHDinKGQQ88zzNcyTi3asD7MSuC153xPwOXPbJWgf8BHBzZXQIBAAAAAAAAQMI9AEAAAAAAAABBAAH/ARwc2V0AiAlslEHDinKGQQ88zzNcyTi3asD7MSuC153xPwOXPbJWgf8BHBzZXQIBAAAAAAA"
) )
func createTestPoolTx(sharedOutputAmount, numberOfInputs uint64) (string, error) { type input struct {
_, key, err := common.DecodePubKey(testingKey) txid string
if err != nil { vout uint32
return "", err }
}
payment := payment.FromPublicKey(key, &network.Testnet, nil) func (i *input) GetTxid() string {
script := payment.WitnessScript return i.txid
}
pset, err := psetv2.New(nil, nil, nil) func (i *input) GetIndex() uint32 {
if err != nil { return i.vout
return "", err }
}
updater, err := psetv2.NewUpdater(pset) func (i *input) GetScript() string {
if err != nil { return "a914ea9f486e82efb3dd83a69fd96e3f0113757da03c87"
return "", err }
}
err = updater.AddInputs([]psetv2.InputArgs{ func (i *input) GetAsset() string {
{ return "5ac9f65c0efcc4775e0baec4ec03abdde22473cd3cf33c0419ca290e0751b225"
Txid: "2f8f5733734fd44d581976bd3c1aee098bd606402df2ce02ce908287f1d5ede4", }
TxIndex: 0,
},
})
if err != nil {
return "", err
}
connectorsAmount := numberOfInputs * (450 + 500) func (i *input) GetValue() uint64 {
return 1000
err = updater.AddOutputs([]psetv2.OutputArgs{
{
Asset: network.Regtest.AssetID,
Amount: sharedOutputAmount,
Script: script,
},
{
Asset: network.Regtest.AssetID,
Amount: connectorsAmount,
Script: script,
},
{
Asset: network.Regtest.AssetID,
Amount: 500,
},
})
if err != nil {
return "", err
}
return pset.ToBase64()
} }
type mockedWalletService struct{} type mockedWalletService struct{}
@@ -106,9 +79,22 @@ func (*mockedWalletService) Status(ctx context.Context) (ports.WalletStatus, err
panic("unimplemented") panic("unimplemented")
} }
// Transfer implements ports.WalletService. func (*mockedWalletService) SelectUtxos(ctx context.Context, asset string, amount uint64) ([]ports.TxInput, uint64, error) {
func (*mockedWalletService) Transfer(ctx context.Context, outs []ports.TxOutput) (string, error) { // random txid
return createTestPoolTx(1000, (450+500)*1) bytes := make([]byte, 32)
if _, err := rand.Read(bytes); err != nil {
return nil, 0, err
}
fakeInput := input{
txid: hex.EncodeToString(bytes),
vout: 0,
}
return []ports.TxInput{&fakeInput}, 0, nil
}
func (*mockedWalletService) EstimateFees(ctx context.Context, pset string) (uint64, error) {
return 100, nil
} }
func TestBuildCongestionTree(t *testing.T) { func TestBuildCongestionTree(t *testing.T) {
@@ -290,10 +276,7 @@ func TestBuildCongestionTree(t *testing.T) {
func TestBuildForfeitTxs(t *testing.T) { func TestBuildForfeitTxs(t *testing.T) {
builder := txbuilder.NewTxBuilder(network.Liquid) builder := txbuilder.NewTxBuilder(network.Liquid)
poolTx, err := createTestPoolTx(1000, 450*2) poolPset, err := psetv2.NewPsetFromBase64(fakePoolTx)
require.NoError(t, err)
poolPset, err := psetv2.NewPsetFromBase64(poolTx)
require.NoError(t, err) require.NoError(t, err)
poolTxUnsigned, err := poolPset.UnsignedTx() poolTxUnsigned, err := poolPset.UnsignedTx()
@@ -355,7 +338,7 @@ func TestBuildForfeitTxs(t *testing.T) {
for _, f := range fixtures { for _, f := range fixtures {
connectors, forfeitTxs, err := builder.BuildForfeitTxs( connectors, forfeitTxs, err := builder.BuildForfeitTxs(
key, poolTx, f.payments, key, fakePoolTx, f.payments,
) )
require.NoError(t, err) require.NoError(t, err)
require.Len(t, connectors, f.expectedNumOfConnectors) require.Len(t, connectors, f.expectedNumOfConnectors)

View File

@@ -9,42 +9,41 @@ import (
"github.com/btcsuite/btcd/btcec/v2/schnorr" "github.com/btcsuite/btcd/btcec/v2/schnorr"
"github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/decred/dcrd/dcrec/secp256k1/v4" "github.com/decred/dcrd/dcrec/secp256k1/v4"
"github.com/vulpemventures/go-elements/elementsutil"
"github.com/vulpemventures/go-elements/psetv2" "github.com/vulpemventures/go-elements/psetv2"
"github.com/vulpemventures/go-elements/taproot" "github.com/vulpemventures/go-elements/taproot"
"github.com/vulpemventures/go-elements/transaction"
) )
var ( var (
ErrInvalidPoolTransaction = errors.New("invalid pool transaction") ErrInvalidPoolTransaction = errors.New("invalid pool transaction")
ErrEmptyTree = errors.New("empty congestion tree") ErrInvalidPoolTransactionOutputs = errors.New("invalid number of outputs in pool transaction")
ErrInvalidRootLevel = errors.New("root level must have only one node") ErrEmptyTree = errors.New("empty congestion tree")
ErrNoLeaves = errors.New("no leaves in the tree") ErrInvalidRootLevel = errors.New("root level must have only one node")
ErrNodeTransactionEmpty = errors.New("node transaction is empty") ErrNoLeaves = errors.New("no leaves in the tree")
ErrNodeTxidEmpty = errors.New("node txid is empty") ErrNodeTransactionEmpty = errors.New("node transaction is empty")
ErrNodeParentTxidEmpty = errors.New("node parent txid is empty") ErrNodeTxidEmpty = errors.New("node txid is empty")
ErrNodeTxidDifferent = errors.New("node txid differs from node transaction") ErrNodeParentTxidEmpty = errors.New("node parent txid is empty")
ErrNumberOfInputs = errors.New("node transaction should have only one input") ErrNodeTxidDifferent = errors.New("node txid differs from node transaction")
ErrNumberOfOutputs = errors.New("node transaction should have only three or two outputs") ErrNumberOfInputs = errors.New("node transaction should have only one input")
ErrParentTxidInput = errors.New("parent txid should be the input of the node transaction") ErrNumberOfOutputs = errors.New("node transaction should have only three or two outputs")
ErrNumberOfChildren = errors.New("node branch transaction should have two children") ErrParentTxidInput = errors.New("parent txid should be the input of the node transaction")
ErrLeafChildren = errors.New("leaf node should have max 1 child") ErrNumberOfChildren = errors.New("node branch transaction should have two children")
ErrInvalidChildTxid = errors.New("invalid child txid") ErrLeafChildren = errors.New("leaf node should have max 1 child")
ErrNumberOfTapscripts = errors.New("input should have two tapscripts leaves") ErrInvalidChildTxid = errors.New("invalid child txid")
ErrInternalKey = errors.New("taproot internal key is not unspendable") ErrNumberOfTapscripts = errors.New("input should have two tapscripts leaves")
ErrInvalidTaprootScript = errors.New("invalid taproot script") ErrInternalKey = errors.New("taproot internal key is not unspendable")
ErrInvalidLeafTaprootScript = errors.New("invalid leaf taproot script") ErrInvalidTaprootScript = errors.New("invalid taproot script")
ErrInvalidAmount = errors.New("children amount is different from parent amount") ErrInvalidLeafTaprootScript = errors.New("invalid leaf taproot script")
ErrInvalidAsset = errors.New("invalid output asset") ErrInvalidAmount = errors.New("children amount is different from parent amount")
ErrInvalidSweepSequence = errors.New("invalid sweep sequence") ErrInvalidAsset = errors.New("invalid output asset")
ErrInvalidASP = errors.New("invalid ASP") ErrInvalidSweepSequence = errors.New("invalid sweep sequence")
ErrMissingFeeOutput = errors.New("missing fee output") ErrInvalidASP = errors.New("invalid ASP")
ErrInvalidLeftOutput = errors.New("invalid left output") ErrMissingFeeOutput = errors.New("missing fee output")
ErrInvalidRightOutput = errors.New("invalid right output") ErrInvalidLeftOutput = errors.New("invalid left output")
ErrMissingSweepTapscript = errors.New("missing sweep tapscript") ErrInvalidRightOutput = errors.New("invalid right output")
ErrMissingBranchTapscript = errors.New("missing branch tapscript") ErrMissingSweepTapscript = errors.New("missing sweep tapscript")
ErrInvalidLeaf = errors.New("leaf node shouldn't have children") ErrMissingBranchTapscript = errors.New("missing branch tapscript")
ErrWrongPoolTxID = errors.New("root input should be the pool tx outpoint") ErrInvalidLeaf = errors.New("leaf node shouldn't have children")
ErrWrongPoolTxID = errors.New("root input should be the pool tx outpoint")
) )
const ( const (
@@ -63,24 +62,30 @@ const (
// - input and output amounts // - input and output amounts
func ValidateCongestionTree( func ValidateCongestionTree(
tree CongestionTree, tree CongestionTree,
poolTxHex string, poolTx string,
aspPublicKey *secp256k1.PublicKey, aspPublicKey *secp256k1.PublicKey,
roundLifetimeSeconds uint, roundLifetimeSeconds uint,
) error { ) error {
unspendableKeyBytes, _ := hex.DecodeString(UnspendablePoint) unspendableKeyBytes, _ := hex.DecodeString(UnspendablePoint)
unspendableKey, _ := secp256k1.ParsePubKey(unspendableKeyBytes) unspendableKey, _ := secp256k1.ParsePubKey(unspendableKeyBytes)
poolTransaction, err := transaction.NewTxFromHex(poolTxHex) poolTransaction, err := psetv2.NewPsetFromBase64(poolTx)
if err != nil { if err != nil {
return ErrInvalidPoolTransaction return ErrInvalidPoolTransaction
} }
poolTxAmount, err := elementsutil.ValueFromBytes(poolTransaction.Outputs[sharedOutputIndex].Value) if len(poolTransaction.Outputs) < sharedOutputIndex+1 {
return ErrInvalidPoolTransactionOutputs
}
poolTxAmount := poolTransaction.Outputs[sharedOutputIndex].Value
utx, err := poolTransaction.UnsignedTx()
if err != nil { if err != nil {
return ErrInvalidPoolTransaction return ErrInvalidPoolTransaction
} }
poolTxID := poolTransaction.TxHash().String() poolTxID := utx.TxHash().String()
nbNodes := tree.NumberOfNodes() nbNodes := tree.NumberOfNodes()
if nbNodes == 0 { if nbNodes == 0 {

View File

@@ -19,12 +19,10 @@ import (
"github.com/decred/dcrd/dcrec/secp256k1/v4" "github.com/decred/dcrd/dcrec/secp256k1/v4"
"github.com/urfave/cli/v2" "github.com/urfave/cli/v2"
"github.com/vulpemventures/go-elements/address" "github.com/vulpemventures/go-elements/address"
"github.com/vulpemventures/go-elements/elementsutil"
"github.com/vulpemventures/go-elements/network" "github.com/vulpemventures/go-elements/network"
"github.com/vulpemventures/go-elements/payment" "github.com/vulpemventures/go-elements/payment"
"github.com/vulpemventures/go-elements/psetv2" "github.com/vulpemventures/go-elements/psetv2"
"github.com/vulpemventures/go-elements/taproot" "github.com/vulpemventures/go-elements/taproot"
"github.com/vulpemventures/go-elements/transaction"
"golang.org/x/term" "golang.org/x/term"
) )
@@ -87,7 +85,7 @@ func privateKeyFromPassword() (*secp256k1.PrivateKey, error) {
encryptedPrivateKey, err := hex.DecodeString(encryptedPrivateKeyString) encryptedPrivateKey, err := hex.DecodeString(encryptedPrivateKeyString)
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("invalid encrypted private key: %s", err)
} }
password, err := readPassword() password, err := readPassword()
@@ -384,9 +382,10 @@ func handleRoundStream(
if event.GetRoundFinalization() != nil { if event.GetRoundFinalization() != nil {
// stop pinging as soon as we receive some forfeit txs // stop pinging as soon as we receive some forfeit txs
pingStop() pingStop()
fmt.Println("round finalization started")
poolPartialTx := event.GetRoundFinalization().GetPoolPartialTx() poolPartialTx := event.GetRoundFinalization().GetPoolPartialTx()
poolTransaction, err := transaction.NewTxFromHex(poolPartialTx) poolTransaction, err := psetv2.NewPsetFromBase64(poolPartialTx)
if err != nil { if err != nil {
return "", err return "", err
} }
@@ -429,13 +428,8 @@ func handleRoundStream(
found := false found := false
for _, output := range poolTransaction.Outputs { for _, output := range poolTransaction.Outputs {
if bytes.Equal(output.Script, onchainScript) { if bytes.Equal(output.Script, onchainScript) {
outputValue, err := elementsutil.ValueFromBytes(output.Value) if output.Value != receiver.Amount {
if err != nil { return "", fmt.Errorf("invalid collaborative exit output amount: got %d, want %d", output.Value, receiver.Amount)
return "", err
}
if outputValue != receiver.Amount {
return "", fmt.Errorf("invalid collaborative exit output amount: got %d, want %d", outputValue, receiver.Amount)
} }
found = true found = true