Support round expiration and sweep vtxos (#70)

* sweeper base implementation

* sweeper service final implementation

* fixes

* fix CSV script

* RoundSwept event fix & test

* remove Vtxos after a sweep transaction

* ARK_ROUND_LIFETIME config

* remove TxBuilder.GetLifetime

* refactor sweeper

* use GetTransaction blocktime

* polish and comments

* fix linting

* pair programming fixes

* several fixes

* clean Println

* fixes

* linter fixes

* remove infrastructure deps from application layer

* Fixes

---------

Co-authored-by: altafan <18440657+altafan@users.noreply.github.com>
This commit is contained in:
Louis Singer
2024-02-08 16:58:04 +01:00
committed by GitHub
parent 58aa36b7e7
commit 287db4e08a
35 changed files with 1403 additions and 114 deletions

View File

@@ -30,6 +30,13 @@ func main() {
Port: cfg.Port, Port: cfg.Port,
NoTLS: cfg.NoTLS, NoTLS: cfg.NoTLS,
} }
if cfg.RoundLifetime%512 != 0 {
setLifetime := cfg.RoundLifetime
cfg.RoundLifetime = cfg.RoundLifetime - (cfg.RoundLifetime % 512)
log.Infof("round lifetime must be a multiple of 512, %d -> %d", setLifetime, cfg.RoundLifetime)
}
appConfig := &appconfig.Config{ appConfig := &appconfig.Config{
DbType: cfg.DbType, DbType: cfg.DbType,
DbDir: cfg.DbDir, DbDir: cfg.DbDir,
@@ -40,6 +47,7 @@ func main() {
BlockchainScannerType: cfg.BlockchainScannerType, BlockchainScannerType: cfg.BlockchainScannerType,
WalletAddr: cfg.WalletAddr, WalletAddr: cfg.WalletAddr,
MinRelayFee: cfg.MinRelayFee, MinRelayFee: cfg.MinRelayFee,
RoundLifetime: cfg.RoundLifetime,
} }
svc, err := grpcservice.NewService(svcConfig, appConfig) svc, err := grpcservice.NewService(svcConfig, appConfig)
if err != nil { if err != nil {

View File

@@ -18,7 +18,7 @@ require (
github.com/stretchr/testify v1.8.4 github.com/stretchr/testify v1.8.4
github.com/timshannon/badgerhold/v4 v4.0.3 github.com/timshannon/badgerhold/v4 v4.0.3
github.com/urfave/cli/v2 v2.26.0 github.com/urfave/cli/v2 v2.26.0
github.com/vulpemventures/go-elements v0.5.2 github.com/vulpemventures/go-elements v0.5.3
google.golang.org/genproto/googleapis/api v0.0.0-20231106174013-bbf56f31fb17 google.golang.org/genproto/googleapis/api v0.0.0-20231106174013-bbf56f31fb17
google.golang.org/grpc v1.59.0 google.golang.org/grpc v1.59.0
google.golang.org/protobuf v1.31.0 google.golang.org/protobuf v1.31.0
@@ -26,6 +26,11 @@ require (
require github.com/stretchr/objx v0.5.0 // indirect require github.com/stretchr/objx v0.5.0 // indirect
require (
github.com/FactomProject/basen v0.0.0-20150613233007-fe3947df716e // indirect
github.com/FactomProject/btcutilecc v0.0.0-20130527213604-d3a63a5752ec // indirect
)
require ( require (
github.com/btcsuite/btcd v0.23.1 github.com/btcsuite/btcd v0.23.1
github.com/btcsuite/btcd/btcec/v2 v2.3.2 github.com/btcsuite/btcd/btcec/v2 v2.3.2
@@ -62,6 +67,7 @@ require (
github.com/spf13/pflag v1.0.5 // indirect github.com/spf13/pflag v1.0.5 // indirect
github.com/subosito/gotenv v1.6.0 // indirect github.com/subosito/gotenv v1.6.0 // indirect
github.com/vulpemventures/fastsha256 v0.0.0-20160815193821-637e65642941 // indirect github.com/vulpemventures/fastsha256 v0.0.0-20160815193821-637e65642941 // indirect
github.com/vulpemventures/go-bip32 v0.0.0-20200624192635-867c159da4d7
github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673 // indirect github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673 // indirect
go.opencensus.io v0.24.0 // indirect go.opencensus.io v0.24.0 // indirect
go.uber.org/atomic v1.9.0 // indirect go.uber.org/atomic v1.9.0 // indirect

View File

@@ -38,6 +38,10 @@ cloud.google.com/go/storage v1.14.0/go.mod h1:GrKmX003DSIwi9o29oFT7YDnHYwZoctc3f
dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU= dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU=
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo= github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo=
github.com/FactomProject/basen v0.0.0-20150613233007-fe3947df716e h1:ahyvB3q25YnZWly5Gq1ekg6jcmWaGj/vG/MhF4aisoc=
github.com/FactomProject/basen v0.0.0-20150613233007-fe3947df716e/go.mod h1:kGUqhHd//musdITWjFvNTHn90WG9bMLBEPQZ17Cmlpw=
github.com/FactomProject/btcutilecc v0.0.0-20130527213604-d3a63a5752ec h1:1Qb69mGp/UtRPn422BH4/Y4Q3SLUrD9KHuDkm8iodFc=
github.com/FactomProject/btcutilecc v0.0.0-20130527213604-d3a63a5752ec/go.mod h1:CD8UlnlLDiqb36L110uqiP2iSflVjx9g/3U9hCI4q2U=
github.com/aead/siphash v1.0.1/go.mod h1:Nywa3cDsYNNK3gaciGTWPwHt0wlpNV15vwmswBAUSII= github.com/aead/siphash v1.0.1/go.mod h1:Nywa3cDsYNNK3gaciGTWPwHt0wlpNV15vwmswBAUSII=
github.com/armon/consul-api v0.0.0-20180202201655-eb2c6b5be1b6/go.mod h1:grANhF5doyWs3UAsr3K4I6qtAmlQcZDesFNEHPZAzj8= github.com/armon/consul-api v0.0.0-20180202201655-eb2c6b5be1b6/go.mod h1:grANhF5doyWs3UAsr3K4I6qtAmlQcZDesFNEHPZAzj8=
github.com/benbjohnson/clock v1.1.0/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA= github.com/benbjohnson/clock v1.1.0/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA=
@@ -78,6 +82,8 @@ github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWR
github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5PlCu98SY8svDHJxuZscDgtXS6KTTbou5AhLI= github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5PlCu98SY8svDHJxuZscDgtXS6KTTbou5AhLI=
github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU= github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU=
github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw=
github.com/cmars/basen v0.0.0-20150613233007-fe3947df716e h1:0XBUw73chJ1VYSsfvcPvVT7auykAJce9FpRr10L6Qhw=
github.com/cmars/basen v0.0.0-20150613233007-fe3947df716e/go.mod h1:P13beTBKr5Q18lJe1rIoLUqjM+CB1zYrRg44ZqGuQSA=
github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc= github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc=
github.com/cncf/udpa/go v0.0.0-20200629203442-efcf912fb354/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk= github.com/cncf/udpa/go v0.0.0-20200629203442-efcf912fb354/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk=
github.com/cncf/udpa/go v0.0.0-20201120205902-5459f2c99403/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk= github.com/cncf/udpa/go v0.0.0-20201120205902-5459f2c99403/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk=
@@ -307,6 +313,7 @@ github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0 h1:1zr/of2m5FGMsad5YfcqgdqdWrIhu+EBEJRhR1U7z/c= github.com/stretchr/objx v0.5.0 h1:1zr/of2m5FGMsad5YfcqgdqdWrIhu+EBEJRhR1U7z/c=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
github.com/stretchr/testify v1.1.5-0.20170601210322-f6abca593680/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
@@ -328,8 +335,10 @@ github.com/urfave/cli/v2 v2.26.0 h1:3f3AMg3HpThFNT4I++TKOejZO8yU55t3JnnSr4S4QEI=
github.com/urfave/cli/v2 v2.26.0/go.mod h1:8qnjx1vcq5s2/wpsqoZFndg2CE5tNFyrTvS6SinrnYQ= github.com/urfave/cli/v2 v2.26.0/go.mod h1:8qnjx1vcq5s2/wpsqoZFndg2CE5tNFyrTvS6SinrnYQ=
github.com/vulpemventures/fastsha256 v0.0.0-20160815193821-637e65642941 h1:CTcw80hz/Sw8hqlKX5ZYvBUF5gAHSHwdjXxRf/cjDcI= github.com/vulpemventures/fastsha256 v0.0.0-20160815193821-637e65642941 h1:CTcw80hz/Sw8hqlKX5ZYvBUF5gAHSHwdjXxRf/cjDcI=
github.com/vulpemventures/fastsha256 v0.0.0-20160815193821-637e65642941/go.mod h1:GXBJykxW2kUcktGdsgyay7uwwWvkljASfljNcT0mbh8= github.com/vulpemventures/fastsha256 v0.0.0-20160815193821-637e65642941/go.mod h1:GXBJykxW2kUcktGdsgyay7uwwWvkljASfljNcT0mbh8=
github.com/vulpemventures/go-elements v0.5.2 h1:vIDzVpRXG5PnlzHA8tCnr2Tn7raIV5cHy7bRyDrbuM4= github.com/vulpemventures/go-bip32 v0.0.0-20200624192635-867c159da4d7 h1:X7DtNv+YWy76kELMZB/xVkIJ7YNp2vpgMFVsDcQA40U=
github.com/vulpemventures/go-elements v0.5.2/go.mod h1:aBGuWXHaiAIUIcwqCdtEh2iQ3kJjKwHU9ywvhlcRSeU= github.com/vulpemventures/go-bip32 v0.0.0-20200624192635-867c159da4d7/go.mod h1:Zrvx8XgpWvSPdz1lXnuN083CkoZnzwxBLEB03S8et1I=
github.com/vulpemventures/go-elements v0.5.3 h1:zaC/ynHFwCAzFSOMfzb6BcbD6FXASppSiGMycc95WVA=
github.com/vulpemventures/go-elements v0.5.3/go.mod h1:aBGuWXHaiAIUIcwqCdtEh2iQ3kJjKwHU9ywvhlcRSeU=
github.com/vulpemventures/go-secp256k1-zkp v1.1.6 h1:BmsrmXRLUibwa75Qkk8yELjpzCzlAjYFGLiLiOdq7Xo= github.com/vulpemventures/go-secp256k1-zkp v1.1.6 h1:BmsrmXRLUibwa75Qkk8yELjpzCzlAjYFGLiLiOdq7Xo=
github.com/vulpemventures/go-secp256k1-zkp v1.1.6/go.mod h1:zo7CpgkuPgoe7fAV+inyxsI9IhGmcoFgyD8nqZaPSOM= github.com/vulpemventures/go-secp256k1-zkp v1.1.6/go.mod h1:zo7CpgkuPgoe7fAV+inyxsI9IhGmcoFgyD8nqZaPSOM=
github.com/xordataexchange/crypt v0.0.3-0.20170626215501-b2862e3d0a77/go.mod h1:aYKd//L2LvnjZzWKhF00oedf4jCCReLcmhLdhm1A27Q= github.com/xordataexchange/crypt v0.0.3-0.20170626215501-b2862e3d0a77/go.mod h1:aYKd//L2LvnjZzWKhF00oedf4jCCReLcmhLdhm1A27Q=
@@ -356,6 +365,7 @@ go.uber.org/multierr v1.6.0/go.mod h1:cdWPpRnG4AhwMwsgIHip0KRBQjJy5kYEpYjJxpXp9i
go.uber.org/multierr v1.9.0 h1:7fIwc/ZtS0q++VgcfqFDxSBZVv/Xo49/SYnDFupUwlI= go.uber.org/multierr v1.9.0 h1:7fIwc/ZtS0q++VgcfqFDxSBZVv/Xo49/SYnDFupUwlI=
go.uber.org/multierr v1.9.0/go.mod h1:X2jQV1h+kxSjClGpnseKVIxpmcjrj7MNnI0bnlfKTVQ= go.uber.org/multierr v1.9.0/go.mod h1:X2jQV1h+kxSjClGpnseKVIxpmcjrj7MNnI0bnlfKTVQ=
go.uber.org/zap v1.18.1/go.mod h1:xg/QME4nWcxGxrpdeYfq7UvYrLh66cuVKdrbD1XF/NI= go.uber.org/zap v1.18.1/go.mod h1:xg/QME4nWcxGxrpdeYfq7UvYrLh66cuVKdrbD1XF/NI=
golang.org/x/crypto v0.0.0-20170613210332-850760c427c5/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
golang.org/x/crypto v0.0.0-20170930174604-9419663f5a44/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20170930174604-9419663f5a44/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
golang.org/x/crypto v0.0.0-20181203042331-505ab145d0a9/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20181203042331-505ab145d0a9/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
@@ -725,6 +735,8 @@ honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWh
honnef.co/go/tools v0.0.1-2019.2.3/go.mod h1:a3bituU0lyd329TUQxRnasdCoJDkEUEAqEt0JzvZhAg= honnef.co/go/tools v0.0.1-2019.2.3/go.mod h1:a3bituU0lyd329TUQxRnasdCoJDkEUEAqEt0JzvZhAg=
honnef.co/go/tools v0.0.1-2020.1.3/go.mod h1:X/FiERA/W4tHapMX5mGpAtMSVEeEUOyHaw9vFzvIQ3k= honnef.co/go/tools v0.0.1-2020.1.3/go.mod h1:X/FiERA/W4tHapMX5mGpAtMSVEeEUOyHaw9vFzvIQ3k=
honnef.co/go/tools v0.0.1-2020.1.4/go.mod h1:X/FiERA/W4tHapMX5mGpAtMSVEeEUOyHaw9vFzvIQ3k= honnef.co/go/tools v0.0.1-2020.1.4/go.mod h1:X/FiERA/W4tHapMX5mGpAtMSVEeEUOyHaw9vFzvIQ3k=
launchpad.net/gocheck v0.0.0-20140225173054-000000000087 h1:Izowp2XBH6Ya6rv+hqbceQyw/gSGoXfH/UPoTGduL54=
launchpad.net/gocheck v0.0.0-20140225173054-000000000087/go.mod h1:hj7XX3B/0A+80Vse0e+BUHsHMTEhd0O4cpUHr/e/BUM=
rsc.io/binaryregexp v0.2.0/go.mod h1:qTv7/COck+e2FymRvadv62gMdZztPaShugOCi3I+8D8= rsc.io/binaryregexp v0.2.0/go.mod h1:qTv7/COck+e2FymRvadv62gMdZztPaShugOCi3I+8D8=
rsc.io/quote/v3 v3.1.0/go.mod h1:yEA65RcK8LyAZtP9Kv3t0HmxON59tX3rD+tICJqUlj0= rsc.io/quote/v3 v3.1.0/go.mod h1:yEA65RcK8LyAZtP9Kv3t0HmxON59tX3rD+tICJqUlj0=
rsc.io/sampler v1.3.0/go.mod h1:T1hPZKmBbMNahiBKFy5HrXp6adAjACjK9JXDnKaTXpA= rsc.io/sampler v1.3.0/go.mod h1:T1hPZKmBbMNahiBKFy5HrXp6adAjACjK9JXDnKaTXpA=

View File

@@ -9,6 +9,7 @@ import (
"github.com/ark-network/ark/internal/core/ports" "github.com/ark-network/ark/internal/core/ports"
"github.com/ark-network/ark/internal/infrastructure/db" "github.com/ark-network/ark/internal/infrastructure/db"
oceanwallet "github.com/ark-network/ark/internal/infrastructure/ocean-wallet" oceanwallet "github.com/ark-network/ark/internal/infrastructure/ocean-wallet"
scheduler "github.com/ark-network/ark/internal/infrastructure/scheduler/gocron"
txbuilder "github.com/ark-network/ark/internal/infrastructure/tx-builder/covenant" txbuilder "github.com/ark-network/ark/internal/infrastructure/tx-builder/covenant"
txbuilderdummy "github.com/ark-network/ark/internal/infrastructure/tx-builder/dummy" txbuilderdummy "github.com/ark-network/ark/internal/infrastructure/tx-builder/dummy"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@@ -41,12 +42,14 @@ type Config struct {
BlockchainScannerType string BlockchainScannerType string
WalletAddr string WalletAddr string
MinRelayFee uint64 MinRelayFee uint64
RoundLifetime int64
repo ports.RepoManager repo ports.RepoManager
svc application.Service svc application.Service
wallet ports.WalletService wallet ports.WalletService
txBuilder ports.TxBuilder txBuilder ports.TxBuilder
scanner ports.BlockchainScanner scanner ports.BlockchainScanner
scheduler ports.SchedulerService
} }
func (c *Config) Validate() error { func (c *Config) Validate() error {
@@ -86,9 +89,30 @@ func (c *Config) Validate() error {
if err := c.scannerService(); err != nil { if err := c.scannerService(); err != nil {
return err return err
} }
if err := c.schedulerService(); err != nil {
return err
}
if err := c.appService(); err != nil { if err := c.appService(); err != nil {
return err return err
} }
// round life time must be a multiple of 512
if c.RoundLifetime <= 0 || c.RoundLifetime%512 != 0 {
return fmt.Errorf("invalid round lifetime, must be greater than 0 and a multiple of 512")
}
seq, err := common.BIP68Encode(uint(c.RoundLifetime))
if err != nil {
return fmt.Errorf("invalid round lifetime, %s", err)
}
seconds, err := common.BIP68Decode(seq)
if err != nil {
return fmt.Errorf("invalid round lifetime, %s", err)
}
if seconds != uint(c.RoundLifetime) {
return fmt.Errorf("invalid round lifetime, must be a multiple of 512")
}
return nil return nil
} }
@@ -141,7 +165,7 @@ func (c *Config) txBuilderService() error {
case "dummy": case "dummy":
svc = txbuilderdummy.NewTxBuilder(c.wallet, net) svc = txbuilderdummy.NewTxBuilder(c.wallet, net)
case "covenant": case "covenant":
svc = txbuilder.NewTxBuilder(c.wallet, net) svc = txbuilder.NewTxBuilder(c.wallet, net, c.RoundLifetime)
default: default:
err = fmt.Errorf("unknown tx builder type") err = fmt.Errorf("unknown tx builder type")
} }
@@ -170,10 +194,28 @@ func (c *Config) scannerService() error {
return nil return nil
} }
func (c *Config) schedulerService() error {
var svc ports.SchedulerService
var err error
switch c.SchedulerType {
case "gocron":
svc = scheduler.NewScheduler()
default:
err = fmt.Errorf("unknown scheduler type")
}
if err != nil {
return err
}
c.scheduler = svc
return nil
}
func (c *Config) appService() error { func (c *Config) appService() error {
net := c.mainChain() net := c.mainChain()
svc, err := application.NewService( svc, err := application.NewService(
c.RoundInterval, c.Network, net, c.wallet, c.repo, c.txBuilder, c.scanner, c.MinRelayFee, c.Network, net, c.RoundInterval, c.RoundLifetime, c.MinRelayFee,
c.wallet, c.repo, c.txBuilder, c.scanner, c.scheduler,
) )
if err != nil { if err != nil {
return err return err

View File

@@ -23,6 +23,7 @@ type Config struct {
Network common.Network Network common.Network
LogLevel int LogLevel int
MinRelayFee uint64 MinRelayFee uint64
RoundLifetime int64
} }
var ( var (
@@ -38,6 +39,7 @@ var (
LogLevel = "LOG_LEVEL" LogLevel = "LOG_LEVEL"
Network = "NETWORK" Network = "NETWORK"
MinRelayFee = "MIN_RELAY_FEE" MinRelayFee = "MIN_RELAY_FEE"
RoundLifetime = "ROUND_LIFETIME"
defaultDatadir = common.AppDataDir("arkd", false) defaultDatadir = common.AppDataDir("arkd", false)
defaultRoundInterval = 60 defaultRoundInterval = 60
@@ -50,6 +52,7 @@ var (
defaultNetwork = "testnet" defaultNetwork = "testnet"
defaultLogLevel = 5 defaultLogLevel = 5
defaultMinRelayFee = 30 defaultMinRelayFee = 30
defaultRoundLifetime = 512
) )
func LoadConfig() (*Config, error) { func LoadConfig() (*Config, error) {
@@ -66,6 +69,7 @@ func LoadConfig() (*Config, error) {
viper.SetDefault(Insecure, defaultInsecure) viper.SetDefault(Insecure, defaultInsecure)
viper.SetDefault(LogLevel, defaultLogLevel) viper.SetDefault(LogLevel, defaultLogLevel)
viper.SetDefault(Network, defaultNetwork) viper.SetDefault(Network, defaultNetwork)
viper.SetDefault(RoundLifetime, defaultRoundLifetime)
viper.SetDefault(MinRelayFee, defaultMinRelayFee) viper.SetDefault(MinRelayFee, defaultMinRelayFee)
net, err := getNetwork() net, err := getNetwork()
@@ -90,6 +94,7 @@ func LoadConfig() (*Config, error) {
LogLevel: viper.GetInt(LogLevel), LogLevel: viper.GetInt(LogLevel),
Network: net, Network: net,
MinRelayFee: viper.GetUint64(MinRelayFee), MinRelayFee: viper.GetUint64(MinRelayFee),
RoundLifetime: viper.GetInt64(RoundLifetime),
}, nil }, nil
} }

View File

@@ -41,27 +41,31 @@ type Service interface {
} }
type service struct { type service struct {
minRelayFee uint64
roundInterval int64
network common.Network network common.Network
onchainNework network.Network onchainNework network.Network
pubkey *secp256k1.PublicKey pubkey *secp256k1.PublicKey
roundLifetime int64
roundInterval int64
minRelayFee uint64
wallet ports.WalletService wallet ports.WalletService
repoManager ports.RepoManager repoManager ports.RepoManager
builder ports.TxBuilder builder ports.TxBuilder
scanner ports.BlockchainScanner scanner ports.BlockchainScanner
sweeper *sweeper
paymentRequests *paymentsMap paymentRequests *paymentsMap
forfeitTxs *forfeitTxsMap forfeitTxs *forfeitTxsMap
eventsCh chan domain.RoundEvent
eventsCh chan domain.RoundEvent
} }
func NewService( func NewService(
interval int64, network common.Network, onchainNetwork network.Network, network common.Network, onchainNetwork network.Network,
roundInterval, roundLifetime int64, minRelayFee uint64,
walletSvc ports.WalletService, repoManager ports.RepoManager, walletSvc ports.WalletService, repoManager ports.RepoManager,
builder ports.TxBuilder, scanner ports.BlockchainScanner, builder ports.TxBuilder, scanner ports.BlockchainScanner,
minRelayFee uint64, scheduler ports.SchedulerService,
) (Service, error) { ) (Service, error) {
eventsCh := make(chan domain.RoundEvent) eventsCh := make(chan domain.RoundEvent)
paymentRequests := newPaymentsMap(nil) paymentRequests := newPaymentsMap(nil)
@@ -72,10 +76,14 @@ func NewService(
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to fetch pubkey: %s", err) return nil, fmt.Errorf("failed to fetch pubkey: %s", err)
} }
sweeper := newSweeper(walletSvc, repoManager, builder, scheduler)
svc := &service{ svc := &service{
minRelayFee, interval, network, onchainNetwork, pubkey, network, onchainNetwork, pubkey,
walletSvc, repoManager, builder, scanner, paymentRequests, forfeitTxs, roundLifetime, roundInterval, minRelayFee,
eventsCh, walletSvc, repoManager, builder, scanner, sweeper,
paymentRequests, forfeitTxs, eventsCh,
} }
repoManager.RegisterEventsHandler( repoManager.RegisterEventsHandler(
func(round *domain.Round) { func(round *domain.Round) {
@@ -92,12 +100,18 @@ func NewService(
} }
func (s *service) Start() error { func (s *service) Start() error {
log.Debug("starting sweeper service")
if err := s.sweeper.start(); err != nil {
return err
}
log.Debug("starting app service") log.Debug("starting app service")
go s.start() go s.start()
return nil return nil
} }
func (s *service) Stop() { func (s *service) Stop() {
s.sweeper.stop()
// nolint // nolint
vtxos, _ := s.repoManager.Vtxos().GetSpendableVtxos( vtxos, _ := s.repoManager.Vtxos().GetSpendableVtxos(
context.Background(), "", context.Background(), "",
@@ -356,7 +370,17 @@ func (s *service) finalizeRound() {
return return
} }
now := time.Now().Unix()
expirationTimestamp := now + s.roundLifetime + 30 // add 30 secs to be sure that the tx is confirmed
if err := s.sweeper.schedule(expirationTimestamp, txid, round.CongestionTree); err != nil {
changes = round.Fail(fmt.Errorf("failed to schedule sweep tx: %s", err))
log.WithError(err).Warn("failed to schedule sweep tx")
return
}
changes, _ = round.EndFinalization(forfeitTxs, txid) changes, _ = round.EndFinalization(forfeitTxs, txid)
log.Debugf("finalized round %s with pool tx %s", round.Id, round.Txid) log.Debugf("finalized round %s with pool tx %s", round.Id, round.Txid)
} }

View File

@@ -0,0 +1,579 @@
package application
import (
"context"
"encoding/hex"
"fmt"
"time"
"github.com/ark-network/ark/common/tree"
"github.com/ark-network/ark/internal/core/domain"
"github.com/ark-network/ark/internal/core/ports"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/decred/dcrd/dcrec/secp256k1/v4"
log "github.com/sirupsen/logrus"
"github.com/vulpemventures/go-elements/psetv2"
)
// sweeper is an unexported service running while the main application service is started
// it is responsible for sweeping onchain shared outputs that expired
// it also handles delaying the sweep events in case some parts of the tree are broadcasted
// when a round is finalized, the main application service schedules a sweep event on the newly created congestion tree
type sweeper struct {
wallet ports.WalletService
repoManager ports.RepoManager
builder ports.TxBuilder
scheduler ports.SchedulerService
// cache of scheduled tasks, avoid scheduling the same sweep event multiple times
scheduledTasks map[string]struct{}
}
func newSweeper(
wallet ports.WalletService,
repoManager ports.RepoManager,
builder ports.TxBuilder,
scheduler ports.SchedulerService,
) *sweeper {
return &sweeper{
wallet,
repoManager,
builder,
scheduler,
make(map[string]struct{}),
}
}
func (s *sweeper) start() error {
s.scheduler.Start()
allRounds, err := s.repoManager.Rounds().GetSweepableRounds(context.Background())
if err != nil {
return err
}
for _, round := range allRounds {
task := s.createTask(round.Txid, round.CongestionTree)
task()
}
return nil
}
func (s *sweeper) stop() {
s.scheduler.Stop()
}
// removeTask update the cached map of scheduled tasks
func (s *sweeper) removeTask(treeRootTxid string) {
delete(s.scheduledTasks, treeRootTxid)
}
// schedule set up a task to be executed once at the given timestamp
func (s *sweeper) schedule(
expirationTimestamp int64, roundTxid string, congestionTree tree.CongestionTree,
) error {
root, err := congestionTree.Root()
if err != nil {
return err
}
if _, scheduled := s.scheduledTasks[root.Txid]; scheduled {
return nil
}
task := s.createTask(roundTxid, congestionTree)
fancyTime := time.Unix(expirationTimestamp, 0).Format("2006-01-02 15:04:05")
log.Debugf("scheduled sweep task at %s", fancyTime)
if err := s.scheduler.ScheduleTaskOnce(expirationTimestamp, task); err != nil {
return err
}
s.scheduledTasks[root.Txid] = struct{}{}
return nil
}
// createTask returns a function passed as handler in the scheduler
// it tries to craft a sweep tx containing the onchain outputs of the given congestion tree
// if some parts of the tree have been broadcasted in the meantine, it will schedule the next taskes for the remaining parts of the tree
func (s *sweeper) createTask(
roundTxid string, congestionTree tree.CongestionTree,
) func() {
return func() {
ctx := context.Background()
root, err := congestionTree.Root()
if err != nil {
log.WithError(err).Error("error while getting root node")
return
}
s.removeTask(root.Txid)
log.Debugf("sweeper: %s", root.Txid)
sweepInputs := make([]ports.SweepInput, 0)
vtxoKeys := make([]domain.VtxoKey, 0) // vtxos associated to the sweep inputs
// inspect the congestion tree to find onchain shared outputs
sharedOutputs, err := s.findSweepableOutputs(ctx, congestionTree)
if err != nil {
log.WithError(err).Error("error while inspecting congestion tree")
return
}
for expiredAt, inputs := range sharedOutputs {
// if the shared outputs are not expired, schedule a sweep task for it
if time.Unix(expiredAt, 0).After(time.Now()) {
subtrees, err := computeSubTrees(congestionTree, inputs)
if err != nil {
log.WithError(err).Error("error while computing subtrees")
continue
}
for _, subTree := range subtrees {
// mitigate the risk to get BIP68 non-final errors by scheduling the task 30 seconds after the expiration time
if err := s.schedule(int64(expiredAt), roundTxid, subTree); err != nil {
log.WithError(err).Error("error while scheduling sweep task")
continue
}
}
continue
}
// iterate over the expired shared outputs
for _, input := range inputs {
// sweepableVtxos related to the sweep input
sweepableVtxos := make([]domain.VtxoKey, 0)
// check if input is the vtxo itself
vtxos, _ := s.repoManager.Vtxos().GetVtxos(
ctx,
[]domain.VtxoKey{
{
Txid: input.InputArgs.Txid,
VOut: input.InputArgs.TxIndex,
},
},
)
if len(vtxos) > 0 {
if !vtxos[0].Swept && !vtxos[0].Redeemed {
sweepableVtxos = append(sweepableVtxos, vtxos[0].VtxoKey)
}
} else {
// if it's not a vtxo, find all the vtxos leaves reachable from that input
vtxosLeaves, err := congestionTree.FindLeaves(input.InputArgs.Txid, input.InputArgs.TxIndex)
if err != nil {
log.WithError(err).Error("error while finding vtxos leaves")
continue
}
for _, leaf := range vtxosLeaves {
pset, err := psetv2.NewPsetFromBase64(leaf.Tx)
if err != nil {
log.Error(fmt.Errorf("error while decoding pset: %w", err))
continue
}
vtxo, err := extractVtxoOutpoint(pset)
if err != nil {
log.Error(err)
continue
}
sweepableVtxos = append(sweepableVtxos, *vtxo)
}
if len(sweepableVtxos) <= 0 {
continue
}
firstVtxo, err := s.repoManager.Vtxos().GetVtxos(ctx, sweepableVtxos[1:])
if err != nil {
log.Error(fmt.Errorf("error while getting vtxo: %w", err))
sweepInputs = append(sweepInputs, input) // add the input anyway in order to try to sweep it
continue
}
if firstVtxo[0].Swept || firstVtxo[0].Redeemed {
// we assume that if the first vtxo is swept or redeemed, the shared output has been spent
// skip, the output is already swept or spent by a unilateral redeem
continue
}
}
if len(sweepableVtxos) > 0 {
vtxoKeys = append(vtxoKeys, sweepableVtxos...)
sweepInputs = append(sweepInputs, input)
}
}
}
if len(sweepInputs) > 0 {
// build the sweep transaction with all the expired non-swept shared outputs
sweepTx, err := s.builder.BuildSweepTx(s.wallet, sweepInputs)
if err != nil {
log.WithError(err).Error("error while building sweep tx")
return
}
err = nil
txid := ""
// retry until the tx is broadcasted or the error is not BIP68 final
for len(txid) == 0 && (err == nil || err == fmt.Errorf("non-BIP68-final")) {
if err != nil {
log.Debugln("sweep tx not BIP68 final, retrying in 5 seconds")
time.Sleep(5 * time.Second)
}
txid, err = s.wallet.BroadcastTransaction(ctx, sweepTx)
}
if err != nil {
log.WithError(err).Error("error while broadcasting sweep tx")
return
}
if len(txid) > 0 {
log.Debugln("sweep tx broadcasted:", txid)
vtxosRepository := s.repoManager.Vtxos()
// mark the vtxos as swept
if err := vtxosRepository.SweepVtxos(ctx, vtxoKeys); err != nil {
log.Error(fmt.Errorf("error while deleting vtxos: %w", err))
return
}
log.Debugf("%d vtxos swept", len(vtxoKeys))
roundVtxos, err := vtxosRepository.GetVtxosForRound(ctx, roundTxid)
if err != nil {
log.WithError(err).Error("error while getting vtxos for round")
return
}
allSwept := true
for _, vtxo := range roundVtxos {
allSwept = allSwept && vtxo.Swept
if !allSwept {
break
}
}
if allSwept {
// update the round
roundRepo := s.repoManager.Rounds()
round, err := roundRepo.GetRoundWithTxid(ctx, roundTxid)
if err != nil {
log.WithError(err).Error("error while getting round")
return
}
round.Sweep()
if err := roundRepo.AddOrUpdateRound(ctx, *round); err != nil {
log.WithError(err).Error("error while marking round as swept")
return
}
}
}
}
}
}
// onchainOutputs iterates over all the nodes' outputs in the congestion tree and checks their onchain state
// returns the sweepable outputs as ports.SweepInput mapped by their expiration time
func (s *sweeper) findSweepableOutputs(
ctx context.Context,
congestionTree tree.CongestionTree,
) (map[int64][]ports.SweepInput, error) {
sweepableOutputs := make(map[int64][]ports.SweepInput)
blocktimeCache := make(map[string]int64) // txid -> blocktime
nodesToCheck := congestionTree[0] // init with the root
for len(nodesToCheck) > 0 {
newNodesToCheck := make([]tree.Node, 0)
for _, node := range nodesToCheck {
isPublished, blocktime, err := s.wallet.IsTransactionPublished(ctx, node.Txid)
if err != nil {
return nil, err
}
var expirationTime int64
var sweepInputs []ports.SweepInput
if !isPublished {
if _, ok := blocktimeCache[node.ParentTxid]; !ok {
isPublished, blocktime, err := s.wallet.IsTransactionPublished(ctx, node.ParentTxid)
if !isPublished || err != nil {
return nil, fmt.Errorf("tx %s not found", node.Txid)
}
blocktimeCache[node.ParentTxid] = blocktime
}
expirationTime, sweepInputs, err = s.nodeToSweepInputs(blocktimeCache[node.ParentTxid], node)
if err != nil {
return nil, err
}
} else {
// cache the blocktime for future use
blocktimeCache[node.Txid] = int64(blocktime)
// if the tx is onchain, it means that the input is spent
// add the children to the nodes in order to check them during the next iteration
// We will return the error below, but are we going to schedule the tasks for the "children roots"?
if !node.Leaf {
children := congestionTree.Children(node.Txid)
newNodesToCheck = append(newNodesToCheck, children...)
continue
}
// if the node is a leaf, the vtxos outputs should added as onchain outputs if they are not swept yet
vtxoExpiration, sweepInput, err := s.leafToSweepInput(ctx, blocktime, node)
if err != nil {
return nil, err
}
if sweepInput != nil {
expirationTime = vtxoExpiration
sweepInputs = []ports.SweepInput{*sweepInput}
}
}
if _, ok := sweepableOutputs[expirationTime]; !ok {
sweepableOutputs[expirationTime] = make([]ports.SweepInput, 0)
}
sweepableOutputs[expirationTime] = append(sweepableOutputs[expirationTime], sweepInputs...)
}
nodesToCheck = newNodesToCheck
}
return sweepableOutputs, nil
}
func (s *sweeper) leafToSweepInput(ctx context.Context, txBlocktime int64, node tree.Node) (int64, *ports.SweepInput, error) {
pset, err := psetv2.NewPsetFromBase64(node.Tx)
if err != nil {
return -1, nil, err
}
vtxo, err := extractVtxoOutpoint(pset)
if err != nil {
return -1, nil, err
}
fromRepo, err := s.repoManager.Vtxos().GetVtxos(ctx, []domain.VtxoKey{*vtxo})
if err != nil {
return -1, nil, err
}
if len(fromRepo) == 0 {
return -1, nil, fmt.Errorf("vtxo not found")
}
if fromRepo[0].Swept {
return -1, nil, nil
}
if fromRepo[0].Redeemed {
return -1, nil, nil
}
// if the vtxo is not swept or redeemed, add it to the onchain outputs
pubKeyBytes, err := hex.DecodeString(fromRepo[0].Pubkey)
if err != nil {
return -1, nil, err
}
pubKey, err := secp256k1.ParsePubKey(pubKeyBytes)
if err != nil {
return -1, nil, err
}
sweepLeaf, lifetime, err := s.builder.GetLeafSweepClosure(node, pubKey)
if err != nil {
return -1, nil, err
}
sweepInput := ports.SweepInput{
InputArgs: psetv2.InputArgs{
Txid: vtxo.Txid,
TxIndex: vtxo.VOut,
},
SweepLeaf: *sweepLeaf,
Amount: fromRepo[0].Amount,
}
return txBlocktime + lifetime, &sweepInput, nil
}
func (s *sweeper) nodeToSweepInputs(parentBlocktime int64, node tree.Node) (int64, []ports.SweepInput, error) {
pset, err := psetv2.NewPsetFromBase64(node.Tx)
if err != nil {
return -1, nil, err
}
if len(pset.Inputs) != 1 {
return -1, nil, fmt.Errorf("invalid node pset, expect 1 input, got %d", len(pset.Inputs))
}
// if the tx is not onchain, it means that the input is an existing shared output
input := pset.Inputs[0]
txid := chainhash.Hash(input.PreviousTxid).String()
index := input.PreviousTxIndex
sweepLeaf, lifetime, err := extractSweepLeaf(input)
if err != nil {
return -1, nil, err
}
expirationTime := parentBlocktime + lifetime
amount := uint64(0)
for _, out := range pset.Outputs {
amount += out.Value
}
sweepInputs := []ports.SweepInput{
{
InputArgs: psetv2.InputArgs{
Txid: txid,
TxIndex: index,
},
SweepLeaf: *sweepLeaf,
Amount: amount,
},
}
return expirationTime, sweepInputs, nil
}
func computeSubTrees(congestionTree tree.CongestionTree, inputs []ports.SweepInput) ([]tree.CongestionTree, error) {
subTrees := make(map[string]tree.CongestionTree, 0)
// for each sweepable input, create a sub congestion tree
// it allows to skip the part of the tree that has been broadcasted in the next task
for _, input := range inputs {
subTree, err := computeSubTree(congestionTree, input.InputArgs.Txid)
if err != nil {
log.WithError(err).Error("error while finding sub tree")
continue
}
root, err := subTree.Root()
if err != nil {
log.WithError(err).Error("error while getting root node")
continue
}
subTrees[root.Txid] = subTree
}
// filter out the sub trees, remove the ones that are included in others
filteredSubTrees := make([]tree.CongestionTree, 0)
for i, subTree := range subTrees {
notIncludedInOtherTrees := true
for j, otherSubTree := range subTrees {
if i == j {
continue
}
contains, err := containsTree(otherSubTree, subTree)
if err != nil {
log.WithError(err).Error("error while checking if a tree contains another")
continue
}
if contains {
notIncludedInOtherTrees = false
break
}
}
if notIncludedInOtherTrees {
filteredSubTrees = append(filteredSubTrees, subTree)
}
}
return filteredSubTrees, nil
}
func computeSubTree(congestionTree tree.CongestionTree, newRoot string) (tree.CongestionTree, error) {
for _, level := range congestionTree {
for _, node := range level {
if node.Txid == newRoot || node.ParentTxid == newRoot {
newTree := make(tree.CongestionTree, 0)
newTree = append(newTree, []tree.Node{node})
children := congestionTree.Children(node.Txid)
for len(children) > 0 {
newTree = append(newTree, children)
newChildren := make([]tree.Node, 0)
for _, child := range children {
newChildren = append(newChildren, congestionTree.Children(child.Txid)...)
}
children = newChildren
}
return newTree, nil
}
}
}
return nil, fmt.Errorf("failed to create subtree, new root not found")
}
func containsTree(tr0 tree.CongestionTree, tr1 tree.CongestionTree) (bool, error) {
tr1Root, err := tr1.Root()
if err != nil {
return false, err
}
for _, level := range tr0 {
for _, node := range level {
if node.Txid == tr1Root.Txid {
return true, nil
}
}
}
return false, nil
}
// given a congestion tree input, searches and returns the sweep leaf and its lifetime in seconds
func extractSweepLeaf(input psetv2.Input) (sweepLeaf *psetv2.TapLeafScript, lifetime int64, err error) {
for _, leaf := range input.TapLeafScript {
isSweep, _, seconds, err := tree.DecodeSweepScript(leaf.Script)
if err != nil {
return nil, 0, err
}
if isSweep {
lifetime = int64(seconds)
sweepLeaf = &leaf
break
}
}
if sweepLeaf == nil {
return nil, 0, fmt.Errorf("sweep leaf not found")
}
return sweepLeaf, lifetime, nil
}
// assuming the pset is a leaf in the congestion tree, returns the vtxos outputs
func extractVtxoOutpoint(pset *psetv2.Pset) (*domain.VtxoKey, error) {
if len(pset.Outputs) != 2 {
return nil, fmt.Errorf("invalid leaf pset, expect 2 outputs, got %d", len(pset.Outputs))
}
utx, err := pset.UnsignedTx()
if err != nil {
return nil, err
}
return &domain.VtxoKey{
Txid: utx.TxHash().String(),
VOut: 0,
}, nil
}

View File

@@ -125,4 +125,5 @@ type Vtxo struct {
PoolTx string PoolTx string
Spent bool Spent bool
Redeemed bool Redeemed bool
Swept bool
} }

View File

@@ -46,6 +46,7 @@ type Round struct {
Connectors []string Connectors []string
DustAmount uint64 DustAmount uint64
Version uint Version uint
Swept bool // true if all the vtxos are vtxo.Swept
changes []RoundEvent changes []RoundEvent
} }
@@ -239,6 +240,10 @@ func (r *Round) TotalOutputAmount() uint64 {
return tot return tot
} }
func (r *Round) Sweep() {
r.Swept = true
}
func (r *Round) raise(event RoundEvent) { func (r *Round) raise(event RoundEvent) {
if r.changes == nil { if r.changes == nil {
r.changes = make([]RoundEvent, 0) r.changes = make([]RoundEvent, 0)

View File

@@ -1,6 +1,8 @@
package domain package domain
import "context" import (
"context"
)
type RoundEventRepository interface { type RoundEventRepository interface {
Save(ctx context.Context, id string, events ...RoundEvent) error Save(ctx context.Context, id string, events ...RoundEvent) error
@@ -12,6 +14,7 @@ type RoundRepository interface {
GetCurrentRound(ctx context.Context) (*Round, error) GetCurrentRound(ctx context.Context) (*Round, error)
GetRoundWithId(ctx context.Context, id string) (*Round, error) GetRoundWithId(ctx context.Context, id string) (*Round, error)
GetRoundWithTxid(ctx context.Context, txid string) (*Round, error) GetRoundWithTxid(ctx context.Context, txid string) (*Round, error)
GetSweepableRounds(ctx context.Context) ([]Round, error)
} }
type VtxoRepository interface { type VtxoRepository interface {
@@ -19,5 +22,7 @@ type VtxoRepository interface {
SpendVtxos(ctx context.Context, vtxos []VtxoKey) error SpendVtxos(ctx context.Context, vtxos []VtxoKey) error
RedeemVtxos(ctx context.Context, vtxos []VtxoKey) ([]Vtxo, error) RedeemVtxos(ctx context.Context, vtxos []VtxoKey) ([]Vtxo, error)
GetVtxos(ctx context.Context, vtxos []VtxoKey) ([]Vtxo, error) GetVtxos(ctx context.Context, vtxos []VtxoKey) ([]Vtxo, error)
GetVtxosForRound(ctx context.Context, txid string) ([]Vtxo, error)
SweepVtxos(ctx context.Context, vtxos []VtxoKey) error
GetSpendableVtxos(ctx context.Context, pubkey string) ([]Vtxo, error) GetSpendableVtxos(ctx context.Context, pubkey string) ([]Vtxo, error)
} }

View File

@@ -5,4 +5,5 @@ type SchedulerService interface {
Stop() Stop()
ScheduleTask(interval int64, immediate bool, task func()) error ScheduleTask(interval int64, immediate bool, task func()) error
ScheduleTaskOnce(delay int64, task func()) error
} }

View File

@@ -4,8 +4,15 @@ import (
"github.com/ark-network/ark/common/tree" "github.com/ark-network/ark/common/tree"
"github.com/ark-network/ark/internal/core/domain" "github.com/ark-network/ark/internal/core/domain"
"github.com/decred/dcrd/dcrec/secp256k1/v4" "github.com/decred/dcrd/dcrec/secp256k1/v4"
"github.com/vulpemventures/go-elements/psetv2"
) )
type SweepInput struct {
InputArgs psetv2.InputArgs
SweepLeaf psetv2.TapLeafScript
Amount uint64
}
type TxBuilder interface { type TxBuilder interface {
BuildPoolTx( BuildPoolTx(
aspPubkey *secp256k1.PublicKey, payments []domain.Payment, minRelayFee uint64, aspPubkey *secp256k1.PublicKey, payments []domain.Payment, minRelayFee uint64,
@@ -13,5 +20,10 @@ type TxBuilder interface {
BuildForfeitTxs( BuildForfeitTxs(
aspPubkey *secp256k1.PublicKey, poolTx string, payments []domain.Payment, aspPubkey *secp256k1.PublicKey, poolTx string, payments []domain.Payment,
) (connectors []string, forfeitTxs []string, err error) ) (connectors []string, forfeitTxs []string, err error)
BuildSweepTx(
wallet WalletService,
inputs []SweepInput,
) (signedSweepTx string, err error)
GetLeafSweepClosure(node tree.Node, userPubKey *secp256k1.PublicKey) (*psetv2.TapLeafScript, int64, error)
GetVtxoScript(userPubkey, aspPubkey *secp256k1.PublicKey) ([]byte, error) GetVtxoScript(userPubkey, aspPubkey *secp256k1.PublicKey) ([]byte, error)
} }

View File

@@ -16,6 +16,8 @@ type WalletService interface {
) (string, error) ) (string, error)
SelectUtxos(ctx context.Context, asset string, amount uint64) ([]TxInput, uint64, error) SelectUtxos(ctx context.Context, asset string, amount uint64) ([]TxInput, uint64, error)
BroadcastTransaction(ctx context.Context, txHex string) (string, error) BroadcastTransaction(ctx context.Context, txHex string) (string, error)
SignPsetWithKey(ctx context.Context, pset string, inputIndexes []int) (string, error) // inputIndexes == nil means sign all inputs
IsTransactionPublished(ctx context.Context, txid string) (isPublished bool, blocktime int64, err error)
EstimateFees(ctx context.Context, pset string) (uint64, error) EstimateFees(ctx context.Context, pset string) (uint64, error)
Close() Close()
} }

View File

@@ -95,6 +95,18 @@ func (r *roundRepository) GetRoundWithTxid(
return round, nil return round, nil
} }
func (r *roundRepository) GetSweepableRounds(
ctx context.Context,
) ([]domain.Round, error) {
query := badgerhold.Where("Stage.Code").Eq(domain.FinalizationStage).
And("Stage.Ended").Eq(true).And("Swept").Eq(false)
rounds, err := r.findRound(ctx, query)
if err != nil {
return nil, err
}
return rounds, nil
}
func (r *roundRepository) Close() { func (r *roundRepository) Close() {
r.store.Close() r.store.Close()
} }

View File

@@ -93,16 +93,34 @@ func (r *vtxoRepository) GetVtxos(
return vtxos, nil return vtxos, nil
} }
func (r *vtxoRepository) GetVtxosForRound(
ctx context.Context, txid string,
) ([]domain.Vtxo, error) {
query := badgerhold.Where("Txid").Eq(txid)
return r.findVtxos(ctx, query)
}
func (r *vtxoRepository) GetSpendableVtxos( func (r *vtxoRepository) GetSpendableVtxos(
ctx context.Context, pubkey string, ctx context.Context, pubkey string,
) ([]domain.Vtxo, error) { ) ([]domain.Vtxo, error) {
query := badgerhold.Where("Spent").Eq(false).And("Redeemed").Eq(false) query := badgerhold.Where("Spent").Eq(false).And("Redeemed").Eq(false).And("Swept").Eq(false)
if len(pubkey) > 0 { if len(pubkey) > 0 {
query = query.And("Pubkey").Eq(pubkey) query = query.And("Pubkey").Eq(pubkey)
} }
return r.findVtxos(ctx, query) return r.findVtxos(ctx, query)
} }
func (r *vtxoRepository) SweepVtxos(
ctx context.Context, vtxoKeys []domain.VtxoKey,
) error {
for _, vtxoKey := range vtxoKeys {
if err := r.sweepVtxo(ctx, vtxoKey); err != nil {
return err
}
}
return nil
}
func (r *vtxoRepository) Close() { func (r *vtxoRepository) Close() {
r.store.Close() r.store.Close()
} }
@@ -203,3 +221,25 @@ func (r *vtxoRepository) findVtxos(ctx context.Context, query *badgerhold.Query)
return vtxos, err return vtxos, err
} }
func (r *vtxoRepository) sweepVtxo(ctx context.Context, vtxoKey domain.VtxoKey) error {
vtxo, err := r.getVtxo(ctx, vtxoKey)
if err != nil {
return err
}
if vtxo.Swept {
return nil
}
vtxo.Swept = true
if ctx.Value("tx") != nil {
tx := ctx.Value("tx").(*badger.Txn)
err = r.store.TxUpdate(tx, vtxoKey.Hash(), *vtxo)
} else {
err = r.store.Update(vtxoKey.Hash(), *vtxo)
}
if err != nil {
return err
}
return nil
}

View File

@@ -60,6 +60,7 @@ func NewService(addr string) (ports.WalletService, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
found := false found := false
for _, account := range info.GetAccounts() { for _, account := range info.GetAccounts() {
if account.GetLabel() == accountLabel { if account.GetLabel() == accountLabel {

View File

@@ -2,12 +2,17 @@ package oceanwallet
import ( import (
"context" "context"
"encoding/binary"
"encoding/hex" "encoding/hex"
"fmt" "fmt"
"strings"
"time"
pb "github.com/ark-network/ark/api-spec/protobuf/gen/ocean/v1" pb "github.com/ark-network/ark/api-spec/protobuf/gen/ocean/v1"
"github.com/ark-network/ark/common/tree"
"github.com/ark-network/ark/internal/core/ports" "github.com/ark-network/ark/internal/core/ports"
"github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/txscript"
"github.com/vulpemventures/go-elements/elementsutil" "github.com/vulpemventures/go-elements/elementsutil"
"github.com/vulpemventures/go-elements/psetv2" "github.com/vulpemventures/go-elements/psetv2"
) )
@@ -76,6 +81,24 @@ func (s *service) SelectUtxos(ctx context.Context, asset string, amount uint64)
return inputs, res.GetChange(), nil return inputs, res.GetChange(), nil
} }
func (s *service) GetTransaction(
ctx context.Context, txid string,
) (string, int64, error) {
res, err := s.txClient.GetTransaction(ctx, &pb.GetTransactionRequest{
Txid: txid,
})
if err != nil {
return "", 0, err
}
if res.GetBlockDetails().GetTimestamp() > 0 {
return res.GetTxHex(), res.BlockDetails.GetTimestamp(), nil
}
// if not confirmed, we return now + 30 secs to estimate the next blocktime
return res.GetTxHex(), time.Now().Unix() + 30, nil
}
func (s *service) BroadcastTransaction( func (s *service) BroadcastTransaction(
ctx context.Context, txHex string, ctx context.Context, txHex string,
) (string, error) { ) (string, error) {
@@ -85,11 +108,104 @@ func (s *service) BroadcastTransaction(
}, },
) )
if err != nil { if err != nil {
if strings.Contains(err.Error(), "non-BIP68-final") {
return "", fmt.Errorf("non-BIP68-final")
}
return "", err return "", err
} }
return res.GetTxid(), nil return res.GetTxid(), nil
} }
func (s *service) IsTransactionPublished(
ctx context.Context, txid string,
) (bool, int64, error) {
_, blocktime, err := s.GetTransaction(ctx, txid)
if err != nil {
if strings.Contains(strings.ToLower(err.Error()), "missing transaction") {
return false, 0, nil
}
return false, 0, err
}
return true, blocktime, nil
}
func (s *service) SignPsetWithKey(ctx context.Context, b64 string, indexes []int) (string, error) {
pset, err := psetv2.NewPsetFromBase64(b64)
if err != nil {
return "", err
}
if indexes == nil {
for i := 0; i < len(pset.Inputs); i++ {
indexes = append(indexes, i)
}
}
key, masterKey, err := s.getPubkey(ctx)
if err != nil {
return "", err
}
fingerprint := binary.LittleEndian.Uint32(masterKey.FingerPrint)
extendedKey, err := masterKey.Serialize()
if err != nil {
return "", err
}
pset.Global.Xpubs = []psetv2.Xpub{{
ExtendedKey: extendedKey[:len(extendedKey)-4],
MasterFingerprint: fingerprint,
DerivationPath: derivationPath,
}}
updater, err := psetv2.NewUpdater(pset)
if err != nil {
return "", err
}
bip32derivation := psetv2.DerivationPathWithPubKey{
PubKey: key.SerializeCompressed(),
MasterKeyFingerprint: fingerprint,
Bip32Path: derivationPath,
}
for _, i := range indexes {
if len(pset.Inputs[i].TapLeafScript) == 0 {
return "", fmt.Errorf("no tap leaf script found for input %d", i)
}
leafHash := pset.Inputs[i].TapLeafScript[0].TapHash()
if err := updater.AddInTapBip32Derivation(i, psetv2.TapDerivationPathWithPubKey{
DerivationPathWithPubKey: bip32derivation,
LeafHashes: [][]byte{leafHash[:]},
}); err != nil {
return "", err
}
if err := updater.AddInSighashType(i, txscript.SigHashDefault); err != nil {
return "", err
}
}
unsignedPset, err := pset.ToBase64()
if err != nil {
return "", err
}
signedPset, err := s.txClient.SignPsetWithSchnorrKey(ctx, &pb.SignPsetWithSchnorrKeyRequest{
Tx: unsignedPset,
SighashType: uint32(txscript.SigHashDefault),
})
if err != nil {
return "", err
}
return signedPset.GetSignedTx(), nil
}
func (s *service) EstimateFees( func (s *service) EstimateFees(
ctx context.Context, pset string, ctx context.Context, pset string,
) (uint64, error) { ) (uint64, error) {
@@ -102,15 +218,30 @@ func (s *service) EstimateFees(
outputs := make([]*pb.Output, 0, len(tx.Outputs)) outputs := make([]*pb.Output, 0, len(tx.Outputs))
for _, in := range tx.Inputs { for _, in := range tx.Inputs {
if in.WitnessUtxo == nil { pbInput := &pb.Input{
return 0, fmt.Errorf("missing witness utxo, cannot estimate fees") Txid: chainhash.Hash(in.PreviousTxid).String(),
Index: in.PreviousTxIndex,
} }
inputs = append(inputs, &pb.Input{ if len(in.TapLeafScript) == 1 {
Txid: chainhash.Hash(in.PreviousTxid).String(), isSweep, _, _, err := tree.DecodeSweepScript(in.TapLeafScript[0].Script)
Index: in.PreviousTxIndex, if err != nil {
Script: hex.EncodeToString(in.WitnessUtxo.Script), return 0, err
}) }
if isSweep {
pbInput.WitnessSize = 64
pbInput.ScriptsigSize = 0
}
} else {
if in.WitnessUtxo == nil {
return 0, fmt.Errorf("missing witness utxo, cannot estimate fees")
}
pbInput.Script = hex.EncodeToString(in.WitnessUtxo.Script)
}
inputs = append(inputs, pbInput)
} }
for _, out := range tx.Outputs { for _, out := range tx.Outputs {

View File

@@ -8,30 +8,16 @@ import (
"github.com/ark-network/ark/internal/core/ports" "github.com/ark-network/ark/internal/core/ports"
"github.com/btcsuite/btcd/btcutil/hdkeychain" "github.com/btcsuite/btcd/btcutil/hdkeychain"
"github.com/decred/dcrd/dcrec/secp256k1/v4" "github.com/decred/dcrd/dcrec/secp256k1/v4"
"github.com/vulpemventures/go-bip32"
) )
const accountLabel = "ark" const accountLabel = "ark"
var derivationPath = []uint32{0, 0}
func (s *service) GetPubkey(ctx context.Context) (*secp256k1.PublicKey, error) { func (s *service) GetPubkey(ctx context.Context) (*secp256k1.PublicKey, error) {
res, err := s.walletClient.GetInfo(ctx, &pb.GetInfoRequest{}) key, _, err := s.getPubkey(ctx)
if err != nil { return key, err
return nil, err
}
if len(res.GetAccounts()) <= 0 {
return nil, fmt.Errorf("wallet is locked")
}
xpub := res.GetAccounts()[0].GetXpubs()[0]
node, err := hdkeychain.NewKeyFromString(xpub)
if err != nil {
return nil, err
}
for i := 0; i < 2; i++ {
node, err = node.Derive(0)
if err != nil {
return nil, err
}
}
return node.ECPubKey()
} }
func (s *service) Status( func (s *service) Status(
@@ -57,3 +43,53 @@ func (w walletStatus) IsUnlocked() bool {
func (w walletStatus) IsSynced() bool { func (w walletStatus) IsSynced() bool {
return w.StatusResponse.GetSynced() return w.StatusResponse.GetSynced()
} }
func (s *service) findAccount(ctx context.Context, label string) (*pb.AccountInfo, error) {
res, err := s.walletClient.GetInfo(ctx, &pb.GetInfoRequest{})
if err != nil {
return nil, err
}
if len(res.GetAccounts()) <= 0 {
return nil, fmt.Errorf("wallet is locked")
}
for _, account := range res.GetAccounts() {
if account.GetLabel() == label {
return account, nil
}
}
return nil, fmt.Errorf("account not found")
}
func (s *service) getPubkey(ctx context.Context) (*secp256k1.PublicKey, *bip32.Key, error) {
account, err := s.findAccount(ctx, accountLabel)
if err != nil {
return nil, nil, err
}
xpub := account.GetXpubs()[0]
node, err := hdkeychain.NewKeyFromString(xpub)
if err != nil {
return nil, nil, err
}
for _, i := range derivationPath {
node, err = node.Derive(i)
if err != nil {
return nil, nil, err
}
}
key, err := node.ECPubKey()
if err != nil {
return nil, nil, err
}
masterKey, err := bip32.B58Deserialize(xpub)
if err != nil {
return nil, nil, err
}
return key, masterKey, nil
}

View File

@@ -1,6 +1,7 @@
package scheduler package scheduler
import ( import (
"fmt"
"time" "time"
"github.com/ark-network/ark/internal/core/ports" "github.com/ark-network/ark/internal/core/ports"
@@ -32,3 +33,13 @@ func (s *service) ScheduleTask(interval int64, immediate bool, task func()) erro
_, err := s.scheduler.Every(int(interval)).Seconds().WaitForSchedule().Do(task) _, err := s.scheduler.Every(int(interval)).Seconds().WaitForSchedule().Do(task)
return err return err
} }
func (s *service) ScheduleTaskOnce(at int64, task func()) error {
delay := at - time.Now().Unix()
if delay < 0 {
return fmt.Errorf("cannot schedule task in the past")
}
_, err := s.scheduler.Every(int(delay)).Seconds().WaitForSchedule().LimitRunsTo(1).Do(task)
return err
}

View File

@@ -21,14 +21,15 @@ const (
) )
type txBuilder struct { type txBuilder struct {
wallet ports.WalletService wallet ports.WalletService
net *network.Network net *network.Network
roundLifetime int64 // in seconds
} }
func NewTxBuilder( func NewTxBuilder(
wallet ports.WalletService, net network.Network, wallet ports.WalletService, net network.Network, roundLifetime int64,
) ports.TxBuilder { ) ports.TxBuilder {
return &txBuilder{wallet, &net} return &txBuilder{wallet, &net, roundLifetime}
} }
func (b *txBuilder) GetVtxoScript(userPubkey, aspPubkey *secp256k1.PublicKey) ([]byte, error) { func (b *txBuilder) GetVtxoScript(userPubkey, aspPubkey *secp256k1.PublicKey) ([]byte, error) {
@@ -39,6 +40,44 @@ func (b *txBuilder) GetVtxoScript(userPubkey, aspPubkey *secp256k1.PublicKey) ([
return outputScript, nil return outputScript, nil
} }
func (b *txBuilder) BuildSweepTx(wallet ports.WalletService, inputs []ports.SweepInput) (signedSweepTx string, err error) {
sweepPset, err := sweepTransaction(
wallet,
inputs,
b.net.AssetID,
)
if err != nil {
return "", err
}
sweepPsetBase64, err := sweepPset.ToBase64()
if err != nil {
return "", err
}
ctx := context.Background()
signedSweepPsetB64, err := wallet.SignPsetWithKey(ctx, sweepPsetBase64, nil)
if err != nil {
return "", err
}
signedPset, err := psetv2.NewPsetFromBase64(signedSweepPsetB64)
if err != nil {
return "", err
}
if err := psetv2.FinalizeAll(signedPset); err != nil {
return "", err
}
extractedTx, err := psetv2.Extract(signedPset)
if err != nil {
return "", err
}
return extractedTx.ToHex()
}
func (b *txBuilder) BuildForfeitTxs( func (b *txBuilder) BuildForfeitTxs(
aspPubkey *secp256k1.PublicKey, poolTx string, payments []domain.Payment, aspPubkey *secp256k1.PublicKey, poolTx string, payments []domain.Payment,
) (connectors []string, forfeitTxs []string, err error) { ) (connectors []string, forfeitTxs []string, err error) {
@@ -74,7 +113,7 @@ func (b *txBuilder) BuildPoolTx(
// This is safe as the memory allocated for `craftCongestionTree` is freed // This is safe as the memory allocated for `craftCongestionTree` is freed
// only after `BuildPoolTx` returns. // only after `BuildPoolTx` returns.
treeFactoryFn, sharedOutputScript, sharedOutputAmount, err := craftCongestionTree( treeFactoryFn, sharedOutputScript, sharedOutputAmount, err := craftCongestionTree(
b.net.AssetID, aspPubkey, payments, minRelayFee, b.net.AssetID, aspPubkey, payments, minRelayFee, b.roundLifetime,
) )
if err != nil { if err != nil {
return return
@@ -109,6 +148,45 @@ func (b *txBuilder) BuildPoolTx(
return return
} }
func (b *txBuilder) GetLeafSweepClosure(
node tree.Node, userPubKey *secp256k1.PublicKey,
) (*psetv2.TapLeafScript, int64, error) {
if !node.Leaf {
return nil, 0, fmt.Errorf("node is not a leaf")
}
pset, err := psetv2.NewPsetFromBase64(node.Tx)
if err != nil {
return nil, 0, err
}
input := pset.Inputs[0]
sweepLeaf, lifetime, err := extractSweepLeaf(input)
if err != nil {
return nil, 0, err
}
// craft the vtxo taproot tree
vtxoScript, err := tree.VtxoScript(userPubKey)
if err != nil {
return nil, 0, err
}
vtxoTaprootTree := taproot.AssembleTaprootScriptTree(
*vtxoScript,
sweepLeaf.TapElementsLeaf,
)
proofIndex := vtxoTaprootTree.LeafProofIndex[sweepLeaf.TapHash()]
proof := vtxoTaprootTree.LeafMerkleProofs[proofIndex]
return &psetv2.TapLeafScript{
TapElementsLeaf: proof.TapElementsLeaf,
ControlBlock: proof.ToControlBlock(sweepLeaf.ControlBlock.InternalKey),
}, lifetime, nil
}
func (b *txBuilder) getLeafScriptAndTree( func (b *txBuilder) getLeafScriptAndTree(
userPubkey, aspPubkey *secp256k1.PublicKey, userPubkey, aspPubkey *secp256k1.PublicKey,
) ([]byte, *taproot.IndexedElementsTapScriptTree, error) { ) ([]byte, *taproot.IndexedElementsTapScriptTree, error) {
@@ -117,7 +195,7 @@ func (b *txBuilder) getLeafScriptAndTree(
return nil, nil, err return nil, nil, err
} }
sweepClosure, err := tree.SweepScript(aspPubkey, expirationTime) sweepClosure, err := tree.SweepScript(aspPubkey, uint(b.roundLifetime))
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
@@ -382,3 +460,24 @@ func (b *txBuilder) createForfeitTxs(
} }
return forfeitTxs, nil return forfeitTxs, nil
} }
// given a congestion tree input, searches and returns the sweep leaf and its lifetime in seconds
func extractSweepLeaf(input psetv2.Input) (sweepLeaf *psetv2.TapLeafScript, lifetime int64, err error) {
for _, leaf := range input.TapLeafScript {
isSweep, _, seconds, err := tree.DecodeSweepScript(leaf.Script)
if err != nil {
return nil, 0, err
}
if isSweep {
lifetime = int64(seconds)
sweepLeaf = &leaf
break
}
}
if sweepLeaf == nil {
return nil, 0, fmt.Errorf("sweep leaf not found")
}
return sweepLeaf, lifetime, nil
}

View File

@@ -44,7 +44,7 @@ func TestMain(m *testing.M) {
} }
func TestBuildPoolTx(t *testing.T) { func TestBuildPoolTx(t *testing.T) {
builder := txbuilder.NewTxBuilder(wallet, network.Liquid) builder := txbuilder.NewTxBuilder(wallet, network.Liquid, roundLifetime)
fixtures, err := parsePoolTxFixtures() fixtures, err := parsePoolTxFixtures()
require.NoError(t, err) require.NoError(t, err)
@@ -79,7 +79,7 @@ func TestBuildPoolTx(t *testing.T) {
} }
func TestBuildForfeitTxs(t *testing.T) { func TestBuildForfeitTxs(t *testing.T) {
builder := txbuilder.NewTxBuilder(wallet, network.Liquid) builder := txbuilder.NewTxBuilder(wallet, network.Liquid, 1209344)
fixtures, err := parseForfeitTxsFixtures() fixtures, err := parseForfeitTxsFixtures()
require.NoError(t, err) require.NoError(t, err)

View File

@@ -97,6 +97,32 @@ func (m *mockedWallet) EstimateFees(ctx context.Context, pset string) (uint64, e
return res, args.Error(1) return res, args.Error(1)
} }
func (m *mockedWallet) IsTransactionPublished(ctx context.Context, txid string) (bool, int64, error) {
args := m.Called(ctx, txid)
var res bool
if a := args.Get(0); a != nil {
res = a.(bool)
}
var blocktime int64
if b := args.Get(1); b != nil {
blocktime = b.(int64)
}
return res, blocktime, args.Error(2)
}
func (m *mockedWallet) SignPsetWithKey(ctx context.Context, pset string, inputIndexes []int) (string, error) {
args := m.Called(ctx, pset, inputIndexes)
var res string
if a := args.Get(0); a != nil {
res = a.(string)
}
return res, args.Error(1)
}
func (m *mockedWallet) WatchScripts( func (m *mockedWallet) WatchScripts(
ctx context.Context, scripts []string, ctx context.Context, scripts []string,
) error { ) error {

View File

@@ -0,0 +1,135 @@
package txbuilder
import (
"context"
"fmt"
"github.com/ark-network/ark/common"
"github.com/ark-network/ark/common/tree"
"github.com/ark-network/ark/internal/core/ports"
"github.com/vulpemventures/go-elements/address"
"github.com/vulpemventures/go-elements/elementsutil"
"github.com/vulpemventures/go-elements/psetv2"
"github.com/vulpemventures/go-elements/taproot"
"github.com/vulpemventures/go-elements/transaction"
)
func sweepTransaction(
wallet ports.WalletService,
sweepInputs []ports.SweepInput,
lbtc string,
) (*psetv2.Pset, error) {
sweepPset, err := psetv2.New(nil, nil, nil)
if err != nil {
return nil, err
}
updater, err := psetv2.NewUpdater(sweepPset)
if err != nil {
return nil, err
}
amount := uint64(0)
for i, input := range sweepInputs {
leaf := input.SweepLeaf
isSweep, _, lifetime, err := tree.DecodeSweepScript(leaf.Script)
if err != nil {
return nil, err
}
if isSweep {
amount += input.Amount
if err := updater.AddInputs([]psetv2.InputArgs{input.InputArgs}); err != nil {
return nil, err
}
if err := updater.AddInTapLeafScript(i, leaf); err != nil {
return nil, err
}
assetHash, err := elementsutil.AssetHashToBytes(lbtc)
if err != nil {
return nil, err
}
value, err := elementsutil.ValueToBytes(input.Amount)
if err != nil {
return nil, err
}
root := leaf.ControlBlock.RootHash(leaf.Script)
taprootKey := taproot.ComputeTaprootOutputKey(leaf.ControlBlock.InternalKey, root)
script, err := taprootOutputScript(taprootKey)
if err != nil {
return nil, err
}
witnessUtxo := transaction.NewTxOutput(assetHash, value, script)
if err := updater.AddInWitnessUtxo(i, witnessUtxo); err != nil {
return nil, err
}
sequence, err := common.BIP68EncodeAsNumber(lifetime)
if err != nil {
return nil, err
}
updater.Pset.Inputs[i].Sequence = sequence
continue
}
return nil, fmt.Errorf("invalid sweep script")
}
ctx := context.Background()
sweepAddress, err := wallet.DeriveAddresses(ctx, 1)
if err != nil {
return nil, err
}
script, err := address.ToOutputScript(sweepAddress[0])
if err != nil {
return nil, err
}
if err := updater.AddOutputs([]psetv2.OutputArgs{
{
Asset: lbtc,
Amount: amount,
Script: script,
},
}); err != nil {
return nil, err
}
b64, err := sweepPset.ToBase64()
if err != nil {
return nil, err
}
fees, err := wallet.EstimateFees(ctx, b64)
if err != nil {
return nil, err
}
if amount < fees {
return nil, fmt.Errorf("insufficient funds (%d) to cover fees (%d) for sweep transaction", amount, fees)
}
updater.Pset.Outputs[0].Value = amount - fees
if err := updater.AddOutputs([]psetv2.OutputArgs{
{
Asset: lbtc,
Amount: fees,
},
}); err != nil {
return nil, err
}
return sweepPset, nil
}

View File

@@ -12,19 +12,16 @@ import (
"github.com/vulpemventures/go-elements/taproot" "github.com/vulpemventures/go-elements/taproot"
) )
const (
expirationTime = 60 * 60 * 24 * 14 // 14 days in seconds
)
type treeFactory func(outpoint psetv2.InputArgs) (tree.CongestionTree, error) type treeFactory func(outpoint psetv2.InputArgs) (tree.CongestionTree, error)
type node struct { type node struct {
sweepKey *secp256k1.PublicKey sweepKey *secp256k1.PublicKey
receivers []domain.Receiver receivers []domain.Receiver
left *node left *node
right *node right *node
asset string asset string
feeSats uint64 feeSats uint64
roundLifetime int64
_inputTaprootKey *secp256k1.PublicKey _inputTaprootKey *secp256k1.PublicKey
_inputTaprootTree *taproot.IndexedElementsTapScriptTree _inputTaprootTree *taproot.IndexedElementsTapScriptTree
@@ -133,7 +130,7 @@ func (n *node) getWitnessData() (
return n._inputTaprootKey, n._inputTaprootTree, nil return n._inputTaprootKey, n._inputTaprootTree, nil
} }
sweepClosure, err := tree.SweepScript(n.sweepKey, expirationTime) sweepClosure, err := tree.SweepScript(n.sweepKey, uint(n.roundLifetime))
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
@@ -203,7 +200,7 @@ func (n *node) getVtxoWitnessData() (
return nil, nil, fmt.Errorf("cannot call vtxoWitness on a non-leaf node") return nil, nil, fmt.Errorf("cannot call vtxoWitness on a non-leaf node")
} }
sweepClosure, err := tree.SweepScript(n.sweepKey, expirationTime) sweepClosure, err := tree.SweepScript(n.sweepKey, uint(n.roundLifetime))
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
@@ -360,14 +357,14 @@ func (n *node) createFinalCongestionTree() treeFactory {
func craftCongestionTree( func craftCongestionTree(
asset string, aspPublicKey *secp256k1.PublicKey, asset string, aspPublicKey *secp256k1.PublicKey,
payments []domain.Payment, feeSatsPerNode uint64, payments []domain.Payment, feeSatsPerNode uint64, roundLifetime int64,
) ( ) (
buildCongestionTree treeFactory, buildCongestionTree treeFactory,
sharedOutputScript []byte, sharedOutputAmount uint64, err error, sharedOutputScript []byte, sharedOutputAmount uint64, err error,
) { ) {
receivers := getOffchainReceivers(payments) receivers := getOffchainReceivers(payments)
root, err := createPartialCongestionTree( root, err := createPartialCongestionTree(
receivers, aspPublicKey, asset, feeSatsPerNode, receivers, aspPublicKey, asset, feeSatsPerNode, roundLifetime,
) )
if err != nil { if err != nil {
return return
@@ -393,6 +390,7 @@ func createPartialCongestionTree(
aspPublicKey *secp256k1.PublicKey, aspPublicKey *secp256k1.PublicKey,
asset string, asset string,
feeSatsPerNode uint64, feeSatsPerNode uint64,
roundLifetime int64,
) (root *node, err error) { ) (root *node, err error) {
if len(receivers) == 0 { if len(receivers) == 0 {
return nil, fmt.Errorf("no receivers provided") return nil, fmt.Errorf("no receivers provided")
@@ -401,10 +399,11 @@ func createPartialCongestionTree(
nodes := make([]*node, 0, len(receivers)) nodes := make([]*node, 0, len(receivers))
for _, r := range receivers { for _, r := range receivers {
leafNode := &node{ leafNode := &node{
sweepKey: aspPublicKey, sweepKey: aspPublicKey,
receivers: []domain.Receiver{r}, receivers: []domain.Receiver{r},
asset: asset, asset: asset,
feeSats: feeSatsPerNode, feeSats: feeSatsPerNode,
roundLifetime: roundLifetime,
} }
nodes = append(nodes, leafNode) nodes = append(nodes, leafNode)
} }
@@ -435,12 +434,13 @@ func createUpperLevel(nodes []*node) ([]*node, error) {
left := nodes[i] left := nodes[i]
right := nodes[i+1] right := nodes[i+1]
branchNode := &node{ branchNode := &node{
sweepKey: left.sweepKey, sweepKey: left.sweepKey,
receivers: append(left.receivers, right.receivers...), receivers: append(left.receivers, right.receivers...),
left: left, left: left,
right: right, right: right,
asset: left.asset, asset: left.asset,
feeSats: left.feeSats, feeSats: left.feeSats,
roundLifetime: left.roundLifetime,
} }
pairs = append(pairs, branchNode) pairs = append(pairs, branchNode)
} }

View File

@@ -16,6 +16,7 @@ import (
const ( const (
connectorAmount = 450 connectorAmount = 450
sevenDays = 7 * 24 * 60 * 60
) )
type txBuilder struct { type txBuilder struct {
@@ -29,6 +30,11 @@ func NewTxBuilder(
return &txBuilder{wallet, net} return &txBuilder{wallet, net}
} }
// BuildSweepTx implements ports.TxBuilder.
func (*txBuilder) BuildSweepTx(wallet ports.WalletService, inputs []ports.SweepInput) (signedSweepTx string, err error) {
panic("unimplemented")
}
// BuildForfeitTxs implements ports.TxBuilder. // BuildForfeitTxs implements ports.TxBuilder.
func (b *txBuilder) BuildForfeitTxs( func (b *txBuilder) BuildForfeitTxs(
aspPubkey *secp256k1.PublicKey, poolTx string, payments []domain.Payment, aspPubkey *secp256k1.PublicKey, poolTx string, payments []domain.Payment,
@@ -185,6 +191,12 @@ func (b *txBuilder) GetVtxoScript(userPubkey, _ *secp256k1.PublicKey) ([]byte, e
return address.ToOutputScript(addr) return address.ToOutputScript(addr)
} }
func (b *txBuilder) GetLeafSweepClosure(
node tree.Node, userPubKey *secp256k1.PublicKey,
) (*psetv2.TapLeafScript, int64, error) {
panic("unimplemented")
}
func connectorsToInputArgs(connectors []string) ([]psetv2.InputArgs, error) { func connectorsToInputArgs(connectors []string) ([]psetv2.InputArgs, error) {
inputs := make([]psetv2.InputArgs, 0, len(connectors)+1) inputs := make([]psetv2.InputArgs, 0, len(connectors)+1)
for i, psetb64 := range connectors { for i, psetb64 := range connectors {

View File

@@ -109,6 +109,14 @@ func (*mockedWalletService) EstimateFees(ctx context.Context, pset string) (uint
return 100, nil return 100, nil
} }
func (*mockedWalletService) SignPsetWithKey(ctx context.Context, pset string, inputIndex []int) (string, error) {
panic("unimplemented")
}
func (*mockedWalletService) IsTransactionPublished(ctx context.Context, txid string) (bool, int64, error) {
panic("unimplemented")
}
func TestBuildCongestionTree(t *testing.T) { func TestBuildCongestionTree(t *testing.T) {
builder := txbuilder.NewTxBuilder(&mockedWalletService{}, network.Liquid) builder := txbuilder.NewTxBuilder(&mockedWalletService{}, network.Liquid)

View File

@@ -52,6 +52,7 @@ func (s *service) Start() error {
return fmt.Errorf("failed to start app service: %s", err) return fmt.Errorf("failed to start app service: %s", err)
} }
log.Info("started app service") log.Info("started app service")
return nil return nil
} }

View File

@@ -18,17 +18,25 @@ func closerToModulo512(x uint) uint {
return x - (x % 512) return x - (x % 512)
} }
// BIP68Encode returns the encoded sequence locktime for the given number of seconds. func BIP68EncodeAsNumber(seconds uint) (uint32, error) {
func BIP68Encode(seconds uint) ([]byte, error) {
seconds = closerToModulo512(seconds) seconds = closerToModulo512(seconds)
if seconds > SECONDS_MAX { if seconds > SECONDS_MAX {
return nil, fmt.Errorf("seconds too large, max is %d", SECONDS_MAX) return 0, fmt.Errorf("seconds too large, max is %d", SECONDS_MAX)
} }
if seconds%SECONDS_MOD != 0 { if seconds%SECONDS_MOD != 0 {
return nil, fmt.Errorf("seconds must be a multiple of %d", SECONDS_MOD) return 0, fmt.Errorf("seconds must be a multiple of %d", SECONDS_MOD)
} }
asNumber := SEQUENCE_LOCKTIME_TYPE_FLAG | (seconds >> SEQUENCE_LOCKTIME_GRANULARITY) asNumber := SEQUENCE_LOCKTIME_TYPE_FLAG | (seconds >> SEQUENCE_LOCKTIME_GRANULARITY)
return uint32(asNumber), nil
}
// BIP68Encode returns the encoded sequence locktime for the given number of seconds.
func BIP68Encode(seconds uint) ([]byte, error) {
asNumber, err := BIP68EncodeAsNumber(seconds)
if err != nil {
return nil, err
}
hexString := fmt.Sprintf("%x", asNumber) hexString := fmt.Sprintf("%x", asNumber)
reversed, err := hex.DecodeString(hexString) reversed, err := hex.DecodeString(hexString)
if err != nil { if err != nil {

View File

@@ -1,6 +1,11 @@
package tree package tree
import "errors" import (
"errors"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/vulpemventures/go-elements/psetv2"
)
// Node is a struct embedding the transaction and the parent txid of a congestion tree node // Node is a struct embedding the transaction and the parent txid of a congestion tree node
type Node struct { type Node struct {
@@ -19,6 +24,19 @@ var (
// the first level of the matrix is the root of the tree // the first level of the matrix is the root of the tree
type CongestionTree [][]Node type CongestionTree [][]Node
// Root returns the root node of the congestion tree
func (c CongestionTree) Root() (Node, error) {
if len(c) <= 0 {
return Node{}, errors.New("empty congestion tree")
}
if len(c[0]) <= 0 {
return Node{}, errors.New("empty congestion tree")
}
return c[0][0], nil
}
// Leaves returns the leaves of the congestion tree (the vtxos txs) // Leaves returns the leaves of the congestion tree (the vtxos txs)
func (c CongestionTree) Leaves() []Node { func (c CongestionTree) Leaves() []Node {
leaves := c[len(c)-1] leaves := c[len(c)-1]
@@ -47,6 +65,7 @@ func (c CongestionTree) Children(nodeTxid string) []Node {
return children return children
} }
// NumberOfNodes returns the total number of pset in the congestion tree
func (c CongestionTree) NumberOfNodes() int { func (c CongestionTree) NumberOfNodes() int {
var count int var count int
for _, level := range c { for _, level := range c {
@@ -55,6 +74,7 @@ func (c CongestionTree) NumberOfNodes() int {
return count return count
} }
// Branch returns the branch of the given vtxo txid from root to leaf in the order of the congestion tree
func (c CongestionTree) Branch(vtxoTxid string) ([]Node, error) { func (c CongestionTree) Branch(vtxoTxid string) ([]Node, error) {
branch := make([]Node, 0) branch := make([]Node, 0)
@@ -85,6 +105,37 @@ func (c CongestionTree) Branch(vtxoTxid string) ([]Node, error) {
return branch, nil return branch, nil
} }
// FindLeaves returns all the leaves that are reachable from the given node output
func (c CongestionTree) FindLeaves(fromtxid string, vout uint32) ([]Node, error) {
allLeaves := c.Leaves()
foundLeaves := make([]Node, 0)
for _, leaf := range allLeaves {
branch, err := c.Branch(leaf.Txid)
if err != nil {
return nil, err
}
for _, node := range branch {
pset, err := psetv2.NewPsetFromBase64(node.Tx)
if err != nil {
return nil, err
}
input := pset.Inputs[0]
txid := chainhash.Hash(input.PreviousTxid).String()
index := input.PreviousTxIndex
if txid == fromtxid && index == vout {
foundLeaves = append(foundLeaves, leaf)
break
}
}
}
return foundLeaves, nil
}
func (n Node) findParent(tree CongestionTree) (Node, error) { func (n Node) findParent(tree CongestionTree) (Node, error) {
for _, level := range tree { for _, level := range tree {
for _, node := range level { for _, node := range level {

View File

@@ -106,12 +106,8 @@ func decodeWithOutputScript(script []byte, expectedIndex byte, isVerify bool) (v
return false, nil, 0, err return false, nil, 0, err
} }
inspectOutputValueIndex := bytes.IndexByte(script, OP_INSPECTOUTPUTVALUE) // verify the index of INSPECTVALUE
if inspectOutputValueIndex == -1 { if script[38] != expectedIndex {
return false, nil, 0, nil
}
if script[inspectOutputValueIndex-1] != expectedIndex {
return false, nil, 0, nil return false, nil, 0, nil
} }
@@ -128,12 +124,12 @@ func decodeWithOutputScript(script []byte, expectedIndex byte, isVerify bool) (v
} }
func decodeChecksigScript(script []byte) (valid bool, pubkey *secp256k1.PublicKey, err error) { func decodeChecksigScript(script []byte) (valid bool, pubkey *secp256k1.PublicKey, err error) {
checksigIndex := bytes.Index(script, []byte{txscript.OP_CHECKSIG}) data32Index := bytes.Index(script, []byte{txscript.OP_DATA_32})
if checksigIndex == -1 || checksigIndex == 0 { if data32Index == -1 {
return false, nil, nil return false, nil, nil
} }
key := script[1:checksigIndex] key := script[data32Index+1 : data32Index+33]
if len(key) != 32 { if len(key) != 32 {
return false, nil, nil return false, nil, nil
} }
@@ -155,13 +151,13 @@ func decodeChecksigScript(script []byte) (valid bool, pubkey *secp256k1.PublicKe
return true, pubkey, nil return true, pubkey, nil
} }
func decodeSweepScript(script []byte) (valid bool, aspPubKey *secp256k1.PublicKey, seconds uint, err error) { func DecodeSweepScript(script []byte) (valid bool, aspPubKey *secp256k1.PublicKey, seconds uint, err error) {
csvIndex := bytes.Index(script, []byte{txscript.OP_CHECKSEQUENCEVERIFY, txscript.OP_DROP}) csvIndex := bytes.Index(script, []byte{txscript.OP_CHECKSEQUENCEVERIFY, txscript.OP_DROP})
if csvIndex == -1 || csvIndex == 0 { if csvIndex == -1 || csvIndex == 0 {
return false, nil, 0, nil return false, nil, 0, nil
} }
sequence := script[:csvIndex] sequence := script[1:csvIndex]
seconds, err = common.BIP68Decode(sequence) seconds, err = common.BIP68Decode(sequence)
if err != nil { if err != nil {
@@ -174,6 +170,10 @@ func decodeSweepScript(script []byte) (valid bool, aspPubKey *secp256k1.PublicKe
return false, nil, 0, err return false, nil, 0, err
} }
if !valid {
return false, nil, 0, nil
}
rebuilt, err := csvChecksigScript(aspPubKey, seconds) rebuilt, err := csvChecksigScript(aspPubKey, seconds)
if err != nil { if err != nil {
return false, nil, 0, err return false, nil, 0, err
@@ -193,10 +193,10 @@ func checkSequenceVerifyScript(seconds uint) ([]byte, error) {
return nil, err return nil, err
} }
return append(sequence, []byte{ return txscript.NewScriptBuilder().AddData(sequence).AddOps([]byte{
txscript.OP_CHECKSEQUENCEVERIFY, txscript.OP_CHECKSEQUENCEVERIFY,
txscript.OP_DROP, txscript.OP_DROP,
}...), nil }).Script()
} }
// checkSequenceVerifyScript + checksig // checkSequenceVerifyScript + checksig

View File

@@ -229,7 +229,7 @@ func validateNodeTransaction(
return ErrInvalidTaprootScript return ErrInvalidTaprootScript
} }
isSweepLeaf, aspKey, seconds, err := decodeSweepScript(tapLeaf.Script) isSweepLeaf, aspKey, seconds, err := DecodeSweepScript(tapLeaf.Script)
if err != nil { if err != nil {
return fmt.Errorf("invalid sweep script: %w", err) return fmt.Errorf("invalid sweep script: %w", err)
} }

View File

@@ -31,7 +31,7 @@ require (
github.com/golang/protobuf v1.5.3 // indirect github.com/golang/protobuf v1.5.3 // indirect
github.com/grpc-ecosystem/grpc-gateway/v2 v2.18.1 // indirect github.com/grpc-ecosystem/grpc-gateway/v2 v2.18.1 // indirect
github.com/russross/blackfriday/v2 v2.1.0 // indirect github.com/russross/blackfriday/v2 v2.1.0 // indirect
github.com/vulpemventures/go-elements v0.5.2 github.com/vulpemventures/go-elements v0.5.3
github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673 // indirect github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673 // indirect
golang.org/x/net v0.19.0 // indirect golang.org/x/net v0.19.0 // indirect
golang.org/x/sys v0.15.0 // indirect golang.org/x/sys v0.15.0 // indirect

View File

@@ -91,8 +91,8 @@ github.com/urfave/cli/v2 v2.26.0 h1:3f3AMg3HpThFNT4I++TKOejZO8yU55t3JnnSr4S4QEI=
github.com/urfave/cli/v2 v2.26.0/go.mod h1:8qnjx1vcq5s2/wpsqoZFndg2CE5tNFyrTvS6SinrnYQ= github.com/urfave/cli/v2 v2.26.0/go.mod h1:8qnjx1vcq5s2/wpsqoZFndg2CE5tNFyrTvS6SinrnYQ=
github.com/vulpemventures/fastsha256 v0.0.0-20160815193821-637e65642941 h1:CTcw80hz/Sw8hqlKX5ZYvBUF5gAHSHwdjXxRf/cjDcI= github.com/vulpemventures/fastsha256 v0.0.0-20160815193821-637e65642941 h1:CTcw80hz/Sw8hqlKX5ZYvBUF5gAHSHwdjXxRf/cjDcI=
github.com/vulpemventures/fastsha256 v0.0.0-20160815193821-637e65642941/go.mod h1:GXBJykxW2kUcktGdsgyay7uwwWvkljASfljNcT0mbh8= github.com/vulpemventures/fastsha256 v0.0.0-20160815193821-637e65642941/go.mod h1:GXBJykxW2kUcktGdsgyay7uwwWvkljASfljNcT0mbh8=
github.com/vulpemventures/go-elements v0.5.2 h1:vIDzVpRXG5PnlzHA8tCnr2Tn7raIV5cHy7bRyDrbuM4= github.com/vulpemventures/go-elements v0.5.3 h1:zaC/ynHFwCAzFSOMfzb6BcbD6FXASppSiGMycc95WVA=
github.com/vulpemventures/go-elements v0.5.2/go.mod h1:aBGuWXHaiAIUIcwqCdtEh2iQ3kJjKwHU9ywvhlcRSeU= github.com/vulpemventures/go-elements v0.5.3/go.mod h1:aBGuWXHaiAIUIcwqCdtEh2iQ3kJjKwHU9ywvhlcRSeU=
github.com/vulpemventures/go-secp256k1-zkp v1.1.6 h1:BmsrmXRLUibwa75Qkk8yELjpzCzlAjYFGLiLiOdq7Xo= github.com/vulpemventures/go-secp256k1-zkp v1.1.6 h1:BmsrmXRLUibwa75Qkk8yELjpzCzlAjYFGLiLiOdq7Xo=
github.com/vulpemventures/go-secp256k1-zkp v1.1.6/go.mod h1:zo7CpgkuPgoe7fAV+inyxsI9IhGmcoFgyD8nqZaPSOM= github.com/vulpemventures/go-secp256k1-zkp v1.1.6/go.mod h1:zo7CpgkuPgoe7fAV+inyxsI9IhGmcoFgyD8nqZaPSOM=
github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673 h1:bAn7/zixMGCfxrRTfdpNzjtPYqr8smhKouy9mxVdGPU= github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673 h1:bAn7/zixMGCfxrRTfdpNzjtPYqr8smhKouy9mxVdGPU=

View File

@@ -218,16 +218,6 @@ func unilateralRedeem(ctx *cli.Context, addr string) error {
transactionsMap := make(map[string]struct{}, 0) transactionsMap := make(map[string]struct{}, 0)
transactions := make([]string, 0) transactions := make([]string, 0)
aspPublicKey, err := getServiceProviderPublicKey()
if err != nil {
return err
}
sweepLeaf, err := tree.SweepScript(aspPublicKey, 1209344)
if err != nil {
return err
}
for _, vtxo := range vtxos { for _, vtxo := range vtxos {
if _, ok := congestionTrees[vtxo.poolTxid]; !ok { if _, ok := congestionTrees[vtxo.poolTxid]; !ok {
round, err := client.GetRound(ctx.Context, &arkv1.GetRoundRequest{ round, err := client.GetRound(ctx.Context, &arkv1.GetRoundRequest{
@@ -246,7 +236,7 @@ func unilateralRedeem(ctx *cli.Context, addr string) error {
congestionTrees[vtxo.poolTxid] = congestionTree congestionTrees[vtxo.poolTxid] = congestionTree
} }
redeemBranch, err := newRedeemBranch(ctx, congestionTrees[vtxo.poolTxid], vtxo, sweepLeaf) redeemBranch, err := newRedeemBranch(ctx, congestionTrees[vtxo.poolTxid], vtxo)
if err != nil { if err != nil {
return err return err
} }

View File

@@ -1,7 +1,6 @@
package main package main
import ( import (
"bytes"
"fmt" "fmt"
"github.com/ark-network/ark/common/tree" "github.com/ark-network/ark/common/tree"
@@ -24,16 +23,40 @@ type RedeemBranch interface {
type redeemBranch struct { type redeemBranch struct {
vtxo *vtxo vtxo *vtxo
branch []*psetv2.Pset branch []*psetv2.Pset
sweepTapLeaf *taproot.TapElementsLeaf
internalKey *secp256k1.PublicKey internalKey *secp256k1.PublicKey
sweepClosure *taproot.TapElementsLeaf
} }
func newRedeemBranch(ctx *cli.Context, congestionTree tree.CongestionTree, vtxo vtxo, sweepLeaf *taproot.TapElementsLeaf) (RedeemBranch, error) { func newRedeemBranch(ctx *cli.Context, congestionTree tree.CongestionTree, vtxo vtxo) (RedeemBranch, error) {
nodes, err := congestionTree.Branch(vtxo.txid) nodes, err := congestionTree.Branch(vtxo.txid)
if err != nil { if err != nil {
return nil, err return nil, err
} }
// find the sweep closure
tx, err := psetv2.NewPsetFromBase64(nodes[0].Tx)
if err != nil {
return nil, err
}
var sweepClosure *taproot.TapElementsLeaf
for _, tapLeaf := range tx.Inputs[0].TapLeafScript {
isSweep, _, _, err := tree.DecodeSweepScript(tapLeaf.Script)
if err != nil {
continue
}
if isSweep {
sweepClosure = &tapLeaf.TapElementsLeaf
break
}
}
if sweepClosure == nil {
return nil, fmt.Errorf("sweep closure not found")
}
branch := make([]*psetv2.Pset, 0, len(nodes)) branch := make([]*psetv2.Pset, 0, len(nodes))
for _, node := range nodes { for _, node := range nodes {
pset, err := psetv2.NewPsetFromBase64(node.Tx) pset, err := psetv2.NewPsetFromBase64(node.Tx)
@@ -52,8 +75,8 @@ func newRedeemBranch(ctx *cli.Context, congestionTree tree.CongestionTree, vtxo
return &redeemBranch{ return &redeemBranch{
vtxo: &vtxo, vtxo: &vtxo,
branch: branch, branch: branch,
sweepTapLeaf: sweepLeaf,
internalKey: internalKey, internalKey: internalKey,
sweepClosure: sweepClosure,
}, nil }, nil
} }
@@ -96,10 +119,13 @@ func (r *redeemBranch) RedeemPath() ([]string, error) {
return nil, fmt.Errorf("tap leaf script not found on input #%d", i) return nil, fmt.Errorf("tap leaf script not found on input #%d", i)
} }
sweepTapLeafScript := r.sweepTapLeaf.Script
for _, leaf := range input.TapLeafScript { for _, leaf := range input.TapLeafScript {
if bytes.Equal(leaf.Script, sweepTapLeafScript) { isSweep, _, _, err := tree.DecodeSweepScript(leaf.Script)
if err != nil {
return nil, err
}
if isSweep {
continue continue
} }
@@ -159,7 +185,7 @@ func (r *redeemBranch) AddVtxoInput(updater *psetv2.Updater) error {
vtxoTaprootTree := taproot.AssembleTaprootScriptTree( vtxoTaprootTree := taproot.AssembleTaprootScriptTree(
*checksigLeaf, *checksigLeaf,
*r.sweepTapLeaf, *r.sweepClosure,
) )
proofIndex := vtxoTaprootTree.LeafProofIndex[checksigLeaf.TapHash()] proofIndex := vtxoTaprootTree.LeafProofIndex[checksigLeaf.TapHash()]