Prevent getting cheated by broadcasting forfeit transactions (#123)

* broadcast forfeit transaction in case the user is trying the cheat the ASP

* fix connector input + --cheat flag in CLI

* WIP

* cleaning and fixes

* add TODO

* sweeper.go: mark round swept if vtxo are redeemed

* fixes after reviews

* revert "--cheat" flag in client

* revert redeem.go

* optimization

* update account.go according to ocean ListUtxos new spec

* WaitForSync implementation

* ocean-wallet/service.go: remove go rountine while writing to notification channel
This commit is contained in:
Louis Singer
2024-03-04 13:58:36 +01:00
committed by GitHub
parent 6d0d03e316
commit 066e8eeabb
21 changed files with 537 additions and 120 deletions

View File

@@ -5,6 +5,7 @@ import (
"context"
"encoding/hex"
"fmt"
"sync"
"time"
"github.com/ark-network/ark/common"
@@ -286,7 +287,7 @@ func (s *service) startFinalization() {
return
}
unsignedPoolTx, tree, err := s.builder.BuildPoolTx(s.pubkey, payments, s.minRelayFee)
unsignedPoolTx, tree, connectorAddress, err := s.builder.BuildPoolTx(s.pubkey, payments, s.minRelayFee)
if err != nil {
changes = round.Fail(fmt.Errorf("failed to create pool tx: %s", err))
log.WithError(err).Warn("failed to create pool tx")
@@ -295,7 +296,7 @@ func (s *service) startFinalization() {
log.Debugf("pool tx created for round %s", round.Id)
connectors, forfeitTxs, err := s.builder.BuildForfeitTxs(s.pubkey, unsignedPoolTx, payments)
connectors, forfeitTxs, err := s.builder.BuildForfeitTxs(s.pubkey, unsignedPoolTx, payments, s.minRelayFee)
if err != nil {
changes = round.Fail(fmt.Errorf("failed to create connectors and forfeit txs: %s", err))
log.WithError(err).Warn("failed to create connectors and forfeit txs")
@@ -304,7 +305,7 @@ func (s *service) startFinalization() {
log.Debugf("forfeit transactions created for round %s", round.Id)
events, err := round.StartFinalization(connectors, tree, unsignedPoolTx)
events, err := round.StartFinalization(connectorAddress, connectors, tree, unsignedPoolTx)
if err != nil {
changes = round.Fail(fmt.Errorf("failed to start finalization: %s", err))
log.WithError(err).Warn("failed to start finalization")
@@ -369,25 +370,162 @@ func (s *service) finalizeRound() {
func (s *service) listenToRedemptions() {
ctx := context.Background()
chVtxos := s.scanner.GetNotificationChannel(ctx)
mutx := &sync.Mutex{}
for vtxoKeys := range chVtxos {
if len(vtxoKeys) > 0 {
for {
// TODO: make sure that the vtxos haven't been already spent, otherwise
// broadcast the corresponding forfeit tx and connector to prevent
// getting cheated.
vtxos, err := s.repoManager.Vtxos().RedeemVtxos(ctx, vtxoKeys)
go func(vtxoKeys []domain.VtxoKey) {
vtxosRepo := s.repoManager.Vtxos()
roundRepo := s.repoManager.Rounds()
for _, v := range vtxoKeys {
vtxos, err := vtxosRepo.GetVtxos(ctx, []domain.VtxoKey{v})
if err != nil {
log.WithError(err).Warn("failed to redeem vtxos, retrying...")
time.Sleep(100 * time.Millisecond)
log.WithError(err).Warn("failed to retrieve vtxos, skipping...")
continue
}
if len(vtxos) > 0 {
log.Debugf("redeemed %d vtxos", len(vtxos))
vtxo := vtxos[0]
if vtxo.Redeemed {
continue
}
break
if _, err := s.repoManager.Vtxos().RedeemVtxos(ctx, []domain.VtxoKey{vtxo.VtxoKey}); err != nil {
log.WithError(err).Warn("failed to redeem vtxos, retrying...")
continue
}
log.Debugf("vtxo %s redeemed", vtxo.Txid)
if !vtxo.Spent {
continue
}
log.Debugf("fraud detected on vtxo %s", vtxo.Txid)
round, err := roundRepo.GetRoundWithTxid(ctx, vtxo.SpentBy)
if err != nil {
log.WithError(err).Warn("failed to retrieve round")
continue
}
mutx.Lock()
defer mutx.Unlock()
connectorTxid, connectorVout, err := s.getNextConnector(ctx, *round)
if err != nil {
log.WithError(err).Warn("failed to retrieve next connector")
continue
}
forfeitTx, err := findForfeitTx(round.ForfeitTxs, connectorTxid, connectorVout, vtxo.Txid)
if err != nil {
log.WithError(err).Warn("failed to retrieve forfeit tx")
continue
}
signedForfeitTx, err := s.wallet.SignConnectorInput(ctx, forfeitTx, []int{0}, false)
if err != nil {
log.WithError(err).Warn("failed to sign connector input in forfeit tx")
continue
}
signedForfeitTx, err = s.wallet.SignPsetWithKey(ctx, signedForfeitTx, []int{1})
if err != nil {
log.WithError(err).Warn("failed to sign vtxo input in forfeit tx")
continue
}
forfeitTxHex, err := finalizeAndExtractForfeit(signedForfeitTx)
if err != nil {
log.WithError(err).Warn("failed to finalize forfeit tx")
continue
}
forfeitTxid, err := s.wallet.BroadcastTransaction(ctx, forfeitTxHex)
if err != nil {
log.WithError(err).Warn("failed to broadcast forfeit tx")
continue
}
log.Debugf("broadcasted forfeit tx %s", forfeitTxid)
}
}(vtxoKeys)
}
}
func (s *service) getNextConnector(
ctx context.Context,
round domain.Round,
) (string, uint32, error) {
connectorTx, err := psetv2.NewPsetFromBase64(round.Connectors[0])
if err != nil {
return "", 0, err
}
prevout := connectorTx.Inputs[0].WitnessUtxo
if prevout == nil {
return "", 0, fmt.Errorf("connector prevout not found")
}
utxos, err := s.wallet.ListConnectorUtxos(ctx, round.ConnectorAddress)
if err != nil {
return "", 0, err
}
// if we do not find any utxos, we make sure to wait for the connector outpoint to be confirmed then we retry
if len(utxos) <= 0 {
if err := s.wallet.WaitForSync(ctx, round.Txid); err != nil {
return "", 0, err
}
utxos, err = s.wallet.ListConnectorUtxos(ctx, round.ConnectorAddress)
if err != nil {
return "", 0, err
}
}
// search for an already existing connector
for _, u := range utxos {
if u.GetValue() == 450 {
return u.GetTxid(), u.GetIndex(), nil
}
}
for _, u := range utxos {
if u.GetValue() > 450 {
for _, b64 := range round.Connectors {
pset, err := psetv2.NewPsetFromBase64(b64)
if err != nil {
return "", 0, err
}
for _, i := range pset.Inputs {
if chainhash.Hash(i.PreviousTxid).String() == u.GetTxid() && i.PreviousTxIndex == u.GetIndex() {
// sign & broadcast the connector tx
signedConnectorTx, err := s.wallet.SignConnectorInput(ctx, b64, []int{0}, true)
if err != nil {
return "", 0, err
}
connectorTxid, err := s.wallet.BroadcastTransaction(ctx, signedConnectorTx)
if err != nil {
return "", 0, err
}
log.Debugf("broadcasted connector tx %s", connectorTxid)
// wait for the connector tx to be in the mempool
if err := s.wallet.WaitForSync(ctx, connectorTxid); err != nil {
return "", 0, err
}
return connectorTxid, 0, nil
}
}
}
}
}
return "", 0, fmt.Errorf("no connector utxos found")
}
func (s *service) updateVtxoSet(round *domain.Round) {
@@ -401,7 +539,7 @@ func (s *service) updateVtxoSet(round *domain.Round) {
spentVtxos := getSpentVtxos(round.Payments)
if len(spentVtxos) > 0 {
for {
if err := repo.SpendVtxos(ctx, spentVtxos); err != nil {
if err := repo.SpendVtxos(ctx, spentVtxos, round.Txid); err != nil {
log.WithError(err).Warn("failed to add new vtxos, retrying soon")
time.Sleep(100 * time.Millisecond)
continue
@@ -535,12 +673,28 @@ func (s *service) stopWatchingVtxos(vtxos []domain.Vtxo) {
}
func (s *service) restoreWatchingVtxos() error {
vtxos, err := s.repoManager.Vtxos().GetSpendableVtxos(
context.Background(), "",
)
sweepableRounds, err := s.repoManager.Rounds().GetSweepableRounds(context.Background())
if err != nil {
return err
}
vtxos := make([]domain.Vtxo, 0)
for _, round := range sweepableRounds {
fromRound, err := s.repoManager.Vtxos().GetVtxosForRound(
context.Background(), round.Txid,
)
if err != nil {
log.WithError(err).Warnf("failed to retrieve vtxos for round %s", round.Txid)
continue
}
for _, v := range fromRound {
if !v.Swept && !v.Redeemed {
vtxos = append(vtxos, v)
}
}
}
if len(vtxos) <= 0 {
return nil
}
@@ -617,3 +771,45 @@ func getPaymentsFromOnboarding(
payment := domain.NewPaymentUnsafe(nil, receivers)
return []domain.Payment{*payment}
}
func finalizeAndExtractForfeit(b64 string) (string, error) {
p, err := psetv2.NewPsetFromBase64(b64)
if err != nil {
return "", err
}
// finalize connector input
if err := psetv2.FinalizeAll(p); err != nil {
return "", err
}
// extract the forfeit tx
extracted, err := psetv2.Extract(p)
if err != nil {
return "", err
}
return extracted.ToHex()
}
func findForfeitTx(
forfeits []string, connectorTxid string, connectorVout uint32, vtxoTxid string,
) (string, error) {
for _, forfeit := range forfeits {
forfeitTx, err := psetv2.NewPsetFromBase64(forfeit)
if err != nil {
return "", err
}
connector := forfeitTx.Inputs[0]
vtxoInput := forfeitTx.Inputs[1]
if chainhash.Hash(connector.PreviousTxid).String() == connectorTxid &&
connector.PreviousTxIndex == connectorVout &&
chainhash.Hash(vtxoInput.PreviousTxid).String() == vtxoTxid {
return forfeit, nil
}
}
return "", fmt.Errorf("forfeit tx not found")
}

View File

@@ -205,9 +205,10 @@ func (s *sweeper) createTask(
}
}
vtxosRepository := s.repoManager.Vtxos()
if len(sweepInputs) > 0 {
// build the sweep transaction with all the expired non-swept shared outputs
sweepTx, err := s.builder.BuildSweepTx(s.wallet, sweepInputs)
sweepTx, err := s.builder.BuildSweepTx(sweepInputs)
if err != nil {
log.WithError(err).Error("error while building sweep tx")
return
@@ -231,7 +232,6 @@ func (s *sweeper) createTask(
}
if len(txid) > 0 {
log.Debugln("sweep tx broadcasted:", txid)
vtxosRepository := s.repoManager.Vtxos()
// mark the vtxos as swept
if err := vtxosRepository.SweepVtxos(ctx, vtxoKeys); err != nil {
@@ -240,6 +240,8 @@ func (s *sweeper) createTask(
}
log.Debugf("%d vtxos swept", len(vtxoKeys))
}
}
roundVtxos, err := vtxosRepository.GetVtxosForRound(ctx, roundTxid)
if err != nil {
@@ -249,7 +251,7 @@ func (s *sweeper) createTask(
allSwept := true
for _, vtxo := range roundVtxos {
allSwept = allSwept && vtxo.Swept
allSwept = allSwept && (vtxo.Swept || vtxo.Redeemed)
if !allSwept {
break
}
@@ -264,6 +266,7 @@ func (s *sweeper) createTask(
return
}
log.Debugf("round %s fully swept", roundTxid)
round.Sweep()
if err := roundRepo.AddOrUpdateRound(ctx, *round); err != nil {
@@ -272,8 +275,6 @@ func (s *sweeper) createTask(
}
}
}
}
}
}
// onchainOutputs iterates over all the nodes' outputs in the congestion tree and checks their onchain state

View File

@@ -201,6 +201,7 @@ func (m *forfeitTxsMap) sign(txs []string) error {
}
if sig.Verify(preimage, pubkey) {
m.forfeitTxs[txid].tx = tx
m.forfeitTxs[txid].signed = true
} else {
return fmt.Errorf("invalid signature")

View File

@@ -21,6 +21,7 @@ type RoundFinalizationStarted struct {
Id string
CongestionTree tree.CongestionTree
Connectors []string
ConnectorAddress string
UnsignedForfeitTxs []string
PoolTx string
}

View File

@@ -131,6 +131,7 @@ type Vtxo struct {
VtxoKey
Receiver
PoolTx string
SpentBy string
Spent bool
Redeemed bool
Swept bool

View File

@@ -44,9 +44,10 @@ type Round struct {
ForfeitTxs []string
CongestionTree tree.CongestionTree
Connectors []string
ConnectorAddress string
DustAmount uint64
Version uint
Swept bool // true if all the vtxos are vtxo.Swept
Swept bool // true if all the vtxos are vtxo.Swept or vtxo.Redeemed
changes []RoundEvent
}
@@ -118,6 +119,7 @@ func (r *Round) On(event RoundEvent, replayed bool) {
r.Stage.Code = FinalizationStage
r.CongestionTree = e.CongestionTree
r.Connectors = append([]string{}, e.Connectors...)
r.ConnectorAddress = e.ConnectorAddress
r.UnsignedTx = e.PoolTx
case RoundFinalized:
r.Stage.Ended = true
@@ -178,7 +180,7 @@ func (r *Round) RegisterPayments(payments []Payment) ([]RoundEvent, error) {
return []RoundEvent{event}, nil
}
func (r *Round) StartFinalization(connectors []string, congestionTree tree.CongestionTree, poolTx string) ([]RoundEvent, error) {
func (r *Round) StartFinalization(connectorAddress string, connectors []string, congestionTree tree.CongestionTree, poolTx string) ([]RoundEvent, error) {
if len(connectors) <= 0 {
return nil, fmt.Errorf("missing list of connectors")
}
@@ -199,6 +201,7 @@ func (r *Round) StartFinalization(connectors []string, congestionTree tree.Conge
Id: r.Id,
CongestionTree: congestionTree,
Connectors: connectors,
ConnectorAddress: connectorAddress,
PoolTx: poolTx,
}
r.raise(event)

View File

@@ -19,7 +19,7 @@ type RoundRepository interface {
type VtxoRepository interface {
AddVtxos(ctx context.Context, vtxos []Vtxo) error
SpendVtxos(ctx context.Context, vtxos []VtxoKey) error
SpendVtxos(ctx context.Context, vtxos []VtxoKey, txid string) error
RedeemVtxos(ctx context.Context, vtxos []VtxoKey) ([]Vtxo, error)
GetVtxos(ctx context.Context, vtxos []VtxoKey) ([]Vtxo, error)
GetVtxosForRound(ctx context.Context, txid string) ([]Vtxo, error)

View File

@@ -296,7 +296,7 @@ func testStartFinalization(t *testing.T) {
require.NoError(t, err)
require.NotEmpty(t, events)
events, err = round.StartFinalization(connectors, congestionTree, poolTx)
events, err = round.StartFinalization("", connectors, congestionTree, poolTx)
require.NoError(t, err)
require.Len(t, events, 1)
require.True(t, round.IsStarted())
@@ -418,7 +418,8 @@ func testStartFinalization(t *testing.T) {
}
for _, f := range fixtures {
events, err := f.round.StartFinalization(f.connectors, f.tree, f.poolTx)
// TODO fix this
events, err := f.round.StartFinalization("", f.connectors, f.tree, f.poolTx)
require.EqualError(t, err, f.expectedErr)
require.Empty(t, events)
}
@@ -438,7 +439,7 @@ func testEndFinalization(t *testing.T) {
require.NoError(t, err)
require.NotEmpty(t, events)
events, err = round.StartFinalization(connectors, congestionTree, poolTx)
events, err = round.StartFinalization("", connectors, congestionTree, poolTx)
require.NoError(t, err)
require.NotEmpty(t, events)

View File

@@ -14,15 +14,8 @@ type SweepInput struct {
}
type TxBuilder interface {
BuildPoolTx(
aspPubkey *secp256k1.PublicKey, payments []domain.Payment, minRelayFee uint64,
) (poolTx string, congestionTree tree.CongestionTree, err error)
BuildForfeitTxs(
aspPubkey *secp256k1.PublicKey, poolTx string, payments []domain.Payment,
) (connectors []string, forfeitTxs []string, err error)
BuildSweepTx(
wallet WalletService,
inputs []SweepInput,
) (signedSweepTx string, err error)
BuildPoolTx(aspPubkey *secp256k1.PublicKey, payments []domain.Payment, minRelayFee uint64) (poolTx string, congestionTree tree.CongestionTree, connectorAddress string, err error)
BuildForfeitTxs(aspPubkey *secp256k1.PublicKey, poolTx string, payments []domain.Payment, minRelayFee uint64) (connectors []string, forfeitTxs []string, err error)
BuildSweepTx(inputs []SweepInput) (signedSweepTx string, err error)
GetVtxoScript(userPubkey, aspPubkey *secp256k1.PublicKey) ([]byte, error)
}

View File

@@ -10,6 +10,7 @@ type WalletService interface {
BlockchainScanner
Status(ctx context.Context) (WalletStatus, error)
GetPubkey(ctx context.Context) (*secp256k1.PublicKey, error)
DeriveConnectorAddress(ctx context.Context) (string, error)
DeriveAddresses(ctx context.Context, num int) ([]string, error)
SignPset(
ctx context.Context, pset string, extractRawTx bool,
@@ -18,7 +19,10 @@ type WalletService interface {
BroadcastTransaction(ctx context.Context, txHex string) (string, error)
SignPsetWithKey(ctx context.Context, pset string, inputIndexes []int) (string, error) // inputIndexes == nil means sign all inputs
IsTransactionConfirmed(ctx context.Context, txid string) (isConfirmed bool, blocktime int64, err error)
WaitForSync(ctx context.Context, txid string) error
EstimateFees(ctx context.Context, pset string) (uint64, error)
SignConnectorInput(ctx context.Context, pset string, inputIndexes []int, extract bool) (string, error)
ListConnectorUtxos(ctx context.Context, connectorAddress string) ([]TxInput, error)
Close()
}

View File

@@ -53,10 +53,10 @@ func (r *vtxoRepository) AddVtxos(
}
func (r *vtxoRepository) SpendVtxos(
ctx context.Context, vtxoKeys []domain.VtxoKey,
ctx context.Context, vtxoKeys []domain.VtxoKey, spentBy string,
) error {
for _, vtxoKey := range vtxoKeys {
if err := r.spendVtxo(ctx, vtxoKey); err != nil {
if err := r.spendVtxo(ctx, vtxoKey, spentBy); err != nil {
return err
}
}
@@ -96,7 +96,7 @@ func (r *vtxoRepository) GetVtxos(
func (r *vtxoRepository) GetVtxosForRound(
ctx context.Context, txid string,
) ([]domain.Vtxo, error) {
query := badgerhold.Where("Txid").Eq(txid)
query := badgerhold.Where("PoolTx").Eq(txid)
return r.findVtxos(ctx, query)
}
@@ -161,7 +161,7 @@ func (r *vtxoRepository) getVtxo(
return &vtxo, nil
}
func (r *vtxoRepository) spendVtxo(ctx context.Context, vtxoKey domain.VtxoKey) error {
func (r *vtxoRepository) spendVtxo(ctx context.Context, vtxoKey domain.VtxoKey, spendBy string) error {
vtxo, err := r.getVtxo(ctx, vtxoKey)
if err != nil {
if strings.Contains(err.Error(), "not found") {
@@ -174,6 +174,7 @@ func (r *vtxoRepository) spendVtxo(ctx context.Context, vtxoKey domain.VtxoKey)
}
vtxo.Spent = true
vtxo.SpentBy = spendBy
if ctx.Value("tx") != nil {
tx := ctx.Value("tx").(*badger.Txn)
err = r.store.TxUpdate(tx, vtxoKey.Hash(), *vtxo)

View File

@@ -362,7 +362,7 @@ func testVtxoRepository(t *testing.T, svc ports.RepoManager) {
require.NoError(t, err)
require.Exactly(t, userVtxos, spendableVtxos)
err = svc.Vtxos().SpendVtxos(ctx, vtxoKeys[:1])
err = svc.Vtxos().SpendVtxos(ctx, vtxoKeys[:1], txid)
require.NoError(t, err)
spentVtxos, err := svc.Vtxos().GetVtxos(ctx, vtxoKeys[:1])

View File

@@ -4,14 +4,52 @@ import (
"context"
pb "github.com/ark-network/ark/api-spec/protobuf/gen/ocean/v1"
"github.com/ark-network/ark/internal/core/ports"
"github.com/vulpemventures/go-elements/address"
)
func (s *service) DeriveAddresses(
ctx context.Context, numOfAddresses int,
) ([]string, error) {
return s.deriveAddresses(ctx, numOfAddresses, arkAccount)
}
func (s *service) DeriveConnectorAddress(ctx context.Context) (string, error) {
addresses, err := s.deriveAddresses(ctx, 1, connectorAccount)
if err != nil {
return "", err
}
return addresses[0], nil
}
func (s *service) ListConnectorUtxos(
ctx context.Context, connectorAddress string,
) ([]ports.TxInput, error) {
res, err := s.accountClient.ListUtxos(ctx, &pb.ListUtxosRequest{
AccountName: connectorAccount,
Addresses: []string{connectorAddress},
})
if err != nil {
return nil, err
}
utxos := make([]ports.TxInput, 0)
for _, utxo := range res.GetSpendableUtxos().GetUtxos() {
utxos = append(utxos, utxo)
}
for _, utxo := range res.GetLockedUtxos().GetUtxos() {
utxos = append(utxos, utxo)
}
return utxos, nil
}
func (s *service) deriveAddresses(
ctx context.Context, numOfAddresses int, account string,
) ([]string, error) {
res, err := s.accountClient.DeriveAddresses(ctx, &pb.DeriveAddressesRequest{
AccountName: accountLabel,
AccountName: account,
NumOfAddresses: uint64(numOfAddresses),
})
if err != nil {

View File

@@ -61,16 +61,35 @@ func NewService(addr string) (ports.WalletService, error) {
return nil, err
}
found := false
mainAccountFound, connectorAccountFound := false, false
for _, account := range info.GetAccounts() {
if account.GetLabel() == accountLabel {
found = true
if account.GetLabel() == arkAccount {
mainAccountFound = true
continue
}
if account.GetLabel() == connectorAccount {
connectorAccountFound = true
continue
}
if mainAccountFound && connectorAccountFound {
break
}
}
if !found {
if !mainAccountFound {
if _, err := accountClient.CreateAccountBIP44(ctx, &pb.CreateAccountBIP44Request{
Label: accountLabel,
Label: arkAccount,
Unconfidential: true,
}); err != nil {
return nil, err
}
}
if !connectorAccountFound {
if _, err := accountClient.CreateAccountBIP44(ctx, &pb.CreateAccountBIP44Request{
Label: connectorAccount,
Unconfidential: true,
}); err != nil {
return nil, err
@@ -114,9 +133,7 @@ func (s *service) listenToNotificaitons() {
}
vtxos := toVtxos(msg.GetUtxos())
if len(vtxos) > 0 {
go func() {
s.chVtxos <- vtxos
}()
}
}
}

View File

@@ -59,8 +59,9 @@ func (s *service) SignPset(
}
func (s *service) SelectUtxos(ctx context.Context, asset string, amount uint64) ([]ports.TxInput, uint64, error) {
// TODO: select coins from the connector account IF the round is swept
res, err := s.txClient.SelectUtxos(ctx, &pb.SelectUtxosRequest{
AccountName: accountLabel,
AccountName: arkAccount,
TargetAsset: asset,
TargetAmount: amount,
})
@@ -81,22 +82,22 @@ func (s *service) SelectUtxos(ctx context.Context, asset string, amount uint64)
return inputs, res.GetChange(), nil
}
func (s *service) GetTransaction(
func (s *service) getTransaction(
ctx context.Context, txid string,
) (string, int64, error) {
) (string, bool, int64, error) {
res, err := s.txClient.GetTransaction(ctx, &pb.GetTransactionRequest{
Txid: txid,
})
if err != nil {
return "", 0, err
return "", false, 0, err
}
if res.GetBlockDetails().GetTimestamp() > 0 {
return res.GetTxHex(), res.BlockDetails.GetTimestamp(), nil
return res.GetTxHex(), true, res.BlockDetails.GetTimestamp(), nil
}
// if not confirmed, we return now + 30 secs to estimate the next blocktime
return res.GetTxHex(), time.Now().Unix() + 30, nil
// if not confirmed, we return now + 1 min to estimate the next blocktime
return res.GetTxHex(), false, time.Now().Add(time.Minute).Unix(), nil
}
func (s *service) BroadcastTransaction(
@@ -120,15 +121,29 @@ func (s *service) BroadcastTransaction(
func (s *service) IsTransactionConfirmed(
ctx context.Context, txid string,
) (bool, int64, error) {
_, blocktime, err := s.GetTransaction(ctx, txid)
_, isConfirmed, blocktime, err := s.getTransaction(ctx, txid)
if err != nil {
if strings.Contains(strings.ToLower(err.Error()), "missing transaction") {
return false, 0, nil
return isConfirmed, 0, nil
}
return false, 0, err
}
return true, blocktime, nil
return isConfirmed, blocktime, nil
}
func (s *service) WaitForSync(ctx context.Context, txid string) error {
for {
time.Sleep(5 * time.Second)
_, _, _, err := s.getTransaction(ctx, txid)
if err != nil {
if strings.Contains(strings.ToLower(err.Error()), "missing transaction") {
continue
}
return err
}
break
}
return nil
}
func (s *service) SignPsetWithKey(ctx context.Context, b64 string, indexes []int) (string, error) {
@@ -206,6 +221,38 @@ func (s *service) SignPsetWithKey(ctx context.Context, b64 string, indexes []int
return signedPset.GetSignedTx(), nil
}
func (s *service) SignConnectorInput(ctx context.Context, pset string, inputIndexes []int, extract bool) (string, error) {
decodedTx, err := psetv2.NewPsetFromBase64(pset)
if err != nil {
return "", err
}
utxos := make([]*pb.Input, 0, len(decodedTx.Inputs))
for i := range inputIndexes {
if i >= len(decodedTx.Inputs) {
return "", fmt.Errorf("input index %d out of range", i)
}
input := decodedTx.Inputs[i]
utxos = append(utxos, &pb.Input{
Txid: chainhash.Hash(input.PreviousTxid).String(),
Index: input.PreviousTxIndex,
})
}
_, err = s.txClient.LockUtxos(ctx, &pb.LockUtxosRequest{
AccountName: connectorAccount,
Utxos: utxos,
})
if err != nil {
return "", err
}
return s.SignPset(ctx, pset, extract)
}
func (s *service) EstimateFees(
ctx context.Context, pset string,
) (uint64, error) {

View File

@@ -11,7 +11,10 @@ import (
"github.com/vulpemventures/go-bip32"
)
const accountLabel = "ark"
const (
arkAccount = "ark"
connectorAccount = "ark-connector"
)
var derivationPath = []uint32{0, 0}
@@ -63,7 +66,7 @@ func (s *service) findAccount(ctx context.Context, label string) (*pb.AccountInf
}
func (s *service) getPubkey(ctx context.Context) (*secp256k1.PublicKey, *bip32.Key, error) {
account, err := s.findAccount(ctx, accountLabel)
account, err := s.findAccount(ctx, arkAccount)
if err != nil {
return nil, nil, err
}

View File

@@ -11,6 +11,7 @@ import (
"github.com/decred/dcrd/dcrec/secp256k1/v4"
"github.com/vulpemventures/go-elements/address"
"github.com/vulpemventures/go-elements/network"
"github.com/vulpemventures/go-elements/payment"
"github.com/vulpemventures/go-elements/psetv2"
"github.com/vulpemventures/go-elements/taproot"
)
@@ -41,9 +42,9 @@ func (b *txBuilder) GetVtxoScript(userPubkey, aspPubkey *secp256k1.PublicKey) ([
return outputScript, nil
}
func (b *txBuilder) BuildSweepTx(wallet ports.WalletService, inputs []ports.SweepInput) (signedSweepTx string, err error) {
func (b *txBuilder) BuildSweepTx(inputs []ports.SweepInput) (signedSweepTx string, err error) {
sweepPset, err := sweepTransaction(
wallet,
b.wallet,
inputs,
b.net.AssetID,
)
@@ -57,7 +58,7 @@ func (b *txBuilder) BuildSweepTx(wallet ports.WalletService, inputs []ports.Swee
}
ctx := context.Background()
signedSweepPsetB64, err := wallet.SignPsetWithKey(ctx, sweepPsetBase64, nil)
signedSweepPsetB64, err := b.wallet.SignPsetWithKey(ctx, sweepPsetBase64, nil)
if err != nil {
return "", err
}
@@ -80,9 +81,14 @@ func (b *txBuilder) BuildSweepTx(wallet ports.WalletService, inputs []ports.Swee
}
func (b *txBuilder) BuildForfeitTxs(
aspPubkey *secp256k1.PublicKey, poolTx string, payments []domain.Payment,
aspPubkey *secp256k1.PublicKey, poolTx string, payments []domain.Payment, minRelayFee uint64,
) (connectors []string, forfeitTxs []string, err error) {
connectorTxs, err := b.createConnectors(poolTx, payments, aspPubkey)
connectorAddress, err := b.getConnectorAddress(poolTx)
if err != nil {
return nil, nil, err
}
connectorTxs, err := b.createConnectors(poolTx, payments, connectorAddress, minRelayFee)
if err != nil {
return nil, nil, err
}
@@ -101,7 +107,7 @@ func (b *txBuilder) BuildForfeitTxs(
func (b *txBuilder) BuildPoolTx(
aspPubkey *secp256k1.PublicKey, payments []domain.Payment, minRelayFee uint64,
) (poolTx string, congestionTree tree.CongestionTree, err error) {
) (poolTx string, congestionTree tree.CongestionTree, connectorAddress string, err error) {
// The creation of the tree and the pool tx are tightly coupled:
// - building the tree requires knowing the shared outpoint (txid:vout)
// - building the pool tx requires knowing the shared output script and amount
@@ -120,8 +126,13 @@ func (b *txBuilder) BuildPoolTx(
return
}
connectorAddress, err = b.wallet.DeriveConnectorAddress(context.Background())
if err != nil {
return
}
ptx, err := b.createPoolTx(
sharedOutputAmount, sharedOutputScript, payments, aspPubkey,
sharedOutputAmount, sharedOutputScript, payments, aspPubkey, connectorAddress, minRelayFee,
)
if err != nil {
return
@@ -190,15 +201,24 @@ func (b *txBuilder) getLeafScriptAndTree(
func (b *txBuilder) createPoolTx(
sharedOutputAmount uint64, sharedOutputScript []byte,
payments []domain.Payment, aspPubKey *secp256k1.PublicKey,
payments []domain.Payment, aspPubKey *secp256k1.PublicKey, connectorAddress string, minRelayFee uint64,
) (*psetv2.Pset, error) {
aspScript, err := p2wpkhScript(aspPubKey, b.net)
if err != nil {
return nil, err
}
connectorScript, err := address.ToOutputScript(connectorAddress)
if err != nil {
return nil, err
}
receivers := getOnchainReceivers(payments)
connectorsAmount := connectorAmount * countSpentVtxos(payments)
nbOfInputs := countSpentVtxos(payments)
connectorsAmount := (connectorAmount + minRelayFee) * nbOfInputs
if nbOfInputs > 1 {
connectorsAmount -= minRelayFee
}
targetAmount := sharedOutputAmount + connectorsAmount
outputs := []psetv2.OutputArgs{
@@ -210,7 +230,7 @@ func (b *txBuilder) createPoolTx(
{
Asset: b.net.AssetID,
Amount: connectorsAmount,
Script: aspScript,
Script: connectorScript,
},
}
@@ -354,11 +374,11 @@ func (b *txBuilder) createPoolTx(
}
func (b *txBuilder) createConnectors(
poolTx string, payments []domain.Payment, aspPubkey *secp256k1.PublicKey,
poolTx string, payments []domain.Payment, connectorAddress string, minRelayFee uint64,
) ([]*psetv2.Pset, error) {
txid, _ := getTxid(poolTx)
aspScript, err := p2wpkhScript(aspPubkey, b.net)
aspScript, err := address.ToOutputScript(connectorAddress)
if err != nil {
return nil, err
}
@@ -378,7 +398,7 @@ func (b *txBuilder) createConnectors(
if numberOfConnectors == 1 {
outputs := []psetv2.OutputArgs{connectorOutput}
connectorTx, err := craftConnectorTx(previousInput, outputs)
connectorTx, err := craftConnectorTx(previousInput, aspScript, outputs, minRelayFee)
if err != nil {
return nil, err
}
@@ -386,12 +406,16 @@ func (b *txBuilder) createConnectors(
return []*psetv2.Pset{connectorTx}, nil
}
totalConnectorAmount := connectorAmount * numberOfConnectors
totalConnectorAmount := (connectorAmount + minRelayFee) * numberOfConnectors
if numberOfConnectors > 1 {
totalConnectorAmount -= minRelayFee
}
connectors := make([]*psetv2.Pset, 0, numberOfConnectors-1)
for i := uint64(0); i < numberOfConnectors-1; i++ {
outputs := []psetv2.OutputArgs{connectorOutput}
totalConnectorAmount -= connectorAmount
totalConnectorAmount -= minRelayFee
if totalConnectorAmount > 0 {
outputs = append(outputs, psetv2.OutputArgs{
Asset: b.net.AssetID,
@@ -399,7 +423,7 @@ func (b *txBuilder) createConnectors(
Amount: totalConnectorAmount,
})
}
connectorTx, err := craftConnectorTx(previousInput, outputs)
connectorTx, err := craftConnectorTx(previousInput, aspScript, outputs, minRelayFee)
if err != nil {
return nil, err
}
@@ -473,3 +497,23 @@ func (b *txBuilder) createForfeitTxs(
}
return forfeitTxs, nil
}
func (b *txBuilder) getConnectorAddress(poolTx string) (string, error) {
pset, err := psetv2.NewPsetFromBase64(poolTx)
if err != nil {
return "", err
}
if len(pset.Outputs) < 1 {
return "", fmt.Errorf("connector output not found in pool tx")
}
connectorOutput := pset.Outputs[1]
pay, err := payment.FromScript(connectorOutput.Script, b.net, nil)
if err != nil {
return "", err
}
return pay.WitnessPubKeyHash()
}

View File

@@ -21,6 +21,7 @@ import (
const (
testingKey = "0218d5ca8b58797b7dbd65c075dd7ba7784b3f38ab71b1a5a8e3f94ba0257654a6"
connectorAddress = "tex1qekd5u0qj8jl07vy60830xy7n9qtmcx9u3s0cqc"
minRelayFee = uint64(30)
roundLifetime = int64(1209344)
unilateralExitDelay = int64(512)
@@ -37,6 +38,8 @@ func TestMain(m *testing.M) {
Return(uint64(100), nil)
wallet.On("SelectUtxos", mock.Anything, mock.Anything, mock.Anything).
Return(randomInput, uint64(0), nil)
wallet.On("DeriveConnectorAddress", mock.Anything).
Return(connectorAddress, nil)
pubkeyBytes, _ := hex.DecodeString(testingKey)
pubkey, _ = secp256k1.ParsePubKey(pubkeyBytes)
@@ -56,12 +59,13 @@ func TestBuildPoolTx(t *testing.T) {
if len(fixtures.Valid) > 0 {
t.Run("valid", func(t *testing.T) {
for _, f := range fixtures.Valid {
poolTx, congestionTree, err := builder.BuildPoolTx(
poolTx, congestionTree, connAddr, err := builder.BuildPoolTx(
pubkey, f.Payments, minRelayFee,
)
require.NoError(t, err)
require.NotEmpty(t, poolTx)
require.NotEmpty(t, congestionTree)
require.Equal(t, connectorAddress, connAddr)
require.Equal(t, f.ExpectedNumOfNodes, congestionTree.NumberOfNodes())
require.Len(t, congestionTree.Leaves(), f.ExpectedNumOfLeaves)
@@ -76,11 +80,12 @@ func TestBuildPoolTx(t *testing.T) {
if len(fixtures.Invalid) > 0 {
t.Run("invalid", func(t *testing.T) {
for _, f := range fixtures.Invalid {
poolTx, congestionTree, err := builder.BuildPoolTx(
poolTx, congestionTree, connAddr, err := builder.BuildPoolTx(
pubkey, f.Payments, minRelayFee,
)
require.EqualError(t, err, f.ExpectedErr)
require.Empty(t, poolTx)
require.Empty(t, connAddr)
require.Empty(t, congestionTree)
}
})
@@ -100,7 +105,7 @@ func TestBuildForfeitTxs(t *testing.T) {
t.Run("valid", func(t *testing.T) {
for _, f := range fixtures.Valid {
connectors, forfeitTxs, err := builder.BuildForfeitTxs(
pubkey, f.PoolTx, f.Payments,
pubkey, f.PoolTx, f.Payments, minRelayFee,
)
require.NoError(t, err)
require.Len(t, connectors, f.ExpectedNumOfConnectors)
@@ -114,7 +119,7 @@ func TestBuildForfeitTxs(t *testing.T) {
require.NotNil(t, tx)
require.Len(t, tx.Inputs, 1)
require.Len(t, tx.Outputs, 2)
require.Len(t, tx.Outputs, 3)
inputTxid := chainhash.Hash(tx.Inputs[0].PreviousTxid).String()
require.Equal(t, expectedInputTxid, inputTxid)
@@ -138,7 +143,7 @@ func TestBuildForfeitTxs(t *testing.T) {
t.Run("invalid", func(t *testing.T) {
for _, f := range fixtures.Invalid {
connectors, forfeitTxs, err := builder.BuildForfeitTxs(
pubkey, f.PoolTx, f.Payments,
pubkey, f.PoolTx, f.Payments, minRelayFee,
)
require.EqualError(t, err, f.ExpectedErr)
require.Empty(t, connectors)

View File

@@ -1,12 +1,13 @@
package txbuilder
import (
"github.com/vulpemventures/go-elements/elementsutil"
"github.com/vulpemventures/go-elements/psetv2"
"github.com/vulpemventures/go-elements/transaction"
)
func craftConnectorTx(
input psetv2.InputArgs, outputs []psetv2.OutputArgs,
input psetv2.InputArgs, inputScript []byte, outputs []psetv2.OutputArgs, feeAmount uint64,
) (*psetv2.Pset, error) {
ptx, _ := psetv2.New(nil, nil, nil)
updater, _ := psetv2.NewUpdater(ptx)
@@ -17,9 +18,34 @@ func craftConnectorTx(
return nil, err
}
// TODO: add prevout.
var asset []byte
amount := feeAmount
for _, output := range outputs {
amount += output.Amount
if asset == nil {
var err error
asset, err = elementsutil.AssetHashToBytes(output.Asset)
if err != nil {
return nil, err
}
}
}
if err := updater.AddOutputs(outputs); err != nil {
value, err := elementsutil.ValueToBytes(amount)
if err != nil {
return nil, err
}
if err := updater.AddInWitnessUtxo(0, transaction.NewTxOutput(asset, value, inputScript)); err != nil {
return nil, err
}
feeOutput := psetv2.OutputArgs{
Asset: outputs[0].Asset,
Amount: feeAmount,
}
if err := updater.AddOutputs(append(outputs, feeOutput)); err != nil {
return nil, err
}

View File

@@ -38,11 +38,6 @@ func craftForfeitTxs(
}
vtxoAmount, _ := elementsutil.ValueToBytes(vtxo.Amount)
vtxoPrevout := &transaction.TxOutput{
Asset: connectorPrevout.Asset,
Value: vtxoAmount,
Script: vtxoScript,
}
if err := updater.AddInputs([]psetv2.InputArgs{connectorInput, vtxoInput}); err != nil {
return nil, err
@@ -56,6 +51,8 @@ func craftForfeitTxs(
return nil, err
}
vtxoPrevout := transaction.NewTxOutput(connectorPrevout.Asset, vtxoAmount, vtxoScript)
if err = updater.AddInWitnessUtxo(1, vtxoPrevout); err != nil {
return nil, err
}

View File

@@ -40,6 +40,17 @@ func (m *mockedWallet) DeriveAddresses(ctx context.Context, num int) ([]string,
return res, args.Error(1)
}
// DeriveConnectorAddress implements ports.WalletService.
func (m *mockedWallet) DeriveConnectorAddress(ctx context.Context) (string, error) {
args := m.Called(ctx)
var res string
if a := args.Get(0); a != nil {
res = a.(string)
}
return res, args.Error(1)
}
// GetPubkey implements ports.WalletService.
func (m *mockedWallet) GetPubkey(ctx context.Context) (*secp256k1.PublicKey, error) {
args := m.Called(ctx)
@@ -123,6 +134,17 @@ func (m *mockedWallet) SignPsetWithKey(ctx context.Context, pset string, inputIn
return res, args.Error(1)
}
func (m *mockedWallet) SignConnectorInput(ctx context.Context, pset string, inputIndexes []int, extract bool) (string, error) {
args := m.Called(ctx, pset, inputIndexes, extract)
var res string
if a := args.Get(0); a != nil {
res = a.(string)
}
return res, args.Error(1)
}
func (m *mockedWallet) WatchScripts(
ctx context.Context, scripts []string,
) error {
@@ -147,6 +169,22 @@ func (m *mockedWallet) GetNotificationChannel(ctx context.Context) chan []domain
return res
}
func (m *mockedWallet) ListConnectorUtxos(ctx context.Context, addr string) ([]ports.TxInput, error) {
args := m.Called(ctx, addr)
var res []ports.TxInput
if a := args.Get(0); a != nil {
res = a.([]ports.TxInput)
}
return res, args.Error(1)
}
func (m *mockedWallet) WaitForSync(ctx context.Context, txid string) error {
args := m.Called(ctx, txid)
return args.Error(0)
}
type mockedInput struct {
mock.Mock
}