diff --git a/server/internal/core/application/service.go b/server/internal/core/application/service.go index 2090509..33d5a64 100644 --- a/server/internal/core/application/service.go +++ b/server/internal/core/application/service.go @@ -5,6 +5,7 @@ import ( "context" "encoding/hex" "fmt" + "sync" "time" "github.com/ark-network/ark/common" @@ -286,7 +287,7 @@ func (s *service) startFinalization() { return } - unsignedPoolTx, tree, err := s.builder.BuildPoolTx(s.pubkey, payments, s.minRelayFee) + unsignedPoolTx, tree, connectorAddress, err := s.builder.BuildPoolTx(s.pubkey, payments, s.minRelayFee) if err != nil { changes = round.Fail(fmt.Errorf("failed to create pool tx: %s", err)) log.WithError(err).Warn("failed to create pool tx") @@ -295,7 +296,7 @@ func (s *service) startFinalization() { log.Debugf("pool tx created for round %s", round.Id) - connectors, forfeitTxs, err := s.builder.BuildForfeitTxs(s.pubkey, unsignedPoolTx, payments) + connectors, forfeitTxs, err := s.builder.BuildForfeitTxs(s.pubkey, unsignedPoolTx, payments, s.minRelayFee) if err != nil { changes = round.Fail(fmt.Errorf("failed to create connectors and forfeit txs: %s", err)) log.WithError(err).Warn("failed to create connectors and forfeit txs") @@ -304,7 +305,7 @@ func (s *service) startFinalization() { log.Debugf("forfeit transactions created for round %s", round.Id) - events, err := round.StartFinalization(connectors, tree, unsignedPoolTx) + events, err := round.StartFinalization(connectorAddress, connectors, tree, unsignedPoolTx) if err != nil { changes = round.Fail(fmt.Errorf("failed to start finalization: %s", err)) log.WithError(err).Warn("failed to start finalization") @@ -369,25 +370,162 @@ func (s *service) finalizeRound() { func (s *service) listenToRedemptions() { ctx := context.Background() chVtxos := s.scanner.GetNotificationChannel(ctx) + + mutx := &sync.Mutex{} for vtxoKeys := range chVtxos { - if len(vtxoKeys) > 0 { - for { - // TODO: make sure that the vtxos haven't been already spent, otherwise - // broadcast the corresponding forfeit tx and connector to prevent - // getting cheated. - vtxos, err := s.repoManager.Vtxos().RedeemVtxos(ctx, vtxoKeys) + go func(vtxoKeys []domain.VtxoKey) { + vtxosRepo := s.repoManager.Vtxos() + roundRepo := s.repoManager.Rounds() + + for _, v := range vtxoKeys { + vtxos, err := vtxosRepo.GetVtxos(ctx, []domain.VtxoKey{v}) if err != nil { - log.WithError(err).Warn("failed to redeem vtxos, retrying...") - time.Sleep(100 * time.Millisecond) + log.WithError(err).Warn("failed to retrieve vtxos, skipping...") continue } - if len(vtxos) > 0 { - log.Debugf("redeemed %d vtxos", len(vtxos)) + + vtxo := vtxos[0] + + if vtxo.Redeemed { + continue + } + + if _, err := s.repoManager.Vtxos().RedeemVtxos(ctx, []domain.VtxoKey{vtxo.VtxoKey}); err != nil { + log.WithError(err).Warn("failed to redeem vtxos, retrying...") + continue + } + log.Debugf("vtxo %s redeemed", vtxo.Txid) + + if !vtxo.Spent { + continue + } + + log.Debugf("fraud detected on vtxo %s", vtxo.Txid) + + round, err := roundRepo.GetRoundWithTxid(ctx, vtxo.SpentBy) + if err != nil { + log.WithError(err).Warn("failed to retrieve round") + continue + } + + mutx.Lock() + defer mutx.Unlock() + + connectorTxid, connectorVout, err := s.getNextConnector(ctx, *round) + if err != nil { + log.WithError(err).Warn("failed to retrieve next connector") + continue + } + + forfeitTx, err := findForfeitTx(round.ForfeitTxs, connectorTxid, connectorVout, vtxo.Txid) + if err != nil { + log.WithError(err).Warn("failed to retrieve forfeit tx") + continue + } + + signedForfeitTx, err := s.wallet.SignConnectorInput(ctx, forfeitTx, []int{0}, false) + if err != nil { + log.WithError(err).Warn("failed to sign connector input in forfeit tx") + continue + } + + signedForfeitTx, err = s.wallet.SignPsetWithKey(ctx, signedForfeitTx, []int{1}) + if err != nil { + log.WithError(err).Warn("failed to sign vtxo input in forfeit tx") + continue + } + + forfeitTxHex, err := finalizeAndExtractForfeit(signedForfeitTx) + if err != nil { + log.WithError(err).Warn("failed to finalize forfeit tx") + continue + } + + forfeitTxid, err := s.wallet.BroadcastTransaction(ctx, forfeitTxHex) + if err != nil { + log.WithError(err).Warn("failed to broadcast forfeit tx") + continue + } + + log.Debugf("broadcasted forfeit tx %s", forfeitTxid) + } + }(vtxoKeys) + } +} + +func (s *service) getNextConnector( + ctx context.Context, + round domain.Round, +) (string, uint32, error) { + connectorTx, err := psetv2.NewPsetFromBase64(round.Connectors[0]) + if err != nil { + return "", 0, err + } + + prevout := connectorTx.Inputs[0].WitnessUtxo + if prevout == nil { + return "", 0, fmt.Errorf("connector prevout not found") + } + + utxos, err := s.wallet.ListConnectorUtxos(ctx, round.ConnectorAddress) + if err != nil { + return "", 0, err + } + + // if we do not find any utxos, we make sure to wait for the connector outpoint to be confirmed then we retry + if len(utxos) <= 0 { + if err := s.wallet.WaitForSync(ctx, round.Txid); err != nil { + return "", 0, err + } + + utxos, err = s.wallet.ListConnectorUtxos(ctx, round.ConnectorAddress) + if err != nil { + return "", 0, err + } + } + + // search for an already existing connector + for _, u := range utxos { + if u.GetValue() == 450 { + return u.GetTxid(), u.GetIndex(), nil + } + } + + for _, u := range utxos { + if u.GetValue() > 450 { + for _, b64 := range round.Connectors { + pset, err := psetv2.NewPsetFromBase64(b64) + if err != nil { + return "", 0, err + } + + for _, i := range pset.Inputs { + if chainhash.Hash(i.PreviousTxid).String() == u.GetTxid() && i.PreviousTxIndex == u.GetIndex() { + // sign & broadcast the connector tx + signedConnectorTx, err := s.wallet.SignConnectorInput(ctx, b64, []int{0}, true) + if err != nil { + return "", 0, err + } + + connectorTxid, err := s.wallet.BroadcastTransaction(ctx, signedConnectorTx) + if err != nil { + return "", 0, err + } + log.Debugf("broadcasted connector tx %s", connectorTxid) + + // wait for the connector tx to be in the mempool + if err := s.wallet.WaitForSync(ctx, connectorTxid); err != nil { + return "", 0, err + } + + return connectorTxid, 0, nil + } } - break } } } + + return "", 0, fmt.Errorf("no connector utxos found") } func (s *service) updateVtxoSet(round *domain.Round) { @@ -401,7 +539,7 @@ func (s *service) updateVtxoSet(round *domain.Round) { spentVtxos := getSpentVtxos(round.Payments) if len(spentVtxos) > 0 { for { - if err := repo.SpendVtxos(ctx, spentVtxos); err != nil { + if err := repo.SpendVtxos(ctx, spentVtxos, round.Txid); err != nil { log.WithError(err).Warn("failed to add new vtxos, retrying soon") time.Sleep(100 * time.Millisecond) continue @@ -535,12 +673,28 @@ func (s *service) stopWatchingVtxos(vtxos []domain.Vtxo) { } func (s *service) restoreWatchingVtxos() error { - vtxos, err := s.repoManager.Vtxos().GetSpendableVtxos( - context.Background(), "", - ) + sweepableRounds, err := s.repoManager.Rounds().GetSweepableRounds(context.Background()) if err != nil { return err } + + vtxos := make([]domain.Vtxo, 0) + + for _, round := range sweepableRounds { + fromRound, err := s.repoManager.Vtxos().GetVtxosForRound( + context.Background(), round.Txid, + ) + if err != nil { + log.WithError(err).Warnf("failed to retrieve vtxos for round %s", round.Txid) + continue + } + for _, v := range fromRound { + if !v.Swept && !v.Redeemed { + vtxos = append(vtxos, v) + } + } + } + if len(vtxos) <= 0 { return nil } @@ -617,3 +771,45 @@ func getPaymentsFromOnboarding( payment := domain.NewPaymentUnsafe(nil, receivers) return []domain.Payment{*payment} } + +func finalizeAndExtractForfeit(b64 string) (string, error) { + p, err := psetv2.NewPsetFromBase64(b64) + if err != nil { + return "", err + } + + // finalize connector input + if err := psetv2.FinalizeAll(p); err != nil { + return "", err + } + + // extract the forfeit tx + extracted, err := psetv2.Extract(p) + if err != nil { + return "", err + } + + return extracted.ToHex() +} + +func findForfeitTx( + forfeits []string, connectorTxid string, connectorVout uint32, vtxoTxid string, +) (string, error) { + for _, forfeit := range forfeits { + forfeitTx, err := psetv2.NewPsetFromBase64(forfeit) + if err != nil { + return "", err + } + + connector := forfeitTx.Inputs[0] + vtxoInput := forfeitTx.Inputs[1] + + if chainhash.Hash(connector.PreviousTxid).String() == connectorTxid && + connector.PreviousTxIndex == connectorVout && + chainhash.Hash(vtxoInput.PreviousTxid).String() == vtxoTxid { + return forfeit, nil + } + } + + return "", fmt.Errorf("forfeit tx not found") +} diff --git a/server/internal/core/application/sweeper.go b/server/internal/core/application/sweeper.go index 7165743..efa9f92 100644 --- a/server/internal/core/application/sweeper.go +++ b/server/internal/core/application/sweeper.go @@ -205,9 +205,10 @@ func (s *sweeper) createTask( } } + vtxosRepository := s.repoManager.Vtxos() if len(sweepInputs) > 0 { // build the sweep transaction with all the expired non-swept shared outputs - sweepTx, err := s.builder.BuildSweepTx(s.wallet, sweepInputs) + sweepTx, err := s.builder.BuildSweepTx(sweepInputs) if err != nil { log.WithError(err).Error("error while building sweep tx") return @@ -231,7 +232,6 @@ func (s *sweeper) createTask( } if len(txid) > 0 { log.Debugln("sweep tx broadcasted:", txid) - vtxosRepository := s.repoManager.Vtxos() // mark the vtxos as swept if err := vtxosRepository.SweepVtxos(ctx, vtxoKeys); err != nil { @@ -240,37 +240,38 @@ func (s *sweeper) createTask( } log.Debugf("%d vtxos swept", len(vtxoKeys)) + } + } - roundVtxos, err := vtxosRepository.GetVtxosForRound(ctx, roundTxid) - if err != nil { - log.WithError(err).Error("error while getting vtxos for round") - return - } + roundVtxos, err := vtxosRepository.GetVtxosForRound(ctx, roundTxid) + if err != nil { + log.WithError(err).Error("error while getting vtxos for round") + return + } - allSwept := true - for _, vtxo := range roundVtxos { - allSwept = allSwept && vtxo.Swept - if !allSwept { - break - } - } + allSwept := true + for _, vtxo := range roundVtxos { + allSwept = allSwept && (vtxo.Swept || vtxo.Redeemed) + if !allSwept { + break + } + } - if allSwept { - // update the round - roundRepo := s.repoManager.Rounds() - round, err := roundRepo.GetRoundWithTxid(ctx, roundTxid) - if err != nil { - log.WithError(err).Error("error while getting round") - return - } + if allSwept { + // update the round + roundRepo := s.repoManager.Rounds() + round, err := roundRepo.GetRoundWithTxid(ctx, roundTxid) + if err != nil { + log.WithError(err).Error("error while getting round") + return + } - round.Sweep() + log.Debugf("round %s fully swept", roundTxid) + round.Sweep() - if err := roundRepo.AddOrUpdateRound(ctx, *round); err != nil { - log.WithError(err).Error("error while marking round as swept") - return - } - } + if err := roundRepo.AddOrUpdateRound(ctx, *round); err != nil { + log.WithError(err).Error("error while marking round as swept") + return } } } diff --git a/server/internal/core/application/utils.go b/server/internal/core/application/utils.go index b421a39..c1ea031 100644 --- a/server/internal/core/application/utils.go +++ b/server/internal/core/application/utils.go @@ -201,6 +201,7 @@ func (m *forfeitTxsMap) sign(txs []string) error { } if sig.Verify(preimage, pubkey) { + m.forfeitTxs[txid].tx = tx m.forfeitTxs[txid].signed = true } else { return fmt.Errorf("invalid signature") diff --git a/server/internal/core/domain/events.go b/server/internal/core/domain/events.go index 978fadb..f8da116 100644 --- a/server/internal/core/domain/events.go +++ b/server/internal/core/domain/events.go @@ -21,6 +21,7 @@ type RoundFinalizationStarted struct { Id string CongestionTree tree.CongestionTree Connectors []string + ConnectorAddress string UnsignedForfeitTxs []string PoolTx string } diff --git a/server/internal/core/domain/payment.go b/server/internal/core/domain/payment.go index 71b4cd8..2fd03d5 100644 --- a/server/internal/core/domain/payment.go +++ b/server/internal/core/domain/payment.go @@ -131,6 +131,7 @@ type Vtxo struct { VtxoKey Receiver PoolTx string + SpentBy string Spent bool Redeemed bool Swept bool diff --git a/server/internal/core/domain/round.go b/server/internal/core/domain/round.go index 87ae6a1..128354b 100644 --- a/server/internal/core/domain/round.go +++ b/server/internal/core/domain/round.go @@ -44,9 +44,10 @@ type Round struct { ForfeitTxs []string CongestionTree tree.CongestionTree Connectors []string + ConnectorAddress string DustAmount uint64 Version uint - Swept bool // true if all the vtxos are vtxo.Swept + Swept bool // true if all the vtxos are vtxo.Swept or vtxo.Redeemed changes []RoundEvent } @@ -118,6 +119,7 @@ func (r *Round) On(event RoundEvent, replayed bool) { r.Stage.Code = FinalizationStage r.CongestionTree = e.CongestionTree r.Connectors = append([]string{}, e.Connectors...) + r.ConnectorAddress = e.ConnectorAddress r.UnsignedTx = e.PoolTx case RoundFinalized: r.Stage.Ended = true @@ -178,7 +180,7 @@ func (r *Round) RegisterPayments(payments []Payment) ([]RoundEvent, error) { return []RoundEvent{event}, nil } -func (r *Round) StartFinalization(connectors []string, congestionTree tree.CongestionTree, poolTx string) ([]RoundEvent, error) { +func (r *Round) StartFinalization(connectorAddress string, connectors []string, congestionTree tree.CongestionTree, poolTx string) ([]RoundEvent, error) { if len(connectors) <= 0 { return nil, fmt.Errorf("missing list of connectors") } @@ -196,10 +198,11 @@ func (r *Round) StartFinalization(connectors []string, congestionTree tree.Conge } event := RoundFinalizationStarted{ - Id: r.Id, - CongestionTree: congestionTree, - Connectors: connectors, - PoolTx: poolTx, + Id: r.Id, + CongestionTree: congestionTree, + Connectors: connectors, + ConnectorAddress: connectorAddress, + PoolTx: poolTx, } r.raise(event) diff --git a/server/internal/core/domain/round_repo.go b/server/internal/core/domain/round_repo.go index cc8ea65..7e89c0f 100644 --- a/server/internal/core/domain/round_repo.go +++ b/server/internal/core/domain/round_repo.go @@ -19,7 +19,7 @@ type RoundRepository interface { type VtxoRepository interface { AddVtxos(ctx context.Context, vtxos []Vtxo) error - SpendVtxos(ctx context.Context, vtxos []VtxoKey) error + SpendVtxos(ctx context.Context, vtxos []VtxoKey, txid string) error RedeemVtxos(ctx context.Context, vtxos []VtxoKey) ([]Vtxo, error) GetVtxos(ctx context.Context, vtxos []VtxoKey) ([]Vtxo, error) GetVtxosForRound(ctx context.Context, txid string) ([]Vtxo, error) diff --git a/server/internal/core/domain/round_test.go b/server/internal/core/domain/round_test.go index e44defd..55de389 100644 --- a/server/internal/core/domain/round_test.go +++ b/server/internal/core/domain/round_test.go @@ -296,7 +296,7 @@ func testStartFinalization(t *testing.T) { require.NoError(t, err) require.NotEmpty(t, events) - events, err = round.StartFinalization(connectors, congestionTree, poolTx) + events, err = round.StartFinalization("", connectors, congestionTree, poolTx) require.NoError(t, err) require.Len(t, events, 1) require.True(t, round.IsStarted()) @@ -418,7 +418,8 @@ func testStartFinalization(t *testing.T) { } for _, f := range fixtures { - events, err := f.round.StartFinalization(f.connectors, f.tree, f.poolTx) + // TODO fix this + events, err := f.round.StartFinalization("", f.connectors, f.tree, f.poolTx) require.EqualError(t, err, f.expectedErr) require.Empty(t, events) } @@ -438,7 +439,7 @@ func testEndFinalization(t *testing.T) { require.NoError(t, err) require.NotEmpty(t, events) - events, err = round.StartFinalization(connectors, congestionTree, poolTx) + events, err = round.StartFinalization("", connectors, congestionTree, poolTx) require.NoError(t, err) require.NotEmpty(t, events) diff --git a/server/internal/core/ports/tx_builder.go b/server/internal/core/ports/tx_builder.go index f563659..bb2eb68 100644 --- a/server/internal/core/ports/tx_builder.go +++ b/server/internal/core/ports/tx_builder.go @@ -14,15 +14,8 @@ type SweepInput struct { } type TxBuilder interface { - BuildPoolTx( - aspPubkey *secp256k1.PublicKey, payments []domain.Payment, minRelayFee uint64, - ) (poolTx string, congestionTree tree.CongestionTree, err error) - BuildForfeitTxs( - aspPubkey *secp256k1.PublicKey, poolTx string, payments []domain.Payment, - ) (connectors []string, forfeitTxs []string, err error) - BuildSweepTx( - wallet WalletService, - inputs []SweepInput, - ) (signedSweepTx string, err error) + BuildPoolTx(aspPubkey *secp256k1.PublicKey, payments []domain.Payment, minRelayFee uint64) (poolTx string, congestionTree tree.CongestionTree, connectorAddress string, err error) + BuildForfeitTxs(aspPubkey *secp256k1.PublicKey, poolTx string, payments []domain.Payment, minRelayFee uint64) (connectors []string, forfeitTxs []string, err error) + BuildSweepTx(inputs []SweepInput) (signedSweepTx string, err error) GetVtxoScript(userPubkey, aspPubkey *secp256k1.PublicKey) ([]byte, error) } diff --git a/server/internal/core/ports/wallet.go b/server/internal/core/ports/wallet.go index 3963055..5a39b3c 100644 --- a/server/internal/core/ports/wallet.go +++ b/server/internal/core/ports/wallet.go @@ -10,6 +10,7 @@ type WalletService interface { BlockchainScanner Status(ctx context.Context) (WalletStatus, error) GetPubkey(ctx context.Context) (*secp256k1.PublicKey, error) + DeriveConnectorAddress(ctx context.Context) (string, error) DeriveAddresses(ctx context.Context, num int) ([]string, error) SignPset( ctx context.Context, pset string, extractRawTx bool, @@ -18,7 +19,10 @@ type WalletService interface { BroadcastTransaction(ctx context.Context, txHex string) (string, error) SignPsetWithKey(ctx context.Context, pset string, inputIndexes []int) (string, error) // inputIndexes == nil means sign all inputs IsTransactionConfirmed(ctx context.Context, txid string) (isConfirmed bool, blocktime int64, err error) + WaitForSync(ctx context.Context, txid string) error EstimateFees(ctx context.Context, pset string) (uint64, error) + SignConnectorInput(ctx context.Context, pset string, inputIndexes []int, extract bool) (string, error) + ListConnectorUtxos(ctx context.Context, connectorAddress string) ([]TxInput, error) Close() } diff --git a/server/internal/infrastructure/db/badger/vtxo_repo.go b/server/internal/infrastructure/db/badger/vtxo_repo.go index 2201889..fcc679f 100644 --- a/server/internal/infrastructure/db/badger/vtxo_repo.go +++ b/server/internal/infrastructure/db/badger/vtxo_repo.go @@ -53,10 +53,10 @@ func (r *vtxoRepository) AddVtxos( } func (r *vtxoRepository) SpendVtxos( - ctx context.Context, vtxoKeys []domain.VtxoKey, + ctx context.Context, vtxoKeys []domain.VtxoKey, spentBy string, ) error { for _, vtxoKey := range vtxoKeys { - if err := r.spendVtxo(ctx, vtxoKey); err != nil { + if err := r.spendVtxo(ctx, vtxoKey, spentBy); err != nil { return err } } @@ -96,7 +96,7 @@ func (r *vtxoRepository) GetVtxos( func (r *vtxoRepository) GetVtxosForRound( ctx context.Context, txid string, ) ([]domain.Vtxo, error) { - query := badgerhold.Where("Txid").Eq(txid) + query := badgerhold.Where("PoolTx").Eq(txid) return r.findVtxos(ctx, query) } @@ -161,7 +161,7 @@ func (r *vtxoRepository) getVtxo( return &vtxo, nil } -func (r *vtxoRepository) spendVtxo(ctx context.Context, vtxoKey domain.VtxoKey) error { +func (r *vtxoRepository) spendVtxo(ctx context.Context, vtxoKey domain.VtxoKey, spendBy string) error { vtxo, err := r.getVtxo(ctx, vtxoKey) if err != nil { if strings.Contains(err.Error(), "not found") { @@ -174,6 +174,7 @@ func (r *vtxoRepository) spendVtxo(ctx context.Context, vtxoKey domain.VtxoKey) } vtxo.Spent = true + vtxo.SpentBy = spendBy if ctx.Value("tx") != nil { tx := ctx.Value("tx").(*badger.Txn) err = r.store.TxUpdate(tx, vtxoKey.Hash(), *vtxo) diff --git a/server/internal/infrastructure/db/service_test.go b/server/internal/infrastructure/db/service_test.go index 6b23c99..d5e918e 100644 --- a/server/internal/infrastructure/db/service_test.go +++ b/server/internal/infrastructure/db/service_test.go @@ -362,7 +362,7 @@ func testVtxoRepository(t *testing.T, svc ports.RepoManager) { require.NoError(t, err) require.Exactly(t, userVtxos, spendableVtxos) - err = svc.Vtxos().SpendVtxos(ctx, vtxoKeys[:1]) + err = svc.Vtxos().SpendVtxos(ctx, vtxoKeys[:1], txid) require.NoError(t, err) spentVtxos, err := svc.Vtxos().GetVtxos(ctx, vtxoKeys[:1]) diff --git a/server/internal/infrastructure/ocean-wallet/account.go b/server/internal/infrastructure/ocean-wallet/account.go index 72b8828..15676d7 100644 --- a/server/internal/infrastructure/ocean-wallet/account.go +++ b/server/internal/infrastructure/ocean-wallet/account.go @@ -4,14 +4,52 @@ import ( "context" pb "github.com/ark-network/ark/api-spec/protobuf/gen/ocean/v1" + "github.com/ark-network/ark/internal/core/ports" "github.com/vulpemventures/go-elements/address" ) func (s *service) DeriveAddresses( ctx context.Context, numOfAddresses int, +) ([]string, error) { + return s.deriveAddresses(ctx, numOfAddresses, arkAccount) +} + +func (s *service) DeriveConnectorAddress(ctx context.Context) (string, error) { + addresses, err := s.deriveAddresses(ctx, 1, connectorAccount) + if err != nil { + return "", err + } + + return addresses[0], nil +} + +func (s *service) ListConnectorUtxos( + ctx context.Context, connectorAddress string, +) ([]ports.TxInput, error) { + res, err := s.accountClient.ListUtxos(ctx, &pb.ListUtxosRequest{ + AccountName: connectorAccount, + Addresses: []string{connectorAddress}, + }) + if err != nil { + return nil, err + } + + utxos := make([]ports.TxInput, 0) + for _, utxo := range res.GetSpendableUtxos().GetUtxos() { + utxos = append(utxos, utxo) + } + for _, utxo := range res.GetLockedUtxos().GetUtxos() { + utxos = append(utxos, utxo) + } + + return utxos, nil +} + +func (s *service) deriveAddresses( + ctx context.Context, numOfAddresses int, account string, ) ([]string, error) { res, err := s.accountClient.DeriveAddresses(ctx, &pb.DeriveAddressesRequest{ - AccountName: accountLabel, + AccountName: account, NumOfAddresses: uint64(numOfAddresses), }) if err != nil { diff --git a/server/internal/infrastructure/ocean-wallet/service.go b/server/internal/infrastructure/ocean-wallet/service.go index 6ddb208..6ed739d 100644 --- a/server/internal/infrastructure/ocean-wallet/service.go +++ b/server/internal/infrastructure/ocean-wallet/service.go @@ -61,16 +61,35 @@ func NewService(addr string) (ports.WalletService, error) { return nil, err } - found := false + mainAccountFound, connectorAccountFound := false, false + for _, account := range info.GetAccounts() { - if account.GetLabel() == accountLabel { - found = true + if account.GetLabel() == arkAccount { + mainAccountFound = true + continue + } + + if account.GetLabel() == connectorAccount { + connectorAccountFound = true + continue + } + + if mainAccountFound && connectorAccountFound { break } } - if !found { + if !mainAccountFound { if _, err := accountClient.CreateAccountBIP44(ctx, &pb.CreateAccountBIP44Request{ - Label: accountLabel, + Label: arkAccount, + Unconfidential: true, + }); err != nil { + return nil, err + } + } + + if !connectorAccountFound { + if _, err := accountClient.CreateAccountBIP44(ctx, &pb.CreateAccountBIP44Request{ + Label: connectorAccount, Unconfidential: true, }); err != nil { return nil, err @@ -114,9 +133,7 @@ func (s *service) listenToNotificaitons() { } vtxos := toVtxos(msg.GetUtxos()) if len(vtxos) > 0 { - go func() { - s.chVtxos <- vtxos - }() + s.chVtxos <- vtxos } } } diff --git a/server/internal/infrastructure/ocean-wallet/transaction.go b/server/internal/infrastructure/ocean-wallet/transaction.go index bfafb51..6499dd1 100644 --- a/server/internal/infrastructure/ocean-wallet/transaction.go +++ b/server/internal/infrastructure/ocean-wallet/transaction.go @@ -59,8 +59,9 @@ func (s *service) SignPset( } func (s *service) SelectUtxos(ctx context.Context, asset string, amount uint64) ([]ports.TxInput, uint64, error) { + // TODO: select coins from the connector account IF the round is swept res, err := s.txClient.SelectUtxos(ctx, &pb.SelectUtxosRequest{ - AccountName: accountLabel, + AccountName: arkAccount, TargetAsset: asset, TargetAmount: amount, }) @@ -81,22 +82,22 @@ func (s *service) SelectUtxos(ctx context.Context, asset string, amount uint64) return inputs, res.GetChange(), nil } -func (s *service) GetTransaction( +func (s *service) getTransaction( ctx context.Context, txid string, -) (string, int64, error) { +) (string, bool, int64, error) { res, err := s.txClient.GetTransaction(ctx, &pb.GetTransactionRequest{ Txid: txid, }) if err != nil { - return "", 0, err + return "", false, 0, err } if res.GetBlockDetails().GetTimestamp() > 0 { - return res.GetTxHex(), res.BlockDetails.GetTimestamp(), nil + return res.GetTxHex(), true, res.BlockDetails.GetTimestamp(), nil } - // if not confirmed, we return now + 30 secs to estimate the next blocktime - return res.GetTxHex(), time.Now().Unix() + 30, nil + // if not confirmed, we return now + 1 min to estimate the next blocktime + return res.GetTxHex(), false, time.Now().Add(time.Minute).Unix(), nil } func (s *service) BroadcastTransaction( @@ -120,15 +121,29 @@ func (s *service) BroadcastTransaction( func (s *service) IsTransactionConfirmed( ctx context.Context, txid string, ) (bool, int64, error) { - _, blocktime, err := s.GetTransaction(ctx, txid) + _, isConfirmed, blocktime, err := s.getTransaction(ctx, txid) if err != nil { if strings.Contains(strings.ToLower(err.Error()), "missing transaction") { - return false, 0, nil + return isConfirmed, 0, nil } return false, 0, err } - return true, blocktime, nil + return isConfirmed, blocktime, nil +} +func (s *service) WaitForSync(ctx context.Context, txid string) error { + for { + time.Sleep(5 * time.Second) + _, _, _, err := s.getTransaction(ctx, txid) + if err != nil { + if strings.Contains(strings.ToLower(err.Error()), "missing transaction") { + continue + } + return err + } + break + } + return nil } func (s *service) SignPsetWithKey(ctx context.Context, b64 string, indexes []int) (string, error) { @@ -206,6 +221,38 @@ func (s *service) SignPsetWithKey(ctx context.Context, b64 string, indexes []int return signedPset.GetSignedTx(), nil } +func (s *service) SignConnectorInput(ctx context.Context, pset string, inputIndexes []int, extract bool) (string, error) { + decodedTx, err := psetv2.NewPsetFromBase64(pset) + if err != nil { + return "", err + } + + utxos := make([]*pb.Input, 0, len(decodedTx.Inputs)) + + for i := range inputIndexes { + if i >= len(decodedTx.Inputs) { + return "", fmt.Errorf("input index %d out of range", i) + } + + input := decodedTx.Inputs[i] + + utxos = append(utxos, &pb.Input{ + Txid: chainhash.Hash(input.PreviousTxid).String(), + Index: input.PreviousTxIndex, + }) + } + + _, err = s.txClient.LockUtxos(ctx, &pb.LockUtxosRequest{ + AccountName: connectorAccount, + Utxos: utxos, + }) + if err != nil { + return "", err + } + + return s.SignPset(ctx, pset, extract) +} + func (s *service) EstimateFees( ctx context.Context, pset string, ) (uint64, error) { diff --git a/server/internal/infrastructure/ocean-wallet/wallet.go b/server/internal/infrastructure/ocean-wallet/wallet.go index d9f17c6..ba0e2c9 100644 --- a/server/internal/infrastructure/ocean-wallet/wallet.go +++ b/server/internal/infrastructure/ocean-wallet/wallet.go @@ -11,7 +11,10 @@ import ( "github.com/vulpemventures/go-bip32" ) -const accountLabel = "ark" +const ( + arkAccount = "ark" + connectorAccount = "ark-connector" +) var derivationPath = []uint32{0, 0} @@ -63,7 +66,7 @@ func (s *service) findAccount(ctx context.Context, label string) (*pb.AccountInf } func (s *service) getPubkey(ctx context.Context) (*secp256k1.PublicKey, *bip32.Key, error) { - account, err := s.findAccount(ctx, accountLabel) + account, err := s.findAccount(ctx, arkAccount) if err != nil { return nil, nil, err } diff --git a/server/internal/infrastructure/tx-builder/covenant/builder.go b/server/internal/infrastructure/tx-builder/covenant/builder.go index 6534021..1de9fe8 100644 --- a/server/internal/infrastructure/tx-builder/covenant/builder.go +++ b/server/internal/infrastructure/tx-builder/covenant/builder.go @@ -11,6 +11,7 @@ import ( "github.com/decred/dcrd/dcrec/secp256k1/v4" "github.com/vulpemventures/go-elements/address" "github.com/vulpemventures/go-elements/network" + "github.com/vulpemventures/go-elements/payment" "github.com/vulpemventures/go-elements/psetv2" "github.com/vulpemventures/go-elements/taproot" ) @@ -41,9 +42,9 @@ func (b *txBuilder) GetVtxoScript(userPubkey, aspPubkey *secp256k1.PublicKey) ([ return outputScript, nil } -func (b *txBuilder) BuildSweepTx(wallet ports.WalletService, inputs []ports.SweepInput) (signedSweepTx string, err error) { +func (b *txBuilder) BuildSweepTx(inputs []ports.SweepInput) (signedSweepTx string, err error) { sweepPset, err := sweepTransaction( - wallet, + b.wallet, inputs, b.net.AssetID, ) @@ -57,7 +58,7 @@ func (b *txBuilder) BuildSweepTx(wallet ports.WalletService, inputs []ports.Swee } ctx := context.Background() - signedSweepPsetB64, err := wallet.SignPsetWithKey(ctx, sweepPsetBase64, nil) + signedSweepPsetB64, err := b.wallet.SignPsetWithKey(ctx, sweepPsetBase64, nil) if err != nil { return "", err } @@ -80,9 +81,14 @@ func (b *txBuilder) BuildSweepTx(wallet ports.WalletService, inputs []ports.Swee } func (b *txBuilder) BuildForfeitTxs( - aspPubkey *secp256k1.PublicKey, poolTx string, payments []domain.Payment, + aspPubkey *secp256k1.PublicKey, poolTx string, payments []domain.Payment, minRelayFee uint64, ) (connectors []string, forfeitTxs []string, err error) { - connectorTxs, err := b.createConnectors(poolTx, payments, aspPubkey) + connectorAddress, err := b.getConnectorAddress(poolTx) + if err != nil { + return nil, nil, err + } + + connectorTxs, err := b.createConnectors(poolTx, payments, connectorAddress, minRelayFee) if err != nil { return nil, nil, err } @@ -101,7 +107,7 @@ func (b *txBuilder) BuildForfeitTxs( func (b *txBuilder) BuildPoolTx( aspPubkey *secp256k1.PublicKey, payments []domain.Payment, minRelayFee uint64, -) (poolTx string, congestionTree tree.CongestionTree, err error) { +) (poolTx string, congestionTree tree.CongestionTree, connectorAddress string, err error) { // The creation of the tree and the pool tx are tightly coupled: // - building the tree requires knowing the shared outpoint (txid:vout) // - building the pool tx requires knowing the shared output script and amount @@ -120,8 +126,13 @@ func (b *txBuilder) BuildPoolTx( return } + connectorAddress, err = b.wallet.DeriveConnectorAddress(context.Background()) + if err != nil { + return + } + ptx, err := b.createPoolTx( - sharedOutputAmount, sharedOutputScript, payments, aspPubkey, + sharedOutputAmount, sharedOutputScript, payments, aspPubkey, connectorAddress, minRelayFee, ) if err != nil { return @@ -190,15 +201,24 @@ func (b *txBuilder) getLeafScriptAndTree( func (b *txBuilder) createPoolTx( sharedOutputAmount uint64, sharedOutputScript []byte, - payments []domain.Payment, aspPubKey *secp256k1.PublicKey, + payments []domain.Payment, aspPubKey *secp256k1.PublicKey, connectorAddress string, minRelayFee uint64, ) (*psetv2.Pset, error) { aspScript, err := p2wpkhScript(aspPubKey, b.net) if err != nil { return nil, err } + connectorScript, err := address.ToOutputScript(connectorAddress) + if err != nil { + return nil, err + } + receivers := getOnchainReceivers(payments) - connectorsAmount := connectorAmount * countSpentVtxos(payments) + nbOfInputs := countSpentVtxos(payments) + connectorsAmount := (connectorAmount + minRelayFee) * nbOfInputs + if nbOfInputs > 1 { + connectorsAmount -= minRelayFee + } targetAmount := sharedOutputAmount + connectorsAmount outputs := []psetv2.OutputArgs{ @@ -210,7 +230,7 @@ func (b *txBuilder) createPoolTx( { Asset: b.net.AssetID, Amount: connectorsAmount, - Script: aspScript, + Script: connectorScript, }, } @@ -354,11 +374,11 @@ func (b *txBuilder) createPoolTx( } func (b *txBuilder) createConnectors( - poolTx string, payments []domain.Payment, aspPubkey *secp256k1.PublicKey, + poolTx string, payments []domain.Payment, connectorAddress string, minRelayFee uint64, ) ([]*psetv2.Pset, error) { txid, _ := getTxid(poolTx) - aspScript, err := p2wpkhScript(aspPubkey, b.net) + aspScript, err := address.ToOutputScript(connectorAddress) if err != nil { return nil, err } @@ -378,7 +398,7 @@ func (b *txBuilder) createConnectors( if numberOfConnectors == 1 { outputs := []psetv2.OutputArgs{connectorOutput} - connectorTx, err := craftConnectorTx(previousInput, outputs) + connectorTx, err := craftConnectorTx(previousInput, aspScript, outputs, minRelayFee) if err != nil { return nil, err } @@ -386,12 +406,16 @@ func (b *txBuilder) createConnectors( return []*psetv2.Pset{connectorTx}, nil } - totalConnectorAmount := connectorAmount * numberOfConnectors + totalConnectorAmount := (connectorAmount + minRelayFee) * numberOfConnectors + if numberOfConnectors > 1 { + totalConnectorAmount -= minRelayFee + } connectors := make([]*psetv2.Pset, 0, numberOfConnectors-1) for i := uint64(0); i < numberOfConnectors-1; i++ { outputs := []psetv2.OutputArgs{connectorOutput} totalConnectorAmount -= connectorAmount + totalConnectorAmount -= minRelayFee if totalConnectorAmount > 0 { outputs = append(outputs, psetv2.OutputArgs{ Asset: b.net.AssetID, @@ -399,7 +423,7 @@ func (b *txBuilder) createConnectors( Amount: totalConnectorAmount, }) } - connectorTx, err := craftConnectorTx(previousInput, outputs) + connectorTx, err := craftConnectorTx(previousInput, aspScript, outputs, minRelayFee) if err != nil { return nil, err } @@ -473,3 +497,23 @@ func (b *txBuilder) createForfeitTxs( } return forfeitTxs, nil } + +func (b *txBuilder) getConnectorAddress(poolTx string) (string, error) { + pset, err := psetv2.NewPsetFromBase64(poolTx) + if err != nil { + return "", err + } + + if len(pset.Outputs) < 1 { + return "", fmt.Errorf("connector output not found in pool tx") + } + + connectorOutput := pset.Outputs[1] + + pay, err := payment.FromScript(connectorOutput.Script, b.net, nil) + if err != nil { + return "", err + } + + return pay.WitnessPubKeyHash() +} diff --git a/server/internal/infrastructure/tx-builder/covenant/builder_test.go b/server/internal/infrastructure/tx-builder/covenant/builder_test.go index 4abd85e..f82deec 100644 --- a/server/internal/infrastructure/tx-builder/covenant/builder_test.go +++ b/server/internal/infrastructure/tx-builder/covenant/builder_test.go @@ -21,6 +21,7 @@ import ( const ( testingKey = "0218d5ca8b58797b7dbd65c075dd7ba7784b3f38ab71b1a5a8e3f94ba0257654a6" + connectorAddress = "tex1qekd5u0qj8jl07vy60830xy7n9qtmcx9u3s0cqc" minRelayFee = uint64(30) roundLifetime = int64(1209344) unilateralExitDelay = int64(512) @@ -37,6 +38,8 @@ func TestMain(m *testing.M) { Return(uint64(100), nil) wallet.On("SelectUtxos", mock.Anything, mock.Anything, mock.Anything). Return(randomInput, uint64(0), nil) + wallet.On("DeriveConnectorAddress", mock.Anything). + Return(connectorAddress, nil) pubkeyBytes, _ := hex.DecodeString(testingKey) pubkey, _ = secp256k1.ParsePubKey(pubkeyBytes) @@ -56,12 +59,13 @@ func TestBuildPoolTx(t *testing.T) { if len(fixtures.Valid) > 0 { t.Run("valid", func(t *testing.T) { for _, f := range fixtures.Valid { - poolTx, congestionTree, err := builder.BuildPoolTx( + poolTx, congestionTree, connAddr, err := builder.BuildPoolTx( pubkey, f.Payments, minRelayFee, ) require.NoError(t, err) require.NotEmpty(t, poolTx) require.NotEmpty(t, congestionTree) + require.Equal(t, connectorAddress, connAddr) require.Equal(t, f.ExpectedNumOfNodes, congestionTree.NumberOfNodes()) require.Len(t, congestionTree.Leaves(), f.ExpectedNumOfLeaves) @@ -76,11 +80,12 @@ func TestBuildPoolTx(t *testing.T) { if len(fixtures.Invalid) > 0 { t.Run("invalid", func(t *testing.T) { for _, f := range fixtures.Invalid { - poolTx, congestionTree, err := builder.BuildPoolTx( + poolTx, congestionTree, connAddr, err := builder.BuildPoolTx( pubkey, f.Payments, minRelayFee, ) require.EqualError(t, err, f.ExpectedErr) require.Empty(t, poolTx) + require.Empty(t, connAddr) require.Empty(t, congestionTree) } }) @@ -100,7 +105,7 @@ func TestBuildForfeitTxs(t *testing.T) { t.Run("valid", func(t *testing.T) { for _, f := range fixtures.Valid { connectors, forfeitTxs, err := builder.BuildForfeitTxs( - pubkey, f.PoolTx, f.Payments, + pubkey, f.PoolTx, f.Payments, minRelayFee, ) require.NoError(t, err) require.Len(t, connectors, f.ExpectedNumOfConnectors) @@ -114,7 +119,7 @@ func TestBuildForfeitTxs(t *testing.T) { require.NotNil(t, tx) require.Len(t, tx.Inputs, 1) - require.Len(t, tx.Outputs, 2) + require.Len(t, tx.Outputs, 3) inputTxid := chainhash.Hash(tx.Inputs[0].PreviousTxid).String() require.Equal(t, expectedInputTxid, inputTxid) @@ -138,7 +143,7 @@ func TestBuildForfeitTxs(t *testing.T) { t.Run("invalid", func(t *testing.T) { for _, f := range fixtures.Invalid { connectors, forfeitTxs, err := builder.BuildForfeitTxs( - pubkey, f.PoolTx, f.Payments, + pubkey, f.PoolTx, f.Payments, minRelayFee, ) require.EqualError(t, err, f.ExpectedErr) require.Empty(t, connectors) diff --git a/server/internal/infrastructure/tx-builder/covenant/connectors.go b/server/internal/infrastructure/tx-builder/covenant/connectors.go index 0d9d02f..c4e54d4 100644 --- a/server/internal/infrastructure/tx-builder/covenant/connectors.go +++ b/server/internal/infrastructure/tx-builder/covenant/connectors.go @@ -1,12 +1,13 @@ package txbuilder import ( + "github.com/vulpemventures/go-elements/elementsutil" "github.com/vulpemventures/go-elements/psetv2" "github.com/vulpemventures/go-elements/transaction" ) func craftConnectorTx( - input psetv2.InputArgs, outputs []psetv2.OutputArgs, + input psetv2.InputArgs, inputScript []byte, outputs []psetv2.OutputArgs, feeAmount uint64, ) (*psetv2.Pset, error) { ptx, _ := psetv2.New(nil, nil, nil) updater, _ := psetv2.NewUpdater(ptx) @@ -17,9 +18,34 @@ func craftConnectorTx( return nil, err } - // TODO: add prevout. + var asset []byte + amount := feeAmount + for _, output := range outputs { + amount += output.Amount + if asset == nil { + var err error + asset, err = elementsutil.AssetHashToBytes(output.Asset) + if err != nil { + return nil, err + } + } + } - if err := updater.AddOutputs(outputs); err != nil { + value, err := elementsutil.ValueToBytes(amount) + if err != nil { + return nil, err + } + + if err := updater.AddInWitnessUtxo(0, transaction.NewTxOutput(asset, value, inputScript)); err != nil { + return nil, err + } + + feeOutput := psetv2.OutputArgs{ + Asset: outputs[0].Asset, + Amount: feeAmount, + } + + if err := updater.AddOutputs(append(outputs, feeOutput)); err != nil { return nil, err } diff --git a/server/internal/infrastructure/tx-builder/covenant/forfeit.go b/server/internal/infrastructure/tx-builder/covenant/forfeit.go index 18a4f8e..7b83db0 100644 --- a/server/internal/infrastructure/tx-builder/covenant/forfeit.go +++ b/server/internal/infrastructure/tx-builder/covenant/forfeit.go @@ -38,11 +38,6 @@ func craftForfeitTxs( } vtxoAmount, _ := elementsutil.ValueToBytes(vtxo.Amount) - vtxoPrevout := &transaction.TxOutput{ - Asset: connectorPrevout.Asset, - Value: vtxoAmount, - Script: vtxoScript, - } if err := updater.AddInputs([]psetv2.InputArgs{connectorInput, vtxoInput}); err != nil { return nil, err @@ -56,6 +51,8 @@ func craftForfeitTxs( return nil, err } + vtxoPrevout := transaction.NewTxOutput(connectorPrevout.Asset, vtxoAmount, vtxoScript) + if err = updater.AddInWitnessUtxo(1, vtxoPrevout); err != nil { return nil, err } diff --git a/server/internal/infrastructure/tx-builder/covenant/mocks_test.go b/server/internal/infrastructure/tx-builder/covenant/mocks_test.go index 9c38874..af455d2 100644 --- a/server/internal/infrastructure/tx-builder/covenant/mocks_test.go +++ b/server/internal/infrastructure/tx-builder/covenant/mocks_test.go @@ -40,6 +40,17 @@ func (m *mockedWallet) DeriveAddresses(ctx context.Context, num int) ([]string, return res, args.Error(1) } +// DeriveConnectorAddress implements ports.WalletService. +func (m *mockedWallet) DeriveConnectorAddress(ctx context.Context) (string, error) { + args := m.Called(ctx) + + var res string + if a := args.Get(0); a != nil { + res = a.(string) + } + return res, args.Error(1) +} + // GetPubkey implements ports.WalletService. func (m *mockedWallet) GetPubkey(ctx context.Context) (*secp256k1.PublicKey, error) { args := m.Called(ctx) @@ -123,6 +134,17 @@ func (m *mockedWallet) SignPsetWithKey(ctx context.Context, pset string, inputIn return res, args.Error(1) } +func (m *mockedWallet) SignConnectorInput(ctx context.Context, pset string, inputIndexes []int, extract bool) (string, error) { + args := m.Called(ctx, pset, inputIndexes, extract) + + var res string + if a := args.Get(0); a != nil { + res = a.(string) + } + + return res, args.Error(1) +} + func (m *mockedWallet) WatchScripts( ctx context.Context, scripts []string, ) error { @@ -147,6 +169,22 @@ func (m *mockedWallet) GetNotificationChannel(ctx context.Context) chan []domain return res } +func (m *mockedWallet) ListConnectorUtxos(ctx context.Context, addr string) ([]ports.TxInput, error) { + args := m.Called(ctx, addr) + + var res []ports.TxInput + if a := args.Get(0); a != nil { + res = a.([]ports.TxInput) + } + + return res, args.Error(1) +} + +func (m *mockedWallet) WaitForSync(ctx context.Context, txid string) error { + args := m.Called(ctx, txid) + return args.Error(0) +} + type mockedInput struct { mock.Mock }