Rename folders (#97)

* Rename arkd folder & drop cli

* Rename ark cli folder & update docs

* Update readme

* Fix

* scripts: add build-all

* Add target to build cli for all platforms

* Update build scripts

---------

Co-authored-by: tiero <3596602+tiero@users.noreply.github.com>
This commit is contained in:
Pietralberto Mazza
2024-02-09 19:32:58 +01:00
committed by GitHub
parent 0d8c7bffb2
commit dc00d60585
119 changed files with 154 additions and 449 deletions

View File

@@ -0,0 +1,249 @@
package appconfig
import (
"fmt"
"strings"
"github.com/ark-network/ark/common"
"github.com/ark-network/ark/internal/core/application"
"github.com/ark-network/ark/internal/core/ports"
"github.com/ark-network/ark/internal/infrastructure/db"
oceanwallet "github.com/ark-network/ark/internal/infrastructure/ocean-wallet"
scheduler "github.com/ark-network/ark/internal/infrastructure/scheduler/gocron"
txbuilder "github.com/ark-network/ark/internal/infrastructure/tx-builder/covenant"
txbuilderdummy "github.com/ark-network/ark/internal/infrastructure/tx-builder/dummy"
log "github.com/sirupsen/logrus"
"github.com/vulpemventures/go-elements/network"
)
var (
supportedDbs = supportedType{
"badger": {},
}
supportedSchedulers = supportedType{
"gocron": {},
}
supportedTxBuilders = supportedType{
"dummy": {},
"covenant": {},
}
supportedScanners = supportedType{
"ocean": {},
}
)
type Config struct {
DbType string
DbDir string
RoundInterval int64
Network common.Network
SchedulerType string
TxBuilderType string
BlockchainScannerType string
WalletAddr string
MinRelayFee uint64
RoundLifetime int64
repo ports.RepoManager
svc application.Service
wallet ports.WalletService
txBuilder ports.TxBuilder
scanner ports.BlockchainScanner
scheduler ports.SchedulerService
}
func (c *Config) Validate() error {
if !supportedDbs.supports(c.DbType) {
return fmt.Errorf("db type not supported, please select one of: %s", supportedDbs)
}
if !supportedSchedulers.supports(c.SchedulerType) {
return fmt.Errorf("scheduler type not supported, please select one of: %s", supportedSchedulers)
}
if !supportedTxBuilders.supports(c.TxBuilderType) {
return fmt.Errorf("tx builder type not supported, please select one of: %s", supportedTxBuilders)
}
if !supportedScanners.supports(c.BlockchainScannerType) {
return fmt.Errorf("blockchain scanner type not supported, please select one of: %s", supportedScanners)
}
if c.RoundInterval < 5 {
return fmt.Errorf("invalid round interval, must be at least 5 seconds")
}
if c.Network.Name != "liquid" && c.Network.Name != "testnet" {
return fmt.Errorf("invalid network, must be either liquid or testnet")
}
if len(c.WalletAddr) <= 0 {
return fmt.Errorf("missing onchain wallet address")
}
if c.MinRelayFee < 30 {
return fmt.Errorf("invalid min relay fee, must be at least 30 sats")
}
if err := c.repoManager(); err != nil {
return err
}
if err := c.walletService(); err != nil {
return fmt.Errorf("failed to connect to wallet: %s", err)
}
if err := c.txBuilderService(); err != nil {
return err
}
if err := c.scannerService(); err != nil {
return err
}
if err := c.schedulerService(); err != nil {
return err
}
if err := c.appService(); err != nil {
return err
}
// round life time must be a multiple of 512
if c.RoundLifetime <= 0 || c.RoundLifetime%512 != 0 {
return fmt.Errorf("invalid round lifetime, must be greater than 0 and a multiple of 512")
}
seq, err := common.BIP68Encode(uint(c.RoundLifetime))
if err != nil {
return fmt.Errorf("invalid round lifetime, %s", err)
}
seconds, err := common.BIP68Decode(seq)
if err != nil {
return fmt.Errorf("invalid round lifetime, %s", err)
}
if seconds != uint(c.RoundLifetime) {
return fmt.Errorf("invalid round lifetime, must be a multiple of 512")
}
return nil
}
func (c *Config) AppService() application.Service {
return c.svc
}
func (c *Config) repoManager() error {
var svc ports.RepoManager
var err error
switch c.DbType {
case "badger":
logger := log.New()
svc, err = db.NewService(db.ServiceConfig{
EventStoreType: c.DbType,
RoundStoreType: c.DbType,
VtxoStoreType: c.DbType,
EventStoreConfig: []interface{}{c.DbDir, logger},
RoundStoreConfig: []interface{}{c.DbDir, logger},
VtxoStoreConfig: []interface{}{c.DbDir, logger},
})
default:
return fmt.Errorf("unknown db type")
}
if err != nil {
return err
}
c.repo = svc
return nil
}
func (c *Config) walletService() error {
svc, err := oceanwallet.NewService(c.WalletAddr)
if err != nil {
return err
}
c.wallet = svc
return nil
}
func (c *Config) txBuilderService() error {
var svc ports.TxBuilder
var err error
net := c.mainChain()
switch c.TxBuilderType {
case "dummy":
svc = txbuilderdummy.NewTxBuilder(c.wallet, net)
case "covenant":
svc = txbuilder.NewTxBuilder(c.wallet, net, c.RoundLifetime)
default:
err = fmt.Errorf("unknown tx builder type")
}
if err != nil {
return err
}
c.txBuilder = svc
return nil
}
func (c *Config) scannerService() error {
var svc ports.BlockchainScanner
var err error
switch c.BlockchainScannerType {
case "ocean":
svc = c.wallet
default:
err = fmt.Errorf("unknown blockchain scanner type")
}
if err != nil {
return err
}
c.scanner = svc
return nil
}
func (c *Config) schedulerService() error {
var svc ports.SchedulerService
var err error
switch c.SchedulerType {
case "gocron":
svc = scheduler.NewScheduler()
default:
err = fmt.Errorf("unknown scheduler type")
}
if err != nil {
return err
}
c.scheduler = svc
return nil
}
func (c *Config) appService() error {
net := c.mainChain()
svc, err := application.NewService(
c.Network, net, c.RoundInterval, c.RoundLifetime, c.MinRelayFee,
c.wallet, c.repo, c.txBuilder, c.scanner, c.scheduler,
)
if err != nil {
return err
}
c.svc = svc
return nil
}
func (c *Config) mainChain() network.Network {
net := network.Liquid
if c.Network.Name != "mainnet" {
net = network.Testnet
}
return net
}
type supportedType map[string]struct{}
func (t supportedType) String() string {
types := make([]string, 0, len(t))
for tt := range t {
types = append(types, tt)
}
return strings.Join(types, " | ")
}
func (t supportedType) supports(typeStr string) bool {
_, ok := t[typeStr]
return ok
}

View File

@@ -0,0 +1,122 @@
package config
import (
"fmt"
"os"
"path/filepath"
"strings"
common "github.com/ark-network/ark/common"
"github.com/spf13/viper"
)
type Config struct {
WalletAddr string
RoundInterval int64
Port uint32
DbType string
DbDir string
SchedulerType string
TxBuilderType string
BlockchainScannerType string
NoTLS bool
Network common.Network
LogLevel int
MinRelayFee uint64
RoundLifetime int64
}
var (
Datadir = "DATADIR"
WalletAddr = "WALLET_ADDR"
RoundInterval = "ROUND_INTERVAL"
Port = "PORT"
DbType = "DB_TYPE"
SchedulerType = "SCHEDULER_TYPE"
TxBuilderType = "TX_BUILDER_TYPE"
BlockchainScannerType = "BC_SCANNER_TYPE"
Insecure = "INSECURE"
LogLevel = "LOG_LEVEL"
Network = "NETWORK"
MinRelayFee = "MIN_RELAY_FEE"
RoundLifetime = "ROUND_LIFETIME"
defaultDatadir = common.AppDataDir("arkd", false)
defaultRoundInterval = 60
defaultPort = 6000
defaultDbType = "badger"
defaultSchedulerType = "gocron"
defaultTxBuilderType = "covenant"
defaultBlockchainScannerType = "ocean"
defaultInsecure = true
defaultNetwork = "testnet"
defaultLogLevel = 5
defaultMinRelayFee = 30
defaultRoundLifetime = 512
)
func LoadConfig() (*Config, error) {
viper.SetEnvPrefix("ARK")
viper.AutomaticEnv()
viper.SetDefault(Datadir, defaultDatadir)
viper.SetDefault(RoundInterval, defaultRoundInterval)
viper.SetDefault(Port, defaultPort)
viper.SetDefault(DbType, defaultDbType)
viper.SetDefault(SchedulerType, defaultSchedulerType)
viper.SetDefault(TxBuilderType, defaultTxBuilderType)
viper.SetDefault(BlockchainScannerType, defaultBlockchainScannerType)
viper.SetDefault(Insecure, defaultInsecure)
viper.SetDefault(LogLevel, defaultLogLevel)
viper.SetDefault(Network, defaultNetwork)
viper.SetDefault(RoundLifetime, defaultRoundLifetime)
viper.SetDefault(MinRelayFee, defaultMinRelayFee)
net, err := getNetwork()
if err != nil {
return nil, err
}
if err := initDatadir(); err != nil {
return nil, fmt.Errorf("error while creating datadir: %s", err)
}
return &Config{
WalletAddr: viper.GetString(WalletAddr),
RoundInterval: viper.GetInt64(RoundInterval),
Port: viper.GetUint32(Port),
DbType: viper.GetString(DbType),
SchedulerType: viper.GetString(SchedulerType),
TxBuilderType: viper.GetString(TxBuilderType),
BlockchainScannerType: viper.GetString(BlockchainScannerType),
NoTLS: viper.GetBool(Insecure),
DbDir: filepath.Join(viper.GetString(Datadir), "db"),
LogLevel: viper.GetInt(LogLevel),
Network: net,
MinRelayFee: viper.GetUint64(MinRelayFee),
RoundLifetime: viper.GetInt64(RoundLifetime),
}, nil
}
func initDatadir() error {
datadir := viper.GetString(Datadir)
return makeDirectoryIfNotExists(datadir)
}
func makeDirectoryIfNotExists(path string) error {
if _, err := os.Stat(path); os.IsNotExist(err) {
return os.MkdirAll(path, os.ModeDir|0755)
}
return nil
}
func getNetwork() (common.Network, error) {
switch strings.ToLower(viper.GetString(Network)) {
case "mainnet":
return common.MainNet, nil
case "testnet":
return common.TestNet, nil
default:
return common.Network{}, fmt.Errorf("unknown network")
}
}

View File

@@ -0,0 +1,601 @@
package application
import (
"bytes"
"context"
"encoding/hex"
"fmt"
"time"
"github.com/ark-network/ark/common"
"github.com/ark-network/ark/internal/core/domain"
"github.com/ark-network/ark/internal/core/ports"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/decred/dcrd/dcrec/secp256k1/v4"
log "github.com/sirupsen/logrus"
"github.com/vulpemventures/go-elements/network"
"github.com/vulpemventures/go-elements/psetv2"
)
var (
paymentsThreshold = int64(128)
dustAmount = uint64(450)
faucetVtxo = domain.VtxoKey{
Txid: "0000000000000000000000000000000000000000000000000000000000000000",
VOut: 0,
}
)
type Service interface {
Start() error
Stop()
SpendVtxos(ctx context.Context, inputs []domain.VtxoKey) (string, error)
ClaimVtxos(ctx context.Context, creds string, receivers []domain.Receiver) error
SignVtxos(ctx context.Context, forfeitTxs []string) error
FaucetVtxos(ctx context.Context, pubkey *secp256k1.PublicKey) error
GetRoundByTxid(ctx context.Context, poolTxid string) (*domain.Round, error)
GetEventsChannel(ctx context.Context) <-chan domain.RoundEvent
UpdatePaymentStatus(ctx context.Context, id string) error
ListVtxos(ctx context.Context, pubkey *secp256k1.PublicKey) ([]domain.Vtxo, error)
GetPubkey(ctx context.Context) (string, error)
}
type service struct {
network common.Network
onchainNework network.Network
pubkey *secp256k1.PublicKey
roundLifetime int64
roundInterval int64
minRelayFee uint64
wallet ports.WalletService
repoManager ports.RepoManager
builder ports.TxBuilder
scanner ports.BlockchainScanner
sweeper *sweeper
paymentRequests *paymentsMap
forfeitTxs *forfeitTxsMap
eventsCh chan domain.RoundEvent
}
func NewService(
network common.Network, onchainNetwork network.Network,
roundInterval, roundLifetime int64, minRelayFee uint64,
walletSvc ports.WalletService, repoManager ports.RepoManager,
builder ports.TxBuilder, scanner ports.BlockchainScanner,
scheduler ports.SchedulerService,
) (Service, error) {
eventsCh := make(chan domain.RoundEvent)
paymentRequests := newPaymentsMap(nil)
genesisHash, _ := chainhash.NewHashFromStr(onchainNetwork.GenesisBlockHash)
forfeitTxs := newForfeitTxsMap(genesisHash)
pubkey, err := walletSvc.GetPubkey(context.Background())
if err != nil {
return nil, fmt.Errorf("failed to fetch pubkey: %s", err)
}
sweeper := newSweeper(walletSvc, repoManager, builder, scheduler)
svc := &service{
network, onchainNetwork, pubkey,
roundLifetime, roundInterval, minRelayFee,
walletSvc, repoManager, builder, scanner, sweeper,
paymentRequests, forfeitTxs, eventsCh,
}
repoManager.RegisterEventsHandler(
func(round *domain.Round) {
svc.updateProjectionStore(round)
svc.propagateEvents(round)
},
)
if err := svc.restoreWatchingVtxos(); err != nil {
return nil, fmt.Errorf("failed to restore watching vtxos: %s", err)
}
go svc.listenToRedemptions()
return svc, nil
}
func (s *service) Start() error {
log.Debug("starting sweeper service")
if err := s.sweeper.start(); err != nil {
return err
}
log.Debug("starting app service")
go s.start()
return nil
}
func (s *service) Stop() {
s.sweeper.stop()
// nolint
vtxos, _ := s.repoManager.Vtxos().GetSpendableVtxos(
context.Background(), "",
)
if len(vtxos) > 0 {
s.stopWatchingVtxos(vtxos)
}
s.wallet.Close()
log.Debug("closed connection to wallet")
s.repoManager.Close()
log.Debug("closed connection to db")
}
func (s *service) SpendVtxos(ctx context.Context, inputs []domain.VtxoKey) (string, error) {
vtxos, err := s.repoManager.Vtxos().GetVtxos(ctx, inputs)
if err != nil {
return "", err
}
for _, v := range vtxos {
if v.Spent {
return "", fmt.Errorf("input %s:%d already spent", v.Txid, v.VOut)
}
}
payment, err := domain.NewPayment(vtxos)
if err != nil {
return "", err
}
if err := s.paymentRequests.push(*payment); err != nil {
return "", err
}
return payment.Id, nil
}
func (s *service) ClaimVtxos(ctx context.Context, creds string, receivers []domain.Receiver) error {
// Check credentials
payment, ok := s.paymentRequests.view(creds)
if !ok {
return fmt.Errorf("invalid credentials")
}
if err := payment.AddReceivers(receivers); err != nil {
return err
}
return s.paymentRequests.update(*payment)
}
func (s *service) UpdatePaymentStatus(_ context.Context, id string) error {
return s.paymentRequests.updatePingTimestamp(id)
}
func (s *service) FaucetVtxos(ctx context.Context, userPubkey *secp256k1.PublicKey) error {
pubkey := hex.EncodeToString(userPubkey.SerializeCompressed())
payment, err := domain.NewPayment([]domain.Vtxo{
{
VtxoKey: faucetVtxo,
Receiver: domain.Receiver{
Pubkey: pubkey,
Amount: 10000,
},
},
})
if err != nil {
return err
}
if err := payment.AddReceivers([]domain.Receiver{
{Pubkey: pubkey, Amount: 1000},
{Pubkey: pubkey, Amount: 1000},
{Pubkey: pubkey, Amount: 1000},
{Pubkey: pubkey, Amount: 1000},
{Pubkey: pubkey, Amount: 1000},
{Pubkey: pubkey, Amount: 1000},
{Pubkey: pubkey, Amount: 1000},
{Pubkey: pubkey, Amount: 1000},
{Pubkey: pubkey, Amount: 1000},
{Pubkey: pubkey, Amount: 1000},
}); err != nil {
return err
}
if err := s.paymentRequests.push(*payment); err != nil {
return err
}
return s.paymentRequests.updatePingTimestamp(payment.Id)
}
func (s *service) SignVtxos(ctx context.Context, forfeitTxs []string) error {
return s.forfeitTxs.sign(forfeitTxs)
}
func (s *service) ListVtxos(ctx context.Context, pubkey *secp256k1.PublicKey) ([]domain.Vtxo, error) {
pk := hex.EncodeToString(pubkey.SerializeCompressed())
return s.repoManager.Vtxos().GetSpendableVtxos(ctx, pk)
}
func (s *service) GetEventsChannel(ctx context.Context) <-chan domain.RoundEvent {
return s.eventsCh
}
func (s *service) GetRoundByTxid(ctx context.Context, poolTxid string) (*domain.Round, error) {
return s.repoManager.Rounds().GetRoundWithTxid(ctx, poolTxid)
}
func (s *service) GetPubkey(ctx context.Context) (string, error) {
pubkey, err := common.EncodePubKey(s.network.PubKey, s.pubkey)
if err != nil {
return "", err
}
return pubkey, nil
}
func (s *service) start() {
s.startRound()
}
func (s *service) startRound() {
round := domain.NewRound(dustAmount)
changes, _ := round.StartRegistration()
if err := s.repoManager.Events().Save(
context.Background(), round.Id, changes...,
); err != nil {
log.WithError(err).Warn("failed to store new round events")
return
}
defer func() {
time.Sleep(time.Duration(s.roundInterval/2) * time.Second)
s.startFinalization()
}()
log.Debugf("started registration stage for new round: %s", round.Id)
}
func (s *service) startFinalization() {
ctx := context.Background()
round, err := s.repoManager.Rounds().GetCurrentRound(ctx)
if err != nil {
log.WithError(err).Warn("failed to retrieve current round")
return
}
var changes []domain.RoundEvent
defer func() {
if len(changes) > 0 {
if err := s.repoManager.Events().Save(ctx, round.Id, changes...); err != nil {
log.WithError(err).Warn("failed to store new round events")
}
}
if round.IsFailed() {
s.startRound()
return
}
time.Sleep(time.Duration((s.roundInterval/2)-1) * time.Second)
s.finalizeRound()
}()
if round.IsFailed() {
return
}
// TODO: understand how many payments must be popped from the queue and actually registered for the round
num := s.paymentRequests.len()
if num == 0 {
err := fmt.Errorf("no payments registered")
changes = round.Fail(fmt.Errorf("round aborted: %s", err))
log.WithError(err).Debugf("round %s aborted", round.Id)
return
}
if num > paymentsThreshold {
num = paymentsThreshold
}
payments := s.paymentRequests.pop(num)
changes, err = round.RegisterPayments(payments)
if err != nil {
changes = round.Fail(fmt.Errorf("failed to register payments: %s", err))
log.WithError(err).Warn("failed to register payments")
return
}
unsignedPoolTx, tree, err := s.builder.BuildPoolTx(s.pubkey, payments, s.minRelayFee)
if err != nil {
changes = round.Fail(fmt.Errorf("failed to create pool tx: %s", err))
log.WithError(err).Warn("failed to create pool tx")
return
}
log.Debugf("pool tx created for round %s", round.Id)
connectors, forfeitTxs, err := s.builder.BuildForfeitTxs(s.pubkey, unsignedPoolTx, payments)
if err != nil {
changes = round.Fail(fmt.Errorf("failed to create connectors and forfeit txs: %s", err))
log.WithError(err).Warn("failed to create connectors and forfeit txs")
return
}
log.Debugf("forfeit transactions created for round %s", round.Id)
events, err := round.StartFinalization(connectors, tree, unsignedPoolTx)
if err != nil {
changes = round.Fail(fmt.Errorf("failed to start finalization: %s", err))
log.WithError(err).Warn("failed to start finalization")
return
}
changes = append(changes, events...)
s.forfeitTxs.push(forfeitTxs)
log.Debugf("started finalization stage for round: %s", round.Id)
}
func (s *service) finalizeRound() {
defer s.startRound()
ctx := context.Background()
round, err := s.repoManager.Rounds().GetCurrentRound(ctx)
if err != nil {
log.WithError(err).Warn("failed to retrieve current round")
return
}
if round.IsFailed() {
return
}
var changes []domain.RoundEvent
defer func() {
if err := s.repoManager.Events().Save(ctx, round.Id, changes...); err != nil {
log.WithError(err).Warn("failed to store new round events")
return
}
}()
forfeitTxs, leftUnsigned := s.forfeitTxs.pop()
if len(leftUnsigned) > 0 {
err := fmt.Errorf("%d forfeit txs left to sign", len(leftUnsigned))
changes = round.Fail(fmt.Errorf("failed to finalize round: %s", err))
log.WithError(err).Warn("failed to finalize round")
return
}
log.Debugf("signing round transaction %s\n", round.Id)
signedPoolTx, err := s.wallet.SignPset(ctx, round.UnsignedTx, true)
if err != nil {
changes = round.Fail(fmt.Errorf("failed to sign round tx: %s", err))
log.WithError(err).Warn("failed to sign round tx")
return
}
txid, err := s.wallet.BroadcastTransaction(ctx, signedPoolTx)
if err != nil {
changes = round.Fail(fmt.Errorf("failed to broadcast pool tx: %s", err))
log.WithError(err).Warn("failed to broadcast pool tx")
return
}
now := time.Now().Unix()
expirationTimestamp := now + s.roundLifetime + 30 // add 30 secs to be sure that the tx is confirmed
if err := s.sweeper.schedule(expirationTimestamp, txid, round.CongestionTree); err != nil {
changes = round.Fail(fmt.Errorf("failed to schedule sweep tx: %s", err))
log.WithError(err).Warn("failed to schedule sweep tx")
return
}
changes, _ = round.EndFinalization(forfeitTxs, txid)
log.Debugf("finalized round %s with pool tx %s", round.Id, round.Txid)
}
func (s *service) listenToRedemptions() {
ctx := context.Background()
chVtxos := s.scanner.GetNotificationChannel(ctx)
for vtxoKeys := range chVtxos {
if len(vtxoKeys) > 0 {
for {
// TODO: make sure that the vtxos haven't been already spent, otherwise
// broadcast the corresponding forfeit tx and connector to prevent
// getting cheated.
vtxos, err := s.repoManager.Vtxos().RedeemVtxos(ctx, vtxoKeys)
if err != nil {
log.WithError(err).Warn("failed to redeem vtxos, retrying...")
time.Sleep(100 * time.Millisecond)
continue
}
if len(vtxos) > 0 {
log.Debugf("redeemed %d vtxos", len(vtxos))
}
break
}
}
}
}
func (s *service) updateProjectionStore(round *domain.Round) {
ctx := context.Background()
lastChange := round.Events()[len(round.Events())-1]
// Update the vtxo set only after a round is finalized.
if _, ok := lastChange.(domain.RoundFinalized); ok {
repo := s.repoManager.Vtxos()
spentVtxos := getSpentVtxos(round.Payments)
if len(spentVtxos) > 0 {
for {
if err := repo.SpendVtxos(ctx, spentVtxos); err != nil {
log.WithError(err).Warn("failed to add new vtxos, retrying soon")
time.Sleep(100 * time.Millisecond)
continue
}
log.Debugf("spent %d vtxos", len(spentVtxos))
break
}
}
newVtxos := s.getNewVtxos(round)
for {
if err := repo.AddVtxos(ctx, newVtxos); err != nil {
log.WithError(err).Warn("failed to add new vtxos, retrying soon")
time.Sleep(100 * time.Millisecond)
continue
}
log.Debugf("added %d new vtxos", len(newVtxos))
break
}
go func() {
for {
if err := s.startWatchingVtxos(newVtxos); err != nil {
log.WithError(err).Warn(
"failed to start watching vtxos, retrying in a moment...",
)
continue
}
log.Debugf("started watching %d vtxos", len(newVtxos))
return
}
}()
}
// Always update the status of the round.
for {
if err := s.repoManager.Rounds().AddOrUpdateRound(ctx, *round); err != nil {
time.Sleep(100 * time.Millisecond)
continue
}
break
}
}
func (s *service) propagateEvents(round *domain.Round) {
lastEvent := round.Events()[len(round.Events())-1]
switch e := lastEvent.(type) {
case domain.RoundFinalizationStarted:
forfeitTxs := s.forfeitTxs.view()
s.eventsCh <- domain.RoundFinalizationStarted{
Id: e.Id,
CongestionTree: e.CongestionTree,
Connectors: e.Connectors,
PoolTx: e.PoolTx,
UnsignedForfeitTxs: forfeitTxs,
}
case domain.RoundFinalized, domain.RoundFailed:
s.eventsCh <- e
}
}
func (s *service) getNewVtxos(round *domain.Round) []domain.Vtxo {
leaves := round.CongestionTree.Leaves()
vtxos := make([]domain.Vtxo, 0)
for _, node := range leaves {
tx, _ := psetv2.NewPsetFromBase64(node.Tx)
for i, out := range tx.Outputs {
for _, p := range round.Payments {
var pubkey string
found := false
for _, r := range p.Receivers {
if r.IsOnchain() {
continue
}
buf, _ := hex.DecodeString(r.Pubkey)
pk, _ := secp256k1.ParsePubKey(buf)
script, _ := s.builder.GetVtxoScript(pk, s.pubkey)
if bytes.Equal(script, out.Script) {
found = true
pubkey = r.Pubkey
break
}
}
if found {
vtxos = append(vtxos, domain.Vtxo{
VtxoKey: domain.VtxoKey{Txid: node.Txid, VOut: uint32(i)},
Receiver: domain.Receiver{Pubkey: pubkey, Amount: out.Value},
PoolTx: round.Txid,
})
break
}
}
}
}
return vtxos
}
func (s *service) startWatchingVtxos(vtxos []domain.Vtxo) error {
scripts, err := s.extractVtxosScripts(vtxos)
if err != nil {
return err
}
return s.scanner.WatchScripts(context.Background(), scripts)
}
func (s *service) stopWatchingVtxos(vtxos []domain.Vtxo) {
scripts, err := s.extractVtxosScripts(vtxos)
if err != nil {
log.WithError(err).Warn("failed to extract scripts from vtxos")
return
}
for {
if err := s.scanner.UnwatchScripts(context.Background(), scripts); err != nil {
log.WithError(err).Warn("failed to stop watching vtxos, retrying in a moment...")
time.Sleep(100 * time.Millisecond)
continue
}
log.Debugf("stopped watching %d vtxos", len(vtxos))
break
}
}
func (s *service) restoreWatchingVtxos() error {
vtxos, err := s.repoManager.Vtxos().GetSpendableVtxos(
context.Background(), "",
)
if err != nil {
return err
}
if len(vtxos) <= 0 {
return nil
}
if err := s.startWatchingVtxos(vtxos); err != nil {
return err
}
log.Debugf("restored watching %d vtxos", len(vtxos))
return nil
}
func (s *service) extractVtxosScripts(vtxos []domain.Vtxo) ([]string, error) {
indexedScripts := make(map[string]struct{})
for _, vtxo := range vtxos {
buf, err := hex.DecodeString(vtxo.Pubkey)
if err != nil {
return nil, err
}
userPubkey, err := secp256k1.ParsePubKey(buf)
if err != nil {
return nil, err
}
script, err := s.builder.GetVtxoScript(userPubkey, s.pubkey)
if err != nil {
return nil, err
}
indexedScripts[hex.EncodeToString(script)] = struct{}{}
}
scripts := make([]string, 0, len(indexedScripts))
for script := range indexedScripts {
scripts = append(scripts, script)
}
return scripts, nil
}
func getSpentVtxos(payments map[string]domain.Payment) []domain.VtxoKey {
vtxos := make([]domain.VtxoKey, 0)
for _, p := range payments {
for _, vtxo := range p.Inputs {
if vtxo.VtxoKey == faucetVtxo {
continue
}
vtxos = append(vtxos, vtxo.VtxoKey)
}
}
return vtxos
}

View File

@@ -0,0 +1,579 @@
package application
import (
"context"
"encoding/hex"
"fmt"
"time"
"github.com/ark-network/ark/common/tree"
"github.com/ark-network/ark/internal/core/domain"
"github.com/ark-network/ark/internal/core/ports"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/decred/dcrd/dcrec/secp256k1/v4"
log "github.com/sirupsen/logrus"
"github.com/vulpemventures/go-elements/psetv2"
)
// sweeper is an unexported service running while the main application service is started
// it is responsible for sweeping onchain shared outputs that expired
// it also handles delaying the sweep events in case some parts of the tree are broadcasted
// when a round is finalized, the main application service schedules a sweep event on the newly created congestion tree
type sweeper struct {
wallet ports.WalletService
repoManager ports.RepoManager
builder ports.TxBuilder
scheduler ports.SchedulerService
// cache of scheduled tasks, avoid scheduling the same sweep event multiple times
scheduledTasks map[string]struct{}
}
func newSweeper(
wallet ports.WalletService,
repoManager ports.RepoManager,
builder ports.TxBuilder,
scheduler ports.SchedulerService,
) *sweeper {
return &sweeper{
wallet,
repoManager,
builder,
scheduler,
make(map[string]struct{}),
}
}
func (s *sweeper) start() error {
s.scheduler.Start()
allRounds, err := s.repoManager.Rounds().GetSweepableRounds(context.Background())
if err != nil {
return err
}
for _, round := range allRounds {
task := s.createTask(round.Txid, round.CongestionTree)
task()
}
return nil
}
func (s *sweeper) stop() {
s.scheduler.Stop()
}
// removeTask update the cached map of scheduled tasks
func (s *sweeper) removeTask(treeRootTxid string) {
delete(s.scheduledTasks, treeRootTxid)
}
// schedule set up a task to be executed once at the given timestamp
func (s *sweeper) schedule(
expirationTimestamp int64, roundTxid string, congestionTree tree.CongestionTree,
) error {
root, err := congestionTree.Root()
if err != nil {
return err
}
if _, scheduled := s.scheduledTasks[root.Txid]; scheduled {
return nil
}
task := s.createTask(roundTxid, congestionTree)
fancyTime := time.Unix(expirationTimestamp, 0).Format("2006-01-02 15:04:05")
log.Debugf("scheduled sweep task at %s", fancyTime)
if err := s.scheduler.ScheduleTaskOnce(expirationTimestamp, task); err != nil {
return err
}
s.scheduledTasks[root.Txid] = struct{}{}
return nil
}
// createTask returns a function passed as handler in the scheduler
// it tries to craft a sweep tx containing the onchain outputs of the given congestion tree
// if some parts of the tree have been broadcasted in the meantine, it will schedule the next taskes for the remaining parts of the tree
func (s *sweeper) createTask(
roundTxid string, congestionTree tree.CongestionTree,
) func() {
return func() {
ctx := context.Background()
root, err := congestionTree.Root()
if err != nil {
log.WithError(err).Error("error while getting root node")
return
}
s.removeTask(root.Txid)
log.Debugf("sweeper: %s", root.Txid)
sweepInputs := make([]ports.SweepInput, 0)
vtxoKeys := make([]domain.VtxoKey, 0) // vtxos associated to the sweep inputs
// inspect the congestion tree to find onchain shared outputs
sharedOutputs, err := s.findSweepableOutputs(ctx, congestionTree)
if err != nil {
log.WithError(err).Error("error while inspecting congestion tree")
return
}
for expiredAt, inputs := range sharedOutputs {
// if the shared outputs are not expired, schedule a sweep task for it
if time.Unix(expiredAt, 0).After(time.Now()) {
subtrees, err := computeSubTrees(congestionTree, inputs)
if err != nil {
log.WithError(err).Error("error while computing subtrees")
continue
}
for _, subTree := range subtrees {
// mitigate the risk to get BIP68 non-final errors by scheduling the task 30 seconds after the expiration time
if err := s.schedule(int64(expiredAt), roundTxid, subTree); err != nil {
log.WithError(err).Error("error while scheduling sweep task")
continue
}
}
continue
}
// iterate over the expired shared outputs
for _, input := range inputs {
// sweepableVtxos related to the sweep input
sweepableVtxos := make([]domain.VtxoKey, 0)
// check if input is the vtxo itself
vtxos, _ := s.repoManager.Vtxos().GetVtxos(
ctx,
[]domain.VtxoKey{
{
Txid: input.InputArgs.Txid,
VOut: input.InputArgs.TxIndex,
},
},
)
if len(vtxos) > 0 {
if !vtxos[0].Swept && !vtxos[0].Redeemed {
sweepableVtxos = append(sweepableVtxos, vtxos[0].VtxoKey)
}
} else {
// if it's not a vtxo, find all the vtxos leaves reachable from that input
vtxosLeaves, err := congestionTree.FindLeaves(input.InputArgs.Txid, input.InputArgs.TxIndex)
if err != nil {
log.WithError(err).Error("error while finding vtxos leaves")
continue
}
for _, leaf := range vtxosLeaves {
pset, err := psetv2.NewPsetFromBase64(leaf.Tx)
if err != nil {
log.Error(fmt.Errorf("error while decoding pset: %w", err))
continue
}
vtxo, err := extractVtxoOutpoint(pset)
if err != nil {
log.Error(err)
continue
}
sweepableVtxos = append(sweepableVtxos, *vtxo)
}
if len(sweepableVtxos) <= 0 {
continue
}
firstVtxo, err := s.repoManager.Vtxos().GetVtxos(ctx, sweepableVtxos[:1])
if err != nil {
log.Error(fmt.Errorf("error while getting vtxo: %w", err))
sweepInputs = append(sweepInputs, input) // add the input anyway in order to try to sweep it
continue
}
if firstVtxo[0].Swept || firstVtxo[0].Redeemed {
// we assume that if the first vtxo is swept or redeemed, the shared output has been spent
// skip, the output is already swept or spent by a unilateral redeem
continue
}
}
if len(sweepableVtxos) > 0 {
vtxoKeys = append(vtxoKeys, sweepableVtxos...)
sweepInputs = append(sweepInputs, input)
}
}
}
if len(sweepInputs) > 0 {
// build the sweep transaction with all the expired non-swept shared outputs
sweepTx, err := s.builder.BuildSweepTx(s.wallet, sweepInputs)
if err != nil {
log.WithError(err).Error("error while building sweep tx")
return
}
err = nil
txid := ""
// retry until the tx is broadcasted or the error is not BIP68 final
for len(txid) == 0 && (err == nil || err == fmt.Errorf("non-BIP68-final")) {
if err != nil {
log.Debugln("sweep tx not BIP68 final, retrying in 5 seconds")
time.Sleep(5 * time.Second)
}
txid, err = s.wallet.BroadcastTransaction(ctx, sweepTx)
}
if err != nil {
log.WithError(err).Error("error while broadcasting sweep tx")
return
}
if len(txid) > 0 {
log.Debugln("sweep tx broadcasted:", txid)
vtxosRepository := s.repoManager.Vtxos()
// mark the vtxos as swept
if err := vtxosRepository.SweepVtxos(ctx, vtxoKeys); err != nil {
log.Error(fmt.Errorf("error while deleting vtxos: %w", err))
return
}
log.Debugf("%d vtxos swept", len(vtxoKeys))
roundVtxos, err := vtxosRepository.GetVtxosForRound(ctx, roundTxid)
if err != nil {
log.WithError(err).Error("error while getting vtxos for round")
return
}
allSwept := true
for _, vtxo := range roundVtxos {
allSwept = allSwept && vtxo.Swept
if !allSwept {
break
}
}
if allSwept {
// update the round
roundRepo := s.repoManager.Rounds()
round, err := roundRepo.GetRoundWithTxid(ctx, roundTxid)
if err != nil {
log.WithError(err).Error("error while getting round")
return
}
round.Sweep()
if err := roundRepo.AddOrUpdateRound(ctx, *round); err != nil {
log.WithError(err).Error("error while marking round as swept")
return
}
}
}
}
}
}
// onchainOutputs iterates over all the nodes' outputs in the congestion tree and checks their onchain state
// returns the sweepable outputs as ports.SweepInput mapped by their expiration time
func (s *sweeper) findSweepableOutputs(
ctx context.Context,
congestionTree tree.CongestionTree,
) (map[int64][]ports.SweepInput, error) {
sweepableOutputs := make(map[int64][]ports.SweepInput)
blocktimeCache := make(map[string]int64) // txid -> blocktime
nodesToCheck := congestionTree[0] // init with the root
for len(nodesToCheck) > 0 {
newNodesToCheck := make([]tree.Node, 0)
for _, node := range nodesToCheck {
isPublished, blocktime, err := s.wallet.IsTransactionPublished(ctx, node.Txid)
if err != nil {
return nil, err
}
var expirationTime int64
var sweepInputs []ports.SweepInput
if !isPublished {
if _, ok := blocktimeCache[node.ParentTxid]; !ok {
isPublished, blocktime, err := s.wallet.IsTransactionPublished(ctx, node.ParentTxid)
if !isPublished || err != nil {
return nil, fmt.Errorf("tx %s not found", node.Txid)
}
blocktimeCache[node.ParentTxid] = blocktime
}
expirationTime, sweepInputs, err = s.nodeToSweepInputs(blocktimeCache[node.ParentTxid], node)
if err != nil {
return nil, err
}
} else {
// cache the blocktime for future use
blocktimeCache[node.Txid] = int64(blocktime)
// if the tx is onchain, it means that the input is spent
// add the children to the nodes in order to check them during the next iteration
// We will return the error below, but are we going to schedule the tasks for the "children roots"?
if !node.Leaf {
children := congestionTree.Children(node.Txid)
newNodesToCheck = append(newNodesToCheck, children...)
continue
}
// if the node is a leaf, the vtxos outputs should added as onchain outputs if they are not swept yet
vtxoExpiration, sweepInput, err := s.leafToSweepInput(ctx, blocktime, node)
if err != nil {
return nil, err
}
if sweepInput != nil {
expirationTime = vtxoExpiration
sweepInputs = []ports.SweepInput{*sweepInput}
}
}
if _, ok := sweepableOutputs[expirationTime]; !ok {
sweepableOutputs[expirationTime] = make([]ports.SweepInput, 0)
}
sweepableOutputs[expirationTime] = append(sweepableOutputs[expirationTime], sweepInputs...)
}
nodesToCheck = newNodesToCheck
}
return sweepableOutputs, nil
}
func (s *sweeper) leafToSweepInput(ctx context.Context, txBlocktime int64, node tree.Node) (int64, *ports.SweepInput, error) {
pset, err := psetv2.NewPsetFromBase64(node.Tx)
if err != nil {
return -1, nil, err
}
vtxo, err := extractVtxoOutpoint(pset)
if err != nil {
return -1, nil, err
}
fromRepo, err := s.repoManager.Vtxos().GetVtxos(ctx, []domain.VtxoKey{*vtxo})
if err != nil {
return -1, nil, err
}
if len(fromRepo) == 0 {
return -1, nil, fmt.Errorf("vtxo not found")
}
if fromRepo[0].Swept {
return -1, nil, nil
}
if fromRepo[0].Redeemed {
return -1, nil, nil
}
// if the vtxo is not swept or redeemed, add it to the onchain outputs
pubKeyBytes, err := hex.DecodeString(fromRepo[0].Pubkey)
if err != nil {
return -1, nil, err
}
pubKey, err := secp256k1.ParsePubKey(pubKeyBytes)
if err != nil {
return -1, nil, err
}
sweepLeaf, lifetime, err := s.builder.GetLeafSweepClosure(node, pubKey)
if err != nil {
return -1, nil, err
}
sweepInput := ports.SweepInput{
InputArgs: psetv2.InputArgs{
Txid: vtxo.Txid,
TxIndex: vtxo.VOut,
},
SweepLeaf: *sweepLeaf,
Amount: fromRepo[0].Amount,
}
return txBlocktime + lifetime, &sweepInput, nil
}
func (s *sweeper) nodeToSweepInputs(parentBlocktime int64, node tree.Node) (int64, []ports.SweepInput, error) {
pset, err := psetv2.NewPsetFromBase64(node.Tx)
if err != nil {
return -1, nil, err
}
if len(pset.Inputs) != 1 {
return -1, nil, fmt.Errorf("invalid node pset, expect 1 input, got %d", len(pset.Inputs))
}
// if the tx is not onchain, it means that the input is an existing shared output
input := pset.Inputs[0]
txid := chainhash.Hash(input.PreviousTxid).String()
index := input.PreviousTxIndex
sweepLeaf, lifetime, err := extractSweepLeaf(input)
if err != nil {
return -1, nil, err
}
expirationTime := parentBlocktime + lifetime
amount := uint64(0)
for _, out := range pset.Outputs {
amount += out.Value
}
sweepInputs := []ports.SweepInput{
{
InputArgs: psetv2.InputArgs{
Txid: txid,
TxIndex: index,
},
SweepLeaf: *sweepLeaf,
Amount: amount,
},
}
return expirationTime, sweepInputs, nil
}
func computeSubTrees(congestionTree tree.CongestionTree, inputs []ports.SweepInput) ([]tree.CongestionTree, error) {
subTrees := make(map[string]tree.CongestionTree, 0)
// for each sweepable input, create a sub congestion tree
// it allows to skip the part of the tree that has been broadcasted in the next task
for _, input := range inputs {
subTree, err := computeSubTree(congestionTree, input.InputArgs.Txid)
if err != nil {
log.WithError(err).Error("error while finding sub tree")
continue
}
root, err := subTree.Root()
if err != nil {
log.WithError(err).Error("error while getting root node")
continue
}
subTrees[root.Txid] = subTree
}
// filter out the sub trees, remove the ones that are included in others
filteredSubTrees := make([]tree.CongestionTree, 0)
for i, subTree := range subTrees {
notIncludedInOtherTrees := true
for j, otherSubTree := range subTrees {
if i == j {
continue
}
contains, err := containsTree(otherSubTree, subTree)
if err != nil {
log.WithError(err).Error("error while checking if a tree contains another")
continue
}
if contains {
notIncludedInOtherTrees = false
break
}
}
if notIncludedInOtherTrees {
filteredSubTrees = append(filteredSubTrees, subTree)
}
}
return filteredSubTrees, nil
}
func computeSubTree(congestionTree tree.CongestionTree, newRoot string) (tree.CongestionTree, error) {
for _, level := range congestionTree {
for _, node := range level {
if node.Txid == newRoot || node.ParentTxid == newRoot {
newTree := make(tree.CongestionTree, 0)
newTree = append(newTree, []tree.Node{node})
children := congestionTree.Children(node.Txid)
for len(children) > 0 {
newTree = append(newTree, children)
newChildren := make([]tree.Node, 0)
for _, child := range children {
newChildren = append(newChildren, congestionTree.Children(child.Txid)...)
}
children = newChildren
}
return newTree, nil
}
}
}
return nil, fmt.Errorf("failed to create subtree, new root not found")
}
func containsTree(tr0 tree.CongestionTree, tr1 tree.CongestionTree) (bool, error) {
tr1Root, err := tr1.Root()
if err != nil {
return false, err
}
for _, level := range tr0 {
for _, node := range level {
if node.Txid == tr1Root.Txid {
return true, nil
}
}
}
return false, nil
}
// given a congestion tree input, searches and returns the sweep leaf and its lifetime in seconds
func extractSweepLeaf(input psetv2.Input) (sweepLeaf *psetv2.TapLeafScript, lifetime int64, err error) {
for _, leaf := range input.TapLeafScript {
isSweep, _, seconds, err := tree.DecodeSweepScript(leaf.Script)
if err != nil {
return nil, 0, err
}
if isSweep {
lifetime = int64(seconds)
sweepLeaf = &leaf
break
}
}
if sweepLeaf == nil {
return nil, 0, fmt.Errorf("sweep leaf not found")
}
return sweepLeaf, lifetime, nil
}
// assuming the pset is a leaf in the congestion tree, returns the vtxos outputs
func extractVtxoOutpoint(pset *psetv2.Pset) (*domain.VtxoKey, error) {
if len(pset.Outputs) != 2 {
return nil, fmt.Errorf("invalid leaf pset, expect 2 outputs, got %d", len(pset.Outputs))
}
utx, err := pset.UnsignedTx()
if err != nil {
return nil, err
}
return &domain.VtxoKey{
Txid: utx.TxHash().String(),
VOut: 0,
}, nil
}

View File

@@ -0,0 +1,254 @@
package application
import (
"bytes"
"encoding/hex"
"fmt"
"sort"
"sync"
"time"
"github.com/ark-network/ark/common"
"github.com/ark-network/ark/internal/core/domain"
"github.com/btcsuite/btcd/btcec/v2/schnorr"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/vulpemventures/go-elements/psetv2"
)
type timedPayment struct {
domain.Payment
timestamp time.Time
pingTimestamp time.Time
}
type paymentsMap struct {
lock *sync.RWMutex
payments map[string]*timedPayment
}
func newPaymentsMap(payments []domain.Payment) *paymentsMap {
paymentsById := make(map[string]*timedPayment)
for _, p := range payments {
paymentsById[p.Id] = &timedPayment{p, time.Now(), time.Time{}}
}
lock := &sync.RWMutex{}
return &paymentsMap{lock, paymentsById}
}
func (m *paymentsMap) len() int64 {
m.lock.RLock()
defer m.lock.RUnlock()
count := int64(0)
for _, p := range m.payments {
if len(p.Receivers) > 0 {
count++
}
}
return count
}
func (m *paymentsMap) push(payment domain.Payment) error {
m.lock.Lock()
defer m.lock.Unlock()
if _, ok := m.payments[payment.Id]; ok {
return fmt.Errorf("duplicated inputs")
}
m.payments[payment.Id] = &timedPayment{payment, time.Now(), time.Time{}}
return nil
}
func (m *paymentsMap) pop(num int64) []domain.Payment {
m.lock.Lock()
defer m.lock.Unlock()
paymentsByTime := make([]timedPayment, 0, len(m.payments))
for _, p := range m.payments {
// Skip payments without registered receivers.
if len(p.Receivers) <= 0 {
continue
}
// Skip payments for which users didn't notify to be online in the last minute.
if p.pingTimestamp.IsZero() || time.Since(p.pingTimestamp).Minutes() > 1 {
continue
}
paymentsByTime = append(paymentsByTime, *p)
}
sort.SliceStable(paymentsByTime, func(i, j int) bool {
return paymentsByTime[i].timestamp.Before(paymentsByTime[j].timestamp)
})
if num < 0 || num > int64(len(paymentsByTime)) {
num = int64(len(paymentsByTime))
}
payments := make([]domain.Payment, 0, num)
for _, p := range paymentsByTime[:num] {
payments = append(payments, p.Payment)
delete(m.payments, p.Id)
}
return payments
}
func (m *paymentsMap) update(payment domain.Payment) error {
m.lock.Lock()
defer m.lock.Unlock()
p, ok := m.payments[payment.Id]
if !ok {
return fmt.Errorf("payment %s not found", payment.Id)
}
p.Payment = payment
return nil
}
func (m *paymentsMap) updatePingTimestamp(id string) error {
m.lock.Lock()
defer m.lock.Unlock()
payment, ok := m.payments[id]
if !ok {
return fmt.Errorf("payment %s not found", id)
}
payment.pingTimestamp = time.Now()
return nil
}
func (m *paymentsMap) view(id string) (*domain.Payment, bool) {
m.lock.RLock()
defer m.lock.RUnlock()
payment, ok := m.payments[id]
if !ok {
return nil, false
}
return &domain.Payment{
Id: payment.Id,
Inputs: payment.Inputs,
Receivers: payment.Receivers,
}, true
}
type signedTx struct {
tx string
signed bool
}
type forfeitTxsMap struct {
lock *sync.RWMutex
forfeitTxs map[string]*signedTx
genesisBlockHash *chainhash.Hash
}
func newForfeitTxsMap(genesisBlockHash *chainhash.Hash) *forfeitTxsMap {
return &forfeitTxsMap{&sync.RWMutex{}, make(map[string]*signedTx), genesisBlockHash}
}
func (m *forfeitTxsMap) push(txs []string) {
m.lock.Lock()
defer m.lock.Unlock()
faucetTxID, _ := hex.DecodeString(faucetVtxo.Txid)
for _, tx := range txs {
ptx, _ := psetv2.NewPsetFromBase64(tx)
utx, _ := ptx.UnsignedTx()
signed := false
// find the faucet vtxos, and mark them as signed
for _, input := range ptx.Inputs {
if bytes.Equal(input.PreviousTxid, faucetTxID) {
signed = true
break
}
}
m.forfeitTxs[utx.TxHash().String()] = &signedTx{tx, signed}
}
}
func (m *forfeitTxsMap) sign(txs []string) error {
m.lock.Lock()
defer m.lock.Unlock()
for _, tx := range txs {
ptx, _ := psetv2.NewPsetFromBase64(tx)
utx, _ := ptx.UnsignedTx()
txid := utx.TxHash().String()
if _, ok := m.forfeitTxs[txid]; ok {
for index, input := range ptx.Inputs {
if len(input.TapScriptSig) > 0 {
for _, tapScriptSig := range input.TapScriptSig {
leafHash, err := chainhash.NewHash(tapScriptSig.LeafHash)
if err != nil {
return err
}
preimage, err := common.TaprootPreimage(
m.genesisBlockHash,
ptx,
index,
leafHash,
)
if err != nil {
return err
}
sig, err := schnorr.ParseSignature(tapScriptSig.Signature)
if err != nil {
return err
}
pubkey, err := schnorr.ParsePubKey(tapScriptSig.PubKey)
if err != nil {
return err
}
if sig.Verify(preimage, pubkey) {
m.forfeitTxs[txid].signed = true
} else {
return fmt.Errorf("invalid signature")
}
}
}
}
}
}
return nil
}
func (m *forfeitTxsMap) pop() (signed, unsigned []string) {
m.lock.Lock()
defer m.lock.Unlock()
for _, t := range m.forfeitTxs {
if t.signed {
signed = append(signed, t.tx)
} else {
unsigned = append(unsigned, t.tx)
}
}
m.forfeitTxs = make(map[string]*signedTx)
return signed, unsigned
}
func (m *forfeitTxsMap) view() []string {
m.lock.RLock()
defer m.lock.RUnlock()
txs := make([]string, 0, len(m.forfeitTxs))
for _, tx := range m.forfeitTxs {
txs = append(txs, tx.tx)
}
return txs
}

View File

@@ -0,0 +1,44 @@
package domain
import "github.com/ark-network/ark/common/tree"
type RoundEvent interface {
isEvent()
}
func (r RoundStarted) isEvent() {}
func (r RoundFinalizationStarted) isEvent() {}
func (r RoundFinalized) isEvent() {}
func (r RoundFailed) isEvent() {}
func (r PaymentsRegistered) isEvent() {}
type RoundStarted struct {
Id string
Timestamp int64
}
type RoundFinalizationStarted struct {
Id string
CongestionTree tree.CongestionTree
Connectors []string
UnsignedForfeitTxs []string
PoolTx string
}
type RoundFinalized struct {
Id string
Txid string
ForfeitTxs []string
Timestamp int64
}
type RoundFailed struct {
Id string
Err string
Timestamp int64
}
type PaymentsRegistered struct {
Id string
Payments []Payment
}

View File

@@ -0,0 +1,129 @@
package domain
import (
"crypto/sha256"
"encoding/hex"
"fmt"
"hash"
"github.com/google/uuid"
)
const dustAmount = 450
type Payment struct {
Id string
Inputs []Vtxo
Receivers []Receiver
}
func NewPayment(inputs []Vtxo) (*Payment, error) {
p := &Payment{
Id: uuid.New().String(),
Inputs: inputs,
}
if err := p.validate(true); err != nil {
return nil, err
}
return p, nil
}
func (p *Payment) AddReceivers(receivers []Receiver) (err error) {
if p.Receivers == nil {
p.Receivers = make([]Receiver, 0)
}
p.Receivers = append(p.Receivers, receivers...)
defer func() {
if err != nil {
p.Receivers = p.Receivers[:len(p.Receivers)-len(receivers)]
}
}()
err = p.validate(false)
return
}
func (p Payment) TotalInputAmount() uint64 {
tot := uint64(0)
for _, in := range p.Inputs {
tot += in.Amount
}
return tot
}
func (p Payment) TotalOutputAmount() uint64 {
tot := uint64(0)
for _, r := range p.Receivers {
tot += r.Amount
}
return tot
}
func (p Payment) validate(ignoreOuts bool) error {
if len(p.Id) <= 0 {
return fmt.Errorf("missing id")
}
if len(p.Inputs) <= 0 {
return fmt.Errorf("missing inputs")
}
if ignoreOuts {
return nil
}
if len(p.Receivers) <= 0 {
return fmt.Errorf("missing outputs")
}
// Check that input and output and output amounts match.
inAmount := p.TotalInputAmount()
outAmount := uint64(0)
for _, r := range p.Receivers {
if len(r.OnchainAddress) <= 0 && len(r.Pubkey) <= 0 {
return fmt.Errorf("missing receiver destination")
}
if r.Amount < dustAmount {
return fmt.Errorf("receiver amount must be greater than dust")
}
outAmount += r.Amount
}
if inAmount != outAmount {
return fmt.Errorf("input and output amounts mismatch")
}
return nil
}
type VtxoKey struct {
Txid string
VOut uint32
}
func (k VtxoKey) Hash() string {
calcHash := func(buf []byte, hasher hash.Hash) []byte {
_, _ = hasher.Write(buf)
return hasher.Sum(nil)
}
hash160 := func(buf []byte) []byte {
return calcHash(calcHash(buf, sha256.New()), sha256.New())
}
buf, _ := hex.DecodeString(k.Txid)
buf = append(buf, byte(k.VOut))
return hex.EncodeToString(hash160(buf))
}
type Receiver struct {
Pubkey string
Amount uint64
OnchainAddress string
}
func (r Receiver) IsOnchain() bool {
return len(r.OnchainAddress) > 0
}
type Vtxo struct {
VtxoKey
Receiver
PoolTx string
Spent bool
Redeemed bool
Swept bool
}

View File

@@ -0,0 +1,111 @@
package domain_test
import (
"testing"
"github.com/ark-network/ark/internal/core/domain"
"github.com/stretchr/testify/require"
)
var inputs = []domain.Vtxo{
{
VtxoKey: domain.VtxoKey{
Txid: "0000000000000000000000000000000000000000000000000000000000000000",
VOut: 0,
},
Receiver: domain.Receiver{
Pubkey: "030000000000000000000000000000000000000000000000000000000000000001",
Amount: 1000,
},
},
}
func TestPayment(t *testing.T) {
t.Run("new_payment", func(t *testing.T) {
t.Run("vaild", func(t *testing.T) {
payment, err := domain.NewPayment(inputs)
require.NoError(t, err)
require.NotNil(t, payment)
require.NotEmpty(t, payment.Id)
require.Exactly(t, inputs, payment.Inputs)
require.Empty(t, payment.Receivers)
})
t.Run("invaild", func(t *testing.T) {
fixtures := []struct {
inputs []domain.Vtxo
expectedErr string
}{
{
inputs: nil,
expectedErr: "missing inputs",
},
}
for _, f := range fixtures {
payment, err := domain.NewPayment(f.inputs)
require.EqualError(t, err, f.expectedErr)
require.Nil(t, payment)
}
})
})
t.Run("add_receivers", func(t *testing.T) {
t.Run("valid", func(t *testing.T) {
payment, err := domain.NewPayment(inputs)
require.NoError(t, err)
require.NotNil(t, payment)
err = payment.AddReceivers([]domain.Receiver{
{
Pubkey: "030000000000000000000000000000000000000000000000000000000000000001",
Amount: 450,
},
{
Pubkey: "020000000000000000000000000000000000000000000000000000000000000002",
Amount: 550,
},
})
require.NoError(t, err)
})
t.Run("invalid", func(t *testing.T) {
fixtures := []struct {
receivers []domain.Receiver
expectedErr string
}{
{
receivers: nil,
expectedErr: "missing outputs",
},
{
receivers: []domain.Receiver{
{
Pubkey: "030000000000000000000000000000000000000000000000000000000000000001",
Amount: 400,
},
},
expectedErr: "receiver amount must be greater than dust",
},
{
receivers: []domain.Receiver{
{
Pubkey: "030000000000000000000000000000000000000000000000000000000000000001",
Amount: 600,
},
},
expectedErr: "input and output amounts mismatch",
},
}
payment, err := domain.NewPayment(inputs)
require.NoError(t, err)
require.NotNil(t, payment)
for _, f := range fixtures {
err := payment.AddReceivers(f.receivers)
require.EqualError(t, err, f.expectedErr)
}
})
})
}

View File

@@ -0,0 +1,253 @@
package domain
import (
"fmt"
"time"
"github.com/ark-network/ark/common/tree"
"github.com/google/uuid"
)
const (
UndefinedStage RoundStage = iota
RegistrationStage
FinalizationStage
)
type RoundStage int
func (s RoundStage) String() string {
switch s {
case RegistrationStage:
return "REGISTRATION_STAGE"
case FinalizationStage:
return "FINALIZATION_STAGE"
default:
return "UNDEFINED_STAGE"
}
}
type Stage struct {
Code RoundStage
Ended bool
Failed bool
}
type Round struct {
Id string
StartingTimestamp int64
EndingTimestamp int64
Stage Stage
Payments map[string]Payment
Txid string
UnsignedTx string
ForfeitTxs []string
CongestionTree tree.CongestionTree
Connectors []string
DustAmount uint64
Version uint
Swept bool // true if all the vtxos are vtxo.Swept
changes []RoundEvent
}
func NewRound(dustAmount uint64) *Round {
return &Round{
Id: uuid.New().String(),
DustAmount: dustAmount,
Payments: make(map[string]Payment),
changes: make([]RoundEvent, 0),
}
}
func NewRoundFromEvents(events []RoundEvent) *Round {
r := &Round{}
for _, event := range events {
r.On(event, true)
}
r.changes = append([]RoundEvent{}, events...)
return r
}
func (r *Round) Events() []RoundEvent {
return r.changes
}
func (r *Round) On(event RoundEvent, replayed bool) {
switch e := event.(type) {
case RoundStarted:
r.Stage.Code = RegistrationStage
r.Id = e.Id
r.StartingTimestamp = e.Timestamp
case RoundFinalizationStarted:
r.Stage.Code = FinalizationStage
r.CongestionTree = e.CongestionTree
r.Connectors = append([]string{}, e.Connectors...)
r.UnsignedTx = e.PoolTx
case RoundFinalized:
r.Stage.Ended = true
r.Txid = e.Txid
r.ForfeitTxs = append([]string{}, e.ForfeitTxs...)
r.EndingTimestamp = e.Timestamp
case RoundFailed:
r.Stage.Failed = true
r.EndingTimestamp = e.Timestamp
case PaymentsRegistered:
if r.Payments == nil {
r.Payments = make(map[string]Payment)
}
for _, p := range e.Payments {
r.Payments[p.Id] = p
}
}
if replayed {
r.Version++
}
}
func (r *Round) StartRegistration() ([]RoundEvent, error) {
empty := Stage{}
if r.Stage != empty {
return nil, fmt.Errorf("not in a valid stage to start payment registration")
}
event := RoundStarted{
Id: r.Id,
Timestamp: time.Now().Unix(),
}
r.raise(event)
return []RoundEvent{event}, nil
}
func (r *Round) RegisterPayments(payments []Payment) ([]RoundEvent, error) {
if r.Stage.Code != RegistrationStage || r.IsFailed() {
return nil, fmt.Errorf("not in a valid stage to register payments")
}
if len(payments) <= 0 {
return nil, fmt.Errorf("missing payments to register")
}
for _, p := range payments {
if err := p.validate(false); err != nil {
return nil, err
}
}
event := PaymentsRegistered{
Id: r.Id,
Payments: payments,
}
r.raise(event)
return []RoundEvent{event}, nil
}
func (r *Round) StartFinalization(connectors []string, congestionTree tree.CongestionTree, poolTx string) ([]RoundEvent, error) {
if len(connectors) <= 0 {
return nil, fmt.Errorf("missing list of connectors")
}
if len(congestionTree) <= 0 {
return nil, fmt.Errorf("missing congestion tree")
}
if len(poolTx) <= 0 {
return nil, fmt.Errorf("missing unsigned pool tx")
}
if r.Stage.Code != RegistrationStage || r.IsFailed() {
return nil, fmt.Errorf("not in a valid stage to start payment finalization")
}
if len(r.Payments) <= 0 {
return nil, fmt.Errorf("no payments registered")
}
event := RoundFinalizationStarted{
Id: r.Id,
CongestionTree: congestionTree,
Connectors: connectors,
PoolTx: poolTx,
}
r.raise(event)
return []RoundEvent{event}, nil
}
func (r *Round) EndFinalization(forfeitTxs []string, txid string) ([]RoundEvent, error) {
if len(forfeitTxs) <= 0 {
return nil, fmt.Errorf("missing list of signed forfeit txs")
}
if len(txid) <= 0 {
return nil, fmt.Errorf("missing pool txid")
}
if r.Stage.Code != FinalizationStage || r.IsFailed() {
return nil, fmt.Errorf("not in a valid stage to end payment finalization")
}
if r.Stage.Ended {
return nil, fmt.Errorf("round already finalized")
}
event := RoundFinalized{
Id: r.Id,
Txid: txid,
ForfeitTxs: forfeitTxs,
Timestamp: time.Now().Unix(),
}
r.raise(event)
return []RoundEvent{event}, nil
}
func (r *Round) Fail(err error) []RoundEvent {
if r.Stage.Failed {
return nil
}
event := RoundFailed{
Id: r.Id,
Err: err.Error(),
Timestamp: time.Now().Unix(),
}
r.raise(event)
return []RoundEvent{event}
}
func (r *Round) IsStarted() bool {
empty := Stage{}
return !r.IsFailed() && !r.IsEnded() && r.Stage != empty
}
func (r *Round) IsEnded() bool {
return !r.IsFailed() && r.Stage.Code == FinalizationStage && r.Stage.Ended
}
func (r *Round) IsFailed() bool {
return r.Stage.Failed
}
func (r *Round) TotalInputAmount() uint64 {
totInputs := 0
for _, p := range r.Payments {
totInputs += len(p.Inputs)
}
return uint64(totInputs * int(r.DustAmount))
}
func (r *Round) TotalOutputAmount() uint64 {
tot := uint64(0)
for _, p := range r.Payments {
tot += p.TotalOutputAmount()
}
return tot
}
func (r *Round) Sweep() {
r.Swept = true
}
func (r *Round) raise(event RoundEvent) {
if r.changes == nil {
r.changes = make([]RoundEvent, 0)
}
r.changes = append(r.changes, event)
r.On(event, false)
}

View File

@@ -0,0 +1,28 @@
package domain
import (
"context"
)
type RoundEventRepository interface {
Save(ctx context.Context, id string, events ...RoundEvent) error
Load(ctx context.Context, id string) (*Round, error)
}
type RoundRepository interface {
AddOrUpdateRound(ctx context.Context, round Round) error
GetCurrentRound(ctx context.Context) (*Round, error)
GetRoundWithId(ctx context.Context, id string) (*Round, error)
GetRoundWithTxid(ctx context.Context, txid string) (*Round, error)
GetSweepableRounds(ctx context.Context) ([]Round, error)
}
type VtxoRepository interface {
AddVtxos(ctx context.Context, vtxos []Vtxo) error
SpendVtxos(ctx context.Context, vtxos []VtxoKey) error
RedeemVtxos(ctx context.Context, vtxos []VtxoKey) ([]Vtxo, error)
GetVtxos(ctx context.Context, vtxos []VtxoKey) ([]Vtxo, error)
GetVtxosForRound(ctx context.Context, txid string) ([]Vtxo, error)
SweepVtxos(ctx context.Context, vtxos []VtxoKey) error
GetSpendableVtxos(ctx context.Context, pubkey string) ([]Vtxo, error)
}

View File

@@ -0,0 +1,576 @@
package domain_test
import (
"fmt"
"testing"
"github.com/ark-network/ark/common/tree"
"github.com/ark-network/ark/internal/core/domain"
"github.com/stretchr/testify/require"
)
var (
dustAmount = uint64(450)
payments = []domain.Payment{
{
Id: "0",
Inputs: []domain.Vtxo{{
VtxoKey: domain.VtxoKey{
Txid: txid,
VOut: 0,
},
Receiver: domain.Receiver{
Pubkey: pubkey,
Amount: 2000,
},
}},
Receivers: []domain.Receiver{
{
Pubkey: pubkey,
Amount: 700,
},
{
Pubkey: pubkey,
Amount: 700,
},
{
Pubkey: pubkey,
Amount: 600,
},
},
},
{
Id: "1",
Inputs: []domain.Vtxo{
{
VtxoKey: domain.VtxoKey{
Txid: txid,
VOut: 0,
},
Receiver: domain.Receiver{
Pubkey: pubkey,
Amount: 1000,
},
},
{
VtxoKey: domain.VtxoKey{
Txid: txid,
VOut: 0,
},
Receiver: domain.Receiver{
Pubkey: pubkey,
Amount: 1000,
},
},
},
Receivers: []domain.Receiver{{
Pubkey: pubkey,
Amount: 2000,
}},
},
}
emptyPtx = "cHNldP8BAgQCAAAAAQQBAAEFAQABBgEDAfsEAgAAAAA="
emptyTx = "0200000000000000000000"
txid = "0000000000000000000000000000000000000000000000000000000000000000"
pubkey = "030000000000000000000000000000000000000000000000000000000000000001"
congestionTree = tree.CongestionTree{
{
{
Txid: txid,
Tx: emptyPtx,
ParentTxid: txid,
},
},
{
{
Txid: txid,
Tx: emptyPtx,
ParentTxid: txid,
},
{
Txid: txid,
Tx: emptyPtx,
ParentTxid: txid,
},
},
{
{
Txid: txid,
Tx: emptyPtx,
ParentTxid: txid,
},
{
Txid: txid,
Tx: emptyPtx,
ParentTxid: txid,
},
{
Txid: txid,
Tx: emptyPtx,
ParentTxid: txid,
},
{
Txid: txid,
Tx: emptyPtx,
ParentTxid: txid,
},
},
}
connectors = []string{emptyPtx, emptyPtx, emptyPtx}
forfeitTxs = []string{emptyPtx, emptyPtx, emptyPtx, emptyPtx, emptyPtx, emptyPtx, emptyPtx, emptyPtx, emptyPtx}
poolTx = emptyTx
)
func TestRound(t *testing.T) {
testStartRegistration(t)
testRegisterPayments(t)
testStartFinalization(t)
testEndFinalization(t)
testFail(t)
}
func testStartRegistration(t *testing.T) {
t.Run("start_registration", func(t *testing.T) {
t.Run("valid", func(t *testing.T) {
round := domain.NewRound(dustAmount)
require.NotNil(t, round)
require.NotEmpty(t, round.Id)
require.Empty(t, round.Events())
require.False(t, round.IsStarted())
require.False(t, round.IsEnded())
require.False(t, round.IsFailed())
events, err := round.StartRegistration()
require.NoError(t, err)
require.Len(t, events, 1)
require.True(t, round.IsStarted())
require.False(t, round.IsEnded())
require.False(t, round.IsFailed())
event, ok := events[0].(domain.RoundStarted)
require.True(t, ok)
require.Equal(t, round.Id, event.Id)
require.Equal(t, round.StartingTimestamp, event.Timestamp)
})
t.Run("invalid", func(t *testing.T) {
fixtures := []struct {
round *domain.Round
expectedErr string
}{
{
round: &domain.Round{
Id: "id",
Stage: domain.Stage{
Code: domain.UndefinedStage,
Failed: true,
},
},
expectedErr: "not in a valid stage to start payment registration",
},
{
round: &domain.Round{
Id: "id",
Stage: domain.Stage{
Code: domain.RegistrationStage,
},
},
expectedErr: "not in a valid stage to start payment registration",
},
{
round: &domain.Round{
Id: "id",
Stage: domain.Stage{
Code: domain.FinalizationStage,
},
},
expectedErr: "not in a valid stage to start payment registration",
},
}
for _, f := range fixtures {
events, err := f.round.StartRegistration()
require.EqualError(t, err, f.expectedErr)
require.Empty(t, events)
}
})
})
}
func testRegisterPayments(t *testing.T) {
t.Run("register_payments", func(t *testing.T) {
t.Run("valid", func(t *testing.T) {
round := domain.NewRound(dustAmount)
events, err := round.StartRegistration()
require.NoError(t, err)
require.NotEmpty(t, events)
events, err = round.RegisterPayments(payments)
require.NoError(t, err)
require.Len(t, events, 1)
require.Condition(t, func() bool {
for _, payment := range payments {
_, ok := round.Payments[payment.Id]
if !ok {
return false
}
}
return true
})
event, ok := events[0].(domain.PaymentsRegistered)
require.True(t, ok)
require.Equal(t, round.Id, event.Id)
require.Equal(t, payments, event.Payments)
})
t.Run("invalid", func(t *testing.T) {
fixtures := []struct {
round *domain.Round
payments []domain.Payment
expectedErr string
}{
{
round: &domain.Round{
Id: "id",
Stage: domain.Stage{},
},
payments: payments,
expectedErr: "not in a valid stage to register payments",
},
{
round: &domain.Round{
Id: "id",
Stage: domain.Stage{
Code: domain.RegistrationStage,
Failed: true,
},
},
payments: payments,
expectedErr: "not in a valid stage to register payments",
},
{
round: &domain.Round{
Id: "id",
Stage: domain.Stage{
Code: domain.FinalizationStage,
},
},
payments: payments,
expectedErr: "not in a valid stage to register payments",
},
{
round: &domain.Round{
Id: "id",
Stage: domain.Stage{
Code: domain.RegistrationStage,
},
},
payments: nil,
expectedErr: "missing payments to register",
},
}
for _, f := range fixtures {
events, err := f.round.RegisterPayments(f.payments)
require.EqualError(t, err, f.expectedErr)
require.Empty(t, events)
}
})
})
}
func testStartFinalization(t *testing.T) {
t.Run("start_finalization", func(t *testing.T) {
t.Run("valid", func(t *testing.T) {
round := domain.NewRound(dustAmount)
events, err := round.StartRegistration()
require.NoError(t, err)
require.NotEmpty(t, events)
events, err = round.RegisterPayments(payments)
require.NoError(t, err)
require.NotEmpty(t, events)
events, err = round.StartFinalization(connectors, congestionTree, poolTx)
require.NoError(t, err)
require.Len(t, events, 1)
require.True(t, round.IsStarted())
require.False(t, round.IsEnded())
require.False(t, round.IsFailed())
event, ok := events[0].(domain.RoundFinalizationStarted)
require.True(t, ok)
require.Equal(t, round.Id, event.Id)
require.Exactly(t, connectors, event.Connectors)
require.Exactly(t, congestionTree, event.CongestionTree)
require.Exactly(t, poolTx, event.PoolTx)
})
t.Run("invalid", func(t *testing.T) {
paymentsById := map[string]domain.Payment{}
for _, p := range payments {
paymentsById[p.Id] = p
}
fixtures := []struct {
round *domain.Round
connectors []string
tree tree.CongestionTree
poolTx string
expectedErr string
}{
{
round: &domain.Round{
Id: "0",
Stage: domain.Stage{
Code: domain.RegistrationStage,
},
Payments: paymentsById,
},
connectors: nil,
tree: congestionTree,
poolTx: poolTx,
expectedErr: "missing list of connectors",
},
{
round: &domain.Round{
Id: "0",
Stage: domain.Stage{
Code: domain.RegistrationStage,
},
Payments: paymentsById,
},
connectors: connectors,
tree: nil,
poolTx: poolTx,
expectedErr: "missing congestion tree",
},
{
round: &domain.Round{
Id: "0",
Stage: domain.Stage{
Code: domain.RegistrationStage,
},
Payments: paymentsById,
},
connectors: connectors,
tree: congestionTree,
poolTx: "",
expectedErr: "missing unsigned pool tx",
},
{
round: &domain.Round{
Id: "0",
Stage: domain.Stage{
Code: domain.RegistrationStage,
},
Payments: nil,
},
connectors: connectors,
tree: congestionTree,
poolTx: poolTx,
expectedErr: "no payments registered",
},
{
round: &domain.Round{
Id: "0",
Stage: domain.Stage{
Code: domain.UndefinedStage,
},
Payments: paymentsById,
},
connectors: connectors,
tree: congestionTree,
poolTx: poolTx,
expectedErr: "not in a valid stage to start payment finalization",
},
{
round: &domain.Round{
Id: "0",
Stage: domain.Stage{
Code: domain.RegistrationStage,
Failed: true,
},
Payments: paymentsById,
},
connectors: connectors,
tree: congestionTree,
poolTx: poolTx,
expectedErr: "not in a valid stage to start payment finalization",
},
{
round: &domain.Round{
Id: "0",
Stage: domain.Stage{
Code: domain.FinalizationStage,
},
Payments: paymentsById,
},
connectors: connectors,
tree: congestionTree,
poolTx: poolTx,
expectedErr: "not in a valid stage to start payment finalization",
},
}
for _, f := range fixtures {
events, err := f.round.StartFinalization(f.connectors, f.tree, f.poolTx)
require.EqualError(t, err, f.expectedErr)
require.Empty(t, events)
}
})
})
}
func testEndFinalization(t *testing.T) {
t.Run("end_registration", func(t *testing.T) {
t.Run("valid", func(t *testing.T) {
round := domain.NewRound(dustAmount)
events, err := round.StartRegistration()
require.NoError(t, err)
require.NotEmpty(t, events)
events, err = round.RegisterPayments(payments)
require.NoError(t, err)
require.NotEmpty(t, events)
events, err = round.StartFinalization(connectors, congestionTree, poolTx)
require.NoError(t, err)
require.NotEmpty(t, events)
events, err = round.EndFinalization(forfeitTxs, txid)
require.NoError(t, err)
require.Len(t, events, 1)
require.False(t, round.IsStarted())
require.True(t, round.IsEnded())
require.False(t, round.IsFailed())
event, ok := events[0].(domain.RoundFinalized)
require.True(t, ok)
require.Equal(t, round.Id, event.Id)
require.Exactly(t, txid, event.Txid)
require.Exactly(t, forfeitTxs, event.ForfeitTxs)
require.Exactly(t, round.EndingTimestamp, event.Timestamp)
})
t.Run("invalid", func(t *testing.T) {
paymentsById := map[string]domain.Payment{}
for _, p := range payments {
paymentsById[p.Id] = p
}
fixtures := []struct {
round *domain.Round
forfeitTxs []string
txid string
expectedErr string
}{
{
round: &domain.Round{
Id: "0",
Stage: domain.Stage{
Code: domain.FinalizationStage,
},
},
forfeitTxs: nil,
txid: txid,
expectedErr: "missing list of signed forfeit txs",
},
{
round: &domain.Round{
Id: "0",
Stage: domain.Stage{
Code: domain.FinalizationStage,
},
},
forfeitTxs: forfeitTxs,
txid: "",
expectedErr: "missing pool txid",
},
{
round: &domain.Round{
Id: "0",
},
forfeitTxs: forfeitTxs,
txid: txid,
expectedErr: "not in a valid stage to end payment finalization",
},
{
round: &domain.Round{
Id: "0",
Stage: domain.Stage{
Code: domain.RegistrationStage,
},
},
forfeitTxs: forfeitTxs,
txid: txid,
expectedErr: "not in a valid stage to end payment finalization",
},
{
round: &domain.Round{
Id: "0",
Stage: domain.Stage{
Code: domain.FinalizationStage,
Failed: true,
},
},
forfeitTxs: []string{emptyPtx, emptyPtx, emptyPtx, emptyPtx},
txid: txid,
expectedErr: "not in a valid stage to end payment finalization",
},
{
round: &domain.Round{
Id: "0",
Stage: domain.Stage{
Code: domain.FinalizationStage,
Ended: true,
},
},
forfeitTxs: []string{emptyPtx, emptyPtx, emptyPtx, emptyPtx},
txid: txid,
expectedErr: "round already finalized",
},
}
for _, f := range fixtures {
events, err := f.round.EndFinalization(f.forfeitTxs, f.txid)
require.EqualError(t, err, f.expectedErr)
require.Empty(t, events)
}
})
})
}
func testFail(t *testing.T) {
t.Run("fail", func(t *testing.T) {
t.Run("valid", func(t *testing.T) {
round := domain.NewRound(dustAmount)
events, err := round.StartRegistration()
require.NoError(t, err)
require.NotEmpty(t, events)
events, err = round.RegisterPayments(payments)
require.NoError(t, err)
require.NotEmpty(t, events)
reason := fmt.Errorf("some valid reason")
events = round.Fail(reason)
require.Len(t, events, 1)
require.False(t, round.IsStarted())
require.False(t, round.IsEnded())
require.True(t, round.IsFailed())
event, ok := events[0].(domain.RoundFailed)
require.True(t, ok)
require.Exactly(t, round.Id, event.Id)
require.Exactly(t, round.EndingTimestamp, event.Timestamp)
require.EqualError(t, reason, event.Err)
events = round.Fail(reason)
require.Empty(t, events)
})
})
}

View File

@@ -0,0 +1,11 @@
package ports
import "github.com/ark-network/ark/internal/core/domain"
type RepoManager interface {
Events() domain.RoundEventRepository
Rounds() domain.RoundRepository
Vtxos() domain.VtxoRepository
RegisterEventsHandler(func(*domain.Round))
Close()
}

View File

@@ -0,0 +1,12 @@
package ports
import (
"github.com/ark-network/ark/internal/core/domain"
"golang.org/x/net/context"
)
type BlockchainScanner interface {
WatchScripts(ctx context.Context, scripts []string) error
UnwatchScripts(ctx context.Context, scripts []string) error
GetNotificationChannel(ctx context.Context) chan []domain.VtxoKey
}

View File

@@ -0,0 +1,9 @@
package ports
type SchedulerService interface {
Start()
Stop()
ScheduleTask(interval int64, immediate bool, task func()) error
ScheduleTaskOnce(delay int64, task func()) error
}

View File

@@ -0,0 +1,29 @@
package ports
import (
"github.com/ark-network/ark/common/tree"
"github.com/ark-network/ark/internal/core/domain"
"github.com/decred/dcrd/dcrec/secp256k1/v4"
"github.com/vulpemventures/go-elements/psetv2"
)
type SweepInput struct {
InputArgs psetv2.InputArgs
SweepLeaf psetv2.TapLeafScript
Amount uint64
}
type TxBuilder interface {
BuildPoolTx(
aspPubkey *secp256k1.PublicKey, payments []domain.Payment, minRelayFee uint64,
) (poolTx string, congestionTree tree.CongestionTree, err error)
BuildForfeitTxs(
aspPubkey *secp256k1.PublicKey, poolTx string, payments []domain.Payment,
) (connectors []string, forfeitTxs []string, err error)
BuildSweepTx(
wallet WalletService,
inputs []SweepInput,
) (signedSweepTx string, err error)
GetLeafSweepClosure(node tree.Node, userPubKey *secp256k1.PublicKey) (*psetv2.TapLeafScript, int64, error)
GetVtxoScript(userPubkey, aspPubkey *secp256k1.PublicKey) ([]byte, error)
}

View File

@@ -0,0 +1,37 @@
package ports
import (
"context"
"github.com/decred/dcrd/dcrec/secp256k1/v4"
)
type WalletService interface {
BlockchainScanner
Status(ctx context.Context) (WalletStatus, error)
GetPubkey(ctx context.Context) (*secp256k1.PublicKey, error)
DeriveAddresses(ctx context.Context, num int) ([]string, error)
SignPset(
ctx context.Context, pset string, extractRawTx bool,
) (string, error)
SelectUtxos(ctx context.Context, asset string, amount uint64) ([]TxInput, uint64, error)
BroadcastTransaction(ctx context.Context, txHex string) (string, error)
SignPsetWithKey(ctx context.Context, pset string, inputIndexes []int) (string, error) // inputIndexes == nil means sign all inputs
IsTransactionPublished(ctx context.Context, txid string) (isPublished bool, blocktime int64, err error)
EstimateFees(ctx context.Context, pset string) (uint64, error)
Close()
}
type WalletStatus interface {
IsInitialized() bool
IsUnlocked() bool
IsSynced() bool
}
type TxInput interface {
GetTxid() string
GetIndex() uint32
GetScript() string
GetAsset() string
GetValue() uint64
}

View File

@@ -0,0 +1,154 @@
package badgerdb
import (
"context"
"fmt"
"path/filepath"
"sync"
"github.com/ark-network/ark/internal/core/domain"
dbtypes "github.com/ark-network/ark/internal/infrastructure/db/types"
"github.com/dgraph-io/badger/v4"
"github.com/timshannon/badgerhold/v4"
)
const eventStoreDir = "round-events"
type eventsDTO struct {
Events [][]byte
}
type eventRepository struct {
store *badgerhold.Store
lock *sync.RWMutex
chUpdates chan *domain.Round
handler func(round *domain.Round)
}
func NewRoundEventRepository(config ...interface{}) (dbtypes.EventStore, error) {
if len(config) != 2 {
return nil, fmt.Errorf("invalid config")
}
baseDir, ok := config[0].(string)
if !ok {
return nil, fmt.Errorf("invalid base directory")
}
var logger badger.Logger
if config[1] != nil {
logger, ok = config[1].(badger.Logger)
if !ok {
return nil, fmt.Errorf("invalid logger")
}
}
var dir string
if len(baseDir) > 0 {
dir = filepath.Join(baseDir, eventStoreDir)
}
store, err := createDB(dir, logger)
if err != nil {
return nil, fmt.Errorf("failed to open round events store: %s", err)
}
chEvents := make(chan *domain.Round)
lock := &sync.RWMutex{}
repo := &eventRepository{store, lock, chEvents, nil}
go repo.listen()
return repo, nil
}
func (r *eventRepository) Save(
ctx context.Context, id string, events ...domain.RoundEvent,
) error {
allEvents, err := r.get(ctx, id)
if err != nil {
return err
}
allEvents = append(allEvents, events...)
if err := r.upsert(ctx, id, allEvents); err != nil {
return err
}
go r.publishEvents(allEvents)
return nil
}
func (r *eventRepository) Load(
ctx context.Context, id string,
) (*domain.Round, error) {
events, err := r.get(ctx, id)
if err != nil {
return nil, err
}
return domain.NewRoundFromEvents(events), nil
}
func (r *eventRepository) RegisterEventsHandler(
handler func(round *domain.Round),
) {
r.lock.Lock()
defer r.lock.Unlock()
r.handler = handler
}
func (r *eventRepository) Close() {
close(r.chUpdates)
r.store.Close()
}
func (r *eventRepository) get(
ctx context.Context, id string,
) ([]domain.RoundEvent, error) {
dto := eventsDTO{}
var err error
if ctx.Value("tx") != nil {
tx := ctx.Value("tx").(*badger.Txn)
err = r.store.TxGet(tx, id, &dto)
} else {
err = r.store.Get(id, &dto)
}
if err != nil {
if err == badgerhold.ErrNotFound {
return nil, nil
}
return nil, fmt.Errorf("failed to get events with id %s: %s", id, err)
}
return deserializeEvents(dto.Events)
}
func (r *eventRepository) upsert(
ctx context.Context, id string, events []domain.RoundEvent,
) error {
buf, err := serializeEvents(events)
if err != nil {
return err
}
if ctx.Value("tx") != nil {
tx := ctx.Value("tx").(*badger.Txn)
err = r.store.TxUpsert(tx, id, buf)
} else {
err = r.store.Upsert(id, buf)
}
if err != nil {
return fmt.Errorf("failed to upsert events with id %s: %s", id, err)
}
return nil
}
func (r *eventRepository) listen() {
for updatedRound := range r.chUpdates {
r.lock.RLock()
if r.handler != nil {
r.handler(updatedRound)
}
r.lock.RUnlock()
}
}
func (r *eventRepository) publishEvents(events []domain.RoundEvent) {
r.lock.Lock()
defer r.lock.Unlock()
round := domain.NewRoundFromEvents(events)
r.chUpdates <- round
}

View File

@@ -0,0 +1,140 @@
package badgerdb
import (
"context"
"fmt"
"path/filepath"
"github.com/ark-network/ark/internal/core/domain"
dbtypes "github.com/ark-network/ark/internal/infrastructure/db/types"
"github.com/dgraph-io/badger/v4"
"github.com/timshannon/badgerhold/v4"
)
const roundStoreDir = "rounds"
type roundRepository struct {
store *badgerhold.Store
}
func NewRoundRepository(config ...interface{}) (dbtypes.RoundStore, error) {
if len(config) != 2 {
return nil, fmt.Errorf("invalid config")
}
baseDir, ok := config[0].(string)
if !ok {
return nil, fmt.Errorf("invalid base directory")
}
var logger badger.Logger
if config[1] != nil {
logger, ok = config[1].(badger.Logger)
if !ok {
return nil, fmt.Errorf("invalid logger")
}
}
var dir string
if len(baseDir) > 0 {
dir = filepath.Join(baseDir, roundStoreDir)
}
store, err := createDB(dir, logger)
if err != nil {
return nil, fmt.Errorf("failed to open round events store: %s", err)
}
return &roundRepository{store}, nil
}
func (r *roundRepository) AddOrUpdateRound(
ctx context.Context, round domain.Round,
) error {
return r.addOrUpdateRound(ctx, round)
}
func (r *roundRepository) GetCurrentRound(
ctx context.Context,
) (*domain.Round, error) {
query := badgerhold.Where("Stage.Ended").Eq(false).And("Stage.Failed").Eq(false)
rounds, err := r.findRound(ctx, query)
if err != nil {
return nil, err
}
if len(rounds) <= 0 {
return nil, fmt.Errorf("ongoing round not found")
}
return &rounds[0], nil
}
func (r *roundRepository) GetRoundWithId(
ctx context.Context, id string,
) (*domain.Round, error) {
query := badgerhold.Where("Id").Eq(id)
rounds, err := r.findRound(ctx, query)
if err != nil {
return nil, err
}
if len(rounds) <= 0 {
return nil, fmt.Errorf("round with id %s not found", id)
}
round := &rounds[0]
return round, nil
}
func (r *roundRepository) GetRoundWithTxid(
ctx context.Context, txid string,
) (*domain.Round, error) {
query := badgerhold.Where("Txid").Eq(txid)
rounds, err := r.findRound(ctx, query)
if err != nil {
return nil, err
}
if len(rounds) <= 0 {
return nil, fmt.Errorf("round with txid %s not found", txid)
}
round := &rounds[0]
return round, nil
}
func (r *roundRepository) GetSweepableRounds(
ctx context.Context,
) ([]domain.Round, error) {
query := badgerhold.Where("Stage.Code").Eq(domain.FinalizationStage).
And("Stage.Ended").Eq(true).And("Swept").Eq(false)
rounds, err := r.findRound(ctx, query)
if err != nil {
return nil, err
}
return rounds, nil
}
func (r *roundRepository) Close() {
r.store.Close()
}
func (r *roundRepository) findRound(
ctx context.Context, query *badgerhold.Query,
) ([]domain.Round, error) {
var rounds []domain.Round
var err error
if ctx.Value("tx") != nil {
tx := ctx.Value("tx").(*badger.Txn)
err = r.store.TxFind(tx, &rounds, query)
} else {
err = r.store.Find(&rounds, query)
}
return rounds, err
}
func (r *roundRepository) addOrUpdateRound(
ctx context.Context, round domain.Round,
) (err error) {
if ctx.Value("tx") != nil {
tx := ctx.Value("tx").(*badger.Txn)
err = r.store.TxUpsert(tx, round.Id, round)
} else {
err = r.store.Upsert(round.Id, round)
}
return
}

View File

@@ -0,0 +1,116 @@
package badgerdb
import (
"encoding/json"
"fmt"
"time"
"github.com/ark-network/ark/internal/core/domain"
"github.com/dgraph-io/badger/v4"
"github.com/dgraph-io/badger/v4/options"
"github.com/timshannon/badgerhold/v4"
)
func createDB(dbDir string, logger badger.Logger) (*badgerhold.Store, error) {
isInMemory := len(dbDir) <= 0
opts := badger.DefaultOptions(dbDir)
opts.Logger = logger
if isInMemory {
opts.InMemory = true
} else {
opts.Compression = options.ZSTD
}
db, err := badgerhold.Open(badgerhold.Options{
Encoder: badgerhold.DefaultEncode,
Decoder: badgerhold.DefaultDecode,
SequenceBandwith: 100,
Options: opts,
})
if err != nil {
return nil, err
}
if !isInMemory {
ticker := time.NewTicker(30 * time.Minute)
go func() {
for {
<-ticker.C
if err := db.Badger().RunValueLogGC(0.5); err != nil && err != badger.ErrNoRewrite {
logger.Errorf("%s", err)
}
}
}()
}
return db, nil
}
func serializeEvents(events []domain.RoundEvent) (*eventsDTO, error) {
rawEvents := make([][]byte, 0, len(events))
for _, event := range events {
buf, err := serializeEvent(event)
if err != nil {
return nil, err
}
rawEvents = append(rawEvents, buf)
}
return &eventsDTO{rawEvents}, nil
}
func deserializeEvents(rawEvents [][]byte) ([]domain.RoundEvent, error) {
events := make([]domain.RoundEvent, 0)
for _, buf := range rawEvents {
event, err := deserializeEvent(buf)
if err != nil {
return nil, err
}
events = append(events, event)
}
return events, nil
}
func serializeEvent(event domain.RoundEvent) ([]byte, error) {
switch eventType := event.(type) {
default:
return json.Marshal(eventType)
}
}
func deserializeEvent(buf []byte) (domain.RoundEvent, error) {
{
var event = domain.RoundFailed{}
if err := json.Unmarshal(buf, &event); err == nil && len(event.Err) > 0 {
return event, nil
}
}
{
var event = domain.RoundFinalized{}
if err := json.Unmarshal(buf, &event); err == nil && len(event.Txid) > 0 {
return event, nil
}
}
{
var event = domain.RoundFinalizationStarted{}
if err := json.Unmarshal(buf, &event); err == nil && len(event.CongestionTree) > 0 {
return event, nil
}
}
{
var event = domain.PaymentsRegistered{}
if err := json.Unmarshal(buf, &event); err == nil && len(event.Payments) > 0 {
return event, nil
}
}
{
var event = domain.RoundStarted{}
if err := json.Unmarshal(buf, &event); err == nil && event.Timestamp > 0 {
return event, nil
}
}
return nil, fmt.Errorf("unknown event")
}

View File

@@ -0,0 +1,245 @@
package badgerdb
import (
"context"
"fmt"
"path/filepath"
"strings"
"github.com/ark-network/ark/internal/core/domain"
dbtypes "github.com/ark-network/ark/internal/infrastructure/db/types"
"github.com/dgraph-io/badger/v4"
"github.com/timshannon/badgerhold/v4"
)
const vtxoStoreDir = "vtxos"
type vtxoRepository struct {
store *badgerhold.Store
}
func NewVtxoRepository(config ...interface{}) (dbtypes.VtxoStore, error) {
if len(config) != 2 {
return nil, fmt.Errorf("invalid config")
}
baseDir, ok := config[0].(string)
if !ok {
return nil, fmt.Errorf("invalid base directory")
}
var logger badger.Logger
if config[1] != nil {
logger, ok = config[1].(badger.Logger)
if !ok {
return nil, fmt.Errorf("invalid logger")
}
}
var dir string
if len(baseDir) > 0 {
dir = filepath.Join(baseDir, vtxoStoreDir)
}
store, err := createDB(dir, logger)
if err != nil {
return nil, fmt.Errorf("failed to open round events store: %s", err)
}
return &vtxoRepository{store}, nil
}
func (r *vtxoRepository) AddVtxos(
ctx context.Context, vtxos []domain.Vtxo,
) error {
return r.addVtxos(ctx, vtxos)
}
func (r *vtxoRepository) SpendVtxos(
ctx context.Context, vtxoKeys []domain.VtxoKey,
) error {
for _, vtxoKey := range vtxoKeys {
if err := r.spendVtxo(ctx, vtxoKey); err != nil {
return err
}
}
return nil
}
func (r *vtxoRepository) RedeemVtxos(
ctx context.Context, vtxoKeys []domain.VtxoKey,
) ([]domain.Vtxo, error) {
vtxos := make([]domain.Vtxo, 0, len(vtxoKeys))
for _, vtxoKey := range vtxoKeys {
vtxo, err := r.redeemVtxo(ctx, vtxoKey)
if err != nil {
return nil, err
}
if vtxo != nil {
vtxos = append(vtxos, *vtxo)
}
}
return vtxos, nil
}
func (r *vtxoRepository) GetVtxos(
ctx context.Context, vtxoKeys []domain.VtxoKey,
) ([]domain.Vtxo, error) {
vtxos := make([]domain.Vtxo, 0, len(vtxoKeys))
for _, vtxoKey := range vtxoKeys {
vtxo, err := r.getVtxo(ctx, vtxoKey)
if err != nil {
return nil, err
}
vtxos = append(vtxos, *vtxo)
}
return vtxos, nil
}
func (r *vtxoRepository) GetVtxosForRound(
ctx context.Context, txid string,
) ([]domain.Vtxo, error) {
query := badgerhold.Where("Txid").Eq(txid)
return r.findVtxos(ctx, query)
}
func (r *vtxoRepository) GetSpendableVtxos(
ctx context.Context, pubkey string,
) ([]domain.Vtxo, error) {
query := badgerhold.Where("Spent").Eq(false).And("Redeemed").Eq(false).And("Swept").Eq(false)
if len(pubkey) > 0 {
query = query.And("Pubkey").Eq(pubkey)
}
return r.findVtxos(ctx, query)
}
func (r *vtxoRepository) SweepVtxos(
ctx context.Context, vtxoKeys []domain.VtxoKey,
) error {
for _, vtxoKey := range vtxoKeys {
if err := r.sweepVtxo(ctx, vtxoKey); err != nil {
return err
}
}
return nil
}
func (r *vtxoRepository) Close() {
r.store.Close()
}
func (r *vtxoRepository) addVtxos(
ctx context.Context, vtxos []domain.Vtxo,
) (err error) {
for _, vtxo := range vtxos {
vtxoKey := vtxo.VtxoKey.Hash()
if ctx.Value("tx") != nil {
tx := ctx.Value("tx").(*badger.Txn)
err = r.store.TxInsert(tx, vtxoKey, vtxo)
} else {
err = r.store.Insert(vtxoKey, vtxo)
}
}
if err != nil && err == badgerhold.ErrKeyExists {
err = nil
}
return
}
func (r *vtxoRepository) getVtxo(
ctx context.Context, vtxoKey domain.VtxoKey,
) (*domain.Vtxo, error) {
var vtxo domain.Vtxo
var err error
if ctx.Value("tx") != nil {
tx := ctx.Value("tx").(*badger.Txn)
err = r.store.TxGet(tx, vtxoKey.Hash(), &vtxo)
} else {
err = r.store.Get(vtxoKey.Hash(), &vtxo)
}
if err != nil && err == badgerhold.ErrNotFound {
return nil, fmt.Errorf("vtxo %s:%d not found", vtxoKey.Txid, vtxoKey.VOut)
}
return &vtxo, nil
}
func (r *vtxoRepository) spendVtxo(ctx context.Context, vtxoKey domain.VtxoKey) error {
vtxo, err := r.getVtxo(ctx, vtxoKey)
if err != nil {
if strings.Contains(err.Error(), "not found") {
return nil
}
return err
}
if vtxo.Spent {
return nil
}
vtxo.Spent = true
if ctx.Value("tx") != nil {
tx := ctx.Value("tx").(*badger.Txn)
err = r.store.TxUpdate(tx, vtxoKey.Hash(), *vtxo)
} else {
err = r.store.Update(vtxoKey.Hash(), *vtxo)
}
return err
}
func (r *vtxoRepository) redeemVtxo(ctx context.Context, vtxoKey domain.VtxoKey) (*domain.Vtxo, error) {
vtxo, err := r.getVtxo(ctx, vtxoKey)
if err != nil {
if strings.Contains(err.Error(), "not found") {
return nil, nil
}
return nil, err
}
if vtxo.Redeemed {
return nil, nil
}
vtxo.Redeemed = true
if ctx.Value("tx") != nil {
tx := ctx.Value("tx").(*badger.Txn)
err = r.store.TxUpdate(tx, vtxoKey.Hash(), *vtxo)
} else {
err = r.store.Update(vtxoKey.Hash(), *vtxo)
}
if err != nil {
return nil, err
}
return vtxo, nil
}
func (r *vtxoRepository) findVtxos(ctx context.Context, query *badgerhold.Query) ([]domain.Vtxo, error) {
var vtxos []domain.Vtxo
var err error
if ctx.Value("tx") != nil {
tx := ctx.Value("tx").(*badger.Txn)
err = r.store.TxFind(tx, &vtxos, query)
} else {
err = r.store.Find(&vtxos, query)
}
return vtxos, err
}
func (r *vtxoRepository) sweepVtxo(ctx context.Context, vtxoKey domain.VtxoKey) error {
vtxo, err := r.getVtxo(ctx, vtxoKey)
if err != nil {
return err
}
if vtxo.Swept {
return nil
}
vtxo.Swept = true
if ctx.Value("tx") != nil {
tx := ctx.Value("tx").(*badger.Txn)
err = r.store.TxUpdate(tx, vtxoKey.Hash(), *vtxo)
} else {
err = r.store.Update(vtxoKey.Hash(), *vtxo)
}
if err != nil {
return err
}
return nil
}

View File

@@ -0,0 +1,90 @@
package db
import (
"fmt"
"github.com/ark-network/ark/internal/core/domain"
"github.com/ark-network/ark/internal/core/ports"
badgerdb "github.com/ark-network/ark/internal/infrastructure/db/badger"
dbtypes "github.com/ark-network/ark/internal/infrastructure/db/types"
)
var (
eventStoreTypes = map[string]func(...interface{}) (dbtypes.EventStore, error){
"badger": badgerdb.NewRoundEventRepository,
}
roundStoreTypes = map[string]func(...interface{}) (dbtypes.RoundStore, error){
"badger": badgerdb.NewRoundRepository,
}
vtxoStoreTypes = map[string]func(...interface{}) (dbtypes.VtxoStore, error){
"badger": badgerdb.NewVtxoRepository,
}
)
type ServiceConfig struct {
EventStoreType string
RoundStoreType string
VtxoStoreType string
EventStoreConfig []interface{}
RoundStoreConfig []interface{}
VtxoStoreConfig []interface{}
}
type service struct {
eventStore dbtypes.EventStore
roundStore dbtypes.RoundStore
vtxoStore dbtypes.VtxoStore
}
func NewService(config ServiceConfig) (ports.RepoManager, error) {
eventStoreFactory, ok := eventStoreTypes[config.EventStoreType]
if !ok {
return nil, fmt.Errorf("event store type not supported")
}
roundStoreFactory, ok := roundStoreTypes[config.RoundStoreType]
if !ok {
return nil, fmt.Errorf("round store type not supported")
}
vtxoStoreFactory, ok := vtxoStoreTypes[config.VtxoStoreType]
if !ok {
return nil, fmt.Errorf("vtxo store type not supported")
}
eventStore, err := eventStoreFactory(config.EventStoreConfig...)
if err != nil {
return nil, fmt.Errorf("failed to open event store: %s", err)
}
roundStore, err := roundStoreFactory(config.RoundStoreConfig...)
if err != nil {
return nil, fmt.Errorf("failed to open round store: %s", err)
}
vtxoStore, err := vtxoStoreFactory(config.VtxoStoreConfig...)
if err != nil {
return nil, fmt.Errorf("failed to open vtxo store: %s", err)
}
return &service{eventStore, roundStore, vtxoStore}, nil
}
func (s *service) RegisterEventsHandler(handler func(round *domain.Round)) {
s.eventStore.RegisterEventsHandler(handler)
}
func (s *service) Events() domain.RoundEventRepository {
return s.eventStore
}
func (s *service) Rounds() domain.RoundRepository {
return s.roundStore
}
func (s *service) Vtxos() domain.VtxoRepository {
return s.vtxoStore
}
func (s *service) Close() {
s.eventStore.Close()
s.roundStore.Close()
s.vtxoStore.Close()
}

View File

@@ -0,0 +1,417 @@
package db_test
import (
"context"
"reflect"
"testing"
"time"
"github.com/ark-network/ark/common/tree"
"github.com/ark-network/ark/internal/core/domain"
"github.com/ark-network/ark/internal/core/ports"
"github.com/ark-network/ark/internal/infrastructure/db"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
const (
emptyPtx = "cHNldP8BAgQCAAAAAQQBAAEFAQABBgEDAfsEAgAAAAA="
emptyTx = "0200000000000000000000"
txid = "00000000000000000000000000000000000000000000000000000000000000000"
pubkey1 = "0300000000000000000000000000000000000000000000000000000000000000001"
pubkey2 = "0200000000000000000000000000000000000000000000000000000000000000002"
)
var congestionTree = [][]tree.Node{
{
{
Txid: txid,
Tx: emptyPtx,
ParentTxid: txid,
},
},
{
{
Txid: txid,
Tx: emptyPtx,
ParentTxid: txid,
},
{
Txid: txid,
Tx: emptyPtx,
ParentTxid: txid,
},
},
{
{
Txid: txid,
Tx: emptyPtx,
ParentTxid: txid,
},
{
Txid: txid,
Tx: emptyPtx,
ParentTxid: txid,
},
{
Txid: txid,
Tx: emptyPtx,
ParentTxid: txid,
},
{
Txid: txid,
Tx: emptyPtx,
ParentTxid: txid,
},
},
}
func TestService(t *testing.T) {
tests := []struct {
name string
config db.ServiceConfig
}{
{
name: "repo_manager_with_badger_stores",
config: db.ServiceConfig{
EventStoreType: "badger",
RoundStoreType: "badger",
VtxoStoreType: "badger",
EventStoreConfig: []interface{}{"", nil},
RoundStoreConfig: []interface{}{"", nil},
VtxoStoreConfig: []interface{}{"", nil},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
svc, err := db.NewService(tt.config)
require.NoError(t, err)
require.NotNil(t, svc)
testRoundEventRepository(t, svc)
testRoundRepository(t, svc)
testVtxoRepository(t, svc)
svc.Close()
})
}
}
func testRoundEventRepository(t *testing.T, svc ports.RepoManager) {
t.Run("test_event_repository", func(t *testing.T) {
fixtures := []struct {
roundId string
events []domain.RoundEvent
handler func(*domain.Round)
}{
{
roundId: "42dd81f7-cadd-482c-bf69-8e9209aae9f3",
events: []domain.RoundEvent{
domain.RoundStarted{
Id: "42dd81f7-cadd-482c-bf69-8e9209aae9f3",
Timestamp: 1701190270,
},
},
handler: func(round *domain.Round) {
require.NotNil(t, round)
require.Len(t, round.Events(), 1)
require.True(t, round.IsStarted())
require.False(t, round.IsFailed())
require.False(t, round.IsEnded())
},
},
{
roundId: "1ea610ff-bf3e-4068-9bfd-b6c3f553467e",
events: []domain.RoundEvent{
domain.RoundStarted{
Id: "1ea610ff-bf3e-4068-9bfd-b6c3f553467e",
Timestamp: 1701190270,
},
domain.RoundFinalizationStarted{
Id: "1ea610ff-bf3e-4068-9bfd-b6c3f553467e",
CongestionTree: congestionTree,
Connectors: []string{emptyPtx, emptyPtx},
PoolTx: emptyTx,
},
},
handler: func(round *domain.Round) {
require.NotNil(t, round)
require.Len(t, round.Events(), 2)
require.Len(t, round.CongestionTree, 3)
require.Equal(t, round.CongestionTree.NumberOfNodes(), 7)
require.Len(t, round.Connectors, 2)
},
},
{
roundId: "7578231e-428d-45ae-aaa4-e62c77ad5cec",
events: []domain.RoundEvent{
domain.RoundStarted{
Id: "7578231e-428d-45ae-aaa4-e62c77ad5cec",
Timestamp: 1701190270,
},
domain.RoundFinalizationStarted{
Id: "7578231e-428d-45ae-aaa4-e62c77ad5cec",
CongestionTree: congestionTree,
Connectors: []string{emptyPtx, emptyPtx},
PoolTx: emptyTx,
},
domain.RoundFinalized{
Id: "7578231e-428d-45ae-aaa4-e62c77ad5cec",
Txid: txid,
ForfeitTxs: []string{emptyPtx, emptyPtx, emptyPtx, emptyPtx},
Timestamp: 1701190300,
},
},
handler: func(round *domain.Round) {
require.NotNil(t, round)
require.Len(t, round.Events(), 3)
require.False(t, round.IsStarted())
require.False(t, round.IsFailed())
require.True(t, round.IsEnded())
require.NotEmpty(t, round.Txid)
},
},
}
ctx := context.Background()
for _, f := range fixtures {
svc.RegisterEventsHandler(f.handler)
err := svc.Events().Save(ctx, f.roundId, f.events...)
require.NoError(t, err)
round, err := svc.Events().Load(ctx, f.roundId)
require.NoError(t, err)
require.NotNil(t, round)
require.Equal(t, f.roundId, round.Id)
require.Len(t, round.Events(), len(f.events))
}
})
}
func testRoundRepository(t *testing.T, svc ports.RepoManager) {
t.Run("test_round_repository", func(t *testing.T) {
ctx := context.Background()
now := time.Now()
roundId := uuid.New().String()
round, err := svc.Rounds().GetRoundWithId(ctx, roundId)
require.Error(t, err)
require.Nil(t, round)
events := []domain.RoundEvent{
domain.RoundStarted{
Id: roundId,
Timestamp: now.Unix(),
},
}
round = domain.NewRoundFromEvents(events)
err = svc.Rounds().AddOrUpdateRound(ctx, *round)
require.NoError(t, err)
currentRound, err := svc.Rounds().GetCurrentRound(ctx)
require.NoError(t, err)
require.NotNil(t, currentRound)
require.Condition(t, roundsMatch(*round, *currentRound))
roundById, err := svc.Rounds().GetRoundWithId(ctx, roundId)
require.NoError(t, err)
require.NotNil(t, roundById)
require.Condition(t, roundsMatch(*round, *roundById))
newEvents := []domain.RoundEvent{
domain.PaymentsRegistered{
Id: roundId,
Payments: []domain.Payment{
{
Id: uuid.New().String(),
Inputs: []domain.Vtxo{{}},
Receivers: []domain.Receiver{{}},
},
{
Id: uuid.New().String(),
Inputs: []domain.Vtxo{{}},
Receivers: []domain.Receiver{{}, {}, {}},
},
},
},
domain.RoundFinalizationStarted{
Id: roundId,
CongestionTree: congestionTree,
Connectors: []string{emptyPtx, emptyPtx},
PoolTx: emptyTx,
},
}
events = append(events, newEvents...)
updatedRound := domain.NewRoundFromEvents(events)
err = svc.Rounds().AddOrUpdateRound(ctx, *updatedRound)
require.NoError(t, err)
currentRound, err = svc.Rounds().GetCurrentRound(ctx)
require.NoError(t, err)
require.NotNil(t, currentRound)
require.Condition(t, roundsMatch(*updatedRound, *currentRound))
roundById, err = svc.Rounds().GetRoundWithId(ctx, updatedRound.Id)
require.NoError(t, err)
require.NotNil(t, currentRound)
require.Condition(t, roundsMatch(*updatedRound, *roundById))
newEvents = []domain.RoundEvent{
domain.RoundFinalized{
Id: roundId,
Txid: txid,
ForfeitTxs: []string{emptyPtx, emptyPtx, emptyPtx, emptyPtx},
Timestamp: now.Add(60 * time.Second).Unix(),
},
}
events = append(events, newEvents...)
finalizedRound := domain.NewRoundFromEvents(events)
err = svc.Rounds().AddOrUpdateRound(ctx, *finalizedRound)
require.NoError(t, err)
currentRound, err = svc.Rounds().GetCurrentRound(ctx)
require.Error(t, err)
require.Nil(t, currentRound)
roundById, err = svc.Rounds().GetRoundWithId(ctx, roundId)
require.NoError(t, err)
require.NotNil(t, roundById)
require.Condition(t, roundsMatch(*finalizedRound, *roundById))
roundByTxid, err := svc.Rounds().GetRoundWithTxid(ctx, txid)
require.NoError(t, err)
require.NotNil(t, roundByTxid)
require.Condition(t, roundsMatch(*finalizedRound, *roundByTxid))
})
}
func testVtxoRepository(t *testing.T, svc ports.RepoManager) {
t.Run("test_vtxo_repository", func(t *testing.T) {
ctx := context.Background()
userVtxos := []domain.Vtxo{
{
VtxoKey: domain.VtxoKey{
Txid: txid,
VOut: 0,
},
Receiver: domain.Receiver{
Pubkey: pubkey1,
Amount: 1000,
},
},
{
VtxoKey: domain.VtxoKey{
Txid: txid,
VOut: 1,
},
Receiver: domain.Receiver{
Pubkey: pubkey1,
Amount: 2000,
},
},
}
newVtxos := append(userVtxos, domain.Vtxo{
VtxoKey: domain.VtxoKey{
Txid: txid,
VOut: 1,
},
Receiver: domain.Receiver{
Pubkey: pubkey2,
Amount: 2000,
},
})
vtxoKeys := make([]domain.VtxoKey, 0, len(userVtxos))
for _, v := range userVtxos {
vtxoKeys = append(vtxoKeys, v.VtxoKey)
}
vtxos, err := svc.Vtxos().GetVtxos(ctx, vtxoKeys)
require.Error(t, err)
require.Empty(t, vtxos)
spendableVtxos, err := svc.Vtxos().GetSpendableVtxos(ctx, pubkey1)
require.NoError(t, err)
require.Empty(t, spendableVtxos)
spendableVtxos, err = svc.Vtxos().GetSpendableVtxos(ctx, "")
require.NoError(t, err)
require.Empty(t, spendableVtxos)
err = svc.Vtxos().AddVtxos(ctx, newVtxos)
require.NoError(t, err)
vtxos, err = svc.Vtxos().GetVtxos(ctx, vtxoKeys)
require.NoError(t, err)
require.Exactly(t, userVtxos, vtxos)
spendableVtxos, err = svc.Vtxos().GetSpendableVtxos(ctx, pubkey1)
require.NoError(t, err)
require.Exactly(t, vtxos, spendableVtxos)
spendableVtxos, err = svc.Vtxos().GetSpendableVtxos(ctx, "")
require.NoError(t, err)
require.Exactly(t, userVtxos, spendableVtxos)
err = svc.Vtxos().SpendVtxos(ctx, vtxoKeys[:1])
require.NoError(t, err)
spentVtxos, err := svc.Vtxos().GetVtxos(ctx, vtxoKeys[:1])
require.NoError(t, err)
require.Len(t, spentVtxos, len(vtxoKeys[:1]))
for _, v := range spentVtxos {
require.True(t, v.Spent)
}
spendableVtxos, err = svc.Vtxos().GetSpendableVtxos(ctx, pubkey1)
require.NoError(t, err)
require.Exactly(t, vtxos[1:], spendableVtxos)
})
}
func roundsMatch(expected, got domain.Round) assert.Comparison {
return func() bool {
if expected.Id != got.Id {
return false
}
if expected.StartingTimestamp != got.StartingTimestamp {
return false
}
if expected.EndingTimestamp != got.EndingTimestamp {
return false
}
if expected.Stage != got.Stage {
return false
}
if !reflect.DeepEqual(expected.Payments, got.Payments) {
return false
}
if expected.Txid != got.Txid {
return false
}
if expected.UnsignedTx != got.UnsignedTx {
return false
}
if !reflect.DeepEqual(expected.ForfeitTxs, got.ForfeitTxs) {
return false
}
if !reflect.DeepEqual(expected.CongestionTree, got.CongestionTree) {
return false
}
if !reflect.DeepEqual(expected.Connectors, got.Connectors) {
return false
}
if expected.Version != got.Version {
return false
}
return true
}
}

View File

@@ -0,0 +1,19 @@
package dbtypes
import "github.com/ark-network/ark/internal/core/domain"
type EventStore interface {
domain.RoundEventRepository
RegisterEventsHandler(func(*domain.Round))
Close()
}
type RoundStore interface {
domain.RoundRepository
Close()
}
type VtxoStore interface {
domain.VtxoRepository
Close()
}

View File

@@ -0,0 +1,30 @@
package oceanwallet
import (
"context"
pb "github.com/ark-network/ark/api-spec/protobuf/gen/ocean/v1"
"github.com/vulpemventures/go-elements/address"
)
func (s *service) DeriveAddresses(
ctx context.Context, numOfAddresses int,
) ([]string, error) {
res, err := s.accountClient.DeriveAddresses(ctx, &pb.DeriveAddressesRequest{
AccountName: accountLabel,
NumOfAddresses: uint64(numOfAddresses),
})
if err != nil {
return nil, err
}
addresses := make([]string, 0, numOfAddresses)
for _, addr := range res.GetAddresses() {
if isConf, _ := address.IsConfidential(addr); !isConf {
addresses = append(addresses, addr)
continue
}
info, _ := address.FromConfidential(addr)
addresses = append(addresses, info.Address)
}
return addresses, nil
}

View File

@@ -0,0 +1,47 @@
package oceanwallet
import (
"context"
"crypto/sha256"
"encoding/hex"
pb "github.com/ark-network/ark/api-spec/protobuf/gen/ocean/v1"
"github.com/ark-network/ark/internal/core/domain"
"github.com/btcsuite/btcd/chaincfg/chainhash"
)
func (s *service) WatchScripts(ctx context.Context, scripts []string) error {
for _, script := range scripts {
if _, err := s.notifyClient.WatchExternalScript(ctx, &pb.WatchExternalScriptRequest{
Script: script,
}); err != nil {
return err
}
}
return nil
}
func (s *service) UnwatchScripts(ctx context.Context, scripts []string) error {
for _, script := range scripts {
scriptHash := calcScriptHash(script)
if _, err := s.notifyClient.UnwatchExternalScript(ctx, &pb.UnwatchExternalScriptRequest{
Label: scriptHash,
}); err != nil {
return err
}
}
return nil
}
func (s *service) GetNotificationChannel(ctx context.Context) chan []domain.VtxoKey {
return s.chVtxos
}
func calcScriptHash(script string) string {
buf, _ := hex.DecodeString(script)
hashedBuf := sha256.Sum256(buf)
hash, _ := chainhash.NewHash(hashedBuf[:])
return hash.String()
}

View File

@@ -0,0 +1,140 @@
package oceanwallet
import (
"context"
"fmt"
"io"
"strings"
pb "github.com/ark-network/ark/api-spec/protobuf/gen/ocean/v1"
"github.com/ark-network/ark/internal/core/domain"
"github.com/ark-network/ark/internal/core/ports"
log "github.com/sirupsen/logrus"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/status"
)
type service struct {
addr string
conn *grpc.ClientConn
walletClient pb.WalletServiceClient
accountClient pb.AccountServiceClient
txClient pb.TransactionServiceClient
notifyClient pb.NotificationServiceClient
chVtxos chan []domain.VtxoKey
}
func NewService(addr string) (ports.WalletService, error) {
conn, err := grpc.Dial(addr, grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
return nil, err
}
walletClient := pb.NewWalletServiceClient(conn)
accountClient := pb.NewAccountServiceClient(conn)
txClient := pb.NewTransactionServiceClient(conn)
notifyClient := pb.NewNotificationServiceClient(conn)
chVtxos := make(chan []domain.VtxoKey)
svc := &service{
addr: addr,
conn: conn,
walletClient: walletClient,
accountClient: accountClient,
txClient: txClient,
notifyClient: notifyClient,
chVtxos: chVtxos,
}
ctx := context.Background()
status, err := svc.Status(ctx)
if err != nil {
return nil, err
}
if !(status.IsInitialized() && status.IsUnlocked()) {
return nil, fmt.Errorf("wallet must be already initialized and unlocked")
}
// Create ark account at startup if needed.
info, err := walletClient.GetInfo(ctx, &pb.GetInfoRequest{})
if err != nil {
return nil, err
}
found := false
for _, account := range info.GetAccounts() {
if account.GetLabel() == accountLabel {
found = true
break
}
}
if !found {
if _, err := accountClient.CreateAccountBIP44(ctx, &pb.CreateAccountBIP44Request{
Label: accountLabel,
Unconfidential: true,
}); err != nil {
return nil, err
}
}
go svc.listenToNotificaitons()
return svc, nil
}
func (s *service) Close() {
close(s.chVtxos)
s.conn.Close()
}
func (s *service) listenToNotificaitons() {
var stream pb.NotificationService_UtxosNotificationsClient
var err error
for {
stream, err = s.notifyClient.UtxosNotifications(context.Background(), &pb.UtxosNotificationsRequest{})
if err != nil {
continue
}
break
}
for {
msg, err := stream.Recv()
if err != nil {
if err == io.EOF || status.Convert(err).Code() == codes.Canceled {
return
}
log.WithError(err).Warn("received unexpected error from source")
return
}
if msg.GetEventType() != pb.UtxoEventType_UTXO_EVENT_TYPE_NEW &&
msg.GetEventType() != pb.UtxoEventType_UTXO_EVENT_TYPE_CONFIRMED {
continue
}
vtxos := toVtxos(msg.GetUtxos())
if len(vtxos) > 0 {
go func() {
s.chVtxos <- vtxos
}()
}
}
}
func toVtxos(utxos []*pb.Utxo) []domain.VtxoKey {
vtxos := make([]domain.VtxoKey, 0, len(utxos))
for _, utxo := range utxos {
// We want to notify for activity related to vtxos owner, therefore we skip
// returning anything related to the internal accounts of the wallet, like
// for example bip84-account0.
if strings.HasPrefix(utxo.GetAccountName(), "bip") {
continue
}
vtxos = append(vtxos, domain.VtxoKey{
Txid: utxo.GetTxid(),
VOut: utxo.GetIndex(),
})
}
return vtxos
}

View File

@@ -0,0 +1,270 @@
package oceanwallet
import (
"context"
"encoding/binary"
"encoding/hex"
"fmt"
"strings"
"time"
pb "github.com/ark-network/ark/api-spec/protobuf/gen/ocean/v1"
"github.com/ark-network/ark/common/tree"
"github.com/ark-network/ark/internal/core/ports"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/txscript"
"github.com/vulpemventures/go-elements/elementsutil"
"github.com/vulpemventures/go-elements/psetv2"
)
const (
zero32 = "0000000000000000000000000000000000000000000000000000000000000000"
)
func (s *service) SignPset(
ctx context.Context, pset string, extractRawTx bool,
) (string, error) {
res, err := s.txClient.SignPset(ctx, &pb.SignPsetRequest{
Pset: pset,
})
if err != nil {
return "", err
}
signedPset := res.GetPset()
if !extractRawTx {
return signedPset, nil
}
ptx, err := psetv2.NewPsetFromBase64(signedPset)
if err != nil {
return "", err
}
if err := psetv2.MaybeFinalizeAll(ptx); err != nil {
return "", fmt.Errorf("failed to finalize signed pset: %s", err)
}
extractedTx, err := psetv2.Extract(ptx)
if err != nil {
return "", fmt.Errorf("failed to extract signed pset: %s", err)
}
txHex, err := extractedTx.ToHex()
if err != nil {
return "", fmt.Errorf("failed to convert extracted tx to hex: %s", err)
}
return txHex, nil
}
func (s *service) SelectUtxos(ctx context.Context, asset string, amount uint64) ([]ports.TxInput, uint64, error) {
res, err := s.txClient.SelectUtxos(ctx, &pb.SelectUtxosRequest{
AccountName: accountLabel,
TargetAsset: asset,
TargetAmount: amount,
})
if err != nil {
return nil, 0, err
}
inputs := make([]ports.TxInput, 0, len(res.GetUtxos()))
for _, utxo := range res.GetUtxos() {
// check that the utxos are not confidential
if utxo.GetAssetBlinder() != zero32 || utxo.GetValueBlinder() != zero32 {
return nil, 0, fmt.Errorf("utxo is confidential")
}
inputs = append(inputs, utxo)
}
return inputs, res.GetChange(), nil
}
func (s *service) GetTransaction(
ctx context.Context, txid string,
) (string, int64, error) {
res, err := s.txClient.GetTransaction(ctx, &pb.GetTransactionRequest{
Txid: txid,
})
if err != nil {
return "", 0, err
}
if res.GetBlockDetails().GetTimestamp() > 0 {
return res.GetTxHex(), res.BlockDetails.GetTimestamp(), nil
}
// if not confirmed, we return now + 30 secs to estimate the next blocktime
return res.GetTxHex(), time.Now().Unix() + 30, nil
}
func (s *service) BroadcastTransaction(
ctx context.Context, txHex string,
) (string, error) {
res, err := s.txClient.BroadcastTransaction(
ctx, &pb.BroadcastTransactionRequest{
TxHex: txHex,
},
)
if err != nil {
if strings.Contains(err.Error(), "non-BIP68-final") {
return "", fmt.Errorf("non-BIP68-final")
}
return "", err
}
return res.GetTxid(), nil
}
func (s *service) IsTransactionPublished(
ctx context.Context, txid string,
) (bool, int64, error) {
_, blocktime, err := s.GetTransaction(ctx, txid)
if err != nil {
if strings.Contains(strings.ToLower(err.Error()), "missing transaction") {
return false, 0, nil
}
return false, 0, err
}
return true, blocktime, nil
}
func (s *service) SignPsetWithKey(ctx context.Context, b64 string, indexes []int) (string, error) {
pset, err := psetv2.NewPsetFromBase64(b64)
if err != nil {
return "", err
}
if indexes == nil {
for i := 0; i < len(pset.Inputs); i++ {
indexes = append(indexes, i)
}
}
key, masterKey, err := s.getPubkey(ctx)
if err != nil {
return "", err
}
fingerprint := binary.LittleEndian.Uint32(masterKey.FingerPrint)
extendedKey, err := masterKey.Serialize()
if err != nil {
return "", err
}
pset.Global.Xpubs = []psetv2.Xpub{{
ExtendedKey: extendedKey[:len(extendedKey)-4],
MasterFingerprint: fingerprint,
DerivationPath: derivationPath,
}}
updater, err := psetv2.NewUpdater(pset)
if err != nil {
return "", err
}
bip32derivation := psetv2.DerivationPathWithPubKey{
PubKey: key.SerializeCompressed(),
MasterKeyFingerprint: fingerprint,
Bip32Path: derivationPath,
}
for _, i := range indexes {
if len(pset.Inputs[i].TapLeafScript) == 0 {
return "", fmt.Errorf("no tap leaf script found for input %d", i)
}
leafHash := pset.Inputs[i].TapLeafScript[0].TapHash()
if err := updater.AddInTapBip32Derivation(i, psetv2.TapDerivationPathWithPubKey{
DerivationPathWithPubKey: bip32derivation,
LeafHashes: [][]byte{leafHash[:]},
}); err != nil {
return "", err
}
if err := updater.AddInSighashType(i, txscript.SigHashDefault); err != nil {
return "", err
}
}
unsignedPset, err := pset.ToBase64()
if err != nil {
return "", err
}
signedPset, err := s.txClient.SignPsetWithSchnorrKey(ctx, &pb.SignPsetWithSchnorrKeyRequest{
Tx: unsignedPset,
SighashType: uint32(txscript.SigHashDefault),
})
if err != nil {
return "", err
}
return signedPset.GetSignedTx(), nil
}
func (s *service) EstimateFees(
ctx context.Context, pset string,
) (uint64, error) {
tx, err := psetv2.NewPsetFromBase64(pset)
if err != nil {
return 0, err
}
inputs := make([]*pb.Input, 0, len(tx.Inputs))
outputs := make([]*pb.Output, 0, len(tx.Outputs))
for _, in := range tx.Inputs {
pbInput := &pb.Input{
Txid: chainhash.Hash(in.PreviousTxid).String(),
Index: in.PreviousTxIndex,
}
if len(in.TapLeafScript) == 1 {
isSweep, _, _, err := tree.DecodeSweepScript(in.TapLeafScript[0].Script)
if err != nil {
return 0, err
}
if isSweep {
pbInput.WitnessSize = 64
pbInput.ScriptsigSize = 0
}
} else {
if in.WitnessUtxo == nil {
return 0, fmt.Errorf("missing witness utxo, cannot estimate fees")
}
pbInput.Script = hex.EncodeToString(in.WitnessUtxo.Script)
}
inputs = append(inputs, pbInput)
}
for _, out := range tx.Outputs {
outputs = append(outputs, &pb.Output{
Asset: elementsutil.AssetHashFromBytes(
append([]byte{0x01}, out.Asset...),
),
Amount: out.Value,
Script: hex.EncodeToString(out.Script),
})
}
fee, err := s.txClient.EstimateFees(
ctx,
&pb.EstimateFeesRequest{
Inputs: inputs,
Outputs: outputs,
},
)
if err != nil {
return 0, fmt.Errorf("failed to estimate fees: %s", err)
}
// we add 5 sats in order to avoid min-relay-fee not met errors
return fee.GetFeeAmount() + 5, nil
}

View File

@@ -0,0 +1,95 @@
package oceanwallet
import (
"context"
"fmt"
pb "github.com/ark-network/ark/api-spec/protobuf/gen/ocean/v1"
"github.com/ark-network/ark/internal/core/ports"
"github.com/btcsuite/btcd/btcutil/hdkeychain"
"github.com/decred/dcrd/dcrec/secp256k1/v4"
"github.com/vulpemventures/go-bip32"
)
const accountLabel = "ark"
var derivationPath = []uint32{0, 0}
func (s *service) GetPubkey(ctx context.Context) (*secp256k1.PublicKey, error) {
key, _, err := s.getPubkey(ctx)
return key, err
}
func (s *service) Status(
ctx context.Context,
) (ports.WalletStatus, error) {
res, err := s.walletClient.Status(ctx, &pb.StatusRequest{})
if err != nil {
return nil, err
}
return walletStatus{res}, nil
}
type walletStatus struct {
*pb.StatusResponse
}
func (w walletStatus) IsInitialized() bool {
return w.StatusResponse.GetInitialized()
}
func (w walletStatus) IsUnlocked() bool {
return w.StatusResponse.GetUnlocked()
}
func (w walletStatus) IsSynced() bool {
return w.StatusResponse.GetSynced()
}
func (s *service) findAccount(ctx context.Context, label string) (*pb.AccountInfo, error) {
res, err := s.walletClient.GetInfo(ctx, &pb.GetInfoRequest{})
if err != nil {
return nil, err
}
if len(res.GetAccounts()) <= 0 {
return nil, fmt.Errorf("wallet is locked")
}
for _, account := range res.GetAccounts() {
if account.GetLabel() == label {
return account, nil
}
}
return nil, fmt.Errorf("account not found")
}
func (s *service) getPubkey(ctx context.Context) (*secp256k1.PublicKey, *bip32.Key, error) {
account, err := s.findAccount(ctx, accountLabel)
if err != nil {
return nil, nil, err
}
xpub := account.GetXpubs()[0]
node, err := hdkeychain.NewKeyFromString(xpub)
if err != nil {
return nil, nil, err
}
for _, i := range derivationPath {
node, err = node.Derive(i)
if err != nil {
return nil, nil, err
}
}
key, err := node.ECPubKey()
if err != nil {
return nil, nil, err
}
masterKey, err := bip32.B58Deserialize(xpub)
if err != nil {
return nil, nil, err
}
return key, masterKey, nil
}

View File

@@ -0,0 +1,45 @@
package scheduler
import (
"fmt"
"time"
"github.com/ark-network/ark/internal/core/ports"
"github.com/go-co-op/gocron"
)
type service struct {
scheduler *gocron.Scheduler
}
func NewScheduler() ports.SchedulerService {
svc := gocron.NewScheduler(time.UTC)
return &service{svc}
}
func (s *service) Start() {
s.scheduler.StartAsync()
}
func (s *service) Stop() {
s.scheduler.Stop()
}
func (s *service) ScheduleTask(interval int64, immediate bool, task func()) error {
if immediate {
_, err := s.scheduler.Every(int(interval)).Seconds().Do(task)
return err
}
_, err := s.scheduler.Every(int(interval)).Seconds().WaitForSchedule().Do(task)
return err
}
func (s *service) ScheduleTaskOnce(at int64, task func()) error {
delay := at - time.Now().Unix()
if delay < 0 {
return fmt.Errorf("cannot schedule task in the past")
}
_, err := s.scheduler.Every(int(delay)).Seconds().WaitForSchedule().LimitRunsTo(1).Do(task)
return err
}

View File

@@ -0,0 +1,483 @@
package txbuilder
import (
"context"
"encoding/hex"
"fmt"
"github.com/ark-network/ark/common/tree"
"github.com/ark-network/ark/internal/core/domain"
"github.com/ark-network/ark/internal/core/ports"
"github.com/decred/dcrd/dcrec/secp256k1/v4"
"github.com/vulpemventures/go-elements/address"
"github.com/vulpemventures/go-elements/network"
"github.com/vulpemventures/go-elements/psetv2"
"github.com/vulpemventures/go-elements/taproot"
)
const (
connectorAmount = uint64(450)
dustLimit = uint64(450)
)
type txBuilder struct {
wallet ports.WalletService
net *network.Network
roundLifetime int64 // in seconds
}
func NewTxBuilder(
wallet ports.WalletService, net network.Network, roundLifetime int64,
) ports.TxBuilder {
return &txBuilder{wallet, &net, roundLifetime}
}
func (b *txBuilder) GetVtxoScript(userPubkey, aspPubkey *secp256k1.PublicKey) ([]byte, error) {
outputScript, _, err := b.getLeafScriptAndTree(userPubkey, aspPubkey)
if err != nil {
return nil, err
}
return outputScript, nil
}
func (b *txBuilder) BuildSweepTx(wallet ports.WalletService, inputs []ports.SweepInput) (signedSweepTx string, err error) {
sweepPset, err := sweepTransaction(
wallet,
inputs,
b.net.AssetID,
)
if err != nil {
return "", err
}
sweepPsetBase64, err := sweepPset.ToBase64()
if err != nil {
return "", err
}
ctx := context.Background()
signedSweepPsetB64, err := wallet.SignPsetWithKey(ctx, sweepPsetBase64, nil)
if err != nil {
return "", err
}
signedPset, err := psetv2.NewPsetFromBase64(signedSweepPsetB64)
if err != nil {
return "", err
}
if err := psetv2.FinalizeAll(signedPset); err != nil {
return "", err
}
extractedTx, err := psetv2.Extract(signedPset)
if err != nil {
return "", err
}
return extractedTx.ToHex()
}
func (b *txBuilder) BuildForfeitTxs(
aspPubkey *secp256k1.PublicKey, poolTx string, payments []domain.Payment,
) (connectors []string, forfeitTxs []string, err error) {
connectorTxs, err := b.createConnectors(poolTx, payments, aspPubkey)
if err != nil {
return nil, nil, err
}
forfeitTxs, err = b.createForfeitTxs(aspPubkey, payments, connectorTxs)
if err != nil {
return nil, nil, err
}
for _, tx := range connectorTxs {
buf, _ := tx.ToBase64()
connectors = append(connectors, buf)
}
return connectors, forfeitTxs, nil
}
func (b *txBuilder) BuildPoolTx(
aspPubkey *secp256k1.PublicKey, payments []domain.Payment, minRelayFee uint64,
) (poolTx string, congestionTree tree.CongestionTree, err error) {
// The creation of the tree and the pool tx are tightly coupled:
// - building the tree requires knowing the shared outpoint (txid:vout)
// - building the pool tx requires knowing the shared output script and amount
// The idea here is to first create all the data for the outputs of the txs
// of the congestion tree to calculate the shared output script and amount.
// With these data the pool tx can be created, and once the shared utxo
// outpoint is obtained, the congestion tree can be finally created.
// The factory function `treeFactoryFn` returned below holds all outputs data
// generated in the process and takes the shared utxo outpoint as argument.
// This is safe as the memory allocated for `craftCongestionTree` is freed
// only after `BuildPoolTx` returns.
treeFactoryFn, sharedOutputScript, sharedOutputAmount, err := craftCongestionTree(
b.net.AssetID, aspPubkey, payments, minRelayFee, b.roundLifetime,
)
if err != nil {
return
}
ptx, err := b.createPoolTx(
sharedOutputAmount, sharedOutputScript, payments, aspPubkey,
)
if err != nil {
return
}
unsignedTx, err := ptx.UnsignedTx()
if err != nil {
return
}
tree, err := treeFactoryFn(psetv2.InputArgs{
Txid: unsignedTx.TxHash().String(),
TxIndex: 0,
})
if err != nil {
return
}
poolTx, err = ptx.ToBase64()
if err != nil {
return
}
congestionTree = tree
return
}
func (b *txBuilder) GetLeafSweepClosure(
node tree.Node, userPubKey *secp256k1.PublicKey,
) (*psetv2.TapLeafScript, int64, error) {
if !node.Leaf {
return nil, 0, fmt.Errorf("node is not a leaf")
}
pset, err := psetv2.NewPsetFromBase64(node.Tx)
if err != nil {
return nil, 0, err
}
input := pset.Inputs[0]
sweepLeaf, lifetime, err := extractSweepLeaf(input)
if err != nil {
return nil, 0, err
}
// craft the vtxo taproot tree
vtxoScript, err := tree.VtxoScript(userPubKey)
if err != nil {
return nil, 0, err
}
vtxoTaprootTree := taproot.AssembleTaprootScriptTree(
*vtxoScript,
sweepLeaf.TapElementsLeaf,
)
proofIndex := vtxoTaprootTree.LeafProofIndex[sweepLeaf.TapHash()]
proof := vtxoTaprootTree.LeafMerkleProofs[proofIndex]
return &psetv2.TapLeafScript{
TapElementsLeaf: proof.TapElementsLeaf,
ControlBlock: proof.ToControlBlock(sweepLeaf.ControlBlock.InternalKey),
}, lifetime, nil
}
func (b *txBuilder) getLeafScriptAndTree(
userPubkey, aspPubkey *secp256k1.PublicKey,
) ([]byte, *taproot.IndexedElementsTapScriptTree, error) {
redeemClosure, err := tree.VtxoScript(userPubkey)
if err != nil {
return nil, nil, err
}
sweepClosure, err := tree.SweepScript(aspPubkey, uint(b.roundLifetime))
if err != nil {
return nil, nil, err
}
taprootTree := taproot.AssembleTaprootScriptTree(
*redeemClosure, *sweepClosure,
)
root := taprootTree.RootNode.TapHash()
unspendableKey := tree.UnspendableKey()
taprootKey := taproot.ComputeTaprootOutputKey(unspendableKey, root[:])
outputScript, err := taprootOutputScript(taprootKey)
if err != nil {
return nil, nil, err
}
return outputScript, taprootTree, nil
}
func (b *txBuilder) createPoolTx(
sharedOutputAmount uint64, sharedOutputScript []byte,
payments []domain.Payment, aspPubKey *secp256k1.PublicKey,
) (*psetv2.Pset, error) {
aspScript, err := p2wpkhScript(aspPubKey, b.net)
if err != nil {
return nil, err
}
receivers := getOnchainReceivers(payments)
connectorsAmount := connectorAmount * countSpentVtxos(payments)
targetAmount := sharedOutputAmount + connectorsAmount
outputs := []psetv2.OutputArgs{
{
Asset: b.net.AssetID,
Amount: sharedOutputAmount,
Script: sharedOutputScript,
},
{
Asset: b.net.AssetID,
Amount: connectorsAmount,
Script: aspScript,
},
}
for _, receiver := range receivers {
targetAmount += receiver.Amount
receiverScript, err := address.ToOutputScript(receiver.OnchainAddress)
if err != nil {
return nil, err
}
outputs = append(outputs, psetv2.OutputArgs{
Asset: b.net.AssetID,
Amount: receiver.Amount,
Script: receiverScript,
})
}
ctx := context.Background()
utxos, change, err := b.wallet.SelectUtxos(ctx, b.net.AssetID, targetAmount)
if err != nil {
return nil, err
}
var dust uint64
if change > 0 {
if change < dustLimit {
dust = change
change = 0
} else {
outputs = append(outputs, psetv2.OutputArgs{
Asset: b.net.AssetID,
Amount: change,
Script: aspScript,
})
}
}
ptx, err := psetv2.New(nil, outputs, nil)
if err != nil {
return nil, err
}
updater, err := psetv2.NewUpdater(ptx)
if err != nil {
return nil, err
}
if err := addInputs(updater, utxos); err != nil {
return nil, err
}
b64, err := ptx.ToBase64()
if err != nil {
return nil, err
}
feeAmount, err := b.wallet.EstimateFees(ctx, b64)
if err != nil {
return nil, err
}
if dust > feeAmount {
feeAmount = dust
} else {
feeAmount += dust
}
if dust == 0 {
if feeAmount == change {
// fees = change, remove change output
ptx.Outputs = ptx.Outputs[:len(ptx.Outputs)-1]
} else if feeAmount < change {
// change covers the fees, reduce change amount
ptx.Outputs[len(ptx.Outputs)-1].Value = change - feeAmount
} else {
// change is not enough to cover fees, re-select utxos
if change > 0 {
// remove change output if present
ptx.Outputs = ptx.Outputs[:len(ptx.Outputs)-1]
}
newUtxos, change, err := b.wallet.SelectUtxos(ctx, b.net.AssetID, feeAmount-change)
if err != nil {
return nil, err
}
if change > 0 {
if err := updater.AddOutputs([]psetv2.OutputArgs{
{
Asset: b.net.AssetID,
Amount: change,
Script: aspScript,
},
}); err != nil {
return nil, err
}
}
if err := addInputs(updater, newUtxos); err != nil {
return nil, err
}
}
}
// add fee output
if err := updater.AddOutputs([]psetv2.OutputArgs{
{
Asset: b.net.AssetID,
Amount: feeAmount,
},
}); err != nil {
return nil, err
}
return ptx, nil
}
func (b *txBuilder) createConnectors(
poolTx string, payments []domain.Payment, aspPubkey *secp256k1.PublicKey,
) ([]*psetv2.Pset, error) {
txid, _ := getTxid(poolTx)
aspScript, err := p2wpkhScript(aspPubkey, b.net)
if err != nil {
return nil, err
}
connectorOutput := psetv2.OutputArgs{
Asset: b.net.AssetID,
Script: aspScript,
Amount: connectorAmount,
}
numberOfConnectors := countSpentVtxos(payments)
previousInput := psetv2.InputArgs{
Txid: txid,
TxIndex: 1,
}
if numberOfConnectors == 1 {
outputs := []psetv2.OutputArgs{connectorOutput}
connectorTx, err := craftConnectorTx(previousInput, outputs)
if err != nil {
return nil, err
}
return []*psetv2.Pset{connectorTx}, nil
}
totalConnectorAmount := connectorAmount * numberOfConnectors
connectors := make([]*psetv2.Pset, 0, numberOfConnectors-1)
for i := uint64(0); i < numberOfConnectors-1; i++ {
outputs := []psetv2.OutputArgs{connectorOutput}
totalConnectorAmount -= connectorAmount
if totalConnectorAmount > 0 {
outputs = append(outputs, psetv2.OutputArgs{
Asset: b.net.AssetID,
Script: aspScript,
Amount: totalConnectorAmount,
})
}
connectorTx, err := craftConnectorTx(previousInput, outputs)
if err != nil {
return nil, err
}
txid, _ := getPsetId(connectorTx)
previousInput = psetv2.InputArgs{
Txid: txid,
TxIndex: 1,
}
connectors = append(connectors, connectorTx)
}
return connectors, nil
}
func (b *txBuilder) createForfeitTxs(
aspPubkey *secp256k1.PublicKey, payments []domain.Payment, connectors []*psetv2.Pset,
) ([]string, error) {
aspScript, err := p2wpkhScript(aspPubkey, b.net)
if err != nil {
return nil, err
}
forfeitTxs := make([]string, 0)
for _, payment := range payments {
for _, vtxo := range payment.Inputs {
pubkeyBytes, err := hex.DecodeString(vtxo.Pubkey)
if err != nil {
return nil, fmt.Errorf("failed to decode pubkey: %s", err)
}
vtxoPubkey, err := secp256k1.ParsePubKey(pubkeyBytes)
if err != nil {
return nil, err
}
vtxoScript, vtxoTaprootTree, err := b.getLeafScriptAndTree(vtxoPubkey, aspPubkey)
if err != nil {
return nil, err
}
for _, connector := range connectors {
txs, err := craftForfeitTxs(
connector, vtxo, vtxoTaprootTree, vtxoScript, aspScript,
)
if err != nil {
return nil, err
}
forfeitTxs = append(forfeitTxs, txs...)
}
}
}
return forfeitTxs, nil
}
// given a congestion tree input, searches and returns the sweep leaf and its lifetime in seconds
func extractSweepLeaf(input psetv2.Input) (sweepLeaf *psetv2.TapLeafScript, lifetime int64, err error) {
for _, leaf := range input.TapLeafScript {
isSweep, _, seconds, err := tree.DecodeSweepScript(leaf.Script)
if err != nil {
return nil, 0, err
}
if isSweep {
lifetime = int64(seconds)
sweepLeaf = &leaf
break
}
}
if sweepLeaf == nil {
return nil, 0, fmt.Errorf("sweep leaf not found")
}
return sweepLeaf, lifetime, nil
}

View File

@@ -0,0 +1,229 @@
package txbuilder_test
import (
"crypto/rand"
"encoding/hex"
"encoding/json"
"os"
"testing"
"github.com/ark-network/ark/common"
"github.com/ark-network/ark/common/tree"
"github.com/ark-network/ark/internal/core/domain"
"github.com/ark-network/ark/internal/core/ports"
txbuilder "github.com/ark-network/ark/internal/infrastructure/tx-builder/covenant"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/decred/dcrd/dcrec/secp256k1/v4"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
"github.com/vulpemventures/go-elements/network"
"github.com/vulpemventures/go-elements/psetv2"
)
const (
testingKey = "apub1qgvdtj5ttpuhkldavhq8thtm5auyk0ec4dcmrfdgu0u5hgp9we22v3hrs4x"
minRelayFee = uint64(30)
roundLifetime = int64(1209344)
)
var (
wallet *mockedWallet
pubkey *secp256k1.PublicKey
)
func TestMain(m *testing.M) {
wallet = &mockedWallet{}
wallet.On("EstimateFees", mock.Anything, mock.Anything).
Return(uint64(100), nil)
wallet.On("SelectUtxos", mock.Anything, mock.Anything, mock.Anything).
Return(randomInput, uint64(0), nil)
_, pubkey, _ = common.DecodePubKey(testingKey)
os.Exit(m.Run())
}
func TestBuildPoolTx(t *testing.T) {
builder := txbuilder.NewTxBuilder(wallet, network.Liquid, roundLifetime)
fixtures, err := parsePoolTxFixtures()
require.NoError(t, err)
require.NotEmpty(t, fixtures)
if len(fixtures.Valid) > 0 {
t.Run("valid", func(t *testing.T) {
for _, f := range fixtures.Valid {
poolTx, congestionTree, err := builder.BuildPoolTx(pubkey, f.Payments, minRelayFee)
require.NoError(t, err)
require.NotEmpty(t, poolTx)
require.NotEmpty(t, congestionTree)
require.Equal(t, f.ExpectedNumOfNodes, congestionTree.NumberOfNodes())
require.Len(t, congestionTree.Leaves(), f.ExpectedNumOfLeaves)
err = tree.ValidateCongestionTree(congestionTree, poolTx, pubkey, roundLifetime)
require.NoError(t, err)
}
})
}
if len(fixtures.Invalid) > 0 {
t.Run("invalid", func(t *testing.T) {
for _, f := range fixtures.Invalid {
poolTx, congestionTree, err := builder.BuildPoolTx(pubkey, f.Payments, minRelayFee)
require.EqualError(t, err, f.ExpectedErr)
require.Empty(t, poolTx)
require.Empty(t, congestionTree)
}
})
}
}
func TestBuildForfeitTxs(t *testing.T) {
builder := txbuilder.NewTxBuilder(wallet, network.Liquid, 1209344)
fixtures, err := parseForfeitTxsFixtures()
require.NoError(t, err)
require.NotEmpty(t, fixtures)
if len(fixtures.Valid) > 0 {
t.Run("valid", func(t *testing.T) {
for _, f := range fixtures.Valid {
connectors, forfeitTxs, err := builder.BuildForfeitTxs(
pubkey, f.PoolTx, f.Payments,
)
require.NoError(t, err)
require.Len(t, connectors, f.ExpectedNumOfConnectors)
require.Len(t, forfeitTxs, f.ExpectedNumOfForfeitTxs)
expectedInputTxid := f.PoolTxid
// Verify the chain of connectors
for _, connector := range connectors {
tx, err := psetv2.NewPsetFromBase64(connector)
require.NoError(t, err)
require.NotNil(t, tx)
require.Len(t, tx.Inputs, 1)
require.Len(t, tx.Outputs, 2)
inputTxid := chainhash.Hash(tx.Inputs[0].PreviousTxid).String()
require.Equal(t, expectedInputTxid, inputTxid)
require.Equal(t, 1, int(tx.Inputs[0].PreviousTxIndex))
expectedInputTxid = getTxid(tx)
}
// decode and check forfeit txs
for _, forfeitTx := range forfeitTxs {
tx, err := psetv2.NewPsetFromBase64(forfeitTx)
require.NoError(t, err)
require.Len(t, tx.Inputs, 2)
require.Len(t, tx.Outputs, 2)
}
}
})
}
if len(fixtures.Invalid) > 0 {
t.Run("invalid", func(t *testing.T) {
for _, f := range fixtures.Invalid {
connectors, forfeitTxs, err := builder.BuildForfeitTxs(
pubkey, f.PoolTx, f.Payments,
)
require.EqualError(t, err, f.ExpectedErr)
require.Empty(t, connectors)
require.Empty(t, forfeitTxs)
}
})
}
}
func randomInput() []ports.TxInput {
txid := randomHex(32)
input := &mockedInput{}
input.On("GetAsset").Return("5ac9f65c0efcc4775e0baec4ec03abdde22473cd3cf33c0419ca290e0751b225")
input.On("GetValue").Return(uint64(1000))
input.On("GetScript").Return("a914ea9f486e82efb3dd83a69fd96e3f0113757da03c87")
input.On("GetTxid").Return(txid)
input.On("GetIndex").Return(uint32(0))
return []ports.TxInput{input}
}
func randomHex(len int) string {
buf := make([]byte, len)
// nolint
rand.Read(buf)
return hex.EncodeToString(buf)
}
type poolTxFixtures struct {
Valid []struct {
Payments []domain.Payment
ExpectedNumOfNodes int
ExpectedNumOfLeaves int
}
Invalid []struct {
Payments []domain.Payment
ExpectedErr string
}
}
func parsePoolTxFixtures() (*poolTxFixtures, error) {
file, err := os.ReadFile("testdata/fixtures.json")
if err != nil {
return nil, err
}
v := map[string]interface{}{}
if err := json.Unmarshal(file, &v); err != nil {
return nil, err
}
vv := v["buildPoolTx"].(map[string]interface{})
file, _ = json.Marshal(vv)
var fixtures poolTxFixtures
if err := json.Unmarshal(file, &fixtures); err != nil {
return nil, err
}
return &fixtures, nil
}
type forfeitTxsFixtures struct {
Valid []struct {
Payments []domain.Payment
ExpectedNumOfConnectors int
ExpectedNumOfForfeitTxs int
PoolTx string
PoolTxid string
}
Invalid []struct {
Payments []domain.Payment
ExpectedErr string
PoolTx string
}
}
func parseForfeitTxsFixtures() (*forfeitTxsFixtures, error) {
file, err := os.ReadFile("testdata/fixtures.json")
if err != nil {
return nil, err
}
v := map[string]interface{}{}
if err := json.Unmarshal(file, &v); err != nil {
return nil, err
}
vv := v["buildForfeitTxs"].(map[string]interface{})
file, _ = json.Marshal(vv)
var fixtures forfeitTxsFixtures
if err := json.Unmarshal(file, &fixtures); err != nil {
return nil, err
}
return &fixtures, nil
}
func getTxid(tx *psetv2.Pset) string {
utx, _ := tx.UnsignedTx()
return utx.TxHash().String()
}

View File

@@ -0,0 +1,48 @@
package txbuilder
import (
"github.com/vulpemventures/go-elements/psetv2"
"github.com/vulpemventures/go-elements/transaction"
)
func craftConnectorTx(
input psetv2.InputArgs, outputs []psetv2.OutputArgs,
) (*psetv2.Pset, error) {
ptx, _ := psetv2.New(nil, nil, nil)
updater, _ := psetv2.NewUpdater(ptx)
if err := updater.AddInputs(
[]psetv2.InputArgs{input},
); err != nil {
return nil, err
}
// TODO: add prevout.
if err := updater.AddOutputs(outputs); err != nil {
return nil, err
}
return ptx, nil
}
func getConnectorInputs(pset *psetv2.Pset) ([]psetv2.InputArgs, []*transaction.TxOutput) {
txID, _ := getPsetId(pset)
inputs := make([]psetv2.InputArgs, 0, len(pset.Outputs))
witnessUtxos := make([]*transaction.TxOutput, 0, len(pset.Outputs))
for i, output := range pset.Outputs {
utx, _ := pset.UnsignedTx()
if output.Value == connectorAmount && len(output.Script) > 0 {
inputs = append(inputs, psetv2.InputArgs{
Txid: txID,
TxIndex: uint32(i),
})
witnessUtxos = append(witnessUtxos, utx.Outputs[i])
}
}
return inputs, witnessUtxos
}

View File

@@ -0,0 +1,104 @@
package txbuilder
import (
"github.com/ark-network/ark/common/tree"
"github.com/ark-network/ark/internal/core/domain"
"github.com/btcsuite/btcd/txscript"
"github.com/vulpemventures/go-elements/elementsutil"
"github.com/vulpemventures/go-elements/psetv2"
"github.com/vulpemventures/go-elements/taproot"
"github.com/vulpemventures/go-elements/transaction"
)
func craftForfeitTxs(
connectorTx *psetv2.Pset,
vtxo domain.Vtxo,
vtxoTaprootTree *taproot.IndexedElementsTapScriptTree,
vtxoScript, aspScript []byte,
) (forfeitTxs []string, err error) {
connectors, prevouts := getConnectorInputs(connectorTx)
for i, connectorInput := range connectors {
connectorPrevout := prevouts[i]
asset := elementsutil.AssetHashFromBytes(connectorPrevout.Asset)
pset, err := psetv2.New(nil, nil, nil)
if err != nil {
return nil, err
}
updater, err := psetv2.NewUpdater(pset)
if err != nil {
return nil, err
}
vtxoInput := psetv2.InputArgs{
Txid: vtxo.Txid,
TxIndex: vtxo.VOut,
}
vtxoAmount, _ := elementsutil.ValueToBytes(vtxo.Amount)
vtxoPrevout := &transaction.TxOutput{
Asset: connectorPrevout.Asset,
Value: vtxoAmount,
Script: vtxoScript,
}
if err := updater.AddInputs([]psetv2.InputArgs{connectorInput, vtxoInput}); err != nil {
return nil, err
}
if err = updater.AddInWitnessUtxo(0, connectorPrevout); err != nil {
return nil, err
}
if err := updater.AddInSighashType(0, txscript.SigHashAll); err != nil {
return nil, err
}
if err = updater.AddInWitnessUtxo(1, vtxoPrevout); err != nil {
return nil, err
}
if err := updater.AddInSighashType(1, txscript.SigHashDefault); err != nil {
return nil, err
}
unspendableKey := tree.UnspendableKey()
for _, proof := range vtxoTaprootTree.LeafMerkleProofs {
tapScript := psetv2.NewTapLeafScript(proof, unspendableKey)
if err := updater.AddInTapLeafScript(1, tapScript); err != nil {
return nil, err
}
}
connectorAmount, err := elementsutil.ValueFromBytes(connectorPrevout.Value)
if err != nil {
return nil, err
}
err = updater.AddOutputs([]psetv2.OutputArgs{
{
Asset: asset,
Amount: vtxo.Amount + connectorAmount - 30,
Script: aspScript,
},
{
Asset: asset,
Amount: 30,
},
})
if err != nil {
return nil, err
}
tx, err := pset.ToBase64()
if err != nil {
return nil, err
}
forfeitTxs = append(forfeitTxs, tx)
}
return forfeitTxs, nil
}

View File

@@ -0,0 +1,203 @@
package txbuilder_test
import (
"context"
"github.com/ark-network/ark/internal/core/domain"
"github.com/ark-network/ark/internal/core/ports"
"github.com/decred/dcrd/dcrec/secp256k1/v4"
"github.com/stretchr/testify/mock"
)
type mockedWallet struct {
mock.Mock
}
// BroadcastTransaction implements ports.WalletService.
func (m *mockedWallet) BroadcastTransaction(ctx context.Context, txHex string) (string, error) {
args := m.Called(ctx, txHex)
var res string
if a := args.Get(0); a != nil {
res = a.(string)
}
return res, args.Error(1)
}
// Close implements ports.WalletService.
func (m *mockedWallet) Close() {
m.Called()
}
// DeriveAddresses implements ports.WalletService.
func (m *mockedWallet) DeriveAddresses(ctx context.Context, num int) ([]string, error) {
args := m.Called(ctx, num)
var res []string
if a := args.Get(0); a != nil {
res = a.([]string)
}
return res, args.Error(1)
}
// GetPubkey implements ports.WalletService.
func (m *mockedWallet) GetPubkey(ctx context.Context) (*secp256k1.PublicKey, error) {
args := m.Called(ctx)
var res *secp256k1.PublicKey
if a := args.Get(0); a != nil {
res = a.(*secp256k1.PublicKey)
}
return res, args.Error(1)
}
// SignPset implements ports.WalletService.
func (m *mockedWallet) SignPset(ctx context.Context, pset string, extractRawTx bool) (string, error) {
args := m.Called(ctx, pset, extractRawTx)
var res string
if a := args.Get(0); a != nil {
res = a.(string)
}
return res, args.Error(1)
}
// Status implements ports.WalletService.
func (m *mockedWallet) Status(ctx context.Context) (ports.WalletStatus, error) {
args := m.Called(ctx)
var res ports.WalletStatus
if a := args.Get(0); a != nil {
res = a.(ports.WalletStatus)
}
return res, args.Error(1)
}
func (m *mockedWallet) SelectUtxos(ctx context.Context, asset string, amount uint64) ([]ports.TxInput, uint64, error) {
args := m.Called(ctx, asset, amount)
var res0 func() []ports.TxInput
if a := args.Get(0); a != nil {
res0 = a.(func() []ports.TxInput)
}
var res1 uint64
if a := args.Get(1); a != nil {
res1 = a.(uint64)
}
return res0(), res1, args.Error(2)
}
func (m *mockedWallet) EstimateFees(ctx context.Context, pset string) (uint64, error) {
args := m.Called(ctx, pset)
var res uint64
if a := args.Get(0); a != nil {
res = a.(uint64)
}
return res, args.Error(1)
}
func (m *mockedWallet) IsTransactionPublished(ctx context.Context, txid string) (bool, int64, error) {
args := m.Called(ctx, txid)
var res bool
if a := args.Get(0); a != nil {
res = a.(bool)
}
var blocktime int64
if b := args.Get(1); b != nil {
blocktime = b.(int64)
}
return res, blocktime, args.Error(2)
}
func (m *mockedWallet) SignPsetWithKey(ctx context.Context, pset string, inputIndexes []int) (string, error) {
args := m.Called(ctx, pset, inputIndexes)
var res string
if a := args.Get(0); a != nil {
res = a.(string)
}
return res, args.Error(1)
}
func (m *mockedWallet) WatchScripts(
ctx context.Context, scripts []string,
) error {
args := m.Called(ctx, scripts)
return args.Error(0)
}
func (m *mockedWallet) UnwatchScripts(
ctx context.Context, scripts []string,
) error {
args := m.Called(ctx, scripts)
return args.Error(0)
}
func (m *mockedWallet) GetNotificationChannel(ctx context.Context) chan []domain.VtxoKey {
args := m.Called(ctx)
var res chan []domain.VtxoKey
if a := args.Get(0); a != nil {
res = a.(chan []domain.VtxoKey)
}
return res
}
type mockedInput struct {
mock.Mock
}
func (m *mockedInput) GetTxid() string {
args := m.Called()
var res string
if a := args.Get(0); a != nil {
res = a.(string)
}
return res
}
func (m *mockedInput) GetIndex() uint32 {
args := m.Called()
var res uint32
if a := args.Get(0); a != nil {
res = a.(uint32)
}
return res
}
func (m *mockedInput) GetScript() string {
args := m.Called()
var res string
if a := args.Get(0); a != nil {
res = a.(string)
}
return res
}
func (m *mockedInput) GetAsset() string {
args := m.Called()
var res string
if a := args.Get(0); a != nil {
res = a.(string)
}
return res
}
func (m *mockedInput) GetValue() uint64 {
args := m.Called()
var res uint64
if a := args.Get(0); a != nil {
res = a.(uint64)
}
return res
}

View File

@@ -0,0 +1,135 @@
package txbuilder
import (
"context"
"fmt"
"github.com/ark-network/ark/common"
"github.com/ark-network/ark/common/tree"
"github.com/ark-network/ark/internal/core/ports"
"github.com/vulpemventures/go-elements/address"
"github.com/vulpemventures/go-elements/elementsutil"
"github.com/vulpemventures/go-elements/psetv2"
"github.com/vulpemventures/go-elements/taproot"
"github.com/vulpemventures/go-elements/transaction"
)
func sweepTransaction(
wallet ports.WalletService,
sweepInputs []ports.SweepInput,
lbtc string,
) (*psetv2.Pset, error) {
sweepPset, err := psetv2.New(nil, nil, nil)
if err != nil {
return nil, err
}
updater, err := psetv2.NewUpdater(sweepPset)
if err != nil {
return nil, err
}
amount := uint64(0)
for i, input := range sweepInputs {
leaf := input.SweepLeaf
isSweep, _, lifetime, err := tree.DecodeSweepScript(leaf.Script)
if err != nil {
return nil, err
}
if isSweep {
amount += input.Amount
if err := updater.AddInputs([]psetv2.InputArgs{input.InputArgs}); err != nil {
return nil, err
}
if err := updater.AddInTapLeafScript(i, leaf); err != nil {
return nil, err
}
assetHash, err := elementsutil.AssetHashToBytes(lbtc)
if err != nil {
return nil, err
}
value, err := elementsutil.ValueToBytes(input.Amount)
if err != nil {
return nil, err
}
root := leaf.ControlBlock.RootHash(leaf.Script)
taprootKey := taproot.ComputeTaprootOutputKey(leaf.ControlBlock.InternalKey, root)
script, err := taprootOutputScript(taprootKey)
if err != nil {
return nil, err
}
witnessUtxo := transaction.NewTxOutput(assetHash, value, script)
if err := updater.AddInWitnessUtxo(i, witnessUtxo); err != nil {
return nil, err
}
sequence, err := common.BIP68EncodeAsNumber(lifetime)
if err != nil {
return nil, err
}
updater.Pset.Inputs[i].Sequence = sequence
continue
}
return nil, fmt.Errorf("invalid sweep script")
}
ctx := context.Background()
sweepAddress, err := wallet.DeriveAddresses(ctx, 1)
if err != nil {
return nil, err
}
script, err := address.ToOutputScript(sweepAddress[0])
if err != nil {
return nil, err
}
if err := updater.AddOutputs([]psetv2.OutputArgs{
{
Asset: lbtc,
Amount: amount,
Script: script,
},
}); err != nil {
return nil, err
}
b64, err := sweepPset.ToBase64()
if err != nil {
return nil, err
}
fees, err := wallet.EstimateFees(ctx, b64)
if err != nil {
return nil, err
}
if amount < fees {
return nil, fmt.Errorf("insufficient funds (%d) to cover fees (%d) for sweep transaction", amount, fees)
}
updater.Pset.Outputs[0].Value = amount - fees
if err := updater.AddOutputs([]psetv2.OutputArgs{
{
Asset: lbtc,
Amount: fees,
},
}); err != nil {
return nil, err
}
return sweepPset, nil
}

View File

@@ -0,0 +1,229 @@
{
"buildPoolTx": {
"valid": [
{
"payments": [
{
"id": "0",
"inputs": [
{
"txid": "fd68e3c5796cc7db0a8036d486d5f625b6b2f2c014810ac020e1ac23e82c59d6",
"vout": 0,
"pubkey": "020000000000000000000000000000000000000000000000000000000000000002",
"amount": 1100
}
],
"receivers": [
{
"pubkey": "020000000000000000000000000000000000000000000000000000000000000002",
"amount": 1100
}
]
}
],
"expectedNumOfNodes": 1,
"expectedNumOfLeaves": 1
},
{
"payments": [
{
"id": "0",
"inputs": [
{
"txid": "fd68e3c5796cc7db0a8036d486d5f625b6b2f2c014810ac020e1ac23e82c59d6",
"vout": 0,
"pubkey": "020000000000000000000000000000000000000000000000000000000000000002",
"amount": 1100
}
],
"receivers": [
{
"pubkey": "020000000000000000000000000000000000000000000000000000000000000002",
"amount": 600
},
{
"pubkey": "020000000000000000000000000000000000000000000000000000000000000002",
"amount": 500
}
]
}
],
"expectedNumOfNodes": 3,
"expectedNumOfLeaves": 2
},
{
"payments": [
{
"id": "0",
"inputs": [
{
"txid": "fd68e3c5796cc7db0a8036d486d5f625b6b2f2c014810ac020e1ac23e82c59d6",
"vout": 0,
"pubkey": "020000000000000000000000000000000000000000000000000000000000000002",
"amount": 1100
}
],
"receivers": [
{
"pubkey": "020000000000000000000000000000000000000000000000000000000000000002",
"amount": 600
},
{
"pubkey": "020000000000000000000000000000000000000000000000000000000000000002",
"amount": 500
}
]
},
{
"id": "0",
"inputs": [
{
"txid": "fd68e3c5796cc7db0a8036d486d5f625b6b2f2c014810ac020e1ac23e82c59d6",
"vout": 0,
"pubkey": "020000000000000000000000000000000000000000000000000000000000000002",
"amount": 1100
}
],
"receivers": [
{
"pubkey": "020000000000000000000000000000000000000000000000000000000000000002",
"amount": 600
},
{
"pubkey": "020000000000000000000000000000000000000000000000000000000000000002",
"amount": 500
}
]
},
{
"id": "0",
"inputs": [
{
"txid": "fd68e3c5796cc7db0a8036d486d5f625b6b2f2c014810ac020e1ac23e82c59d6",
"vout": 0,
"pubkey": "020000000000000000000000000000000000000000000000000000000000000002",
"amount": 1100
}
],
"receivers": [
{
"pubkey": "020000000000000000000000000000000000000000000000000000000000000002",
"amount": 600
},
{
"pubkey": "020000000000000000000000000000000000000000000000000000000000000002",
"amount": 500
}
]
}
],
"expectedNumOfNodes": 11,
"expectedNumOfLeaves": 6
},
{
"payments": [
{
"id": "a242cdd8-f3d5-46c0-ae98-94135a2bee3f",
"inputs": [
{
"txid": "755c820771284d85ea4bbcc246565b4eddadc44237a7e57a0f9cb78a840d1d41",
"vout": 0,
"pubkey": "02c87e5c1758df5ad42a918ec507b6e8dfcdcebf22f64f58eb4ad5804257d658a5",
"amount": 1000
},
{
"txid": "66a0df86fcdeb84b8877adfe0b2c556dba30305d72ddbd4c49355f6930355357",
"vout": 0,
"pubkey": "02c87e5c1758df5ad42a918ec507b6e8dfcdcebf22f64f58eb4ad5804257d658a5",
"amount": 1000
},
{
"txid": "9913159bc7aa493ca53cbb9cbc88f97ba01137c814009dc7ef520c3fafc67909",
"vout": 1,
"pubkey": "02c87e5c1758df5ad42a918ec507b6e8dfcdcebf22f64f58eb4ad5804257d658a5",
"amount": 500
},
{
"txid": "5e10e77a7cdedc153be5193a4b6055a7802706ded4f2a9efefe86ed2f9a6ae60",
"vout": 0,
"pubkey": "02c87e5c1758df5ad42a918ec507b6e8dfcdcebf22f64f58eb4ad5804257d658a5",
"amount": 1000
},
{
"txid": "5e10e77a7cdedc153be5193a4b6055a7802706ded4f2a9efefe86ed2f9a6ae60",
"vout": 1,
"pubkey": "02c87e5c1758df5ad42a918ec507b6e8dfcdcebf22f64f58eb4ad5804257d658a5",
"amount": 1000
}
],
"receivers": [
{
"pubkey": "02c87e5c1758df5ad42a918ec507b6e8dfcdcebf22f64f58eb4ad5804257d658a5",
"amount": 1000
},
{
"pubkey": "02c87e5c1758df5ad42a918ec507b6e8dfcdcebf22f64f58eb4ad5804257d658a5",
"amount": 1000
},
{
"pubkey": "02c87e5c1758df5ad42a918ec507b6e8dfcdcebf22f64f58eb4ad5804257d658a5",
"amount": 1000
},
{
"pubkey": "02c87e5c1758df5ad42a918ec507b6e8dfcdcebf22f64f58eb4ad5804257d658a5",
"amount": 1000
},
{
"pubkey": "02c87e5c1758df5ad42a918ec507b6e8dfcdcebf22f64f58eb4ad5804257d658a5",
"amount": 500
}
]
}
],
"expectedNumOfNodes": 9,
"expectedNumOfLeaves": 5
}
],
"invalid": []
},
"buildForfeitTxs": {
"valid": [
{
"payments": [
{
"id": "0",
"inputs": [
{
"txid": "fd68e3c5796cc7db0a8036d486d5f625b6b2f2c014810ac020e1ac23e82c59d6",
"vout": 0,
"pubkey": "020000000000000000000000000000000000000000000000000000000000000002",
"amount": 600
},
{
"txid": "fd68e3c5796cc7db0a8036d486d5f625b6b2f2c014810ac020e1ac23e82c59d6",
"vout": 1,
"pubkey": "020000000000000000000000000000000000000000000000000000000000000002",
"amount": 500
}
],
"receivers": [
{
"pubkey": "020000000000000000000000000000000000000000000000000000000000000002",
"amount": 600
},
{
"pubkey": "020000000000000000000000000000000000000000000000000000000000000002",
"amount": 500
}
]
}
],
"poolTx": "cHNldP8BAgQCAAAAAQQBAQEFAQMBBgEDAfsEAgAAAAABDiDk7dXxh4KQzgLO8i1ABtaLCe4aPL12GVhN1E9zM1ePLwEPBAAAAAABEAT/////AAEDCOgDAAAAAAAAAQQWABSNnpy01UJqd99eTg2M1IpdKId11gf8BHBzZXQCICWyUQcOKcoZBDzzPM1zJOLdqwPsxK4LXnfE/A5c9slaB/wEcHNldAgEAAAAAAABAwh4BQAAAAAAAAEEFgAUjZ6ctNVCanffXk4NjNSKXSiHddYH/ARwc2V0AiAlslEHDinKGQQ88zzNcyTi3asD7MSuC153xPwOXPbJWgf8BHBzZXQIBAAAAAAAAQMI9AEAAAAAAAABBAAH/ARwc2V0AiAlslEHDinKGQQ88zzNcyTi3asD7MSuC153xPwOXPbJWgf8BHBzZXQIBAAAAAAA",
"poolTxid": "7981fce656f266472cc742444527cb32a8bed8c90fed6d47adbfc4c8780d4d9a",
"expectedNumOfForfeitTxs": 4,
"expectedNumOfConnectors": 1
}
],
"invalid": []
}
}

View File

@@ -0,0 +1,448 @@
package txbuilder
import (
"encoding/hex"
"fmt"
"github.com/ark-network/ark/common/tree"
"github.com/ark-network/ark/internal/core/domain"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/decred/dcrd/dcrec/secp256k1/v4"
"github.com/vulpemventures/go-elements/psetv2"
"github.com/vulpemventures/go-elements/taproot"
)
type treeFactory func(outpoint psetv2.InputArgs) (tree.CongestionTree, error)
type node struct {
sweepKey *secp256k1.PublicKey
receivers []domain.Receiver
left *node
right *node
asset string
feeSats uint64
roundLifetime int64
_inputTaprootKey *secp256k1.PublicKey
_inputTaprootTree *taproot.IndexedElementsTapScriptTree
}
func (n *node) isLeaf() bool {
return len(n.receivers) == 1
}
func (n *node) getAmount() uint64 {
var amount uint64
for _, r := range n.receivers {
amount += r.Amount
}
if n.isLeaf() {
return amount
}
return amount + n.feeSats*uint64(n.countChildren())
}
func (n *node) countChildren() int {
result := 0
if n.left != nil {
result++
result += n.left.countChildren()
}
if n.right != nil {
result++
result += n.right.countChildren()
}
return result
}
func (n *node) getChildren() []*node {
if n.isLeaf() {
return nil
}
children := make([]*node, 0, 2)
if n.left != nil {
children = append(children, n.left)
}
if n.right != nil {
children = append(children, n.right)
}
return children
}
func (n *node) getOutputs() ([]psetv2.OutputArgs, error) {
if n.isLeaf() {
taprootKey, _, err := n.getVtxoWitnessData()
if err != nil {
return nil, err
}
script, err := taprootOutputScript(taprootKey)
if err != nil {
return nil, err
}
output := &psetv2.OutputArgs{
Asset: n.asset,
Amount: uint64(n.getAmount()),
Script: script,
}
return []psetv2.OutputArgs{*output}, nil
}
outputs := make([]psetv2.OutputArgs, 0, 2)
children := n.getChildren()
for _, child := range children {
childWitnessProgram, _, err := child.getWitnessData()
if err != nil {
return nil, err
}
script, err := taprootOutputScript(childWitnessProgram)
if err != nil {
return nil, err
}
outputs = append(outputs, psetv2.OutputArgs{
Asset: n.asset,
Amount: child.getAmount() + child.feeSats,
Script: script,
})
}
return outputs, nil
}
func (n *node) getWitnessData() (
*secp256k1.PublicKey, *taproot.IndexedElementsTapScriptTree, error,
) {
if n._inputTaprootKey != nil && n._inputTaprootTree != nil {
return n._inputTaprootKey, n._inputTaprootTree, nil
}
sweepClosure, err := tree.SweepScript(n.sweepKey, uint(n.roundLifetime))
if err != nil {
return nil, nil, err
}
if n.isLeaf() {
taprootKey, _, err := n.getVtxoWitnessData()
if err != nil {
return nil, nil, err
}
branchTaprootScript := tree.BranchScript(
taprootKey, nil, n.getAmount(), 0,
)
branchTaprootTree := taproot.AssembleTaprootScriptTree(
branchTaprootScript, *sweepClosure,
)
root := branchTaprootTree.RootNode.TapHash()
inputTapkey := taproot.ComputeTaprootOutputKey(
tree.UnspendableKey(),
root[:],
)
n._inputTaprootKey = inputTapkey
n._inputTaprootTree = branchTaprootTree
return inputTapkey, branchTaprootTree, nil
}
leftKey, _, err := n.left.getWitnessData()
if err != nil {
return nil, nil, err
}
rightKey, _, err := n.right.getWitnessData()
if err != nil {
return nil, nil, err
}
leftAmount := n.left.getAmount() + n.feeSats
rightAmount := n.right.getAmount() + n.feeSats
branchTaprootLeaf := tree.BranchScript(
leftKey, rightKey, leftAmount, rightAmount,
)
branchTaprootTree := taproot.AssembleTaprootScriptTree(
branchTaprootLeaf, *sweepClosure,
)
root := branchTaprootTree.RootNode.TapHash()
taprootKey := taproot.ComputeTaprootOutputKey(
tree.UnspendableKey(),
root[:],
)
n._inputTaprootKey = taprootKey
n._inputTaprootTree = branchTaprootTree
return taprootKey, branchTaprootTree, nil
}
func (n *node) getVtxoWitnessData() (
*secp256k1.PublicKey, *taproot.IndexedElementsTapScriptTree, error,
) {
if !n.isLeaf() {
return nil, nil, fmt.Errorf("cannot call vtxoWitness on a non-leaf node")
}
sweepClosure, err := tree.SweepScript(n.sweepKey, uint(n.roundLifetime))
if err != nil {
return nil, nil, err
}
key, err := hex.DecodeString(n.receivers[0].Pubkey)
if err != nil {
return nil, nil, err
}
pubkey, err := secp256k1.ParsePubKey(key)
if err != nil {
return nil, nil, err
}
vtxoLeaf, err := tree.VtxoScript(pubkey)
if err != nil {
return nil, nil, err
}
// TODO: add forfeit path
leafTaprootTree := taproot.AssembleTaprootScriptTree(
*vtxoLeaf, *sweepClosure,
)
root := leafTaprootTree.RootNode.TapHash()
taprootKey := taproot.ComputeTaprootOutputKey(
tree.UnspendableKey(),
root[:],
)
return taprootKey, leafTaprootTree, nil
}
func (n *node) getTreeNode(
input psetv2.InputArgs, tapTree *taproot.IndexedElementsTapScriptTree,
) (tree.Node, error) {
pset, err := n.getTx(input, tapTree)
if err != nil {
return tree.Node{}, err
}
txid, err := getPsetId(pset)
if err != nil {
return tree.Node{}, err
}
tx, err := pset.ToBase64()
if err != nil {
return tree.Node{}, err
}
parentTxid := chainhash.Hash(pset.Inputs[0].PreviousTxid).String()
return tree.Node{
Txid: txid,
Tx: tx,
ParentTxid: parentTxid,
Leaf: n.isLeaf(),
}, nil
}
func (n *node) getTx(
input psetv2.InputArgs, inputTapTree *taproot.IndexedElementsTapScriptTree,
) (*psetv2.Pset, error) {
pset, err := psetv2.New(nil, nil, nil)
if err != nil {
return nil, err
}
updater, err := psetv2.NewUpdater(pset)
if err != nil {
return nil, err
}
if err := addTaprootInput(
updater, input, tree.UnspendableKey(), inputTapTree,
); err != nil {
return nil, err
}
feeOutput := psetv2.OutputArgs{
Amount: uint64(n.feeSats),
Asset: n.asset,
}
outputs, err := n.getOutputs()
if err != nil {
return nil, err
}
if err := updater.AddOutputs(append(outputs, feeOutput)); err != nil {
return nil, err
}
return pset, nil
}
func (n *node) createFinalCongestionTree() treeFactory {
return func(poolTxInput psetv2.InputArgs) (tree.CongestionTree, error) {
congestionTree := make(tree.CongestionTree, 0)
_, taprootTree, err := n.getWitnessData()
if err != nil {
return nil, err
}
ins := []psetv2.InputArgs{poolTxInput}
inTrees := []*taproot.IndexedElementsTapScriptTree{taprootTree}
nodes := []*node{n}
for len(nodes) > 0 {
nextNodes := make([]*node, 0)
nextInputsArgs := make([]psetv2.InputArgs, 0)
nextTaprootTrees := make([]*taproot.IndexedElementsTapScriptTree, 0)
treeLevel := make([]tree.Node, 0)
for i, node := range nodes {
treeNode, err := node.getTreeNode(ins[i], inTrees[i])
if err != nil {
return nil, err
}
treeLevel = append(treeLevel, treeNode)
children := node.getChildren()
for i, child := range children {
_, taprootTree, err := child.getWitnessData()
if err != nil {
return nil, err
}
nextNodes = append(nextNodes, child)
nextInputsArgs = append(nextInputsArgs, psetv2.InputArgs{
Txid: treeNode.Txid,
TxIndex: uint32(i),
})
nextTaprootTrees = append(nextTaprootTrees, taprootTree)
}
}
congestionTree = append(congestionTree, treeLevel)
nodes = append([]*node{}, nextNodes...)
ins = append([]psetv2.InputArgs{}, nextInputsArgs...)
inTrees = append(
[]*taproot.IndexedElementsTapScriptTree{}, nextTaprootTrees...,
)
}
return congestionTree, nil
}
}
func craftCongestionTree(
asset string, aspPublicKey *secp256k1.PublicKey,
payments []domain.Payment, feeSatsPerNode uint64, roundLifetime int64,
) (
buildCongestionTree treeFactory,
sharedOutputScript []byte, sharedOutputAmount uint64, err error,
) {
receivers := getOffchainReceivers(payments)
root, err := createPartialCongestionTree(
receivers, aspPublicKey, asset, feeSatsPerNode, roundLifetime,
)
if err != nil {
return
}
taprootKey, _, err := root.getWitnessData()
if err != nil {
return
}
sharedOutputScript, err = taprootOutputScript(taprootKey)
if err != nil {
return
}
sharedOutputAmount = root.getAmount() + root.feeSats
buildCongestionTree = root.createFinalCongestionTree()
return
}
func createPartialCongestionTree(
receivers []domain.Receiver,
aspPublicKey *secp256k1.PublicKey,
asset string,
feeSatsPerNode uint64,
roundLifetime int64,
) (root *node, err error) {
if len(receivers) == 0 {
return nil, fmt.Errorf("no receivers provided")
}
nodes := make([]*node, 0, len(receivers))
for _, r := range receivers {
leafNode := &node{
sweepKey: aspPublicKey,
receivers: []domain.Receiver{r},
asset: asset,
feeSats: feeSatsPerNode,
roundLifetime: roundLifetime,
}
nodes = append(nodes, leafNode)
}
for len(nodes) > 1 {
nodes, err = createUpperLevel(nodes)
if err != nil {
return
}
}
return nodes[0], nil
}
func createUpperLevel(nodes []*node) ([]*node, error) {
if len(nodes)%2 != 0 {
last := nodes[len(nodes)-1]
pairs, err := createUpperLevel(nodes[:len(nodes)-1])
if err != nil {
return nil, err
}
return append(pairs, last), nil
}
pairs := make([]*node, 0, len(nodes)/2)
for i := 0; i < len(nodes); i += 2 {
left := nodes[i]
right := nodes[i+1]
branchNode := &node{
sweepKey: left.sweepKey,
receivers: append(left.receivers, right.receivers...),
left: left,
right: right,
asset: left.asset,
feeSats: left.feeSats,
roundLifetime: left.roundLifetime,
}
pairs = append(pairs, branchNode)
}
return pairs, nil
}

View File

@@ -0,0 +1,167 @@
package txbuilder
import (
"encoding/hex"
"fmt"
"github.com/ark-network/ark/internal/core/domain"
"github.com/ark-network/ark/internal/core/ports"
"github.com/btcsuite/btcd/btcec/v2/schnorr"
"github.com/btcsuite/btcd/txscript"
"github.com/decred/dcrd/dcrec/secp256k1/v4"
"github.com/vulpemventures/go-elements/address"
"github.com/vulpemventures/go-elements/elementsutil"
"github.com/vulpemventures/go-elements/network"
"github.com/vulpemventures/go-elements/payment"
"github.com/vulpemventures/go-elements/psetv2"
"github.com/vulpemventures/go-elements/taproot"
"github.com/vulpemventures/go-elements/transaction"
)
func p2wpkhScript(publicKey *secp256k1.PublicKey, net *network.Network) ([]byte, error) {
payment := payment.FromPublicKey(publicKey, net, nil)
addr, err := payment.WitnessPubKeyHash()
if err != nil {
return nil, err
}
return address.ToOutputScript(addr)
}
func getTxid(txStr string) (string, error) {
pset, err := psetv2.NewPsetFromBase64(txStr)
if err != nil {
return "", err
}
return getPsetId(pset)
}
func getPsetId(pset *psetv2.Pset) (string, error) {
utx, err := pset.UnsignedTx()
if err != nil {
return "", err
}
return utx.TxHash().String(), nil
}
func getOnchainReceivers(
payments []domain.Payment,
) []domain.Receiver {
receivers := make([]domain.Receiver, 0)
for _, payment := range payments {
for _, receiver := range payment.Receivers {
if receiver.IsOnchain() {
receivers = append(receivers, receiver)
}
}
}
return receivers
}
func getOffchainReceivers(
payments []domain.Payment,
) []domain.Receiver {
receivers := make([]domain.Receiver, 0)
for _, payment := range payments {
for _, receiver := range payment.Receivers {
if !receiver.IsOnchain() {
receivers = append(receivers, receiver)
}
}
}
return receivers
}
func toWitnessUtxo(in ports.TxInput) (*transaction.TxOutput, error) {
valueBytes, err := elementsutil.ValueToBytes(in.GetValue())
if err != nil {
return nil, fmt.Errorf("failed to convert value to bytes: %s", err)
}
assetBytes, err := elementsutil.AssetHashToBytes(in.GetAsset())
if err != nil {
return nil, fmt.Errorf("failed to convert asset to bytes: %s", err)
}
scriptBytes, err := hex.DecodeString(in.GetScript())
if err != nil {
return nil, fmt.Errorf("failed to decode script: %s", err)
}
return transaction.NewTxOutput(assetBytes, valueBytes, scriptBytes), nil
}
func countSpentVtxos(payments []domain.Payment) uint64 {
var sum uint64
for _, payment := range payments {
sum += uint64(len(payment.Inputs))
}
return sum
}
func addInputs(
updater *psetv2.Updater,
inputs []ports.TxInput,
) error {
for _, in := range inputs {
inputArg := psetv2.InputArgs{
Txid: in.GetTxid(),
TxIndex: in.GetIndex(),
}
witnessUtxo, err := toWitnessUtxo(in)
if err != nil {
return err
}
if err := updater.AddInputs([]psetv2.InputArgs{inputArg}); err != nil {
return err
}
index := int(updater.Pset.Global.InputCount) - 1
if err := updater.AddInWitnessUtxo(index, witnessUtxo); err != nil {
return err
}
if err := updater.AddInSighashType(index, txscript.SigHashAll); err != nil {
return err
}
}
return nil
}
// wrapper of updater methods adding a taproot input to the pset with all the necessary data to spend it via any taproot script
func addTaprootInput(
updater *psetv2.Updater,
input psetv2.InputArgs,
internalTaprootKey *secp256k1.PublicKey,
taprootTree *taproot.IndexedElementsTapScriptTree,
) error {
if err := updater.AddInputs([]psetv2.InputArgs{input}); err != nil {
return err
}
if err := updater.AddInTapInternalKey(0, schnorr.SerializePubKey(internalTaprootKey)); err != nil {
return err
}
for _, proof := range taprootTree.LeafMerkleProofs {
controlBlock := proof.ToControlBlock(internalTaprootKey)
if err := updater.AddInTapLeafScript(0, psetv2.TapLeafScript{
TapElementsLeaf: taproot.NewBaseTapElementsLeaf(proof.Script),
ControlBlock: controlBlock,
}); err != nil {
return err
}
}
return nil
}
func taprootOutputScript(taprootKey *secp256k1.PublicKey) ([]byte, error) {
return txscript.NewScriptBuilder().AddOp(txscript.OP_1).AddData(schnorr.SerializePubKey(taprootKey)).Script()
}

View File

@@ -0,0 +1,285 @@
package txbuilder
import (
"context"
"github.com/ark-network/ark/common/tree"
"github.com/ark-network/ark/internal/core/domain"
"github.com/ark-network/ark/internal/core/ports"
"github.com/decred/dcrd/dcrec/secp256k1/v4"
"github.com/vulpemventures/go-elements/address"
"github.com/vulpemventures/go-elements/network"
"github.com/vulpemventures/go-elements/payment"
"github.com/vulpemventures/go-elements/psetv2"
"github.com/vulpemventures/go-elements/transaction"
)
const (
connectorAmount = 450
sevenDays = 7 * 24 * 60 * 60
)
type txBuilder struct {
wallet ports.WalletService
net network.Network
}
func NewTxBuilder(
wallet ports.WalletService, net network.Network,
) ports.TxBuilder {
return &txBuilder{wallet, net}
}
// BuildSweepTx implements ports.TxBuilder.
func (*txBuilder) BuildSweepTx(wallet ports.WalletService, inputs []ports.SweepInput) (signedSweepTx string, err error) {
panic("unimplemented")
}
// BuildForfeitTxs implements ports.TxBuilder.
func (b *txBuilder) BuildForfeitTxs(
aspPubkey *secp256k1.PublicKey, poolTx string, payments []domain.Payment,
) (connectors []string, forfeitTxs []string, err error) {
poolTxID, err := getTxid(poolTx)
if err != nil {
return nil, nil, err
}
aspScript, err := p2wpkhScript(aspPubkey, b.net)
if err != nil {
return nil, nil, err
}
numberOfConnectors := countSpentVtxos(payments)
connectors, err = createConnectors(
poolTxID,
1,
psetv2.OutputArgs{
Asset: b.net.AssetID,
Amount: connectorAmount,
Script: aspScript,
},
aspScript,
numberOfConnectors,
)
if err != nil {
return nil, nil, err
}
connectorsAsInputs, err := connectorsToInputArgs(connectors)
if err != nil {
return nil, nil, err
}
forfeitTxs = make([]string, 0)
for _, payment := range payments {
for _, vtxo := range payment.Inputs {
for _, connector := range connectorsAsInputs {
forfeitTx, err := createForfeitTx(
connector,
psetv2.InputArgs{
Txid: vtxo.Txid,
TxIndex: vtxo.VOut,
},
vtxo.Amount,
aspScript,
b.net,
)
if err != nil {
return nil, nil, err
}
forfeitTxs = append(forfeitTxs, forfeitTx)
}
}
}
return connectors, forfeitTxs, nil
}
// BuildPoolTx implements ports.TxBuilder.
func (b *txBuilder) BuildPoolTx(
aspPubkey *secp256k1.PublicKey, payments []domain.Payment, minRelayFee uint64,
) (poolTx string, congestionTree tree.CongestionTree, err error) {
aspScriptBytes, err := p2wpkhScript(aspPubkey, b.net)
if err != nil {
return "", nil, err
}
offchainReceivers, onchainReceivers := receiversFromPayments(payments)
sharedOutputAmount := sumReceivers(offchainReceivers)
numberOfConnectors := countSpentVtxos(payments)
connectorOutputAmount := connectorAmount * numberOfConnectors
ctx := context.Background()
outputs := []psetv2.OutputArgs{
{
Asset: b.net.AssetID,
Amount: sharedOutputAmount,
Script: aspScriptBytes,
},
{
Asset: b.net.AssetID,
Amount: connectorOutputAmount,
Script: aspScriptBytes,
},
}
amountToSelect := sharedOutputAmount + connectorOutputAmount
for _, receiver := range onchainReceivers {
amountToSelect += receiver.Amount
receiverScript, err := address.ToOutputScript(receiver.OnchainAddress)
if err != nil {
return "", nil, err
}
outputs = append(outputs, psetv2.OutputArgs{
Asset: b.net.AssetID,
Amount: receiver.Amount,
Script: receiverScript,
})
}
utxos, change, err := b.wallet.SelectUtxos(ctx, b.net.AssetID, amountToSelect)
if err != nil {
return
}
if change > 0 {
outputs = append(outputs, psetv2.OutputArgs{
Asset: b.net.AssetID,
Amount: change,
Script: aspScriptBytes,
})
}
ptx, err := psetv2.New(toInputArgs(utxos), outputs, nil)
if err != nil {
return
}
utx, err := ptx.UnsignedTx()
if err != nil {
return
}
congestionTree, err = buildCongestionTree(
newOutputScriptFactory(aspPubkey, b.net),
b.net,
utx.TxHash().String(),
offchainReceivers,
)
if err != nil {
return
}
poolTx, err = ptx.ToBase64()
if err != nil {
return
}
return poolTx, congestionTree, err
}
func (b *txBuilder) GetVtxoScript(userPubkey, _ *secp256k1.PublicKey) ([]byte, error) {
p2wpkh := payment.FromPublicKey(userPubkey, &b.net, nil)
addr, _ := p2wpkh.WitnessPubKeyHash()
return address.ToOutputScript(addr)
}
func (b *txBuilder) GetLeafSweepClosure(
node tree.Node, userPubKey *secp256k1.PublicKey,
) (*psetv2.TapLeafScript, int64, error) {
panic("unimplemented")
}
func connectorsToInputArgs(connectors []string) ([]psetv2.InputArgs, error) {
inputs := make([]psetv2.InputArgs, 0, len(connectors)+1)
for i, psetb64 := range connectors {
tx, err := psetv2.NewPsetFromBase64(psetb64)
if err != nil {
return nil, err
}
utx, err := tx.UnsignedTx()
if err != nil {
return nil, err
}
txid := utx.TxHash().String()
for j := range tx.Outputs {
inputs = append(inputs, psetv2.InputArgs{
Txid: txid,
TxIndex: uint32(j),
})
if i != len(connectors)-1 {
break
}
}
}
return inputs, nil
}
func getTxid(txStr string) (string, error) {
pset, err := psetv2.NewPsetFromBase64(txStr)
if err != nil {
tx, err := transaction.NewTxFromHex(txStr)
if err != nil {
return "", err
}
return tx.TxHash().String(), nil
}
utx, err := pset.UnsignedTx()
if err != nil {
return "", err
}
return utx.TxHash().String(), nil
}
func countSpentVtxos(payments []domain.Payment) uint64 {
var sum uint64
for _, payment := range payments {
sum += uint64(len(payment.Inputs))
}
return sum
}
func receiversFromPayments(
payments []domain.Payment,
) (offchainReceivers, onchainReceivers []domain.Receiver) {
for _, payment := range payments {
for _, receiver := range payment.Receivers {
if receiver.IsOnchain() {
onchainReceivers = append(onchainReceivers, receiver)
} else {
offchainReceivers = append(offchainReceivers, receiver)
}
}
}
return
}
func sumReceivers(receivers []domain.Receiver) uint64 {
var sum uint64
for _, r := range receivers {
sum += r.Amount
}
return sum
}
func toInputArgs(
ins []ports.TxInput,
) []psetv2.InputArgs {
inputs := make([]psetv2.InputArgs, 0, len(ins))
for _, in := range ins {
inputs = append(inputs, psetv2.InputArgs{
Txid: in.GetTxid(),
TxIndex: in.GetIndex(),
})
}
return inputs
}

View File

@@ -0,0 +1,407 @@
package txbuilder_test
import (
"context"
"crypto/rand"
"encoding/hex"
"testing"
"github.com/ark-network/ark/common"
"github.com/ark-network/ark/internal/core/domain"
"github.com/ark-network/ark/internal/core/ports"
txbuilder "github.com/ark-network/ark/internal/infrastructure/tx-builder/dummy"
"github.com/btcsuite/btcd/chaincfg/chainhash"
secp256k1 "github.com/decred/dcrd/dcrec/secp256k1/v4"
"github.com/stretchr/testify/require"
"github.com/vulpemventures/go-elements/network"
"github.com/vulpemventures/go-elements/psetv2"
)
const (
testingKey = "apub1qgvdtj5ttpuhkldavhq8thtm5auyk0ec4dcmrfdgu0u5hgp9we22v3hrs4x"
fakePoolTx = "cHNldP8BAgQCAAAAAQQBAQEFAQMBBgEDAfsEAgAAAAABDiDk7dXxh4KQzgLO8i1ABtaLCe4aPL12GVhN1E9zM1ePLwEPBAAAAAABEAT/////AAEDCOgDAAAAAAAAAQQWABSNnpy01UJqd99eTg2M1IpdKId11gf8BHBzZXQCICWyUQcOKcoZBDzzPM1zJOLdqwPsxK4LXnfE/A5c9slaB/wEcHNldAgEAAAAAAABAwh4BQAAAAAAAAEEFgAUjZ6ctNVCanffXk4NjNSKXSiHddYH/ARwc2V0AiAlslEHDinKGQQ88zzNcyTi3asD7MSuC153xPwOXPbJWgf8BHBzZXQIBAAAAAAAAQMI9AEAAAAAAAABBAAH/ARwc2V0AiAlslEHDinKGQQ88zzNcyTi3asD7MSuC153xPwOXPbJWgf8BHBzZXQIBAAAAAAA"
)
type input struct {
txid string
vout uint32
}
func (i *input) GetTxid() string {
return i.txid
}
func (i *input) GetIndex() uint32 {
return i.vout
}
func (i *input) GetScript() string {
return "a914ea9f486e82efb3dd83a69fd96e3f0113757da03c87"
}
func (i *input) GetAsset() string {
return "5ac9f65c0efcc4775e0baec4ec03abdde22473cd3cf33c0419ca290e0751b225"
}
func (i *input) GetValue() uint64 {
return 1000
}
type mockedWalletService struct{}
// BroadcastTransaction implements ports.WalletService.
func (*mockedWalletService) BroadcastTransaction(ctx context.Context, txHex string) (string, error) {
panic("unimplemented")
}
// Close implements ports.WalletService.
func (*mockedWalletService) Close() {
panic("unimplemented")
}
// DeriveAddresses implements ports.WalletService.
func (*mockedWalletService) DeriveAddresses(ctx context.Context, num int) ([]string, error) {
panic("unimplemented")
}
// GetPubkey implements ports.WalletService.
func (*mockedWalletService) GetPubkey(ctx context.Context) (*secp256k1.PublicKey, error) {
panic("unimplemented")
}
// SignPset implements ports.WalletService.
func (*mockedWalletService) SignPset(ctx context.Context, pset string, extractRawTx bool) (string, error) {
panic("unimplemented")
}
// Status implements ports.WalletService.
func (*mockedWalletService) Status(ctx context.Context) (ports.WalletStatus, error) {
panic("unimplemented")
}
func (*mockedWalletService) WatchScripts(ctx context.Context, scripts []string) error {
panic("unimplemented")
}
func (*mockedWalletService) UnwatchScripts(ctx context.Context, scripts []string) error {
panic("unimplemented")
}
func (*mockedWalletService) GetNotificationChannel(ctx context.Context) chan []domain.VtxoKey {
panic("unimplemented")
}
func (*mockedWalletService) SelectUtxos(ctx context.Context, asset string, amount uint64) ([]ports.TxInput, uint64, error) {
// random txid
bytes := make([]byte, 32)
if _, err := rand.Read(bytes); err != nil {
return nil, 0, err
}
fakeInput := input{
txid: hex.EncodeToString(bytes),
vout: 0,
}
return []ports.TxInput{&fakeInput}, 0, nil
}
func (*mockedWalletService) EstimateFees(ctx context.Context, pset string) (uint64, error) {
return 100, nil
}
func (*mockedWalletService) SignPsetWithKey(ctx context.Context, pset string, inputIndex []int) (string, error) {
panic("unimplemented")
}
func (*mockedWalletService) IsTransactionPublished(ctx context.Context, txid string) (bool, int64, error) {
panic("unimplemented")
}
func TestBuildCongestionTree(t *testing.T) {
builder := txbuilder.NewTxBuilder(&mockedWalletService{}, network.Liquid)
fixtures := []struct {
payments []domain.Payment
expectedNodesNum int // 2*len(receivers)-1
expectedLeavesNum int
}{
{
payments: []domain.Payment{
{
Id: "0",
Inputs: []domain.Vtxo{
{
VtxoKey: domain.VtxoKey{
Txid: "fd68e3c5796cc7db0a8036d486d5f625b6b2f2c014810ac020e1ac23e82c59d6",
VOut: 0,
},
Receiver: domain.Receiver{
Pubkey: "020000000000000000000000000000000000000000000000000000000000000002",
Amount: 600,
},
},
},
Receivers: []domain.Receiver{
{
Pubkey: "020000000000000000000000000000000000000000000000000000000000000002",
Amount: 600,
},
{
Pubkey: "020000000000000000000000000000000000000000000000000000000000000002",
Amount: 400,
},
},
},
},
expectedNodesNum: 3,
expectedLeavesNum: 2,
},
{
payments: []domain.Payment{
{
Id: "0",
Inputs: []domain.Vtxo{
{
VtxoKey: domain.VtxoKey{
Txid: "fd68e3c5796cc7db0a8036d486d5f625b6b2f2c014810ac020e1ac23e82c59d6",
VOut: 0,
},
Receiver: domain.Receiver{
Pubkey: "020000000000000000000000000000000000000000000000000000000000000002",
Amount: 600,
},
},
},
Receivers: []domain.Receiver{
{
Pubkey: "020000000000000000000000000000000000000000000000000000000000000002",
Amount: 600,
},
{
Pubkey: "020000000000000000000000000000000000000000000000000000000000000002",
Amount: 400,
},
},
},
{
Id: "0",
Inputs: []domain.Vtxo{
{
VtxoKey: domain.VtxoKey{
Txid: "fd68e3c5796cc7db0a8036d486d5f625b6b2f2c014810ac020e1ac23e82c59d6",
VOut: 0,
},
Receiver: domain.Receiver{
Pubkey: "020000000000000000000000000000000000000000000000000000000000000002",
Amount: 600,
},
},
},
Receivers: []domain.Receiver{
{
Pubkey: "020000000000000000000000000000000000000000000000000000000000000002",
Amount: 600,
},
{
Pubkey: "020000000000000000000000000000000000000000000000000000000000000002",
Amount: 400,
},
},
},
{
Id: "0",
Inputs: []domain.Vtxo{
{
VtxoKey: domain.VtxoKey{
Txid: "fd68e3c5796cc7db0a8036d486d5f625b6b2f2c014810ac020e1ac23e82c59d6",
VOut: 0,
},
Receiver: domain.Receiver{
Pubkey: "020000000000000000000000000000000000000000000000000000000000000002",
Amount: 600,
},
},
},
Receivers: []domain.Receiver{
{
Pubkey: "020000000000000000000000000000000000000000000000000000000000000002",
Amount: 600,
},
{
Pubkey: "020000000000000000000000000000000000000000000000000000000000000002",
Amount: 400,
},
},
},
},
expectedNodesNum: 11,
expectedLeavesNum: 6,
},
}
_, key, err := common.DecodePubKey(testingKey)
require.NoError(t, err)
require.NotNil(t, key)
for _, f := range fixtures {
poolTx, tree, err := builder.BuildPoolTx(key, f.payments, 30)
require.NoError(t, err)
require.Equal(t, f.expectedNodesNum, tree.NumberOfNodes())
require.Len(t, tree.Leaves(), f.expectedLeavesNum)
poolPset, err := psetv2.NewPsetFromBase64(poolTx)
require.NoError(t, err)
poolTxUnsigned, err := poolPset.UnsignedTx()
require.NoError(t, err)
poolTxID := poolTxUnsigned.TxHash().String()
// check the root
require.Len(t, tree[0], 1)
require.Equal(t, poolTxID, tree[0][0].ParentTxid)
// check the leaves
for _, leaf := range tree.Leaves() {
pset, err := psetv2.NewPsetFromBase64(leaf.Tx)
require.NoError(t, err)
require.Len(t, pset.Inputs, 1)
require.Len(t, pset.Outputs, 1)
inputTxID := chainhash.Hash(pset.Inputs[0].PreviousTxid).String()
require.Equal(t, leaf.ParentTxid, inputTxID)
}
// check the nodes
for _, level := range tree[:len(tree)-2] {
for _, node := range level {
pset, err := psetv2.NewPsetFromBase64(node.Tx)
require.NoError(t, err)
require.Len(t, pset.Inputs, 1)
require.Len(t, pset.Outputs, 2)
inputTxID := chainhash.Hash(pset.Inputs[0].PreviousTxid).String()
require.Equal(t, node.ParentTxid, inputTxID)
children := tree.Children(node.Txid)
require.Len(t, children, 2)
}
}
}
}
func TestBuildForfeitTxs(t *testing.T) {
builder := txbuilder.NewTxBuilder(&mockedWalletService{}, network.Liquid)
poolPset, err := psetv2.NewPsetFromBase64(fakePoolTx)
require.NoError(t, err)
poolTxUnsigned, err := poolPset.UnsignedTx()
require.NoError(t, err)
poolTxID := poolTxUnsigned.TxHash().String()
fixtures := []struct {
payments []domain.Payment
expectedNumOfForfeitTxs int
expectedNumOfConnectors int
}{
{
payments: []domain.Payment{
{
Id: "0",
Inputs: []domain.Vtxo{
{
VtxoKey: domain.VtxoKey{
Txid: "fd68e3c5796cc7db0a8036d486d5f625b6b2f2c014810ac020e1ac23e82c59d6",
VOut: 0,
},
Receiver: domain.Receiver{
Pubkey: "020000000000000000000000000000000000000000000000000000000000000002",
Amount: 600,
},
},
{
VtxoKey: domain.VtxoKey{
Txid: "fd68e3c5796cc7db0a8036d486d5f625b6b2f2c014810ac020e1ac23e82c59d6",
VOut: 1,
},
Receiver: domain.Receiver{
Pubkey: "020000000000000000000000000000000000000000000000000000000000000002",
Amount: 400,
},
},
},
Receivers: []domain.Receiver{
{
Pubkey: "020000000000000000000000000000000000000000000000000000000000000002",
Amount: 600,
},
{
Pubkey: "020000000000000000000000000000000000000000000000000000000000000002",
Amount: 400,
},
},
},
},
expectedNumOfForfeitTxs: 4,
expectedNumOfConnectors: 1,
},
}
_, key, err := common.DecodePubKey(testingKey)
require.NoError(t, err)
require.NotNil(t, key)
for _, f := range fixtures {
connectors, forfeitTxs, err := builder.BuildForfeitTxs(
key, fakePoolTx, f.payments,
)
require.NoError(t, err)
require.Len(t, connectors, f.expectedNumOfConnectors)
require.Len(t, forfeitTxs, f.expectedNumOfForfeitTxs)
// decode and check connectors
connectorsPsets := make([]*psetv2.Pset, 0, f.expectedNumOfConnectors)
for _, pset := range connectors {
p, err := psetv2.NewPsetFromBase64(pset)
require.NoError(t, err)
connectorsPsets = append(connectorsPsets, p)
}
for i, pset := range connectorsPsets {
require.Len(t, pset.Inputs, 1)
require.Len(t, pset.Outputs, 2)
expectedInputTxid := poolTxID
expectedInputVout := uint32(1)
if i > 0 {
tx, err := connectorsPsets[i-1].UnsignedTx()
require.NoError(t, err)
require.NotNil(t, tx)
expectedInputTxid = tx.TxHash().String()
}
inputTxid := chainhash.Hash(pset.Inputs[0].PreviousTxid).String()
require.Equal(t, expectedInputTxid, inputTxid)
require.Equal(t, expectedInputVout, pset.Inputs[0].PreviousTxIndex)
}
// decode and check forfeit txs
forfeitTxsPsets := make([]*psetv2.Pset, 0, f.expectedNumOfForfeitTxs)
for _, pset := range forfeitTxs {
p, err := psetv2.NewPsetFromBase64(pset)
require.NoError(t, err)
forfeitTxsPsets = append(forfeitTxsPsets, p)
}
// each forfeit tx should have 2 inputs and 2 outputs
for _, pset := range forfeitTxsPsets {
require.Len(t, pset.Inputs, 2)
require.Len(t, pset.Outputs, 1)
}
}
}

View File

@@ -0,0 +1,103 @@
package txbuilder
import (
"github.com/vulpemventures/go-elements/psetv2"
)
func createConnectors(
poolTxID string,
connectorOutputIndex uint32,
connectorOutput psetv2.OutputArgs,
changeScript []byte,
numberOfConnectors uint64,
) (connectorsPsets []string, err error) {
previousInput := psetv2.InputArgs{
Txid: poolTxID,
TxIndex: connectorOutputIndex,
}
if numberOfConnectors == 1 {
pset, err := psetv2.New(nil, nil, nil)
if err != nil {
return nil, err
}
updater, err := psetv2.NewUpdater(pset)
if err != nil {
return nil, err
}
err = updater.AddInputs([]psetv2.InputArgs{previousInput})
if err != nil {
return nil, err
}
err = updater.AddOutputs([]psetv2.OutputArgs{connectorOutput})
if err != nil {
return nil, err
}
base64, err := pset.ToBase64()
if err != nil {
return nil, err
}
return []string{base64}, nil
}
// compute the initial amount of the connectors output in pool transaction
remainingAmount := connectorAmount * numberOfConnectors
connectorsPset := make([]string, 0, numberOfConnectors-1)
for i := uint64(0); i < numberOfConnectors-1; i++ {
// create a new pset
pset, err := psetv2.New(nil, nil, nil)
if err != nil {
return nil, err
}
updater, err := psetv2.NewUpdater(pset)
if err != nil {
return nil, err
}
err = updater.AddInputs([]psetv2.InputArgs{previousInput})
if err != nil {
return nil, err
}
err = updater.AddOutputs([]psetv2.OutputArgs{connectorOutput})
if err != nil {
return nil, err
}
changeAmount := remainingAmount - connectorOutput.Amount
if changeAmount > 0 {
changeOutput := psetv2.OutputArgs{
Asset: connectorOutput.Asset,
Amount: changeAmount,
Script: changeScript,
}
err = updater.AddOutputs([]psetv2.OutputArgs{changeOutput})
if err != nil {
return nil, err
}
tx, _ := pset.UnsignedTx()
txid := tx.TxHash().String()
// make the change the next previousInput
previousInput = psetv2.InputArgs{
Txid: txid,
TxIndex: 1,
}
}
base64, err := pset.ToBase64()
if err != nil {
return nil, err
}
connectorsPset = append(connectorsPset, base64)
}
return connectorsPset, nil
}

View File

@@ -0,0 +1,42 @@
package txbuilder
import (
"github.com/vulpemventures/go-elements/network"
"github.com/vulpemventures/go-elements/psetv2"
)
func createForfeitTx(
connectorInput psetv2.InputArgs,
vtxoInput psetv2.InputArgs,
vtxoAmount uint64,
aspScript []byte,
net network.Network,
) (forfeitTx string, err error) {
pset, err := psetv2.New(nil, nil, nil)
if err != nil {
return "", err
}
updater, err := psetv2.NewUpdater(pset)
if err != nil {
return "", err
}
err = updater.AddInputs([]psetv2.InputArgs{connectorInput, vtxoInput})
if err != nil {
return "", err
}
err = updater.AddOutputs([]psetv2.OutputArgs{
{
Asset: net.AssetID,
Amount: vtxoAmount,
Script: aspScript,
},
})
if err != nil {
return "", err
}
return pset.ToBase64()
}

View File

@@ -0,0 +1,309 @@
package txbuilder
import (
"encoding/hex"
"github.com/ark-network/ark/common/tree"
"github.com/ark-network/ark/internal/core/domain"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/decred/dcrd/dcrec/secp256k1/v4"
"github.com/vulpemventures/go-elements/address"
"github.com/vulpemventures/go-elements/network"
"github.com/vulpemventures/go-elements/payment"
"github.com/vulpemventures/go-elements/psetv2"
)
const (
sharedOutputIndex = 0
)
type outputScriptFactory func(leaves []domain.Receiver) ([]byte, error)
func p2wpkhScript(publicKey *secp256k1.PublicKey, net network.Network) ([]byte, error) {
payment := payment.FromPublicKey(publicKey, &net, nil)
addr, err := payment.WitnessPubKeyHash()
if err != nil {
return nil, err
}
return address.ToOutputScript(addr)
}
// newOtputScriptFactory returns an output script factory func that lock funds using the ASP public key only on all branches psbt. The leaves are instead locked by the leaf public key.
func newOutputScriptFactory(aspPublicKey *secp256k1.PublicKey, net network.Network) outputScriptFactory {
return func(leaves []domain.Receiver) ([]byte, error) {
aspScript, err := p2wpkhScript(aspPublicKey, net)
if err != nil {
return nil, err
}
switch len(leaves) {
case 0:
return nil, nil
case 1: // it's a leaf
buf, err := hex.DecodeString(leaves[0].Pubkey)
if err != nil {
return nil, err
}
key, err := secp256k1.ParsePubKey(buf)
if err != nil {
return nil, err
}
return p2wpkhScript(key, net)
default: // it's a branch, lock funds with ASP public key
return aspScript, nil
}
}
}
// congestionTree builder iteratively creates a binary tree of Pset from a set of receivers
// it also expect createOutputScript func managing the output script creation and the network to use (mainly for L-BTC asset id)
func buildCongestionTree(
createOutputScript outputScriptFactory,
net network.Network,
poolTxID string,
receivers []domain.Receiver,
) (congestionTree tree.CongestionTree, err error) {
var nodes []*node
for _, r := range receivers {
nodes = append(nodes, newLeaf(createOutputScript, net, r))
}
for len(nodes) > 1 {
nodes, err = createTreeLevel(nodes)
if err != nil {
return nil, err
}
}
psets, err := nodes[0].psets(psetv2.InputArgs{
Txid: poolTxID,
TxIndex: sharedOutputIndex,
}, 0)
if err != nil {
return nil, err
}
maxLevel := 0
for _, psetWithLevel := range psets {
if psetWithLevel.level > maxLevel {
maxLevel = psetWithLevel.level
}
}
congestionTree = make(tree.CongestionTree, maxLevel+1)
for _, psetWithLevel := range psets {
utx, err := psetWithLevel.pset.UnsignedTx()
if err != nil {
return nil, err
}
txid := utx.TxHash().String()
psetB64, err := psetWithLevel.pset.ToBase64()
if err != nil {
return nil, err
}
parentTxid := chainhash.Hash(psetWithLevel.pset.Inputs[0].PreviousTxid).String()
congestionTree[psetWithLevel.level] = append(congestionTree[psetWithLevel.level], tree.Node{
Txid: txid,
Tx: psetB64,
ParentTxid: parentTxid,
Leaf: psetWithLevel.leaf,
})
}
return congestionTree, nil
}
func createTreeLevel(nodes []*node) ([]*node, error) {
if len(nodes)%2 != 0 {
last := nodes[len(nodes)-1]
pairs, err := createTreeLevel(nodes[:len(nodes)-1])
if err != nil {
return nil, err
}
return append(pairs, last), nil
}
pairs := make([]*node, 0, len(nodes)/2)
for i := 0; i < len(nodes); i += 2 {
pairs = append(pairs, newBranch(nodes[i], nodes[i+1]))
}
return pairs, nil
}
// internal struct to build a binary tree of Pset
type node struct {
receivers []domain.Receiver
left *node
right *node
createOutputScript outputScriptFactory
network network.Network
}
// create a node from a single receiver
func newLeaf(
createOutputScript outputScriptFactory,
network network.Network,
receiver domain.Receiver,
) *node {
return &node{
receivers: []domain.Receiver{receiver},
createOutputScript: createOutputScript,
network: network,
left: nil,
right: nil,
}
}
// aggregate two nodes into a branch node
func newBranch(
left *node,
right *node,
) *node {
return &node{
receivers: append(left.receivers, right.receivers...),
createOutputScript: left.createOutputScript,
network: left.network,
left: left,
right: right,
}
}
// is it the final node of the tree
func (n *node) isLeaf() bool {
return n.left == nil && n.right == nil
}
// compute the output amount of a node
func (n *node) amount() uint64 {
var amount uint64
for _, r := range n.receivers {
amount += r.Amount
}
return amount
}
// compute the output script of a node
func (n *node) script() ([]byte, error) {
return n.createOutputScript(n.receivers)
}
// use script & amount to create OutputArgs
func (n *node) output() (*psetv2.OutputArgs, error) {
script, err := n.script()
if err != nil {
return nil, err
}
return &psetv2.OutputArgs{
Asset: n.network.AssetID,
Amount: n.amount(),
Script: script,
}, nil
}
// create the node Pset from the previous node Pset represented by input arg
// if node is a branch, it adds two outputs to the Pset, one for the left branch and one for the right branch
// if node is a leaf, it only adds one output to the Pset (the node output)
func (n *node) pset(input psetv2.InputArgs) (*psetv2.Pset, error) {
pset, err := psetv2.New(nil, nil, nil)
if err != nil {
return nil, err
}
updater, err := psetv2.NewUpdater(pset)
if err != nil {
return nil, err
}
err = updater.AddInputs([]psetv2.InputArgs{input})
if err != nil {
return nil, err
}
if n.isLeaf() {
output, err := n.output()
if err != nil {
return nil, err
}
err = updater.AddOutputs([]psetv2.OutputArgs{*output})
if err != nil {
return nil, err
}
return pset, nil
}
outputLeft, err := n.left.output()
if err != nil {
return nil, err
}
outputRight, err := n.right.output()
if err != nil {
return nil, err
}
err = updater.AddOutputs([]psetv2.OutputArgs{*outputLeft, *outputRight})
if err != nil {
return nil, err
}
return pset, nil
}
type psetWithLevel struct {
pset *psetv2.Pset
level int
leaf bool
}
// create the node pset and all the psets of its children recursively, updating the input arg at each step
// the function stops when it reaches a leaf node
func (n *node) psets(input psetv2.InputArgs, level int) ([]psetWithLevel, error) {
pset, err := n.pset(input)
if err != nil {
return nil, err
}
nodeResult := []psetWithLevel{
{pset, level, n.isLeaf()},
}
if n.isLeaf() {
return nodeResult, nil
}
unsignedTx, err := pset.UnsignedTx()
if err != nil {
return nil, err
}
txID := unsignedTx.TxHash().String()
psetsLeft, err := n.left.psets(psetv2.InputArgs{
Txid: txID,
TxIndex: 0,
}, level+1)
if err != nil {
return nil, err
}
psetsRight, err := n.right.psets(psetv2.InputArgs{
Txid: txID,
TxIndex: 1,
}, level+1)
if err != nil {
return nil, err
}
return append(nodeResult, append(psetsLeft, psetsRight...)...), nil
}

View File

@@ -0,0 +1,46 @@
package grpcservice
import (
"crypto/tls"
"fmt"
"net"
)
type Config struct {
Port uint32
NoTLS bool
}
func (c Config) Validate() error {
lis, err := net.Listen("tcp", c.address())
if err != nil {
return fmt.Errorf("invalid port: %s", err)
}
defer lis.Close()
if !c.NoTLS {
return fmt.Errorf("tls termination not supported yet")
}
return nil
}
func (c Config) insecure() bool {
return c.NoTLS
}
func (c Config) address() string {
return fmt.Sprintf(":%d", c.Port)
}
func (c Config) listener() net.Listener {
lis, _ := net.Listen("tcp", c.address())
if c.insecure() {
return lis
}
return tls.NewListener(lis, c.tlsConfig())
}
func (c Config) tlsConfig() *tls.Config {
return nil
}

View File

@@ -0,0 +1,305 @@
package handlers
import (
"context"
"encoding/hex"
"sync"
arkv1 "github.com/ark-network/ark/api-spec/protobuf/gen/ark/v1"
"github.com/ark-network/ark/common"
"github.com/ark-network/ark/common/tree"
"github.com/ark-network/ark/internal/core/application"
"github.com/ark-network/ark/internal/core/domain"
"github.com/decred/dcrd/dcrec/secp256k1/v4"
"github.com/google/uuid"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)
type listener struct {
id string
ch chan *arkv1.GetEventStreamResponse
}
type handler struct {
svc application.Service
listenersLock *sync.Mutex
listeners []*listener
}
func NewHandler(service application.Service) arkv1.ArkServiceServer {
h := &handler{
svc: service,
listenersLock: &sync.Mutex{},
listeners: make([]*listener, 0),
}
go h.listenToEvents()
return h
}
func (h *handler) Ping(ctx context.Context, req *arkv1.PingRequest) (*arkv1.PingResponse, error) {
if req.GetPaymentId() == "" {
return nil, status.Error(codes.InvalidArgument, "missing payment id")
}
if err := h.svc.UpdatePaymentStatus(ctx, req.GetPaymentId()); err != nil {
return nil, err
}
return &arkv1.PingResponse{}, nil
}
func (h *handler) RegisterPayment(ctx context.Context, req *arkv1.RegisterPaymentRequest) (*arkv1.RegisterPaymentResponse, error) {
vtxosKeys := make([]domain.VtxoKey, 0, len(req.GetInputs()))
for _, input := range req.GetInputs() {
vtxosKeys = append(vtxosKeys, domain.VtxoKey{
Txid: input.GetTxid(),
VOut: input.GetVout(),
})
}
id, err := h.svc.SpendVtxos(ctx, vtxosKeys)
if err != nil {
return nil, err
}
return &arkv1.RegisterPaymentResponse{
Id: id,
}, nil
}
func (h *handler) ClaimPayment(ctx context.Context, req *arkv1.ClaimPaymentRequest) (*arkv1.ClaimPaymentResponse, error) {
receivers, err := parseReceivers(req.GetOutputs())
if err != nil {
return nil, status.Error(codes.InvalidArgument, err.Error())
}
if err := h.svc.ClaimVtxos(ctx, req.GetId(), receivers); err != nil {
return nil, err
}
return &arkv1.ClaimPaymentResponse{}, nil
}
func (h *handler) FinalizePayment(ctx context.Context, req *arkv1.FinalizePaymentRequest) (*arkv1.FinalizePaymentResponse, error) {
forfeitTxs, err := parseTxs(req.GetSignedForfeitTxs())
if err != nil {
return nil, status.Error(codes.InvalidArgument, err.Error())
}
if err := h.svc.SignVtxos(ctx, forfeitTxs); err != nil {
return nil, err
}
return &arkv1.FinalizePaymentResponse{}, nil
}
func (h *handler) Faucet(ctx context.Context, req *arkv1.FaucetRequest) (*arkv1.FaucetResponse, error) {
_, pubkey, _, err := parseAddress(req.GetAddress())
if err != nil {
return nil, status.Error(codes.InvalidArgument, err.Error())
}
if err := h.svc.FaucetVtxos(ctx, pubkey); err != nil {
return nil, err
}
return &arkv1.FaucetResponse{}, nil
}
func (h *handler) GetRound(ctx context.Context, req *arkv1.GetRoundRequest) (*arkv1.GetRoundResponse, error) {
if len(req.GetTxid()) <= 0 {
return nil, status.Error(codes.InvalidArgument, "missing pool txid")
}
round, err := h.svc.GetRoundByTxid(ctx, req.GetTxid())
if err != nil {
return nil, err
}
return &arkv1.GetRoundResponse{
Round: &arkv1.Round{
Id: round.Id,
Start: round.StartingTimestamp,
End: round.EndingTimestamp,
Txid: round.Txid,
CongestionTree: castCongestionTree(round.CongestionTree),
},
}, nil
}
func (h *handler) GetEventStream(_ *arkv1.GetEventStreamRequest, stream arkv1.ArkService_GetEventStreamServer) error {
listener := &listener{
id: uuid.NewString(),
ch: make(chan *arkv1.GetEventStreamResponse),
}
defer h.removeListener(listener.id)
defer close(listener.ch)
h.pushListener(listener)
for {
select {
case <-stream.Context().Done():
return nil
case ev := <-listener.ch:
if err := stream.Send(ev); err != nil {
return err
}
switch ev.Event.(type) {
case *arkv1.GetEventStreamResponse_RoundFinalized, *arkv1.GetEventStreamResponse_RoundFailed:
if err := stream.Send(ev); err != nil {
return err
}
return nil
}
}
}
}
func (h *handler) ListVtxos(ctx context.Context, req *arkv1.ListVtxosRequest) (*arkv1.ListVtxosResponse, error) {
hrp, userPubkey, aspPubkey, err := parseAddress(req.GetAddress())
if err != nil {
return nil, status.Error(codes.InvalidArgument, err.Error())
}
vtxos, err := h.svc.ListVtxos(ctx, userPubkey)
if err != nil {
return nil, err
}
return &arkv1.ListVtxosResponse{
Vtxos: vtxoList(vtxos).toProto(hrp, aspPubkey),
}, nil
}
func (h *handler) GetPubkey(ctx context.Context, req *arkv1.GetPubkeyRequest) (*arkv1.GetPubkeyResponse, error) {
pubkey, err := h.svc.GetPubkey(ctx)
if err != nil {
return nil, err
}
return &arkv1.GetPubkeyResponse{
Pubkey: pubkey,
}, nil
}
func (h *handler) pushListener(l *listener) {
h.listenersLock.Lock()
defer h.listenersLock.Unlock()
h.listeners = append(h.listeners, l)
}
func (h *handler) removeListener(id string) {
h.listenersLock.Lock()
defer h.listenersLock.Unlock()
for i, listener := range h.listeners {
if listener.id == id {
h.listeners = append(h.listeners[:i], h.listeners[i+1:]...)
return
}
}
}
// listenToEvents forwards events from the application layer to the set of listeners
func (h *handler) listenToEvents() {
channel := h.svc.GetEventsChannel(context.Background())
for event := range channel {
var ev *arkv1.GetEventStreamResponse
switch e := event.(type) {
case domain.RoundFinalizationStarted:
ev = &arkv1.GetEventStreamResponse{
Event: &arkv1.GetEventStreamResponse_RoundFinalization{
RoundFinalization: &arkv1.RoundFinalizationEvent{
Id: e.Id,
PoolPartialTx: e.PoolTx,
CongestionTree: castCongestionTree(e.CongestionTree),
ForfeitTxs: e.UnsignedForfeitTxs,
},
},
}
case domain.RoundFinalized:
ev = &arkv1.GetEventStreamResponse{
Event: &arkv1.GetEventStreamResponse_RoundFinalized{
RoundFinalized: &arkv1.RoundFinalizedEvent{
Id: e.Id,
PoolTxid: e.Txid,
},
},
}
case domain.RoundFailed:
ev = &arkv1.GetEventStreamResponse{
Event: &arkv1.GetEventStreamResponse_RoundFailed{
RoundFailed: &arkv1.RoundFailed{
Id: e.Id,
Reason: e.Err,
},
},
}
}
if ev != nil {
for _, listener := range h.listeners {
listener.ch <- ev
}
}
}
}
type vtxoList []domain.Vtxo
func (v vtxoList) toProto(hrp string, aspKey *secp256k1.PublicKey) []*arkv1.Vtxo {
list := make([]*arkv1.Vtxo, 0, len(v))
for _, vv := range v {
addr := vv.OnchainAddress
if vv.Pubkey != "" {
buf, _ := hex.DecodeString(vv.Pubkey)
key, _ := secp256k1.ParsePubKey(buf)
addr, _ = common.EncodeAddress(hrp, key, aspKey)
}
list = append(list, &arkv1.Vtxo{
Outpoint: &arkv1.Input{
Txid: vv.Txid,
Vout: vv.VOut,
},
Receiver: &arkv1.Output{
Address: addr,
Amount: vv.Amount,
},
PoolTxid: vv.PoolTx,
Spent: vv.Spent,
})
}
return list
}
// castCongestionTree converts a tree.CongestionTree to a repeated arkv1.TreeLevel
func castCongestionTree(congestionTree tree.CongestionTree) *arkv1.Tree {
levels := make([]*arkv1.TreeLevel, 0, len(congestionTree))
for _, level := range congestionTree {
levelProto := &arkv1.TreeLevel{
Nodes: make([]*arkv1.Node, 0, len(level)),
}
for _, node := range level {
levelProto.Nodes = append(levelProto.Nodes, &arkv1.Node{
Txid: node.Txid,
Tx: node.Tx,
ParentTxid: node.ParentTxid,
})
}
levels = append(levels, levelProto)
}
return &arkv1.Tree{
Levels: levels,
}
}

View File

@@ -0,0 +1,64 @@
package handlers
import (
"encoding/hex"
"fmt"
arkv1 "github.com/ark-network/ark/api-spec/protobuf/gen/ark/v1"
"github.com/ark-network/ark/common"
"github.com/ark-network/ark/internal/core/domain"
"github.com/decred/dcrd/dcrec/secp256k1/v4"
"github.com/vulpemventures/go-elements/address"
"github.com/vulpemventures/go-elements/psetv2"
)
func parseTxs(txs []string) ([]string, error) {
if len(txs) <= 0 {
return nil, fmt.Errorf("missing list of forfeit txs")
}
for _, tx := range txs {
if _, err := psetv2.NewPsetFromBase64(tx); err != nil {
return nil, fmt.Errorf("invalid tx format")
}
}
return txs, nil
}
func parseAddress(addr string) (string, *secp256k1.PublicKey, *secp256k1.PublicKey, error) {
if len(addr) <= 0 {
return "", nil, nil, fmt.Errorf("missing address")
}
return common.DecodeAddress(addr)
}
func parseReceivers(outs []*arkv1.Output) ([]domain.Receiver, error) {
receivers := make([]domain.Receiver, 0, len(outs))
for _, out := range outs {
if out.GetAmount() == 0 {
return nil, fmt.Errorf("missing output amount")
}
if len(out.GetAddress()) <= 0 {
return nil, fmt.Errorf("missing output address")
}
var pubkey, addr string
_, pk, _, err := common.DecodeAddress(out.GetAddress())
if err != nil {
if _, err := address.ToOutputScript(out.GetAddress()); err != nil {
return nil, fmt.Errorf("invalid output address: unknown format")
}
if isConf, _ := address.IsConfidential(out.GetAddress()); isConf {
return nil, fmt.Errorf("invalid output address: must be unconfidential")
}
addr = out.GetAddress()
}
if pk != nil {
pubkey = hex.EncodeToString(pk.SerializeCompressed())
}
receivers = append(receivers, domain.Receiver{
Pubkey: pubkey,
Amount: out.GetAmount(),
OnchainAddress: addr,
})
}
return receivers, nil
}

View File

@@ -0,0 +1,16 @@
package interceptors
import (
middleware "github.com/grpc-ecosystem/go-grpc-middleware"
"google.golang.org/grpc"
)
// UnaryInterceptor returns the unary interceptor
func UnaryInterceptor() grpc.ServerOption {
return grpc.UnaryInterceptor(middleware.ChainUnaryServer(unaryLogger))
}
// StreamInterceptor returns the stream interceptor with a logrus log
func StreamInterceptor() grpc.ServerOption {
return grpc.StreamInterceptor(middleware.ChainStreamServer(streamLogger))
}

View File

@@ -0,0 +1,28 @@
package interceptors
import (
"context"
log "github.com/sirupsen/logrus"
"google.golang.org/grpc"
)
func unaryLogger(
ctx context.Context,
req interface{},
info *grpc.UnaryServerInfo,
handler grpc.UnaryHandler,
) (interface{}, error) {
log.Debugf("gRPC method: %s", info.FullMethod)
return handler(ctx, req)
}
func streamLogger(
srv interface{},
stream grpc.ServerStream,
info *grpc.StreamServerInfo,
handler grpc.StreamHandler,
) error {
log.Debugf("gRPC method: %s", info.FullMethod)
return handler(srv, stream)
}

View File

@@ -0,0 +1,64 @@
package grpcservice
import (
"fmt"
arkv1 "github.com/ark-network/ark/api-spec/protobuf/gen/ark/v1"
appconfig "github.com/ark-network/ark/internal/app-config"
interfaces "github.com/ark-network/ark/internal/interface"
"github.com/ark-network/ark/internal/interface/grpc/handlers"
"github.com/ark-network/ark/internal/interface/grpc/interceptors"
log "github.com/sirupsen/logrus"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
)
type service struct {
config Config
appConfig *appconfig.Config
server *grpc.Server
}
func NewService(
svcConfig Config, appConfig *appconfig.Config,
) (interfaces.Service, error) {
if err := svcConfig.Validate(); err != nil {
return nil, fmt.Errorf("invalid service config: %s", err)
}
if err := appConfig.Validate(); err != nil {
return nil, fmt.Errorf("invalid app config: %s", err)
}
grpcConfig := []grpc.ServerOption{
interceptors.UnaryInterceptor(), interceptors.StreamInterceptor(),
}
if !svcConfig.NoTLS {
return nil, fmt.Errorf("tls termination not supported yet")
}
creds := insecure.NewCredentials()
grpcConfig = append(grpcConfig, grpc.Creds(creds))
server := grpc.NewServer(grpcConfig...)
handler := handlers.NewHandler(appConfig.AppService())
arkv1.RegisterArkServiceServer(server, handler)
return &service{svcConfig, appConfig, server}, nil
}
func (s *service) Start() error {
// nolint:all
go s.server.Serve(s.config.listener())
log.Infof("started listening at %s", s.config.address())
if err := s.appConfig.AppService().Start(); err != nil {
return fmt.Errorf("failed to start app service: %s", err)
}
log.Info("started app service")
return nil
}
func (s *service) Stop() {
s.server.Stop()
log.Info("stopped grpc server")
s.appConfig.AppService().Stop()
log.Info("stopped app service")
}

View File

@@ -0,0 +1,6 @@
package interfaces
type Service interface {
Start() error
Stop()
}

0
server/internal/test/.gitkeep Executable file
View File