mirror of
https://github.com/aljazceru/ark.git
synced 2025-12-18 12:44:19 +01:00
Rename folders (#97)
* Rename arkd folder & drop cli * Rename ark cli folder & update docs * Update readme * Fix * scripts: add build-all * Add target to build cli for all platforms * Update build scripts --------- Co-authored-by: tiero <3596602+tiero@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
0d8c7bffb2
commit
dc00d60585
249
server/internal/app-config/config.go
Normal file
249
server/internal/app-config/config.go
Normal file
@@ -0,0 +1,249 @@
|
||||
package appconfig
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/ark-network/ark/common"
|
||||
"github.com/ark-network/ark/internal/core/application"
|
||||
"github.com/ark-network/ark/internal/core/ports"
|
||||
"github.com/ark-network/ark/internal/infrastructure/db"
|
||||
oceanwallet "github.com/ark-network/ark/internal/infrastructure/ocean-wallet"
|
||||
scheduler "github.com/ark-network/ark/internal/infrastructure/scheduler/gocron"
|
||||
txbuilder "github.com/ark-network/ark/internal/infrastructure/tx-builder/covenant"
|
||||
txbuilderdummy "github.com/ark-network/ark/internal/infrastructure/tx-builder/dummy"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/vulpemventures/go-elements/network"
|
||||
)
|
||||
|
||||
var (
|
||||
supportedDbs = supportedType{
|
||||
"badger": {},
|
||||
}
|
||||
supportedSchedulers = supportedType{
|
||||
"gocron": {},
|
||||
}
|
||||
supportedTxBuilders = supportedType{
|
||||
"dummy": {},
|
||||
"covenant": {},
|
||||
}
|
||||
supportedScanners = supportedType{
|
||||
"ocean": {},
|
||||
}
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
DbType string
|
||||
DbDir string
|
||||
RoundInterval int64
|
||||
Network common.Network
|
||||
SchedulerType string
|
||||
TxBuilderType string
|
||||
BlockchainScannerType string
|
||||
WalletAddr string
|
||||
MinRelayFee uint64
|
||||
RoundLifetime int64
|
||||
|
||||
repo ports.RepoManager
|
||||
svc application.Service
|
||||
wallet ports.WalletService
|
||||
txBuilder ports.TxBuilder
|
||||
scanner ports.BlockchainScanner
|
||||
scheduler ports.SchedulerService
|
||||
}
|
||||
|
||||
func (c *Config) Validate() error {
|
||||
if !supportedDbs.supports(c.DbType) {
|
||||
return fmt.Errorf("db type not supported, please select one of: %s", supportedDbs)
|
||||
}
|
||||
if !supportedSchedulers.supports(c.SchedulerType) {
|
||||
return fmt.Errorf("scheduler type not supported, please select one of: %s", supportedSchedulers)
|
||||
}
|
||||
if !supportedTxBuilders.supports(c.TxBuilderType) {
|
||||
return fmt.Errorf("tx builder type not supported, please select one of: %s", supportedTxBuilders)
|
||||
}
|
||||
if !supportedScanners.supports(c.BlockchainScannerType) {
|
||||
return fmt.Errorf("blockchain scanner type not supported, please select one of: %s", supportedScanners)
|
||||
}
|
||||
if c.RoundInterval < 5 {
|
||||
return fmt.Errorf("invalid round interval, must be at least 5 seconds")
|
||||
}
|
||||
if c.Network.Name != "liquid" && c.Network.Name != "testnet" {
|
||||
return fmt.Errorf("invalid network, must be either liquid or testnet")
|
||||
}
|
||||
if len(c.WalletAddr) <= 0 {
|
||||
return fmt.Errorf("missing onchain wallet address")
|
||||
}
|
||||
if c.MinRelayFee < 30 {
|
||||
return fmt.Errorf("invalid min relay fee, must be at least 30 sats")
|
||||
}
|
||||
if err := c.repoManager(); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := c.walletService(); err != nil {
|
||||
return fmt.Errorf("failed to connect to wallet: %s", err)
|
||||
}
|
||||
if err := c.txBuilderService(); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := c.scannerService(); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := c.schedulerService(); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := c.appService(); err != nil {
|
||||
return err
|
||||
}
|
||||
// round life time must be a multiple of 512
|
||||
if c.RoundLifetime <= 0 || c.RoundLifetime%512 != 0 {
|
||||
return fmt.Errorf("invalid round lifetime, must be greater than 0 and a multiple of 512")
|
||||
}
|
||||
seq, err := common.BIP68Encode(uint(c.RoundLifetime))
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid round lifetime, %s", err)
|
||||
}
|
||||
|
||||
seconds, err := common.BIP68Decode(seq)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid round lifetime, %s", err)
|
||||
}
|
||||
|
||||
if seconds != uint(c.RoundLifetime) {
|
||||
return fmt.Errorf("invalid round lifetime, must be a multiple of 512")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Config) AppService() application.Service {
|
||||
return c.svc
|
||||
}
|
||||
|
||||
func (c *Config) repoManager() error {
|
||||
var svc ports.RepoManager
|
||||
var err error
|
||||
switch c.DbType {
|
||||
case "badger":
|
||||
logger := log.New()
|
||||
svc, err = db.NewService(db.ServiceConfig{
|
||||
EventStoreType: c.DbType,
|
||||
RoundStoreType: c.DbType,
|
||||
VtxoStoreType: c.DbType,
|
||||
|
||||
EventStoreConfig: []interface{}{c.DbDir, logger},
|
||||
RoundStoreConfig: []interface{}{c.DbDir, logger},
|
||||
VtxoStoreConfig: []interface{}{c.DbDir, logger},
|
||||
})
|
||||
default:
|
||||
return fmt.Errorf("unknown db type")
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.repo = svc
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Config) walletService() error {
|
||||
svc, err := oceanwallet.NewService(c.WalletAddr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.wallet = svc
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Config) txBuilderService() error {
|
||||
var svc ports.TxBuilder
|
||||
var err error
|
||||
net := c.mainChain()
|
||||
|
||||
switch c.TxBuilderType {
|
||||
case "dummy":
|
||||
svc = txbuilderdummy.NewTxBuilder(c.wallet, net)
|
||||
case "covenant":
|
||||
svc = txbuilder.NewTxBuilder(c.wallet, net, c.RoundLifetime)
|
||||
default:
|
||||
err = fmt.Errorf("unknown tx builder type")
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.txBuilder = svc
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Config) scannerService() error {
|
||||
var svc ports.BlockchainScanner
|
||||
var err error
|
||||
switch c.BlockchainScannerType {
|
||||
case "ocean":
|
||||
svc = c.wallet
|
||||
default:
|
||||
err = fmt.Errorf("unknown blockchain scanner type")
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.scanner = svc
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Config) schedulerService() error {
|
||||
var svc ports.SchedulerService
|
||||
var err error
|
||||
switch c.SchedulerType {
|
||||
case "gocron":
|
||||
svc = scheduler.NewScheduler()
|
||||
default:
|
||||
err = fmt.Errorf("unknown scheduler type")
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.scheduler = svc
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Config) appService() error {
|
||||
net := c.mainChain()
|
||||
svc, err := application.NewService(
|
||||
c.Network, net, c.RoundInterval, c.RoundLifetime, c.MinRelayFee,
|
||||
c.wallet, c.repo, c.txBuilder, c.scanner, c.scheduler,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.svc = svc
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Config) mainChain() network.Network {
|
||||
net := network.Liquid
|
||||
if c.Network.Name != "mainnet" {
|
||||
net = network.Testnet
|
||||
}
|
||||
return net
|
||||
}
|
||||
|
||||
type supportedType map[string]struct{}
|
||||
|
||||
func (t supportedType) String() string {
|
||||
types := make([]string, 0, len(t))
|
||||
for tt := range t {
|
||||
types = append(types, tt)
|
||||
}
|
||||
return strings.Join(types, " | ")
|
||||
}
|
||||
|
||||
func (t supportedType) supports(typeStr string) bool {
|
||||
_, ok := t[typeStr]
|
||||
return ok
|
||||
}
|
||||
122
server/internal/config/config.go
Normal file
122
server/internal/config/config.go
Normal file
@@ -0,0 +1,122 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
common "github.com/ark-network/ark/common"
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
WalletAddr string
|
||||
RoundInterval int64
|
||||
Port uint32
|
||||
DbType string
|
||||
DbDir string
|
||||
SchedulerType string
|
||||
TxBuilderType string
|
||||
BlockchainScannerType string
|
||||
NoTLS bool
|
||||
Network common.Network
|
||||
LogLevel int
|
||||
MinRelayFee uint64
|
||||
RoundLifetime int64
|
||||
}
|
||||
|
||||
var (
|
||||
Datadir = "DATADIR"
|
||||
WalletAddr = "WALLET_ADDR"
|
||||
RoundInterval = "ROUND_INTERVAL"
|
||||
Port = "PORT"
|
||||
DbType = "DB_TYPE"
|
||||
SchedulerType = "SCHEDULER_TYPE"
|
||||
TxBuilderType = "TX_BUILDER_TYPE"
|
||||
BlockchainScannerType = "BC_SCANNER_TYPE"
|
||||
Insecure = "INSECURE"
|
||||
LogLevel = "LOG_LEVEL"
|
||||
Network = "NETWORK"
|
||||
MinRelayFee = "MIN_RELAY_FEE"
|
||||
RoundLifetime = "ROUND_LIFETIME"
|
||||
|
||||
defaultDatadir = common.AppDataDir("arkd", false)
|
||||
defaultRoundInterval = 60
|
||||
defaultPort = 6000
|
||||
defaultDbType = "badger"
|
||||
defaultSchedulerType = "gocron"
|
||||
defaultTxBuilderType = "covenant"
|
||||
defaultBlockchainScannerType = "ocean"
|
||||
defaultInsecure = true
|
||||
defaultNetwork = "testnet"
|
||||
defaultLogLevel = 5
|
||||
defaultMinRelayFee = 30
|
||||
defaultRoundLifetime = 512
|
||||
)
|
||||
|
||||
func LoadConfig() (*Config, error) {
|
||||
viper.SetEnvPrefix("ARK")
|
||||
viper.AutomaticEnv()
|
||||
|
||||
viper.SetDefault(Datadir, defaultDatadir)
|
||||
viper.SetDefault(RoundInterval, defaultRoundInterval)
|
||||
viper.SetDefault(Port, defaultPort)
|
||||
viper.SetDefault(DbType, defaultDbType)
|
||||
viper.SetDefault(SchedulerType, defaultSchedulerType)
|
||||
viper.SetDefault(TxBuilderType, defaultTxBuilderType)
|
||||
viper.SetDefault(BlockchainScannerType, defaultBlockchainScannerType)
|
||||
viper.SetDefault(Insecure, defaultInsecure)
|
||||
viper.SetDefault(LogLevel, defaultLogLevel)
|
||||
viper.SetDefault(Network, defaultNetwork)
|
||||
viper.SetDefault(RoundLifetime, defaultRoundLifetime)
|
||||
viper.SetDefault(MinRelayFee, defaultMinRelayFee)
|
||||
|
||||
net, err := getNetwork()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := initDatadir(); err != nil {
|
||||
return nil, fmt.Errorf("error while creating datadir: %s", err)
|
||||
}
|
||||
|
||||
return &Config{
|
||||
WalletAddr: viper.GetString(WalletAddr),
|
||||
RoundInterval: viper.GetInt64(RoundInterval),
|
||||
Port: viper.GetUint32(Port),
|
||||
DbType: viper.GetString(DbType),
|
||||
SchedulerType: viper.GetString(SchedulerType),
|
||||
TxBuilderType: viper.GetString(TxBuilderType),
|
||||
BlockchainScannerType: viper.GetString(BlockchainScannerType),
|
||||
NoTLS: viper.GetBool(Insecure),
|
||||
DbDir: filepath.Join(viper.GetString(Datadir), "db"),
|
||||
LogLevel: viper.GetInt(LogLevel),
|
||||
Network: net,
|
||||
MinRelayFee: viper.GetUint64(MinRelayFee),
|
||||
RoundLifetime: viper.GetInt64(RoundLifetime),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func initDatadir() error {
|
||||
datadir := viper.GetString(Datadir)
|
||||
return makeDirectoryIfNotExists(datadir)
|
||||
}
|
||||
|
||||
func makeDirectoryIfNotExists(path string) error {
|
||||
if _, err := os.Stat(path); os.IsNotExist(err) {
|
||||
return os.MkdirAll(path, os.ModeDir|0755)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func getNetwork() (common.Network, error) {
|
||||
switch strings.ToLower(viper.GetString(Network)) {
|
||||
case "mainnet":
|
||||
return common.MainNet, nil
|
||||
case "testnet":
|
||||
return common.TestNet, nil
|
||||
default:
|
||||
return common.Network{}, fmt.Errorf("unknown network")
|
||||
}
|
||||
}
|
||||
601
server/internal/core/application/service.go
Normal file
601
server/internal/core/application/service.go
Normal file
@@ -0,0 +1,601 @@
|
||||
package application
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/ark-network/ark/common"
|
||||
"github.com/ark-network/ark/internal/core/domain"
|
||||
"github.com/ark-network/ark/internal/core/ports"
|
||||
"github.com/btcsuite/btcd/chaincfg/chainhash"
|
||||
"github.com/decred/dcrd/dcrec/secp256k1/v4"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/vulpemventures/go-elements/network"
|
||||
"github.com/vulpemventures/go-elements/psetv2"
|
||||
)
|
||||
|
||||
var (
|
||||
paymentsThreshold = int64(128)
|
||||
dustAmount = uint64(450)
|
||||
faucetVtxo = domain.VtxoKey{
|
||||
Txid: "0000000000000000000000000000000000000000000000000000000000000000",
|
||||
VOut: 0,
|
||||
}
|
||||
)
|
||||
|
||||
type Service interface {
|
||||
Start() error
|
||||
Stop()
|
||||
SpendVtxos(ctx context.Context, inputs []domain.VtxoKey) (string, error)
|
||||
ClaimVtxos(ctx context.Context, creds string, receivers []domain.Receiver) error
|
||||
SignVtxos(ctx context.Context, forfeitTxs []string) error
|
||||
FaucetVtxos(ctx context.Context, pubkey *secp256k1.PublicKey) error
|
||||
GetRoundByTxid(ctx context.Context, poolTxid string) (*domain.Round, error)
|
||||
GetEventsChannel(ctx context.Context) <-chan domain.RoundEvent
|
||||
UpdatePaymentStatus(ctx context.Context, id string) error
|
||||
ListVtxos(ctx context.Context, pubkey *secp256k1.PublicKey) ([]domain.Vtxo, error)
|
||||
GetPubkey(ctx context.Context) (string, error)
|
||||
}
|
||||
|
||||
type service struct {
|
||||
network common.Network
|
||||
onchainNework network.Network
|
||||
pubkey *secp256k1.PublicKey
|
||||
roundLifetime int64
|
||||
roundInterval int64
|
||||
minRelayFee uint64
|
||||
|
||||
wallet ports.WalletService
|
||||
repoManager ports.RepoManager
|
||||
builder ports.TxBuilder
|
||||
scanner ports.BlockchainScanner
|
||||
sweeper *sweeper
|
||||
|
||||
paymentRequests *paymentsMap
|
||||
forfeitTxs *forfeitTxsMap
|
||||
|
||||
eventsCh chan domain.RoundEvent
|
||||
}
|
||||
|
||||
func NewService(
|
||||
network common.Network, onchainNetwork network.Network,
|
||||
roundInterval, roundLifetime int64, minRelayFee uint64,
|
||||
walletSvc ports.WalletService, repoManager ports.RepoManager,
|
||||
builder ports.TxBuilder, scanner ports.BlockchainScanner,
|
||||
scheduler ports.SchedulerService,
|
||||
) (Service, error) {
|
||||
eventsCh := make(chan domain.RoundEvent)
|
||||
paymentRequests := newPaymentsMap(nil)
|
||||
|
||||
genesisHash, _ := chainhash.NewHashFromStr(onchainNetwork.GenesisBlockHash)
|
||||
forfeitTxs := newForfeitTxsMap(genesisHash)
|
||||
pubkey, err := walletSvc.GetPubkey(context.Background())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to fetch pubkey: %s", err)
|
||||
}
|
||||
|
||||
sweeper := newSweeper(walletSvc, repoManager, builder, scheduler)
|
||||
|
||||
svc := &service{
|
||||
network, onchainNetwork, pubkey,
|
||||
roundLifetime, roundInterval, minRelayFee,
|
||||
walletSvc, repoManager, builder, scanner, sweeper,
|
||||
paymentRequests, forfeitTxs, eventsCh,
|
||||
}
|
||||
repoManager.RegisterEventsHandler(
|
||||
func(round *domain.Round) {
|
||||
svc.updateProjectionStore(round)
|
||||
svc.propagateEvents(round)
|
||||
},
|
||||
)
|
||||
|
||||
if err := svc.restoreWatchingVtxos(); err != nil {
|
||||
return nil, fmt.Errorf("failed to restore watching vtxos: %s", err)
|
||||
}
|
||||
go svc.listenToRedemptions()
|
||||
return svc, nil
|
||||
}
|
||||
|
||||
func (s *service) Start() error {
|
||||
log.Debug("starting sweeper service")
|
||||
if err := s.sweeper.start(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Debug("starting app service")
|
||||
go s.start()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *service) Stop() {
|
||||
s.sweeper.stop()
|
||||
// nolint
|
||||
vtxos, _ := s.repoManager.Vtxos().GetSpendableVtxos(
|
||||
context.Background(), "",
|
||||
)
|
||||
if len(vtxos) > 0 {
|
||||
s.stopWatchingVtxos(vtxos)
|
||||
}
|
||||
|
||||
s.wallet.Close()
|
||||
log.Debug("closed connection to wallet")
|
||||
s.repoManager.Close()
|
||||
log.Debug("closed connection to db")
|
||||
}
|
||||
|
||||
func (s *service) SpendVtxos(ctx context.Context, inputs []domain.VtxoKey) (string, error) {
|
||||
vtxos, err := s.repoManager.Vtxos().GetVtxos(ctx, inputs)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
for _, v := range vtxos {
|
||||
if v.Spent {
|
||||
return "", fmt.Errorf("input %s:%d already spent", v.Txid, v.VOut)
|
||||
}
|
||||
}
|
||||
|
||||
payment, err := domain.NewPayment(vtxos)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if err := s.paymentRequests.push(*payment); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return payment.Id, nil
|
||||
}
|
||||
|
||||
func (s *service) ClaimVtxos(ctx context.Context, creds string, receivers []domain.Receiver) error {
|
||||
// Check credentials
|
||||
payment, ok := s.paymentRequests.view(creds)
|
||||
if !ok {
|
||||
return fmt.Errorf("invalid credentials")
|
||||
}
|
||||
|
||||
if err := payment.AddReceivers(receivers); err != nil {
|
||||
return err
|
||||
}
|
||||
return s.paymentRequests.update(*payment)
|
||||
}
|
||||
|
||||
func (s *service) UpdatePaymentStatus(_ context.Context, id string) error {
|
||||
return s.paymentRequests.updatePingTimestamp(id)
|
||||
}
|
||||
|
||||
func (s *service) FaucetVtxos(ctx context.Context, userPubkey *secp256k1.PublicKey) error {
|
||||
pubkey := hex.EncodeToString(userPubkey.SerializeCompressed())
|
||||
|
||||
payment, err := domain.NewPayment([]domain.Vtxo{
|
||||
{
|
||||
VtxoKey: faucetVtxo,
|
||||
Receiver: domain.Receiver{
|
||||
Pubkey: pubkey,
|
||||
Amount: 10000,
|
||||
},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := payment.AddReceivers([]domain.Receiver{
|
||||
{Pubkey: pubkey, Amount: 1000},
|
||||
{Pubkey: pubkey, Amount: 1000},
|
||||
{Pubkey: pubkey, Amount: 1000},
|
||||
{Pubkey: pubkey, Amount: 1000},
|
||||
{Pubkey: pubkey, Amount: 1000},
|
||||
{Pubkey: pubkey, Amount: 1000},
|
||||
{Pubkey: pubkey, Amount: 1000},
|
||||
{Pubkey: pubkey, Amount: 1000},
|
||||
{Pubkey: pubkey, Amount: 1000},
|
||||
{Pubkey: pubkey, Amount: 1000},
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := s.paymentRequests.push(*payment); err != nil {
|
||||
return err
|
||||
}
|
||||
return s.paymentRequests.updatePingTimestamp(payment.Id)
|
||||
}
|
||||
|
||||
func (s *service) SignVtxos(ctx context.Context, forfeitTxs []string) error {
|
||||
return s.forfeitTxs.sign(forfeitTxs)
|
||||
}
|
||||
|
||||
func (s *service) ListVtxos(ctx context.Context, pubkey *secp256k1.PublicKey) ([]domain.Vtxo, error) {
|
||||
pk := hex.EncodeToString(pubkey.SerializeCompressed())
|
||||
return s.repoManager.Vtxos().GetSpendableVtxos(ctx, pk)
|
||||
}
|
||||
|
||||
func (s *service) GetEventsChannel(ctx context.Context) <-chan domain.RoundEvent {
|
||||
return s.eventsCh
|
||||
}
|
||||
|
||||
func (s *service) GetRoundByTxid(ctx context.Context, poolTxid string) (*domain.Round, error) {
|
||||
return s.repoManager.Rounds().GetRoundWithTxid(ctx, poolTxid)
|
||||
}
|
||||
|
||||
func (s *service) GetPubkey(ctx context.Context) (string, error) {
|
||||
pubkey, err := common.EncodePubKey(s.network.PubKey, s.pubkey)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return pubkey, nil
|
||||
}
|
||||
|
||||
func (s *service) start() {
|
||||
s.startRound()
|
||||
}
|
||||
|
||||
func (s *service) startRound() {
|
||||
round := domain.NewRound(dustAmount)
|
||||
changes, _ := round.StartRegistration()
|
||||
if err := s.repoManager.Events().Save(
|
||||
context.Background(), round.Id, changes...,
|
||||
); err != nil {
|
||||
log.WithError(err).Warn("failed to store new round events")
|
||||
return
|
||||
}
|
||||
|
||||
defer func() {
|
||||
time.Sleep(time.Duration(s.roundInterval/2) * time.Second)
|
||||
s.startFinalization()
|
||||
}()
|
||||
|
||||
log.Debugf("started registration stage for new round: %s", round.Id)
|
||||
}
|
||||
|
||||
func (s *service) startFinalization() {
|
||||
ctx := context.Background()
|
||||
round, err := s.repoManager.Rounds().GetCurrentRound(ctx)
|
||||
if err != nil {
|
||||
log.WithError(err).Warn("failed to retrieve current round")
|
||||
return
|
||||
}
|
||||
|
||||
var changes []domain.RoundEvent
|
||||
defer func() {
|
||||
if len(changes) > 0 {
|
||||
if err := s.repoManager.Events().Save(ctx, round.Id, changes...); err != nil {
|
||||
log.WithError(err).Warn("failed to store new round events")
|
||||
}
|
||||
}
|
||||
|
||||
if round.IsFailed() {
|
||||
s.startRound()
|
||||
return
|
||||
}
|
||||
time.Sleep(time.Duration((s.roundInterval/2)-1) * time.Second)
|
||||
s.finalizeRound()
|
||||
}()
|
||||
|
||||
if round.IsFailed() {
|
||||
return
|
||||
}
|
||||
|
||||
// TODO: understand how many payments must be popped from the queue and actually registered for the round
|
||||
num := s.paymentRequests.len()
|
||||
if num == 0 {
|
||||
err := fmt.Errorf("no payments registered")
|
||||
changes = round.Fail(fmt.Errorf("round aborted: %s", err))
|
||||
log.WithError(err).Debugf("round %s aborted", round.Id)
|
||||
return
|
||||
}
|
||||
if num > paymentsThreshold {
|
||||
num = paymentsThreshold
|
||||
}
|
||||
payments := s.paymentRequests.pop(num)
|
||||
changes, err = round.RegisterPayments(payments)
|
||||
if err != nil {
|
||||
changes = round.Fail(fmt.Errorf("failed to register payments: %s", err))
|
||||
log.WithError(err).Warn("failed to register payments")
|
||||
return
|
||||
}
|
||||
|
||||
unsignedPoolTx, tree, err := s.builder.BuildPoolTx(s.pubkey, payments, s.minRelayFee)
|
||||
if err != nil {
|
||||
changes = round.Fail(fmt.Errorf("failed to create pool tx: %s", err))
|
||||
log.WithError(err).Warn("failed to create pool tx")
|
||||
return
|
||||
}
|
||||
|
||||
log.Debugf("pool tx created for round %s", round.Id)
|
||||
|
||||
connectors, forfeitTxs, err := s.builder.BuildForfeitTxs(s.pubkey, unsignedPoolTx, payments)
|
||||
if err != nil {
|
||||
changes = round.Fail(fmt.Errorf("failed to create connectors and forfeit txs: %s", err))
|
||||
log.WithError(err).Warn("failed to create connectors and forfeit txs")
|
||||
return
|
||||
}
|
||||
|
||||
log.Debugf("forfeit transactions created for round %s", round.Id)
|
||||
|
||||
events, err := round.StartFinalization(connectors, tree, unsignedPoolTx)
|
||||
if err != nil {
|
||||
changes = round.Fail(fmt.Errorf("failed to start finalization: %s", err))
|
||||
log.WithError(err).Warn("failed to start finalization")
|
||||
return
|
||||
}
|
||||
changes = append(changes, events...)
|
||||
|
||||
s.forfeitTxs.push(forfeitTxs)
|
||||
|
||||
log.Debugf("started finalization stage for round: %s", round.Id)
|
||||
}
|
||||
|
||||
func (s *service) finalizeRound() {
|
||||
defer s.startRound()
|
||||
|
||||
ctx := context.Background()
|
||||
round, err := s.repoManager.Rounds().GetCurrentRound(ctx)
|
||||
if err != nil {
|
||||
log.WithError(err).Warn("failed to retrieve current round")
|
||||
return
|
||||
}
|
||||
if round.IsFailed() {
|
||||
return
|
||||
}
|
||||
|
||||
var changes []domain.RoundEvent
|
||||
defer func() {
|
||||
if err := s.repoManager.Events().Save(ctx, round.Id, changes...); err != nil {
|
||||
log.WithError(err).Warn("failed to store new round events")
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
forfeitTxs, leftUnsigned := s.forfeitTxs.pop()
|
||||
if len(leftUnsigned) > 0 {
|
||||
err := fmt.Errorf("%d forfeit txs left to sign", len(leftUnsigned))
|
||||
changes = round.Fail(fmt.Errorf("failed to finalize round: %s", err))
|
||||
log.WithError(err).Warn("failed to finalize round")
|
||||
return
|
||||
}
|
||||
|
||||
log.Debugf("signing round transaction %s\n", round.Id)
|
||||
signedPoolTx, err := s.wallet.SignPset(ctx, round.UnsignedTx, true)
|
||||
if err != nil {
|
||||
changes = round.Fail(fmt.Errorf("failed to sign round tx: %s", err))
|
||||
log.WithError(err).Warn("failed to sign round tx")
|
||||
return
|
||||
}
|
||||
|
||||
txid, err := s.wallet.BroadcastTransaction(ctx, signedPoolTx)
|
||||
if err != nil {
|
||||
changes = round.Fail(fmt.Errorf("failed to broadcast pool tx: %s", err))
|
||||
log.WithError(err).Warn("failed to broadcast pool tx")
|
||||
return
|
||||
}
|
||||
|
||||
now := time.Now().Unix()
|
||||
expirationTimestamp := now + s.roundLifetime + 30 // add 30 secs to be sure that the tx is confirmed
|
||||
|
||||
if err := s.sweeper.schedule(expirationTimestamp, txid, round.CongestionTree); err != nil {
|
||||
changes = round.Fail(fmt.Errorf("failed to schedule sweep tx: %s", err))
|
||||
log.WithError(err).Warn("failed to schedule sweep tx")
|
||||
return
|
||||
}
|
||||
|
||||
changes, _ = round.EndFinalization(forfeitTxs, txid)
|
||||
|
||||
log.Debugf("finalized round %s with pool tx %s", round.Id, round.Txid)
|
||||
}
|
||||
|
||||
func (s *service) listenToRedemptions() {
|
||||
ctx := context.Background()
|
||||
chVtxos := s.scanner.GetNotificationChannel(ctx)
|
||||
for vtxoKeys := range chVtxos {
|
||||
if len(vtxoKeys) > 0 {
|
||||
for {
|
||||
// TODO: make sure that the vtxos haven't been already spent, otherwise
|
||||
// broadcast the corresponding forfeit tx and connector to prevent
|
||||
// getting cheated.
|
||||
vtxos, err := s.repoManager.Vtxos().RedeemVtxos(ctx, vtxoKeys)
|
||||
if err != nil {
|
||||
log.WithError(err).Warn("failed to redeem vtxos, retrying...")
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
continue
|
||||
}
|
||||
if len(vtxos) > 0 {
|
||||
log.Debugf("redeemed %d vtxos", len(vtxos))
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *service) updateProjectionStore(round *domain.Round) {
|
||||
ctx := context.Background()
|
||||
lastChange := round.Events()[len(round.Events())-1]
|
||||
// Update the vtxo set only after a round is finalized.
|
||||
if _, ok := lastChange.(domain.RoundFinalized); ok {
|
||||
repo := s.repoManager.Vtxos()
|
||||
spentVtxos := getSpentVtxos(round.Payments)
|
||||
if len(spentVtxos) > 0 {
|
||||
for {
|
||||
if err := repo.SpendVtxos(ctx, spentVtxos); err != nil {
|
||||
log.WithError(err).Warn("failed to add new vtxos, retrying soon")
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
continue
|
||||
}
|
||||
log.Debugf("spent %d vtxos", len(spentVtxos))
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
newVtxos := s.getNewVtxos(round)
|
||||
for {
|
||||
if err := repo.AddVtxos(ctx, newVtxos); err != nil {
|
||||
log.WithError(err).Warn("failed to add new vtxos, retrying soon")
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
continue
|
||||
}
|
||||
log.Debugf("added %d new vtxos", len(newVtxos))
|
||||
break
|
||||
}
|
||||
|
||||
go func() {
|
||||
for {
|
||||
if err := s.startWatchingVtxos(newVtxos); err != nil {
|
||||
log.WithError(err).Warn(
|
||||
"failed to start watching vtxos, retrying in a moment...",
|
||||
)
|
||||
continue
|
||||
}
|
||||
log.Debugf("started watching %d vtxos", len(newVtxos))
|
||||
return
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// Always update the status of the round.
|
||||
for {
|
||||
if err := s.repoManager.Rounds().AddOrUpdateRound(ctx, *round); err != nil {
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
continue
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
func (s *service) propagateEvents(round *domain.Round) {
|
||||
lastEvent := round.Events()[len(round.Events())-1]
|
||||
switch e := lastEvent.(type) {
|
||||
case domain.RoundFinalizationStarted:
|
||||
forfeitTxs := s.forfeitTxs.view()
|
||||
s.eventsCh <- domain.RoundFinalizationStarted{
|
||||
Id: e.Id,
|
||||
CongestionTree: e.CongestionTree,
|
||||
Connectors: e.Connectors,
|
||||
PoolTx: e.PoolTx,
|
||||
UnsignedForfeitTxs: forfeitTxs,
|
||||
}
|
||||
case domain.RoundFinalized, domain.RoundFailed:
|
||||
s.eventsCh <- e
|
||||
}
|
||||
}
|
||||
|
||||
func (s *service) getNewVtxos(round *domain.Round) []domain.Vtxo {
|
||||
leaves := round.CongestionTree.Leaves()
|
||||
vtxos := make([]domain.Vtxo, 0)
|
||||
for _, node := range leaves {
|
||||
tx, _ := psetv2.NewPsetFromBase64(node.Tx)
|
||||
for i, out := range tx.Outputs {
|
||||
for _, p := range round.Payments {
|
||||
var pubkey string
|
||||
found := false
|
||||
for _, r := range p.Receivers {
|
||||
if r.IsOnchain() {
|
||||
continue
|
||||
}
|
||||
|
||||
buf, _ := hex.DecodeString(r.Pubkey)
|
||||
pk, _ := secp256k1.ParsePubKey(buf)
|
||||
script, _ := s.builder.GetVtxoScript(pk, s.pubkey)
|
||||
if bytes.Equal(script, out.Script) {
|
||||
found = true
|
||||
pubkey = r.Pubkey
|
||||
break
|
||||
}
|
||||
}
|
||||
if found {
|
||||
vtxos = append(vtxos, domain.Vtxo{
|
||||
VtxoKey: domain.VtxoKey{Txid: node.Txid, VOut: uint32(i)},
|
||||
Receiver: domain.Receiver{Pubkey: pubkey, Amount: out.Value},
|
||||
PoolTx: round.Txid,
|
||||
})
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return vtxos
|
||||
}
|
||||
|
||||
func (s *service) startWatchingVtxos(vtxos []domain.Vtxo) error {
|
||||
scripts, err := s.extractVtxosScripts(vtxos)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return s.scanner.WatchScripts(context.Background(), scripts)
|
||||
}
|
||||
|
||||
func (s *service) stopWatchingVtxos(vtxos []domain.Vtxo) {
|
||||
scripts, err := s.extractVtxosScripts(vtxos)
|
||||
if err != nil {
|
||||
log.WithError(err).Warn("failed to extract scripts from vtxos")
|
||||
return
|
||||
}
|
||||
|
||||
for {
|
||||
if err := s.scanner.UnwatchScripts(context.Background(), scripts); err != nil {
|
||||
log.WithError(err).Warn("failed to stop watching vtxos, retrying in a moment...")
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
continue
|
||||
}
|
||||
log.Debugf("stopped watching %d vtxos", len(vtxos))
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
func (s *service) restoreWatchingVtxos() error {
|
||||
vtxos, err := s.repoManager.Vtxos().GetSpendableVtxos(
|
||||
context.Background(), "",
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(vtxos) <= 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := s.startWatchingVtxos(vtxos); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Debugf("restored watching %d vtxos", len(vtxos))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *service) extractVtxosScripts(vtxos []domain.Vtxo) ([]string, error) {
|
||||
indexedScripts := make(map[string]struct{})
|
||||
for _, vtxo := range vtxos {
|
||||
buf, err := hex.DecodeString(vtxo.Pubkey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
userPubkey, err := secp256k1.ParsePubKey(buf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
script, err := s.builder.GetVtxoScript(userPubkey, s.pubkey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
indexedScripts[hex.EncodeToString(script)] = struct{}{}
|
||||
}
|
||||
scripts := make([]string, 0, len(indexedScripts))
|
||||
for script := range indexedScripts {
|
||||
scripts = append(scripts, script)
|
||||
}
|
||||
return scripts, nil
|
||||
}
|
||||
|
||||
func getSpentVtxos(payments map[string]domain.Payment) []domain.VtxoKey {
|
||||
vtxos := make([]domain.VtxoKey, 0)
|
||||
for _, p := range payments {
|
||||
for _, vtxo := range p.Inputs {
|
||||
if vtxo.VtxoKey == faucetVtxo {
|
||||
continue
|
||||
}
|
||||
vtxos = append(vtxos, vtxo.VtxoKey)
|
||||
}
|
||||
}
|
||||
return vtxos
|
||||
}
|
||||
579
server/internal/core/application/sweeper.go
Normal file
579
server/internal/core/application/sweeper.go
Normal file
@@ -0,0 +1,579 @@
|
||||
package application
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/ark-network/ark/common/tree"
|
||||
"github.com/ark-network/ark/internal/core/domain"
|
||||
"github.com/ark-network/ark/internal/core/ports"
|
||||
"github.com/btcsuite/btcd/chaincfg/chainhash"
|
||||
"github.com/decred/dcrd/dcrec/secp256k1/v4"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/vulpemventures/go-elements/psetv2"
|
||||
)
|
||||
|
||||
// sweeper is an unexported service running while the main application service is started
|
||||
// it is responsible for sweeping onchain shared outputs that expired
|
||||
// it also handles delaying the sweep events in case some parts of the tree are broadcasted
|
||||
// when a round is finalized, the main application service schedules a sweep event on the newly created congestion tree
|
||||
type sweeper struct {
|
||||
wallet ports.WalletService
|
||||
repoManager ports.RepoManager
|
||||
builder ports.TxBuilder
|
||||
scheduler ports.SchedulerService
|
||||
|
||||
// cache of scheduled tasks, avoid scheduling the same sweep event multiple times
|
||||
scheduledTasks map[string]struct{}
|
||||
}
|
||||
|
||||
func newSweeper(
|
||||
wallet ports.WalletService,
|
||||
repoManager ports.RepoManager,
|
||||
builder ports.TxBuilder,
|
||||
scheduler ports.SchedulerService,
|
||||
) *sweeper {
|
||||
return &sweeper{
|
||||
wallet,
|
||||
repoManager,
|
||||
builder,
|
||||
scheduler,
|
||||
make(map[string]struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *sweeper) start() error {
|
||||
s.scheduler.Start()
|
||||
|
||||
allRounds, err := s.repoManager.Rounds().GetSweepableRounds(context.Background())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, round := range allRounds {
|
||||
task := s.createTask(round.Txid, round.CongestionTree)
|
||||
task()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *sweeper) stop() {
|
||||
s.scheduler.Stop()
|
||||
}
|
||||
|
||||
// removeTask update the cached map of scheduled tasks
|
||||
func (s *sweeper) removeTask(treeRootTxid string) {
|
||||
delete(s.scheduledTasks, treeRootTxid)
|
||||
}
|
||||
|
||||
// schedule set up a task to be executed once at the given timestamp
|
||||
func (s *sweeper) schedule(
|
||||
expirationTimestamp int64, roundTxid string, congestionTree tree.CongestionTree,
|
||||
) error {
|
||||
root, err := congestionTree.Root()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if _, scheduled := s.scheduledTasks[root.Txid]; scheduled {
|
||||
return nil
|
||||
}
|
||||
|
||||
task := s.createTask(roundTxid, congestionTree)
|
||||
fancyTime := time.Unix(expirationTimestamp, 0).Format("2006-01-02 15:04:05")
|
||||
log.Debugf("scheduled sweep task at %s", fancyTime)
|
||||
if err := s.scheduler.ScheduleTaskOnce(expirationTimestamp, task); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.scheduledTasks[root.Txid] = struct{}{}
|
||||
return nil
|
||||
}
|
||||
|
||||
// createTask returns a function passed as handler in the scheduler
|
||||
// it tries to craft a sweep tx containing the onchain outputs of the given congestion tree
|
||||
// if some parts of the tree have been broadcasted in the meantine, it will schedule the next taskes for the remaining parts of the tree
|
||||
func (s *sweeper) createTask(
|
||||
roundTxid string, congestionTree tree.CongestionTree,
|
||||
) func() {
|
||||
return func() {
|
||||
ctx := context.Background()
|
||||
root, err := congestionTree.Root()
|
||||
if err != nil {
|
||||
log.WithError(err).Error("error while getting root node")
|
||||
return
|
||||
}
|
||||
|
||||
s.removeTask(root.Txid)
|
||||
log.Debugf("sweeper: %s", root.Txid)
|
||||
|
||||
sweepInputs := make([]ports.SweepInput, 0)
|
||||
vtxoKeys := make([]domain.VtxoKey, 0) // vtxos associated to the sweep inputs
|
||||
|
||||
// inspect the congestion tree to find onchain shared outputs
|
||||
sharedOutputs, err := s.findSweepableOutputs(ctx, congestionTree)
|
||||
if err != nil {
|
||||
log.WithError(err).Error("error while inspecting congestion tree")
|
||||
return
|
||||
}
|
||||
|
||||
for expiredAt, inputs := range sharedOutputs {
|
||||
// if the shared outputs are not expired, schedule a sweep task for it
|
||||
if time.Unix(expiredAt, 0).After(time.Now()) {
|
||||
subtrees, err := computeSubTrees(congestionTree, inputs)
|
||||
if err != nil {
|
||||
log.WithError(err).Error("error while computing subtrees")
|
||||
continue
|
||||
}
|
||||
|
||||
for _, subTree := range subtrees {
|
||||
// mitigate the risk to get BIP68 non-final errors by scheduling the task 30 seconds after the expiration time
|
||||
if err := s.schedule(int64(expiredAt), roundTxid, subTree); err != nil {
|
||||
log.WithError(err).Error("error while scheduling sweep task")
|
||||
continue
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// iterate over the expired shared outputs
|
||||
for _, input := range inputs {
|
||||
// sweepableVtxos related to the sweep input
|
||||
sweepableVtxos := make([]domain.VtxoKey, 0)
|
||||
|
||||
// check if input is the vtxo itself
|
||||
vtxos, _ := s.repoManager.Vtxos().GetVtxos(
|
||||
ctx,
|
||||
[]domain.VtxoKey{
|
||||
{
|
||||
Txid: input.InputArgs.Txid,
|
||||
VOut: input.InputArgs.TxIndex,
|
||||
},
|
||||
},
|
||||
)
|
||||
if len(vtxos) > 0 {
|
||||
if !vtxos[0].Swept && !vtxos[0].Redeemed {
|
||||
sweepableVtxos = append(sweepableVtxos, vtxos[0].VtxoKey)
|
||||
}
|
||||
} else {
|
||||
// if it's not a vtxo, find all the vtxos leaves reachable from that input
|
||||
vtxosLeaves, err := congestionTree.FindLeaves(input.InputArgs.Txid, input.InputArgs.TxIndex)
|
||||
if err != nil {
|
||||
log.WithError(err).Error("error while finding vtxos leaves")
|
||||
continue
|
||||
}
|
||||
|
||||
for _, leaf := range vtxosLeaves {
|
||||
pset, err := psetv2.NewPsetFromBase64(leaf.Tx)
|
||||
if err != nil {
|
||||
log.Error(fmt.Errorf("error while decoding pset: %w", err))
|
||||
continue
|
||||
}
|
||||
|
||||
vtxo, err := extractVtxoOutpoint(pset)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
continue
|
||||
}
|
||||
|
||||
sweepableVtxos = append(sweepableVtxos, *vtxo)
|
||||
}
|
||||
|
||||
if len(sweepableVtxos) <= 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
firstVtxo, err := s.repoManager.Vtxos().GetVtxos(ctx, sweepableVtxos[:1])
|
||||
if err != nil {
|
||||
log.Error(fmt.Errorf("error while getting vtxo: %w", err))
|
||||
sweepInputs = append(sweepInputs, input) // add the input anyway in order to try to sweep it
|
||||
continue
|
||||
}
|
||||
|
||||
if firstVtxo[0].Swept || firstVtxo[0].Redeemed {
|
||||
// we assume that if the first vtxo is swept or redeemed, the shared output has been spent
|
||||
// skip, the output is already swept or spent by a unilateral redeem
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
if len(sweepableVtxos) > 0 {
|
||||
vtxoKeys = append(vtxoKeys, sweepableVtxos...)
|
||||
sweepInputs = append(sweepInputs, input)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(sweepInputs) > 0 {
|
||||
// build the sweep transaction with all the expired non-swept shared outputs
|
||||
sweepTx, err := s.builder.BuildSweepTx(s.wallet, sweepInputs)
|
||||
if err != nil {
|
||||
log.WithError(err).Error("error while building sweep tx")
|
||||
return
|
||||
}
|
||||
|
||||
err = nil
|
||||
txid := ""
|
||||
// retry until the tx is broadcasted or the error is not BIP68 final
|
||||
for len(txid) == 0 && (err == nil || err == fmt.Errorf("non-BIP68-final")) {
|
||||
if err != nil {
|
||||
log.Debugln("sweep tx not BIP68 final, retrying in 5 seconds")
|
||||
time.Sleep(5 * time.Second)
|
||||
}
|
||||
|
||||
txid, err = s.wallet.BroadcastTransaction(ctx, sweepTx)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
log.WithError(err).Error("error while broadcasting sweep tx")
|
||||
return
|
||||
}
|
||||
if len(txid) > 0 {
|
||||
log.Debugln("sweep tx broadcasted:", txid)
|
||||
vtxosRepository := s.repoManager.Vtxos()
|
||||
|
||||
// mark the vtxos as swept
|
||||
if err := vtxosRepository.SweepVtxos(ctx, vtxoKeys); err != nil {
|
||||
log.Error(fmt.Errorf("error while deleting vtxos: %w", err))
|
||||
return
|
||||
}
|
||||
|
||||
log.Debugf("%d vtxos swept", len(vtxoKeys))
|
||||
|
||||
roundVtxos, err := vtxosRepository.GetVtxosForRound(ctx, roundTxid)
|
||||
if err != nil {
|
||||
log.WithError(err).Error("error while getting vtxos for round")
|
||||
return
|
||||
}
|
||||
|
||||
allSwept := true
|
||||
for _, vtxo := range roundVtxos {
|
||||
allSwept = allSwept && vtxo.Swept
|
||||
if !allSwept {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if allSwept {
|
||||
// update the round
|
||||
roundRepo := s.repoManager.Rounds()
|
||||
round, err := roundRepo.GetRoundWithTxid(ctx, roundTxid)
|
||||
if err != nil {
|
||||
log.WithError(err).Error("error while getting round")
|
||||
return
|
||||
}
|
||||
|
||||
round.Sweep()
|
||||
|
||||
if err := roundRepo.AddOrUpdateRound(ctx, *round); err != nil {
|
||||
log.WithError(err).Error("error while marking round as swept")
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// onchainOutputs iterates over all the nodes' outputs in the congestion tree and checks their onchain state
|
||||
// returns the sweepable outputs as ports.SweepInput mapped by their expiration time
|
||||
func (s *sweeper) findSweepableOutputs(
|
||||
ctx context.Context,
|
||||
congestionTree tree.CongestionTree,
|
||||
) (map[int64][]ports.SweepInput, error) {
|
||||
sweepableOutputs := make(map[int64][]ports.SweepInput)
|
||||
blocktimeCache := make(map[string]int64) // txid -> blocktime
|
||||
nodesToCheck := congestionTree[0] // init with the root
|
||||
|
||||
for len(nodesToCheck) > 0 {
|
||||
newNodesToCheck := make([]tree.Node, 0)
|
||||
|
||||
for _, node := range nodesToCheck {
|
||||
isPublished, blocktime, err := s.wallet.IsTransactionPublished(ctx, node.Txid)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var expirationTime int64
|
||||
var sweepInputs []ports.SweepInput
|
||||
|
||||
if !isPublished {
|
||||
if _, ok := blocktimeCache[node.ParentTxid]; !ok {
|
||||
isPublished, blocktime, err := s.wallet.IsTransactionPublished(ctx, node.ParentTxid)
|
||||
if !isPublished || err != nil {
|
||||
return nil, fmt.Errorf("tx %s not found", node.Txid)
|
||||
}
|
||||
|
||||
blocktimeCache[node.ParentTxid] = blocktime
|
||||
}
|
||||
|
||||
expirationTime, sweepInputs, err = s.nodeToSweepInputs(blocktimeCache[node.ParentTxid], node)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
// cache the blocktime for future use
|
||||
blocktimeCache[node.Txid] = int64(blocktime)
|
||||
|
||||
// if the tx is onchain, it means that the input is spent
|
||||
// add the children to the nodes in order to check them during the next iteration
|
||||
// We will return the error below, but are we going to schedule the tasks for the "children roots"?
|
||||
if !node.Leaf {
|
||||
children := congestionTree.Children(node.Txid)
|
||||
newNodesToCheck = append(newNodesToCheck, children...)
|
||||
continue
|
||||
}
|
||||
|
||||
// if the node is a leaf, the vtxos outputs should added as onchain outputs if they are not swept yet
|
||||
vtxoExpiration, sweepInput, err := s.leafToSweepInput(ctx, blocktime, node)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if sweepInput != nil {
|
||||
expirationTime = vtxoExpiration
|
||||
sweepInputs = []ports.SweepInput{*sweepInput}
|
||||
}
|
||||
}
|
||||
|
||||
if _, ok := sweepableOutputs[expirationTime]; !ok {
|
||||
sweepableOutputs[expirationTime] = make([]ports.SweepInput, 0)
|
||||
}
|
||||
sweepableOutputs[expirationTime] = append(sweepableOutputs[expirationTime], sweepInputs...)
|
||||
}
|
||||
|
||||
nodesToCheck = newNodesToCheck
|
||||
}
|
||||
|
||||
return sweepableOutputs, nil
|
||||
}
|
||||
|
||||
func (s *sweeper) leafToSweepInput(ctx context.Context, txBlocktime int64, node tree.Node) (int64, *ports.SweepInput, error) {
|
||||
pset, err := psetv2.NewPsetFromBase64(node.Tx)
|
||||
if err != nil {
|
||||
return -1, nil, err
|
||||
}
|
||||
|
||||
vtxo, err := extractVtxoOutpoint(pset)
|
||||
if err != nil {
|
||||
return -1, nil, err
|
||||
}
|
||||
|
||||
fromRepo, err := s.repoManager.Vtxos().GetVtxos(ctx, []domain.VtxoKey{*vtxo})
|
||||
if err != nil {
|
||||
return -1, nil, err
|
||||
}
|
||||
|
||||
if len(fromRepo) == 0 {
|
||||
return -1, nil, fmt.Errorf("vtxo not found")
|
||||
}
|
||||
|
||||
if fromRepo[0].Swept {
|
||||
return -1, nil, nil
|
||||
}
|
||||
|
||||
if fromRepo[0].Redeemed {
|
||||
return -1, nil, nil
|
||||
}
|
||||
|
||||
// if the vtxo is not swept or redeemed, add it to the onchain outputs
|
||||
pubKeyBytes, err := hex.DecodeString(fromRepo[0].Pubkey)
|
||||
if err != nil {
|
||||
return -1, nil, err
|
||||
}
|
||||
|
||||
pubKey, err := secp256k1.ParsePubKey(pubKeyBytes)
|
||||
if err != nil {
|
||||
return -1, nil, err
|
||||
}
|
||||
|
||||
sweepLeaf, lifetime, err := s.builder.GetLeafSweepClosure(node, pubKey)
|
||||
if err != nil {
|
||||
return -1, nil, err
|
||||
}
|
||||
|
||||
sweepInput := ports.SweepInput{
|
||||
InputArgs: psetv2.InputArgs{
|
||||
Txid: vtxo.Txid,
|
||||
TxIndex: vtxo.VOut,
|
||||
},
|
||||
SweepLeaf: *sweepLeaf,
|
||||
Amount: fromRepo[0].Amount,
|
||||
}
|
||||
|
||||
return txBlocktime + lifetime, &sweepInput, nil
|
||||
}
|
||||
|
||||
func (s *sweeper) nodeToSweepInputs(parentBlocktime int64, node tree.Node) (int64, []ports.SweepInput, error) {
|
||||
pset, err := psetv2.NewPsetFromBase64(node.Tx)
|
||||
if err != nil {
|
||||
return -1, nil, err
|
||||
}
|
||||
|
||||
if len(pset.Inputs) != 1 {
|
||||
return -1, nil, fmt.Errorf("invalid node pset, expect 1 input, got %d", len(pset.Inputs))
|
||||
}
|
||||
|
||||
// if the tx is not onchain, it means that the input is an existing shared output
|
||||
input := pset.Inputs[0]
|
||||
txid := chainhash.Hash(input.PreviousTxid).String()
|
||||
index := input.PreviousTxIndex
|
||||
|
||||
sweepLeaf, lifetime, err := extractSweepLeaf(input)
|
||||
if err != nil {
|
||||
return -1, nil, err
|
||||
}
|
||||
|
||||
expirationTime := parentBlocktime + lifetime
|
||||
|
||||
amount := uint64(0)
|
||||
for _, out := range pset.Outputs {
|
||||
amount += out.Value
|
||||
}
|
||||
|
||||
sweepInputs := []ports.SweepInput{
|
||||
{
|
||||
InputArgs: psetv2.InputArgs{
|
||||
Txid: txid,
|
||||
TxIndex: index,
|
||||
},
|
||||
SweepLeaf: *sweepLeaf,
|
||||
Amount: amount,
|
||||
},
|
||||
}
|
||||
|
||||
return expirationTime, sweepInputs, nil
|
||||
}
|
||||
|
||||
func computeSubTrees(congestionTree tree.CongestionTree, inputs []ports.SweepInput) ([]tree.CongestionTree, error) {
|
||||
subTrees := make(map[string]tree.CongestionTree, 0)
|
||||
|
||||
// for each sweepable input, create a sub congestion tree
|
||||
// it allows to skip the part of the tree that has been broadcasted in the next task
|
||||
for _, input := range inputs {
|
||||
subTree, err := computeSubTree(congestionTree, input.InputArgs.Txid)
|
||||
if err != nil {
|
||||
log.WithError(err).Error("error while finding sub tree")
|
||||
continue
|
||||
}
|
||||
|
||||
root, err := subTree.Root()
|
||||
if err != nil {
|
||||
log.WithError(err).Error("error while getting root node")
|
||||
continue
|
||||
}
|
||||
|
||||
subTrees[root.Txid] = subTree
|
||||
}
|
||||
|
||||
// filter out the sub trees, remove the ones that are included in others
|
||||
filteredSubTrees := make([]tree.CongestionTree, 0)
|
||||
for i, subTree := range subTrees {
|
||||
notIncludedInOtherTrees := true
|
||||
|
||||
for j, otherSubTree := range subTrees {
|
||||
if i == j {
|
||||
continue
|
||||
}
|
||||
contains, err := containsTree(otherSubTree, subTree)
|
||||
if err != nil {
|
||||
log.WithError(err).Error("error while checking if a tree contains another")
|
||||
continue
|
||||
}
|
||||
|
||||
if contains {
|
||||
notIncludedInOtherTrees = false
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if notIncludedInOtherTrees {
|
||||
filteredSubTrees = append(filteredSubTrees, subTree)
|
||||
}
|
||||
}
|
||||
|
||||
return filteredSubTrees, nil
|
||||
}
|
||||
|
||||
func computeSubTree(congestionTree tree.CongestionTree, newRoot string) (tree.CongestionTree, error) {
|
||||
for _, level := range congestionTree {
|
||||
for _, node := range level {
|
||||
if node.Txid == newRoot || node.ParentTxid == newRoot {
|
||||
newTree := make(tree.CongestionTree, 0)
|
||||
newTree = append(newTree, []tree.Node{node})
|
||||
|
||||
children := congestionTree.Children(node.Txid)
|
||||
for len(children) > 0 {
|
||||
newTree = append(newTree, children)
|
||||
newChildren := make([]tree.Node, 0)
|
||||
for _, child := range children {
|
||||
newChildren = append(newChildren, congestionTree.Children(child.Txid)...)
|
||||
}
|
||||
children = newChildren
|
||||
}
|
||||
|
||||
return newTree, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("failed to create subtree, new root not found")
|
||||
}
|
||||
|
||||
func containsTree(tr0 tree.CongestionTree, tr1 tree.CongestionTree) (bool, error) {
|
||||
tr1Root, err := tr1.Root()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
for _, level := range tr0 {
|
||||
for _, node := range level {
|
||||
if node.Txid == tr1Root.Txid {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// given a congestion tree input, searches and returns the sweep leaf and its lifetime in seconds
|
||||
func extractSweepLeaf(input psetv2.Input) (sweepLeaf *psetv2.TapLeafScript, lifetime int64, err error) {
|
||||
for _, leaf := range input.TapLeafScript {
|
||||
isSweep, _, seconds, err := tree.DecodeSweepScript(leaf.Script)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
if isSweep {
|
||||
lifetime = int64(seconds)
|
||||
sweepLeaf = &leaf
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if sweepLeaf == nil {
|
||||
return nil, 0, fmt.Errorf("sweep leaf not found")
|
||||
}
|
||||
|
||||
return sweepLeaf, lifetime, nil
|
||||
}
|
||||
|
||||
// assuming the pset is a leaf in the congestion tree, returns the vtxos outputs
|
||||
func extractVtxoOutpoint(pset *psetv2.Pset) (*domain.VtxoKey, error) {
|
||||
if len(pset.Outputs) != 2 {
|
||||
return nil, fmt.Errorf("invalid leaf pset, expect 2 outputs, got %d", len(pset.Outputs))
|
||||
}
|
||||
|
||||
utx, err := pset.UnsignedTx()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &domain.VtxoKey{
|
||||
Txid: utx.TxHash().String(),
|
||||
VOut: 0,
|
||||
}, nil
|
||||
}
|
||||
254
server/internal/core/application/utils.go
Normal file
254
server/internal/core/application/utils.go
Normal file
@@ -0,0 +1,254 @@
|
||||
package application
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"sort"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/ark-network/ark/common"
|
||||
"github.com/ark-network/ark/internal/core/domain"
|
||||
"github.com/btcsuite/btcd/btcec/v2/schnorr"
|
||||
"github.com/btcsuite/btcd/chaincfg/chainhash"
|
||||
"github.com/vulpemventures/go-elements/psetv2"
|
||||
)
|
||||
|
||||
type timedPayment struct {
|
||||
domain.Payment
|
||||
timestamp time.Time
|
||||
pingTimestamp time.Time
|
||||
}
|
||||
|
||||
type paymentsMap struct {
|
||||
lock *sync.RWMutex
|
||||
payments map[string]*timedPayment
|
||||
}
|
||||
|
||||
func newPaymentsMap(payments []domain.Payment) *paymentsMap {
|
||||
paymentsById := make(map[string]*timedPayment)
|
||||
for _, p := range payments {
|
||||
paymentsById[p.Id] = &timedPayment{p, time.Now(), time.Time{}}
|
||||
}
|
||||
lock := &sync.RWMutex{}
|
||||
return &paymentsMap{lock, paymentsById}
|
||||
}
|
||||
|
||||
func (m *paymentsMap) len() int64 {
|
||||
m.lock.RLock()
|
||||
defer m.lock.RUnlock()
|
||||
|
||||
count := int64(0)
|
||||
for _, p := range m.payments {
|
||||
if len(p.Receivers) > 0 {
|
||||
count++
|
||||
}
|
||||
}
|
||||
return count
|
||||
}
|
||||
|
||||
func (m *paymentsMap) push(payment domain.Payment) error {
|
||||
m.lock.Lock()
|
||||
defer m.lock.Unlock()
|
||||
|
||||
if _, ok := m.payments[payment.Id]; ok {
|
||||
return fmt.Errorf("duplicated inputs")
|
||||
}
|
||||
|
||||
m.payments[payment.Id] = &timedPayment{payment, time.Now(), time.Time{}}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *paymentsMap) pop(num int64) []domain.Payment {
|
||||
m.lock.Lock()
|
||||
defer m.lock.Unlock()
|
||||
|
||||
paymentsByTime := make([]timedPayment, 0, len(m.payments))
|
||||
for _, p := range m.payments {
|
||||
// Skip payments without registered receivers.
|
||||
if len(p.Receivers) <= 0 {
|
||||
continue
|
||||
}
|
||||
// Skip payments for which users didn't notify to be online in the last minute.
|
||||
if p.pingTimestamp.IsZero() || time.Since(p.pingTimestamp).Minutes() > 1 {
|
||||
continue
|
||||
}
|
||||
paymentsByTime = append(paymentsByTime, *p)
|
||||
}
|
||||
sort.SliceStable(paymentsByTime, func(i, j int) bool {
|
||||
return paymentsByTime[i].timestamp.Before(paymentsByTime[j].timestamp)
|
||||
})
|
||||
|
||||
if num < 0 || num > int64(len(paymentsByTime)) {
|
||||
num = int64(len(paymentsByTime))
|
||||
}
|
||||
|
||||
payments := make([]domain.Payment, 0, num)
|
||||
for _, p := range paymentsByTime[:num] {
|
||||
payments = append(payments, p.Payment)
|
||||
delete(m.payments, p.Id)
|
||||
}
|
||||
return payments
|
||||
}
|
||||
|
||||
func (m *paymentsMap) update(payment domain.Payment) error {
|
||||
m.lock.Lock()
|
||||
defer m.lock.Unlock()
|
||||
|
||||
p, ok := m.payments[payment.Id]
|
||||
if !ok {
|
||||
return fmt.Errorf("payment %s not found", payment.Id)
|
||||
}
|
||||
|
||||
p.Payment = payment
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *paymentsMap) updatePingTimestamp(id string) error {
|
||||
m.lock.Lock()
|
||||
defer m.lock.Unlock()
|
||||
|
||||
payment, ok := m.payments[id]
|
||||
if !ok {
|
||||
return fmt.Errorf("payment %s not found", id)
|
||||
}
|
||||
|
||||
payment.pingTimestamp = time.Now()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *paymentsMap) view(id string) (*domain.Payment, bool) {
|
||||
m.lock.RLock()
|
||||
defer m.lock.RUnlock()
|
||||
|
||||
payment, ok := m.payments[id]
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
return &domain.Payment{
|
||||
Id: payment.Id,
|
||||
Inputs: payment.Inputs,
|
||||
Receivers: payment.Receivers,
|
||||
}, true
|
||||
}
|
||||
|
||||
type signedTx struct {
|
||||
tx string
|
||||
signed bool
|
||||
}
|
||||
|
||||
type forfeitTxsMap struct {
|
||||
lock *sync.RWMutex
|
||||
forfeitTxs map[string]*signedTx
|
||||
genesisBlockHash *chainhash.Hash
|
||||
}
|
||||
|
||||
func newForfeitTxsMap(genesisBlockHash *chainhash.Hash) *forfeitTxsMap {
|
||||
return &forfeitTxsMap{&sync.RWMutex{}, make(map[string]*signedTx), genesisBlockHash}
|
||||
}
|
||||
|
||||
func (m *forfeitTxsMap) push(txs []string) {
|
||||
m.lock.Lock()
|
||||
defer m.lock.Unlock()
|
||||
|
||||
faucetTxID, _ := hex.DecodeString(faucetVtxo.Txid)
|
||||
|
||||
for _, tx := range txs {
|
||||
ptx, _ := psetv2.NewPsetFromBase64(tx)
|
||||
utx, _ := ptx.UnsignedTx()
|
||||
|
||||
signed := false
|
||||
|
||||
// find the faucet vtxos, and mark them as signed
|
||||
for _, input := range ptx.Inputs {
|
||||
if bytes.Equal(input.PreviousTxid, faucetTxID) {
|
||||
signed = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
m.forfeitTxs[utx.TxHash().String()] = &signedTx{tx, signed}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *forfeitTxsMap) sign(txs []string) error {
|
||||
m.lock.Lock()
|
||||
defer m.lock.Unlock()
|
||||
|
||||
for _, tx := range txs {
|
||||
ptx, _ := psetv2.NewPsetFromBase64(tx)
|
||||
utx, _ := ptx.UnsignedTx()
|
||||
txid := utx.TxHash().String()
|
||||
|
||||
if _, ok := m.forfeitTxs[txid]; ok {
|
||||
for index, input := range ptx.Inputs {
|
||||
if len(input.TapScriptSig) > 0 {
|
||||
for _, tapScriptSig := range input.TapScriptSig {
|
||||
leafHash, err := chainhash.NewHash(tapScriptSig.LeafHash)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
preimage, err := common.TaprootPreimage(
|
||||
m.genesisBlockHash,
|
||||
ptx,
|
||||
index,
|
||||
leafHash,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
sig, err := schnorr.ParseSignature(tapScriptSig.Signature)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
pubkey, err := schnorr.ParsePubKey(tapScriptSig.PubKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if sig.Verify(preimage, pubkey) {
|
||||
m.forfeitTxs[txid].signed = true
|
||||
} else {
|
||||
return fmt.Errorf("invalid signature")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *forfeitTxsMap) pop() (signed, unsigned []string) {
|
||||
m.lock.Lock()
|
||||
defer m.lock.Unlock()
|
||||
|
||||
for _, t := range m.forfeitTxs {
|
||||
if t.signed {
|
||||
signed = append(signed, t.tx)
|
||||
} else {
|
||||
unsigned = append(unsigned, t.tx)
|
||||
}
|
||||
}
|
||||
|
||||
m.forfeitTxs = make(map[string]*signedTx)
|
||||
return signed, unsigned
|
||||
}
|
||||
|
||||
func (m *forfeitTxsMap) view() []string {
|
||||
m.lock.RLock()
|
||||
defer m.lock.RUnlock()
|
||||
|
||||
txs := make([]string, 0, len(m.forfeitTxs))
|
||||
for _, tx := range m.forfeitTxs {
|
||||
txs = append(txs, tx.tx)
|
||||
}
|
||||
return txs
|
||||
}
|
||||
44
server/internal/core/domain/events.go
Normal file
44
server/internal/core/domain/events.go
Normal file
@@ -0,0 +1,44 @@
|
||||
package domain
|
||||
|
||||
import "github.com/ark-network/ark/common/tree"
|
||||
|
||||
type RoundEvent interface {
|
||||
isEvent()
|
||||
}
|
||||
|
||||
func (r RoundStarted) isEvent() {}
|
||||
func (r RoundFinalizationStarted) isEvent() {}
|
||||
func (r RoundFinalized) isEvent() {}
|
||||
func (r RoundFailed) isEvent() {}
|
||||
func (r PaymentsRegistered) isEvent() {}
|
||||
|
||||
type RoundStarted struct {
|
||||
Id string
|
||||
Timestamp int64
|
||||
}
|
||||
|
||||
type RoundFinalizationStarted struct {
|
||||
Id string
|
||||
CongestionTree tree.CongestionTree
|
||||
Connectors []string
|
||||
UnsignedForfeitTxs []string
|
||||
PoolTx string
|
||||
}
|
||||
|
||||
type RoundFinalized struct {
|
||||
Id string
|
||||
Txid string
|
||||
ForfeitTxs []string
|
||||
Timestamp int64
|
||||
}
|
||||
|
||||
type RoundFailed struct {
|
||||
Id string
|
||||
Err string
|
||||
Timestamp int64
|
||||
}
|
||||
|
||||
type PaymentsRegistered struct {
|
||||
Id string
|
||||
Payments []Payment
|
||||
}
|
||||
129
server/internal/core/domain/payment.go
Normal file
129
server/internal/core/domain/payment.go
Normal file
@@ -0,0 +1,129 @@
|
||||
package domain
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"hash"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
const dustAmount = 450
|
||||
|
||||
type Payment struct {
|
||||
Id string
|
||||
Inputs []Vtxo
|
||||
Receivers []Receiver
|
||||
}
|
||||
|
||||
func NewPayment(inputs []Vtxo) (*Payment, error) {
|
||||
p := &Payment{
|
||||
Id: uuid.New().String(),
|
||||
Inputs: inputs,
|
||||
}
|
||||
if err := p.validate(true); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return p, nil
|
||||
}
|
||||
|
||||
func (p *Payment) AddReceivers(receivers []Receiver) (err error) {
|
||||
if p.Receivers == nil {
|
||||
p.Receivers = make([]Receiver, 0)
|
||||
}
|
||||
p.Receivers = append(p.Receivers, receivers...)
|
||||
defer func() {
|
||||
if err != nil {
|
||||
p.Receivers = p.Receivers[:len(p.Receivers)-len(receivers)]
|
||||
}
|
||||
}()
|
||||
err = p.validate(false)
|
||||
return
|
||||
}
|
||||
|
||||
func (p Payment) TotalInputAmount() uint64 {
|
||||
tot := uint64(0)
|
||||
for _, in := range p.Inputs {
|
||||
tot += in.Amount
|
||||
}
|
||||
return tot
|
||||
}
|
||||
|
||||
func (p Payment) TotalOutputAmount() uint64 {
|
||||
tot := uint64(0)
|
||||
for _, r := range p.Receivers {
|
||||
tot += r.Amount
|
||||
}
|
||||
return tot
|
||||
}
|
||||
|
||||
func (p Payment) validate(ignoreOuts bool) error {
|
||||
if len(p.Id) <= 0 {
|
||||
return fmt.Errorf("missing id")
|
||||
}
|
||||
if len(p.Inputs) <= 0 {
|
||||
return fmt.Errorf("missing inputs")
|
||||
}
|
||||
if ignoreOuts {
|
||||
return nil
|
||||
}
|
||||
if len(p.Receivers) <= 0 {
|
||||
return fmt.Errorf("missing outputs")
|
||||
}
|
||||
// Check that input and output and output amounts match.
|
||||
inAmount := p.TotalInputAmount()
|
||||
outAmount := uint64(0)
|
||||
for _, r := range p.Receivers {
|
||||
if len(r.OnchainAddress) <= 0 && len(r.Pubkey) <= 0 {
|
||||
return fmt.Errorf("missing receiver destination")
|
||||
}
|
||||
if r.Amount < dustAmount {
|
||||
return fmt.Errorf("receiver amount must be greater than dust")
|
||||
}
|
||||
outAmount += r.Amount
|
||||
}
|
||||
if inAmount != outAmount {
|
||||
return fmt.Errorf("input and output amounts mismatch")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type VtxoKey struct {
|
||||
Txid string
|
||||
VOut uint32
|
||||
}
|
||||
|
||||
func (k VtxoKey) Hash() string {
|
||||
calcHash := func(buf []byte, hasher hash.Hash) []byte {
|
||||
_, _ = hasher.Write(buf)
|
||||
return hasher.Sum(nil)
|
||||
}
|
||||
|
||||
hash160 := func(buf []byte) []byte {
|
||||
return calcHash(calcHash(buf, sha256.New()), sha256.New())
|
||||
}
|
||||
|
||||
buf, _ := hex.DecodeString(k.Txid)
|
||||
buf = append(buf, byte(k.VOut))
|
||||
return hex.EncodeToString(hash160(buf))
|
||||
}
|
||||
|
||||
type Receiver struct {
|
||||
Pubkey string
|
||||
Amount uint64
|
||||
OnchainAddress string
|
||||
}
|
||||
|
||||
func (r Receiver) IsOnchain() bool {
|
||||
return len(r.OnchainAddress) > 0
|
||||
}
|
||||
|
||||
type Vtxo struct {
|
||||
VtxoKey
|
||||
Receiver
|
||||
PoolTx string
|
||||
Spent bool
|
||||
Redeemed bool
|
||||
Swept bool
|
||||
}
|
||||
111
server/internal/core/domain/payment_test.go
Normal file
111
server/internal/core/domain/payment_test.go
Normal file
@@ -0,0 +1,111 @@
|
||||
package domain_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/ark-network/ark/internal/core/domain"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
var inputs = []domain.Vtxo{
|
||||
{
|
||||
VtxoKey: domain.VtxoKey{
|
||||
Txid: "0000000000000000000000000000000000000000000000000000000000000000",
|
||||
VOut: 0,
|
||||
},
|
||||
Receiver: domain.Receiver{
|
||||
Pubkey: "030000000000000000000000000000000000000000000000000000000000000001",
|
||||
Amount: 1000,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
func TestPayment(t *testing.T) {
|
||||
t.Run("new_payment", func(t *testing.T) {
|
||||
t.Run("vaild", func(t *testing.T) {
|
||||
payment, err := domain.NewPayment(inputs)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, payment)
|
||||
require.NotEmpty(t, payment.Id)
|
||||
require.Exactly(t, inputs, payment.Inputs)
|
||||
require.Empty(t, payment.Receivers)
|
||||
})
|
||||
|
||||
t.Run("invaild", func(t *testing.T) {
|
||||
fixtures := []struct {
|
||||
inputs []domain.Vtxo
|
||||
expectedErr string
|
||||
}{
|
||||
{
|
||||
inputs: nil,
|
||||
expectedErr: "missing inputs",
|
||||
},
|
||||
}
|
||||
|
||||
for _, f := range fixtures {
|
||||
payment, err := domain.NewPayment(f.inputs)
|
||||
require.EqualError(t, err, f.expectedErr)
|
||||
require.Nil(t, payment)
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("add_receivers", func(t *testing.T) {
|
||||
t.Run("valid", func(t *testing.T) {
|
||||
payment, err := domain.NewPayment(inputs)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, payment)
|
||||
|
||||
err = payment.AddReceivers([]domain.Receiver{
|
||||
{
|
||||
Pubkey: "030000000000000000000000000000000000000000000000000000000000000001",
|
||||
Amount: 450,
|
||||
},
|
||||
{
|
||||
Pubkey: "020000000000000000000000000000000000000000000000000000000000000002",
|
||||
Amount: 550,
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("invalid", func(t *testing.T) {
|
||||
fixtures := []struct {
|
||||
receivers []domain.Receiver
|
||||
expectedErr string
|
||||
}{
|
||||
{
|
||||
receivers: nil,
|
||||
expectedErr: "missing outputs",
|
||||
},
|
||||
{
|
||||
receivers: []domain.Receiver{
|
||||
{
|
||||
Pubkey: "030000000000000000000000000000000000000000000000000000000000000001",
|
||||
Amount: 400,
|
||||
},
|
||||
},
|
||||
expectedErr: "receiver amount must be greater than dust",
|
||||
},
|
||||
{
|
||||
receivers: []domain.Receiver{
|
||||
{
|
||||
Pubkey: "030000000000000000000000000000000000000000000000000000000000000001",
|
||||
Amount: 600,
|
||||
},
|
||||
},
|
||||
expectedErr: "input and output amounts mismatch",
|
||||
},
|
||||
}
|
||||
|
||||
payment, err := domain.NewPayment(inputs)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, payment)
|
||||
|
||||
for _, f := range fixtures {
|
||||
err := payment.AddReceivers(f.receivers)
|
||||
require.EqualError(t, err, f.expectedErr)
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
253
server/internal/core/domain/round.go
Normal file
253
server/internal/core/domain/round.go
Normal file
@@ -0,0 +1,253 @@
|
||||
package domain
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/ark-network/ark/common/tree"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
const (
|
||||
UndefinedStage RoundStage = iota
|
||||
RegistrationStage
|
||||
FinalizationStage
|
||||
)
|
||||
|
||||
type RoundStage int
|
||||
|
||||
func (s RoundStage) String() string {
|
||||
switch s {
|
||||
case RegistrationStage:
|
||||
return "REGISTRATION_STAGE"
|
||||
case FinalizationStage:
|
||||
return "FINALIZATION_STAGE"
|
||||
default:
|
||||
return "UNDEFINED_STAGE"
|
||||
}
|
||||
}
|
||||
|
||||
type Stage struct {
|
||||
Code RoundStage
|
||||
Ended bool
|
||||
Failed bool
|
||||
}
|
||||
|
||||
type Round struct {
|
||||
Id string
|
||||
StartingTimestamp int64
|
||||
EndingTimestamp int64
|
||||
Stage Stage
|
||||
Payments map[string]Payment
|
||||
Txid string
|
||||
UnsignedTx string
|
||||
ForfeitTxs []string
|
||||
CongestionTree tree.CongestionTree
|
||||
Connectors []string
|
||||
DustAmount uint64
|
||||
Version uint
|
||||
Swept bool // true if all the vtxos are vtxo.Swept
|
||||
changes []RoundEvent
|
||||
}
|
||||
|
||||
func NewRound(dustAmount uint64) *Round {
|
||||
return &Round{
|
||||
Id: uuid.New().String(),
|
||||
DustAmount: dustAmount,
|
||||
Payments: make(map[string]Payment),
|
||||
changes: make([]RoundEvent, 0),
|
||||
}
|
||||
}
|
||||
|
||||
func NewRoundFromEvents(events []RoundEvent) *Round {
|
||||
r := &Round{}
|
||||
|
||||
for _, event := range events {
|
||||
r.On(event, true)
|
||||
}
|
||||
|
||||
r.changes = append([]RoundEvent{}, events...)
|
||||
|
||||
return r
|
||||
}
|
||||
|
||||
func (r *Round) Events() []RoundEvent {
|
||||
return r.changes
|
||||
}
|
||||
|
||||
func (r *Round) On(event RoundEvent, replayed bool) {
|
||||
switch e := event.(type) {
|
||||
case RoundStarted:
|
||||
r.Stage.Code = RegistrationStage
|
||||
r.Id = e.Id
|
||||
r.StartingTimestamp = e.Timestamp
|
||||
case RoundFinalizationStarted:
|
||||
r.Stage.Code = FinalizationStage
|
||||
r.CongestionTree = e.CongestionTree
|
||||
r.Connectors = append([]string{}, e.Connectors...)
|
||||
r.UnsignedTx = e.PoolTx
|
||||
case RoundFinalized:
|
||||
r.Stage.Ended = true
|
||||
r.Txid = e.Txid
|
||||
r.ForfeitTxs = append([]string{}, e.ForfeitTxs...)
|
||||
r.EndingTimestamp = e.Timestamp
|
||||
case RoundFailed:
|
||||
r.Stage.Failed = true
|
||||
r.EndingTimestamp = e.Timestamp
|
||||
case PaymentsRegistered:
|
||||
if r.Payments == nil {
|
||||
r.Payments = make(map[string]Payment)
|
||||
}
|
||||
for _, p := range e.Payments {
|
||||
r.Payments[p.Id] = p
|
||||
}
|
||||
}
|
||||
|
||||
if replayed {
|
||||
r.Version++
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Round) StartRegistration() ([]RoundEvent, error) {
|
||||
empty := Stage{}
|
||||
if r.Stage != empty {
|
||||
return nil, fmt.Errorf("not in a valid stage to start payment registration")
|
||||
}
|
||||
|
||||
event := RoundStarted{
|
||||
Id: r.Id,
|
||||
Timestamp: time.Now().Unix(),
|
||||
}
|
||||
r.raise(event)
|
||||
|
||||
return []RoundEvent{event}, nil
|
||||
}
|
||||
|
||||
func (r *Round) RegisterPayments(payments []Payment) ([]RoundEvent, error) {
|
||||
if r.Stage.Code != RegistrationStage || r.IsFailed() {
|
||||
return nil, fmt.Errorf("not in a valid stage to register payments")
|
||||
}
|
||||
if len(payments) <= 0 {
|
||||
return nil, fmt.Errorf("missing payments to register")
|
||||
}
|
||||
for _, p := range payments {
|
||||
if err := p.validate(false); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
event := PaymentsRegistered{
|
||||
Id: r.Id,
|
||||
Payments: payments,
|
||||
}
|
||||
r.raise(event)
|
||||
|
||||
return []RoundEvent{event}, nil
|
||||
}
|
||||
|
||||
func (r *Round) StartFinalization(connectors []string, congestionTree tree.CongestionTree, poolTx string) ([]RoundEvent, error) {
|
||||
if len(connectors) <= 0 {
|
||||
return nil, fmt.Errorf("missing list of connectors")
|
||||
}
|
||||
if len(congestionTree) <= 0 {
|
||||
return nil, fmt.Errorf("missing congestion tree")
|
||||
}
|
||||
if len(poolTx) <= 0 {
|
||||
return nil, fmt.Errorf("missing unsigned pool tx")
|
||||
}
|
||||
if r.Stage.Code != RegistrationStage || r.IsFailed() {
|
||||
return nil, fmt.Errorf("not in a valid stage to start payment finalization")
|
||||
}
|
||||
if len(r.Payments) <= 0 {
|
||||
return nil, fmt.Errorf("no payments registered")
|
||||
}
|
||||
|
||||
event := RoundFinalizationStarted{
|
||||
Id: r.Id,
|
||||
CongestionTree: congestionTree,
|
||||
Connectors: connectors,
|
||||
PoolTx: poolTx,
|
||||
}
|
||||
r.raise(event)
|
||||
|
||||
return []RoundEvent{event}, nil
|
||||
}
|
||||
|
||||
func (r *Round) EndFinalization(forfeitTxs []string, txid string) ([]RoundEvent, error) {
|
||||
if len(forfeitTxs) <= 0 {
|
||||
return nil, fmt.Errorf("missing list of signed forfeit txs")
|
||||
}
|
||||
if len(txid) <= 0 {
|
||||
return nil, fmt.Errorf("missing pool txid")
|
||||
}
|
||||
if r.Stage.Code != FinalizationStage || r.IsFailed() {
|
||||
return nil, fmt.Errorf("not in a valid stage to end payment finalization")
|
||||
}
|
||||
if r.Stage.Ended {
|
||||
return nil, fmt.Errorf("round already finalized")
|
||||
}
|
||||
event := RoundFinalized{
|
||||
Id: r.Id,
|
||||
Txid: txid,
|
||||
ForfeitTxs: forfeitTxs,
|
||||
Timestamp: time.Now().Unix(),
|
||||
}
|
||||
r.raise(event)
|
||||
|
||||
return []RoundEvent{event}, nil
|
||||
}
|
||||
|
||||
func (r *Round) Fail(err error) []RoundEvent {
|
||||
if r.Stage.Failed {
|
||||
return nil
|
||||
}
|
||||
event := RoundFailed{
|
||||
Id: r.Id,
|
||||
Err: err.Error(),
|
||||
Timestamp: time.Now().Unix(),
|
||||
}
|
||||
r.raise(event)
|
||||
|
||||
return []RoundEvent{event}
|
||||
}
|
||||
|
||||
func (r *Round) IsStarted() bool {
|
||||
empty := Stage{}
|
||||
return !r.IsFailed() && !r.IsEnded() && r.Stage != empty
|
||||
}
|
||||
|
||||
func (r *Round) IsEnded() bool {
|
||||
return !r.IsFailed() && r.Stage.Code == FinalizationStage && r.Stage.Ended
|
||||
}
|
||||
|
||||
func (r *Round) IsFailed() bool {
|
||||
return r.Stage.Failed
|
||||
}
|
||||
|
||||
func (r *Round) TotalInputAmount() uint64 {
|
||||
totInputs := 0
|
||||
for _, p := range r.Payments {
|
||||
totInputs += len(p.Inputs)
|
||||
}
|
||||
return uint64(totInputs * int(r.DustAmount))
|
||||
}
|
||||
|
||||
func (r *Round) TotalOutputAmount() uint64 {
|
||||
tot := uint64(0)
|
||||
for _, p := range r.Payments {
|
||||
tot += p.TotalOutputAmount()
|
||||
}
|
||||
return tot
|
||||
}
|
||||
|
||||
func (r *Round) Sweep() {
|
||||
r.Swept = true
|
||||
}
|
||||
|
||||
func (r *Round) raise(event RoundEvent) {
|
||||
if r.changes == nil {
|
||||
r.changes = make([]RoundEvent, 0)
|
||||
}
|
||||
r.changes = append(r.changes, event)
|
||||
r.On(event, false)
|
||||
}
|
||||
28
server/internal/core/domain/round_repo.go
Normal file
28
server/internal/core/domain/round_repo.go
Normal file
@@ -0,0 +1,28 @@
|
||||
package domain
|
||||
|
||||
import (
|
||||
"context"
|
||||
)
|
||||
|
||||
type RoundEventRepository interface {
|
||||
Save(ctx context.Context, id string, events ...RoundEvent) error
|
||||
Load(ctx context.Context, id string) (*Round, error)
|
||||
}
|
||||
|
||||
type RoundRepository interface {
|
||||
AddOrUpdateRound(ctx context.Context, round Round) error
|
||||
GetCurrentRound(ctx context.Context) (*Round, error)
|
||||
GetRoundWithId(ctx context.Context, id string) (*Round, error)
|
||||
GetRoundWithTxid(ctx context.Context, txid string) (*Round, error)
|
||||
GetSweepableRounds(ctx context.Context) ([]Round, error)
|
||||
}
|
||||
|
||||
type VtxoRepository interface {
|
||||
AddVtxos(ctx context.Context, vtxos []Vtxo) error
|
||||
SpendVtxos(ctx context.Context, vtxos []VtxoKey) error
|
||||
RedeemVtxos(ctx context.Context, vtxos []VtxoKey) ([]Vtxo, error)
|
||||
GetVtxos(ctx context.Context, vtxos []VtxoKey) ([]Vtxo, error)
|
||||
GetVtxosForRound(ctx context.Context, txid string) ([]Vtxo, error)
|
||||
SweepVtxos(ctx context.Context, vtxos []VtxoKey) error
|
||||
GetSpendableVtxos(ctx context.Context, pubkey string) ([]Vtxo, error)
|
||||
}
|
||||
576
server/internal/core/domain/round_test.go
Normal file
576
server/internal/core/domain/round_test.go
Normal file
@@ -0,0 +1,576 @@
|
||||
package domain_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/ark-network/ark/common/tree"
|
||||
"github.com/ark-network/ark/internal/core/domain"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
var (
|
||||
dustAmount = uint64(450)
|
||||
payments = []domain.Payment{
|
||||
{
|
||||
Id: "0",
|
||||
Inputs: []domain.Vtxo{{
|
||||
VtxoKey: domain.VtxoKey{
|
||||
Txid: txid,
|
||||
VOut: 0,
|
||||
},
|
||||
Receiver: domain.Receiver{
|
||||
Pubkey: pubkey,
|
||||
Amount: 2000,
|
||||
},
|
||||
}},
|
||||
Receivers: []domain.Receiver{
|
||||
{
|
||||
Pubkey: pubkey,
|
||||
Amount: 700,
|
||||
},
|
||||
{
|
||||
Pubkey: pubkey,
|
||||
Amount: 700,
|
||||
},
|
||||
{
|
||||
Pubkey: pubkey,
|
||||
Amount: 600,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Id: "1",
|
||||
Inputs: []domain.Vtxo{
|
||||
{
|
||||
VtxoKey: domain.VtxoKey{
|
||||
Txid: txid,
|
||||
VOut: 0,
|
||||
},
|
||||
Receiver: domain.Receiver{
|
||||
Pubkey: pubkey,
|
||||
Amount: 1000,
|
||||
},
|
||||
},
|
||||
{
|
||||
VtxoKey: domain.VtxoKey{
|
||||
Txid: txid,
|
||||
VOut: 0,
|
||||
},
|
||||
Receiver: domain.Receiver{
|
||||
Pubkey: pubkey,
|
||||
Amount: 1000,
|
||||
},
|
||||
},
|
||||
},
|
||||
Receivers: []domain.Receiver{{
|
||||
Pubkey: pubkey,
|
||||
Amount: 2000,
|
||||
}},
|
||||
},
|
||||
}
|
||||
emptyPtx = "cHNldP8BAgQCAAAAAQQBAAEFAQABBgEDAfsEAgAAAAA="
|
||||
emptyTx = "0200000000000000000000"
|
||||
txid = "0000000000000000000000000000000000000000000000000000000000000000"
|
||||
pubkey = "030000000000000000000000000000000000000000000000000000000000000001"
|
||||
congestionTree = tree.CongestionTree{
|
||||
{
|
||||
{
|
||||
Txid: txid,
|
||||
Tx: emptyPtx,
|
||||
ParentTxid: txid,
|
||||
},
|
||||
},
|
||||
{
|
||||
{
|
||||
Txid: txid,
|
||||
Tx: emptyPtx,
|
||||
ParentTxid: txid,
|
||||
},
|
||||
{
|
||||
Txid: txid,
|
||||
Tx: emptyPtx,
|
||||
ParentTxid: txid,
|
||||
},
|
||||
},
|
||||
{
|
||||
{
|
||||
Txid: txid,
|
||||
Tx: emptyPtx,
|
||||
ParentTxid: txid,
|
||||
},
|
||||
{
|
||||
Txid: txid,
|
||||
Tx: emptyPtx,
|
||||
ParentTxid: txid,
|
||||
},
|
||||
{
|
||||
Txid: txid,
|
||||
Tx: emptyPtx,
|
||||
ParentTxid: txid,
|
||||
},
|
||||
{
|
||||
Txid: txid,
|
||||
Tx: emptyPtx,
|
||||
ParentTxid: txid,
|
||||
},
|
||||
},
|
||||
}
|
||||
connectors = []string{emptyPtx, emptyPtx, emptyPtx}
|
||||
forfeitTxs = []string{emptyPtx, emptyPtx, emptyPtx, emptyPtx, emptyPtx, emptyPtx, emptyPtx, emptyPtx, emptyPtx}
|
||||
poolTx = emptyTx
|
||||
)
|
||||
|
||||
func TestRound(t *testing.T) {
|
||||
testStartRegistration(t)
|
||||
|
||||
testRegisterPayments(t)
|
||||
|
||||
testStartFinalization(t)
|
||||
|
||||
testEndFinalization(t)
|
||||
|
||||
testFail(t)
|
||||
}
|
||||
|
||||
func testStartRegistration(t *testing.T) {
|
||||
t.Run("start_registration", func(t *testing.T) {
|
||||
t.Run("valid", func(t *testing.T) {
|
||||
round := domain.NewRound(dustAmount)
|
||||
require.NotNil(t, round)
|
||||
require.NotEmpty(t, round.Id)
|
||||
require.Empty(t, round.Events())
|
||||
require.False(t, round.IsStarted())
|
||||
require.False(t, round.IsEnded())
|
||||
require.False(t, round.IsFailed())
|
||||
|
||||
events, err := round.StartRegistration()
|
||||
require.NoError(t, err)
|
||||
require.Len(t, events, 1)
|
||||
require.True(t, round.IsStarted())
|
||||
require.False(t, round.IsEnded())
|
||||
require.False(t, round.IsFailed())
|
||||
|
||||
event, ok := events[0].(domain.RoundStarted)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, round.Id, event.Id)
|
||||
require.Equal(t, round.StartingTimestamp, event.Timestamp)
|
||||
})
|
||||
|
||||
t.Run("invalid", func(t *testing.T) {
|
||||
fixtures := []struct {
|
||||
round *domain.Round
|
||||
expectedErr string
|
||||
}{
|
||||
{
|
||||
round: &domain.Round{
|
||||
Id: "id",
|
||||
Stage: domain.Stage{
|
||||
Code: domain.UndefinedStage,
|
||||
Failed: true,
|
||||
},
|
||||
},
|
||||
expectedErr: "not in a valid stage to start payment registration",
|
||||
},
|
||||
{
|
||||
round: &domain.Round{
|
||||
Id: "id",
|
||||
Stage: domain.Stage{
|
||||
Code: domain.RegistrationStage,
|
||||
},
|
||||
},
|
||||
expectedErr: "not in a valid stage to start payment registration",
|
||||
},
|
||||
{
|
||||
round: &domain.Round{
|
||||
Id: "id",
|
||||
Stage: domain.Stage{
|
||||
Code: domain.FinalizationStage,
|
||||
},
|
||||
},
|
||||
expectedErr: "not in a valid stage to start payment registration",
|
||||
},
|
||||
}
|
||||
|
||||
for _, f := range fixtures {
|
||||
events, err := f.round.StartRegistration()
|
||||
require.EqualError(t, err, f.expectedErr)
|
||||
require.Empty(t, events)
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func testRegisterPayments(t *testing.T) {
|
||||
t.Run("register_payments", func(t *testing.T) {
|
||||
t.Run("valid", func(t *testing.T) {
|
||||
round := domain.NewRound(dustAmount)
|
||||
events, err := round.StartRegistration()
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, events)
|
||||
|
||||
events, err = round.RegisterPayments(payments)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, events, 1)
|
||||
require.Condition(t, func() bool {
|
||||
for _, payment := range payments {
|
||||
_, ok := round.Payments[payment.Id]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
event, ok := events[0].(domain.PaymentsRegistered)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, round.Id, event.Id)
|
||||
require.Equal(t, payments, event.Payments)
|
||||
})
|
||||
|
||||
t.Run("invalid", func(t *testing.T) {
|
||||
fixtures := []struct {
|
||||
round *domain.Round
|
||||
payments []domain.Payment
|
||||
expectedErr string
|
||||
}{
|
||||
{
|
||||
round: &domain.Round{
|
||||
Id: "id",
|
||||
Stage: domain.Stage{},
|
||||
},
|
||||
payments: payments,
|
||||
expectedErr: "not in a valid stage to register payments",
|
||||
},
|
||||
{
|
||||
round: &domain.Round{
|
||||
Id: "id",
|
||||
Stage: domain.Stage{
|
||||
Code: domain.RegistrationStage,
|
||||
Failed: true,
|
||||
},
|
||||
},
|
||||
payments: payments,
|
||||
expectedErr: "not in a valid stage to register payments",
|
||||
},
|
||||
{
|
||||
round: &domain.Round{
|
||||
Id: "id",
|
||||
Stage: domain.Stage{
|
||||
Code: domain.FinalizationStage,
|
||||
},
|
||||
},
|
||||
payments: payments,
|
||||
expectedErr: "not in a valid stage to register payments",
|
||||
},
|
||||
{
|
||||
round: &domain.Round{
|
||||
Id: "id",
|
||||
Stage: domain.Stage{
|
||||
Code: domain.RegistrationStage,
|
||||
},
|
||||
},
|
||||
payments: nil,
|
||||
expectedErr: "missing payments to register",
|
||||
},
|
||||
}
|
||||
|
||||
for _, f := range fixtures {
|
||||
events, err := f.round.RegisterPayments(f.payments)
|
||||
require.EqualError(t, err, f.expectedErr)
|
||||
require.Empty(t, events)
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func testStartFinalization(t *testing.T) {
|
||||
t.Run("start_finalization", func(t *testing.T) {
|
||||
t.Run("valid", func(t *testing.T) {
|
||||
round := domain.NewRound(dustAmount)
|
||||
events, err := round.StartRegistration()
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, events)
|
||||
|
||||
events, err = round.RegisterPayments(payments)
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, events)
|
||||
|
||||
events, err = round.StartFinalization(connectors, congestionTree, poolTx)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, events, 1)
|
||||
require.True(t, round.IsStarted())
|
||||
require.False(t, round.IsEnded())
|
||||
require.False(t, round.IsFailed())
|
||||
|
||||
event, ok := events[0].(domain.RoundFinalizationStarted)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, round.Id, event.Id)
|
||||
require.Exactly(t, connectors, event.Connectors)
|
||||
require.Exactly(t, congestionTree, event.CongestionTree)
|
||||
require.Exactly(t, poolTx, event.PoolTx)
|
||||
})
|
||||
|
||||
t.Run("invalid", func(t *testing.T) {
|
||||
paymentsById := map[string]domain.Payment{}
|
||||
for _, p := range payments {
|
||||
paymentsById[p.Id] = p
|
||||
}
|
||||
fixtures := []struct {
|
||||
round *domain.Round
|
||||
connectors []string
|
||||
tree tree.CongestionTree
|
||||
poolTx string
|
||||
expectedErr string
|
||||
}{
|
||||
{
|
||||
round: &domain.Round{
|
||||
Id: "0",
|
||||
Stage: domain.Stage{
|
||||
Code: domain.RegistrationStage,
|
||||
},
|
||||
Payments: paymentsById,
|
||||
},
|
||||
connectors: nil,
|
||||
tree: congestionTree,
|
||||
poolTx: poolTx,
|
||||
expectedErr: "missing list of connectors",
|
||||
},
|
||||
{
|
||||
round: &domain.Round{
|
||||
Id: "0",
|
||||
Stage: domain.Stage{
|
||||
Code: domain.RegistrationStage,
|
||||
},
|
||||
Payments: paymentsById,
|
||||
},
|
||||
connectors: connectors,
|
||||
tree: nil,
|
||||
poolTx: poolTx,
|
||||
expectedErr: "missing congestion tree",
|
||||
},
|
||||
{
|
||||
round: &domain.Round{
|
||||
Id: "0",
|
||||
Stage: domain.Stage{
|
||||
Code: domain.RegistrationStage,
|
||||
},
|
||||
Payments: paymentsById,
|
||||
},
|
||||
connectors: connectors,
|
||||
tree: congestionTree,
|
||||
poolTx: "",
|
||||
expectedErr: "missing unsigned pool tx",
|
||||
},
|
||||
{
|
||||
round: &domain.Round{
|
||||
Id: "0",
|
||||
Stage: domain.Stage{
|
||||
Code: domain.RegistrationStage,
|
||||
},
|
||||
Payments: nil,
|
||||
},
|
||||
connectors: connectors,
|
||||
tree: congestionTree,
|
||||
poolTx: poolTx,
|
||||
expectedErr: "no payments registered",
|
||||
},
|
||||
{
|
||||
round: &domain.Round{
|
||||
Id: "0",
|
||||
Stage: domain.Stage{
|
||||
Code: domain.UndefinedStage,
|
||||
},
|
||||
Payments: paymentsById,
|
||||
},
|
||||
connectors: connectors,
|
||||
tree: congestionTree,
|
||||
poolTx: poolTx,
|
||||
expectedErr: "not in a valid stage to start payment finalization",
|
||||
},
|
||||
{
|
||||
round: &domain.Round{
|
||||
Id: "0",
|
||||
Stage: domain.Stage{
|
||||
Code: domain.RegistrationStage,
|
||||
Failed: true,
|
||||
},
|
||||
Payments: paymentsById,
|
||||
},
|
||||
connectors: connectors,
|
||||
tree: congestionTree,
|
||||
poolTx: poolTx,
|
||||
expectedErr: "not in a valid stage to start payment finalization",
|
||||
},
|
||||
{
|
||||
round: &domain.Round{
|
||||
Id: "0",
|
||||
Stage: domain.Stage{
|
||||
Code: domain.FinalizationStage,
|
||||
},
|
||||
Payments: paymentsById,
|
||||
},
|
||||
connectors: connectors,
|
||||
tree: congestionTree,
|
||||
poolTx: poolTx,
|
||||
expectedErr: "not in a valid stage to start payment finalization",
|
||||
},
|
||||
}
|
||||
|
||||
for _, f := range fixtures {
|
||||
events, err := f.round.StartFinalization(f.connectors, f.tree, f.poolTx)
|
||||
require.EqualError(t, err, f.expectedErr)
|
||||
require.Empty(t, events)
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func testEndFinalization(t *testing.T) {
|
||||
t.Run("end_registration", func(t *testing.T) {
|
||||
t.Run("valid", func(t *testing.T) {
|
||||
round := domain.NewRound(dustAmount)
|
||||
events, err := round.StartRegistration()
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, events)
|
||||
|
||||
events, err = round.RegisterPayments(payments)
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, events)
|
||||
|
||||
events, err = round.StartFinalization(connectors, congestionTree, poolTx)
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, events)
|
||||
|
||||
events, err = round.EndFinalization(forfeitTxs, txid)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, events, 1)
|
||||
require.False(t, round.IsStarted())
|
||||
require.True(t, round.IsEnded())
|
||||
require.False(t, round.IsFailed())
|
||||
|
||||
event, ok := events[0].(domain.RoundFinalized)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, round.Id, event.Id)
|
||||
require.Exactly(t, txid, event.Txid)
|
||||
require.Exactly(t, forfeitTxs, event.ForfeitTxs)
|
||||
require.Exactly(t, round.EndingTimestamp, event.Timestamp)
|
||||
})
|
||||
|
||||
t.Run("invalid", func(t *testing.T) {
|
||||
paymentsById := map[string]domain.Payment{}
|
||||
for _, p := range payments {
|
||||
paymentsById[p.Id] = p
|
||||
}
|
||||
fixtures := []struct {
|
||||
round *domain.Round
|
||||
forfeitTxs []string
|
||||
txid string
|
||||
expectedErr string
|
||||
}{
|
||||
{
|
||||
round: &domain.Round{
|
||||
Id: "0",
|
||||
Stage: domain.Stage{
|
||||
Code: domain.FinalizationStage,
|
||||
},
|
||||
},
|
||||
forfeitTxs: nil,
|
||||
txid: txid,
|
||||
expectedErr: "missing list of signed forfeit txs",
|
||||
},
|
||||
{
|
||||
round: &domain.Round{
|
||||
Id: "0",
|
||||
Stage: domain.Stage{
|
||||
Code: domain.FinalizationStage,
|
||||
},
|
||||
},
|
||||
forfeitTxs: forfeitTxs,
|
||||
txid: "",
|
||||
expectedErr: "missing pool txid",
|
||||
},
|
||||
{
|
||||
round: &domain.Round{
|
||||
Id: "0",
|
||||
},
|
||||
forfeitTxs: forfeitTxs,
|
||||
txid: txid,
|
||||
expectedErr: "not in a valid stage to end payment finalization",
|
||||
},
|
||||
{
|
||||
round: &domain.Round{
|
||||
Id: "0",
|
||||
Stage: domain.Stage{
|
||||
Code: domain.RegistrationStage,
|
||||
},
|
||||
},
|
||||
forfeitTxs: forfeitTxs,
|
||||
txid: txid,
|
||||
expectedErr: "not in a valid stage to end payment finalization",
|
||||
},
|
||||
{
|
||||
round: &domain.Round{
|
||||
Id: "0",
|
||||
Stage: domain.Stage{
|
||||
Code: domain.FinalizationStage,
|
||||
Failed: true,
|
||||
},
|
||||
},
|
||||
forfeitTxs: []string{emptyPtx, emptyPtx, emptyPtx, emptyPtx},
|
||||
txid: txid,
|
||||
expectedErr: "not in a valid stage to end payment finalization",
|
||||
},
|
||||
{
|
||||
round: &domain.Round{
|
||||
Id: "0",
|
||||
Stage: domain.Stage{
|
||||
Code: domain.FinalizationStage,
|
||||
Ended: true,
|
||||
},
|
||||
},
|
||||
forfeitTxs: []string{emptyPtx, emptyPtx, emptyPtx, emptyPtx},
|
||||
txid: txid,
|
||||
expectedErr: "round already finalized",
|
||||
},
|
||||
}
|
||||
|
||||
for _, f := range fixtures {
|
||||
events, err := f.round.EndFinalization(f.forfeitTxs, f.txid)
|
||||
require.EqualError(t, err, f.expectedErr)
|
||||
require.Empty(t, events)
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func testFail(t *testing.T) {
|
||||
t.Run("fail", func(t *testing.T) {
|
||||
t.Run("valid", func(t *testing.T) {
|
||||
round := domain.NewRound(dustAmount)
|
||||
events, err := round.StartRegistration()
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, events)
|
||||
|
||||
events, err = round.RegisterPayments(payments)
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, events)
|
||||
|
||||
reason := fmt.Errorf("some valid reason")
|
||||
events = round.Fail(reason)
|
||||
require.Len(t, events, 1)
|
||||
require.False(t, round.IsStarted())
|
||||
require.False(t, round.IsEnded())
|
||||
require.True(t, round.IsFailed())
|
||||
|
||||
event, ok := events[0].(domain.RoundFailed)
|
||||
require.True(t, ok)
|
||||
require.Exactly(t, round.Id, event.Id)
|
||||
require.Exactly(t, round.EndingTimestamp, event.Timestamp)
|
||||
require.EqualError(t, reason, event.Err)
|
||||
|
||||
events = round.Fail(reason)
|
||||
require.Empty(t, events)
|
||||
})
|
||||
})
|
||||
}
|
||||
11
server/internal/core/ports/repo_manager.go
Normal file
11
server/internal/core/ports/repo_manager.go
Normal file
@@ -0,0 +1,11 @@
|
||||
package ports
|
||||
|
||||
import "github.com/ark-network/ark/internal/core/domain"
|
||||
|
||||
type RepoManager interface {
|
||||
Events() domain.RoundEventRepository
|
||||
Rounds() domain.RoundRepository
|
||||
Vtxos() domain.VtxoRepository
|
||||
RegisterEventsHandler(func(*domain.Round))
|
||||
Close()
|
||||
}
|
||||
12
server/internal/core/ports/scanner.go
Normal file
12
server/internal/core/ports/scanner.go
Normal file
@@ -0,0 +1,12 @@
|
||||
package ports
|
||||
|
||||
import (
|
||||
"github.com/ark-network/ark/internal/core/domain"
|
||||
"golang.org/x/net/context"
|
||||
)
|
||||
|
||||
type BlockchainScanner interface {
|
||||
WatchScripts(ctx context.Context, scripts []string) error
|
||||
UnwatchScripts(ctx context.Context, scripts []string) error
|
||||
GetNotificationChannel(ctx context.Context) chan []domain.VtxoKey
|
||||
}
|
||||
9
server/internal/core/ports/scheduler.go
Normal file
9
server/internal/core/ports/scheduler.go
Normal file
@@ -0,0 +1,9 @@
|
||||
package ports
|
||||
|
||||
type SchedulerService interface {
|
||||
Start()
|
||||
Stop()
|
||||
|
||||
ScheduleTask(interval int64, immediate bool, task func()) error
|
||||
ScheduleTaskOnce(delay int64, task func()) error
|
||||
}
|
||||
29
server/internal/core/ports/tx_builder.go
Normal file
29
server/internal/core/ports/tx_builder.go
Normal file
@@ -0,0 +1,29 @@
|
||||
package ports
|
||||
|
||||
import (
|
||||
"github.com/ark-network/ark/common/tree"
|
||||
"github.com/ark-network/ark/internal/core/domain"
|
||||
"github.com/decred/dcrd/dcrec/secp256k1/v4"
|
||||
"github.com/vulpemventures/go-elements/psetv2"
|
||||
)
|
||||
|
||||
type SweepInput struct {
|
||||
InputArgs psetv2.InputArgs
|
||||
SweepLeaf psetv2.TapLeafScript
|
||||
Amount uint64
|
||||
}
|
||||
|
||||
type TxBuilder interface {
|
||||
BuildPoolTx(
|
||||
aspPubkey *secp256k1.PublicKey, payments []domain.Payment, minRelayFee uint64,
|
||||
) (poolTx string, congestionTree tree.CongestionTree, err error)
|
||||
BuildForfeitTxs(
|
||||
aspPubkey *secp256k1.PublicKey, poolTx string, payments []domain.Payment,
|
||||
) (connectors []string, forfeitTxs []string, err error)
|
||||
BuildSweepTx(
|
||||
wallet WalletService,
|
||||
inputs []SweepInput,
|
||||
) (signedSweepTx string, err error)
|
||||
GetLeafSweepClosure(node tree.Node, userPubKey *secp256k1.PublicKey) (*psetv2.TapLeafScript, int64, error)
|
||||
GetVtxoScript(userPubkey, aspPubkey *secp256k1.PublicKey) ([]byte, error)
|
||||
}
|
||||
37
server/internal/core/ports/wallet.go
Normal file
37
server/internal/core/ports/wallet.go
Normal file
@@ -0,0 +1,37 @@
|
||||
package ports
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/decred/dcrd/dcrec/secp256k1/v4"
|
||||
)
|
||||
|
||||
type WalletService interface {
|
||||
BlockchainScanner
|
||||
Status(ctx context.Context) (WalletStatus, error)
|
||||
GetPubkey(ctx context.Context) (*secp256k1.PublicKey, error)
|
||||
DeriveAddresses(ctx context.Context, num int) ([]string, error)
|
||||
SignPset(
|
||||
ctx context.Context, pset string, extractRawTx bool,
|
||||
) (string, error)
|
||||
SelectUtxos(ctx context.Context, asset string, amount uint64) ([]TxInput, uint64, error)
|
||||
BroadcastTransaction(ctx context.Context, txHex string) (string, error)
|
||||
SignPsetWithKey(ctx context.Context, pset string, inputIndexes []int) (string, error) // inputIndexes == nil means sign all inputs
|
||||
IsTransactionPublished(ctx context.Context, txid string) (isPublished bool, blocktime int64, err error)
|
||||
EstimateFees(ctx context.Context, pset string) (uint64, error)
|
||||
Close()
|
||||
}
|
||||
|
||||
type WalletStatus interface {
|
||||
IsInitialized() bool
|
||||
IsUnlocked() bool
|
||||
IsSynced() bool
|
||||
}
|
||||
|
||||
type TxInput interface {
|
||||
GetTxid() string
|
||||
GetIndex() uint32
|
||||
GetScript() string
|
||||
GetAsset() string
|
||||
GetValue() uint64
|
||||
}
|
||||
154
server/internal/infrastructure/db/badger/event_repo.go
Normal file
154
server/internal/infrastructure/db/badger/event_repo.go
Normal file
@@ -0,0 +1,154 @@
|
||||
package badgerdb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
|
||||
"github.com/ark-network/ark/internal/core/domain"
|
||||
dbtypes "github.com/ark-network/ark/internal/infrastructure/db/types"
|
||||
"github.com/dgraph-io/badger/v4"
|
||||
"github.com/timshannon/badgerhold/v4"
|
||||
)
|
||||
|
||||
const eventStoreDir = "round-events"
|
||||
|
||||
type eventsDTO struct {
|
||||
Events [][]byte
|
||||
}
|
||||
|
||||
type eventRepository struct {
|
||||
store *badgerhold.Store
|
||||
lock *sync.RWMutex
|
||||
chUpdates chan *domain.Round
|
||||
handler func(round *domain.Round)
|
||||
}
|
||||
|
||||
func NewRoundEventRepository(config ...interface{}) (dbtypes.EventStore, error) {
|
||||
if len(config) != 2 {
|
||||
return nil, fmt.Errorf("invalid config")
|
||||
}
|
||||
baseDir, ok := config[0].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid base directory")
|
||||
}
|
||||
|
||||
var logger badger.Logger
|
||||
if config[1] != nil {
|
||||
logger, ok = config[1].(badger.Logger)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid logger")
|
||||
}
|
||||
}
|
||||
|
||||
var dir string
|
||||
if len(baseDir) > 0 {
|
||||
dir = filepath.Join(baseDir, eventStoreDir)
|
||||
}
|
||||
store, err := createDB(dir, logger)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open round events store: %s", err)
|
||||
}
|
||||
chEvents := make(chan *domain.Round)
|
||||
lock := &sync.RWMutex{}
|
||||
repo := &eventRepository{store, lock, chEvents, nil}
|
||||
go repo.listen()
|
||||
return repo, nil
|
||||
}
|
||||
|
||||
func (r *eventRepository) Save(
|
||||
ctx context.Context, id string, events ...domain.RoundEvent,
|
||||
) error {
|
||||
allEvents, err := r.get(ctx, id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
allEvents = append(allEvents, events...)
|
||||
if err := r.upsert(ctx, id, allEvents); err != nil {
|
||||
return err
|
||||
}
|
||||
go r.publishEvents(allEvents)
|
||||
return nil
|
||||
}
|
||||
func (r *eventRepository) Load(
|
||||
ctx context.Context, id string,
|
||||
) (*domain.Round, error) {
|
||||
events, err := r.get(ctx, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return domain.NewRoundFromEvents(events), nil
|
||||
}
|
||||
|
||||
func (r *eventRepository) RegisterEventsHandler(
|
||||
handler func(round *domain.Round),
|
||||
) {
|
||||
r.lock.Lock()
|
||||
defer r.lock.Unlock()
|
||||
|
||||
r.handler = handler
|
||||
}
|
||||
|
||||
func (r *eventRepository) Close() {
|
||||
close(r.chUpdates)
|
||||
r.store.Close()
|
||||
}
|
||||
|
||||
func (r *eventRepository) get(
|
||||
ctx context.Context, id string,
|
||||
) ([]domain.RoundEvent, error) {
|
||||
dto := eventsDTO{}
|
||||
var err error
|
||||
if ctx.Value("tx") != nil {
|
||||
tx := ctx.Value("tx").(*badger.Txn)
|
||||
err = r.store.TxGet(tx, id, &dto)
|
||||
} else {
|
||||
err = r.store.Get(id, &dto)
|
||||
}
|
||||
if err != nil {
|
||||
if err == badgerhold.ErrNotFound {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, fmt.Errorf("failed to get events with id %s: %s", id, err)
|
||||
}
|
||||
|
||||
return deserializeEvents(dto.Events)
|
||||
}
|
||||
|
||||
func (r *eventRepository) upsert(
|
||||
ctx context.Context, id string, events []domain.RoundEvent,
|
||||
) error {
|
||||
buf, err := serializeEvents(events)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if ctx.Value("tx") != nil {
|
||||
tx := ctx.Value("tx").(*badger.Txn)
|
||||
err = r.store.TxUpsert(tx, id, buf)
|
||||
} else {
|
||||
err = r.store.Upsert(id, buf)
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to upsert events with id %s: %s", id, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *eventRepository) listen() {
|
||||
for updatedRound := range r.chUpdates {
|
||||
r.lock.RLock()
|
||||
if r.handler != nil {
|
||||
r.handler(updatedRound)
|
||||
}
|
||||
r.lock.RUnlock()
|
||||
}
|
||||
}
|
||||
|
||||
func (r *eventRepository) publishEvents(events []domain.RoundEvent) {
|
||||
r.lock.Lock()
|
||||
defer r.lock.Unlock()
|
||||
round := domain.NewRoundFromEvents(events)
|
||||
r.chUpdates <- round
|
||||
}
|
||||
140
server/internal/infrastructure/db/badger/round_repo.go
Normal file
140
server/internal/infrastructure/db/badger/round_repo.go
Normal file
@@ -0,0 +1,140 @@
|
||||
package badgerdb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/ark-network/ark/internal/core/domain"
|
||||
dbtypes "github.com/ark-network/ark/internal/infrastructure/db/types"
|
||||
"github.com/dgraph-io/badger/v4"
|
||||
"github.com/timshannon/badgerhold/v4"
|
||||
)
|
||||
|
||||
const roundStoreDir = "rounds"
|
||||
|
||||
type roundRepository struct {
|
||||
store *badgerhold.Store
|
||||
}
|
||||
|
||||
func NewRoundRepository(config ...interface{}) (dbtypes.RoundStore, error) {
|
||||
if len(config) != 2 {
|
||||
return nil, fmt.Errorf("invalid config")
|
||||
}
|
||||
baseDir, ok := config[0].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid base directory")
|
||||
}
|
||||
var logger badger.Logger
|
||||
if config[1] != nil {
|
||||
logger, ok = config[1].(badger.Logger)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid logger")
|
||||
}
|
||||
}
|
||||
|
||||
var dir string
|
||||
if len(baseDir) > 0 {
|
||||
dir = filepath.Join(baseDir, roundStoreDir)
|
||||
}
|
||||
store, err := createDB(dir, logger)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open round events store: %s", err)
|
||||
}
|
||||
|
||||
return &roundRepository{store}, nil
|
||||
}
|
||||
|
||||
func (r *roundRepository) AddOrUpdateRound(
|
||||
ctx context.Context, round domain.Round,
|
||||
) error {
|
||||
return r.addOrUpdateRound(ctx, round)
|
||||
}
|
||||
|
||||
func (r *roundRepository) GetCurrentRound(
|
||||
ctx context.Context,
|
||||
) (*domain.Round, error) {
|
||||
query := badgerhold.Where("Stage.Ended").Eq(false).And("Stage.Failed").Eq(false)
|
||||
rounds, err := r.findRound(ctx, query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(rounds) <= 0 {
|
||||
return nil, fmt.Errorf("ongoing round not found")
|
||||
}
|
||||
return &rounds[0], nil
|
||||
}
|
||||
|
||||
func (r *roundRepository) GetRoundWithId(
|
||||
ctx context.Context, id string,
|
||||
) (*domain.Round, error) {
|
||||
query := badgerhold.Where("Id").Eq(id)
|
||||
rounds, err := r.findRound(ctx, query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(rounds) <= 0 {
|
||||
return nil, fmt.Errorf("round with id %s not found", id)
|
||||
}
|
||||
round := &rounds[0]
|
||||
return round, nil
|
||||
}
|
||||
|
||||
func (r *roundRepository) GetRoundWithTxid(
|
||||
ctx context.Context, txid string,
|
||||
) (*domain.Round, error) {
|
||||
query := badgerhold.Where("Txid").Eq(txid)
|
||||
rounds, err := r.findRound(ctx, query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(rounds) <= 0 {
|
||||
return nil, fmt.Errorf("round with txid %s not found", txid)
|
||||
}
|
||||
round := &rounds[0]
|
||||
return round, nil
|
||||
}
|
||||
|
||||
func (r *roundRepository) GetSweepableRounds(
|
||||
ctx context.Context,
|
||||
) ([]domain.Round, error) {
|
||||
query := badgerhold.Where("Stage.Code").Eq(domain.FinalizationStage).
|
||||
And("Stage.Ended").Eq(true).And("Swept").Eq(false)
|
||||
rounds, err := r.findRound(ctx, query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return rounds, nil
|
||||
}
|
||||
|
||||
func (r *roundRepository) Close() {
|
||||
r.store.Close()
|
||||
}
|
||||
|
||||
func (r *roundRepository) findRound(
|
||||
ctx context.Context, query *badgerhold.Query,
|
||||
) ([]domain.Round, error) {
|
||||
var rounds []domain.Round
|
||||
var err error
|
||||
|
||||
if ctx.Value("tx") != nil {
|
||||
tx := ctx.Value("tx").(*badger.Txn)
|
||||
err = r.store.TxFind(tx, &rounds, query)
|
||||
} else {
|
||||
err = r.store.Find(&rounds, query)
|
||||
}
|
||||
|
||||
return rounds, err
|
||||
}
|
||||
|
||||
func (r *roundRepository) addOrUpdateRound(
|
||||
ctx context.Context, round domain.Round,
|
||||
) (err error) {
|
||||
if ctx.Value("tx") != nil {
|
||||
tx := ctx.Value("tx").(*badger.Txn)
|
||||
err = r.store.TxUpsert(tx, round.Id, round)
|
||||
} else {
|
||||
err = r.store.Upsert(round.Id, round)
|
||||
}
|
||||
return
|
||||
}
|
||||
116
server/internal/infrastructure/db/badger/utils.go
Normal file
116
server/internal/infrastructure/db/badger/utils.go
Normal file
@@ -0,0 +1,116 @@
|
||||
package badgerdb
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/ark-network/ark/internal/core/domain"
|
||||
"github.com/dgraph-io/badger/v4"
|
||||
"github.com/dgraph-io/badger/v4/options"
|
||||
"github.com/timshannon/badgerhold/v4"
|
||||
)
|
||||
|
||||
func createDB(dbDir string, logger badger.Logger) (*badgerhold.Store, error) {
|
||||
isInMemory := len(dbDir) <= 0
|
||||
|
||||
opts := badger.DefaultOptions(dbDir)
|
||||
opts.Logger = logger
|
||||
|
||||
if isInMemory {
|
||||
opts.InMemory = true
|
||||
} else {
|
||||
opts.Compression = options.ZSTD
|
||||
}
|
||||
|
||||
db, err := badgerhold.Open(badgerhold.Options{
|
||||
Encoder: badgerhold.DefaultEncode,
|
||||
Decoder: badgerhold.DefaultDecode,
|
||||
SequenceBandwith: 100,
|
||||
Options: opts,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !isInMemory {
|
||||
ticker := time.NewTicker(30 * time.Minute)
|
||||
|
||||
go func() {
|
||||
for {
|
||||
<-ticker.C
|
||||
if err := db.Badger().RunValueLogGC(0.5); err != nil && err != badger.ErrNoRewrite {
|
||||
logger.Errorf("%s", err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
return db, nil
|
||||
}
|
||||
|
||||
func serializeEvents(events []domain.RoundEvent) (*eventsDTO, error) {
|
||||
rawEvents := make([][]byte, 0, len(events))
|
||||
for _, event := range events {
|
||||
buf, err := serializeEvent(event)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
rawEvents = append(rawEvents, buf)
|
||||
}
|
||||
return &eventsDTO{rawEvents}, nil
|
||||
}
|
||||
|
||||
func deserializeEvents(rawEvents [][]byte) ([]domain.RoundEvent, error) {
|
||||
events := make([]domain.RoundEvent, 0)
|
||||
for _, buf := range rawEvents {
|
||||
event, err := deserializeEvent(buf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
events = append(events, event)
|
||||
}
|
||||
return events, nil
|
||||
}
|
||||
|
||||
func serializeEvent(event domain.RoundEvent) ([]byte, error) {
|
||||
switch eventType := event.(type) {
|
||||
default:
|
||||
return json.Marshal(eventType)
|
||||
}
|
||||
}
|
||||
|
||||
func deserializeEvent(buf []byte) (domain.RoundEvent, error) {
|
||||
{
|
||||
var event = domain.RoundFailed{}
|
||||
if err := json.Unmarshal(buf, &event); err == nil && len(event.Err) > 0 {
|
||||
return event, nil
|
||||
}
|
||||
}
|
||||
{
|
||||
var event = domain.RoundFinalized{}
|
||||
if err := json.Unmarshal(buf, &event); err == nil && len(event.Txid) > 0 {
|
||||
return event, nil
|
||||
}
|
||||
}
|
||||
{
|
||||
var event = domain.RoundFinalizationStarted{}
|
||||
if err := json.Unmarshal(buf, &event); err == nil && len(event.CongestionTree) > 0 {
|
||||
return event, nil
|
||||
}
|
||||
}
|
||||
{
|
||||
var event = domain.PaymentsRegistered{}
|
||||
if err := json.Unmarshal(buf, &event); err == nil && len(event.Payments) > 0 {
|
||||
return event, nil
|
||||
}
|
||||
}
|
||||
{
|
||||
var event = domain.RoundStarted{}
|
||||
if err := json.Unmarshal(buf, &event); err == nil && event.Timestamp > 0 {
|
||||
return event, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("unknown event")
|
||||
}
|
||||
245
server/internal/infrastructure/db/badger/vtxo_repo.go
Normal file
245
server/internal/infrastructure/db/badger/vtxo_repo.go
Normal file
@@ -0,0 +1,245 @@
|
||||
package badgerdb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/ark-network/ark/internal/core/domain"
|
||||
dbtypes "github.com/ark-network/ark/internal/infrastructure/db/types"
|
||||
"github.com/dgraph-io/badger/v4"
|
||||
"github.com/timshannon/badgerhold/v4"
|
||||
)
|
||||
|
||||
const vtxoStoreDir = "vtxos"
|
||||
|
||||
type vtxoRepository struct {
|
||||
store *badgerhold.Store
|
||||
}
|
||||
|
||||
func NewVtxoRepository(config ...interface{}) (dbtypes.VtxoStore, error) {
|
||||
if len(config) != 2 {
|
||||
return nil, fmt.Errorf("invalid config")
|
||||
}
|
||||
baseDir, ok := config[0].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid base directory")
|
||||
}
|
||||
var logger badger.Logger
|
||||
if config[1] != nil {
|
||||
logger, ok = config[1].(badger.Logger)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid logger")
|
||||
}
|
||||
}
|
||||
|
||||
var dir string
|
||||
if len(baseDir) > 0 {
|
||||
dir = filepath.Join(baseDir, vtxoStoreDir)
|
||||
}
|
||||
store, err := createDB(dir, logger)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open round events store: %s", err)
|
||||
}
|
||||
|
||||
return &vtxoRepository{store}, nil
|
||||
}
|
||||
|
||||
func (r *vtxoRepository) AddVtxos(
|
||||
ctx context.Context, vtxos []domain.Vtxo,
|
||||
) error {
|
||||
return r.addVtxos(ctx, vtxos)
|
||||
}
|
||||
|
||||
func (r *vtxoRepository) SpendVtxos(
|
||||
ctx context.Context, vtxoKeys []domain.VtxoKey,
|
||||
) error {
|
||||
for _, vtxoKey := range vtxoKeys {
|
||||
if err := r.spendVtxo(ctx, vtxoKey); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *vtxoRepository) RedeemVtxos(
|
||||
ctx context.Context, vtxoKeys []domain.VtxoKey,
|
||||
) ([]domain.Vtxo, error) {
|
||||
vtxos := make([]domain.Vtxo, 0, len(vtxoKeys))
|
||||
for _, vtxoKey := range vtxoKeys {
|
||||
vtxo, err := r.redeemVtxo(ctx, vtxoKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if vtxo != nil {
|
||||
vtxos = append(vtxos, *vtxo)
|
||||
}
|
||||
}
|
||||
return vtxos, nil
|
||||
}
|
||||
|
||||
func (r *vtxoRepository) GetVtxos(
|
||||
ctx context.Context, vtxoKeys []domain.VtxoKey,
|
||||
) ([]domain.Vtxo, error) {
|
||||
vtxos := make([]domain.Vtxo, 0, len(vtxoKeys))
|
||||
for _, vtxoKey := range vtxoKeys {
|
||||
vtxo, err := r.getVtxo(ctx, vtxoKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
vtxos = append(vtxos, *vtxo)
|
||||
}
|
||||
return vtxos, nil
|
||||
}
|
||||
|
||||
func (r *vtxoRepository) GetVtxosForRound(
|
||||
ctx context.Context, txid string,
|
||||
) ([]domain.Vtxo, error) {
|
||||
query := badgerhold.Where("Txid").Eq(txid)
|
||||
return r.findVtxos(ctx, query)
|
||||
}
|
||||
|
||||
func (r *vtxoRepository) GetSpendableVtxos(
|
||||
ctx context.Context, pubkey string,
|
||||
) ([]domain.Vtxo, error) {
|
||||
query := badgerhold.Where("Spent").Eq(false).And("Redeemed").Eq(false).And("Swept").Eq(false)
|
||||
if len(pubkey) > 0 {
|
||||
query = query.And("Pubkey").Eq(pubkey)
|
||||
}
|
||||
return r.findVtxos(ctx, query)
|
||||
}
|
||||
|
||||
func (r *vtxoRepository) SweepVtxos(
|
||||
ctx context.Context, vtxoKeys []domain.VtxoKey,
|
||||
) error {
|
||||
for _, vtxoKey := range vtxoKeys {
|
||||
if err := r.sweepVtxo(ctx, vtxoKey); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *vtxoRepository) Close() {
|
||||
r.store.Close()
|
||||
}
|
||||
|
||||
func (r *vtxoRepository) addVtxos(
|
||||
ctx context.Context, vtxos []domain.Vtxo,
|
||||
) (err error) {
|
||||
for _, vtxo := range vtxos {
|
||||
vtxoKey := vtxo.VtxoKey.Hash()
|
||||
if ctx.Value("tx") != nil {
|
||||
tx := ctx.Value("tx").(*badger.Txn)
|
||||
err = r.store.TxInsert(tx, vtxoKey, vtxo)
|
||||
} else {
|
||||
err = r.store.Insert(vtxoKey, vtxo)
|
||||
}
|
||||
}
|
||||
if err != nil && err == badgerhold.ErrKeyExists {
|
||||
err = nil
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (r *vtxoRepository) getVtxo(
|
||||
ctx context.Context, vtxoKey domain.VtxoKey,
|
||||
) (*domain.Vtxo, error) {
|
||||
var vtxo domain.Vtxo
|
||||
var err error
|
||||
if ctx.Value("tx") != nil {
|
||||
tx := ctx.Value("tx").(*badger.Txn)
|
||||
err = r.store.TxGet(tx, vtxoKey.Hash(), &vtxo)
|
||||
} else {
|
||||
err = r.store.Get(vtxoKey.Hash(), &vtxo)
|
||||
}
|
||||
if err != nil && err == badgerhold.ErrNotFound {
|
||||
return nil, fmt.Errorf("vtxo %s:%d not found", vtxoKey.Txid, vtxoKey.VOut)
|
||||
}
|
||||
|
||||
return &vtxo, nil
|
||||
}
|
||||
|
||||
func (r *vtxoRepository) spendVtxo(ctx context.Context, vtxoKey domain.VtxoKey) error {
|
||||
vtxo, err := r.getVtxo(ctx, vtxoKey)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "not found") {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
if vtxo.Spent {
|
||||
return nil
|
||||
}
|
||||
|
||||
vtxo.Spent = true
|
||||
if ctx.Value("tx") != nil {
|
||||
tx := ctx.Value("tx").(*badger.Txn)
|
||||
err = r.store.TxUpdate(tx, vtxoKey.Hash(), *vtxo)
|
||||
} else {
|
||||
err = r.store.Update(vtxoKey.Hash(), *vtxo)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *vtxoRepository) redeemVtxo(ctx context.Context, vtxoKey domain.VtxoKey) (*domain.Vtxo, error) {
|
||||
vtxo, err := r.getVtxo(ctx, vtxoKey)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "not found") {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
if vtxo.Redeemed {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
vtxo.Redeemed = true
|
||||
if ctx.Value("tx") != nil {
|
||||
tx := ctx.Value("tx").(*badger.Txn)
|
||||
err = r.store.TxUpdate(tx, vtxoKey.Hash(), *vtxo)
|
||||
} else {
|
||||
err = r.store.Update(vtxoKey.Hash(), *vtxo)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return vtxo, nil
|
||||
}
|
||||
|
||||
func (r *vtxoRepository) findVtxos(ctx context.Context, query *badgerhold.Query) ([]domain.Vtxo, error) {
|
||||
var vtxos []domain.Vtxo
|
||||
var err error
|
||||
|
||||
if ctx.Value("tx") != nil {
|
||||
tx := ctx.Value("tx").(*badger.Txn)
|
||||
err = r.store.TxFind(tx, &vtxos, query)
|
||||
} else {
|
||||
err = r.store.Find(&vtxos, query)
|
||||
}
|
||||
|
||||
return vtxos, err
|
||||
}
|
||||
|
||||
func (r *vtxoRepository) sweepVtxo(ctx context.Context, vtxoKey domain.VtxoKey) error {
|
||||
vtxo, err := r.getVtxo(ctx, vtxoKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if vtxo.Swept {
|
||||
return nil
|
||||
}
|
||||
|
||||
vtxo.Swept = true
|
||||
if ctx.Value("tx") != nil {
|
||||
tx := ctx.Value("tx").(*badger.Txn)
|
||||
err = r.store.TxUpdate(tx, vtxoKey.Hash(), *vtxo)
|
||||
} else {
|
||||
err = r.store.Update(vtxoKey.Hash(), *vtxo)
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
90
server/internal/infrastructure/db/service.go
Normal file
90
server/internal/infrastructure/db/service.go
Normal file
@@ -0,0 +1,90 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/ark-network/ark/internal/core/domain"
|
||||
"github.com/ark-network/ark/internal/core/ports"
|
||||
badgerdb "github.com/ark-network/ark/internal/infrastructure/db/badger"
|
||||
dbtypes "github.com/ark-network/ark/internal/infrastructure/db/types"
|
||||
)
|
||||
|
||||
var (
|
||||
eventStoreTypes = map[string]func(...interface{}) (dbtypes.EventStore, error){
|
||||
"badger": badgerdb.NewRoundEventRepository,
|
||||
}
|
||||
roundStoreTypes = map[string]func(...interface{}) (dbtypes.RoundStore, error){
|
||||
"badger": badgerdb.NewRoundRepository,
|
||||
}
|
||||
vtxoStoreTypes = map[string]func(...interface{}) (dbtypes.VtxoStore, error){
|
||||
"badger": badgerdb.NewVtxoRepository,
|
||||
}
|
||||
)
|
||||
|
||||
type ServiceConfig struct {
|
||||
EventStoreType string
|
||||
RoundStoreType string
|
||||
VtxoStoreType string
|
||||
|
||||
EventStoreConfig []interface{}
|
||||
RoundStoreConfig []interface{}
|
||||
VtxoStoreConfig []interface{}
|
||||
}
|
||||
|
||||
type service struct {
|
||||
eventStore dbtypes.EventStore
|
||||
roundStore dbtypes.RoundStore
|
||||
vtxoStore dbtypes.VtxoStore
|
||||
}
|
||||
|
||||
func NewService(config ServiceConfig) (ports.RepoManager, error) {
|
||||
eventStoreFactory, ok := eventStoreTypes[config.EventStoreType]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("event store type not supported")
|
||||
}
|
||||
roundStoreFactory, ok := roundStoreTypes[config.RoundStoreType]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("round store type not supported")
|
||||
}
|
||||
vtxoStoreFactory, ok := vtxoStoreTypes[config.VtxoStoreType]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("vtxo store type not supported")
|
||||
}
|
||||
|
||||
eventStore, err := eventStoreFactory(config.EventStoreConfig...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open event store: %s", err)
|
||||
}
|
||||
roundStore, err := roundStoreFactory(config.RoundStoreConfig...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open round store: %s", err)
|
||||
}
|
||||
vtxoStore, err := vtxoStoreFactory(config.VtxoStoreConfig...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open vtxo store: %s", err)
|
||||
}
|
||||
|
||||
return &service{eventStore, roundStore, vtxoStore}, nil
|
||||
}
|
||||
|
||||
func (s *service) RegisterEventsHandler(handler func(round *domain.Round)) {
|
||||
s.eventStore.RegisterEventsHandler(handler)
|
||||
}
|
||||
|
||||
func (s *service) Events() domain.RoundEventRepository {
|
||||
return s.eventStore
|
||||
}
|
||||
|
||||
func (s *service) Rounds() domain.RoundRepository {
|
||||
return s.roundStore
|
||||
}
|
||||
|
||||
func (s *service) Vtxos() domain.VtxoRepository {
|
||||
return s.vtxoStore
|
||||
}
|
||||
|
||||
func (s *service) Close() {
|
||||
s.eventStore.Close()
|
||||
s.roundStore.Close()
|
||||
s.vtxoStore.Close()
|
||||
}
|
||||
417
server/internal/infrastructure/db/service_test.go
Normal file
417
server/internal/infrastructure/db/service_test.go
Normal file
@@ -0,0 +1,417 @@
|
||||
package db_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ark-network/ark/common/tree"
|
||||
"github.com/ark-network/ark/internal/core/domain"
|
||||
"github.com/ark-network/ark/internal/core/ports"
|
||||
"github.com/ark-network/ark/internal/infrastructure/db"
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
const (
|
||||
emptyPtx = "cHNldP8BAgQCAAAAAQQBAAEFAQABBgEDAfsEAgAAAAA="
|
||||
emptyTx = "0200000000000000000000"
|
||||
txid = "00000000000000000000000000000000000000000000000000000000000000000"
|
||||
pubkey1 = "0300000000000000000000000000000000000000000000000000000000000000001"
|
||||
pubkey2 = "0200000000000000000000000000000000000000000000000000000000000000002"
|
||||
)
|
||||
|
||||
var congestionTree = [][]tree.Node{
|
||||
{
|
||||
{
|
||||
Txid: txid,
|
||||
Tx: emptyPtx,
|
||||
ParentTxid: txid,
|
||||
},
|
||||
},
|
||||
{
|
||||
{
|
||||
Txid: txid,
|
||||
Tx: emptyPtx,
|
||||
ParentTxid: txid,
|
||||
},
|
||||
{
|
||||
Txid: txid,
|
||||
Tx: emptyPtx,
|
||||
ParentTxid: txid,
|
||||
},
|
||||
},
|
||||
{
|
||||
{
|
||||
Txid: txid,
|
||||
Tx: emptyPtx,
|
||||
ParentTxid: txid,
|
||||
},
|
||||
{
|
||||
Txid: txid,
|
||||
Tx: emptyPtx,
|
||||
ParentTxid: txid,
|
||||
},
|
||||
{
|
||||
Txid: txid,
|
||||
Tx: emptyPtx,
|
||||
ParentTxid: txid,
|
||||
},
|
||||
{
|
||||
Txid: txid,
|
||||
Tx: emptyPtx,
|
||||
ParentTxid: txid,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
func TestService(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
config db.ServiceConfig
|
||||
}{
|
||||
{
|
||||
name: "repo_manager_with_badger_stores",
|
||||
config: db.ServiceConfig{
|
||||
EventStoreType: "badger",
|
||||
RoundStoreType: "badger",
|
||||
VtxoStoreType: "badger",
|
||||
EventStoreConfig: []interface{}{"", nil},
|
||||
RoundStoreConfig: []interface{}{"", nil},
|
||||
VtxoStoreConfig: []interface{}{"", nil},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
svc, err := db.NewService(tt.config)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, svc)
|
||||
|
||||
testRoundEventRepository(t, svc)
|
||||
testRoundRepository(t, svc)
|
||||
testVtxoRepository(t, svc)
|
||||
|
||||
svc.Close()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func testRoundEventRepository(t *testing.T, svc ports.RepoManager) {
|
||||
t.Run("test_event_repository", func(t *testing.T) {
|
||||
fixtures := []struct {
|
||||
roundId string
|
||||
events []domain.RoundEvent
|
||||
handler func(*domain.Round)
|
||||
}{
|
||||
{
|
||||
roundId: "42dd81f7-cadd-482c-bf69-8e9209aae9f3",
|
||||
events: []domain.RoundEvent{
|
||||
domain.RoundStarted{
|
||||
Id: "42dd81f7-cadd-482c-bf69-8e9209aae9f3",
|
||||
Timestamp: 1701190270,
|
||||
},
|
||||
},
|
||||
handler: func(round *domain.Round) {
|
||||
require.NotNil(t, round)
|
||||
require.Len(t, round.Events(), 1)
|
||||
require.True(t, round.IsStarted())
|
||||
require.False(t, round.IsFailed())
|
||||
require.False(t, round.IsEnded())
|
||||
},
|
||||
},
|
||||
{
|
||||
roundId: "1ea610ff-bf3e-4068-9bfd-b6c3f553467e",
|
||||
events: []domain.RoundEvent{
|
||||
domain.RoundStarted{
|
||||
Id: "1ea610ff-bf3e-4068-9bfd-b6c3f553467e",
|
||||
Timestamp: 1701190270,
|
||||
},
|
||||
domain.RoundFinalizationStarted{
|
||||
Id: "1ea610ff-bf3e-4068-9bfd-b6c3f553467e",
|
||||
CongestionTree: congestionTree,
|
||||
Connectors: []string{emptyPtx, emptyPtx},
|
||||
PoolTx: emptyTx,
|
||||
},
|
||||
},
|
||||
handler: func(round *domain.Round) {
|
||||
require.NotNil(t, round)
|
||||
require.Len(t, round.Events(), 2)
|
||||
require.Len(t, round.CongestionTree, 3)
|
||||
require.Equal(t, round.CongestionTree.NumberOfNodes(), 7)
|
||||
require.Len(t, round.Connectors, 2)
|
||||
},
|
||||
},
|
||||
{
|
||||
roundId: "7578231e-428d-45ae-aaa4-e62c77ad5cec",
|
||||
events: []domain.RoundEvent{
|
||||
domain.RoundStarted{
|
||||
Id: "7578231e-428d-45ae-aaa4-e62c77ad5cec",
|
||||
Timestamp: 1701190270,
|
||||
},
|
||||
domain.RoundFinalizationStarted{
|
||||
Id: "7578231e-428d-45ae-aaa4-e62c77ad5cec",
|
||||
CongestionTree: congestionTree,
|
||||
Connectors: []string{emptyPtx, emptyPtx},
|
||||
PoolTx: emptyTx,
|
||||
},
|
||||
domain.RoundFinalized{
|
||||
Id: "7578231e-428d-45ae-aaa4-e62c77ad5cec",
|
||||
Txid: txid,
|
||||
ForfeitTxs: []string{emptyPtx, emptyPtx, emptyPtx, emptyPtx},
|
||||
Timestamp: 1701190300,
|
||||
},
|
||||
},
|
||||
handler: func(round *domain.Round) {
|
||||
require.NotNil(t, round)
|
||||
require.Len(t, round.Events(), 3)
|
||||
require.False(t, round.IsStarted())
|
||||
require.False(t, round.IsFailed())
|
||||
require.True(t, round.IsEnded())
|
||||
require.NotEmpty(t, round.Txid)
|
||||
},
|
||||
},
|
||||
}
|
||||
ctx := context.Background()
|
||||
|
||||
for _, f := range fixtures {
|
||||
svc.RegisterEventsHandler(f.handler)
|
||||
|
||||
err := svc.Events().Save(ctx, f.roundId, f.events...)
|
||||
require.NoError(t, err)
|
||||
|
||||
round, err := svc.Events().Load(ctx, f.roundId)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, round)
|
||||
require.Equal(t, f.roundId, round.Id)
|
||||
require.Len(t, round.Events(), len(f.events))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func testRoundRepository(t *testing.T, svc ports.RepoManager) {
|
||||
t.Run("test_round_repository", func(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
now := time.Now()
|
||||
|
||||
roundId := uuid.New().String()
|
||||
|
||||
round, err := svc.Rounds().GetRoundWithId(ctx, roundId)
|
||||
require.Error(t, err)
|
||||
require.Nil(t, round)
|
||||
|
||||
events := []domain.RoundEvent{
|
||||
domain.RoundStarted{
|
||||
Id: roundId,
|
||||
Timestamp: now.Unix(),
|
||||
},
|
||||
}
|
||||
round = domain.NewRoundFromEvents(events)
|
||||
err = svc.Rounds().AddOrUpdateRound(ctx, *round)
|
||||
require.NoError(t, err)
|
||||
|
||||
currentRound, err := svc.Rounds().GetCurrentRound(ctx)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, currentRound)
|
||||
require.Condition(t, roundsMatch(*round, *currentRound))
|
||||
|
||||
roundById, err := svc.Rounds().GetRoundWithId(ctx, roundId)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, roundById)
|
||||
require.Condition(t, roundsMatch(*round, *roundById))
|
||||
|
||||
newEvents := []domain.RoundEvent{
|
||||
domain.PaymentsRegistered{
|
||||
Id: roundId,
|
||||
Payments: []domain.Payment{
|
||||
{
|
||||
Id: uuid.New().String(),
|
||||
Inputs: []domain.Vtxo{{}},
|
||||
Receivers: []domain.Receiver{{}},
|
||||
},
|
||||
{
|
||||
Id: uuid.New().String(),
|
||||
Inputs: []domain.Vtxo{{}},
|
||||
Receivers: []domain.Receiver{{}, {}, {}},
|
||||
},
|
||||
},
|
||||
},
|
||||
domain.RoundFinalizationStarted{
|
||||
Id: roundId,
|
||||
CongestionTree: congestionTree,
|
||||
Connectors: []string{emptyPtx, emptyPtx},
|
||||
PoolTx: emptyTx,
|
||||
},
|
||||
}
|
||||
events = append(events, newEvents...)
|
||||
updatedRound := domain.NewRoundFromEvents(events)
|
||||
|
||||
err = svc.Rounds().AddOrUpdateRound(ctx, *updatedRound)
|
||||
require.NoError(t, err)
|
||||
|
||||
currentRound, err = svc.Rounds().GetCurrentRound(ctx)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, currentRound)
|
||||
require.Condition(t, roundsMatch(*updatedRound, *currentRound))
|
||||
|
||||
roundById, err = svc.Rounds().GetRoundWithId(ctx, updatedRound.Id)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, currentRound)
|
||||
require.Condition(t, roundsMatch(*updatedRound, *roundById))
|
||||
|
||||
newEvents = []domain.RoundEvent{
|
||||
domain.RoundFinalized{
|
||||
Id: roundId,
|
||||
Txid: txid,
|
||||
ForfeitTxs: []string{emptyPtx, emptyPtx, emptyPtx, emptyPtx},
|
||||
Timestamp: now.Add(60 * time.Second).Unix(),
|
||||
},
|
||||
}
|
||||
events = append(events, newEvents...)
|
||||
finalizedRound := domain.NewRoundFromEvents(events)
|
||||
|
||||
err = svc.Rounds().AddOrUpdateRound(ctx, *finalizedRound)
|
||||
require.NoError(t, err)
|
||||
|
||||
currentRound, err = svc.Rounds().GetCurrentRound(ctx)
|
||||
require.Error(t, err)
|
||||
require.Nil(t, currentRound)
|
||||
|
||||
roundById, err = svc.Rounds().GetRoundWithId(ctx, roundId)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, roundById)
|
||||
require.Condition(t, roundsMatch(*finalizedRound, *roundById))
|
||||
|
||||
roundByTxid, err := svc.Rounds().GetRoundWithTxid(ctx, txid)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, roundByTxid)
|
||||
require.Condition(t, roundsMatch(*finalizedRound, *roundByTxid))
|
||||
})
|
||||
}
|
||||
|
||||
func testVtxoRepository(t *testing.T, svc ports.RepoManager) {
|
||||
t.Run("test_vtxo_repository", func(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
userVtxos := []domain.Vtxo{
|
||||
{
|
||||
VtxoKey: domain.VtxoKey{
|
||||
Txid: txid,
|
||||
VOut: 0,
|
||||
},
|
||||
Receiver: domain.Receiver{
|
||||
Pubkey: pubkey1,
|
||||
Amount: 1000,
|
||||
},
|
||||
},
|
||||
{
|
||||
VtxoKey: domain.VtxoKey{
|
||||
Txid: txid,
|
||||
VOut: 1,
|
||||
},
|
||||
Receiver: domain.Receiver{
|
||||
Pubkey: pubkey1,
|
||||
Amount: 2000,
|
||||
},
|
||||
},
|
||||
}
|
||||
newVtxos := append(userVtxos, domain.Vtxo{
|
||||
VtxoKey: domain.VtxoKey{
|
||||
Txid: txid,
|
||||
VOut: 1,
|
||||
},
|
||||
Receiver: domain.Receiver{
|
||||
Pubkey: pubkey2,
|
||||
Amount: 2000,
|
||||
},
|
||||
})
|
||||
|
||||
vtxoKeys := make([]domain.VtxoKey, 0, len(userVtxos))
|
||||
for _, v := range userVtxos {
|
||||
vtxoKeys = append(vtxoKeys, v.VtxoKey)
|
||||
}
|
||||
|
||||
vtxos, err := svc.Vtxos().GetVtxos(ctx, vtxoKeys)
|
||||
require.Error(t, err)
|
||||
require.Empty(t, vtxos)
|
||||
|
||||
spendableVtxos, err := svc.Vtxos().GetSpendableVtxos(ctx, pubkey1)
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, spendableVtxos)
|
||||
|
||||
spendableVtxos, err = svc.Vtxos().GetSpendableVtxos(ctx, "")
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, spendableVtxos)
|
||||
|
||||
err = svc.Vtxos().AddVtxos(ctx, newVtxos)
|
||||
require.NoError(t, err)
|
||||
|
||||
vtxos, err = svc.Vtxos().GetVtxos(ctx, vtxoKeys)
|
||||
require.NoError(t, err)
|
||||
require.Exactly(t, userVtxos, vtxos)
|
||||
|
||||
spendableVtxos, err = svc.Vtxos().GetSpendableVtxos(ctx, pubkey1)
|
||||
require.NoError(t, err)
|
||||
require.Exactly(t, vtxos, spendableVtxos)
|
||||
|
||||
spendableVtxos, err = svc.Vtxos().GetSpendableVtxos(ctx, "")
|
||||
require.NoError(t, err)
|
||||
require.Exactly(t, userVtxos, spendableVtxos)
|
||||
|
||||
err = svc.Vtxos().SpendVtxos(ctx, vtxoKeys[:1])
|
||||
require.NoError(t, err)
|
||||
|
||||
spentVtxos, err := svc.Vtxos().GetVtxos(ctx, vtxoKeys[:1])
|
||||
require.NoError(t, err)
|
||||
require.Len(t, spentVtxos, len(vtxoKeys[:1]))
|
||||
for _, v := range spentVtxos {
|
||||
require.True(t, v.Spent)
|
||||
}
|
||||
|
||||
spendableVtxos, err = svc.Vtxos().GetSpendableVtxos(ctx, pubkey1)
|
||||
require.NoError(t, err)
|
||||
require.Exactly(t, vtxos[1:], spendableVtxos)
|
||||
})
|
||||
}
|
||||
|
||||
func roundsMatch(expected, got domain.Round) assert.Comparison {
|
||||
return func() bool {
|
||||
if expected.Id != got.Id {
|
||||
return false
|
||||
}
|
||||
if expected.StartingTimestamp != got.StartingTimestamp {
|
||||
return false
|
||||
}
|
||||
if expected.EndingTimestamp != got.EndingTimestamp {
|
||||
return false
|
||||
}
|
||||
if expected.Stage != got.Stage {
|
||||
return false
|
||||
}
|
||||
if !reflect.DeepEqual(expected.Payments, got.Payments) {
|
||||
return false
|
||||
}
|
||||
if expected.Txid != got.Txid {
|
||||
return false
|
||||
}
|
||||
if expected.UnsignedTx != got.UnsignedTx {
|
||||
return false
|
||||
}
|
||||
if !reflect.DeepEqual(expected.ForfeitTxs, got.ForfeitTxs) {
|
||||
return false
|
||||
}
|
||||
if !reflect.DeepEqual(expected.CongestionTree, got.CongestionTree) {
|
||||
return false
|
||||
}
|
||||
if !reflect.DeepEqual(expected.Connectors, got.Connectors) {
|
||||
return false
|
||||
}
|
||||
if expected.Version != got.Version {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
}
|
||||
19
server/internal/infrastructure/db/types/types.go
Normal file
19
server/internal/infrastructure/db/types/types.go
Normal file
@@ -0,0 +1,19 @@
|
||||
package dbtypes
|
||||
|
||||
import "github.com/ark-network/ark/internal/core/domain"
|
||||
|
||||
type EventStore interface {
|
||||
domain.RoundEventRepository
|
||||
RegisterEventsHandler(func(*domain.Round))
|
||||
Close()
|
||||
}
|
||||
|
||||
type RoundStore interface {
|
||||
domain.RoundRepository
|
||||
Close()
|
||||
}
|
||||
|
||||
type VtxoStore interface {
|
||||
domain.VtxoRepository
|
||||
Close()
|
||||
}
|
||||
30
server/internal/infrastructure/ocean-wallet/account.go
Normal file
30
server/internal/infrastructure/ocean-wallet/account.go
Normal file
@@ -0,0 +1,30 @@
|
||||
package oceanwallet
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
pb "github.com/ark-network/ark/api-spec/protobuf/gen/ocean/v1"
|
||||
"github.com/vulpemventures/go-elements/address"
|
||||
)
|
||||
|
||||
func (s *service) DeriveAddresses(
|
||||
ctx context.Context, numOfAddresses int,
|
||||
) ([]string, error) {
|
||||
res, err := s.accountClient.DeriveAddresses(ctx, &pb.DeriveAddressesRequest{
|
||||
AccountName: accountLabel,
|
||||
NumOfAddresses: uint64(numOfAddresses),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
addresses := make([]string, 0, numOfAddresses)
|
||||
for _, addr := range res.GetAddresses() {
|
||||
if isConf, _ := address.IsConfidential(addr); !isConf {
|
||||
addresses = append(addresses, addr)
|
||||
continue
|
||||
}
|
||||
info, _ := address.FromConfidential(addr)
|
||||
addresses = append(addresses, info.Address)
|
||||
}
|
||||
return addresses, nil
|
||||
}
|
||||
47
server/internal/infrastructure/ocean-wallet/notification.go
Normal file
47
server/internal/infrastructure/ocean-wallet/notification.go
Normal file
@@ -0,0 +1,47 @@
|
||||
package oceanwallet
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
|
||||
pb "github.com/ark-network/ark/api-spec/protobuf/gen/ocean/v1"
|
||||
"github.com/ark-network/ark/internal/core/domain"
|
||||
"github.com/btcsuite/btcd/chaincfg/chainhash"
|
||||
)
|
||||
|
||||
func (s *service) WatchScripts(ctx context.Context, scripts []string) error {
|
||||
for _, script := range scripts {
|
||||
if _, err := s.notifyClient.WatchExternalScript(ctx, &pb.WatchExternalScriptRequest{
|
||||
Script: script,
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *service) UnwatchScripts(ctx context.Context, scripts []string) error {
|
||||
for _, script := range scripts {
|
||||
scriptHash := calcScriptHash(script)
|
||||
if _, err := s.notifyClient.UnwatchExternalScript(ctx, &pb.UnwatchExternalScriptRequest{
|
||||
Label: scriptHash,
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *service) GetNotificationChannel(ctx context.Context) chan []domain.VtxoKey {
|
||||
return s.chVtxos
|
||||
}
|
||||
|
||||
func calcScriptHash(script string) string {
|
||||
buf, _ := hex.DecodeString(script)
|
||||
hashedBuf := sha256.Sum256(buf)
|
||||
hash, _ := chainhash.NewHash(hashedBuf[:])
|
||||
return hash.String()
|
||||
}
|
||||
140
server/internal/infrastructure/ocean-wallet/service.go
Normal file
140
server/internal/infrastructure/ocean-wallet/service.go
Normal file
@@ -0,0 +1,140 @@
|
||||
package oceanwallet
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
|
||||
pb "github.com/ark-network/ark/api-spec/protobuf/gen/ocean/v1"
|
||||
"github.com/ark-network/ark/internal/core/domain"
|
||||
"github.com/ark-network/ark/internal/core/ports"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
type service struct {
|
||||
addr string
|
||||
conn *grpc.ClientConn
|
||||
walletClient pb.WalletServiceClient
|
||||
accountClient pb.AccountServiceClient
|
||||
txClient pb.TransactionServiceClient
|
||||
notifyClient pb.NotificationServiceClient
|
||||
chVtxos chan []domain.VtxoKey
|
||||
}
|
||||
|
||||
func NewService(addr string) (ports.WalletService, error) {
|
||||
conn, err := grpc.Dial(addr, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
walletClient := pb.NewWalletServiceClient(conn)
|
||||
accountClient := pb.NewAccountServiceClient(conn)
|
||||
txClient := pb.NewTransactionServiceClient(conn)
|
||||
notifyClient := pb.NewNotificationServiceClient(conn)
|
||||
chVtxos := make(chan []domain.VtxoKey)
|
||||
svc := &service{
|
||||
addr: addr,
|
||||
conn: conn,
|
||||
walletClient: walletClient,
|
||||
accountClient: accountClient,
|
||||
txClient: txClient,
|
||||
notifyClient: notifyClient,
|
||||
chVtxos: chVtxos,
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
status, err := svc.Status(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !(status.IsInitialized() && status.IsUnlocked()) {
|
||||
return nil, fmt.Errorf("wallet must be already initialized and unlocked")
|
||||
}
|
||||
|
||||
// Create ark account at startup if needed.
|
||||
info, err := walletClient.GetInfo(ctx, &pb.GetInfoRequest{})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
found := false
|
||||
for _, account := range info.GetAccounts() {
|
||||
if account.GetLabel() == accountLabel {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
if _, err := accountClient.CreateAccountBIP44(ctx, &pb.CreateAccountBIP44Request{
|
||||
Label: accountLabel,
|
||||
Unconfidential: true,
|
||||
}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
go svc.listenToNotificaitons()
|
||||
|
||||
return svc, nil
|
||||
}
|
||||
|
||||
func (s *service) Close() {
|
||||
close(s.chVtxos)
|
||||
s.conn.Close()
|
||||
}
|
||||
|
||||
func (s *service) listenToNotificaitons() {
|
||||
var stream pb.NotificationService_UtxosNotificationsClient
|
||||
var err error
|
||||
for {
|
||||
stream, err = s.notifyClient.UtxosNotifications(context.Background(), &pb.UtxosNotificationsRequest{})
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
for {
|
||||
msg, err := stream.Recv()
|
||||
if err != nil {
|
||||
if err == io.EOF || status.Convert(err).Code() == codes.Canceled {
|
||||
return
|
||||
}
|
||||
log.WithError(err).Warn("received unexpected error from source")
|
||||
return
|
||||
}
|
||||
|
||||
if msg.GetEventType() != pb.UtxoEventType_UTXO_EVENT_TYPE_NEW &&
|
||||
msg.GetEventType() != pb.UtxoEventType_UTXO_EVENT_TYPE_CONFIRMED {
|
||||
continue
|
||||
}
|
||||
vtxos := toVtxos(msg.GetUtxos())
|
||||
if len(vtxos) > 0 {
|
||||
go func() {
|
||||
s.chVtxos <- vtxos
|
||||
}()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func toVtxos(utxos []*pb.Utxo) []domain.VtxoKey {
|
||||
vtxos := make([]domain.VtxoKey, 0, len(utxos))
|
||||
for _, utxo := range utxos {
|
||||
// We want to notify for activity related to vtxos owner, therefore we skip
|
||||
// returning anything related to the internal accounts of the wallet, like
|
||||
// for example bip84-account0.
|
||||
if strings.HasPrefix(utxo.GetAccountName(), "bip") {
|
||||
continue
|
||||
}
|
||||
|
||||
vtxos = append(vtxos, domain.VtxoKey{
|
||||
Txid: utxo.GetTxid(),
|
||||
VOut: utxo.GetIndex(),
|
||||
})
|
||||
}
|
||||
return vtxos
|
||||
}
|
||||
270
server/internal/infrastructure/ocean-wallet/transaction.go
Normal file
270
server/internal/infrastructure/ocean-wallet/transaction.go
Normal file
@@ -0,0 +1,270 @@
|
||||
package oceanwallet
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
pb "github.com/ark-network/ark/api-spec/protobuf/gen/ocean/v1"
|
||||
"github.com/ark-network/ark/common/tree"
|
||||
"github.com/ark-network/ark/internal/core/ports"
|
||||
"github.com/btcsuite/btcd/chaincfg/chainhash"
|
||||
"github.com/btcsuite/btcd/txscript"
|
||||
"github.com/vulpemventures/go-elements/elementsutil"
|
||||
"github.com/vulpemventures/go-elements/psetv2"
|
||||
)
|
||||
|
||||
const (
|
||||
zero32 = "0000000000000000000000000000000000000000000000000000000000000000"
|
||||
)
|
||||
|
||||
func (s *service) SignPset(
|
||||
ctx context.Context, pset string, extractRawTx bool,
|
||||
) (string, error) {
|
||||
res, err := s.txClient.SignPset(ctx, &pb.SignPsetRequest{
|
||||
Pset: pset,
|
||||
})
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
signedPset := res.GetPset()
|
||||
|
||||
if !extractRawTx {
|
||||
return signedPset, nil
|
||||
}
|
||||
|
||||
ptx, err := psetv2.NewPsetFromBase64(signedPset)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if err := psetv2.MaybeFinalizeAll(ptx); err != nil {
|
||||
return "", fmt.Errorf("failed to finalize signed pset: %s", err)
|
||||
}
|
||||
|
||||
extractedTx, err := psetv2.Extract(ptx)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to extract signed pset: %s", err)
|
||||
}
|
||||
|
||||
txHex, err := extractedTx.ToHex()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to convert extracted tx to hex: %s", err)
|
||||
}
|
||||
|
||||
return txHex, nil
|
||||
}
|
||||
|
||||
func (s *service) SelectUtxos(ctx context.Context, asset string, amount uint64) ([]ports.TxInput, uint64, error) {
|
||||
res, err := s.txClient.SelectUtxos(ctx, &pb.SelectUtxosRequest{
|
||||
AccountName: accountLabel,
|
||||
TargetAsset: asset,
|
||||
TargetAmount: amount,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
inputs := make([]ports.TxInput, 0, len(res.GetUtxos()))
|
||||
for _, utxo := range res.GetUtxos() {
|
||||
// check that the utxos are not confidential
|
||||
if utxo.GetAssetBlinder() != zero32 || utxo.GetValueBlinder() != zero32 {
|
||||
return nil, 0, fmt.Errorf("utxo is confidential")
|
||||
}
|
||||
|
||||
inputs = append(inputs, utxo)
|
||||
}
|
||||
|
||||
return inputs, res.GetChange(), nil
|
||||
}
|
||||
|
||||
func (s *service) GetTransaction(
|
||||
ctx context.Context, txid string,
|
||||
) (string, int64, error) {
|
||||
res, err := s.txClient.GetTransaction(ctx, &pb.GetTransactionRequest{
|
||||
Txid: txid,
|
||||
})
|
||||
if err != nil {
|
||||
return "", 0, err
|
||||
}
|
||||
|
||||
if res.GetBlockDetails().GetTimestamp() > 0 {
|
||||
return res.GetTxHex(), res.BlockDetails.GetTimestamp(), nil
|
||||
}
|
||||
|
||||
// if not confirmed, we return now + 30 secs to estimate the next blocktime
|
||||
return res.GetTxHex(), time.Now().Unix() + 30, nil
|
||||
}
|
||||
|
||||
func (s *service) BroadcastTransaction(
|
||||
ctx context.Context, txHex string,
|
||||
) (string, error) {
|
||||
res, err := s.txClient.BroadcastTransaction(
|
||||
ctx, &pb.BroadcastTransactionRequest{
|
||||
TxHex: txHex,
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "non-BIP68-final") {
|
||||
return "", fmt.Errorf("non-BIP68-final")
|
||||
}
|
||||
|
||||
return "", err
|
||||
}
|
||||
return res.GetTxid(), nil
|
||||
}
|
||||
|
||||
func (s *service) IsTransactionPublished(
|
||||
ctx context.Context, txid string,
|
||||
) (bool, int64, error) {
|
||||
_, blocktime, err := s.GetTransaction(ctx, txid)
|
||||
if err != nil {
|
||||
if strings.Contains(strings.ToLower(err.Error()), "missing transaction") {
|
||||
return false, 0, nil
|
||||
}
|
||||
return false, 0, err
|
||||
}
|
||||
|
||||
return true, blocktime, nil
|
||||
}
|
||||
|
||||
func (s *service) SignPsetWithKey(ctx context.Context, b64 string, indexes []int) (string, error) {
|
||||
pset, err := psetv2.NewPsetFromBase64(b64)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if indexes == nil {
|
||||
for i := 0; i < len(pset.Inputs); i++ {
|
||||
indexes = append(indexes, i)
|
||||
}
|
||||
}
|
||||
|
||||
key, masterKey, err := s.getPubkey(ctx)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
fingerprint := binary.LittleEndian.Uint32(masterKey.FingerPrint)
|
||||
extendedKey, err := masterKey.Serialize()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
pset.Global.Xpubs = []psetv2.Xpub{{
|
||||
ExtendedKey: extendedKey[:len(extendedKey)-4],
|
||||
MasterFingerprint: fingerprint,
|
||||
DerivationPath: derivationPath,
|
||||
}}
|
||||
|
||||
updater, err := psetv2.NewUpdater(pset)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
bip32derivation := psetv2.DerivationPathWithPubKey{
|
||||
PubKey: key.SerializeCompressed(),
|
||||
MasterKeyFingerprint: fingerprint,
|
||||
Bip32Path: derivationPath,
|
||||
}
|
||||
|
||||
for _, i := range indexes {
|
||||
if len(pset.Inputs[i].TapLeafScript) == 0 {
|
||||
return "", fmt.Errorf("no tap leaf script found for input %d", i)
|
||||
}
|
||||
|
||||
leafHash := pset.Inputs[i].TapLeafScript[0].TapHash()
|
||||
|
||||
if err := updater.AddInTapBip32Derivation(i, psetv2.TapDerivationPathWithPubKey{
|
||||
DerivationPathWithPubKey: bip32derivation,
|
||||
LeafHashes: [][]byte{leafHash[:]},
|
||||
}); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if err := updater.AddInSighashType(i, txscript.SigHashDefault); err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
|
||||
unsignedPset, err := pset.ToBase64()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
signedPset, err := s.txClient.SignPsetWithSchnorrKey(ctx, &pb.SignPsetWithSchnorrKeyRequest{
|
||||
Tx: unsignedPset,
|
||||
SighashType: uint32(txscript.SigHashDefault),
|
||||
})
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return signedPset.GetSignedTx(), nil
|
||||
}
|
||||
|
||||
func (s *service) EstimateFees(
|
||||
ctx context.Context, pset string,
|
||||
) (uint64, error) {
|
||||
tx, err := psetv2.NewPsetFromBase64(pset)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
inputs := make([]*pb.Input, 0, len(tx.Inputs))
|
||||
outputs := make([]*pb.Output, 0, len(tx.Outputs))
|
||||
|
||||
for _, in := range tx.Inputs {
|
||||
pbInput := &pb.Input{
|
||||
Txid: chainhash.Hash(in.PreviousTxid).String(),
|
||||
Index: in.PreviousTxIndex,
|
||||
}
|
||||
|
||||
if len(in.TapLeafScript) == 1 {
|
||||
isSweep, _, _, err := tree.DecodeSweepScript(in.TapLeafScript[0].Script)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if isSweep {
|
||||
pbInput.WitnessSize = 64
|
||||
pbInput.ScriptsigSize = 0
|
||||
}
|
||||
} else {
|
||||
if in.WitnessUtxo == nil {
|
||||
return 0, fmt.Errorf("missing witness utxo, cannot estimate fees")
|
||||
}
|
||||
|
||||
pbInput.Script = hex.EncodeToString(in.WitnessUtxo.Script)
|
||||
}
|
||||
|
||||
inputs = append(inputs, pbInput)
|
||||
}
|
||||
|
||||
for _, out := range tx.Outputs {
|
||||
outputs = append(outputs, &pb.Output{
|
||||
Asset: elementsutil.AssetHashFromBytes(
|
||||
append([]byte{0x01}, out.Asset...),
|
||||
),
|
||||
Amount: out.Value,
|
||||
Script: hex.EncodeToString(out.Script),
|
||||
})
|
||||
}
|
||||
|
||||
fee, err := s.txClient.EstimateFees(
|
||||
ctx,
|
||||
&pb.EstimateFeesRequest{
|
||||
Inputs: inputs,
|
||||
Outputs: outputs,
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to estimate fees: %s", err)
|
||||
}
|
||||
|
||||
// we add 5 sats in order to avoid min-relay-fee not met errors
|
||||
return fee.GetFeeAmount() + 5, nil
|
||||
}
|
||||
95
server/internal/infrastructure/ocean-wallet/wallet.go
Normal file
95
server/internal/infrastructure/ocean-wallet/wallet.go
Normal file
@@ -0,0 +1,95 @@
|
||||
package oceanwallet
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
pb "github.com/ark-network/ark/api-spec/protobuf/gen/ocean/v1"
|
||||
"github.com/ark-network/ark/internal/core/ports"
|
||||
"github.com/btcsuite/btcd/btcutil/hdkeychain"
|
||||
"github.com/decred/dcrd/dcrec/secp256k1/v4"
|
||||
"github.com/vulpemventures/go-bip32"
|
||||
)
|
||||
|
||||
const accountLabel = "ark"
|
||||
|
||||
var derivationPath = []uint32{0, 0}
|
||||
|
||||
func (s *service) GetPubkey(ctx context.Context) (*secp256k1.PublicKey, error) {
|
||||
key, _, err := s.getPubkey(ctx)
|
||||
return key, err
|
||||
}
|
||||
|
||||
func (s *service) Status(
|
||||
ctx context.Context,
|
||||
) (ports.WalletStatus, error) {
|
||||
res, err := s.walletClient.Status(ctx, &pb.StatusRequest{})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return walletStatus{res}, nil
|
||||
}
|
||||
|
||||
type walletStatus struct {
|
||||
*pb.StatusResponse
|
||||
}
|
||||
|
||||
func (w walletStatus) IsInitialized() bool {
|
||||
return w.StatusResponse.GetInitialized()
|
||||
}
|
||||
func (w walletStatus) IsUnlocked() bool {
|
||||
return w.StatusResponse.GetUnlocked()
|
||||
}
|
||||
func (w walletStatus) IsSynced() bool {
|
||||
return w.StatusResponse.GetSynced()
|
||||
}
|
||||
|
||||
func (s *service) findAccount(ctx context.Context, label string) (*pb.AccountInfo, error) {
|
||||
res, err := s.walletClient.GetInfo(ctx, &pb.GetInfoRequest{})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(res.GetAccounts()) <= 0 {
|
||||
return nil, fmt.Errorf("wallet is locked")
|
||||
}
|
||||
|
||||
for _, account := range res.GetAccounts() {
|
||||
if account.GetLabel() == label {
|
||||
return account, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("account not found")
|
||||
}
|
||||
|
||||
func (s *service) getPubkey(ctx context.Context) (*secp256k1.PublicKey, *bip32.Key, error) {
|
||||
account, err := s.findAccount(ctx, accountLabel)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
xpub := account.GetXpubs()[0]
|
||||
|
||||
node, err := hdkeychain.NewKeyFromString(xpub)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
for _, i := range derivationPath {
|
||||
node, err = node.Derive(i)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
}
|
||||
key, err := node.ECPubKey()
|
||||
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
masterKey, err := bip32.B58Deserialize(xpub)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
return key, masterKey, nil
|
||||
}
|
||||
45
server/internal/infrastructure/scheduler/gocron/service.go
Normal file
45
server/internal/infrastructure/scheduler/gocron/service.go
Normal file
@@ -0,0 +1,45 @@
|
||||
package scheduler
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/ark-network/ark/internal/core/ports"
|
||||
"github.com/go-co-op/gocron"
|
||||
)
|
||||
|
||||
type service struct {
|
||||
scheduler *gocron.Scheduler
|
||||
}
|
||||
|
||||
func NewScheduler() ports.SchedulerService {
|
||||
svc := gocron.NewScheduler(time.UTC)
|
||||
return &service{svc}
|
||||
}
|
||||
|
||||
func (s *service) Start() {
|
||||
s.scheduler.StartAsync()
|
||||
}
|
||||
|
||||
func (s *service) Stop() {
|
||||
s.scheduler.Stop()
|
||||
}
|
||||
|
||||
func (s *service) ScheduleTask(interval int64, immediate bool, task func()) error {
|
||||
if immediate {
|
||||
_, err := s.scheduler.Every(int(interval)).Seconds().Do(task)
|
||||
return err
|
||||
}
|
||||
_, err := s.scheduler.Every(int(interval)).Seconds().WaitForSchedule().Do(task)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *service) ScheduleTaskOnce(at int64, task func()) error {
|
||||
delay := at - time.Now().Unix()
|
||||
if delay < 0 {
|
||||
return fmt.Errorf("cannot schedule task in the past")
|
||||
}
|
||||
|
||||
_, err := s.scheduler.Every(int(delay)).Seconds().WaitForSchedule().LimitRunsTo(1).Do(task)
|
||||
return err
|
||||
}
|
||||
483
server/internal/infrastructure/tx-builder/covenant/builder.go
Normal file
483
server/internal/infrastructure/tx-builder/covenant/builder.go
Normal file
@@ -0,0 +1,483 @@
|
||||
package txbuilder
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
|
||||
"github.com/ark-network/ark/common/tree"
|
||||
"github.com/ark-network/ark/internal/core/domain"
|
||||
"github.com/ark-network/ark/internal/core/ports"
|
||||
"github.com/decred/dcrd/dcrec/secp256k1/v4"
|
||||
"github.com/vulpemventures/go-elements/address"
|
||||
"github.com/vulpemventures/go-elements/network"
|
||||
"github.com/vulpemventures/go-elements/psetv2"
|
||||
"github.com/vulpemventures/go-elements/taproot"
|
||||
)
|
||||
|
||||
const (
|
||||
connectorAmount = uint64(450)
|
||||
dustLimit = uint64(450)
|
||||
)
|
||||
|
||||
type txBuilder struct {
|
||||
wallet ports.WalletService
|
||||
net *network.Network
|
||||
roundLifetime int64 // in seconds
|
||||
}
|
||||
|
||||
func NewTxBuilder(
|
||||
wallet ports.WalletService, net network.Network, roundLifetime int64,
|
||||
) ports.TxBuilder {
|
||||
return &txBuilder{wallet, &net, roundLifetime}
|
||||
}
|
||||
|
||||
func (b *txBuilder) GetVtxoScript(userPubkey, aspPubkey *secp256k1.PublicKey) ([]byte, error) {
|
||||
outputScript, _, err := b.getLeafScriptAndTree(userPubkey, aspPubkey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return outputScript, nil
|
||||
}
|
||||
|
||||
func (b *txBuilder) BuildSweepTx(wallet ports.WalletService, inputs []ports.SweepInput) (signedSweepTx string, err error) {
|
||||
sweepPset, err := sweepTransaction(
|
||||
wallet,
|
||||
inputs,
|
||||
b.net.AssetID,
|
||||
)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
sweepPsetBase64, err := sweepPset.ToBase64()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
signedSweepPsetB64, err := wallet.SignPsetWithKey(ctx, sweepPsetBase64, nil)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
signedPset, err := psetv2.NewPsetFromBase64(signedSweepPsetB64)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if err := psetv2.FinalizeAll(signedPset); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
extractedTx, err := psetv2.Extract(signedPset)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return extractedTx.ToHex()
|
||||
}
|
||||
|
||||
func (b *txBuilder) BuildForfeitTxs(
|
||||
aspPubkey *secp256k1.PublicKey, poolTx string, payments []domain.Payment,
|
||||
) (connectors []string, forfeitTxs []string, err error) {
|
||||
connectorTxs, err := b.createConnectors(poolTx, payments, aspPubkey)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
forfeitTxs, err = b.createForfeitTxs(aspPubkey, payments, connectorTxs)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
for _, tx := range connectorTxs {
|
||||
buf, _ := tx.ToBase64()
|
||||
connectors = append(connectors, buf)
|
||||
}
|
||||
return connectors, forfeitTxs, nil
|
||||
}
|
||||
|
||||
func (b *txBuilder) BuildPoolTx(
|
||||
aspPubkey *secp256k1.PublicKey, payments []domain.Payment, minRelayFee uint64,
|
||||
) (poolTx string, congestionTree tree.CongestionTree, err error) {
|
||||
// The creation of the tree and the pool tx are tightly coupled:
|
||||
// - building the tree requires knowing the shared outpoint (txid:vout)
|
||||
// - building the pool tx requires knowing the shared output script and amount
|
||||
// The idea here is to first create all the data for the outputs of the txs
|
||||
// of the congestion tree to calculate the shared output script and amount.
|
||||
// With these data the pool tx can be created, and once the shared utxo
|
||||
// outpoint is obtained, the congestion tree can be finally created.
|
||||
// The factory function `treeFactoryFn` returned below holds all outputs data
|
||||
// generated in the process and takes the shared utxo outpoint as argument.
|
||||
// This is safe as the memory allocated for `craftCongestionTree` is freed
|
||||
// only after `BuildPoolTx` returns.
|
||||
treeFactoryFn, sharedOutputScript, sharedOutputAmount, err := craftCongestionTree(
|
||||
b.net.AssetID, aspPubkey, payments, minRelayFee, b.roundLifetime,
|
||||
)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
ptx, err := b.createPoolTx(
|
||||
sharedOutputAmount, sharedOutputScript, payments, aspPubkey,
|
||||
)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
unsignedTx, err := ptx.UnsignedTx()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
tree, err := treeFactoryFn(psetv2.InputArgs{
|
||||
Txid: unsignedTx.TxHash().String(),
|
||||
TxIndex: 0,
|
||||
})
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
poolTx, err = ptx.ToBase64()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
congestionTree = tree
|
||||
return
|
||||
}
|
||||
|
||||
func (b *txBuilder) GetLeafSweepClosure(
|
||||
node tree.Node, userPubKey *secp256k1.PublicKey,
|
||||
) (*psetv2.TapLeafScript, int64, error) {
|
||||
if !node.Leaf {
|
||||
return nil, 0, fmt.Errorf("node is not a leaf")
|
||||
}
|
||||
|
||||
pset, err := psetv2.NewPsetFromBase64(node.Tx)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
input := pset.Inputs[0]
|
||||
|
||||
sweepLeaf, lifetime, err := extractSweepLeaf(input)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// craft the vtxo taproot tree
|
||||
vtxoScript, err := tree.VtxoScript(userPubKey)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
vtxoTaprootTree := taproot.AssembleTaprootScriptTree(
|
||||
*vtxoScript,
|
||||
sweepLeaf.TapElementsLeaf,
|
||||
)
|
||||
|
||||
proofIndex := vtxoTaprootTree.LeafProofIndex[sweepLeaf.TapHash()]
|
||||
proof := vtxoTaprootTree.LeafMerkleProofs[proofIndex]
|
||||
|
||||
return &psetv2.TapLeafScript{
|
||||
TapElementsLeaf: proof.TapElementsLeaf,
|
||||
ControlBlock: proof.ToControlBlock(sweepLeaf.ControlBlock.InternalKey),
|
||||
}, lifetime, nil
|
||||
}
|
||||
|
||||
func (b *txBuilder) getLeafScriptAndTree(
|
||||
userPubkey, aspPubkey *secp256k1.PublicKey,
|
||||
) ([]byte, *taproot.IndexedElementsTapScriptTree, error) {
|
||||
redeemClosure, err := tree.VtxoScript(userPubkey)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
sweepClosure, err := tree.SweepScript(aspPubkey, uint(b.roundLifetime))
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
taprootTree := taproot.AssembleTaprootScriptTree(
|
||||
*redeemClosure, *sweepClosure,
|
||||
)
|
||||
|
||||
root := taprootTree.RootNode.TapHash()
|
||||
unspendableKey := tree.UnspendableKey()
|
||||
taprootKey := taproot.ComputeTaprootOutputKey(unspendableKey, root[:])
|
||||
|
||||
outputScript, err := taprootOutputScript(taprootKey)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
return outputScript, taprootTree, nil
|
||||
}
|
||||
|
||||
func (b *txBuilder) createPoolTx(
|
||||
sharedOutputAmount uint64, sharedOutputScript []byte,
|
||||
payments []domain.Payment, aspPubKey *secp256k1.PublicKey,
|
||||
) (*psetv2.Pset, error) {
|
||||
aspScript, err := p2wpkhScript(aspPubKey, b.net)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
receivers := getOnchainReceivers(payments)
|
||||
connectorsAmount := connectorAmount * countSpentVtxos(payments)
|
||||
targetAmount := sharedOutputAmount + connectorsAmount
|
||||
|
||||
outputs := []psetv2.OutputArgs{
|
||||
{
|
||||
Asset: b.net.AssetID,
|
||||
Amount: sharedOutputAmount,
|
||||
Script: sharedOutputScript,
|
||||
},
|
||||
{
|
||||
Asset: b.net.AssetID,
|
||||
Amount: connectorsAmount,
|
||||
Script: aspScript,
|
||||
},
|
||||
}
|
||||
|
||||
for _, receiver := range receivers {
|
||||
targetAmount += receiver.Amount
|
||||
|
||||
receiverScript, err := address.ToOutputScript(receiver.OnchainAddress)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
outputs = append(outputs, psetv2.OutputArgs{
|
||||
Asset: b.net.AssetID,
|
||||
Amount: receiver.Amount,
|
||||
Script: receiverScript,
|
||||
})
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
utxos, change, err := b.wallet.SelectUtxos(ctx, b.net.AssetID, targetAmount)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var dust uint64
|
||||
if change > 0 {
|
||||
if change < dustLimit {
|
||||
dust = change
|
||||
change = 0
|
||||
} else {
|
||||
outputs = append(outputs, psetv2.OutputArgs{
|
||||
Asset: b.net.AssetID,
|
||||
Amount: change,
|
||||
Script: aspScript,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
ptx, err := psetv2.New(nil, outputs, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
updater, err := psetv2.NewUpdater(ptx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := addInputs(updater, utxos); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
b64, err := ptx.ToBase64()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
feeAmount, err := b.wallet.EstimateFees(ctx, b64)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if dust > feeAmount {
|
||||
feeAmount = dust
|
||||
} else {
|
||||
feeAmount += dust
|
||||
}
|
||||
|
||||
if dust == 0 {
|
||||
if feeAmount == change {
|
||||
// fees = change, remove change output
|
||||
ptx.Outputs = ptx.Outputs[:len(ptx.Outputs)-1]
|
||||
} else if feeAmount < change {
|
||||
// change covers the fees, reduce change amount
|
||||
ptx.Outputs[len(ptx.Outputs)-1].Value = change - feeAmount
|
||||
} else {
|
||||
// change is not enough to cover fees, re-select utxos
|
||||
if change > 0 {
|
||||
// remove change output if present
|
||||
ptx.Outputs = ptx.Outputs[:len(ptx.Outputs)-1]
|
||||
}
|
||||
newUtxos, change, err := b.wallet.SelectUtxos(ctx, b.net.AssetID, feeAmount-change)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if change > 0 {
|
||||
if err := updater.AddOutputs([]psetv2.OutputArgs{
|
||||
{
|
||||
Asset: b.net.AssetID,
|
||||
Amount: change,
|
||||
Script: aspScript,
|
||||
},
|
||||
}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if err := addInputs(updater, newUtxos); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// add fee output
|
||||
if err := updater.AddOutputs([]psetv2.OutputArgs{
|
||||
{
|
||||
Asset: b.net.AssetID,
|
||||
Amount: feeAmount,
|
||||
},
|
||||
}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return ptx, nil
|
||||
}
|
||||
|
||||
func (b *txBuilder) createConnectors(
|
||||
poolTx string, payments []domain.Payment, aspPubkey *secp256k1.PublicKey,
|
||||
) ([]*psetv2.Pset, error) {
|
||||
txid, _ := getTxid(poolTx)
|
||||
|
||||
aspScript, err := p2wpkhScript(aspPubkey, b.net)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
connectorOutput := psetv2.OutputArgs{
|
||||
Asset: b.net.AssetID,
|
||||
Script: aspScript,
|
||||
Amount: connectorAmount,
|
||||
}
|
||||
|
||||
numberOfConnectors := countSpentVtxos(payments)
|
||||
|
||||
previousInput := psetv2.InputArgs{
|
||||
Txid: txid,
|
||||
TxIndex: 1,
|
||||
}
|
||||
|
||||
if numberOfConnectors == 1 {
|
||||
outputs := []psetv2.OutputArgs{connectorOutput}
|
||||
connectorTx, err := craftConnectorTx(previousInput, outputs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return []*psetv2.Pset{connectorTx}, nil
|
||||
}
|
||||
|
||||
totalConnectorAmount := connectorAmount * numberOfConnectors
|
||||
|
||||
connectors := make([]*psetv2.Pset, 0, numberOfConnectors-1)
|
||||
for i := uint64(0); i < numberOfConnectors-1; i++ {
|
||||
outputs := []psetv2.OutputArgs{connectorOutput}
|
||||
totalConnectorAmount -= connectorAmount
|
||||
if totalConnectorAmount > 0 {
|
||||
outputs = append(outputs, psetv2.OutputArgs{
|
||||
Asset: b.net.AssetID,
|
||||
Script: aspScript,
|
||||
Amount: totalConnectorAmount,
|
||||
})
|
||||
}
|
||||
connectorTx, err := craftConnectorTx(previousInput, outputs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
txid, _ := getPsetId(connectorTx)
|
||||
|
||||
previousInput = psetv2.InputArgs{
|
||||
Txid: txid,
|
||||
TxIndex: 1,
|
||||
}
|
||||
|
||||
connectors = append(connectors, connectorTx)
|
||||
}
|
||||
|
||||
return connectors, nil
|
||||
}
|
||||
|
||||
func (b *txBuilder) createForfeitTxs(
|
||||
aspPubkey *secp256k1.PublicKey, payments []domain.Payment, connectors []*psetv2.Pset,
|
||||
) ([]string, error) {
|
||||
aspScript, err := p2wpkhScript(aspPubkey, b.net)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
forfeitTxs := make([]string, 0)
|
||||
for _, payment := range payments {
|
||||
for _, vtxo := range payment.Inputs {
|
||||
pubkeyBytes, err := hex.DecodeString(vtxo.Pubkey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decode pubkey: %s", err)
|
||||
}
|
||||
|
||||
vtxoPubkey, err := secp256k1.ParsePubKey(pubkeyBytes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
vtxoScript, vtxoTaprootTree, err := b.getLeafScriptAndTree(vtxoPubkey, aspPubkey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, connector := range connectors {
|
||||
txs, err := craftForfeitTxs(
|
||||
connector, vtxo, vtxoTaprootTree, vtxoScript, aspScript,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
forfeitTxs = append(forfeitTxs, txs...)
|
||||
}
|
||||
}
|
||||
}
|
||||
return forfeitTxs, nil
|
||||
}
|
||||
|
||||
// given a congestion tree input, searches and returns the sweep leaf and its lifetime in seconds
|
||||
func extractSweepLeaf(input psetv2.Input) (sweepLeaf *psetv2.TapLeafScript, lifetime int64, err error) {
|
||||
for _, leaf := range input.TapLeafScript {
|
||||
isSweep, _, seconds, err := tree.DecodeSweepScript(leaf.Script)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
if isSweep {
|
||||
lifetime = int64(seconds)
|
||||
sweepLeaf = &leaf
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if sweepLeaf == nil {
|
||||
return nil, 0, fmt.Errorf("sweep leaf not found")
|
||||
}
|
||||
|
||||
return sweepLeaf, lifetime, nil
|
||||
}
|
||||
@@ -0,0 +1,229 @@
|
||||
package txbuilder_test
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/ark-network/ark/common"
|
||||
"github.com/ark-network/ark/common/tree"
|
||||
"github.com/ark-network/ark/internal/core/domain"
|
||||
"github.com/ark-network/ark/internal/core/ports"
|
||||
txbuilder "github.com/ark-network/ark/internal/infrastructure/tx-builder/covenant"
|
||||
"github.com/btcsuite/btcd/chaincfg/chainhash"
|
||||
"github.com/decred/dcrd/dcrec/secp256k1/v4"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/vulpemventures/go-elements/network"
|
||||
"github.com/vulpemventures/go-elements/psetv2"
|
||||
)
|
||||
|
||||
const (
|
||||
testingKey = "apub1qgvdtj5ttpuhkldavhq8thtm5auyk0ec4dcmrfdgu0u5hgp9we22v3hrs4x"
|
||||
minRelayFee = uint64(30)
|
||||
roundLifetime = int64(1209344)
|
||||
)
|
||||
|
||||
var (
|
||||
wallet *mockedWallet
|
||||
pubkey *secp256k1.PublicKey
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
wallet = &mockedWallet{}
|
||||
wallet.On("EstimateFees", mock.Anything, mock.Anything).
|
||||
Return(uint64(100), nil)
|
||||
wallet.On("SelectUtxos", mock.Anything, mock.Anything, mock.Anything).
|
||||
Return(randomInput, uint64(0), nil)
|
||||
|
||||
_, pubkey, _ = common.DecodePubKey(testingKey)
|
||||
|
||||
os.Exit(m.Run())
|
||||
}
|
||||
|
||||
func TestBuildPoolTx(t *testing.T) {
|
||||
builder := txbuilder.NewTxBuilder(wallet, network.Liquid, roundLifetime)
|
||||
|
||||
fixtures, err := parsePoolTxFixtures()
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, fixtures)
|
||||
|
||||
if len(fixtures.Valid) > 0 {
|
||||
t.Run("valid", func(t *testing.T) {
|
||||
for _, f := range fixtures.Valid {
|
||||
poolTx, congestionTree, err := builder.BuildPoolTx(pubkey, f.Payments, minRelayFee)
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, poolTx)
|
||||
require.NotEmpty(t, congestionTree)
|
||||
require.Equal(t, f.ExpectedNumOfNodes, congestionTree.NumberOfNodes())
|
||||
require.Len(t, congestionTree.Leaves(), f.ExpectedNumOfLeaves)
|
||||
|
||||
err = tree.ValidateCongestionTree(congestionTree, poolTx, pubkey, roundLifetime)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
if len(fixtures.Invalid) > 0 {
|
||||
t.Run("invalid", func(t *testing.T) {
|
||||
for _, f := range fixtures.Invalid {
|
||||
poolTx, congestionTree, err := builder.BuildPoolTx(pubkey, f.Payments, minRelayFee)
|
||||
require.EqualError(t, err, f.ExpectedErr)
|
||||
require.Empty(t, poolTx)
|
||||
require.Empty(t, congestionTree)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildForfeitTxs(t *testing.T) {
|
||||
builder := txbuilder.NewTxBuilder(wallet, network.Liquid, 1209344)
|
||||
|
||||
fixtures, err := parseForfeitTxsFixtures()
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, fixtures)
|
||||
|
||||
if len(fixtures.Valid) > 0 {
|
||||
t.Run("valid", func(t *testing.T) {
|
||||
for _, f := range fixtures.Valid {
|
||||
connectors, forfeitTxs, err := builder.BuildForfeitTxs(
|
||||
pubkey, f.PoolTx, f.Payments,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, connectors, f.ExpectedNumOfConnectors)
|
||||
require.Len(t, forfeitTxs, f.ExpectedNumOfForfeitTxs)
|
||||
|
||||
expectedInputTxid := f.PoolTxid
|
||||
// Verify the chain of connectors
|
||||
for _, connector := range connectors {
|
||||
tx, err := psetv2.NewPsetFromBase64(connector)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, tx)
|
||||
|
||||
require.Len(t, tx.Inputs, 1)
|
||||
require.Len(t, tx.Outputs, 2)
|
||||
|
||||
inputTxid := chainhash.Hash(tx.Inputs[0].PreviousTxid).String()
|
||||
require.Equal(t, expectedInputTxid, inputTxid)
|
||||
require.Equal(t, 1, int(tx.Inputs[0].PreviousTxIndex))
|
||||
|
||||
expectedInputTxid = getTxid(tx)
|
||||
}
|
||||
|
||||
// decode and check forfeit txs
|
||||
for _, forfeitTx := range forfeitTxs {
|
||||
tx, err := psetv2.NewPsetFromBase64(forfeitTx)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, tx.Inputs, 2)
|
||||
require.Len(t, tx.Outputs, 2)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
if len(fixtures.Invalid) > 0 {
|
||||
t.Run("invalid", func(t *testing.T) {
|
||||
for _, f := range fixtures.Invalid {
|
||||
connectors, forfeitTxs, err := builder.BuildForfeitTxs(
|
||||
pubkey, f.PoolTx, f.Payments,
|
||||
)
|
||||
require.EqualError(t, err, f.ExpectedErr)
|
||||
require.Empty(t, connectors)
|
||||
require.Empty(t, forfeitTxs)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func randomInput() []ports.TxInput {
|
||||
txid := randomHex(32)
|
||||
input := &mockedInput{}
|
||||
input.On("GetAsset").Return("5ac9f65c0efcc4775e0baec4ec03abdde22473cd3cf33c0419ca290e0751b225")
|
||||
input.On("GetValue").Return(uint64(1000))
|
||||
input.On("GetScript").Return("a914ea9f486e82efb3dd83a69fd96e3f0113757da03c87")
|
||||
input.On("GetTxid").Return(txid)
|
||||
input.On("GetIndex").Return(uint32(0))
|
||||
|
||||
return []ports.TxInput{input}
|
||||
}
|
||||
|
||||
func randomHex(len int) string {
|
||||
buf := make([]byte, len)
|
||||
// nolint
|
||||
rand.Read(buf)
|
||||
return hex.EncodeToString(buf)
|
||||
}
|
||||
|
||||
type poolTxFixtures struct {
|
||||
Valid []struct {
|
||||
Payments []domain.Payment
|
||||
ExpectedNumOfNodes int
|
||||
ExpectedNumOfLeaves int
|
||||
}
|
||||
Invalid []struct {
|
||||
Payments []domain.Payment
|
||||
ExpectedErr string
|
||||
}
|
||||
}
|
||||
|
||||
func parsePoolTxFixtures() (*poolTxFixtures, error) {
|
||||
file, err := os.ReadFile("testdata/fixtures.json")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
v := map[string]interface{}{}
|
||||
if err := json.Unmarshal(file, &v); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
vv := v["buildPoolTx"].(map[string]interface{})
|
||||
file, _ = json.Marshal(vv)
|
||||
var fixtures poolTxFixtures
|
||||
if err := json.Unmarshal(file, &fixtures); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &fixtures, nil
|
||||
}
|
||||
|
||||
type forfeitTxsFixtures struct {
|
||||
Valid []struct {
|
||||
Payments []domain.Payment
|
||||
ExpectedNumOfConnectors int
|
||||
ExpectedNumOfForfeitTxs int
|
||||
PoolTx string
|
||||
PoolTxid string
|
||||
}
|
||||
Invalid []struct {
|
||||
Payments []domain.Payment
|
||||
ExpectedErr string
|
||||
PoolTx string
|
||||
}
|
||||
}
|
||||
|
||||
func parseForfeitTxsFixtures() (*forfeitTxsFixtures, error) {
|
||||
file, err := os.ReadFile("testdata/fixtures.json")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
v := map[string]interface{}{}
|
||||
if err := json.Unmarshal(file, &v); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
vv := v["buildForfeitTxs"].(map[string]interface{})
|
||||
file, _ = json.Marshal(vv)
|
||||
var fixtures forfeitTxsFixtures
|
||||
if err := json.Unmarshal(file, &fixtures); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &fixtures, nil
|
||||
}
|
||||
|
||||
func getTxid(tx *psetv2.Pset) string {
|
||||
utx, _ := tx.UnsignedTx()
|
||||
return utx.TxHash().String()
|
||||
}
|
||||
@@ -0,0 +1,48 @@
|
||||
package txbuilder
|
||||
|
||||
import (
|
||||
"github.com/vulpemventures/go-elements/psetv2"
|
||||
"github.com/vulpemventures/go-elements/transaction"
|
||||
)
|
||||
|
||||
func craftConnectorTx(
|
||||
input psetv2.InputArgs, outputs []psetv2.OutputArgs,
|
||||
) (*psetv2.Pset, error) {
|
||||
ptx, _ := psetv2.New(nil, nil, nil)
|
||||
updater, _ := psetv2.NewUpdater(ptx)
|
||||
|
||||
if err := updater.AddInputs(
|
||||
[]psetv2.InputArgs{input},
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// TODO: add prevout.
|
||||
|
||||
if err := updater.AddOutputs(outputs); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return ptx, nil
|
||||
}
|
||||
|
||||
func getConnectorInputs(pset *psetv2.Pset) ([]psetv2.InputArgs, []*transaction.TxOutput) {
|
||||
txID, _ := getPsetId(pset)
|
||||
|
||||
inputs := make([]psetv2.InputArgs, 0, len(pset.Outputs))
|
||||
witnessUtxos := make([]*transaction.TxOutput, 0, len(pset.Outputs))
|
||||
|
||||
for i, output := range pset.Outputs {
|
||||
utx, _ := pset.UnsignedTx()
|
||||
|
||||
if output.Value == connectorAmount && len(output.Script) > 0 {
|
||||
inputs = append(inputs, psetv2.InputArgs{
|
||||
Txid: txID,
|
||||
TxIndex: uint32(i),
|
||||
})
|
||||
witnessUtxos = append(witnessUtxos, utx.Outputs[i])
|
||||
}
|
||||
}
|
||||
|
||||
return inputs, witnessUtxos
|
||||
}
|
||||
104
server/internal/infrastructure/tx-builder/covenant/forfeit.go
Normal file
104
server/internal/infrastructure/tx-builder/covenant/forfeit.go
Normal file
@@ -0,0 +1,104 @@
|
||||
package txbuilder
|
||||
|
||||
import (
|
||||
"github.com/ark-network/ark/common/tree"
|
||||
"github.com/ark-network/ark/internal/core/domain"
|
||||
"github.com/btcsuite/btcd/txscript"
|
||||
"github.com/vulpemventures/go-elements/elementsutil"
|
||||
"github.com/vulpemventures/go-elements/psetv2"
|
||||
"github.com/vulpemventures/go-elements/taproot"
|
||||
"github.com/vulpemventures/go-elements/transaction"
|
||||
)
|
||||
|
||||
func craftForfeitTxs(
|
||||
connectorTx *psetv2.Pset,
|
||||
vtxo domain.Vtxo,
|
||||
vtxoTaprootTree *taproot.IndexedElementsTapScriptTree,
|
||||
vtxoScript, aspScript []byte,
|
||||
) (forfeitTxs []string, err error) {
|
||||
connectors, prevouts := getConnectorInputs(connectorTx)
|
||||
|
||||
for i, connectorInput := range connectors {
|
||||
connectorPrevout := prevouts[i]
|
||||
asset := elementsutil.AssetHashFromBytes(connectorPrevout.Asset)
|
||||
|
||||
pset, err := psetv2.New(nil, nil, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
updater, err := psetv2.NewUpdater(pset)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
vtxoInput := psetv2.InputArgs{
|
||||
Txid: vtxo.Txid,
|
||||
TxIndex: vtxo.VOut,
|
||||
}
|
||||
|
||||
vtxoAmount, _ := elementsutil.ValueToBytes(vtxo.Amount)
|
||||
vtxoPrevout := &transaction.TxOutput{
|
||||
Asset: connectorPrevout.Asset,
|
||||
Value: vtxoAmount,
|
||||
Script: vtxoScript,
|
||||
}
|
||||
|
||||
if err := updater.AddInputs([]psetv2.InputArgs{connectorInput, vtxoInput}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err = updater.AddInWitnessUtxo(0, connectorPrevout); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := updater.AddInSighashType(0, txscript.SigHashAll); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err = updater.AddInWitnessUtxo(1, vtxoPrevout); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := updater.AddInSighashType(1, txscript.SigHashDefault); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
unspendableKey := tree.UnspendableKey()
|
||||
|
||||
for _, proof := range vtxoTaprootTree.LeafMerkleProofs {
|
||||
tapScript := psetv2.NewTapLeafScript(proof, unspendableKey)
|
||||
if err := updater.AddInTapLeafScript(1, tapScript); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
connectorAmount, err := elementsutil.ValueFromBytes(connectorPrevout.Value)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = updater.AddOutputs([]psetv2.OutputArgs{
|
||||
{
|
||||
Asset: asset,
|
||||
Amount: vtxo.Amount + connectorAmount - 30,
|
||||
Script: aspScript,
|
||||
},
|
||||
{
|
||||
Asset: asset,
|
||||
Amount: 30,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tx, err := pset.ToBase64()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
forfeitTxs = append(forfeitTxs, tx)
|
||||
}
|
||||
return forfeitTxs, nil
|
||||
}
|
||||
203
server/internal/infrastructure/tx-builder/covenant/mocks_test.go
Normal file
203
server/internal/infrastructure/tx-builder/covenant/mocks_test.go
Normal file
@@ -0,0 +1,203 @@
|
||||
package txbuilder_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/ark-network/ark/internal/core/domain"
|
||||
"github.com/ark-network/ark/internal/core/ports"
|
||||
"github.com/decred/dcrd/dcrec/secp256k1/v4"
|
||||
"github.com/stretchr/testify/mock"
|
||||
)
|
||||
|
||||
type mockedWallet struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
// BroadcastTransaction implements ports.WalletService.
|
||||
func (m *mockedWallet) BroadcastTransaction(ctx context.Context, txHex string) (string, error) {
|
||||
args := m.Called(ctx, txHex)
|
||||
|
||||
var res string
|
||||
if a := args.Get(0); a != nil {
|
||||
res = a.(string)
|
||||
}
|
||||
return res, args.Error(1)
|
||||
}
|
||||
|
||||
// Close implements ports.WalletService.
|
||||
func (m *mockedWallet) Close() {
|
||||
m.Called()
|
||||
}
|
||||
|
||||
// DeriveAddresses implements ports.WalletService.
|
||||
func (m *mockedWallet) DeriveAddresses(ctx context.Context, num int) ([]string, error) {
|
||||
args := m.Called(ctx, num)
|
||||
|
||||
var res []string
|
||||
if a := args.Get(0); a != nil {
|
||||
res = a.([]string)
|
||||
}
|
||||
return res, args.Error(1)
|
||||
}
|
||||
|
||||
// GetPubkey implements ports.WalletService.
|
||||
func (m *mockedWallet) GetPubkey(ctx context.Context) (*secp256k1.PublicKey, error) {
|
||||
args := m.Called(ctx)
|
||||
|
||||
var res *secp256k1.PublicKey
|
||||
if a := args.Get(0); a != nil {
|
||||
res = a.(*secp256k1.PublicKey)
|
||||
}
|
||||
return res, args.Error(1)
|
||||
}
|
||||
|
||||
// SignPset implements ports.WalletService.
|
||||
func (m *mockedWallet) SignPset(ctx context.Context, pset string, extractRawTx bool) (string, error) {
|
||||
args := m.Called(ctx, pset, extractRawTx)
|
||||
|
||||
var res string
|
||||
if a := args.Get(0); a != nil {
|
||||
res = a.(string)
|
||||
}
|
||||
return res, args.Error(1)
|
||||
}
|
||||
|
||||
// Status implements ports.WalletService.
|
||||
func (m *mockedWallet) Status(ctx context.Context) (ports.WalletStatus, error) {
|
||||
args := m.Called(ctx)
|
||||
|
||||
var res ports.WalletStatus
|
||||
if a := args.Get(0); a != nil {
|
||||
res = a.(ports.WalletStatus)
|
||||
}
|
||||
return res, args.Error(1)
|
||||
}
|
||||
|
||||
func (m *mockedWallet) SelectUtxos(ctx context.Context, asset string, amount uint64) ([]ports.TxInput, uint64, error) {
|
||||
args := m.Called(ctx, asset, amount)
|
||||
|
||||
var res0 func() []ports.TxInput
|
||||
if a := args.Get(0); a != nil {
|
||||
res0 = a.(func() []ports.TxInput)
|
||||
}
|
||||
var res1 uint64
|
||||
if a := args.Get(1); a != nil {
|
||||
res1 = a.(uint64)
|
||||
}
|
||||
return res0(), res1, args.Error(2)
|
||||
}
|
||||
|
||||
func (m *mockedWallet) EstimateFees(ctx context.Context, pset string) (uint64, error) {
|
||||
args := m.Called(ctx, pset)
|
||||
|
||||
var res uint64
|
||||
if a := args.Get(0); a != nil {
|
||||
res = a.(uint64)
|
||||
}
|
||||
return res, args.Error(1)
|
||||
}
|
||||
|
||||
func (m *mockedWallet) IsTransactionPublished(ctx context.Context, txid string) (bool, int64, error) {
|
||||
args := m.Called(ctx, txid)
|
||||
|
||||
var res bool
|
||||
if a := args.Get(0); a != nil {
|
||||
res = a.(bool)
|
||||
}
|
||||
|
||||
var blocktime int64
|
||||
if b := args.Get(1); b != nil {
|
||||
blocktime = b.(int64)
|
||||
}
|
||||
|
||||
return res, blocktime, args.Error(2)
|
||||
}
|
||||
|
||||
func (m *mockedWallet) SignPsetWithKey(ctx context.Context, pset string, inputIndexes []int) (string, error) {
|
||||
args := m.Called(ctx, pset, inputIndexes)
|
||||
|
||||
var res string
|
||||
if a := args.Get(0); a != nil {
|
||||
res = a.(string)
|
||||
}
|
||||
return res, args.Error(1)
|
||||
}
|
||||
|
||||
func (m *mockedWallet) WatchScripts(
|
||||
ctx context.Context, scripts []string,
|
||||
) error {
|
||||
args := m.Called(ctx, scripts)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m *mockedWallet) UnwatchScripts(
|
||||
ctx context.Context, scripts []string,
|
||||
) error {
|
||||
args := m.Called(ctx, scripts)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m *mockedWallet) GetNotificationChannel(ctx context.Context) chan []domain.VtxoKey {
|
||||
args := m.Called(ctx)
|
||||
|
||||
var res chan []domain.VtxoKey
|
||||
if a := args.Get(0); a != nil {
|
||||
res = a.(chan []domain.VtxoKey)
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
type mockedInput struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *mockedInput) GetTxid() string {
|
||||
args := m.Called()
|
||||
|
||||
var res string
|
||||
if a := args.Get(0); a != nil {
|
||||
res = a.(string)
|
||||
}
|
||||
|
||||
return res
|
||||
}
|
||||
|
||||
func (m *mockedInput) GetIndex() uint32 {
|
||||
args := m.Called()
|
||||
|
||||
var res uint32
|
||||
if a := args.Get(0); a != nil {
|
||||
res = a.(uint32)
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
func (m *mockedInput) GetScript() string {
|
||||
args := m.Called()
|
||||
|
||||
var res string
|
||||
if a := args.Get(0); a != nil {
|
||||
res = a.(string)
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
func (m *mockedInput) GetAsset() string {
|
||||
args := m.Called()
|
||||
|
||||
var res string
|
||||
if a := args.Get(0); a != nil {
|
||||
res = a.(string)
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
func (m *mockedInput) GetValue() uint64 {
|
||||
args := m.Called()
|
||||
|
||||
var res uint64
|
||||
if a := args.Get(0); a != nil {
|
||||
res = a.(uint64)
|
||||
}
|
||||
return res
|
||||
}
|
||||
135
server/internal/infrastructure/tx-builder/covenant/sweep.go
Normal file
135
server/internal/infrastructure/tx-builder/covenant/sweep.go
Normal file
@@ -0,0 +1,135 @@
|
||||
package txbuilder
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/ark-network/ark/common"
|
||||
"github.com/ark-network/ark/common/tree"
|
||||
"github.com/ark-network/ark/internal/core/ports"
|
||||
"github.com/vulpemventures/go-elements/address"
|
||||
"github.com/vulpemventures/go-elements/elementsutil"
|
||||
"github.com/vulpemventures/go-elements/psetv2"
|
||||
"github.com/vulpemventures/go-elements/taproot"
|
||||
"github.com/vulpemventures/go-elements/transaction"
|
||||
)
|
||||
|
||||
func sweepTransaction(
|
||||
wallet ports.WalletService,
|
||||
sweepInputs []ports.SweepInput,
|
||||
lbtc string,
|
||||
) (*psetv2.Pset, error) {
|
||||
sweepPset, err := psetv2.New(nil, nil, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
updater, err := psetv2.NewUpdater(sweepPset)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
amount := uint64(0)
|
||||
|
||||
for i, input := range sweepInputs {
|
||||
leaf := input.SweepLeaf
|
||||
isSweep, _, lifetime, err := tree.DecodeSweepScript(leaf.Script)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if isSweep {
|
||||
amount += input.Amount
|
||||
|
||||
if err := updater.AddInputs([]psetv2.InputArgs{input.InputArgs}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := updater.AddInTapLeafScript(i, leaf); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
assetHash, err := elementsutil.AssetHashToBytes(lbtc)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
value, err := elementsutil.ValueToBytes(input.Amount)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
root := leaf.ControlBlock.RootHash(leaf.Script)
|
||||
taprootKey := taproot.ComputeTaprootOutputKey(leaf.ControlBlock.InternalKey, root)
|
||||
script, err := taprootOutputScript(taprootKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
witnessUtxo := transaction.NewTxOutput(assetHash, value, script)
|
||||
|
||||
if err := updater.AddInWitnessUtxo(i, witnessUtxo); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
sequence, err := common.BIP68EncodeAsNumber(lifetime)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
updater.Pset.Inputs[i].Sequence = sequence
|
||||
continue
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("invalid sweep script")
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
sweepAddress, err := wallet.DeriveAddresses(ctx, 1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
script, err := address.ToOutputScript(sweepAddress[0])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := updater.AddOutputs([]psetv2.OutputArgs{
|
||||
{
|
||||
Asset: lbtc,
|
||||
Amount: amount,
|
||||
Script: script,
|
||||
},
|
||||
}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
b64, err := sweepPset.ToBase64()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
fees, err := wallet.EstimateFees(ctx, b64)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if amount < fees {
|
||||
return nil, fmt.Errorf("insufficient funds (%d) to cover fees (%d) for sweep transaction", amount, fees)
|
||||
}
|
||||
|
||||
updater.Pset.Outputs[0].Value = amount - fees
|
||||
|
||||
if err := updater.AddOutputs([]psetv2.OutputArgs{
|
||||
{
|
||||
Asset: lbtc,
|
||||
Amount: fees,
|
||||
},
|
||||
}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return sweepPset, nil
|
||||
}
|
||||
229
server/internal/infrastructure/tx-builder/covenant/testdata/fixtures.json
vendored
Normal file
229
server/internal/infrastructure/tx-builder/covenant/testdata/fixtures.json
vendored
Normal file
@@ -0,0 +1,229 @@
|
||||
{
|
||||
"buildPoolTx": {
|
||||
"valid": [
|
||||
{
|
||||
"payments": [
|
||||
{
|
||||
"id": "0",
|
||||
"inputs": [
|
||||
{
|
||||
"txid": "fd68e3c5796cc7db0a8036d486d5f625b6b2f2c014810ac020e1ac23e82c59d6",
|
||||
"vout": 0,
|
||||
"pubkey": "020000000000000000000000000000000000000000000000000000000000000002",
|
||||
"amount": 1100
|
||||
}
|
||||
],
|
||||
"receivers": [
|
||||
{
|
||||
"pubkey": "020000000000000000000000000000000000000000000000000000000000000002",
|
||||
"amount": 1100
|
||||
}
|
||||
]
|
||||
}
|
||||
],
|
||||
"expectedNumOfNodes": 1,
|
||||
"expectedNumOfLeaves": 1
|
||||
},
|
||||
{
|
||||
"payments": [
|
||||
{
|
||||
"id": "0",
|
||||
"inputs": [
|
||||
{
|
||||
"txid": "fd68e3c5796cc7db0a8036d486d5f625b6b2f2c014810ac020e1ac23e82c59d6",
|
||||
"vout": 0,
|
||||
"pubkey": "020000000000000000000000000000000000000000000000000000000000000002",
|
||||
"amount": 1100
|
||||
}
|
||||
],
|
||||
"receivers": [
|
||||
{
|
||||
"pubkey": "020000000000000000000000000000000000000000000000000000000000000002",
|
||||
"amount": 600
|
||||
},
|
||||
{
|
||||
"pubkey": "020000000000000000000000000000000000000000000000000000000000000002",
|
||||
"amount": 500
|
||||
}
|
||||
]
|
||||
}
|
||||
],
|
||||
"expectedNumOfNodes": 3,
|
||||
"expectedNumOfLeaves": 2
|
||||
},
|
||||
{
|
||||
"payments": [
|
||||
{
|
||||
"id": "0",
|
||||
"inputs": [
|
||||
{
|
||||
"txid": "fd68e3c5796cc7db0a8036d486d5f625b6b2f2c014810ac020e1ac23e82c59d6",
|
||||
"vout": 0,
|
||||
"pubkey": "020000000000000000000000000000000000000000000000000000000000000002",
|
||||
"amount": 1100
|
||||
}
|
||||
],
|
||||
"receivers": [
|
||||
{
|
||||
"pubkey": "020000000000000000000000000000000000000000000000000000000000000002",
|
||||
"amount": 600
|
||||
},
|
||||
{
|
||||
"pubkey": "020000000000000000000000000000000000000000000000000000000000000002",
|
||||
"amount": 500
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": "0",
|
||||
"inputs": [
|
||||
{
|
||||
"txid": "fd68e3c5796cc7db0a8036d486d5f625b6b2f2c014810ac020e1ac23e82c59d6",
|
||||
"vout": 0,
|
||||
"pubkey": "020000000000000000000000000000000000000000000000000000000000000002",
|
||||
"amount": 1100
|
||||
}
|
||||
],
|
||||
"receivers": [
|
||||
{
|
||||
"pubkey": "020000000000000000000000000000000000000000000000000000000000000002",
|
||||
"amount": 600
|
||||
},
|
||||
{
|
||||
"pubkey": "020000000000000000000000000000000000000000000000000000000000000002",
|
||||
"amount": 500
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": "0",
|
||||
"inputs": [
|
||||
{
|
||||
"txid": "fd68e3c5796cc7db0a8036d486d5f625b6b2f2c014810ac020e1ac23e82c59d6",
|
||||
"vout": 0,
|
||||
"pubkey": "020000000000000000000000000000000000000000000000000000000000000002",
|
||||
"amount": 1100
|
||||
}
|
||||
],
|
||||
"receivers": [
|
||||
{
|
||||
"pubkey": "020000000000000000000000000000000000000000000000000000000000000002",
|
||||
"amount": 600
|
||||
},
|
||||
{
|
||||
"pubkey": "020000000000000000000000000000000000000000000000000000000000000002",
|
||||
"amount": 500
|
||||
}
|
||||
]
|
||||
}
|
||||
],
|
||||
"expectedNumOfNodes": 11,
|
||||
"expectedNumOfLeaves": 6
|
||||
},
|
||||
{
|
||||
"payments": [
|
||||
{
|
||||
"id": "a242cdd8-f3d5-46c0-ae98-94135a2bee3f",
|
||||
"inputs": [
|
||||
{
|
||||
"txid": "755c820771284d85ea4bbcc246565b4eddadc44237a7e57a0f9cb78a840d1d41",
|
||||
"vout": 0,
|
||||
"pubkey": "02c87e5c1758df5ad42a918ec507b6e8dfcdcebf22f64f58eb4ad5804257d658a5",
|
||||
"amount": 1000
|
||||
},
|
||||
{
|
||||
"txid": "66a0df86fcdeb84b8877adfe0b2c556dba30305d72ddbd4c49355f6930355357",
|
||||
"vout": 0,
|
||||
"pubkey": "02c87e5c1758df5ad42a918ec507b6e8dfcdcebf22f64f58eb4ad5804257d658a5",
|
||||
"amount": 1000
|
||||
},
|
||||
{
|
||||
"txid": "9913159bc7aa493ca53cbb9cbc88f97ba01137c814009dc7ef520c3fafc67909",
|
||||
"vout": 1,
|
||||
"pubkey": "02c87e5c1758df5ad42a918ec507b6e8dfcdcebf22f64f58eb4ad5804257d658a5",
|
||||
"amount": 500
|
||||
},
|
||||
{
|
||||
"txid": "5e10e77a7cdedc153be5193a4b6055a7802706ded4f2a9efefe86ed2f9a6ae60",
|
||||
"vout": 0,
|
||||
"pubkey": "02c87e5c1758df5ad42a918ec507b6e8dfcdcebf22f64f58eb4ad5804257d658a5",
|
||||
"amount": 1000
|
||||
},
|
||||
{
|
||||
"txid": "5e10e77a7cdedc153be5193a4b6055a7802706ded4f2a9efefe86ed2f9a6ae60",
|
||||
"vout": 1,
|
||||
"pubkey": "02c87e5c1758df5ad42a918ec507b6e8dfcdcebf22f64f58eb4ad5804257d658a5",
|
||||
"amount": 1000
|
||||
}
|
||||
],
|
||||
"receivers": [
|
||||
{
|
||||
"pubkey": "02c87e5c1758df5ad42a918ec507b6e8dfcdcebf22f64f58eb4ad5804257d658a5",
|
||||
"amount": 1000
|
||||
},
|
||||
{
|
||||
"pubkey": "02c87e5c1758df5ad42a918ec507b6e8dfcdcebf22f64f58eb4ad5804257d658a5",
|
||||
"amount": 1000
|
||||
},
|
||||
{
|
||||
"pubkey": "02c87e5c1758df5ad42a918ec507b6e8dfcdcebf22f64f58eb4ad5804257d658a5",
|
||||
"amount": 1000
|
||||
},
|
||||
{
|
||||
"pubkey": "02c87e5c1758df5ad42a918ec507b6e8dfcdcebf22f64f58eb4ad5804257d658a5",
|
||||
"amount": 1000
|
||||
},
|
||||
{
|
||||
"pubkey": "02c87e5c1758df5ad42a918ec507b6e8dfcdcebf22f64f58eb4ad5804257d658a5",
|
||||
"amount": 500
|
||||
}
|
||||
]
|
||||
}
|
||||
],
|
||||
"expectedNumOfNodes": 9,
|
||||
"expectedNumOfLeaves": 5
|
||||
}
|
||||
],
|
||||
"invalid": []
|
||||
},
|
||||
"buildForfeitTxs": {
|
||||
"valid": [
|
||||
{
|
||||
"payments": [
|
||||
{
|
||||
"id": "0",
|
||||
"inputs": [
|
||||
{
|
||||
"txid": "fd68e3c5796cc7db0a8036d486d5f625b6b2f2c014810ac020e1ac23e82c59d6",
|
||||
"vout": 0,
|
||||
"pubkey": "020000000000000000000000000000000000000000000000000000000000000002",
|
||||
"amount": 600
|
||||
},
|
||||
{
|
||||
"txid": "fd68e3c5796cc7db0a8036d486d5f625b6b2f2c014810ac020e1ac23e82c59d6",
|
||||
"vout": 1,
|
||||
"pubkey": "020000000000000000000000000000000000000000000000000000000000000002",
|
||||
"amount": 500
|
||||
}
|
||||
],
|
||||
"receivers": [
|
||||
{
|
||||
"pubkey": "020000000000000000000000000000000000000000000000000000000000000002",
|
||||
"amount": 600
|
||||
},
|
||||
{
|
||||
"pubkey": "020000000000000000000000000000000000000000000000000000000000000002",
|
||||
"amount": 500
|
||||
}
|
||||
]
|
||||
}
|
||||
],
|
||||
"poolTx": "cHNldP8BAgQCAAAAAQQBAQEFAQMBBgEDAfsEAgAAAAABDiDk7dXxh4KQzgLO8i1ABtaLCe4aPL12GVhN1E9zM1ePLwEPBAAAAAABEAT/////AAEDCOgDAAAAAAAAAQQWABSNnpy01UJqd99eTg2M1IpdKId11gf8BHBzZXQCICWyUQcOKcoZBDzzPM1zJOLdqwPsxK4LXnfE/A5c9slaB/wEcHNldAgEAAAAAAABAwh4BQAAAAAAAAEEFgAUjZ6ctNVCanffXk4NjNSKXSiHddYH/ARwc2V0AiAlslEHDinKGQQ88zzNcyTi3asD7MSuC153xPwOXPbJWgf8BHBzZXQIBAAAAAAAAQMI9AEAAAAAAAABBAAH/ARwc2V0AiAlslEHDinKGQQ88zzNcyTi3asD7MSuC153xPwOXPbJWgf8BHBzZXQIBAAAAAAA",
|
||||
"poolTxid": "7981fce656f266472cc742444527cb32a8bed8c90fed6d47adbfc4c8780d4d9a",
|
||||
"expectedNumOfForfeitTxs": 4,
|
||||
"expectedNumOfConnectors": 1
|
||||
}
|
||||
],
|
||||
"invalid": []
|
||||
}
|
||||
}
|
||||
448
server/internal/infrastructure/tx-builder/covenant/tree.go
Normal file
448
server/internal/infrastructure/tx-builder/covenant/tree.go
Normal file
@@ -0,0 +1,448 @@
|
||||
package txbuilder
|
||||
|
||||
import (
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
|
||||
"github.com/ark-network/ark/common/tree"
|
||||
"github.com/ark-network/ark/internal/core/domain"
|
||||
"github.com/btcsuite/btcd/chaincfg/chainhash"
|
||||
"github.com/decred/dcrd/dcrec/secp256k1/v4"
|
||||
"github.com/vulpemventures/go-elements/psetv2"
|
||||
"github.com/vulpemventures/go-elements/taproot"
|
||||
)
|
||||
|
||||
type treeFactory func(outpoint psetv2.InputArgs) (tree.CongestionTree, error)
|
||||
|
||||
type node struct {
|
||||
sweepKey *secp256k1.PublicKey
|
||||
receivers []domain.Receiver
|
||||
left *node
|
||||
right *node
|
||||
asset string
|
||||
feeSats uint64
|
||||
roundLifetime int64
|
||||
|
||||
_inputTaprootKey *secp256k1.PublicKey
|
||||
_inputTaprootTree *taproot.IndexedElementsTapScriptTree
|
||||
}
|
||||
|
||||
func (n *node) isLeaf() bool {
|
||||
return len(n.receivers) == 1
|
||||
}
|
||||
|
||||
func (n *node) getAmount() uint64 {
|
||||
var amount uint64
|
||||
for _, r := range n.receivers {
|
||||
amount += r.Amount
|
||||
}
|
||||
|
||||
if n.isLeaf() {
|
||||
return amount
|
||||
}
|
||||
|
||||
return amount + n.feeSats*uint64(n.countChildren())
|
||||
}
|
||||
|
||||
func (n *node) countChildren() int {
|
||||
result := 0
|
||||
|
||||
if n.left != nil {
|
||||
result++
|
||||
result += n.left.countChildren()
|
||||
}
|
||||
|
||||
if n.right != nil {
|
||||
result++
|
||||
result += n.right.countChildren()
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
func (n *node) getChildren() []*node {
|
||||
if n.isLeaf() {
|
||||
return nil
|
||||
}
|
||||
|
||||
children := make([]*node, 0, 2)
|
||||
|
||||
if n.left != nil {
|
||||
children = append(children, n.left)
|
||||
}
|
||||
|
||||
if n.right != nil {
|
||||
children = append(children, n.right)
|
||||
}
|
||||
|
||||
return children
|
||||
}
|
||||
|
||||
func (n *node) getOutputs() ([]psetv2.OutputArgs, error) {
|
||||
if n.isLeaf() {
|
||||
taprootKey, _, err := n.getVtxoWitnessData()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
script, err := taprootOutputScript(taprootKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
output := &psetv2.OutputArgs{
|
||||
Asset: n.asset,
|
||||
Amount: uint64(n.getAmount()),
|
||||
Script: script,
|
||||
}
|
||||
|
||||
return []psetv2.OutputArgs{*output}, nil
|
||||
}
|
||||
|
||||
outputs := make([]psetv2.OutputArgs, 0, 2)
|
||||
children := n.getChildren()
|
||||
|
||||
for _, child := range children {
|
||||
childWitnessProgram, _, err := child.getWitnessData()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
script, err := taprootOutputScript(childWitnessProgram)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
outputs = append(outputs, psetv2.OutputArgs{
|
||||
Asset: n.asset,
|
||||
Amount: child.getAmount() + child.feeSats,
|
||||
Script: script,
|
||||
})
|
||||
}
|
||||
|
||||
return outputs, nil
|
||||
}
|
||||
|
||||
func (n *node) getWitnessData() (
|
||||
*secp256k1.PublicKey, *taproot.IndexedElementsTapScriptTree, error,
|
||||
) {
|
||||
if n._inputTaprootKey != nil && n._inputTaprootTree != nil {
|
||||
return n._inputTaprootKey, n._inputTaprootTree, nil
|
||||
}
|
||||
|
||||
sweepClosure, err := tree.SweepScript(n.sweepKey, uint(n.roundLifetime))
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
if n.isLeaf() {
|
||||
taprootKey, _, err := n.getVtxoWitnessData()
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
branchTaprootScript := tree.BranchScript(
|
||||
taprootKey, nil, n.getAmount(), 0,
|
||||
)
|
||||
|
||||
branchTaprootTree := taproot.AssembleTaprootScriptTree(
|
||||
branchTaprootScript, *sweepClosure,
|
||||
)
|
||||
root := branchTaprootTree.RootNode.TapHash()
|
||||
|
||||
inputTapkey := taproot.ComputeTaprootOutputKey(
|
||||
tree.UnspendableKey(),
|
||||
root[:],
|
||||
)
|
||||
|
||||
n._inputTaprootKey = inputTapkey
|
||||
n._inputTaprootTree = branchTaprootTree
|
||||
|
||||
return inputTapkey, branchTaprootTree, nil
|
||||
}
|
||||
|
||||
leftKey, _, err := n.left.getWitnessData()
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
rightKey, _, err := n.right.getWitnessData()
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
leftAmount := n.left.getAmount() + n.feeSats
|
||||
rightAmount := n.right.getAmount() + n.feeSats
|
||||
branchTaprootLeaf := tree.BranchScript(
|
||||
leftKey, rightKey, leftAmount, rightAmount,
|
||||
)
|
||||
|
||||
branchTaprootTree := taproot.AssembleTaprootScriptTree(
|
||||
branchTaprootLeaf, *sweepClosure,
|
||||
)
|
||||
root := branchTaprootTree.RootNode.TapHash()
|
||||
|
||||
taprootKey := taproot.ComputeTaprootOutputKey(
|
||||
tree.UnspendableKey(),
|
||||
root[:],
|
||||
)
|
||||
|
||||
n._inputTaprootKey = taprootKey
|
||||
n._inputTaprootTree = branchTaprootTree
|
||||
|
||||
return taprootKey, branchTaprootTree, nil
|
||||
}
|
||||
|
||||
func (n *node) getVtxoWitnessData() (
|
||||
*secp256k1.PublicKey, *taproot.IndexedElementsTapScriptTree, error,
|
||||
) {
|
||||
if !n.isLeaf() {
|
||||
return nil, nil, fmt.Errorf("cannot call vtxoWitness on a non-leaf node")
|
||||
}
|
||||
|
||||
sweepClosure, err := tree.SweepScript(n.sweepKey, uint(n.roundLifetime))
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
key, err := hex.DecodeString(n.receivers[0].Pubkey)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
pubkey, err := secp256k1.ParsePubKey(key)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
vtxoLeaf, err := tree.VtxoScript(pubkey)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// TODO: add forfeit path
|
||||
leafTaprootTree := taproot.AssembleTaprootScriptTree(
|
||||
*vtxoLeaf, *sweepClosure,
|
||||
)
|
||||
root := leafTaprootTree.RootNode.TapHash()
|
||||
|
||||
taprootKey := taproot.ComputeTaprootOutputKey(
|
||||
tree.UnspendableKey(),
|
||||
root[:],
|
||||
)
|
||||
|
||||
return taprootKey, leafTaprootTree, nil
|
||||
}
|
||||
|
||||
func (n *node) getTreeNode(
|
||||
input psetv2.InputArgs, tapTree *taproot.IndexedElementsTapScriptTree,
|
||||
) (tree.Node, error) {
|
||||
pset, err := n.getTx(input, tapTree)
|
||||
if err != nil {
|
||||
return tree.Node{}, err
|
||||
}
|
||||
|
||||
txid, err := getPsetId(pset)
|
||||
if err != nil {
|
||||
return tree.Node{}, err
|
||||
}
|
||||
|
||||
tx, err := pset.ToBase64()
|
||||
if err != nil {
|
||||
return tree.Node{}, err
|
||||
}
|
||||
|
||||
parentTxid := chainhash.Hash(pset.Inputs[0].PreviousTxid).String()
|
||||
|
||||
return tree.Node{
|
||||
Txid: txid,
|
||||
Tx: tx,
|
||||
ParentTxid: parentTxid,
|
||||
Leaf: n.isLeaf(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (n *node) getTx(
|
||||
input psetv2.InputArgs, inputTapTree *taproot.IndexedElementsTapScriptTree,
|
||||
) (*psetv2.Pset, error) {
|
||||
pset, err := psetv2.New(nil, nil, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
updater, err := psetv2.NewUpdater(pset)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := addTaprootInput(
|
||||
updater, input, tree.UnspendableKey(), inputTapTree,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
feeOutput := psetv2.OutputArgs{
|
||||
Amount: uint64(n.feeSats),
|
||||
Asset: n.asset,
|
||||
}
|
||||
|
||||
outputs, err := n.getOutputs()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := updater.AddOutputs(append(outputs, feeOutput)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return pset, nil
|
||||
}
|
||||
|
||||
func (n *node) createFinalCongestionTree() treeFactory {
|
||||
return func(poolTxInput psetv2.InputArgs) (tree.CongestionTree, error) {
|
||||
congestionTree := make(tree.CongestionTree, 0)
|
||||
|
||||
_, taprootTree, err := n.getWitnessData()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ins := []psetv2.InputArgs{poolTxInput}
|
||||
inTrees := []*taproot.IndexedElementsTapScriptTree{taprootTree}
|
||||
nodes := []*node{n}
|
||||
|
||||
for len(nodes) > 0 {
|
||||
nextNodes := make([]*node, 0)
|
||||
nextInputsArgs := make([]psetv2.InputArgs, 0)
|
||||
nextTaprootTrees := make([]*taproot.IndexedElementsTapScriptTree, 0)
|
||||
|
||||
treeLevel := make([]tree.Node, 0)
|
||||
|
||||
for i, node := range nodes {
|
||||
treeNode, err := node.getTreeNode(ins[i], inTrees[i])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
treeLevel = append(treeLevel, treeNode)
|
||||
|
||||
children := node.getChildren()
|
||||
|
||||
for i, child := range children {
|
||||
_, taprootTree, err := child.getWitnessData()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
nextNodes = append(nextNodes, child)
|
||||
nextInputsArgs = append(nextInputsArgs, psetv2.InputArgs{
|
||||
Txid: treeNode.Txid,
|
||||
TxIndex: uint32(i),
|
||||
})
|
||||
nextTaprootTrees = append(nextTaprootTrees, taprootTree)
|
||||
}
|
||||
}
|
||||
|
||||
congestionTree = append(congestionTree, treeLevel)
|
||||
nodes = append([]*node{}, nextNodes...)
|
||||
ins = append([]psetv2.InputArgs{}, nextInputsArgs...)
|
||||
inTrees = append(
|
||||
[]*taproot.IndexedElementsTapScriptTree{}, nextTaprootTrees...,
|
||||
)
|
||||
}
|
||||
|
||||
return congestionTree, nil
|
||||
}
|
||||
}
|
||||
|
||||
func craftCongestionTree(
|
||||
asset string, aspPublicKey *secp256k1.PublicKey,
|
||||
payments []domain.Payment, feeSatsPerNode uint64, roundLifetime int64,
|
||||
) (
|
||||
buildCongestionTree treeFactory,
|
||||
sharedOutputScript []byte, sharedOutputAmount uint64, err error,
|
||||
) {
|
||||
receivers := getOffchainReceivers(payments)
|
||||
root, err := createPartialCongestionTree(
|
||||
receivers, aspPublicKey, asset, feeSatsPerNode, roundLifetime,
|
||||
)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
taprootKey, _, err := root.getWitnessData()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
sharedOutputScript, err = taprootOutputScript(taprootKey)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
sharedOutputAmount = root.getAmount() + root.feeSats
|
||||
buildCongestionTree = root.createFinalCongestionTree()
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func createPartialCongestionTree(
|
||||
receivers []domain.Receiver,
|
||||
aspPublicKey *secp256k1.PublicKey,
|
||||
asset string,
|
||||
feeSatsPerNode uint64,
|
||||
roundLifetime int64,
|
||||
) (root *node, err error) {
|
||||
if len(receivers) == 0 {
|
||||
return nil, fmt.Errorf("no receivers provided")
|
||||
}
|
||||
|
||||
nodes := make([]*node, 0, len(receivers))
|
||||
for _, r := range receivers {
|
||||
leafNode := &node{
|
||||
sweepKey: aspPublicKey,
|
||||
receivers: []domain.Receiver{r},
|
||||
asset: asset,
|
||||
feeSats: feeSatsPerNode,
|
||||
roundLifetime: roundLifetime,
|
||||
}
|
||||
nodes = append(nodes, leafNode)
|
||||
}
|
||||
|
||||
for len(nodes) > 1 {
|
||||
nodes, err = createUpperLevel(nodes)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
return nodes[0], nil
|
||||
}
|
||||
|
||||
func createUpperLevel(nodes []*node) ([]*node, error) {
|
||||
if len(nodes)%2 != 0 {
|
||||
last := nodes[len(nodes)-1]
|
||||
pairs, err := createUpperLevel(nodes[:len(nodes)-1])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return append(pairs, last), nil
|
||||
}
|
||||
|
||||
pairs := make([]*node, 0, len(nodes)/2)
|
||||
for i := 0; i < len(nodes); i += 2 {
|
||||
left := nodes[i]
|
||||
right := nodes[i+1]
|
||||
branchNode := &node{
|
||||
sweepKey: left.sweepKey,
|
||||
receivers: append(left.receivers, right.receivers...),
|
||||
left: left,
|
||||
right: right,
|
||||
asset: left.asset,
|
||||
feeSats: left.feeSats,
|
||||
roundLifetime: left.roundLifetime,
|
||||
}
|
||||
pairs = append(pairs, branchNode)
|
||||
}
|
||||
return pairs, nil
|
||||
}
|
||||
167
server/internal/infrastructure/tx-builder/covenant/utils.go
Normal file
167
server/internal/infrastructure/tx-builder/covenant/utils.go
Normal file
@@ -0,0 +1,167 @@
|
||||
package txbuilder
|
||||
|
||||
import (
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
|
||||
"github.com/ark-network/ark/internal/core/domain"
|
||||
"github.com/ark-network/ark/internal/core/ports"
|
||||
"github.com/btcsuite/btcd/btcec/v2/schnorr"
|
||||
"github.com/btcsuite/btcd/txscript"
|
||||
"github.com/decred/dcrd/dcrec/secp256k1/v4"
|
||||
"github.com/vulpemventures/go-elements/address"
|
||||
"github.com/vulpemventures/go-elements/elementsutil"
|
||||
"github.com/vulpemventures/go-elements/network"
|
||||
"github.com/vulpemventures/go-elements/payment"
|
||||
"github.com/vulpemventures/go-elements/psetv2"
|
||||
"github.com/vulpemventures/go-elements/taproot"
|
||||
"github.com/vulpemventures/go-elements/transaction"
|
||||
)
|
||||
|
||||
func p2wpkhScript(publicKey *secp256k1.PublicKey, net *network.Network) ([]byte, error) {
|
||||
payment := payment.FromPublicKey(publicKey, net, nil)
|
||||
addr, err := payment.WitnessPubKeyHash()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return address.ToOutputScript(addr)
|
||||
}
|
||||
|
||||
func getTxid(txStr string) (string, error) {
|
||||
pset, err := psetv2.NewPsetFromBase64(txStr)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return getPsetId(pset)
|
||||
}
|
||||
|
||||
func getPsetId(pset *psetv2.Pset) (string, error) {
|
||||
utx, err := pset.UnsignedTx()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return utx.TxHash().String(), nil
|
||||
}
|
||||
|
||||
func getOnchainReceivers(
|
||||
payments []domain.Payment,
|
||||
) []domain.Receiver {
|
||||
receivers := make([]domain.Receiver, 0)
|
||||
for _, payment := range payments {
|
||||
for _, receiver := range payment.Receivers {
|
||||
if receiver.IsOnchain() {
|
||||
receivers = append(receivers, receiver)
|
||||
}
|
||||
}
|
||||
}
|
||||
return receivers
|
||||
}
|
||||
|
||||
func getOffchainReceivers(
|
||||
payments []domain.Payment,
|
||||
) []domain.Receiver {
|
||||
receivers := make([]domain.Receiver, 0)
|
||||
for _, payment := range payments {
|
||||
for _, receiver := range payment.Receivers {
|
||||
if !receiver.IsOnchain() {
|
||||
receivers = append(receivers, receiver)
|
||||
}
|
||||
}
|
||||
}
|
||||
return receivers
|
||||
}
|
||||
|
||||
func toWitnessUtxo(in ports.TxInput) (*transaction.TxOutput, error) {
|
||||
valueBytes, err := elementsutil.ValueToBytes(in.GetValue())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to convert value to bytes: %s", err)
|
||||
}
|
||||
|
||||
assetBytes, err := elementsutil.AssetHashToBytes(in.GetAsset())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to convert asset to bytes: %s", err)
|
||||
}
|
||||
|
||||
scriptBytes, err := hex.DecodeString(in.GetScript())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decode script: %s", err)
|
||||
}
|
||||
|
||||
return transaction.NewTxOutput(assetBytes, valueBytes, scriptBytes), nil
|
||||
}
|
||||
|
||||
func countSpentVtxos(payments []domain.Payment) uint64 {
|
||||
var sum uint64
|
||||
for _, payment := range payments {
|
||||
sum += uint64(len(payment.Inputs))
|
||||
}
|
||||
return sum
|
||||
}
|
||||
|
||||
func addInputs(
|
||||
updater *psetv2.Updater,
|
||||
inputs []ports.TxInput,
|
||||
) error {
|
||||
for _, in := range inputs {
|
||||
inputArg := psetv2.InputArgs{
|
||||
Txid: in.GetTxid(),
|
||||
TxIndex: in.GetIndex(),
|
||||
}
|
||||
|
||||
witnessUtxo, err := toWitnessUtxo(in)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := updater.AddInputs([]psetv2.InputArgs{inputArg}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
index := int(updater.Pset.Global.InputCount) - 1
|
||||
if err := updater.AddInWitnessUtxo(index, witnessUtxo); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := updater.AddInSighashType(index, txscript.SigHashAll); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// wrapper of updater methods adding a taproot input to the pset with all the necessary data to spend it via any taproot script
|
||||
func addTaprootInput(
|
||||
updater *psetv2.Updater,
|
||||
input psetv2.InputArgs,
|
||||
internalTaprootKey *secp256k1.PublicKey,
|
||||
taprootTree *taproot.IndexedElementsTapScriptTree,
|
||||
) error {
|
||||
if err := updater.AddInputs([]psetv2.InputArgs{input}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := updater.AddInTapInternalKey(0, schnorr.SerializePubKey(internalTaprootKey)); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, proof := range taprootTree.LeafMerkleProofs {
|
||||
controlBlock := proof.ToControlBlock(internalTaprootKey)
|
||||
|
||||
if err := updater.AddInTapLeafScript(0, psetv2.TapLeafScript{
|
||||
TapElementsLeaf: taproot.NewBaseTapElementsLeaf(proof.Script),
|
||||
ControlBlock: controlBlock,
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func taprootOutputScript(taprootKey *secp256k1.PublicKey) ([]byte, error) {
|
||||
return txscript.NewScriptBuilder().AddOp(txscript.OP_1).AddData(schnorr.SerializePubKey(taprootKey)).Script()
|
||||
}
|
||||
285
server/internal/infrastructure/tx-builder/dummy/builder.go
Normal file
285
server/internal/infrastructure/tx-builder/dummy/builder.go
Normal file
@@ -0,0 +1,285 @@
|
||||
package txbuilder
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/ark-network/ark/common/tree"
|
||||
"github.com/ark-network/ark/internal/core/domain"
|
||||
"github.com/ark-network/ark/internal/core/ports"
|
||||
"github.com/decred/dcrd/dcrec/secp256k1/v4"
|
||||
"github.com/vulpemventures/go-elements/address"
|
||||
"github.com/vulpemventures/go-elements/network"
|
||||
"github.com/vulpemventures/go-elements/payment"
|
||||
"github.com/vulpemventures/go-elements/psetv2"
|
||||
"github.com/vulpemventures/go-elements/transaction"
|
||||
)
|
||||
|
||||
const (
|
||||
connectorAmount = 450
|
||||
sevenDays = 7 * 24 * 60 * 60
|
||||
)
|
||||
|
||||
type txBuilder struct {
|
||||
wallet ports.WalletService
|
||||
net network.Network
|
||||
}
|
||||
|
||||
func NewTxBuilder(
|
||||
wallet ports.WalletService, net network.Network,
|
||||
) ports.TxBuilder {
|
||||
return &txBuilder{wallet, net}
|
||||
}
|
||||
|
||||
// BuildSweepTx implements ports.TxBuilder.
|
||||
func (*txBuilder) BuildSweepTx(wallet ports.WalletService, inputs []ports.SweepInput) (signedSweepTx string, err error) {
|
||||
panic("unimplemented")
|
||||
}
|
||||
|
||||
// BuildForfeitTxs implements ports.TxBuilder.
|
||||
func (b *txBuilder) BuildForfeitTxs(
|
||||
aspPubkey *secp256k1.PublicKey, poolTx string, payments []domain.Payment,
|
||||
) (connectors []string, forfeitTxs []string, err error) {
|
||||
poolTxID, err := getTxid(poolTx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
aspScript, err := p2wpkhScript(aspPubkey, b.net)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
numberOfConnectors := countSpentVtxos(payments)
|
||||
|
||||
connectors, err = createConnectors(
|
||||
poolTxID,
|
||||
1,
|
||||
psetv2.OutputArgs{
|
||||
Asset: b.net.AssetID,
|
||||
Amount: connectorAmount,
|
||||
Script: aspScript,
|
||||
},
|
||||
aspScript,
|
||||
numberOfConnectors,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
connectorsAsInputs, err := connectorsToInputArgs(connectors)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
forfeitTxs = make([]string, 0)
|
||||
for _, payment := range payments {
|
||||
for _, vtxo := range payment.Inputs {
|
||||
for _, connector := range connectorsAsInputs {
|
||||
forfeitTx, err := createForfeitTx(
|
||||
connector,
|
||||
psetv2.InputArgs{
|
||||
Txid: vtxo.Txid,
|
||||
TxIndex: vtxo.VOut,
|
||||
},
|
||||
vtxo.Amount,
|
||||
aspScript,
|
||||
b.net,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
forfeitTxs = append(forfeitTxs, forfeitTx)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return connectors, forfeitTxs, nil
|
||||
}
|
||||
|
||||
// BuildPoolTx implements ports.TxBuilder.
|
||||
func (b *txBuilder) BuildPoolTx(
|
||||
aspPubkey *secp256k1.PublicKey, payments []domain.Payment, minRelayFee uint64,
|
||||
) (poolTx string, congestionTree tree.CongestionTree, err error) {
|
||||
aspScriptBytes, err := p2wpkhScript(aspPubkey, b.net)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
offchainReceivers, onchainReceivers := receiversFromPayments(payments)
|
||||
sharedOutputAmount := sumReceivers(offchainReceivers)
|
||||
|
||||
numberOfConnectors := countSpentVtxos(payments)
|
||||
connectorOutputAmount := connectorAmount * numberOfConnectors
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
outputs := []psetv2.OutputArgs{
|
||||
{
|
||||
Asset: b.net.AssetID,
|
||||
Amount: sharedOutputAmount,
|
||||
Script: aspScriptBytes,
|
||||
},
|
||||
{
|
||||
Asset: b.net.AssetID,
|
||||
Amount: connectorOutputAmount,
|
||||
Script: aspScriptBytes,
|
||||
},
|
||||
}
|
||||
|
||||
amountToSelect := sharedOutputAmount + connectorOutputAmount
|
||||
|
||||
for _, receiver := range onchainReceivers {
|
||||
amountToSelect += receiver.Amount
|
||||
|
||||
receiverScript, err := address.ToOutputScript(receiver.OnchainAddress)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
outputs = append(outputs, psetv2.OutputArgs{
|
||||
Asset: b.net.AssetID,
|
||||
Amount: receiver.Amount,
|
||||
Script: receiverScript,
|
||||
})
|
||||
}
|
||||
|
||||
utxos, change, err := b.wallet.SelectUtxos(ctx, b.net.AssetID, amountToSelect)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if change > 0 {
|
||||
outputs = append(outputs, psetv2.OutputArgs{
|
||||
Asset: b.net.AssetID,
|
||||
Amount: change,
|
||||
Script: aspScriptBytes,
|
||||
})
|
||||
}
|
||||
|
||||
ptx, err := psetv2.New(toInputArgs(utxos), outputs, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
utx, err := ptx.UnsignedTx()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
congestionTree, err = buildCongestionTree(
|
||||
newOutputScriptFactory(aspPubkey, b.net),
|
||||
b.net,
|
||||
utx.TxHash().String(),
|
||||
offchainReceivers,
|
||||
)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
poolTx, err = ptx.ToBase64()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
return poolTx, congestionTree, err
|
||||
}
|
||||
|
||||
func (b *txBuilder) GetVtxoScript(userPubkey, _ *secp256k1.PublicKey) ([]byte, error) {
|
||||
p2wpkh := payment.FromPublicKey(userPubkey, &b.net, nil)
|
||||
addr, _ := p2wpkh.WitnessPubKeyHash()
|
||||
return address.ToOutputScript(addr)
|
||||
}
|
||||
|
||||
func (b *txBuilder) GetLeafSweepClosure(
|
||||
node tree.Node, userPubKey *secp256k1.PublicKey,
|
||||
) (*psetv2.TapLeafScript, int64, error) {
|
||||
panic("unimplemented")
|
||||
}
|
||||
|
||||
func connectorsToInputArgs(connectors []string) ([]psetv2.InputArgs, error) {
|
||||
inputs := make([]psetv2.InputArgs, 0, len(connectors)+1)
|
||||
for i, psetb64 := range connectors {
|
||||
tx, err := psetv2.NewPsetFromBase64(psetb64)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
utx, err := tx.UnsignedTx()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
txid := utx.TxHash().String()
|
||||
for j := range tx.Outputs {
|
||||
inputs = append(inputs, psetv2.InputArgs{
|
||||
Txid: txid,
|
||||
TxIndex: uint32(j),
|
||||
})
|
||||
if i != len(connectors)-1 {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
return inputs, nil
|
||||
}
|
||||
|
||||
func getTxid(txStr string) (string, error) {
|
||||
pset, err := psetv2.NewPsetFromBase64(txStr)
|
||||
if err != nil {
|
||||
tx, err := transaction.NewTxFromHex(txStr)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return tx.TxHash().String(), nil
|
||||
}
|
||||
|
||||
utx, err := pset.UnsignedTx()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return utx.TxHash().String(), nil
|
||||
}
|
||||
|
||||
func countSpentVtxos(payments []domain.Payment) uint64 {
|
||||
var sum uint64
|
||||
for _, payment := range payments {
|
||||
sum += uint64(len(payment.Inputs))
|
||||
}
|
||||
return sum
|
||||
}
|
||||
|
||||
func receiversFromPayments(
|
||||
payments []domain.Payment,
|
||||
) (offchainReceivers, onchainReceivers []domain.Receiver) {
|
||||
for _, payment := range payments {
|
||||
for _, receiver := range payment.Receivers {
|
||||
if receiver.IsOnchain() {
|
||||
onchainReceivers = append(onchainReceivers, receiver)
|
||||
} else {
|
||||
offchainReceivers = append(offchainReceivers, receiver)
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func sumReceivers(receivers []domain.Receiver) uint64 {
|
||||
var sum uint64
|
||||
for _, r := range receivers {
|
||||
sum += r.Amount
|
||||
}
|
||||
return sum
|
||||
}
|
||||
|
||||
func toInputArgs(
|
||||
ins []ports.TxInput,
|
||||
) []psetv2.InputArgs {
|
||||
inputs := make([]psetv2.InputArgs, 0, len(ins))
|
||||
for _, in := range ins {
|
||||
inputs = append(inputs, psetv2.InputArgs{
|
||||
Txid: in.GetTxid(),
|
||||
TxIndex: in.GetIndex(),
|
||||
})
|
||||
}
|
||||
return inputs
|
||||
}
|
||||
407
server/internal/infrastructure/tx-builder/dummy/builder_test.go
Normal file
407
server/internal/infrastructure/tx-builder/dummy/builder_test.go
Normal file
@@ -0,0 +1,407 @@
|
||||
package txbuilder_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"testing"
|
||||
|
||||
"github.com/ark-network/ark/common"
|
||||
"github.com/ark-network/ark/internal/core/domain"
|
||||
"github.com/ark-network/ark/internal/core/ports"
|
||||
txbuilder "github.com/ark-network/ark/internal/infrastructure/tx-builder/dummy"
|
||||
"github.com/btcsuite/btcd/chaincfg/chainhash"
|
||||
secp256k1 "github.com/decred/dcrd/dcrec/secp256k1/v4"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/vulpemventures/go-elements/network"
|
||||
"github.com/vulpemventures/go-elements/psetv2"
|
||||
)
|
||||
|
||||
const (
|
||||
testingKey = "apub1qgvdtj5ttpuhkldavhq8thtm5auyk0ec4dcmrfdgu0u5hgp9we22v3hrs4x"
|
||||
fakePoolTx = "cHNldP8BAgQCAAAAAQQBAQEFAQMBBgEDAfsEAgAAAAABDiDk7dXxh4KQzgLO8i1ABtaLCe4aPL12GVhN1E9zM1ePLwEPBAAAAAABEAT/////AAEDCOgDAAAAAAAAAQQWABSNnpy01UJqd99eTg2M1IpdKId11gf8BHBzZXQCICWyUQcOKcoZBDzzPM1zJOLdqwPsxK4LXnfE/A5c9slaB/wEcHNldAgEAAAAAAABAwh4BQAAAAAAAAEEFgAUjZ6ctNVCanffXk4NjNSKXSiHddYH/ARwc2V0AiAlslEHDinKGQQ88zzNcyTi3asD7MSuC153xPwOXPbJWgf8BHBzZXQIBAAAAAAAAQMI9AEAAAAAAAABBAAH/ARwc2V0AiAlslEHDinKGQQ88zzNcyTi3asD7MSuC153xPwOXPbJWgf8BHBzZXQIBAAAAAAA"
|
||||
)
|
||||
|
||||
type input struct {
|
||||
txid string
|
||||
vout uint32
|
||||
}
|
||||
|
||||
func (i *input) GetTxid() string {
|
||||
return i.txid
|
||||
}
|
||||
|
||||
func (i *input) GetIndex() uint32 {
|
||||
return i.vout
|
||||
}
|
||||
|
||||
func (i *input) GetScript() string {
|
||||
return "a914ea9f486e82efb3dd83a69fd96e3f0113757da03c87"
|
||||
}
|
||||
|
||||
func (i *input) GetAsset() string {
|
||||
return "5ac9f65c0efcc4775e0baec4ec03abdde22473cd3cf33c0419ca290e0751b225"
|
||||
}
|
||||
|
||||
func (i *input) GetValue() uint64 {
|
||||
return 1000
|
||||
}
|
||||
|
||||
type mockedWalletService struct{}
|
||||
|
||||
// BroadcastTransaction implements ports.WalletService.
|
||||
func (*mockedWalletService) BroadcastTransaction(ctx context.Context, txHex string) (string, error) {
|
||||
panic("unimplemented")
|
||||
}
|
||||
|
||||
// Close implements ports.WalletService.
|
||||
func (*mockedWalletService) Close() {
|
||||
panic("unimplemented")
|
||||
}
|
||||
|
||||
// DeriveAddresses implements ports.WalletService.
|
||||
func (*mockedWalletService) DeriveAddresses(ctx context.Context, num int) ([]string, error) {
|
||||
panic("unimplemented")
|
||||
}
|
||||
|
||||
// GetPubkey implements ports.WalletService.
|
||||
func (*mockedWalletService) GetPubkey(ctx context.Context) (*secp256k1.PublicKey, error) {
|
||||
panic("unimplemented")
|
||||
}
|
||||
|
||||
// SignPset implements ports.WalletService.
|
||||
func (*mockedWalletService) SignPset(ctx context.Context, pset string, extractRawTx bool) (string, error) {
|
||||
panic("unimplemented")
|
||||
}
|
||||
|
||||
// Status implements ports.WalletService.
|
||||
func (*mockedWalletService) Status(ctx context.Context) (ports.WalletStatus, error) {
|
||||
panic("unimplemented")
|
||||
}
|
||||
|
||||
func (*mockedWalletService) WatchScripts(ctx context.Context, scripts []string) error {
|
||||
panic("unimplemented")
|
||||
}
|
||||
|
||||
func (*mockedWalletService) UnwatchScripts(ctx context.Context, scripts []string) error {
|
||||
panic("unimplemented")
|
||||
}
|
||||
|
||||
func (*mockedWalletService) GetNotificationChannel(ctx context.Context) chan []domain.VtxoKey {
|
||||
panic("unimplemented")
|
||||
}
|
||||
|
||||
func (*mockedWalletService) SelectUtxos(ctx context.Context, asset string, amount uint64) ([]ports.TxInput, uint64, error) {
|
||||
// random txid
|
||||
bytes := make([]byte, 32)
|
||||
if _, err := rand.Read(bytes); err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
fakeInput := input{
|
||||
txid: hex.EncodeToString(bytes),
|
||||
vout: 0,
|
||||
}
|
||||
|
||||
return []ports.TxInput{&fakeInput}, 0, nil
|
||||
}
|
||||
|
||||
func (*mockedWalletService) EstimateFees(ctx context.Context, pset string) (uint64, error) {
|
||||
return 100, nil
|
||||
}
|
||||
|
||||
func (*mockedWalletService) SignPsetWithKey(ctx context.Context, pset string, inputIndex []int) (string, error) {
|
||||
panic("unimplemented")
|
||||
}
|
||||
|
||||
func (*mockedWalletService) IsTransactionPublished(ctx context.Context, txid string) (bool, int64, error) {
|
||||
panic("unimplemented")
|
||||
}
|
||||
|
||||
func TestBuildCongestionTree(t *testing.T) {
|
||||
builder := txbuilder.NewTxBuilder(&mockedWalletService{}, network.Liquid)
|
||||
|
||||
fixtures := []struct {
|
||||
payments []domain.Payment
|
||||
expectedNodesNum int // 2*len(receivers)-1
|
||||
expectedLeavesNum int
|
||||
}{
|
||||
{
|
||||
payments: []domain.Payment{
|
||||
{
|
||||
Id: "0",
|
||||
Inputs: []domain.Vtxo{
|
||||
{
|
||||
VtxoKey: domain.VtxoKey{
|
||||
Txid: "fd68e3c5796cc7db0a8036d486d5f625b6b2f2c014810ac020e1ac23e82c59d6",
|
||||
VOut: 0,
|
||||
},
|
||||
Receiver: domain.Receiver{
|
||||
Pubkey: "020000000000000000000000000000000000000000000000000000000000000002",
|
||||
Amount: 600,
|
||||
},
|
||||
},
|
||||
},
|
||||
Receivers: []domain.Receiver{
|
||||
{
|
||||
Pubkey: "020000000000000000000000000000000000000000000000000000000000000002",
|
||||
Amount: 600,
|
||||
},
|
||||
{
|
||||
Pubkey: "020000000000000000000000000000000000000000000000000000000000000002",
|
||||
Amount: 400,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedNodesNum: 3,
|
||||
expectedLeavesNum: 2,
|
||||
},
|
||||
{
|
||||
payments: []domain.Payment{
|
||||
{
|
||||
Id: "0",
|
||||
Inputs: []domain.Vtxo{
|
||||
{
|
||||
VtxoKey: domain.VtxoKey{
|
||||
Txid: "fd68e3c5796cc7db0a8036d486d5f625b6b2f2c014810ac020e1ac23e82c59d6",
|
||||
VOut: 0,
|
||||
},
|
||||
Receiver: domain.Receiver{
|
||||
Pubkey: "020000000000000000000000000000000000000000000000000000000000000002",
|
||||
Amount: 600,
|
||||
},
|
||||
},
|
||||
},
|
||||
Receivers: []domain.Receiver{
|
||||
{
|
||||
Pubkey: "020000000000000000000000000000000000000000000000000000000000000002",
|
||||
Amount: 600,
|
||||
},
|
||||
{
|
||||
Pubkey: "020000000000000000000000000000000000000000000000000000000000000002",
|
||||
Amount: 400,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Id: "0",
|
||||
Inputs: []domain.Vtxo{
|
||||
{
|
||||
VtxoKey: domain.VtxoKey{
|
||||
Txid: "fd68e3c5796cc7db0a8036d486d5f625b6b2f2c014810ac020e1ac23e82c59d6",
|
||||
VOut: 0,
|
||||
},
|
||||
Receiver: domain.Receiver{
|
||||
Pubkey: "020000000000000000000000000000000000000000000000000000000000000002",
|
||||
Amount: 600,
|
||||
},
|
||||
},
|
||||
},
|
||||
Receivers: []domain.Receiver{
|
||||
{
|
||||
Pubkey: "020000000000000000000000000000000000000000000000000000000000000002",
|
||||
Amount: 600,
|
||||
},
|
||||
{
|
||||
Pubkey: "020000000000000000000000000000000000000000000000000000000000000002",
|
||||
Amount: 400,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Id: "0",
|
||||
Inputs: []domain.Vtxo{
|
||||
{
|
||||
VtxoKey: domain.VtxoKey{
|
||||
Txid: "fd68e3c5796cc7db0a8036d486d5f625b6b2f2c014810ac020e1ac23e82c59d6",
|
||||
VOut: 0,
|
||||
},
|
||||
Receiver: domain.Receiver{
|
||||
Pubkey: "020000000000000000000000000000000000000000000000000000000000000002",
|
||||
Amount: 600,
|
||||
},
|
||||
},
|
||||
},
|
||||
Receivers: []domain.Receiver{
|
||||
{
|
||||
Pubkey: "020000000000000000000000000000000000000000000000000000000000000002",
|
||||
Amount: 600,
|
||||
},
|
||||
{
|
||||
Pubkey: "020000000000000000000000000000000000000000000000000000000000000002",
|
||||
Amount: 400,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedNodesNum: 11,
|
||||
expectedLeavesNum: 6,
|
||||
},
|
||||
}
|
||||
|
||||
_, key, err := common.DecodePubKey(testingKey)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, key)
|
||||
|
||||
for _, f := range fixtures {
|
||||
poolTx, tree, err := builder.BuildPoolTx(key, f.payments, 30)
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, f.expectedNodesNum, tree.NumberOfNodes())
|
||||
require.Len(t, tree.Leaves(), f.expectedLeavesNum)
|
||||
|
||||
poolPset, err := psetv2.NewPsetFromBase64(poolTx)
|
||||
require.NoError(t, err)
|
||||
|
||||
poolTxUnsigned, err := poolPset.UnsignedTx()
|
||||
require.NoError(t, err)
|
||||
|
||||
poolTxID := poolTxUnsigned.TxHash().String()
|
||||
|
||||
// check the root
|
||||
require.Len(t, tree[0], 1)
|
||||
require.Equal(t, poolTxID, tree[0][0].ParentTxid)
|
||||
|
||||
// check the leaves
|
||||
for _, leaf := range tree.Leaves() {
|
||||
pset, err := psetv2.NewPsetFromBase64(leaf.Tx)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Len(t, pset.Inputs, 1)
|
||||
require.Len(t, pset.Outputs, 1)
|
||||
|
||||
inputTxID := chainhash.Hash(pset.Inputs[0].PreviousTxid).String()
|
||||
require.Equal(t, leaf.ParentTxid, inputTxID)
|
||||
}
|
||||
|
||||
// check the nodes
|
||||
for _, level := range tree[:len(tree)-2] {
|
||||
for _, node := range level {
|
||||
pset, err := psetv2.NewPsetFromBase64(node.Tx)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Len(t, pset.Inputs, 1)
|
||||
require.Len(t, pset.Outputs, 2)
|
||||
|
||||
inputTxID := chainhash.Hash(pset.Inputs[0].PreviousTxid).String()
|
||||
require.Equal(t, node.ParentTxid, inputTxID)
|
||||
|
||||
children := tree.Children(node.Txid)
|
||||
require.Len(t, children, 2)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildForfeitTxs(t *testing.T) {
|
||||
builder := txbuilder.NewTxBuilder(&mockedWalletService{}, network.Liquid)
|
||||
|
||||
poolPset, err := psetv2.NewPsetFromBase64(fakePoolTx)
|
||||
require.NoError(t, err)
|
||||
|
||||
poolTxUnsigned, err := poolPset.UnsignedTx()
|
||||
require.NoError(t, err)
|
||||
|
||||
poolTxID := poolTxUnsigned.TxHash().String()
|
||||
|
||||
fixtures := []struct {
|
||||
payments []domain.Payment
|
||||
expectedNumOfForfeitTxs int
|
||||
expectedNumOfConnectors int
|
||||
}{
|
||||
{
|
||||
payments: []domain.Payment{
|
||||
{
|
||||
Id: "0",
|
||||
Inputs: []domain.Vtxo{
|
||||
{
|
||||
VtxoKey: domain.VtxoKey{
|
||||
Txid: "fd68e3c5796cc7db0a8036d486d5f625b6b2f2c014810ac020e1ac23e82c59d6",
|
||||
VOut: 0,
|
||||
},
|
||||
Receiver: domain.Receiver{
|
||||
Pubkey: "020000000000000000000000000000000000000000000000000000000000000002",
|
||||
Amount: 600,
|
||||
},
|
||||
},
|
||||
{
|
||||
VtxoKey: domain.VtxoKey{
|
||||
Txid: "fd68e3c5796cc7db0a8036d486d5f625b6b2f2c014810ac020e1ac23e82c59d6",
|
||||
VOut: 1,
|
||||
},
|
||||
Receiver: domain.Receiver{
|
||||
Pubkey: "020000000000000000000000000000000000000000000000000000000000000002",
|
||||
Amount: 400,
|
||||
},
|
||||
},
|
||||
},
|
||||
Receivers: []domain.Receiver{
|
||||
{
|
||||
Pubkey: "020000000000000000000000000000000000000000000000000000000000000002",
|
||||
Amount: 600,
|
||||
},
|
||||
{
|
||||
Pubkey: "020000000000000000000000000000000000000000000000000000000000000002",
|
||||
Amount: 400,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedNumOfForfeitTxs: 4,
|
||||
expectedNumOfConnectors: 1,
|
||||
},
|
||||
}
|
||||
|
||||
_, key, err := common.DecodePubKey(testingKey)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, key)
|
||||
|
||||
for _, f := range fixtures {
|
||||
connectors, forfeitTxs, err := builder.BuildForfeitTxs(
|
||||
key, fakePoolTx, f.payments,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, connectors, f.expectedNumOfConnectors)
|
||||
require.Len(t, forfeitTxs, f.expectedNumOfForfeitTxs)
|
||||
|
||||
// decode and check connectors
|
||||
connectorsPsets := make([]*psetv2.Pset, 0, f.expectedNumOfConnectors)
|
||||
for _, pset := range connectors {
|
||||
p, err := psetv2.NewPsetFromBase64(pset)
|
||||
require.NoError(t, err)
|
||||
connectorsPsets = append(connectorsPsets, p)
|
||||
}
|
||||
|
||||
for i, pset := range connectorsPsets {
|
||||
require.Len(t, pset.Inputs, 1)
|
||||
require.Len(t, pset.Outputs, 2)
|
||||
|
||||
expectedInputTxid := poolTxID
|
||||
expectedInputVout := uint32(1)
|
||||
if i > 0 {
|
||||
tx, err := connectorsPsets[i-1].UnsignedTx()
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, tx)
|
||||
expectedInputTxid = tx.TxHash().String()
|
||||
}
|
||||
|
||||
inputTxid := chainhash.Hash(pset.Inputs[0].PreviousTxid).String()
|
||||
require.Equal(t, expectedInputTxid, inputTxid)
|
||||
require.Equal(t, expectedInputVout, pset.Inputs[0].PreviousTxIndex)
|
||||
}
|
||||
|
||||
// decode and check forfeit txs
|
||||
forfeitTxsPsets := make([]*psetv2.Pset, 0, f.expectedNumOfForfeitTxs)
|
||||
for _, pset := range forfeitTxs {
|
||||
p, err := psetv2.NewPsetFromBase64(pset)
|
||||
require.NoError(t, err)
|
||||
forfeitTxsPsets = append(forfeitTxsPsets, p)
|
||||
}
|
||||
|
||||
// each forfeit tx should have 2 inputs and 2 outputs
|
||||
for _, pset := range forfeitTxsPsets {
|
||||
require.Len(t, pset.Inputs, 2)
|
||||
require.Len(t, pset.Outputs, 1)
|
||||
}
|
||||
}
|
||||
}
|
||||
103
server/internal/infrastructure/tx-builder/dummy/connectors.go
Normal file
103
server/internal/infrastructure/tx-builder/dummy/connectors.go
Normal file
@@ -0,0 +1,103 @@
|
||||
package txbuilder
|
||||
|
||||
import (
|
||||
"github.com/vulpemventures/go-elements/psetv2"
|
||||
)
|
||||
|
||||
func createConnectors(
|
||||
poolTxID string,
|
||||
connectorOutputIndex uint32,
|
||||
connectorOutput psetv2.OutputArgs,
|
||||
changeScript []byte,
|
||||
numberOfConnectors uint64,
|
||||
) (connectorsPsets []string, err error) {
|
||||
previousInput := psetv2.InputArgs{
|
||||
Txid: poolTxID,
|
||||
TxIndex: connectorOutputIndex,
|
||||
}
|
||||
|
||||
if numberOfConnectors == 1 {
|
||||
pset, err := psetv2.New(nil, nil, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
updater, err := psetv2.NewUpdater(pset)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = updater.AddInputs([]psetv2.InputArgs{previousInput})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = updater.AddOutputs([]psetv2.OutputArgs{connectorOutput})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
base64, err := pset.ToBase64()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return []string{base64}, nil
|
||||
}
|
||||
|
||||
// compute the initial amount of the connectors output in pool transaction
|
||||
remainingAmount := connectorAmount * numberOfConnectors
|
||||
|
||||
connectorsPset := make([]string, 0, numberOfConnectors-1)
|
||||
for i := uint64(0); i < numberOfConnectors-1; i++ {
|
||||
// create a new pset
|
||||
pset, err := psetv2.New(nil, nil, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
updater, err := psetv2.NewUpdater(pset)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = updater.AddInputs([]psetv2.InputArgs{previousInput})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = updater.AddOutputs([]psetv2.OutputArgs{connectorOutput})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
changeAmount := remainingAmount - connectorOutput.Amount
|
||||
if changeAmount > 0 {
|
||||
changeOutput := psetv2.OutputArgs{
|
||||
Asset: connectorOutput.Asset,
|
||||
Amount: changeAmount,
|
||||
Script: changeScript,
|
||||
}
|
||||
err = updater.AddOutputs([]psetv2.OutputArgs{changeOutput})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tx, _ := pset.UnsignedTx()
|
||||
txid := tx.TxHash().String()
|
||||
|
||||
// make the change the next previousInput
|
||||
previousInput = psetv2.InputArgs{
|
||||
Txid: txid,
|
||||
TxIndex: 1,
|
||||
}
|
||||
}
|
||||
|
||||
base64, err := pset.ToBase64()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
connectorsPset = append(connectorsPset, base64)
|
||||
}
|
||||
|
||||
return connectorsPset, nil
|
||||
}
|
||||
42
server/internal/infrastructure/tx-builder/dummy/forfeit.go
Normal file
42
server/internal/infrastructure/tx-builder/dummy/forfeit.go
Normal file
@@ -0,0 +1,42 @@
|
||||
package txbuilder
|
||||
|
||||
import (
|
||||
"github.com/vulpemventures/go-elements/network"
|
||||
"github.com/vulpemventures/go-elements/psetv2"
|
||||
)
|
||||
|
||||
func createForfeitTx(
|
||||
connectorInput psetv2.InputArgs,
|
||||
vtxoInput psetv2.InputArgs,
|
||||
vtxoAmount uint64,
|
||||
aspScript []byte,
|
||||
net network.Network,
|
||||
) (forfeitTx string, err error) {
|
||||
pset, err := psetv2.New(nil, nil, nil)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
updater, err := psetv2.NewUpdater(pset)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
err = updater.AddInputs([]psetv2.InputArgs{connectorInput, vtxoInput})
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
err = updater.AddOutputs([]psetv2.OutputArgs{
|
||||
{
|
||||
Asset: net.AssetID,
|
||||
Amount: vtxoAmount,
|
||||
Script: aspScript,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return pset.ToBase64()
|
||||
}
|
||||
309
server/internal/infrastructure/tx-builder/dummy/tree.go
Normal file
309
server/internal/infrastructure/tx-builder/dummy/tree.go
Normal file
@@ -0,0 +1,309 @@
|
||||
package txbuilder
|
||||
|
||||
import (
|
||||
"encoding/hex"
|
||||
|
||||
"github.com/ark-network/ark/common/tree"
|
||||
"github.com/ark-network/ark/internal/core/domain"
|
||||
"github.com/btcsuite/btcd/chaincfg/chainhash"
|
||||
"github.com/decred/dcrd/dcrec/secp256k1/v4"
|
||||
"github.com/vulpemventures/go-elements/address"
|
||||
"github.com/vulpemventures/go-elements/network"
|
||||
"github.com/vulpemventures/go-elements/payment"
|
||||
"github.com/vulpemventures/go-elements/psetv2"
|
||||
)
|
||||
|
||||
const (
|
||||
sharedOutputIndex = 0
|
||||
)
|
||||
|
||||
type outputScriptFactory func(leaves []domain.Receiver) ([]byte, error)
|
||||
|
||||
func p2wpkhScript(publicKey *secp256k1.PublicKey, net network.Network) ([]byte, error) {
|
||||
payment := payment.FromPublicKey(publicKey, &net, nil)
|
||||
addr, err := payment.WitnessPubKeyHash()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return address.ToOutputScript(addr)
|
||||
}
|
||||
|
||||
// newOtputScriptFactory returns an output script factory func that lock funds using the ASP public key only on all branches psbt. The leaves are instead locked by the leaf public key.
|
||||
func newOutputScriptFactory(aspPublicKey *secp256k1.PublicKey, net network.Network) outputScriptFactory {
|
||||
return func(leaves []domain.Receiver) ([]byte, error) {
|
||||
aspScript, err := p2wpkhScript(aspPublicKey, net)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
switch len(leaves) {
|
||||
case 0:
|
||||
return nil, nil
|
||||
case 1: // it's a leaf
|
||||
buf, err := hex.DecodeString(leaves[0].Pubkey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
key, err := secp256k1.ParsePubKey(buf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return p2wpkhScript(key, net)
|
||||
default: // it's a branch, lock funds with ASP public key
|
||||
return aspScript, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// congestionTree builder iteratively creates a binary tree of Pset from a set of receivers
|
||||
// it also expect createOutputScript func managing the output script creation and the network to use (mainly for L-BTC asset id)
|
||||
func buildCongestionTree(
|
||||
createOutputScript outputScriptFactory,
|
||||
net network.Network,
|
||||
poolTxID string,
|
||||
receivers []domain.Receiver,
|
||||
) (congestionTree tree.CongestionTree, err error) {
|
||||
var nodes []*node
|
||||
|
||||
for _, r := range receivers {
|
||||
nodes = append(nodes, newLeaf(createOutputScript, net, r))
|
||||
}
|
||||
|
||||
for len(nodes) > 1 {
|
||||
nodes, err = createTreeLevel(nodes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
psets, err := nodes[0].psets(psetv2.InputArgs{
|
||||
Txid: poolTxID,
|
||||
TxIndex: sharedOutputIndex,
|
||||
}, 0)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
maxLevel := 0
|
||||
for _, psetWithLevel := range psets {
|
||||
if psetWithLevel.level > maxLevel {
|
||||
maxLevel = psetWithLevel.level
|
||||
}
|
||||
}
|
||||
|
||||
congestionTree = make(tree.CongestionTree, maxLevel+1)
|
||||
|
||||
for _, psetWithLevel := range psets {
|
||||
utx, err := psetWithLevel.pset.UnsignedTx()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
txid := utx.TxHash().String()
|
||||
|
||||
psetB64, err := psetWithLevel.pset.ToBase64()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
parentTxid := chainhash.Hash(psetWithLevel.pset.Inputs[0].PreviousTxid).String()
|
||||
|
||||
congestionTree[psetWithLevel.level] = append(congestionTree[psetWithLevel.level], tree.Node{
|
||||
Txid: txid,
|
||||
Tx: psetB64,
|
||||
ParentTxid: parentTxid,
|
||||
Leaf: psetWithLevel.leaf,
|
||||
})
|
||||
}
|
||||
|
||||
return congestionTree, nil
|
||||
}
|
||||
|
||||
func createTreeLevel(nodes []*node) ([]*node, error) {
|
||||
if len(nodes)%2 != 0 {
|
||||
last := nodes[len(nodes)-1]
|
||||
pairs, err := createTreeLevel(nodes[:len(nodes)-1])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return append(pairs, last), nil
|
||||
}
|
||||
|
||||
pairs := make([]*node, 0, len(nodes)/2)
|
||||
for i := 0; i < len(nodes); i += 2 {
|
||||
pairs = append(pairs, newBranch(nodes[i], nodes[i+1]))
|
||||
}
|
||||
return pairs, nil
|
||||
}
|
||||
|
||||
// internal struct to build a binary tree of Pset
|
||||
type node struct {
|
||||
receivers []domain.Receiver
|
||||
left *node
|
||||
right *node
|
||||
createOutputScript outputScriptFactory
|
||||
network network.Network
|
||||
}
|
||||
|
||||
// create a node from a single receiver
|
||||
func newLeaf(
|
||||
createOutputScript outputScriptFactory,
|
||||
network network.Network,
|
||||
receiver domain.Receiver,
|
||||
) *node {
|
||||
return &node{
|
||||
receivers: []domain.Receiver{receiver},
|
||||
createOutputScript: createOutputScript,
|
||||
network: network,
|
||||
left: nil,
|
||||
right: nil,
|
||||
}
|
||||
}
|
||||
|
||||
// aggregate two nodes into a branch node
|
||||
func newBranch(
|
||||
left *node,
|
||||
right *node,
|
||||
) *node {
|
||||
return &node{
|
||||
receivers: append(left.receivers, right.receivers...),
|
||||
createOutputScript: left.createOutputScript,
|
||||
network: left.network,
|
||||
left: left,
|
||||
right: right,
|
||||
}
|
||||
}
|
||||
|
||||
// is it the final node of the tree
|
||||
func (n *node) isLeaf() bool {
|
||||
return n.left == nil && n.right == nil
|
||||
}
|
||||
|
||||
// compute the output amount of a node
|
||||
func (n *node) amount() uint64 {
|
||||
var amount uint64
|
||||
for _, r := range n.receivers {
|
||||
amount += r.Amount
|
||||
}
|
||||
return amount
|
||||
}
|
||||
|
||||
// compute the output script of a node
|
||||
func (n *node) script() ([]byte, error) {
|
||||
return n.createOutputScript(n.receivers)
|
||||
}
|
||||
|
||||
// use script & amount to create OutputArgs
|
||||
func (n *node) output() (*psetv2.OutputArgs, error) {
|
||||
script, err := n.script()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &psetv2.OutputArgs{
|
||||
Asset: n.network.AssetID,
|
||||
Amount: n.amount(),
|
||||
Script: script,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// create the node Pset from the previous node Pset represented by input arg
|
||||
// if node is a branch, it adds two outputs to the Pset, one for the left branch and one for the right branch
|
||||
// if node is a leaf, it only adds one output to the Pset (the node output)
|
||||
func (n *node) pset(input psetv2.InputArgs) (*psetv2.Pset, error) {
|
||||
pset, err := psetv2.New(nil, nil, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
updater, err := psetv2.NewUpdater(pset)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = updater.AddInputs([]psetv2.InputArgs{input})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if n.isLeaf() {
|
||||
output, err := n.output()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = updater.AddOutputs([]psetv2.OutputArgs{*output})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return pset, nil
|
||||
}
|
||||
|
||||
outputLeft, err := n.left.output()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
outputRight, err := n.right.output()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = updater.AddOutputs([]psetv2.OutputArgs{*outputLeft, *outputRight})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return pset, nil
|
||||
}
|
||||
|
||||
type psetWithLevel struct {
|
||||
pset *psetv2.Pset
|
||||
level int
|
||||
leaf bool
|
||||
}
|
||||
|
||||
// create the node pset and all the psets of its children recursively, updating the input arg at each step
|
||||
// the function stops when it reaches a leaf node
|
||||
func (n *node) psets(input psetv2.InputArgs, level int) ([]psetWithLevel, error) {
|
||||
pset, err := n.pset(input)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
nodeResult := []psetWithLevel{
|
||||
{pset, level, n.isLeaf()},
|
||||
}
|
||||
|
||||
if n.isLeaf() {
|
||||
return nodeResult, nil
|
||||
}
|
||||
|
||||
unsignedTx, err := pset.UnsignedTx()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
txID := unsignedTx.TxHash().String()
|
||||
|
||||
psetsLeft, err := n.left.psets(psetv2.InputArgs{
|
||||
Txid: txID,
|
||||
TxIndex: 0,
|
||||
}, level+1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
psetsRight, err := n.right.psets(psetv2.InputArgs{
|
||||
Txid: txID,
|
||||
TxIndex: 1,
|
||||
}, level+1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return append(nodeResult, append(psetsLeft, psetsRight...)...), nil
|
||||
}
|
||||
46
server/internal/interface/grpc/config.go
Normal file
46
server/internal/interface/grpc/config.go
Normal file
@@ -0,0 +1,46 @@
|
||||
package grpcservice
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"net"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
Port uint32
|
||||
NoTLS bool
|
||||
}
|
||||
|
||||
func (c Config) Validate() error {
|
||||
lis, err := net.Listen("tcp", c.address())
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid port: %s", err)
|
||||
}
|
||||
defer lis.Close()
|
||||
|
||||
if !c.NoTLS {
|
||||
return fmt.Errorf("tls termination not supported yet")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c Config) insecure() bool {
|
||||
return c.NoTLS
|
||||
}
|
||||
|
||||
func (c Config) address() string {
|
||||
return fmt.Sprintf(":%d", c.Port)
|
||||
}
|
||||
|
||||
func (c Config) listener() net.Listener {
|
||||
lis, _ := net.Listen("tcp", c.address())
|
||||
|
||||
if c.insecure() {
|
||||
return lis
|
||||
}
|
||||
return tls.NewListener(lis, c.tlsConfig())
|
||||
}
|
||||
|
||||
func (c Config) tlsConfig() *tls.Config {
|
||||
return nil
|
||||
}
|
||||
305
server/internal/interface/grpc/handlers/arkservice.go
Normal file
305
server/internal/interface/grpc/handlers/arkservice.go
Normal file
@@ -0,0 +1,305 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/hex"
|
||||
"sync"
|
||||
|
||||
arkv1 "github.com/ark-network/ark/api-spec/protobuf/gen/ark/v1"
|
||||
"github.com/ark-network/ark/common"
|
||||
"github.com/ark-network/ark/common/tree"
|
||||
"github.com/ark-network/ark/internal/core/application"
|
||||
"github.com/ark-network/ark/internal/core/domain"
|
||||
"github.com/decred/dcrd/dcrec/secp256k1/v4"
|
||||
"github.com/google/uuid"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
type listener struct {
|
||||
id string
|
||||
ch chan *arkv1.GetEventStreamResponse
|
||||
}
|
||||
|
||||
type handler struct {
|
||||
svc application.Service
|
||||
|
||||
listenersLock *sync.Mutex
|
||||
listeners []*listener
|
||||
}
|
||||
|
||||
func NewHandler(service application.Service) arkv1.ArkServiceServer {
|
||||
h := &handler{
|
||||
svc: service,
|
||||
listenersLock: &sync.Mutex{},
|
||||
listeners: make([]*listener, 0),
|
||||
}
|
||||
|
||||
go h.listenToEvents()
|
||||
|
||||
return h
|
||||
}
|
||||
|
||||
func (h *handler) Ping(ctx context.Context, req *arkv1.PingRequest) (*arkv1.PingResponse, error) {
|
||||
if req.GetPaymentId() == "" {
|
||||
return nil, status.Error(codes.InvalidArgument, "missing payment id")
|
||||
}
|
||||
|
||||
if err := h.svc.UpdatePaymentStatus(ctx, req.GetPaymentId()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &arkv1.PingResponse{}, nil
|
||||
}
|
||||
|
||||
func (h *handler) RegisterPayment(ctx context.Context, req *arkv1.RegisterPaymentRequest) (*arkv1.RegisterPaymentResponse, error) {
|
||||
vtxosKeys := make([]domain.VtxoKey, 0, len(req.GetInputs()))
|
||||
for _, input := range req.GetInputs() {
|
||||
vtxosKeys = append(vtxosKeys, domain.VtxoKey{
|
||||
Txid: input.GetTxid(),
|
||||
VOut: input.GetVout(),
|
||||
})
|
||||
}
|
||||
|
||||
id, err := h.svc.SpendVtxos(ctx, vtxosKeys)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &arkv1.RegisterPaymentResponse{
|
||||
Id: id,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (h *handler) ClaimPayment(ctx context.Context, req *arkv1.ClaimPaymentRequest) (*arkv1.ClaimPaymentResponse, error) {
|
||||
receivers, err := parseReceivers(req.GetOutputs())
|
||||
if err != nil {
|
||||
return nil, status.Error(codes.InvalidArgument, err.Error())
|
||||
}
|
||||
|
||||
if err := h.svc.ClaimVtxos(ctx, req.GetId(), receivers); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &arkv1.ClaimPaymentResponse{}, nil
|
||||
}
|
||||
|
||||
func (h *handler) FinalizePayment(ctx context.Context, req *arkv1.FinalizePaymentRequest) (*arkv1.FinalizePaymentResponse, error) {
|
||||
forfeitTxs, err := parseTxs(req.GetSignedForfeitTxs())
|
||||
if err != nil {
|
||||
return nil, status.Error(codes.InvalidArgument, err.Error())
|
||||
}
|
||||
if err := h.svc.SignVtxos(ctx, forfeitTxs); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &arkv1.FinalizePaymentResponse{}, nil
|
||||
}
|
||||
|
||||
func (h *handler) Faucet(ctx context.Context, req *arkv1.FaucetRequest) (*arkv1.FaucetResponse, error) {
|
||||
_, pubkey, _, err := parseAddress(req.GetAddress())
|
||||
if err != nil {
|
||||
return nil, status.Error(codes.InvalidArgument, err.Error())
|
||||
}
|
||||
|
||||
if err := h.svc.FaucetVtxos(ctx, pubkey); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &arkv1.FaucetResponse{}, nil
|
||||
}
|
||||
|
||||
func (h *handler) GetRound(ctx context.Context, req *arkv1.GetRoundRequest) (*arkv1.GetRoundResponse, error) {
|
||||
if len(req.GetTxid()) <= 0 {
|
||||
return nil, status.Error(codes.InvalidArgument, "missing pool txid")
|
||||
}
|
||||
|
||||
round, err := h.svc.GetRoundByTxid(ctx, req.GetTxid())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &arkv1.GetRoundResponse{
|
||||
Round: &arkv1.Round{
|
||||
Id: round.Id,
|
||||
Start: round.StartingTimestamp,
|
||||
End: round.EndingTimestamp,
|
||||
Txid: round.Txid,
|
||||
CongestionTree: castCongestionTree(round.CongestionTree),
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (h *handler) GetEventStream(_ *arkv1.GetEventStreamRequest, stream arkv1.ArkService_GetEventStreamServer) error {
|
||||
listener := &listener{
|
||||
id: uuid.NewString(),
|
||||
ch: make(chan *arkv1.GetEventStreamResponse),
|
||||
}
|
||||
|
||||
defer h.removeListener(listener.id)
|
||||
defer close(listener.ch)
|
||||
|
||||
h.pushListener(listener)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-stream.Context().Done():
|
||||
return nil
|
||||
|
||||
case ev := <-listener.ch:
|
||||
if err := stream.Send(ev); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
switch ev.Event.(type) {
|
||||
case *arkv1.GetEventStreamResponse_RoundFinalized, *arkv1.GetEventStreamResponse_RoundFailed:
|
||||
if err := stream.Send(ev); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (h *handler) ListVtxos(ctx context.Context, req *arkv1.ListVtxosRequest) (*arkv1.ListVtxosResponse, error) {
|
||||
hrp, userPubkey, aspPubkey, err := parseAddress(req.GetAddress())
|
||||
if err != nil {
|
||||
return nil, status.Error(codes.InvalidArgument, err.Error())
|
||||
}
|
||||
|
||||
vtxos, err := h.svc.ListVtxos(ctx, userPubkey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &arkv1.ListVtxosResponse{
|
||||
Vtxos: vtxoList(vtxos).toProto(hrp, aspPubkey),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (h *handler) GetPubkey(ctx context.Context, req *arkv1.GetPubkeyRequest) (*arkv1.GetPubkeyResponse, error) {
|
||||
pubkey, err := h.svc.GetPubkey(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &arkv1.GetPubkeyResponse{
|
||||
Pubkey: pubkey,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (h *handler) pushListener(l *listener) {
|
||||
h.listenersLock.Lock()
|
||||
defer h.listenersLock.Unlock()
|
||||
|
||||
h.listeners = append(h.listeners, l)
|
||||
}
|
||||
|
||||
func (h *handler) removeListener(id string) {
|
||||
h.listenersLock.Lock()
|
||||
defer h.listenersLock.Unlock()
|
||||
|
||||
for i, listener := range h.listeners {
|
||||
if listener.id == id {
|
||||
h.listeners = append(h.listeners[:i], h.listeners[i+1:]...)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// listenToEvents forwards events from the application layer to the set of listeners
|
||||
func (h *handler) listenToEvents() {
|
||||
channel := h.svc.GetEventsChannel(context.Background())
|
||||
for event := range channel {
|
||||
var ev *arkv1.GetEventStreamResponse
|
||||
|
||||
switch e := event.(type) {
|
||||
case domain.RoundFinalizationStarted:
|
||||
ev = &arkv1.GetEventStreamResponse{
|
||||
Event: &arkv1.GetEventStreamResponse_RoundFinalization{
|
||||
RoundFinalization: &arkv1.RoundFinalizationEvent{
|
||||
Id: e.Id,
|
||||
PoolPartialTx: e.PoolTx,
|
||||
CongestionTree: castCongestionTree(e.CongestionTree),
|
||||
ForfeitTxs: e.UnsignedForfeitTxs,
|
||||
},
|
||||
},
|
||||
}
|
||||
case domain.RoundFinalized:
|
||||
ev = &arkv1.GetEventStreamResponse{
|
||||
Event: &arkv1.GetEventStreamResponse_RoundFinalized{
|
||||
RoundFinalized: &arkv1.RoundFinalizedEvent{
|
||||
Id: e.Id,
|
||||
PoolTxid: e.Txid,
|
||||
},
|
||||
},
|
||||
}
|
||||
case domain.RoundFailed:
|
||||
ev = &arkv1.GetEventStreamResponse{
|
||||
Event: &arkv1.GetEventStreamResponse_RoundFailed{
|
||||
RoundFailed: &arkv1.RoundFailed{
|
||||
Id: e.Id,
|
||||
Reason: e.Err,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
if ev != nil {
|
||||
for _, listener := range h.listeners {
|
||||
listener.ch <- ev
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type vtxoList []domain.Vtxo
|
||||
|
||||
func (v vtxoList) toProto(hrp string, aspKey *secp256k1.PublicKey) []*arkv1.Vtxo {
|
||||
list := make([]*arkv1.Vtxo, 0, len(v))
|
||||
for _, vv := range v {
|
||||
addr := vv.OnchainAddress
|
||||
if vv.Pubkey != "" {
|
||||
buf, _ := hex.DecodeString(vv.Pubkey)
|
||||
key, _ := secp256k1.ParsePubKey(buf)
|
||||
addr, _ = common.EncodeAddress(hrp, key, aspKey)
|
||||
}
|
||||
list = append(list, &arkv1.Vtxo{
|
||||
Outpoint: &arkv1.Input{
|
||||
Txid: vv.Txid,
|
||||
Vout: vv.VOut,
|
||||
},
|
||||
Receiver: &arkv1.Output{
|
||||
Address: addr,
|
||||
Amount: vv.Amount,
|
||||
},
|
||||
PoolTxid: vv.PoolTx,
|
||||
Spent: vv.Spent,
|
||||
})
|
||||
}
|
||||
return list
|
||||
}
|
||||
|
||||
// castCongestionTree converts a tree.CongestionTree to a repeated arkv1.TreeLevel
|
||||
func castCongestionTree(congestionTree tree.CongestionTree) *arkv1.Tree {
|
||||
levels := make([]*arkv1.TreeLevel, 0, len(congestionTree))
|
||||
for _, level := range congestionTree {
|
||||
levelProto := &arkv1.TreeLevel{
|
||||
Nodes: make([]*arkv1.Node, 0, len(level)),
|
||||
}
|
||||
|
||||
for _, node := range level {
|
||||
levelProto.Nodes = append(levelProto.Nodes, &arkv1.Node{
|
||||
Txid: node.Txid,
|
||||
Tx: node.Tx,
|
||||
ParentTxid: node.ParentTxid,
|
||||
})
|
||||
}
|
||||
|
||||
levels = append(levels, levelProto)
|
||||
}
|
||||
return &arkv1.Tree{
|
||||
Levels: levels,
|
||||
}
|
||||
}
|
||||
64
server/internal/interface/grpc/handlers/utils.go
Normal file
64
server/internal/interface/grpc/handlers/utils.go
Normal file
@@ -0,0 +1,64 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
|
||||
arkv1 "github.com/ark-network/ark/api-spec/protobuf/gen/ark/v1"
|
||||
"github.com/ark-network/ark/common"
|
||||
"github.com/ark-network/ark/internal/core/domain"
|
||||
"github.com/decred/dcrd/dcrec/secp256k1/v4"
|
||||
"github.com/vulpemventures/go-elements/address"
|
||||
"github.com/vulpemventures/go-elements/psetv2"
|
||||
)
|
||||
|
||||
func parseTxs(txs []string) ([]string, error) {
|
||||
if len(txs) <= 0 {
|
||||
return nil, fmt.Errorf("missing list of forfeit txs")
|
||||
}
|
||||
for _, tx := range txs {
|
||||
if _, err := psetv2.NewPsetFromBase64(tx); err != nil {
|
||||
return nil, fmt.Errorf("invalid tx format")
|
||||
}
|
||||
}
|
||||
return txs, nil
|
||||
}
|
||||
|
||||
func parseAddress(addr string) (string, *secp256k1.PublicKey, *secp256k1.PublicKey, error) {
|
||||
if len(addr) <= 0 {
|
||||
return "", nil, nil, fmt.Errorf("missing address")
|
||||
}
|
||||
return common.DecodeAddress(addr)
|
||||
}
|
||||
|
||||
func parseReceivers(outs []*arkv1.Output) ([]domain.Receiver, error) {
|
||||
receivers := make([]domain.Receiver, 0, len(outs))
|
||||
for _, out := range outs {
|
||||
if out.GetAmount() == 0 {
|
||||
return nil, fmt.Errorf("missing output amount")
|
||||
}
|
||||
if len(out.GetAddress()) <= 0 {
|
||||
return nil, fmt.Errorf("missing output address")
|
||||
}
|
||||
var pubkey, addr string
|
||||
_, pk, _, err := common.DecodeAddress(out.GetAddress())
|
||||
if err != nil {
|
||||
if _, err := address.ToOutputScript(out.GetAddress()); err != nil {
|
||||
return nil, fmt.Errorf("invalid output address: unknown format")
|
||||
}
|
||||
if isConf, _ := address.IsConfidential(out.GetAddress()); isConf {
|
||||
return nil, fmt.Errorf("invalid output address: must be unconfidential")
|
||||
}
|
||||
addr = out.GetAddress()
|
||||
}
|
||||
if pk != nil {
|
||||
pubkey = hex.EncodeToString(pk.SerializeCompressed())
|
||||
}
|
||||
receivers = append(receivers, domain.Receiver{
|
||||
Pubkey: pubkey,
|
||||
Amount: out.GetAmount(),
|
||||
OnchainAddress: addr,
|
||||
})
|
||||
}
|
||||
return receivers, nil
|
||||
}
|
||||
16
server/internal/interface/grpc/interceptors/interceptor.go
Normal file
16
server/internal/interface/grpc/interceptors/interceptor.go
Normal file
@@ -0,0 +1,16 @@
|
||||
package interceptors
|
||||
|
||||
import (
|
||||
middleware "github.com/grpc-ecosystem/go-grpc-middleware"
|
||||
"google.golang.org/grpc"
|
||||
)
|
||||
|
||||
// UnaryInterceptor returns the unary interceptor
|
||||
func UnaryInterceptor() grpc.ServerOption {
|
||||
return grpc.UnaryInterceptor(middleware.ChainUnaryServer(unaryLogger))
|
||||
}
|
||||
|
||||
// StreamInterceptor returns the stream interceptor with a logrus log
|
||||
func StreamInterceptor() grpc.ServerOption {
|
||||
return grpc.StreamInterceptor(middleware.ChainStreamServer(streamLogger))
|
||||
}
|
||||
28
server/internal/interface/grpc/interceptors/logger.go
Normal file
28
server/internal/interface/grpc/interceptors/logger.go
Normal file
@@ -0,0 +1,28 @@
|
||||
package interceptors
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"google.golang.org/grpc"
|
||||
)
|
||||
|
||||
func unaryLogger(
|
||||
ctx context.Context,
|
||||
req interface{},
|
||||
info *grpc.UnaryServerInfo,
|
||||
handler grpc.UnaryHandler,
|
||||
) (interface{}, error) {
|
||||
log.Debugf("gRPC method: %s", info.FullMethod)
|
||||
return handler(ctx, req)
|
||||
}
|
||||
|
||||
func streamLogger(
|
||||
srv interface{},
|
||||
stream grpc.ServerStream,
|
||||
info *grpc.StreamServerInfo,
|
||||
handler grpc.StreamHandler,
|
||||
) error {
|
||||
log.Debugf("gRPC method: %s", info.FullMethod)
|
||||
return handler(srv, stream)
|
||||
}
|
||||
0
server/internal/interface/grpc/permissions/.gitkeep
Executable file
0
server/internal/interface/grpc/permissions/.gitkeep
Executable file
64
server/internal/interface/grpc/service.go
Normal file
64
server/internal/interface/grpc/service.go
Normal file
@@ -0,0 +1,64 @@
|
||||
package grpcservice
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
arkv1 "github.com/ark-network/ark/api-spec/protobuf/gen/ark/v1"
|
||||
appconfig "github.com/ark-network/ark/internal/app-config"
|
||||
interfaces "github.com/ark-network/ark/internal/interface"
|
||||
"github.com/ark-network/ark/internal/interface/grpc/handlers"
|
||||
"github.com/ark-network/ark/internal/interface/grpc/interceptors"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
)
|
||||
|
||||
type service struct {
|
||||
config Config
|
||||
appConfig *appconfig.Config
|
||||
server *grpc.Server
|
||||
}
|
||||
|
||||
func NewService(
|
||||
svcConfig Config, appConfig *appconfig.Config,
|
||||
) (interfaces.Service, error) {
|
||||
if err := svcConfig.Validate(); err != nil {
|
||||
return nil, fmt.Errorf("invalid service config: %s", err)
|
||||
}
|
||||
if err := appConfig.Validate(); err != nil {
|
||||
return nil, fmt.Errorf("invalid app config: %s", err)
|
||||
}
|
||||
|
||||
grpcConfig := []grpc.ServerOption{
|
||||
interceptors.UnaryInterceptor(), interceptors.StreamInterceptor(),
|
||||
}
|
||||
if !svcConfig.NoTLS {
|
||||
return nil, fmt.Errorf("tls termination not supported yet")
|
||||
}
|
||||
creds := insecure.NewCredentials()
|
||||
grpcConfig = append(grpcConfig, grpc.Creds(creds))
|
||||
server := grpc.NewServer(grpcConfig...)
|
||||
handler := handlers.NewHandler(appConfig.AppService())
|
||||
arkv1.RegisterArkServiceServer(server, handler)
|
||||
return &service{svcConfig, appConfig, server}, nil
|
||||
}
|
||||
|
||||
func (s *service) Start() error {
|
||||
// nolint:all
|
||||
go s.server.Serve(s.config.listener())
|
||||
log.Infof("started listening at %s", s.config.address())
|
||||
|
||||
if err := s.appConfig.AppService().Start(); err != nil {
|
||||
return fmt.Errorf("failed to start app service: %s", err)
|
||||
}
|
||||
log.Info("started app service")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *service) Stop() {
|
||||
s.server.Stop()
|
||||
log.Info("stopped grpc server")
|
||||
s.appConfig.AppService().Stop()
|
||||
log.Info("stopped app service")
|
||||
}
|
||||
6
server/internal/interface/service.go
Executable file
6
server/internal/interface/service.go
Executable file
@@ -0,0 +1,6 @@
|
||||
package interfaces
|
||||
|
||||
type Service interface {
|
||||
Start() error
|
||||
Stop()
|
||||
}
|
||||
0
server/internal/test/.gitkeep
Executable file
0
server/internal/test/.gitkeep
Executable file
Reference in New Issue
Block a user