From 287db4e08ac38424511a3b6707dfb61df70f1437 Mon Sep 17 00:00:00 2001 From: Louis Singer <41042567+louisinger@users.noreply.github.com> Date: Thu, 8 Feb 2024 16:58:04 +0100 Subject: [PATCH] Support round expiration and sweep vtxos (#70) * sweeper base implementation * sweeper service final implementation * fixes * fix CSV script * RoundSwept event fix & test * remove Vtxos after a sweep transaction * ARK_ROUND_LIFETIME config * remove TxBuilder.GetLifetime * refactor sweeper * use GetTransaction blocktime * polish and comments * fix linting * pair programming fixes * several fixes * clean Println * fixes * linter fixes * remove infrastructure deps from application layer * Fixes --------- Co-authored-by: altafan <18440657+altafan@users.noreply.github.com> --- asp/cmd/arkd/main.go | 8 + asp/go.mod | 8 +- asp/go.sum | 16 +- asp/internal/app-config/config.go | 46 +- asp/internal/config/config.go | 5 + asp/internal/core/application/service.go | 40 +- asp/internal/core/application/sweeper.go | 579 ++++++++++++++++++ asp/internal/core/domain/payment.go | 1 + asp/internal/core/domain/round.go | 5 + asp/internal/core/domain/round_repo.go | 7 +- asp/internal/core/ports/scheduler.go | 1 + asp/internal/core/ports/tx_builder.go | 12 + asp/internal/core/ports/wallet.go | 2 + .../infrastructure/db/badger/round_repo.go | 12 + .../infrastructure/db/badger/vtxo_repo.go | 42 +- .../infrastructure/ocean-wallet/service.go | 1 + .../ocean-wallet/transaction.go | 145 ++++- .../infrastructure/ocean-wallet/wallet.go | 74 ++- .../scheduler/gocron/service.go | 11 + .../tx-builder/covenant/builder.go | 111 +++- .../tx-builder/covenant/builder_test.go | 4 +- .../tx-builder/covenant/mocks_test.go | 26 + .../tx-builder/covenant/sweep.go | 135 ++++ .../tx-builder/covenant/tree.go | 48 +- .../tx-builder/dummy/builder.go | 12 + .../tx-builder/dummy/builder_test.go | 8 + asp/internal/interface/grpc/service.go | 1 + common/bip68.go | 16 +- common/tree/congestion_tree.go | 53 +- common/tree/script.go | 26 +- common/tree/validation.go | 2 +- noah/go.mod | 2 +- noah/go.sum | 4 +- noah/redeem.go | 12 +- noah/unilateral_redeem.go | 42 +- 35 files changed, 1403 insertions(+), 114 deletions(-) create mode 100644 asp/internal/core/application/sweeper.go create mode 100644 asp/internal/infrastructure/tx-builder/covenant/sweep.go diff --git a/asp/cmd/arkd/main.go b/asp/cmd/arkd/main.go index ce1e607..c96f63d 100755 --- a/asp/cmd/arkd/main.go +++ b/asp/cmd/arkd/main.go @@ -30,6 +30,13 @@ func main() { Port: cfg.Port, NoTLS: cfg.NoTLS, } + + if cfg.RoundLifetime%512 != 0 { + setLifetime := cfg.RoundLifetime + cfg.RoundLifetime = cfg.RoundLifetime - (cfg.RoundLifetime % 512) + log.Infof("round lifetime must be a multiple of 512, %d -> %d", setLifetime, cfg.RoundLifetime) + } + appConfig := &appconfig.Config{ DbType: cfg.DbType, DbDir: cfg.DbDir, @@ -40,6 +47,7 @@ func main() { BlockchainScannerType: cfg.BlockchainScannerType, WalletAddr: cfg.WalletAddr, MinRelayFee: cfg.MinRelayFee, + RoundLifetime: cfg.RoundLifetime, } svc, err := grpcservice.NewService(svcConfig, appConfig) if err != nil { diff --git a/asp/go.mod b/asp/go.mod index 694cff1..9fa6e17 100644 --- a/asp/go.mod +++ b/asp/go.mod @@ -18,7 +18,7 @@ require ( github.com/stretchr/testify v1.8.4 github.com/timshannon/badgerhold/v4 v4.0.3 github.com/urfave/cli/v2 v2.26.0 - github.com/vulpemventures/go-elements v0.5.2 + github.com/vulpemventures/go-elements v0.5.3 google.golang.org/genproto/googleapis/api v0.0.0-20231106174013-bbf56f31fb17 google.golang.org/grpc v1.59.0 google.golang.org/protobuf v1.31.0 @@ -26,6 +26,11 @@ require ( require github.com/stretchr/objx v0.5.0 // indirect +require ( + github.com/FactomProject/basen v0.0.0-20150613233007-fe3947df716e // indirect + github.com/FactomProject/btcutilecc v0.0.0-20130527213604-d3a63a5752ec // indirect +) + require ( github.com/btcsuite/btcd v0.23.1 github.com/btcsuite/btcd/btcec/v2 v2.3.2 @@ -62,6 +67,7 @@ require ( github.com/spf13/pflag v1.0.5 // indirect github.com/subosito/gotenv v1.6.0 // indirect github.com/vulpemventures/fastsha256 v0.0.0-20160815193821-637e65642941 // indirect + github.com/vulpemventures/go-bip32 v0.0.0-20200624192635-867c159da4d7 github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673 // indirect go.opencensus.io v0.24.0 // indirect go.uber.org/atomic v1.9.0 // indirect diff --git a/asp/go.sum b/asp/go.sum index 1ac5ef4..d9c0609 100644 --- a/asp/go.sum +++ b/asp/go.sum @@ -38,6 +38,10 @@ cloud.google.com/go/storage v1.14.0/go.mod h1:GrKmX003DSIwi9o29oFT7YDnHYwZoctc3f dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo= +github.com/FactomProject/basen v0.0.0-20150613233007-fe3947df716e h1:ahyvB3q25YnZWly5Gq1ekg6jcmWaGj/vG/MhF4aisoc= +github.com/FactomProject/basen v0.0.0-20150613233007-fe3947df716e/go.mod h1:kGUqhHd//musdITWjFvNTHn90WG9bMLBEPQZ17Cmlpw= +github.com/FactomProject/btcutilecc v0.0.0-20130527213604-d3a63a5752ec h1:1Qb69mGp/UtRPn422BH4/Y4Q3SLUrD9KHuDkm8iodFc= +github.com/FactomProject/btcutilecc v0.0.0-20130527213604-d3a63a5752ec/go.mod h1:CD8UlnlLDiqb36L110uqiP2iSflVjx9g/3U9hCI4q2U= github.com/aead/siphash v1.0.1/go.mod h1:Nywa3cDsYNNK3gaciGTWPwHt0wlpNV15vwmswBAUSII= github.com/armon/consul-api v0.0.0-20180202201655-eb2c6b5be1b6/go.mod h1:grANhF5doyWs3UAsr3K4I6qtAmlQcZDesFNEHPZAzj8= github.com/benbjohnson/clock v1.1.0/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA= @@ -78,6 +82,8 @@ github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWR github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5PlCu98SY8svDHJxuZscDgtXS6KTTbou5AhLI= github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU= github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= +github.com/cmars/basen v0.0.0-20150613233007-fe3947df716e h1:0XBUw73chJ1VYSsfvcPvVT7auykAJce9FpRr10L6Qhw= +github.com/cmars/basen v0.0.0-20150613233007-fe3947df716e/go.mod h1:P13beTBKr5Q18lJe1rIoLUqjM+CB1zYrRg44ZqGuQSA= github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc= github.com/cncf/udpa/go v0.0.0-20200629203442-efcf912fb354/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk= github.com/cncf/udpa/go v0.0.0-20201120205902-5459f2c99403/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk= @@ -307,6 +313,7 @@ github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+ github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0 h1:1zr/of2m5FGMsad5YfcqgdqdWrIhu+EBEJRhR1U7z/c= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/testify v1.1.5-0.20170601210322-f6abca593680/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= @@ -328,8 +335,10 @@ github.com/urfave/cli/v2 v2.26.0 h1:3f3AMg3HpThFNT4I++TKOejZO8yU55t3JnnSr4S4QEI= github.com/urfave/cli/v2 v2.26.0/go.mod h1:8qnjx1vcq5s2/wpsqoZFndg2CE5tNFyrTvS6SinrnYQ= github.com/vulpemventures/fastsha256 v0.0.0-20160815193821-637e65642941 h1:CTcw80hz/Sw8hqlKX5ZYvBUF5gAHSHwdjXxRf/cjDcI= github.com/vulpemventures/fastsha256 v0.0.0-20160815193821-637e65642941/go.mod h1:GXBJykxW2kUcktGdsgyay7uwwWvkljASfljNcT0mbh8= -github.com/vulpemventures/go-elements v0.5.2 h1:vIDzVpRXG5PnlzHA8tCnr2Tn7raIV5cHy7bRyDrbuM4= -github.com/vulpemventures/go-elements v0.5.2/go.mod h1:aBGuWXHaiAIUIcwqCdtEh2iQ3kJjKwHU9ywvhlcRSeU= +github.com/vulpemventures/go-bip32 v0.0.0-20200624192635-867c159da4d7 h1:X7DtNv+YWy76kELMZB/xVkIJ7YNp2vpgMFVsDcQA40U= +github.com/vulpemventures/go-bip32 v0.0.0-20200624192635-867c159da4d7/go.mod h1:Zrvx8XgpWvSPdz1lXnuN083CkoZnzwxBLEB03S8et1I= +github.com/vulpemventures/go-elements v0.5.3 h1:zaC/ynHFwCAzFSOMfzb6BcbD6FXASppSiGMycc95WVA= +github.com/vulpemventures/go-elements v0.5.3/go.mod h1:aBGuWXHaiAIUIcwqCdtEh2iQ3kJjKwHU9ywvhlcRSeU= github.com/vulpemventures/go-secp256k1-zkp v1.1.6 h1:BmsrmXRLUibwa75Qkk8yELjpzCzlAjYFGLiLiOdq7Xo= github.com/vulpemventures/go-secp256k1-zkp v1.1.6/go.mod h1:zo7CpgkuPgoe7fAV+inyxsI9IhGmcoFgyD8nqZaPSOM= github.com/xordataexchange/crypt v0.0.3-0.20170626215501-b2862e3d0a77/go.mod h1:aYKd//L2LvnjZzWKhF00oedf4jCCReLcmhLdhm1A27Q= @@ -356,6 +365,7 @@ go.uber.org/multierr v1.6.0/go.mod h1:cdWPpRnG4AhwMwsgIHip0KRBQjJy5kYEpYjJxpXp9i go.uber.org/multierr v1.9.0 h1:7fIwc/ZtS0q++VgcfqFDxSBZVv/Xo49/SYnDFupUwlI= go.uber.org/multierr v1.9.0/go.mod h1:X2jQV1h+kxSjClGpnseKVIxpmcjrj7MNnI0bnlfKTVQ= go.uber.org/zap v1.18.1/go.mod h1:xg/QME4nWcxGxrpdeYfq7UvYrLh66cuVKdrbD1XF/NI= +golang.org/x/crypto v0.0.0-20170613210332-850760c427c5/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20170930174604-9419663f5a44/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20181203042331-505ab145d0a9/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= @@ -725,6 +735,8 @@ honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWh honnef.co/go/tools v0.0.1-2019.2.3/go.mod h1:a3bituU0lyd329TUQxRnasdCoJDkEUEAqEt0JzvZhAg= honnef.co/go/tools v0.0.1-2020.1.3/go.mod h1:X/FiERA/W4tHapMX5mGpAtMSVEeEUOyHaw9vFzvIQ3k= honnef.co/go/tools v0.0.1-2020.1.4/go.mod h1:X/FiERA/W4tHapMX5mGpAtMSVEeEUOyHaw9vFzvIQ3k= +launchpad.net/gocheck v0.0.0-20140225173054-000000000087 h1:Izowp2XBH6Ya6rv+hqbceQyw/gSGoXfH/UPoTGduL54= +launchpad.net/gocheck v0.0.0-20140225173054-000000000087/go.mod h1:hj7XX3B/0A+80Vse0e+BUHsHMTEhd0O4cpUHr/e/BUM= rsc.io/binaryregexp v0.2.0/go.mod h1:qTv7/COck+e2FymRvadv62gMdZztPaShugOCi3I+8D8= rsc.io/quote/v3 v3.1.0/go.mod h1:yEA65RcK8LyAZtP9Kv3t0HmxON59tX3rD+tICJqUlj0= rsc.io/sampler v1.3.0/go.mod h1:T1hPZKmBbMNahiBKFy5HrXp6adAjACjK9JXDnKaTXpA= diff --git a/asp/internal/app-config/config.go b/asp/internal/app-config/config.go index ffb7867..98c3fc6 100644 --- a/asp/internal/app-config/config.go +++ b/asp/internal/app-config/config.go @@ -9,6 +9,7 @@ import ( "github.com/ark-network/ark/internal/core/ports" "github.com/ark-network/ark/internal/infrastructure/db" oceanwallet "github.com/ark-network/ark/internal/infrastructure/ocean-wallet" + scheduler "github.com/ark-network/ark/internal/infrastructure/scheduler/gocron" txbuilder "github.com/ark-network/ark/internal/infrastructure/tx-builder/covenant" txbuilderdummy "github.com/ark-network/ark/internal/infrastructure/tx-builder/dummy" log "github.com/sirupsen/logrus" @@ -41,12 +42,14 @@ type Config struct { BlockchainScannerType string WalletAddr string MinRelayFee uint64 + RoundLifetime int64 repo ports.RepoManager svc application.Service wallet ports.WalletService txBuilder ports.TxBuilder scanner ports.BlockchainScanner + scheduler ports.SchedulerService } func (c *Config) Validate() error { @@ -86,9 +89,30 @@ func (c *Config) Validate() error { if err := c.scannerService(); err != nil { return err } + if err := c.schedulerService(); err != nil { + return err + } if err := c.appService(); err != nil { return err } + // round life time must be a multiple of 512 + if c.RoundLifetime <= 0 || c.RoundLifetime%512 != 0 { + return fmt.Errorf("invalid round lifetime, must be greater than 0 and a multiple of 512") + } + seq, err := common.BIP68Encode(uint(c.RoundLifetime)) + if err != nil { + return fmt.Errorf("invalid round lifetime, %s", err) + } + + seconds, err := common.BIP68Decode(seq) + if err != nil { + return fmt.Errorf("invalid round lifetime, %s", err) + } + + if seconds != uint(c.RoundLifetime) { + return fmt.Errorf("invalid round lifetime, must be a multiple of 512") + } + return nil } @@ -141,7 +165,7 @@ func (c *Config) txBuilderService() error { case "dummy": svc = txbuilderdummy.NewTxBuilder(c.wallet, net) case "covenant": - svc = txbuilder.NewTxBuilder(c.wallet, net) + svc = txbuilder.NewTxBuilder(c.wallet, net, c.RoundLifetime) default: err = fmt.Errorf("unknown tx builder type") } @@ -170,10 +194,28 @@ func (c *Config) scannerService() error { return nil } +func (c *Config) schedulerService() error { + var svc ports.SchedulerService + var err error + switch c.SchedulerType { + case "gocron": + svc = scheduler.NewScheduler() + default: + err = fmt.Errorf("unknown scheduler type") + } + if err != nil { + return err + } + + c.scheduler = svc + return nil +} + func (c *Config) appService() error { net := c.mainChain() svc, err := application.NewService( - c.RoundInterval, c.Network, net, c.wallet, c.repo, c.txBuilder, c.scanner, c.MinRelayFee, + c.Network, net, c.RoundInterval, c.RoundLifetime, c.MinRelayFee, + c.wallet, c.repo, c.txBuilder, c.scanner, c.scheduler, ) if err != nil { return err diff --git a/asp/internal/config/config.go b/asp/internal/config/config.go index db9e76b..707a8e5 100644 --- a/asp/internal/config/config.go +++ b/asp/internal/config/config.go @@ -23,6 +23,7 @@ type Config struct { Network common.Network LogLevel int MinRelayFee uint64 + RoundLifetime int64 } var ( @@ -38,6 +39,7 @@ var ( LogLevel = "LOG_LEVEL" Network = "NETWORK" MinRelayFee = "MIN_RELAY_FEE" + RoundLifetime = "ROUND_LIFETIME" defaultDatadir = common.AppDataDir("arkd", false) defaultRoundInterval = 60 @@ -50,6 +52,7 @@ var ( defaultNetwork = "testnet" defaultLogLevel = 5 defaultMinRelayFee = 30 + defaultRoundLifetime = 512 ) func LoadConfig() (*Config, error) { @@ -66,6 +69,7 @@ func LoadConfig() (*Config, error) { viper.SetDefault(Insecure, defaultInsecure) viper.SetDefault(LogLevel, defaultLogLevel) viper.SetDefault(Network, defaultNetwork) + viper.SetDefault(RoundLifetime, defaultRoundLifetime) viper.SetDefault(MinRelayFee, defaultMinRelayFee) net, err := getNetwork() @@ -90,6 +94,7 @@ func LoadConfig() (*Config, error) { LogLevel: viper.GetInt(LogLevel), Network: net, MinRelayFee: viper.GetUint64(MinRelayFee), + RoundLifetime: viper.GetInt64(RoundLifetime), }, nil } diff --git a/asp/internal/core/application/service.go b/asp/internal/core/application/service.go index 8d52c76..7a2f1fe 100644 --- a/asp/internal/core/application/service.go +++ b/asp/internal/core/application/service.go @@ -41,27 +41,31 @@ type Service interface { } type service struct { - minRelayFee uint64 - roundInterval int64 network common.Network onchainNework network.Network pubkey *secp256k1.PublicKey + roundLifetime int64 + roundInterval int64 + minRelayFee uint64 wallet ports.WalletService repoManager ports.RepoManager builder ports.TxBuilder scanner ports.BlockchainScanner + sweeper *sweeper paymentRequests *paymentsMap forfeitTxs *forfeitTxsMap - eventsCh chan domain.RoundEvent + + eventsCh chan domain.RoundEvent } func NewService( - interval int64, network common.Network, onchainNetwork network.Network, + network common.Network, onchainNetwork network.Network, + roundInterval, roundLifetime int64, minRelayFee uint64, walletSvc ports.WalletService, repoManager ports.RepoManager, builder ports.TxBuilder, scanner ports.BlockchainScanner, - minRelayFee uint64, + scheduler ports.SchedulerService, ) (Service, error) { eventsCh := make(chan domain.RoundEvent) paymentRequests := newPaymentsMap(nil) @@ -72,10 +76,14 @@ func NewService( if err != nil { return nil, fmt.Errorf("failed to fetch pubkey: %s", err) } + + sweeper := newSweeper(walletSvc, repoManager, builder, scheduler) + svc := &service{ - minRelayFee, interval, network, onchainNetwork, pubkey, - walletSvc, repoManager, builder, scanner, paymentRequests, forfeitTxs, - eventsCh, + network, onchainNetwork, pubkey, + roundLifetime, roundInterval, minRelayFee, + walletSvc, repoManager, builder, scanner, sweeper, + paymentRequests, forfeitTxs, eventsCh, } repoManager.RegisterEventsHandler( func(round *domain.Round) { @@ -92,12 +100,18 @@ func NewService( } func (s *service) Start() error { + log.Debug("starting sweeper service") + if err := s.sweeper.start(); err != nil { + return err + } + log.Debug("starting app service") go s.start() return nil } func (s *service) Stop() { + s.sweeper.stop() // nolint vtxos, _ := s.repoManager.Vtxos().GetSpendableVtxos( context.Background(), "", @@ -356,7 +370,17 @@ func (s *service) finalizeRound() { return } + now := time.Now().Unix() + expirationTimestamp := now + s.roundLifetime + 30 // add 30 secs to be sure that the tx is confirmed + + if err := s.sweeper.schedule(expirationTimestamp, txid, round.CongestionTree); err != nil { + changes = round.Fail(fmt.Errorf("failed to schedule sweep tx: %s", err)) + log.WithError(err).Warn("failed to schedule sweep tx") + return + } + changes, _ = round.EndFinalization(forfeitTxs, txid) + log.Debugf("finalized round %s with pool tx %s", round.Id, round.Txid) } diff --git a/asp/internal/core/application/sweeper.go b/asp/internal/core/application/sweeper.go new file mode 100644 index 0000000..b5d00b7 --- /dev/null +++ b/asp/internal/core/application/sweeper.go @@ -0,0 +1,579 @@ +package application + +import ( + "context" + "encoding/hex" + "fmt" + "time" + + "github.com/ark-network/ark/common/tree" + "github.com/ark-network/ark/internal/core/domain" + "github.com/ark-network/ark/internal/core/ports" + "github.com/btcsuite/btcd/chaincfg/chainhash" + "github.com/decred/dcrd/dcrec/secp256k1/v4" + log "github.com/sirupsen/logrus" + "github.com/vulpemventures/go-elements/psetv2" +) + +// sweeper is an unexported service running while the main application service is started +// it is responsible for sweeping onchain shared outputs that expired +// it also handles delaying the sweep events in case some parts of the tree are broadcasted +// when a round is finalized, the main application service schedules a sweep event on the newly created congestion tree +type sweeper struct { + wallet ports.WalletService + repoManager ports.RepoManager + builder ports.TxBuilder + scheduler ports.SchedulerService + + // cache of scheduled tasks, avoid scheduling the same sweep event multiple times + scheduledTasks map[string]struct{} +} + +func newSweeper( + wallet ports.WalletService, + repoManager ports.RepoManager, + builder ports.TxBuilder, + scheduler ports.SchedulerService, +) *sweeper { + return &sweeper{ + wallet, + repoManager, + builder, + scheduler, + make(map[string]struct{}), + } +} + +func (s *sweeper) start() error { + s.scheduler.Start() + + allRounds, err := s.repoManager.Rounds().GetSweepableRounds(context.Background()) + if err != nil { + return err + } + + for _, round := range allRounds { + task := s.createTask(round.Txid, round.CongestionTree) + task() + } + + return nil +} + +func (s *sweeper) stop() { + s.scheduler.Stop() +} + +// removeTask update the cached map of scheduled tasks +func (s *sweeper) removeTask(treeRootTxid string) { + delete(s.scheduledTasks, treeRootTxid) +} + +// schedule set up a task to be executed once at the given timestamp +func (s *sweeper) schedule( + expirationTimestamp int64, roundTxid string, congestionTree tree.CongestionTree, +) error { + root, err := congestionTree.Root() + if err != nil { + return err + } + + if _, scheduled := s.scheduledTasks[root.Txid]; scheduled { + return nil + } + + task := s.createTask(roundTxid, congestionTree) + fancyTime := time.Unix(expirationTimestamp, 0).Format("2006-01-02 15:04:05") + log.Debugf("scheduled sweep task at %s", fancyTime) + if err := s.scheduler.ScheduleTaskOnce(expirationTimestamp, task); err != nil { + return err + } + + s.scheduledTasks[root.Txid] = struct{}{} + return nil +} + +// createTask returns a function passed as handler in the scheduler +// it tries to craft a sweep tx containing the onchain outputs of the given congestion tree +// if some parts of the tree have been broadcasted in the meantine, it will schedule the next taskes for the remaining parts of the tree +func (s *sweeper) createTask( + roundTxid string, congestionTree tree.CongestionTree, +) func() { + return func() { + ctx := context.Background() + root, err := congestionTree.Root() + if err != nil { + log.WithError(err).Error("error while getting root node") + return + } + + s.removeTask(root.Txid) + log.Debugf("sweeper: %s", root.Txid) + + sweepInputs := make([]ports.SweepInput, 0) + vtxoKeys := make([]domain.VtxoKey, 0) // vtxos associated to the sweep inputs + + // inspect the congestion tree to find onchain shared outputs + sharedOutputs, err := s.findSweepableOutputs(ctx, congestionTree) + if err != nil { + log.WithError(err).Error("error while inspecting congestion tree") + return + } + + for expiredAt, inputs := range sharedOutputs { + // if the shared outputs are not expired, schedule a sweep task for it + if time.Unix(expiredAt, 0).After(time.Now()) { + subtrees, err := computeSubTrees(congestionTree, inputs) + if err != nil { + log.WithError(err).Error("error while computing subtrees") + continue + } + + for _, subTree := range subtrees { + // mitigate the risk to get BIP68 non-final errors by scheduling the task 30 seconds after the expiration time + if err := s.schedule(int64(expiredAt), roundTxid, subTree); err != nil { + log.WithError(err).Error("error while scheduling sweep task") + continue + } + } + continue + } + + // iterate over the expired shared outputs + for _, input := range inputs { + // sweepableVtxos related to the sweep input + sweepableVtxos := make([]domain.VtxoKey, 0) + + // check if input is the vtxo itself + vtxos, _ := s.repoManager.Vtxos().GetVtxos( + ctx, + []domain.VtxoKey{ + { + Txid: input.InputArgs.Txid, + VOut: input.InputArgs.TxIndex, + }, + }, + ) + if len(vtxos) > 0 { + if !vtxos[0].Swept && !vtxos[0].Redeemed { + sweepableVtxos = append(sweepableVtxos, vtxos[0].VtxoKey) + } + } else { + // if it's not a vtxo, find all the vtxos leaves reachable from that input + vtxosLeaves, err := congestionTree.FindLeaves(input.InputArgs.Txid, input.InputArgs.TxIndex) + if err != nil { + log.WithError(err).Error("error while finding vtxos leaves") + continue + } + + for _, leaf := range vtxosLeaves { + pset, err := psetv2.NewPsetFromBase64(leaf.Tx) + if err != nil { + log.Error(fmt.Errorf("error while decoding pset: %w", err)) + continue + } + + vtxo, err := extractVtxoOutpoint(pset) + if err != nil { + log.Error(err) + continue + } + + sweepableVtxos = append(sweepableVtxos, *vtxo) + } + + if len(sweepableVtxos) <= 0 { + continue + } + + firstVtxo, err := s.repoManager.Vtxos().GetVtxos(ctx, sweepableVtxos[1:]) + if err != nil { + log.Error(fmt.Errorf("error while getting vtxo: %w", err)) + sweepInputs = append(sweepInputs, input) // add the input anyway in order to try to sweep it + continue + } + + if firstVtxo[0].Swept || firstVtxo[0].Redeemed { + // we assume that if the first vtxo is swept or redeemed, the shared output has been spent + // skip, the output is already swept or spent by a unilateral redeem + continue + } + } + + if len(sweepableVtxos) > 0 { + vtxoKeys = append(vtxoKeys, sweepableVtxos...) + sweepInputs = append(sweepInputs, input) + } + } + } + + if len(sweepInputs) > 0 { + // build the sweep transaction with all the expired non-swept shared outputs + sweepTx, err := s.builder.BuildSweepTx(s.wallet, sweepInputs) + if err != nil { + log.WithError(err).Error("error while building sweep tx") + return + } + + err = nil + txid := "" + // retry until the tx is broadcasted or the error is not BIP68 final + for len(txid) == 0 && (err == nil || err == fmt.Errorf("non-BIP68-final")) { + if err != nil { + log.Debugln("sweep tx not BIP68 final, retrying in 5 seconds") + time.Sleep(5 * time.Second) + } + + txid, err = s.wallet.BroadcastTransaction(ctx, sweepTx) + } + + if err != nil { + log.WithError(err).Error("error while broadcasting sweep tx") + return + } + if len(txid) > 0 { + log.Debugln("sweep tx broadcasted:", txid) + vtxosRepository := s.repoManager.Vtxos() + + // mark the vtxos as swept + if err := vtxosRepository.SweepVtxos(ctx, vtxoKeys); err != nil { + log.Error(fmt.Errorf("error while deleting vtxos: %w", err)) + return + } + + log.Debugf("%d vtxos swept", len(vtxoKeys)) + + roundVtxos, err := vtxosRepository.GetVtxosForRound(ctx, roundTxid) + if err != nil { + log.WithError(err).Error("error while getting vtxos for round") + return + } + + allSwept := true + for _, vtxo := range roundVtxos { + allSwept = allSwept && vtxo.Swept + if !allSwept { + break + } + } + + if allSwept { + // update the round + roundRepo := s.repoManager.Rounds() + round, err := roundRepo.GetRoundWithTxid(ctx, roundTxid) + if err != nil { + log.WithError(err).Error("error while getting round") + return + } + + round.Sweep() + + if err := roundRepo.AddOrUpdateRound(ctx, *round); err != nil { + log.WithError(err).Error("error while marking round as swept") + return + } + } + } + } + } +} + +// onchainOutputs iterates over all the nodes' outputs in the congestion tree and checks their onchain state +// returns the sweepable outputs as ports.SweepInput mapped by their expiration time +func (s *sweeper) findSweepableOutputs( + ctx context.Context, + congestionTree tree.CongestionTree, +) (map[int64][]ports.SweepInput, error) { + sweepableOutputs := make(map[int64][]ports.SweepInput) + blocktimeCache := make(map[string]int64) // txid -> blocktime + nodesToCheck := congestionTree[0] // init with the root + + for len(nodesToCheck) > 0 { + newNodesToCheck := make([]tree.Node, 0) + + for _, node := range nodesToCheck { + isPublished, blocktime, err := s.wallet.IsTransactionPublished(ctx, node.Txid) + if err != nil { + return nil, err + } + + var expirationTime int64 + var sweepInputs []ports.SweepInput + + if !isPublished { + if _, ok := blocktimeCache[node.ParentTxid]; !ok { + isPublished, blocktime, err := s.wallet.IsTransactionPublished(ctx, node.ParentTxid) + if !isPublished || err != nil { + return nil, fmt.Errorf("tx %s not found", node.Txid) + } + + blocktimeCache[node.ParentTxid] = blocktime + } + + expirationTime, sweepInputs, err = s.nodeToSweepInputs(blocktimeCache[node.ParentTxid], node) + if err != nil { + return nil, err + } + } else { + // cache the blocktime for future use + blocktimeCache[node.Txid] = int64(blocktime) + + // if the tx is onchain, it means that the input is spent + // add the children to the nodes in order to check them during the next iteration + // We will return the error below, but are we going to schedule the tasks for the "children roots"? + if !node.Leaf { + children := congestionTree.Children(node.Txid) + newNodesToCheck = append(newNodesToCheck, children...) + continue + } + + // if the node is a leaf, the vtxos outputs should added as onchain outputs if they are not swept yet + vtxoExpiration, sweepInput, err := s.leafToSweepInput(ctx, blocktime, node) + if err != nil { + return nil, err + } + + if sweepInput != nil { + expirationTime = vtxoExpiration + sweepInputs = []ports.SweepInput{*sweepInput} + } + } + + if _, ok := sweepableOutputs[expirationTime]; !ok { + sweepableOutputs[expirationTime] = make([]ports.SweepInput, 0) + } + sweepableOutputs[expirationTime] = append(sweepableOutputs[expirationTime], sweepInputs...) + } + + nodesToCheck = newNodesToCheck + } + + return sweepableOutputs, nil +} + +func (s *sweeper) leafToSweepInput(ctx context.Context, txBlocktime int64, node tree.Node) (int64, *ports.SweepInput, error) { + pset, err := psetv2.NewPsetFromBase64(node.Tx) + if err != nil { + return -1, nil, err + } + + vtxo, err := extractVtxoOutpoint(pset) + if err != nil { + return -1, nil, err + } + + fromRepo, err := s.repoManager.Vtxos().GetVtxos(ctx, []domain.VtxoKey{*vtxo}) + if err != nil { + return -1, nil, err + } + + if len(fromRepo) == 0 { + return -1, nil, fmt.Errorf("vtxo not found") + } + + if fromRepo[0].Swept { + return -1, nil, nil + } + + if fromRepo[0].Redeemed { + return -1, nil, nil + } + + // if the vtxo is not swept or redeemed, add it to the onchain outputs + pubKeyBytes, err := hex.DecodeString(fromRepo[0].Pubkey) + if err != nil { + return -1, nil, err + } + + pubKey, err := secp256k1.ParsePubKey(pubKeyBytes) + if err != nil { + return -1, nil, err + } + + sweepLeaf, lifetime, err := s.builder.GetLeafSweepClosure(node, pubKey) + if err != nil { + return -1, nil, err + } + + sweepInput := ports.SweepInput{ + InputArgs: psetv2.InputArgs{ + Txid: vtxo.Txid, + TxIndex: vtxo.VOut, + }, + SweepLeaf: *sweepLeaf, + Amount: fromRepo[0].Amount, + } + + return txBlocktime + lifetime, &sweepInput, nil +} + +func (s *sweeper) nodeToSweepInputs(parentBlocktime int64, node tree.Node) (int64, []ports.SweepInput, error) { + pset, err := psetv2.NewPsetFromBase64(node.Tx) + if err != nil { + return -1, nil, err + } + + if len(pset.Inputs) != 1 { + return -1, nil, fmt.Errorf("invalid node pset, expect 1 input, got %d", len(pset.Inputs)) + } + + // if the tx is not onchain, it means that the input is an existing shared output + input := pset.Inputs[0] + txid := chainhash.Hash(input.PreviousTxid).String() + index := input.PreviousTxIndex + + sweepLeaf, lifetime, err := extractSweepLeaf(input) + if err != nil { + return -1, nil, err + } + + expirationTime := parentBlocktime + lifetime + + amount := uint64(0) + for _, out := range pset.Outputs { + amount += out.Value + } + + sweepInputs := []ports.SweepInput{ + { + InputArgs: psetv2.InputArgs{ + Txid: txid, + TxIndex: index, + }, + SweepLeaf: *sweepLeaf, + Amount: amount, + }, + } + + return expirationTime, sweepInputs, nil +} + +func computeSubTrees(congestionTree tree.CongestionTree, inputs []ports.SweepInput) ([]tree.CongestionTree, error) { + subTrees := make(map[string]tree.CongestionTree, 0) + + // for each sweepable input, create a sub congestion tree + // it allows to skip the part of the tree that has been broadcasted in the next task + for _, input := range inputs { + subTree, err := computeSubTree(congestionTree, input.InputArgs.Txid) + if err != nil { + log.WithError(err).Error("error while finding sub tree") + continue + } + + root, err := subTree.Root() + if err != nil { + log.WithError(err).Error("error while getting root node") + continue + } + + subTrees[root.Txid] = subTree + } + + // filter out the sub trees, remove the ones that are included in others + filteredSubTrees := make([]tree.CongestionTree, 0) + for i, subTree := range subTrees { + notIncludedInOtherTrees := true + + for j, otherSubTree := range subTrees { + if i == j { + continue + } + contains, err := containsTree(otherSubTree, subTree) + if err != nil { + log.WithError(err).Error("error while checking if a tree contains another") + continue + } + + if contains { + notIncludedInOtherTrees = false + break + } + } + + if notIncludedInOtherTrees { + filteredSubTrees = append(filteredSubTrees, subTree) + } + } + + return filteredSubTrees, nil +} + +func computeSubTree(congestionTree tree.CongestionTree, newRoot string) (tree.CongestionTree, error) { + for _, level := range congestionTree { + for _, node := range level { + if node.Txid == newRoot || node.ParentTxid == newRoot { + newTree := make(tree.CongestionTree, 0) + newTree = append(newTree, []tree.Node{node}) + + children := congestionTree.Children(node.Txid) + for len(children) > 0 { + newTree = append(newTree, children) + newChildren := make([]tree.Node, 0) + for _, child := range children { + newChildren = append(newChildren, congestionTree.Children(child.Txid)...) + } + children = newChildren + } + + return newTree, nil + } + } + } + + return nil, fmt.Errorf("failed to create subtree, new root not found") +} + +func containsTree(tr0 tree.CongestionTree, tr1 tree.CongestionTree) (bool, error) { + tr1Root, err := tr1.Root() + if err != nil { + return false, err + } + + for _, level := range tr0 { + for _, node := range level { + if node.Txid == tr1Root.Txid { + return true, nil + } + } + } + + return false, nil +} + +// given a congestion tree input, searches and returns the sweep leaf and its lifetime in seconds +func extractSweepLeaf(input psetv2.Input) (sweepLeaf *psetv2.TapLeafScript, lifetime int64, err error) { + for _, leaf := range input.TapLeafScript { + isSweep, _, seconds, err := tree.DecodeSweepScript(leaf.Script) + if err != nil { + return nil, 0, err + } + if isSweep { + lifetime = int64(seconds) + sweepLeaf = &leaf + break + } + } + + if sweepLeaf == nil { + return nil, 0, fmt.Errorf("sweep leaf not found") + } + + return sweepLeaf, lifetime, nil +} + +// assuming the pset is a leaf in the congestion tree, returns the vtxos outputs +func extractVtxoOutpoint(pset *psetv2.Pset) (*domain.VtxoKey, error) { + if len(pset.Outputs) != 2 { + return nil, fmt.Errorf("invalid leaf pset, expect 2 outputs, got %d", len(pset.Outputs)) + } + + utx, err := pset.UnsignedTx() + if err != nil { + return nil, err + } + + return &domain.VtxoKey{ + Txid: utx.TxHash().String(), + VOut: 0, + }, nil +} diff --git a/asp/internal/core/domain/payment.go b/asp/internal/core/domain/payment.go index 9e76167..b8c3207 100644 --- a/asp/internal/core/domain/payment.go +++ b/asp/internal/core/domain/payment.go @@ -125,4 +125,5 @@ type Vtxo struct { PoolTx string Spent bool Redeemed bool + Swept bool } diff --git a/asp/internal/core/domain/round.go b/asp/internal/core/domain/round.go index 35577ea..ca64634 100644 --- a/asp/internal/core/domain/round.go +++ b/asp/internal/core/domain/round.go @@ -46,6 +46,7 @@ type Round struct { Connectors []string DustAmount uint64 Version uint + Swept bool // true if all the vtxos are vtxo.Swept changes []RoundEvent } @@ -239,6 +240,10 @@ func (r *Round) TotalOutputAmount() uint64 { return tot } +func (r *Round) Sweep() { + r.Swept = true +} + func (r *Round) raise(event RoundEvent) { if r.changes == nil { r.changes = make([]RoundEvent, 0) diff --git a/asp/internal/core/domain/round_repo.go b/asp/internal/core/domain/round_repo.go index b26e1f6..4f78193 100644 --- a/asp/internal/core/domain/round_repo.go +++ b/asp/internal/core/domain/round_repo.go @@ -1,6 +1,8 @@ package domain -import "context" +import ( + "context" +) type RoundEventRepository interface { Save(ctx context.Context, id string, events ...RoundEvent) error @@ -12,6 +14,7 @@ type RoundRepository interface { GetCurrentRound(ctx context.Context) (*Round, error) GetRoundWithId(ctx context.Context, id string) (*Round, error) GetRoundWithTxid(ctx context.Context, txid string) (*Round, error) + GetSweepableRounds(ctx context.Context) ([]Round, error) } type VtxoRepository interface { @@ -19,5 +22,7 @@ type VtxoRepository interface { SpendVtxos(ctx context.Context, vtxos []VtxoKey) error RedeemVtxos(ctx context.Context, vtxos []VtxoKey) ([]Vtxo, error) GetVtxos(ctx context.Context, vtxos []VtxoKey) ([]Vtxo, error) + GetVtxosForRound(ctx context.Context, txid string) ([]Vtxo, error) + SweepVtxos(ctx context.Context, vtxos []VtxoKey) error GetSpendableVtxos(ctx context.Context, pubkey string) ([]Vtxo, error) } diff --git a/asp/internal/core/ports/scheduler.go b/asp/internal/core/ports/scheduler.go index 09c1248..c1333fd 100644 --- a/asp/internal/core/ports/scheduler.go +++ b/asp/internal/core/ports/scheduler.go @@ -5,4 +5,5 @@ type SchedulerService interface { Stop() ScheduleTask(interval int64, immediate bool, task func()) error + ScheduleTaskOnce(delay int64, task func()) error } diff --git a/asp/internal/core/ports/tx_builder.go b/asp/internal/core/ports/tx_builder.go index cb4ff53..056c62b 100644 --- a/asp/internal/core/ports/tx_builder.go +++ b/asp/internal/core/ports/tx_builder.go @@ -4,8 +4,15 @@ import ( "github.com/ark-network/ark/common/tree" "github.com/ark-network/ark/internal/core/domain" "github.com/decred/dcrd/dcrec/secp256k1/v4" + "github.com/vulpemventures/go-elements/psetv2" ) +type SweepInput struct { + InputArgs psetv2.InputArgs + SweepLeaf psetv2.TapLeafScript + Amount uint64 +} + type TxBuilder interface { BuildPoolTx( aspPubkey *secp256k1.PublicKey, payments []domain.Payment, minRelayFee uint64, @@ -13,5 +20,10 @@ type TxBuilder interface { BuildForfeitTxs( aspPubkey *secp256k1.PublicKey, poolTx string, payments []domain.Payment, ) (connectors []string, forfeitTxs []string, err error) + BuildSweepTx( + wallet WalletService, + inputs []SweepInput, + ) (signedSweepTx string, err error) + GetLeafSweepClosure(node tree.Node, userPubKey *secp256k1.PublicKey) (*psetv2.TapLeafScript, int64, error) GetVtxoScript(userPubkey, aspPubkey *secp256k1.PublicKey) ([]byte, error) } diff --git a/asp/internal/core/ports/wallet.go b/asp/internal/core/ports/wallet.go index 3f10f7f..baa15b2 100644 --- a/asp/internal/core/ports/wallet.go +++ b/asp/internal/core/ports/wallet.go @@ -16,6 +16,8 @@ type WalletService interface { ) (string, error) SelectUtxos(ctx context.Context, asset string, amount uint64) ([]TxInput, uint64, error) BroadcastTransaction(ctx context.Context, txHex string) (string, error) + SignPsetWithKey(ctx context.Context, pset string, inputIndexes []int) (string, error) // inputIndexes == nil means sign all inputs + IsTransactionPublished(ctx context.Context, txid string) (isPublished bool, blocktime int64, err error) EstimateFees(ctx context.Context, pset string) (uint64, error) Close() } diff --git a/asp/internal/infrastructure/db/badger/round_repo.go b/asp/internal/infrastructure/db/badger/round_repo.go index fec9d51..f0f0c8b 100644 --- a/asp/internal/infrastructure/db/badger/round_repo.go +++ b/asp/internal/infrastructure/db/badger/round_repo.go @@ -95,6 +95,18 @@ func (r *roundRepository) GetRoundWithTxid( return round, nil } +func (r *roundRepository) GetSweepableRounds( + ctx context.Context, +) ([]domain.Round, error) { + query := badgerhold.Where("Stage.Code").Eq(domain.FinalizationStage). + And("Stage.Ended").Eq(true).And("Swept").Eq(false) + rounds, err := r.findRound(ctx, query) + if err != nil { + return nil, err + } + return rounds, nil +} + func (r *roundRepository) Close() { r.store.Close() } diff --git a/asp/internal/infrastructure/db/badger/vtxo_repo.go b/asp/internal/infrastructure/db/badger/vtxo_repo.go index 3bc63e8..2201889 100644 --- a/asp/internal/infrastructure/db/badger/vtxo_repo.go +++ b/asp/internal/infrastructure/db/badger/vtxo_repo.go @@ -93,16 +93,34 @@ func (r *vtxoRepository) GetVtxos( return vtxos, nil } +func (r *vtxoRepository) GetVtxosForRound( + ctx context.Context, txid string, +) ([]domain.Vtxo, error) { + query := badgerhold.Where("Txid").Eq(txid) + return r.findVtxos(ctx, query) +} + func (r *vtxoRepository) GetSpendableVtxos( ctx context.Context, pubkey string, ) ([]domain.Vtxo, error) { - query := badgerhold.Where("Spent").Eq(false).And("Redeemed").Eq(false) + query := badgerhold.Where("Spent").Eq(false).And("Redeemed").Eq(false).And("Swept").Eq(false) if len(pubkey) > 0 { query = query.And("Pubkey").Eq(pubkey) } return r.findVtxos(ctx, query) } +func (r *vtxoRepository) SweepVtxos( + ctx context.Context, vtxoKeys []domain.VtxoKey, +) error { + for _, vtxoKey := range vtxoKeys { + if err := r.sweepVtxo(ctx, vtxoKey); err != nil { + return err + } + } + return nil +} + func (r *vtxoRepository) Close() { r.store.Close() } @@ -203,3 +221,25 @@ func (r *vtxoRepository) findVtxos(ctx context.Context, query *badgerhold.Query) return vtxos, err } + +func (r *vtxoRepository) sweepVtxo(ctx context.Context, vtxoKey domain.VtxoKey) error { + vtxo, err := r.getVtxo(ctx, vtxoKey) + if err != nil { + return err + } + if vtxo.Swept { + return nil + } + + vtxo.Swept = true + if ctx.Value("tx") != nil { + tx := ctx.Value("tx").(*badger.Txn) + err = r.store.TxUpdate(tx, vtxoKey.Hash(), *vtxo) + } else { + err = r.store.Update(vtxoKey.Hash(), *vtxo) + } + if err != nil { + return err + } + return nil +} diff --git a/asp/internal/infrastructure/ocean-wallet/service.go b/asp/internal/infrastructure/ocean-wallet/service.go index 96557a8..6ddb208 100644 --- a/asp/internal/infrastructure/ocean-wallet/service.go +++ b/asp/internal/infrastructure/ocean-wallet/service.go @@ -60,6 +60,7 @@ func NewService(addr string) (ports.WalletService, error) { if err != nil { return nil, err } + found := false for _, account := range info.GetAccounts() { if account.GetLabel() == accountLabel { diff --git a/asp/internal/infrastructure/ocean-wallet/transaction.go b/asp/internal/infrastructure/ocean-wallet/transaction.go index a0a8f08..77ac531 100644 --- a/asp/internal/infrastructure/ocean-wallet/transaction.go +++ b/asp/internal/infrastructure/ocean-wallet/transaction.go @@ -2,12 +2,17 @@ package oceanwallet import ( "context" + "encoding/binary" "encoding/hex" "fmt" + "strings" + "time" pb "github.com/ark-network/ark/api-spec/protobuf/gen/ocean/v1" + "github.com/ark-network/ark/common/tree" "github.com/ark-network/ark/internal/core/ports" "github.com/btcsuite/btcd/chaincfg/chainhash" + "github.com/btcsuite/btcd/txscript" "github.com/vulpemventures/go-elements/elementsutil" "github.com/vulpemventures/go-elements/psetv2" ) @@ -76,6 +81,24 @@ func (s *service) SelectUtxos(ctx context.Context, asset string, amount uint64) return inputs, res.GetChange(), nil } +func (s *service) GetTransaction( + ctx context.Context, txid string, +) (string, int64, error) { + res, err := s.txClient.GetTransaction(ctx, &pb.GetTransactionRequest{ + Txid: txid, + }) + if err != nil { + return "", 0, err + } + + if res.GetBlockDetails().GetTimestamp() > 0 { + return res.GetTxHex(), res.BlockDetails.GetTimestamp(), nil + } + + // if not confirmed, we return now + 30 secs to estimate the next blocktime + return res.GetTxHex(), time.Now().Unix() + 30, nil +} + func (s *service) BroadcastTransaction( ctx context.Context, txHex string, ) (string, error) { @@ -85,11 +108,104 @@ func (s *service) BroadcastTransaction( }, ) if err != nil { + if strings.Contains(err.Error(), "non-BIP68-final") { + return "", fmt.Errorf("non-BIP68-final") + } + return "", err } return res.GetTxid(), nil } +func (s *service) IsTransactionPublished( + ctx context.Context, txid string, +) (bool, int64, error) { + _, blocktime, err := s.GetTransaction(ctx, txid) + if err != nil { + if strings.Contains(strings.ToLower(err.Error()), "missing transaction") { + return false, 0, nil + } + return false, 0, err + } + + return true, blocktime, nil +} + +func (s *service) SignPsetWithKey(ctx context.Context, b64 string, indexes []int) (string, error) { + pset, err := psetv2.NewPsetFromBase64(b64) + if err != nil { + return "", err + } + + if indexes == nil { + for i := 0; i < len(pset.Inputs); i++ { + indexes = append(indexes, i) + } + } + + key, masterKey, err := s.getPubkey(ctx) + if err != nil { + return "", err + } + + fingerprint := binary.LittleEndian.Uint32(masterKey.FingerPrint) + extendedKey, err := masterKey.Serialize() + if err != nil { + return "", err + } + + pset.Global.Xpubs = []psetv2.Xpub{{ + ExtendedKey: extendedKey[:len(extendedKey)-4], + MasterFingerprint: fingerprint, + DerivationPath: derivationPath, + }} + + updater, err := psetv2.NewUpdater(pset) + if err != nil { + return "", err + } + + bip32derivation := psetv2.DerivationPathWithPubKey{ + PubKey: key.SerializeCompressed(), + MasterKeyFingerprint: fingerprint, + Bip32Path: derivationPath, + } + + for _, i := range indexes { + if len(pset.Inputs[i].TapLeafScript) == 0 { + return "", fmt.Errorf("no tap leaf script found for input %d", i) + } + + leafHash := pset.Inputs[i].TapLeafScript[0].TapHash() + + if err := updater.AddInTapBip32Derivation(i, psetv2.TapDerivationPathWithPubKey{ + DerivationPathWithPubKey: bip32derivation, + LeafHashes: [][]byte{leafHash[:]}, + }); err != nil { + return "", err + } + + if err := updater.AddInSighashType(i, txscript.SigHashDefault); err != nil { + return "", err + } + } + + unsignedPset, err := pset.ToBase64() + if err != nil { + return "", err + } + + signedPset, err := s.txClient.SignPsetWithSchnorrKey(ctx, &pb.SignPsetWithSchnorrKeyRequest{ + Tx: unsignedPset, + SighashType: uint32(txscript.SigHashDefault), + }) + if err != nil { + return "", err + } + + return signedPset.GetSignedTx(), nil +} + func (s *service) EstimateFees( ctx context.Context, pset string, ) (uint64, error) { @@ -102,15 +218,30 @@ func (s *service) EstimateFees( outputs := make([]*pb.Output, 0, len(tx.Outputs)) for _, in := range tx.Inputs { - if in.WitnessUtxo == nil { - return 0, fmt.Errorf("missing witness utxo, cannot estimate fees") + pbInput := &pb.Input{ + Txid: chainhash.Hash(in.PreviousTxid).String(), + Index: in.PreviousTxIndex, } - inputs = append(inputs, &pb.Input{ - Txid: chainhash.Hash(in.PreviousTxid).String(), - Index: in.PreviousTxIndex, - Script: hex.EncodeToString(in.WitnessUtxo.Script), - }) + if len(in.TapLeafScript) == 1 { + isSweep, _, _, err := tree.DecodeSweepScript(in.TapLeafScript[0].Script) + if err != nil { + return 0, err + } + + if isSweep { + pbInput.WitnessSize = 64 + pbInput.ScriptsigSize = 0 + } + } else { + if in.WitnessUtxo == nil { + return 0, fmt.Errorf("missing witness utxo, cannot estimate fees") + } + + pbInput.Script = hex.EncodeToString(in.WitnessUtxo.Script) + } + + inputs = append(inputs, pbInput) } for _, out := range tx.Outputs { diff --git a/asp/internal/infrastructure/ocean-wallet/wallet.go b/asp/internal/infrastructure/ocean-wallet/wallet.go index 1513875..d9f17c6 100644 --- a/asp/internal/infrastructure/ocean-wallet/wallet.go +++ b/asp/internal/infrastructure/ocean-wallet/wallet.go @@ -8,30 +8,16 @@ import ( "github.com/ark-network/ark/internal/core/ports" "github.com/btcsuite/btcd/btcutil/hdkeychain" "github.com/decred/dcrd/dcrec/secp256k1/v4" + "github.com/vulpemventures/go-bip32" ) const accountLabel = "ark" +var derivationPath = []uint32{0, 0} + func (s *service) GetPubkey(ctx context.Context) (*secp256k1.PublicKey, error) { - res, err := s.walletClient.GetInfo(ctx, &pb.GetInfoRequest{}) - if err != nil { - return nil, err - } - if len(res.GetAccounts()) <= 0 { - return nil, fmt.Errorf("wallet is locked") - } - xpub := res.GetAccounts()[0].GetXpubs()[0] - node, err := hdkeychain.NewKeyFromString(xpub) - if err != nil { - return nil, err - } - for i := 0; i < 2; i++ { - node, err = node.Derive(0) - if err != nil { - return nil, err - } - } - return node.ECPubKey() + key, _, err := s.getPubkey(ctx) + return key, err } func (s *service) Status( @@ -57,3 +43,53 @@ func (w walletStatus) IsUnlocked() bool { func (w walletStatus) IsSynced() bool { return w.StatusResponse.GetSynced() } + +func (s *service) findAccount(ctx context.Context, label string) (*pb.AccountInfo, error) { + res, err := s.walletClient.GetInfo(ctx, &pb.GetInfoRequest{}) + if err != nil { + return nil, err + } + if len(res.GetAccounts()) <= 0 { + return nil, fmt.Errorf("wallet is locked") + } + + for _, account := range res.GetAccounts() { + if account.GetLabel() == label { + return account, nil + } + } + + return nil, fmt.Errorf("account not found") +} + +func (s *service) getPubkey(ctx context.Context) (*secp256k1.PublicKey, *bip32.Key, error) { + account, err := s.findAccount(ctx, accountLabel) + if err != nil { + return nil, nil, err + } + + xpub := account.GetXpubs()[0] + + node, err := hdkeychain.NewKeyFromString(xpub) + if err != nil { + return nil, nil, err + } + for _, i := range derivationPath { + node, err = node.Derive(i) + if err != nil { + return nil, nil, err + } + } + key, err := node.ECPubKey() + + if err != nil { + return nil, nil, err + } + + masterKey, err := bip32.B58Deserialize(xpub) + if err != nil { + return nil, nil, err + } + + return key, masterKey, nil +} diff --git a/asp/internal/infrastructure/scheduler/gocron/service.go b/asp/internal/infrastructure/scheduler/gocron/service.go index 391ef8e..f2e0c5a 100644 --- a/asp/internal/infrastructure/scheduler/gocron/service.go +++ b/asp/internal/infrastructure/scheduler/gocron/service.go @@ -1,6 +1,7 @@ package scheduler import ( + "fmt" "time" "github.com/ark-network/ark/internal/core/ports" @@ -32,3 +33,13 @@ func (s *service) ScheduleTask(interval int64, immediate bool, task func()) erro _, err := s.scheduler.Every(int(interval)).Seconds().WaitForSchedule().Do(task) return err } + +func (s *service) ScheduleTaskOnce(at int64, task func()) error { + delay := at - time.Now().Unix() + if delay < 0 { + return fmt.Errorf("cannot schedule task in the past") + } + + _, err := s.scheduler.Every(int(delay)).Seconds().WaitForSchedule().LimitRunsTo(1).Do(task) + return err +} diff --git a/asp/internal/infrastructure/tx-builder/covenant/builder.go b/asp/internal/infrastructure/tx-builder/covenant/builder.go index b1a8dad..0e2c90a 100644 --- a/asp/internal/infrastructure/tx-builder/covenant/builder.go +++ b/asp/internal/infrastructure/tx-builder/covenant/builder.go @@ -21,14 +21,15 @@ const ( ) type txBuilder struct { - wallet ports.WalletService - net *network.Network + wallet ports.WalletService + net *network.Network + roundLifetime int64 // in seconds } func NewTxBuilder( - wallet ports.WalletService, net network.Network, + wallet ports.WalletService, net network.Network, roundLifetime int64, ) ports.TxBuilder { - return &txBuilder{wallet, &net} + return &txBuilder{wallet, &net, roundLifetime} } func (b *txBuilder) GetVtxoScript(userPubkey, aspPubkey *secp256k1.PublicKey) ([]byte, error) { @@ -39,6 +40,44 @@ func (b *txBuilder) GetVtxoScript(userPubkey, aspPubkey *secp256k1.PublicKey) ([ return outputScript, nil } +func (b *txBuilder) BuildSweepTx(wallet ports.WalletService, inputs []ports.SweepInput) (signedSweepTx string, err error) { + sweepPset, err := sweepTransaction( + wallet, + inputs, + b.net.AssetID, + ) + if err != nil { + return "", err + } + + sweepPsetBase64, err := sweepPset.ToBase64() + if err != nil { + return "", err + } + + ctx := context.Background() + signedSweepPsetB64, err := wallet.SignPsetWithKey(ctx, sweepPsetBase64, nil) + if err != nil { + return "", err + } + + signedPset, err := psetv2.NewPsetFromBase64(signedSweepPsetB64) + if err != nil { + return "", err + } + + if err := psetv2.FinalizeAll(signedPset); err != nil { + return "", err + } + + extractedTx, err := psetv2.Extract(signedPset) + if err != nil { + return "", err + } + + return extractedTx.ToHex() +} + func (b *txBuilder) BuildForfeitTxs( aspPubkey *secp256k1.PublicKey, poolTx string, payments []domain.Payment, ) (connectors []string, forfeitTxs []string, err error) { @@ -74,7 +113,7 @@ func (b *txBuilder) BuildPoolTx( // This is safe as the memory allocated for `craftCongestionTree` is freed // only after `BuildPoolTx` returns. treeFactoryFn, sharedOutputScript, sharedOutputAmount, err := craftCongestionTree( - b.net.AssetID, aspPubkey, payments, minRelayFee, + b.net.AssetID, aspPubkey, payments, minRelayFee, b.roundLifetime, ) if err != nil { return @@ -109,6 +148,45 @@ func (b *txBuilder) BuildPoolTx( return } +func (b *txBuilder) GetLeafSweepClosure( + node tree.Node, userPubKey *secp256k1.PublicKey, +) (*psetv2.TapLeafScript, int64, error) { + if !node.Leaf { + return nil, 0, fmt.Errorf("node is not a leaf") + } + + pset, err := psetv2.NewPsetFromBase64(node.Tx) + if err != nil { + return nil, 0, err + } + + input := pset.Inputs[0] + + sweepLeaf, lifetime, err := extractSweepLeaf(input) + if err != nil { + return nil, 0, err + } + + // craft the vtxo taproot tree + vtxoScript, err := tree.VtxoScript(userPubKey) + if err != nil { + return nil, 0, err + } + + vtxoTaprootTree := taproot.AssembleTaprootScriptTree( + *vtxoScript, + sweepLeaf.TapElementsLeaf, + ) + + proofIndex := vtxoTaprootTree.LeafProofIndex[sweepLeaf.TapHash()] + proof := vtxoTaprootTree.LeafMerkleProofs[proofIndex] + + return &psetv2.TapLeafScript{ + TapElementsLeaf: proof.TapElementsLeaf, + ControlBlock: proof.ToControlBlock(sweepLeaf.ControlBlock.InternalKey), + }, lifetime, nil +} + func (b *txBuilder) getLeafScriptAndTree( userPubkey, aspPubkey *secp256k1.PublicKey, ) ([]byte, *taproot.IndexedElementsTapScriptTree, error) { @@ -117,7 +195,7 @@ func (b *txBuilder) getLeafScriptAndTree( return nil, nil, err } - sweepClosure, err := tree.SweepScript(aspPubkey, expirationTime) + sweepClosure, err := tree.SweepScript(aspPubkey, uint(b.roundLifetime)) if err != nil { return nil, nil, err } @@ -382,3 +460,24 @@ func (b *txBuilder) createForfeitTxs( } return forfeitTxs, nil } + +// given a congestion tree input, searches and returns the sweep leaf and its lifetime in seconds +func extractSweepLeaf(input psetv2.Input) (sweepLeaf *psetv2.TapLeafScript, lifetime int64, err error) { + for _, leaf := range input.TapLeafScript { + isSweep, _, seconds, err := tree.DecodeSweepScript(leaf.Script) + if err != nil { + return nil, 0, err + } + if isSweep { + lifetime = int64(seconds) + sweepLeaf = &leaf + break + } + } + + if sweepLeaf == nil { + return nil, 0, fmt.Errorf("sweep leaf not found") + } + + return sweepLeaf, lifetime, nil +} diff --git a/asp/internal/infrastructure/tx-builder/covenant/builder_test.go b/asp/internal/infrastructure/tx-builder/covenant/builder_test.go index 89cd625..5231e2e 100644 --- a/asp/internal/infrastructure/tx-builder/covenant/builder_test.go +++ b/asp/internal/infrastructure/tx-builder/covenant/builder_test.go @@ -44,7 +44,7 @@ func TestMain(m *testing.M) { } func TestBuildPoolTx(t *testing.T) { - builder := txbuilder.NewTxBuilder(wallet, network.Liquid) + builder := txbuilder.NewTxBuilder(wallet, network.Liquid, roundLifetime) fixtures, err := parsePoolTxFixtures() require.NoError(t, err) @@ -79,7 +79,7 @@ func TestBuildPoolTx(t *testing.T) { } func TestBuildForfeitTxs(t *testing.T) { - builder := txbuilder.NewTxBuilder(wallet, network.Liquid) + builder := txbuilder.NewTxBuilder(wallet, network.Liquid, 1209344) fixtures, err := parseForfeitTxsFixtures() require.NoError(t, err) diff --git a/asp/internal/infrastructure/tx-builder/covenant/mocks_test.go b/asp/internal/infrastructure/tx-builder/covenant/mocks_test.go index d8d4fdf..574d251 100644 --- a/asp/internal/infrastructure/tx-builder/covenant/mocks_test.go +++ b/asp/internal/infrastructure/tx-builder/covenant/mocks_test.go @@ -97,6 +97,32 @@ func (m *mockedWallet) EstimateFees(ctx context.Context, pset string) (uint64, e return res, args.Error(1) } +func (m *mockedWallet) IsTransactionPublished(ctx context.Context, txid string) (bool, int64, error) { + args := m.Called(ctx, txid) + + var res bool + if a := args.Get(0); a != nil { + res = a.(bool) + } + + var blocktime int64 + if b := args.Get(1); b != nil { + blocktime = b.(int64) + } + + return res, blocktime, args.Error(2) +} + +func (m *mockedWallet) SignPsetWithKey(ctx context.Context, pset string, inputIndexes []int) (string, error) { + args := m.Called(ctx, pset, inputIndexes) + + var res string + if a := args.Get(0); a != nil { + res = a.(string) + } + return res, args.Error(1) +} + func (m *mockedWallet) WatchScripts( ctx context.Context, scripts []string, ) error { diff --git a/asp/internal/infrastructure/tx-builder/covenant/sweep.go b/asp/internal/infrastructure/tx-builder/covenant/sweep.go new file mode 100644 index 0000000..9a9fed6 --- /dev/null +++ b/asp/internal/infrastructure/tx-builder/covenant/sweep.go @@ -0,0 +1,135 @@ +package txbuilder + +import ( + "context" + "fmt" + + "github.com/ark-network/ark/common" + "github.com/ark-network/ark/common/tree" + "github.com/ark-network/ark/internal/core/ports" + "github.com/vulpemventures/go-elements/address" + "github.com/vulpemventures/go-elements/elementsutil" + "github.com/vulpemventures/go-elements/psetv2" + "github.com/vulpemventures/go-elements/taproot" + "github.com/vulpemventures/go-elements/transaction" +) + +func sweepTransaction( + wallet ports.WalletService, + sweepInputs []ports.SweepInput, + lbtc string, +) (*psetv2.Pset, error) { + sweepPset, err := psetv2.New(nil, nil, nil) + if err != nil { + return nil, err + } + + updater, err := psetv2.NewUpdater(sweepPset) + if err != nil { + return nil, err + } + + amount := uint64(0) + + for i, input := range sweepInputs { + leaf := input.SweepLeaf + isSweep, _, lifetime, err := tree.DecodeSweepScript(leaf.Script) + if err != nil { + return nil, err + } + + if isSweep { + amount += input.Amount + + if err := updater.AddInputs([]psetv2.InputArgs{input.InputArgs}); err != nil { + return nil, err + } + + if err := updater.AddInTapLeafScript(i, leaf); err != nil { + return nil, err + } + + assetHash, err := elementsutil.AssetHashToBytes(lbtc) + if err != nil { + return nil, err + } + + value, err := elementsutil.ValueToBytes(input.Amount) + if err != nil { + return nil, err + } + + root := leaf.ControlBlock.RootHash(leaf.Script) + taprootKey := taproot.ComputeTaprootOutputKey(leaf.ControlBlock.InternalKey, root) + script, err := taprootOutputScript(taprootKey) + if err != nil { + return nil, err + } + + witnessUtxo := transaction.NewTxOutput(assetHash, value, script) + + if err := updater.AddInWitnessUtxo(i, witnessUtxo); err != nil { + return nil, err + } + + sequence, err := common.BIP68EncodeAsNumber(lifetime) + if err != nil { + return nil, err + } + + updater.Pset.Inputs[i].Sequence = sequence + continue + } + + return nil, fmt.Errorf("invalid sweep script") + } + + ctx := context.Background() + + sweepAddress, err := wallet.DeriveAddresses(ctx, 1) + if err != nil { + return nil, err + } + + script, err := address.ToOutputScript(sweepAddress[0]) + if err != nil { + return nil, err + } + + if err := updater.AddOutputs([]psetv2.OutputArgs{ + { + Asset: lbtc, + Amount: amount, + Script: script, + }, + }); err != nil { + return nil, err + } + + b64, err := sweepPset.ToBase64() + if err != nil { + return nil, err + } + + fees, err := wallet.EstimateFees(ctx, b64) + if err != nil { + return nil, err + } + + if amount < fees { + return nil, fmt.Errorf("insufficient funds (%d) to cover fees (%d) for sweep transaction", amount, fees) + } + + updater.Pset.Outputs[0].Value = amount - fees + + if err := updater.AddOutputs([]psetv2.OutputArgs{ + { + Asset: lbtc, + Amount: fees, + }, + }); err != nil { + return nil, err + } + + return sweepPset, nil +} diff --git a/asp/internal/infrastructure/tx-builder/covenant/tree.go b/asp/internal/infrastructure/tx-builder/covenant/tree.go index d049803..c3755c4 100644 --- a/asp/internal/infrastructure/tx-builder/covenant/tree.go +++ b/asp/internal/infrastructure/tx-builder/covenant/tree.go @@ -12,19 +12,16 @@ import ( "github.com/vulpemventures/go-elements/taproot" ) -const ( - expirationTime = 60 * 60 * 24 * 14 // 14 days in seconds -) - type treeFactory func(outpoint psetv2.InputArgs) (tree.CongestionTree, error) type node struct { - sweepKey *secp256k1.PublicKey - receivers []domain.Receiver - left *node - right *node - asset string - feeSats uint64 + sweepKey *secp256k1.PublicKey + receivers []domain.Receiver + left *node + right *node + asset string + feeSats uint64 + roundLifetime int64 _inputTaprootKey *secp256k1.PublicKey _inputTaprootTree *taproot.IndexedElementsTapScriptTree @@ -133,7 +130,7 @@ func (n *node) getWitnessData() ( return n._inputTaprootKey, n._inputTaprootTree, nil } - sweepClosure, err := tree.SweepScript(n.sweepKey, expirationTime) + sweepClosure, err := tree.SweepScript(n.sweepKey, uint(n.roundLifetime)) if err != nil { return nil, nil, err } @@ -203,7 +200,7 @@ func (n *node) getVtxoWitnessData() ( return nil, nil, fmt.Errorf("cannot call vtxoWitness on a non-leaf node") } - sweepClosure, err := tree.SweepScript(n.sweepKey, expirationTime) + sweepClosure, err := tree.SweepScript(n.sweepKey, uint(n.roundLifetime)) if err != nil { return nil, nil, err } @@ -360,14 +357,14 @@ func (n *node) createFinalCongestionTree() treeFactory { func craftCongestionTree( asset string, aspPublicKey *secp256k1.PublicKey, - payments []domain.Payment, feeSatsPerNode uint64, + payments []domain.Payment, feeSatsPerNode uint64, roundLifetime int64, ) ( buildCongestionTree treeFactory, sharedOutputScript []byte, sharedOutputAmount uint64, err error, ) { receivers := getOffchainReceivers(payments) root, err := createPartialCongestionTree( - receivers, aspPublicKey, asset, feeSatsPerNode, + receivers, aspPublicKey, asset, feeSatsPerNode, roundLifetime, ) if err != nil { return @@ -393,6 +390,7 @@ func createPartialCongestionTree( aspPublicKey *secp256k1.PublicKey, asset string, feeSatsPerNode uint64, + roundLifetime int64, ) (root *node, err error) { if len(receivers) == 0 { return nil, fmt.Errorf("no receivers provided") @@ -401,10 +399,11 @@ func createPartialCongestionTree( nodes := make([]*node, 0, len(receivers)) for _, r := range receivers { leafNode := &node{ - sweepKey: aspPublicKey, - receivers: []domain.Receiver{r}, - asset: asset, - feeSats: feeSatsPerNode, + sweepKey: aspPublicKey, + receivers: []domain.Receiver{r}, + asset: asset, + feeSats: feeSatsPerNode, + roundLifetime: roundLifetime, } nodes = append(nodes, leafNode) } @@ -435,12 +434,13 @@ func createUpperLevel(nodes []*node) ([]*node, error) { left := nodes[i] right := nodes[i+1] branchNode := &node{ - sweepKey: left.sweepKey, - receivers: append(left.receivers, right.receivers...), - left: left, - right: right, - asset: left.asset, - feeSats: left.feeSats, + sweepKey: left.sweepKey, + receivers: append(left.receivers, right.receivers...), + left: left, + right: right, + asset: left.asset, + feeSats: left.feeSats, + roundLifetime: left.roundLifetime, } pairs = append(pairs, branchNode) } diff --git a/asp/internal/infrastructure/tx-builder/dummy/builder.go b/asp/internal/infrastructure/tx-builder/dummy/builder.go index 6d12597..59a98b1 100644 --- a/asp/internal/infrastructure/tx-builder/dummy/builder.go +++ b/asp/internal/infrastructure/tx-builder/dummy/builder.go @@ -16,6 +16,7 @@ import ( const ( connectorAmount = 450 + sevenDays = 7 * 24 * 60 * 60 ) type txBuilder struct { @@ -29,6 +30,11 @@ func NewTxBuilder( return &txBuilder{wallet, net} } +// BuildSweepTx implements ports.TxBuilder. +func (*txBuilder) BuildSweepTx(wallet ports.WalletService, inputs []ports.SweepInput) (signedSweepTx string, err error) { + panic("unimplemented") +} + // BuildForfeitTxs implements ports.TxBuilder. func (b *txBuilder) BuildForfeitTxs( aspPubkey *secp256k1.PublicKey, poolTx string, payments []domain.Payment, @@ -185,6 +191,12 @@ func (b *txBuilder) GetVtxoScript(userPubkey, _ *secp256k1.PublicKey) ([]byte, e return address.ToOutputScript(addr) } +func (b *txBuilder) GetLeafSweepClosure( + node tree.Node, userPubKey *secp256k1.PublicKey, +) (*psetv2.TapLeafScript, int64, error) { + panic("unimplemented") +} + func connectorsToInputArgs(connectors []string) ([]psetv2.InputArgs, error) { inputs := make([]psetv2.InputArgs, 0, len(connectors)+1) for i, psetb64 := range connectors { diff --git a/asp/internal/infrastructure/tx-builder/dummy/builder_test.go b/asp/internal/infrastructure/tx-builder/dummy/builder_test.go index 6f03960..f185610 100644 --- a/asp/internal/infrastructure/tx-builder/dummy/builder_test.go +++ b/asp/internal/infrastructure/tx-builder/dummy/builder_test.go @@ -109,6 +109,14 @@ func (*mockedWalletService) EstimateFees(ctx context.Context, pset string) (uint return 100, nil } +func (*mockedWalletService) SignPsetWithKey(ctx context.Context, pset string, inputIndex []int) (string, error) { + panic("unimplemented") +} + +func (*mockedWalletService) IsTransactionPublished(ctx context.Context, txid string) (bool, int64, error) { + panic("unimplemented") +} + func TestBuildCongestionTree(t *testing.T) { builder := txbuilder.NewTxBuilder(&mockedWalletService{}, network.Liquid) diff --git a/asp/internal/interface/grpc/service.go b/asp/internal/interface/grpc/service.go index 3f29952..0632ea1 100644 --- a/asp/internal/interface/grpc/service.go +++ b/asp/internal/interface/grpc/service.go @@ -52,6 +52,7 @@ func (s *service) Start() error { return fmt.Errorf("failed to start app service: %s", err) } log.Info("started app service") + return nil } diff --git a/common/bip68.go b/common/bip68.go index c5e496e..333dc45 100644 --- a/common/bip68.go +++ b/common/bip68.go @@ -18,17 +18,25 @@ func closerToModulo512(x uint) uint { return x - (x % 512) } -// BIP68Encode returns the encoded sequence locktime for the given number of seconds. -func BIP68Encode(seconds uint) ([]byte, error) { +func BIP68EncodeAsNumber(seconds uint) (uint32, error) { seconds = closerToModulo512(seconds) if seconds > SECONDS_MAX { - return nil, fmt.Errorf("seconds too large, max is %d", SECONDS_MAX) + return 0, fmt.Errorf("seconds too large, max is %d", SECONDS_MAX) } if seconds%SECONDS_MOD != 0 { - return nil, fmt.Errorf("seconds must be a multiple of %d", SECONDS_MOD) + return 0, fmt.Errorf("seconds must be a multiple of %d", SECONDS_MOD) } asNumber := SEQUENCE_LOCKTIME_TYPE_FLAG | (seconds >> SEQUENCE_LOCKTIME_GRANULARITY) + return uint32(asNumber), nil +} + +// BIP68Encode returns the encoded sequence locktime for the given number of seconds. +func BIP68Encode(seconds uint) ([]byte, error) { + asNumber, err := BIP68EncodeAsNumber(seconds) + if err != nil { + return nil, err + } hexString := fmt.Sprintf("%x", asNumber) reversed, err := hex.DecodeString(hexString) if err != nil { diff --git a/common/tree/congestion_tree.go b/common/tree/congestion_tree.go index 3ca6438..545fff2 100644 --- a/common/tree/congestion_tree.go +++ b/common/tree/congestion_tree.go @@ -1,6 +1,11 @@ package tree -import "errors" +import ( + "errors" + + "github.com/btcsuite/btcd/chaincfg/chainhash" + "github.com/vulpemventures/go-elements/psetv2" +) // Node is a struct embedding the transaction and the parent txid of a congestion tree node type Node struct { @@ -19,6 +24,19 @@ var ( // the first level of the matrix is the root of the tree type CongestionTree [][]Node +// Root returns the root node of the congestion tree +func (c CongestionTree) Root() (Node, error) { + if len(c) <= 0 { + return Node{}, errors.New("empty congestion tree") + } + + if len(c[0]) <= 0 { + return Node{}, errors.New("empty congestion tree") + } + + return c[0][0], nil +} + // Leaves returns the leaves of the congestion tree (the vtxos txs) func (c CongestionTree) Leaves() []Node { leaves := c[len(c)-1] @@ -47,6 +65,7 @@ func (c CongestionTree) Children(nodeTxid string) []Node { return children } +// NumberOfNodes returns the total number of pset in the congestion tree func (c CongestionTree) NumberOfNodes() int { var count int for _, level := range c { @@ -55,6 +74,7 @@ func (c CongestionTree) NumberOfNodes() int { return count } +// Branch returns the branch of the given vtxo txid from root to leaf in the order of the congestion tree func (c CongestionTree) Branch(vtxoTxid string) ([]Node, error) { branch := make([]Node, 0) @@ -85,6 +105,37 @@ func (c CongestionTree) Branch(vtxoTxid string) ([]Node, error) { return branch, nil } +// FindLeaves returns all the leaves that are reachable from the given node output +func (c CongestionTree) FindLeaves(fromtxid string, vout uint32) ([]Node, error) { + allLeaves := c.Leaves() + foundLeaves := make([]Node, 0) + + for _, leaf := range allLeaves { + branch, err := c.Branch(leaf.Txid) + if err != nil { + return nil, err + } + + for _, node := range branch { + pset, err := psetv2.NewPsetFromBase64(node.Tx) + if err != nil { + return nil, err + } + + input := pset.Inputs[0] + txid := chainhash.Hash(input.PreviousTxid).String() + index := input.PreviousTxIndex + + if txid == fromtxid && index == vout { + foundLeaves = append(foundLeaves, leaf) + break + } + } + } + + return foundLeaves, nil +} + func (n Node) findParent(tree CongestionTree) (Node, error) { for _, level := range tree { for _, node := range level { diff --git a/common/tree/script.go b/common/tree/script.go index 775d4d2..cd6019b 100644 --- a/common/tree/script.go +++ b/common/tree/script.go @@ -106,12 +106,8 @@ func decodeWithOutputScript(script []byte, expectedIndex byte, isVerify bool) (v return false, nil, 0, err } - inspectOutputValueIndex := bytes.IndexByte(script, OP_INSPECTOUTPUTVALUE) - if inspectOutputValueIndex == -1 { - return false, nil, 0, nil - } - - if script[inspectOutputValueIndex-1] != expectedIndex { + // verify the index of INSPECTVALUE + if script[38] != expectedIndex { return false, nil, 0, nil } @@ -128,12 +124,12 @@ func decodeWithOutputScript(script []byte, expectedIndex byte, isVerify bool) (v } func decodeChecksigScript(script []byte) (valid bool, pubkey *secp256k1.PublicKey, err error) { - checksigIndex := bytes.Index(script, []byte{txscript.OP_CHECKSIG}) - if checksigIndex == -1 || checksigIndex == 0 { + data32Index := bytes.Index(script, []byte{txscript.OP_DATA_32}) + if data32Index == -1 { return false, nil, nil } - key := script[1:checksigIndex] + key := script[data32Index+1 : data32Index+33] if len(key) != 32 { return false, nil, nil } @@ -155,13 +151,13 @@ func decodeChecksigScript(script []byte) (valid bool, pubkey *secp256k1.PublicKe return true, pubkey, nil } -func decodeSweepScript(script []byte) (valid bool, aspPubKey *secp256k1.PublicKey, seconds uint, err error) { +func DecodeSweepScript(script []byte) (valid bool, aspPubKey *secp256k1.PublicKey, seconds uint, err error) { csvIndex := bytes.Index(script, []byte{txscript.OP_CHECKSEQUENCEVERIFY, txscript.OP_DROP}) if csvIndex == -1 || csvIndex == 0 { return false, nil, 0, nil } - sequence := script[:csvIndex] + sequence := script[1:csvIndex] seconds, err = common.BIP68Decode(sequence) if err != nil { @@ -174,6 +170,10 @@ func decodeSweepScript(script []byte) (valid bool, aspPubKey *secp256k1.PublicKe return false, nil, 0, err } + if !valid { + return false, nil, 0, nil + } + rebuilt, err := csvChecksigScript(aspPubKey, seconds) if err != nil { return false, nil, 0, err @@ -193,10 +193,10 @@ func checkSequenceVerifyScript(seconds uint) ([]byte, error) { return nil, err } - return append(sequence, []byte{ + return txscript.NewScriptBuilder().AddData(sequence).AddOps([]byte{ txscript.OP_CHECKSEQUENCEVERIFY, txscript.OP_DROP, - }...), nil + }).Script() } // checkSequenceVerifyScript + checksig diff --git a/common/tree/validation.go b/common/tree/validation.go index 1f23a91..c5201db 100644 --- a/common/tree/validation.go +++ b/common/tree/validation.go @@ -229,7 +229,7 @@ func validateNodeTransaction( return ErrInvalidTaprootScript } - isSweepLeaf, aspKey, seconds, err := decodeSweepScript(tapLeaf.Script) + isSweepLeaf, aspKey, seconds, err := DecodeSweepScript(tapLeaf.Script) if err != nil { return fmt.Errorf("invalid sweep script: %w", err) } diff --git a/noah/go.mod b/noah/go.mod index f99e6c1..22c2ac5 100644 --- a/noah/go.mod +++ b/noah/go.mod @@ -31,7 +31,7 @@ require ( github.com/golang/protobuf v1.5.3 // indirect github.com/grpc-ecosystem/grpc-gateway/v2 v2.18.1 // indirect github.com/russross/blackfriday/v2 v2.1.0 // indirect - github.com/vulpemventures/go-elements v0.5.2 + github.com/vulpemventures/go-elements v0.5.3 github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673 // indirect golang.org/x/net v0.19.0 // indirect golang.org/x/sys v0.15.0 // indirect diff --git a/noah/go.sum b/noah/go.sum index 0b28329..2161edd 100644 --- a/noah/go.sum +++ b/noah/go.sum @@ -91,8 +91,8 @@ github.com/urfave/cli/v2 v2.26.0 h1:3f3AMg3HpThFNT4I++TKOejZO8yU55t3JnnSr4S4QEI= github.com/urfave/cli/v2 v2.26.0/go.mod h1:8qnjx1vcq5s2/wpsqoZFndg2CE5tNFyrTvS6SinrnYQ= github.com/vulpemventures/fastsha256 v0.0.0-20160815193821-637e65642941 h1:CTcw80hz/Sw8hqlKX5ZYvBUF5gAHSHwdjXxRf/cjDcI= github.com/vulpemventures/fastsha256 v0.0.0-20160815193821-637e65642941/go.mod h1:GXBJykxW2kUcktGdsgyay7uwwWvkljASfljNcT0mbh8= -github.com/vulpemventures/go-elements v0.5.2 h1:vIDzVpRXG5PnlzHA8tCnr2Tn7raIV5cHy7bRyDrbuM4= -github.com/vulpemventures/go-elements v0.5.2/go.mod h1:aBGuWXHaiAIUIcwqCdtEh2iQ3kJjKwHU9ywvhlcRSeU= +github.com/vulpemventures/go-elements v0.5.3 h1:zaC/ynHFwCAzFSOMfzb6BcbD6FXASppSiGMycc95WVA= +github.com/vulpemventures/go-elements v0.5.3/go.mod h1:aBGuWXHaiAIUIcwqCdtEh2iQ3kJjKwHU9ywvhlcRSeU= github.com/vulpemventures/go-secp256k1-zkp v1.1.6 h1:BmsrmXRLUibwa75Qkk8yELjpzCzlAjYFGLiLiOdq7Xo= github.com/vulpemventures/go-secp256k1-zkp v1.1.6/go.mod h1:zo7CpgkuPgoe7fAV+inyxsI9IhGmcoFgyD8nqZaPSOM= github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673 h1:bAn7/zixMGCfxrRTfdpNzjtPYqr8smhKouy9mxVdGPU= diff --git a/noah/redeem.go b/noah/redeem.go index cc21abf..321e43d 100644 --- a/noah/redeem.go +++ b/noah/redeem.go @@ -218,16 +218,6 @@ func unilateralRedeem(ctx *cli.Context, addr string) error { transactionsMap := make(map[string]struct{}, 0) transactions := make([]string, 0) - aspPublicKey, err := getServiceProviderPublicKey() - if err != nil { - return err - } - - sweepLeaf, err := tree.SweepScript(aspPublicKey, 1209344) - if err != nil { - return err - } - for _, vtxo := range vtxos { if _, ok := congestionTrees[vtxo.poolTxid]; !ok { round, err := client.GetRound(ctx.Context, &arkv1.GetRoundRequest{ @@ -246,7 +236,7 @@ func unilateralRedeem(ctx *cli.Context, addr string) error { congestionTrees[vtxo.poolTxid] = congestionTree } - redeemBranch, err := newRedeemBranch(ctx, congestionTrees[vtxo.poolTxid], vtxo, sweepLeaf) + redeemBranch, err := newRedeemBranch(ctx, congestionTrees[vtxo.poolTxid], vtxo) if err != nil { return err } diff --git a/noah/unilateral_redeem.go b/noah/unilateral_redeem.go index 05c7515..4733b28 100644 --- a/noah/unilateral_redeem.go +++ b/noah/unilateral_redeem.go @@ -1,7 +1,6 @@ package main import ( - "bytes" "fmt" "github.com/ark-network/ark/common/tree" @@ -24,16 +23,40 @@ type RedeemBranch interface { type redeemBranch struct { vtxo *vtxo branch []*psetv2.Pset - sweepTapLeaf *taproot.TapElementsLeaf internalKey *secp256k1.PublicKey + sweepClosure *taproot.TapElementsLeaf } -func newRedeemBranch(ctx *cli.Context, congestionTree tree.CongestionTree, vtxo vtxo, sweepLeaf *taproot.TapElementsLeaf) (RedeemBranch, error) { +func newRedeemBranch(ctx *cli.Context, congestionTree tree.CongestionTree, vtxo vtxo) (RedeemBranch, error) { nodes, err := congestionTree.Branch(vtxo.txid) if err != nil { return nil, err } + // find the sweep closure + tx, err := psetv2.NewPsetFromBase64(nodes[0].Tx) + if err != nil { + return nil, err + } + + var sweepClosure *taproot.TapElementsLeaf + + for _, tapLeaf := range tx.Inputs[0].TapLeafScript { + isSweep, _, _, err := tree.DecodeSweepScript(tapLeaf.Script) + if err != nil { + continue + } + + if isSweep { + sweepClosure = &tapLeaf.TapElementsLeaf + break + } + } + + if sweepClosure == nil { + return nil, fmt.Errorf("sweep closure not found") + } + branch := make([]*psetv2.Pset, 0, len(nodes)) for _, node := range nodes { pset, err := psetv2.NewPsetFromBase64(node.Tx) @@ -52,8 +75,8 @@ func newRedeemBranch(ctx *cli.Context, congestionTree tree.CongestionTree, vtxo return &redeemBranch{ vtxo: &vtxo, branch: branch, - sweepTapLeaf: sweepLeaf, internalKey: internalKey, + sweepClosure: sweepClosure, }, nil } @@ -96,10 +119,13 @@ func (r *redeemBranch) RedeemPath() ([]string, error) { return nil, fmt.Errorf("tap leaf script not found on input #%d", i) } - sweepTapLeafScript := r.sweepTapLeaf.Script - for _, leaf := range input.TapLeafScript { - if bytes.Equal(leaf.Script, sweepTapLeafScript) { + isSweep, _, _, err := tree.DecodeSweepScript(leaf.Script) + if err != nil { + return nil, err + } + + if isSweep { continue } @@ -159,7 +185,7 @@ func (r *redeemBranch) AddVtxoInput(updater *psetv2.Updater) error { vtxoTaprootTree := taproot.AssembleTaprootScriptTree( *checksigLeaf, - *r.sweepTapLeaf, + *r.sweepClosure, ) proofIndex := vtxoTaprootTree.LeafProofIndex[checksigLeaf.TapHash()]