Rename folders (#97)

* Rename arkd folder & drop cli

* Rename ark cli folder & update docs

* Update readme

* Fix

* scripts: add build-all

* Add target to build cli for all platforms

* Update build scripts

---------

Co-authored-by: tiero <3596602+tiero@users.noreply.github.com>
This commit is contained in:
Pietralberto Mazza
2024-02-09 19:32:58 +01:00
committed by GitHub
parent 0d8c7bffb2
commit dc00d60585
119 changed files with 154 additions and 449 deletions

View File

@@ -0,0 +1,154 @@
package badgerdb
import (
"context"
"fmt"
"path/filepath"
"sync"
"github.com/ark-network/ark/internal/core/domain"
dbtypes "github.com/ark-network/ark/internal/infrastructure/db/types"
"github.com/dgraph-io/badger/v4"
"github.com/timshannon/badgerhold/v4"
)
const eventStoreDir = "round-events"
type eventsDTO struct {
Events [][]byte
}
type eventRepository struct {
store *badgerhold.Store
lock *sync.RWMutex
chUpdates chan *domain.Round
handler func(round *domain.Round)
}
func NewRoundEventRepository(config ...interface{}) (dbtypes.EventStore, error) {
if len(config) != 2 {
return nil, fmt.Errorf("invalid config")
}
baseDir, ok := config[0].(string)
if !ok {
return nil, fmt.Errorf("invalid base directory")
}
var logger badger.Logger
if config[1] != nil {
logger, ok = config[1].(badger.Logger)
if !ok {
return nil, fmt.Errorf("invalid logger")
}
}
var dir string
if len(baseDir) > 0 {
dir = filepath.Join(baseDir, eventStoreDir)
}
store, err := createDB(dir, logger)
if err != nil {
return nil, fmt.Errorf("failed to open round events store: %s", err)
}
chEvents := make(chan *domain.Round)
lock := &sync.RWMutex{}
repo := &eventRepository{store, lock, chEvents, nil}
go repo.listen()
return repo, nil
}
func (r *eventRepository) Save(
ctx context.Context, id string, events ...domain.RoundEvent,
) error {
allEvents, err := r.get(ctx, id)
if err != nil {
return err
}
allEvents = append(allEvents, events...)
if err := r.upsert(ctx, id, allEvents); err != nil {
return err
}
go r.publishEvents(allEvents)
return nil
}
func (r *eventRepository) Load(
ctx context.Context, id string,
) (*domain.Round, error) {
events, err := r.get(ctx, id)
if err != nil {
return nil, err
}
return domain.NewRoundFromEvents(events), nil
}
func (r *eventRepository) RegisterEventsHandler(
handler func(round *domain.Round),
) {
r.lock.Lock()
defer r.lock.Unlock()
r.handler = handler
}
func (r *eventRepository) Close() {
close(r.chUpdates)
r.store.Close()
}
func (r *eventRepository) get(
ctx context.Context, id string,
) ([]domain.RoundEvent, error) {
dto := eventsDTO{}
var err error
if ctx.Value("tx") != nil {
tx := ctx.Value("tx").(*badger.Txn)
err = r.store.TxGet(tx, id, &dto)
} else {
err = r.store.Get(id, &dto)
}
if err != nil {
if err == badgerhold.ErrNotFound {
return nil, nil
}
return nil, fmt.Errorf("failed to get events with id %s: %s", id, err)
}
return deserializeEvents(dto.Events)
}
func (r *eventRepository) upsert(
ctx context.Context, id string, events []domain.RoundEvent,
) error {
buf, err := serializeEvents(events)
if err != nil {
return err
}
if ctx.Value("tx") != nil {
tx := ctx.Value("tx").(*badger.Txn)
err = r.store.TxUpsert(tx, id, buf)
} else {
err = r.store.Upsert(id, buf)
}
if err != nil {
return fmt.Errorf("failed to upsert events with id %s: %s", id, err)
}
return nil
}
func (r *eventRepository) listen() {
for updatedRound := range r.chUpdates {
r.lock.RLock()
if r.handler != nil {
r.handler(updatedRound)
}
r.lock.RUnlock()
}
}
func (r *eventRepository) publishEvents(events []domain.RoundEvent) {
r.lock.Lock()
defer r.lock.Unlock()
round := domain.NewRoundFromEvents(events)
r.chUpdates <- round
}

View File

@@ -0,0 +1,140 @@
package badgerdb
import (
"context"
"fmt"
"path/filepath"
"github.com/ark-network/ark/internal/core/domain"
dbtypes "github.com/ark-network/ark/internal/infrastructure/db/types"
"github.com/dgraph-io/badger/v4"
"github.com/timshannon/badgerhold/v4"
)
const roundStoreDir = "rounds"
type roundRepository struct {
store *badgerhold.Store
}
func NewRoundRepository(config ...interface{}) (dbtypes.RoundStore, error) {
if len(config) != 2 {
return nil, fmt.Errorf("invalid config")
}
baseDir, ok := config[0].(string)
if !ok {
return nil, fmt.Errorf("invalid base directory")
}
var logger badger.Logger
if config[1] != nil {
logger, ok = config[1].(badger.Logger)
if !ok {
return nil, fmt.Errorf("invalid logger")
}
}
var dir string
if len(baseDir) > 0 {
dir = filepath.Join(baseDir, roundStoreDir)
}
store, err := createDB(dir, logger)
if err != nil {
return nil, fmt.Errorf("failed to open round events store: %s", err)
}
return &roundRepository{store}, nil
}
func (r *roundRepository) AddOrUpdateRound(
ctx context.Context, round domain.Round,
) error {
return r.addOrUpdateRound(ctx, round)
}
func (r *roundRepository) GetCurrentRound(
ctx context.Context,
) (*domain.Round, error) {
query := badgerhold.Where("Stage.Ended").Eq(false).And("Stage.Failed").Eq(false)
rounds, err := r.findRound(ctx, query)
if err != nil {
return nil, err
}
if len(rounds) <= 0 {
return nil, fmt.Errorf("ongoing round not found")
}
return &rounds[0], nil
}
func (r *roundRepository) GetRoundWithId(
ctx context.Context, id string,
) (*domain.Round, error) {
query := badgerhold.Where("Id").Eq(id)
rounds, err := r.findRound(ctx, query)
if err != nil {
return nil, err
}
if len(rounds) <= 0 {
return nil, fmt.Errorf("round with id %s not found", id)
}
round := &rounds[0]
return round, nil
}
func (r *roundRepository) GetRoundWithTxid(
ctx context.Context, txid string,
) (*domain.Round, error) {
query := badgerhold.Where("Txid").Eq(txid)
rounds, err := r.findRound(ctx, query)
if err != nil {
return nil, err
}
if len(rounds) <= 0 {
return nil, fmt.Errorf("round with txid %s not found", txid)
}
round := &rounds[0]
return round, nil
}
func (r *roundRepository) GetSweepableRounds(
ctx context.Context,
) ([]domain.Round, error) {
query := badgerhold.Where("Stage.Code").Eq(domain.FinalizationStage).
And("Stage.Ended").Eq(true).And("Swept").Eq(false)
rounds, err := r.findRound(ctx, query)
if err != nil {
return nil, err
}
return rounds, nil
}
func (r *roundRepository) Close() {
r.store.Close()
}
func (r *roundRepository) findRound(
ctx context.Context, query *badgerhold.Query,
) ([]domain.Round, error) {
var rounds []domain.Round
var err error
if ctx.Value("tx") != nil {
tx := ctx.Value("tx").(*badger.Txn)
err = r.store.TxFind(tx, &rounds, query)
} else {
err = r.store.Find(&rounds, query)
}
return rounds, err
}
func (r *roundRepository) addOrUpdateRound(
ctx context.Context, round domain.Round,
) (err error) {
if ctx.Value("tx") != nil {
tx := ctx.Value("tx").(*badger.Txn)
err = r.store.TxUpsert(tx, round.Id, round)
} else {
err = r.store.Upsert(round.Id, round)
}
return
}

View File

@@ -0,0 +1,116 @@
package badgerdb
import (
"encoding/json"
"fmt"
"time"
"github.com/ark-network/ark/internal/core/domain"
"github.com/dgraph-io/badger/v4"
"github.com/dgraph-io/badger/v4/options"
"github.com/timshannon/badgerhold/v4"
)
func createDB(dbDir string, logger badger.Logger) (*badgerhold.Store, error) {
isInMemory := len(dbDir) <= 0
opts := badger.DefaultOptions(dbDir)
opts.Logger = logger
if isInMemory {
opts.InMemory = true
} else {
opts.Compression = options.ZSTD
}
db, err := badgerhold.Open(badgerhold.Options{
Encoder: badgerhold.DefaultEncode,
Decoder: badgerhold.DefaultDecode,
SequenceBandwith: 100,
Options: opts,
})
if err != nil {
return nil, err
}
if !isInMemory {
ticker := time.NewTicker(30 * time.Minute)
go func() {
for {
<-ticker.C
if err := db.Badger().RunValueLogGC(0.5); err != nil && err != badger.ErrNoRewrite {
logger.Errorf("%s", err)
}
}
}()
}
return db, nil
}
func serializeEvents(events []domain.RoundEvent) (*eventsDTO, error) {
rawEvents := make([][]byte, 0, len(events))
for _, event := range events {
buf, err := serializeEvent(event)
if err != nil {
return nil, err
}
rawEvents = append(rawEvents, buf)
}
return &eventsDTO{rawEvents}, nil
}
func deserializeEvents(rawEvents [][]byte) ([]domain.RoundEvent, error) {
events := make([]domain.RoundEvent, 0)
for _, buf := range rawEvents {
event, err := deserializeEvent(buf)
if err != nil {
return nil, err
}
events = append(events, event)
}
return events, nil
}
func serializeEvent(event domain.RoundEvent) ([]byte, error) {
switch eventType := event.(type) {
default:
return json.Marshal(eventType)
}
}
func deserializeEvent(buf []byte) (domain.RoundEvent, error) {
{
var event = domain.RoundFailed{}
if err := json.Unmarshal(buf, &event); err == nil && len(event.Err) > 0 {
return event, nil
}
}
{
var event = domain.RoundFinalized{}
if err := json.Unmarshal(buf, &event); err == nil && len(event.Txid) > 0 {
return event, nil
}
}
{
var event = domain.RoundFinalizationStarted{}
if err := json.Unmarshal(buf, &event); err == nil && len(event.CongestionTree) > 0 {
return event, nil
}
}
{
var event = domain.PaymentsRegistered{}
if err := json.Unmarshal(buf, &event); err == nil && len(event.Payments) > 0 {
return event, nil
}
}
{
var event = domain.RoundStarted{}
if err := json.Unmarshal(buf, &event); err == nil && event.Timestamp > 0 {
return event, nil
}
}
return nil, fmt.Errorf("unknown event")
}

View File

@@ -0,0 +1,245 @@
package badgerdb
import (
"context"
"fmt"
"path/filepath"
"strings"
"github.com/ark-network/ark/internal/core/domain"
dbtypes "github.com/ark-network/ark/internal/infrastructure/db/types"
"github.com/dgraph-io/badger/v4"
"github.com/timshannon/badgerhold/v4"
)
const vtxoStoreDir = "vtxos"
type vtxoRepository struct {
store *badgerhold.Store
}
func NewVtxoRepository(config ...interface{}) (dbtypes.VtxoStore, error) {
if len(config) != 2 {
return nil, fmt.Errorf("invalid config")
}
baseDir, ok := config[0].(string)
if !ok {
return nil, fmt.Errorf("invalid base directory")
}
var logger badger.Logger
if config[1] != nil {
logger, ok = config[1].(badger.Logger)
if !ok {
return nil, fmt.Errorf("invalid logger")
}
}
var dir string
if len(baseDir) > 0 {
dir = filepath.Join(baseDir, vtxoStoreDir)
}
store, err := createDB(dir, logger)
if err != nil {
return nil, fmt.Errorf("failed to open round events store: %s", err)
}
return &vtxoRepository{store}, nil
}
func (r *vtxoRepository) AddVtxos(
ctx context.Context, vtxos []domain.Vtxo,
) error {
return r.addVtxos(ctx, vtxos)
}
func (r *vtxoRepository) SpendVtxos(
ctx context.Context, vtxoKeys []domain.VtxoKey,
) error {
for _, vtxoKey := range vtxoKeys {
if err := r.spendVtxo(ctx, vtxoKey); err != nil {
return err
}
}
return nil
}
func (r *vtxoRepository) RedeemVtxos(
ctx context.Context, vtxoKeys []domain.VtxoKey,
) ([]domain.Vtxo, error) {
vtxos := make([]domain.Vtxo, 0, len(vtxoKeys))
for _, vtxoKey := range vtxoKeys {
vtxo, err := r.redeemVtxo(ctx, vtxoKey)
if err != nil {
return nil, err
}
if vtxo != nil {
vtxos = append(vtxos, *vtxo)
}
}
return vtxos, nil
}
func (r *vtxoRepository) GetVtxos(
ctx context.Context, vtxoKeys []domain.VtxoKey,
) ([]domain.Vtxo, error) {
vtxos := make([]domain.Vtxo, 0, len(vtxoKeys))
for _, vtxoKey := range vtxoKeys {
vtxo, err := r.getVtxo(ctx, vtxoKey)
if err != nil {
return nil, err
}
vtxos = append(vtxos, *vtxo)
}
return vtxos, nil
}
func (r *vtxoRepository) GetVtxosForRound(
ctx context.Context, txid string,
) ([]domain.Vtxo, error) {
query := badgerhold.Where("Txid").Eq(txid)
return r.findVtxos(ctx, query)
}
func (r *vtxoRepository) GetSpendableVtxos(
ctx context.Context, pubkey string,
) ([]domain.Vtxo, error) {
query := badgerhold.Where("Spent").Eq(false).And("Redeemed").Eq(false).And("Swept").Eq(false)
if len(pubkey) > 0 {
query = query.And("Pubkey").Eq(pubkey)
}
return r.findVtxos(ctx, query)
}
func (r *vtxoRepository) SweepVtxos(
ctx context.Context, vtxoKeys []domain.VtxoKey,
) error {
for _, vtxoKey := range vtxoKeys {
if err := r.sweepVtxo(ctx, vtxoKey); err != nil {
return err
}
}
return nil
}
func (r *vtxoRepository) Close() {
r.store.Close()
}
func (r *vtxoRepository) addVtxos(
ctx context.Context, vtxos []domain.Vtxo,
) (err error) {
for _, vtxo := range vtxos {
vtxoKey := vtxo.VtxoKey.Hash()
if ctx.Value("tx") != nil {
tx := ctx.Value("tx").(*badger.Txn)
err = r.store.TxInsert(tx, vtxoKey, vtxo)
} else {
err = r.store.Insert(vtxoKey, vtxo)
}
}
if err != nil && err == badgerhold.ErrKeyExists {
err = nil
}
return
}
func (r *vtxoRepository) getVtxo(
ctx context.Context, vtxoKey domain.VtxoKey,
) (*domain.Vtxo, error) {
var vtxo domain.Vtxo
var err error
if ctx.Value("tx") != nil {
tx := ctx.Value("tx").(*badger.Txn)
err = r.store.TxGet(tx, vtxoKey.Hash(), &vtxo)
} else {
err = r.store.Get(vtxoKey.Hash(), &vtxo)
}
if err != nil && err == badgerhold.ErrNotFound {
return nil, fmt.Errorf("vtxo %s:%d not found", vtxoKey.Txid, vtxoKey.VOut)
}
return &vtxo, nil
}
func (r *vtxoRepository) spendVtxo(ctx context.Context, vtxoKey domain.VtxoKey) error {
vtxo, err := r.getVtxo(ctx, vtxoKey)
if err != nil {
if strings.Contains(err.Error(), "not found") {
return nil
}
return err
}
if vtxo.Spent {
return nil
}
vtxo.Spent = true
if ctx.Value("tx") != nil {
tx := ctx.Value("tx").(*badger.Txn)
err = r.store.TxUpdate(tx, vtxoKey.Hash(), *vtxo)
} else {
err = r.store.Update(vtxoKey.Hash(), *vtxo)
}
return err
}
func (r *vtxoRepository) redeemVtxo(ctx context.Context, vtxoKey domain.VtxoKey) (*domain.Vtxo, error) {
vtxo, err := r.getVtxo(ctx, vtxoKey)
if err != nil {
if strings.Contains(err.Error(), "not found") {
return nil, nil
}
return nil, err
}
if vtxo.Redeemed {
return nil, nil
}
vtxo.Redeemed = true
if ctx.Value("tx") != nil {
tx := ctx.Value("tx").(*badger.Txn)
err = r.store.TxUpdate(tx, vtxoKey.Hash(), *vtxo)
} else {
err = r.store.Update(vtxoKey.Hash(), *vtxo)
}
if err != nil {
return nil, err
}
return vtxo, nil
}
func (r *vtxoRepository) findVtxos(ctx context.Context, query *badgerhold.Query) ([]domain.Vtxo, error) {
var vtxos []domain.Vtxo
var err error
if ctx.Value("tx") != nil {
tx := ctx.Value("tx").(*badger.Txn)
err = r.store.TxFind(tx, &vtxos, query)
} else {
err = r.store.Find(&vtxos, query)
}
return vtxos, err
}
func (r *vtxoRepository) sweepVtxo(ctx context.Context, vtxoKey domain.VtxoKey) error {
vtxo, err := r.getVtxo(ctx, vtxoKey)
if err != nil {
return err
}
if vtxo.Swept {
return nil
}
vtxo.Swept = true
if ctx.Value("tx") != nil {
tx := ctx.Value("tx").(*badger.Txn)
err = r.store.TxUpdate(tx, vtxoKey.Hash(), *vtxo)
} else {
err = r.store.Update(vtxoKey.Hash(), *vtxo)
}
if err != nil {
return err
}
return nil
}

View File

@@ -0,0 +1,90 @@
package db
import (
"fmt"
"github.com/ark-network/ark/internal/core/domain"
"github.com/ark-network/ark/internal/core/ports"
badgerdb "github.com/ark-network/ark/internal/infrastructure/db/badger"
dbtypes "github.com/ark-network/ark/internal/infrastructure/db/types"
)
var (
eventStoreTypes = map[string]func(...interface{}) (dbtypes.EventStore, error){
"badger": badgerdb.NewRoundEventRepository,
}
roundStoreTypes = map[string]func(...interface{}) (dbtypes.RoundStore, error){
"badger": badgerdb.NewRoundRepository,
}
vtxoStoreTypes = map[string]func(...interface{}) (dbtypes.VtxoStore, error){
"badger": badgerdb.NewVtxoRepository,
}
)
type ServiceConfig struct {
EventStoreType string
RoundStoreType string
VtxoStoreType string
EventStoreConfig []interface{}
RoundStoreConfig []interface{}
VtxoStoreConfig []interface{}
}
type service struct {
eventStore dbtypes.EventStore
roundStore dbtypes.RoundStore
vtxoStore dbtypes.VtxoStore
}
func NewService(config ServiceConfig) (ports.RepoManager, error) {
eventStoreFactory, ok := eventStoreTypes[config.EventStoreType]
if !ok {
return nil, fmt.Errorf("event store type not supported")
}
roundStoreFactory, ok := roundStoreTypes[config.RoundStoreType]
if !ok {
return nil, fmt.Errorf("round store type not supported")
}
vtxoStoreFactory, ok := vtxoStoreTypes[config.VtxoStoreType]
if !ok {
return nil, fmt.Errorf("vtxo store type not supported")
}
eventStore, err := eventStoreFactory(config.EventStoreConfig...)
if err != nil {
return nil, fmt.Errorf("failed to open event store: %s", err)
}
roundStore, err := roundStoreFactory(config.RoundStoreConfig...)
if err != nil {
return nil, fmt.Errorf("failed to open round store: %s", err)
}
vtxoStore, err := vtxoStoreFactory(config.VtxoStoreConfig...)
if err != nil {
return nil, fmt.Errorf("failed to open vtxo store: %s", err)
}
return &service{eventStore, roundStore, vtxoStore}, nil
}
func (s *service) RegisterEventsHandler(handler func(round *domain.Round)) {
s.eventStore.RegisterEventsHandler(handler)
}
func (s *service) Events() domain.RoundEventRepository {
return s.eventStore
}
func (s *service) Rounds() domain.RoundRepository {
return s.roundStore
}
func (s *service) Vtxos() domain.VtxoRepository {
return s.vtxoStore
}
func (s *service) Close() {
s.eventStore.Close()
s.roundStore.Close()
s.vtxoStore.Close()
}

View File

@@ -0,0 +1,417 @@
package db_test
import (
"context"
"reflect"
"testing"
"time"
"github.com/ark-network/ark/common/tree"
"github.com/ark-network/ark/internal/core/domain"
"github.com/ark-network/ark/internal/core/ports"
"github.com/ark-network/ark/internal/infrastructure/db"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
const (
emptyPtx = "cHNldP8BAgQCAAAAAQQBAAEFAQABBgEDAfsEAgAAAAA="
emptyTx = "0200000000000000000000"
txid = "00000000000000000000000000000000000000000000000000000000000000000"
pubkey1 = "0300000000000000000000000000000000000000000000000000000000000000001"
pubkey2 = "0200000000000000000000000000000000000000000000000000000000000000002"
)
var congestionTree = [][]tree.Node{
{
{
Txid: txid,
Tx: emptyPtx,
ParentTxid: txid,
},
},
{
{
Txid: txid,
Tx: emptyPtx,
ParentTxid: txid,
},
{
Txid: txid,
Tx: emptyPtx,
ParentTxid: txid,
},
},
{
{
Txid: txid,
Tx: emptyPtx,
ParentTxid: txid,
},
{
Txid: txid,
Tx: emptyPtx,
ParentTxid: txid,
},
{
Txid: txid,
Tx: emptyPtx,
ParentTxid: txid,
},
{
Txid: txid,
Tx: emptyPtx,
ParentTxid: txid,
},
},
}
func TestService(t *testing.T) {
tests := []struct {
name string
config db.ServiceConfig
}{
{
name: "repo_manager_with_badger_stores",
config: db.ServiceConfig{
EventStoreType: "badger",
RoundStoreType: "badger",
VtxoStoreType: "badger",
EventStoreConfig: []interface{}{"", nil},
RoundStoreConfig: []interface{}{"", nil},
VtxoStoreConfig: []interface{}{"", nil},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
svc, err := db.NewService(tt.config)
require.NoError(t, err)
require.NotNil(t, svc)
testRoundEventRepository(t, svc)
testRoundRepository(t, svc)
testVtxoRepository(t, svc)
svc.Close()
})
}
}
func testRoundEventRepository(t *testing.T, svc ports.RepoManager) {
t.Run("test_event_repository", func(t *testing.T) {
fixtures := []struct {
roundId string
events []domain.RoundEvent
handler func(*domain.Round)
}{
{
roundId: "42dd81f7-cadd-482c-bf69-8e9209aae9f3",
events: []domain.RoundEvent{
domain.RoundStarted{
Id: "42dd81f7-cadd-482c-bf69-8e9209aae9f3",
Timestamp: 1701190270,
},
},
handler: func(round *domain.Round) {
require.NotNil(t, round)
require.Len(t, round.Events(), 1)
require.True(t, round.IsStarted())
require.False(t, round.IsFailed())
require.False(t, round.IsEnded())
},
},
{
roundId: "1ea610ff-bf3e-4068-9bfd-b6c3f553467e",
events: []domain.RoundEvent{
domain.RoundStarted{
Id: "1ea610ff-bf3e-4068-9bfd-b6c3f553467e",
Timestamp: 1701190270,
},
domain.RoundFinalizationStarted{
Id: "1ea610ff-bf3e-4068-9bfd-b6c3f553467e",
CongestionTree: congestionTree,
Connectors: []string{emptyPtx, emptyPtx},
PoolTx: emptyTx,
},
},
handler: func(round *domain.Round) {
require.NotNil(t, round)
require.Len(t, round.Events(), 2)
require.Len(t, round.CongestionTree, 3)
require.Equal(t, round.CongestionTree.NumberOfNodes(), 7)
require.Len(t, round.Connectors, 2)
},
},
{
roundId: "7578231e-428d-45ae-aaa4-e62c77ad5cec",
events: []domain.RoundEvent{
domain.RoundStarted{
Id: "7578231e-428d-45ae-aaa4-e62c77ad5cec",
Timestamp: 1701190270,
},
domain.RoundFinalizationStarted{
Id: "7578231e-428d-45ae-aaa4-e62c77ad5cec",
CongestionTree: congestionTree,
Connectors: []string{emptyPtx, emptyPtx},
PoolTx: emptyTx,
},
domain.RoundFinalized{
Id: "7578231e-428d-45ae-aaa4-e62c77ad5cec",
Txid: txid,
ForfeitTxs: []string{emptyPtx, emptyPtx, emptyPtx, emptyPtx},
Timestamp: 1701190300,
},
},
handler: func(round *domain.Round) {
require.NotNil(t, round)
require.Len(t, round.Events(), 3)
require.False(t, round.IsStarted())
require.False(t, round.IsFailed())
require.True(t, round.IsEnded())
require.NotEmpty(t, round.Txid)
},
},
}
ctx := context.Background()
for _, f := range fixtures {
svc.RegisterEventsHandler(f.handler)
err := svc.Events().Save(ctx, f.roundId, f.events...)
require.NoError(t, err)
round, err := svc.Events().Load(ctx, f.roundId)
require.NoError(t, err)
require.NotNil(t, round)
require.Equal(t, f.roundId, round.Id)
require.Len(t, round.Events(), len(f.events))
}
})
}
func testRoundRepository(t *testing.T, svc ports.RepoManager) {
t.Run("test_round_repository", func(t *testing.T) {
ctx := context.Background()
now := time.Now()
roundId := uuid.New().String()
round, err := svc.Rounds().GetRoundWithId(ctx, roundId)
require.Error(t, err)
require.Nil(t, round)
events := []domain.RoundEvent{
domain.RoundStarted{
Id: roundId,
Timestamp: now.Unix(),
},
}
round = domain.NewRoundFromEvents(events)
err = svc.Rounds().AddOrUpdateRound(ctx, *round)
require.NoError(t, err)
currentRound, err := svc.Rounds().GetCurrentRound(ctx)
require.NoError(t, err)
require.NotNil(t, currentRound)
require.Condition(t, roundsMatch(*round, *currentRound))
roundById, err := svc.Rounds().GetRoundWithId(ctx, roundId)
require.NoError(t, err)
require.NotNil(t, roundById)
require.Condition(t, roundsMatch(*round, *roundById))
newEvents := []domain.RoundEvent{
domain.PaymentsRegistered{
Id: roundId,
Payments: []domain.Payment{
{
Id: uuid.New().String(),
Inputs: []domain.Vtxo{{}},
Receivers: []domain.Receiver{{}},
},
{
Id: uuid.New().String(),
Inputs: []domain.Vtxo{{}},
Receivers: []domain.Receiver{{}, {}, {}},
},
},
},
domain.RoundFinalizationStarted{
Id: roundId,
CongestionTree: congestionTree,
Connectors: []string{emptyPtx, emptyPtx},
PoolTx: emptyTx,
},
}
events = append(events, newEvents...)
updatedRound := domain.NewRoundFromEvents(events)
err = svc.Rounds().AddOrUpdateRound(ctx, *updatedRound)
require.NoError(t, err)
currentRound, err = svc.Rounds().GetCurrentRound(ctx)
require.NoError(t, err)
require.NotNil(t, currentRound)
require.Condition(t, roundsMatch(*updatedRound, *currentRound))
roundById, err = svc.Rounds().GetRoundWithId(ctx, updatedRound.Id)
require.NoError(t, err)
require.NotNil(t, currentRound)
require.Condition(t, roundsMatch(*updatedRound, *roundById))
newEvents = []domain.RoundEvent{
domain.RoundFinalized{
Id: roundId,
Txid: txid,
ForfeitTxs: []string{emptyPtx, emptyPtx, emptyPtx, emptyPtx},
Timestamp: now.Add(60 * time.Second).Unix(),
},
}
events = append(events, newEvents...)
finalizedRound := domain.NewRoundFromEvents(events)
err = svc.Rounds().AddOrUpdateRound(ctx, *finalizedRound)
require.NoError(t, err)
currentRound, err = svc.Rounds().GetCurrentRound(ctx)
require.Error(t, err)
require.Nil(t, currentRound)
roundById, err = svc.Rounds().GetRoundWithId(ctx, roundId)
require.NoError(t, err)
require.NotNil(t, roundById)
require.Condition(t, roundsMatch(*finalizedRound, *roundById))
roundByTxid, err := svc.Rounds().GetRoundWithTxid(ctx, txid)
require.NoError(t, err)
require.NotNil(t, roundByTxid)
require.Condition(t, roundsMatch(*finalizedRound, *roundByTxid))
})
}
func testVtxoRepository(t *testing.T, svc ports.RepoManager) {
t.Run("test_vtxo_repository", func(t *testing.T) {
ctx := context.Background()
userVtxos := []domain.Vtxo{
{
VtxoKey: domain.VtxoKey{
Txid: txid,
VOut: 0,
},
Receiver: domain.Receiver{
Pubkey: pubkey1,
Amount: 1000,
},
},
{
VtxoKey: domain.VtxoKey{
Txid: txid,
VOut: 1,
},
Receiver: domain.Receiver{
Pubkey: pubkey1,
Amount: 2000,
},
},
}
newVtxos := append(userVtxos, domain.Vtxo{
VtxoKey: domain.VtxoKey{
Txid: txid,
VOut: 1,
},
Receiver: domain.Receiver{
Pubkey: pubkey2,
Amount: 2000,
},
})
vtxoKeys := make([]domain.VtxoKey, 0, len(userVtxos))
for _, v := range userVtxos {
vtxoKeys = append(vtxoKeys, v.VtxoKey)
}
vtxos, err := svc.Vtxos().GetVtxos(ctx, vtxoKeys)
require.Error(t, err)
require.Empty(t, vtxos)
spendableVtxos, err := svc.Vtxos().GetSpendableVtxos(ctx, pubkey1)
require.NoError(t, err)
require.Empty(t, spendableVtxos)
spendableVtxos, err = svc.Vtxos().GetSpendableVtxos(ctx, "")
require.NoError(t, err)
require.Empty(t, spendableVtxos)
err = svc.Vtxos().AddVtxos(ctx, newVtxos)
require.NoError(t, err)
vtxos, err = svc.Vtxos().GetVtxos(ctx, vtxoKeys)
require.NoError(t, err)
require.Exactly(t, userVtxos, vtxos)
spendableVtxos, err = svc.Vtxos().GetSpendableVtxos(ctx, pubkey1)
require.NoError(t, err)
require.Exactly(t, vtxos, spendableVtxos)
spendableVtxos, err = svc.Vtxos().GetSpendableVtxos(ctx, "")
require.NoError(t, err)
require.Exactly(t, userVtxos, spendableVtxos)
err = svc.Vtxos().SpendVtxos(ctx, vtxoKeys[:1])
require.NoError(t, err)
spentVtxos, err := svc.Vtxos().GetVtxos(ctx, vtxoKeys[:1])
require.NoError(t, err)
require.Len(t, spentVtxos, len(vtxoKeys[:1]))
for _, v := range spentVtxos {
require.True(t, v.Spent)
}
spendableVtxos, err = svc.Vtxos().GetSpendableVtxos(ctx, pubkey1)
require.NoError(t, err)
require.Exactly(t, vtxos[1:], spendableVtxos)
})
}
func roundsMatch(expected, got domain.Round) assert.Comparison {
return func() bool {
if expected.Id != got.Id {
return false
}
if expected.StartingTimestamp != got.StartingTimestamp {
return false
}
if expected.EndingTimestamp != got.EndingTimestamp {
return false
}
if expected.Stage != got.Stage {
return false
}
if !reflect.DeepEqual(expected.Payments, got.Payments) {
return false
}
if expected.Txid != got.Txid {
return false
}
if expected.UnsignedTx != got.UnsignedTx {
return false
}
if !reflect.DeepEqual(expected.ForfeitTxs, got.ForfeitTxs) {
return false
}
if !reflect.DeepEqual(expected.CongestionTree, got.CongestionTree) {
return false
}
if !reflect.DeepEqual(expected.Connectors, got.Connectors) {
return false
}
if expected.Version != got.Version {
return false
}
return true
}
}

View File

@@ -0,0 +1,19 @@
package dbtypes
import "github.com/ark-network/ark/internal/core/domain"
type EventStore interface {
domain.RoundEventRepository
RegisterEventsHandler(func(*domain.Round))
Close()
}
type RoundStore interface {
domain.RoundRepository
Close()
}
type VtxoStore interface {
domain.VtxoRepository
Close()
}

View File

@@ -0,0 +1,30 @@
package oceanwallet
import (
"context"
pb "github.com/ark-network/ark/api-spec/protobuf/gen/ocean/v1"
"github.com/vulpemventures/go-elements/address"
)
func (s *service) DeriveAddresses(
ctx context.Context, numOfAddresses int,
) ([]string, error) {
res, err := s.accountClient.DeriveAddresses(ctx, &pb.DeriveAddressesRequest{
AccountName: accountLabel,
NumOfAddresses: uint64(numOfAddresses),
})
if err != nil {
return nil, err
}
addresses := make([]string, 0, numOfAddresses)
for _, addr := range res.GetAddresses() {
if isConf, _ := address.IsConfidential(addr); !isConf {
addresses = append(addresses, addr)
continue
}
info, _ := address.FromConfidential(addr)
addresses = append(addresses, info.Address)
}
return addresses, nil
}

View File

@@ -0,0 +1,47 @@
package oceanwallet
import (
"context"
"crypto/sha256"
"encoding/hex"
pb "github.com/ark-network/ark/api-spec/protobuf/gen/ocean/v1"
"github.com/ark-network/ark/internal/core/domain"
"github.com/btcsuite/btcd/chaincfg/chainhash"
)
func (s *service) WatchScripts(ctx context.Context, scripts []string) error {
for _, script := range scripts {
if _, err := s.notifyClient.WatchExternalScript(ctx, &pb.WatchExternalScriptRequest{
Script: script,
}); err != nil {
return err
}
}
return nil
}
func (s *service) UnwatchScripts(ctx context.Context, scripts []string) error {
for _, script := range scripts {
scriptHash := calcScriptHash(script)
if _, err := s.notifyClient.UnwatchExternalScript(ctx, &pb.UnwatchExternalScriptRequest{
Label: scriptHash,
}); err != nil {
return err
}
}
return nil
}
func (s *service) GetNotificationChannel(ctx context.Context) chan []domain.VtxoKey {
return s.chVtxos
}
func calcScriptHash(script string) string {
buf, _ := hex.DecodeString(script)
hashedBuf := sha256.Sum256(buf)
hash, _ := chainhash.NewHash(hashedBuf[:])
return hash.String()
}

View File

@@ -0,0 +1,140 @@
package oceanwallet
import (
"context"
"fmt"
"io"
"strings"
pb "github.com/ark-network/ark/api-spec/protobuf/gen/ocean/v1"
"github.com/ark-network/ark/internal/core/domain"
"github.com/ark-network/ark/internal/core/ports"
log "github.com/sirupsen/logrus"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/status"
)
type service struct {
addr string
conn *grpc.ClientConn
walletClient pb.WalletServiceClient
accountClient pb.AccountServiceClient
txClient pb.TransactionServiceClient
notifyClient pb.NotificationServiceClient
chVtxos chan []domain.VtxoKey
}
func NewService(addr string) (ports.WalletService, error) {
conn, err := grpc.Dial(addr, grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
return nil, err
}
walletClient := pb.NewWalletServiceClient(conn)
accountClient := pb.NewAccountServiceClient(conn)
txClient := pb.NewTransactionServiceClient(conn)
notifyClient := pb.NewNotificationServiceClient(conn)
chVtxos := make(chan []domain.VtxoKey)
svc := &service{
addr: addr,
conn: conn,
walletClient: walletClient,
accountClient: accountClient,
txClient: txClient,
notifyClient: notifyClient,
chVtxos: chVtxos,
}
ctx := context.Background()
status, err := svc.Status(ctx)
if err != nil {
return nil, err
}
if !(status.IsInitialized() && status.IsUnlocked()) {
return nil, fmt.Errorf("wallet must be already initialized and unlocked")
}
// Create ark account at startup if needed.
info, err := walletClient.GetInfo(ctx, &pb.GetInfoRequest{})
if err != nil {
return nil, err
}
found := false
for _, account := range info.GetAccounts() {
if account.GetLabel() == accountLabel {
found = true
break
}
}
if !found {
if _, err := accountClient.CreateAccountBIP44(ctx, &pb.CreateAccountBIP44Request{
Label: accountLabel,
Unconfidential: true,
}); err != nil {
return nil, err
}
}
go svc.listenToNotificaitons()
return svc, nil
}
func (s *service) Close() {
close(s.chVtxos)
s.conn.Close()
}
func (s *service) listenToNotificaitons() {
var stream pb.NotificationService_UtxosNotificationsClient
var err error
for {
stream, err = s.notifyClient.UtxosNotifications(context.Background(), &pb.UtxosNotificationsRequest{})
if err != nil {
continue
}
break
}
for {
msg, err := stream.Recv()
if err != nil {
if err == io.EOF || status.Convert(err).Code() == codes.Canceled {
return
}
log.WithError(err).Warn("received unexpected error from source")
return
}
if msg.GetEventType() != pb.UtxoEventType_UTXO_EVENT_TYPE_NEW &&
msg.GetEventType() != pb.UtxoEventType_UTXO_EVENT_TYPE_CONFIRMED {
continue
}
vtxos := toVtxos(msg.GetUtxos())
if len(vtxos) > 0 {
go func() {
s.chVtxos <- vtxos
}()
}
}
}
func toVtxos(utxos []*pb.Utxo) []domain.VtxoKey {
vtxos := make([]domain.VtxoKey, 0, len(utxos))
for _, utxo := range utxos {
// We want to notify for activity related to vtxos owner, therefore we skip
// returning anything related to the internal accounts of the wallet, like
// for example bip84-account0.
if strings.HasPrefix(utxo.GetAccountName(), "bip") {
continue
}
vtxos = append(vtxos, domain.VtxoKey{
Txid: utxo.GetTxid(),
VOut: utxo.GetIndex(),
})
}
return vtxos
}

View File

@@ -0,0 +1,270 @@
package oceanwallet
import (
"context"
"encoding/binary"
"encoding/hex"
"fmt"
"strings"
"time"
pb "github.com/ark-network/ark/api-spec/protobuf/gen/ocean/v1"
"github.com/ark-network/ark/common/tree"
"github.com/ark-network/ark/internal/core/ports"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/txscript"
"github.com/vulpemventures/go-elements/elementsutil"
"github.com/vulpemventures/go-elements/psetv2"
)
const (
zero32 = "0000000000000000000000000000000000000000000000000000000000000000"
)
func (s *service) SignPset(
ctx context.Context, pset string, extractRawTx bool,
) (string, error) {
res, err := s.txClient.SignPset(ctx, &pb.SignPsetRequest{
Pset: pset,
})
if err != nil {
return "", err
}
signedPset := res.GetPset()
if !extractRawTx {
return signedPset, nil
}
ptx, err := psetv2.NewPsetFromBase64(signedPset)
if err != nil {
return "", err
}
if err := psetv2.MaybeFinalizeAll(ptx); err != nil {
return "", fmt.Errorf("failed to finalize signed pset: %s", err)
}
extractedTx, err := psetv2.Extract(ptx)
if err != nil {
return "", fmt.Errorf("failed to extract signed pset: %s", err)
}
txHex, err := extractedTx.ToHex()
if err != nil {
return "", fmt.Errorf("failed to convert extracted tx to hex: %s", err)
}
return txHex, nil
}
func (s *service) SelectUtxos(ctx context.Context, asset string, amount uint64) ([]ports.TxInput, uint64, error) {
res, err := s.txClient.SelectUtxos(ctx, &pb.SelectUtxosRequest{
AccountName: accountLabel,
TargetAsset: asset,
TargetAmount: amount,
})
if err != nil {
return nil, 0, err
}
inputs := make([]ports.TxInput, 0, len(res.GetUtxos()))
for _, utxo := range res.GetUtxos() {
// check that the utxos are not confidential
if utxo.GetAssetBlinder() != zero32 || utxo.GetValueBlinder() != zero32 {
return nil, 0, fmt.Errorf("utxo is confidential")
}
inputs = append(inputs, utxo)
}
return inputs, res.GetChange(), nil
}
func (s *service) GetTransaction(
ctx context.Context, txid string,
) (string, int64, error) {
res, err := s.txClient.GetTransaction(ctx, &pb.GetTransactionRequest{
Txid: txid,
})
if err != nil {
return "", 0, err
}
if res.GetBlockDetails().GetTimestamp() > 0 {
return res.GetTxHex(), res.BlockDetails.GetTimestamp(), nil
}
// if not confirmed, we return now + 30 secs to estimate the next blocktime
return res.GetTxHex(), time.Now().Unix() + 30, nil
}
func (s *service) BroadcastTransaction(
ctx context.Context, txHex string,
) (string, error) {
res, err := s.txClient.BroadcastTransaction(
ctx, &pb.BroadcastTransactionRequest{
TxHex: txHex,
},
)
if err != nil {
if strings.Contains(err.Error(), "non-BIP68-final") {
return "", fmt.Errorf("non-BIP68-final")
}
return "", err
}
return res.GetTxid(), nil
}
func (s *service) IsTransactionPublished(
ctx context.Context, txid string,
) (bool, int64, error) {
_, blocktime, err := s.GetTransaction(ctx, txid)
if err != nil {
if strings.Contains(strings.ToLower(err.Error()), "missing transaction") {
return false, 0, nil
}
return false, 0, err
}
return true, blocktime, nil
}
func (s *service) SignPsetWithKey(ctx context.Context, b64 string, indexes []int) (string, error) {
pset, err := psetv2.NewPsetFromBase64(b64)
if err != nil {
return "", err
}
if indexes == nil {
for i := 0; i < len(pset.Inputs); i++ {
indexes = append(indexes, i)
}
}
key, masterKey, err := s.getPubkey(ctx)
if err != nil {
return "", err
}
fingerprint := binary.LittleEndian.Uint32(masterKey.FingerPrint)
extendedKey, err := masterKey.Serialize()
if err != nil {
return "", err
}
pset.Global.Xpubs = []psetv2.Xpub{{
ExtendedKey: extendedKey[:len(extendedKey)-4],
MasterFingerprint: fingerprint,
DerivationPath: derivationPath,
}}
updater, err := psetv2.NewUpdater(pset)
if err != nil {
return "", err
}
bip32derivation := psetv2.DerivationPathWithPubKey{
PubKey: key.SerializeCompressed(),
MasterKeyFingerprint: fingerprint,
Bip32Path: derivationPath,
}
for _, i := range indexes {
if len(pset.Inputs[i].TapLeafScript) == 0 {
return "", fmt.Errorf("no tap leaf script found for input %d", i)
}
leafHash := pset.Inputs[i].TapLeafScript[0].TapHash()
if err := updater.AddInTapBip32Derivation(i, psetv2.TapDerivationPathWithPubKey{
DerivationPathWithPubKey: bip32derivation,
LeafHashes: [][]byte{leafHash[:]},
}); err != nil {
return "", err
}
if err := updater.AddInSighashType(i, txscript.SigHashDefault); err != nil {
return "", err
}
}
unsignedPset, err := pset.ToBase64()
if err != nil {
return "", err
}
signedPset, err := s.txClient.SignPsetWithSchnorrKey(ctx, &pb.SignPsetWithSchnorrKeyRequest{
Tx: unsignedPset,
SighashType: uint32(txscript.SigHashDefault),
})
if err != nil {
return "", err
}
return signedPset.GetSignedTx(), nil
}
func (s *service) EstimateFees(
ctx context.Context, pset string,
) (uint64, error) {
tx, err := psetv2.NewPsetFromBase64(pset)
if err != nil {
return 0, err
}
inputs := make([]*pb.Input, 0, len(tx.Inputs))
outputs := make([]*pb.Output, 0, len(tx.Outputs))
for _, in := range tx.Inputs {
pbInput := &pb.Input{
Txid: chainhash.Hash(in.PreviousTxid).String(),
Index: in.PreviousTxIndex,
}
if len(in.TapLeafScript) == 1 {
isSweep, _, _, err := tree.DecodeSweepScript(in.TapLeafScript[0].Script)
if err != nil {
return 0, err
}
if isSweep {
pbInput.WitnessSize = 64
pbInput.ScriptsigSize = 0
}
} else {
if in.WitnessUtxo == nil {
return 0, fmt.Errorf("missing witness utxo, cannot estimate fees")
}
pbInput.Script = hex.EncodeToString(in.WitnessUtxo.Script)
}
inputs = append(inputs, pbInput)
}
for _, out := range tx.Outputs {
outputs = append(outputs, &pb.Output{
Asset: elementsutil.AssetHashFromBytes(
append([]byte{0x01}, out.Asset...),
),
Amount: out.Value,
Script: hex.EncodeToString(out.Script),
})
}
fee, err := s.txClient.EstimateFees(
ctx,
&pb.EstimateFeesRequest{
Inputs: inputs,
Outputs: outputs,
},
)
if err != nil {
return 0, fmt.Errorf("failed to estimate fees: %s", err)
}
// we add 5 sats in order to avoid min-relay-fee not met errors
return fee.GetFeeAmount() + 5, nil
}

View File

@@ -0,0 +1,95 @@
package oceanwallet
import (
"context"
"fmt"
pb "github.com/ark-network/ark/api-spec/protobuf/gen/ocean/v1"
"github.com/ark-network/ark/internal/core/ports"
"github.com/btcsuite/btcd/btcutil/hdkeychain"
"github.com/decred/dcrd/dcrec/secp256k1/v4"
"github.com/vulpemventures/go-bip32"
)
const accountLabel = "ark"
var derivationPath = []uint32{0, 0}
func (s *service) GetPubkey(ctx context.Context) (*secp256k1.PublicKey, error) {
key, _, err := s.getPubkey(ctx)
return key, err
}
func (s *service) Status(
ctx context.Context,
) (ports.WalletStatus, error) {
res, err := s.walletClient.Status(ctx, &pb.StatusRequest{})
if err != nil {
return nil, err
}
return walletStatus{res}, nil
}
type walletStatus struct {
*pb.StatusResponse
}
func (w walletStatus) IsInitialized() bool {
return w.StatusResponse.GetInitialized()
}
func (w walletStatus) IsUnlocked() bool {
return w.StatusResponse.GetUnlocked()
}
func (w walletStatus) IsSynced() bool {
return w.StatusResponse.GetSynced()
}
func (s *service) findAccount(ctx context.Context, label string) (*pb.AccountInfo, error) {
res, err := s.walletClient.GetInfo(ctx, &pb.GetInfoRequest{})
if err != nil {
return nil, err
}
if len(res.GetAccounts()) <= 0 {
return nil, fmt.Errorf("wallet is locked")
}
for _, account := range res.GetAccounts() {
if account.GetLabel() == label {
return account, nil
}
}
return nil, fmt.Errorf("account not found")
}
func (s *service) getPubkey(ctx context.Context) (*secp256k1.PublicKey, *bip32.Key, error) {
account, err := s.findAccount(ctx, accountLabel)
if err != nil {
return nil, nil, err
}
xpub := account.GetXpubs()[0]
node, err := hdkeychain.NewKeyFromString(xpub)
if err != nil {
return nil, nil, err
}
for _, i := range derivationPath {
node, err = node.Derive(i)
if err != nil {
return nil, nil, err
}
}
key, err := node.ECPubKey()
if err != nil {
return nil, nil, err
}
masterKey, err := bip32.B58Deserialize(xpub)
if err != nil {
return nil, nil, err
}
return key, masterKey, nil
}

View File

@@ -0,0 +1,45 @@
package scheduler
import (
"fmt"
"time"
"github.com/ark-network/ark/internal/core/ports"
"github.com/go-co-op/gocron"
)
type service struct {
scheduler *gocron.Scheduler
}
func NewScheduler() ports.SchedulerService {
svc := gocron.NewScheduler(time.UTC)
return &service{svc}
}
func (s *service) Start() {
s.scheduler.StartAsync()
}
func (s *service) Stop() {
s.scheduler.Stop()
}
func (s *service) ScheduleTask(interval int64, immediate bool, task func()) error {
if immediate {
_, err := s.scheduler.Every(int(interval)).Seconds().Do(task)
return err
}
_, err := s.scheduler.Every(int(interval)).Seconds().WaitForSchedule().Do(task)
return err
}
func (s *service) ScheduleTaskOnce(at int64, task func()) error {
delay := at - time.Now().Unix()
if delay < 0 {
return fmt.Errorf("cannot schedule task in the past")
}
_, err := s.scheduler.Every(int(delay)).Seconds().WaitForSchedule().LimitRunsTo(1).Do(task)
return err
}

View File

@@ -0,0 +1,483 @@
package txbuilder
import (
"context"
"encoding/hex"
"fmt"
"github.com/ark-network/ark/common/tree"
"github.com/ark-network/ark/internal/core/domain"
"github.com/ark-network/ark/internal/core/ports"
"github.com/decred/dcrd/dcrec/secp256k1/v4"
"github.com/vulpemventures/go-elements/address"
"github.com/vulpemventures/go-elements/network"
"github.com/vulpemventures/go-elements/psetv2"
"github.com/vulpemventures/go-elements/taproot"
)
const (
connectorAmount = uint64(450)
dustLimit = uint64(450)
)
type txBuilder struct {
wallet ports.WalletService
net *network.Network
roundLifetime int64 // in seconds
}
func NewTxBuilder(
wallet ports.WalletService, net network.Network, roundLifetime int64,
) ports.TxBuilder {
return &txBuilder{wallet, &net, roundLifetime}
}
func (b *txBuilder) GetVtxoScript(userPubkey, aspPubkey *secp256k1.PublicKey) ([]byte, error) {
outputScript, _, err := b.getLeafScriptAndTree(userPubkey, aspPubkey)
if err != nil {
return nil, err
}
return outputScript, nil
}
func (b *txBuilder) BuildSweepTx(wallet ports.WalletService, inputs []ports.SweepInput) (signedSweepTx string, err error) {
sweepPset, err := sweepTransaction(
wallet,
inputs,
b.net.AssetID,
)
if err != nil {
return "", err
}
sweepPsetBase64, err := sweepPset.ToBase64()
if err != nil {
return "", err
}
ctx := context.Background()
signedSweepPsetB64, err := wallet.SignPsetWithKey(ctx, sweepPsetBase64, nil)
if err != nil {
return "", err
}
signedPset, err := psetv2.NewPsetFromBase64(signedSweepPsetB64)
if err != nil {
return "", err
}
if err := psetv2.FinalizeAll(signedPset); err != nil {
return "", err
}
extractedTx, err := psetv2.Extract(signedPset)
if err != nil {
return "", err
}
return extractedTx.ToHex()
}
func (b *txBuilder) BuildForfeitTxs(
aspPubkey *secp256k1.PublicKey, poolTx string, payments []domain.Payment,
) (connectors []string, forfeitTxs []string, err error) {
connectorTxs, err := b.createConnectors(poolTx, payments, aspPubkey)
if err != nil {
return nil, nil, err
}
forfeitTxs, err = b.createForfeitTxs(aspPubkey, payments, connectorTxs)
if err != nil {
return nil, nil, err
}
for _, tx := range connectorTxs {
buf, _ := tx.ToBase64()
connectors = append(connectors, buf)
}
return connectors, forfeitTxs, nil
}
func (b *txBuilder) BuildPoolTx(
aspPubkey *secp256k1.PublicKey, payments []domain.Payment, minRelayFee uint64,
) (poolTx string, congestionTree tree.CongestionTree, err error) {
// The creation of the tree and the pool tx are tightly coupled:
// - building the tree requires knowing the shared outpoint (txid:vout)
// - building the pool tx requires knowing the shared output script and amount
// The idea here is to first create all the data for the outputs of the txs
// of the congestion tree to calculate the shared output script and amount.
// With these data the pool tx can be created, and once the shared utxo
// outpoint is obtained, the congestion tree can be finally created.
// The factory function `treeFactoryFn` returned below holds all outputs data
// generated in the process and takes the shared utxo outpoint as argument.
// This is safe as the memory allocated for `craftCongestionTree` is freed
// only after `BuildPoolTx` returns.
treeFactoryFn, sharedOutputScript, sharedOutputAmount, err := craftCongestionTree(
b.net.AssetID, aspPubkey, payments, minRelayFee, b.roundLifetime,
)
if err != nil {
return
}
ptx, err := b.createPoolTx(
sharedOutputAmount, sharedOutputScript, payments, aspPubkey,
)
if err != nil {
return
}
unsignedTx, err := ptx.UnsignedTx()
if err != nil {
return
}
tree, err := treeFactoryFn(psetv2.InputArgs{
Txid: unsignedTx.TxHash().String(),
TxIndex: 0,
})
if err != nil {
return
}
poolTx, err = ptx.ToBase64()
if err != nil {
return
}
congestionTree = tree
return
}
func (b *txBuilder) GetLeafSweepClosure(
node tree.Node, userPubKey *secp256k1.PublicKey,
) (*psetv2.TapLeafScript, int64, error) {
if !node.Leaf {
return nil, 0, fmt.Errorf("node is not a leaf")
}
pset, err := psetv2.NewPsetFromBase64(node.Tx)
if err != nil {
return nil, 0, err
}
input := pset.Inputs[0]
sweepLeaf, lifetime, err := extractSweepLeaf(input)
if err != nil {
return nil, 0, err
}
// craft the vtxo taproot tree
vtxoScript, err := tree.VtxoScript(userPubKey)
if err != nil {
return nil, 0, err
}
vtxoTaprootTree := taproot.AssembleTaprootScriptTree(
*vtxoScript,
sweepLeaf.TapElementsLeaf,
)
proofIndex := vtxoTaprootTree.LeafProofIndex[sweepLeaf.TapHash()]
proof := vtxoTaprootTree.LeafMerkleProofs[proofIndex]
return &psetv2.TapLeafScript{
TapElementsLeaf: proof.TapElementsLeaf,
ControlBlock: proof.ToControlBlock(sweepLeaf.ControlBlock.InternalKey),
}, lifetime, nil
}
func (b *txBuilder) getLeafScriptAndTree(
userPubkey, aspPubkey *secp256k1.PublicKey,
) ([]byte, *taproot.IndexedElementsTapScriptTree, error) {
redeemClosure, err := tree.VtxoScript(userPubkey)
if err != nil {
return nil, nil, err
}
sweepClosure, err := tree.SweepScript(aspPubkey, uint(b.roundLifetime))
if err != nil {
return nil, nil, err
}
taprootTree := taproot.AssembleTaprootScriptTree(
*redeemClosure, *sweepClosure,
)
root := taprootTree.RootNode.TapHash()
unspendableKey := tree.UnspendableKey()
taprootKey := taproot.ComputeTaprootOutputKey(unspendableKey, root[:])
outputScript, err := taprootOutputScript(taprootKey)
if err != nil {
return nil, nil, err
}
return outputScript, taprootTree, nil
}
func (b *txBuilder) createPoolTx(
sharedOutputAmount uint64, sharedOutputScript []byte,
payments []domain.Payment, aspPubKey *secp256k1.PublicKey,
) (*psetv2.Pset, error) {
aspScript, err := p2wpkhScript(aspPubKey, b.net)
if err != nil {
return nil, err
}
receivers := getOnchainReceivers(payments)
connectorsAmount := connectorAmount * countSpentVtxos(payments)
targetAmount := sharedOutputAmount + connectorsAmount
outputs := []psetv2.OutputArgs{
{
Asset: b.net.AssetID,
Amount: sharedOutputAmount,
Script: sharedOutputScript,
},
{
Asset: b.net.AssetID,
Amount: connectorsAmount,
Script: aspScript,
},
}
for _, receiver := range receivers {
targetAmount += receiver.Amount
receiverScript, err := address.ToOutputScript(receiver.OnchainAddress)
if err != nil {
return nil, err
}
outputs = append(outputs, psetv2.OutputArgs{
Asset: b.net.AssetID,
Amount: receiver.Amount,
Script: receiverScript,
})
}
ctx := context.Background()
utxos, change, err := b.wallet.SelectUtxos(ctx, b.net.AssetID, targetAmount)
if err != nil {
return nil, err
}
var dust uint64
if change > 0 {
if change < dustLimit {
dust = change
change = 0
} else {
outputs = append(outputs, psetv2.OutputArgs{
Asset: b.net.AssetID,
Amount: change,
Script: aspScript,
})
}
}
ptx, err := psetv2.New(nil, outputs, nil)
if err != nil {
return nil, err
}
updater, err := psetv2.NewUpdater(ptx)
if err != nil {
return nil, err
}
if err := addInputs(updater, utxos); err != nil {
return nil, err
}
b64, err := ptx.ToBase64()
if err != nil {
return nil, err
}
feeAmount, err := b.wallet.EstimateFees(ctx, b64)
if err != nil {
return nil, err
}
if dust > feeAmount {
feeAmount = dust
} else {
feeAmount += dust
}
if dust == 0 {
if feeAmount == change {
// fees = change, remove change output
ptx.Outputs = ptx.Outputs[:len(ptx.Outputs)-1]
} else if feeAmount < change {
// change covers the fees, reduce change amount
ptx.Outputs[len(ptx.Outputs)-1].Value = change - feeAmount
} else {
// change is not enough to cover fees, re-select utxos
if change > 0 {
// remove change output if present
ptx.Outputs = ptx.Outputs[:len(ptx.Outputs)-1]
}
newUtxos, change, err := b.wallet.SelectUtxos(ctx, b.net.AssetID, feeAmount-change)
if err != nil {
return nil, err
}
if change > 0 {
if err := updater.AddOutputs([]psetv2.OutputArgs{
{
Asset: b.net.AssetID,
Amount: change,
Script: aspScript,
},
}); err != nil {
return nil, err
}
}
if err := addInputs(updater, newUtxos); err != nil {
return nil, err
}
}
}
// add fee output
if err := updater.AddOutputs([]psetv2.OutputArgs{
{
Asset: b.net.AssetID,
Amount: feeAmount,
},
}); err != nil {
return nil, err
}
return ptx, nil
}
func (b *txBuilder) createConnectors(
poolTx string, payments []domain.Payment, aspPubkey *secp256k1.PublicKey,
) ([]*psetv2.Pset, error) {
txid, _ := getTxid(poolTx)
aspScript, err := p2wpkhScript(aspPubkey, b.net)
if err != nil {
return nil, err
}
connectorOutput := psetv2.OutputArgs{
Asset: b.net.AssetID,
Script: aspScript,
Amount: connectorAmount,
}
numberOfConnectors := countSpentVtxos(payments)
previousInput := psetv2.InputArgs{
Txid: txid,
TxIndex: 1,
}
if numberOfConnectors == 1 {
outputs := []psetv2.OutputArgs{connectorOutput}
connectorTx, err := craftConnectorTx(previousInput, outputs)
if err != nil {
return nil, err
}
return []*psetv2.Pset{connectorTx}, nil
}
totalConnectorAmount := connectorAmount * numberOfConnectors
connectors := make([]*psetv2.Pset, 0, numberOfConnectors-1)
for i := uint64(0); i < numberOfConnectors-1; i++ {
outputs := []psetv2.OutputArgs{connectorOutput}
totalConnectorAmount -= connectorAmount
if totalConnectorAmount > 0 {
outputs = append(outputs, psetv2.OutputArgs{
Asset: b.net.AssetID,
Script: aspScript,
Amount: totalConnectorAmount,
})
}
connectorTx, err := craftConnectorTx(previousInput, outputs)
if err != nil {
return nil, err
}
txid, _ := getPsetId(connectorTx)
previousInput = psetv2.InputArgs{
Txid: txid,
TxIndex: 1,
}
connectors = append(connectors, connectorTx)
}
return connectors, nil
}
func (b *txBuilder) createForfeitTxs(
aspPubkey *secp256k1.PublicKey, payments []domain.Payment, connectors []*psetv2.Pset,
) ([]string, error) {
aspScript, err := p2wpkhScript(aspPubkey, b.net)
if err != nil {
return nil, err
}
forfeitTxs := make([]string, 0)
for _, payment := range payments {
for _, vtxo := range payment.Inputs {
pubkeyBytes, err := hex.DecodeString(vtxo.Pubkey)
if err != nil {
return nil, fmt.Errorf("failed to decode pubkey: %s", err)
}
vtxoPubkey, err := secp256k1.ParsePubKey(pubkeyBytes)
if err != nil {
return nil, err
}
vtxoScript, vtxoTaprootTree, err := b.getLeafScriptAndTree(vtxoPubkey, aspPubkey)
if err != nil {
return nil, err
}
for _, connector := range connectors {
txs, err := craftForfeitTxs(
connector, vtxo, vtxoTaprootTree, vtxoScript, aspScript,
)
if err != nil {
return nil, err
}
forfeitTxs = append(forfeitTxs, txs...)
}
}
}
return forfeitTxs, nil
}
// given a congestion tree input, searches and returns the sweep leaf and its lifetime in seconds
func extractSweepLeaf(input psetv2.Input) (sweepLeaf *psetv2.TapLeafScript, lifetime int64, err error) {
for _, leaf := range input.TapLeafScript {
isSweep, _, seconds, err := tree.DecodeSweepScript(leaf.Script)
if err != nil {
return nil, 0, err
}
if isSweep {
lifetime = int64(seconds)
sweepLeaf = &leaf
break
}
}
if sweepLeaf == nil {
return nil, 0, fmt.Errorf("sweep leaf not found")
}
return sweepLeaf, lifetime, nil
}

View File

@@ -0,0 +1,229 @@
package txbuilder_test
import (
"crypto/rand"
"encoding/hex"
"encoding/json"
"os"
"testing"
"github.com/ark-network/ark/common"
"github.com/ark-network/ark/common/tree"
"github.com/ark-network/ark/internal/core/domain"
"github.com/ark-network/ark/internal/core/ports"
txbuilder "github.com/ark-network/ark/internal/infrastructure/tx-builder/covenant"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/decred/dcrd/dcrec/secp256k1/v4"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
"github.com/vulpemventures/go-elements/network"
"github.com/vulpemventures/go-elements/psetv2"
)
const (
testingKey = "apub1qgvdtj5ttpuhkldavhq8thtm5auyk0ec4dcmrfdgu0u5hgp9we22v3hrs4x"
minRelayFee = uint64(30)
roundLifetime = int64(1209344)
)
var (
wallet *mockedWallet
pubkey *secp256k1.PublicKey
)
func TestMain(m *testing.M) {
wallet = &mockedWallet{}
wallet.On("EstimateFees", mock.Anything, mock.Anything).
Return(uint64(100), nil)
wallet.On("SelectUtxos", mock.Anything, mock.Anything, mock.Anything).
Return(randomInput, uint64(0), nil)
_, pubkey, _ = common.DecodePubKey(testingKey)
os.Exit(m.Run())
}
func TestBuildPoolTx(t *testing.T) {
builder := txbuilder.NewTxBuilder(wallet, network.Liquid, roundLifetime)
fixtures, err := parsePoolTxFixtures()
require.NoError(t, err)
require.NotEmpty(t, fixtures)
if len(fixtures.Valid) > 0 {
t.Run("valid", func(t *testing.T) {
for _, f := range fixtures.Valid {
poolTx, congestionTree, err := builder.BuildPoolTx(pubkey, f.Payments, minRelayFee)
require.NoError(t, err)
require.NotEmpty(t, poolTx)
require.NotEmpty(t, congestionTree)
require.Equal(t, f.ExpectedNumOfNodes, congestionTree.NumberOfNodes())
require.Len(t, congestionTree.Leaves(), f.ExpectedNumOfLeaves)
err = tree.ValidateCongestionTree(congestionTree, poolTx, pubkey, roundLifetime)
require.NoError(t, err)
}
})
}
if len(fixtures.Invalid) > 0 {
t.Run("invalid", func(t *testing.T) {
for _, f := range fixtures.Invalid {
poolTx, congestionTree, err := builder.BuildPoolTx(pubkey, f.Payments, minRelayFee)
require.EqualError(t, err, f.ExpectedErr)
require.Empty(t, poolTx)
require.Empty(t, congestionTree)
}
})
}
}
func TestBuildForfeitTxs(t *testing.T) {
builder := txbuilder.NewTxBuilder(wallet, network.Liquid, 1209344)
fixtures, err := parseForfeitTxsFixtures()
require.NoError(t, err)
require.NotEmpty(t, fixtures)
if len(fixtures.Valid) > 0 {
t.Run("valid", func(t *testing.T) {
for _, f := range fixtures.Valid {
connectors, forfeitTxs, err := builder.BuildForfeitTxs(
pubkey, f.PoolTx, f.Payments,
)
require.NoError(t, err)
require.Len(t, connectors, f.ExpectedNumOfConnectors)
require.Len(t, forfeitTxs, f.ExpectedNumOfForfeitTxs)
expectedInputTxid := f.PoolTxid
// Verify the chain of connectors
for _, connector := range connectors {
tx, err := psetv2.NewPsetFromBase64(connector)
require.NoError(t, err)
require.NotNil(t, tx)
require.Len(t, tx.Inputs, 1)
require.Len(t, tx.Outputs, 2)
inputTxid := chainhash.Hash(tx.Inputs[0].PreviousTxid).String()
require.Equal(t, expectedInputTxid, inputTxid)
require.Equal(t, 1, int(tx.Inputs[0].PreviousTxIndex))
expectedInputTxid = getTxid(tx)
}
// decode and check forfeit txs
for _, forfeitTx := range forfeitTxs {
tx, err := psetv2.NewPsetFromBase64(forfeitTx)
require.NoError(t, err)
require.Len(t, tx.Inputs, 2)
require.Len(t, tx.Outputs, 2)
}
}
})
}
if len(fixtures.Invalid) > 0 {
t.Run("invalid", func(t *testing.T) {
for _, f := range fixtures.Invalid {
connectors, forfeitTxs, err := builder.BuildForfeitTxs(
pubkey, f.PoolTx, f.Payments,
)
require.EqualError(t, err, f.ExpectedErr)
require.Empty(t, connectors)
require.Empty(t, forfeitTxs)
}
})
}
}
func randomInput() []ports.TxInput {
txid := randomHex(32)
input := &mockedInput{}
input.On("GetAsset").Return("5ac9f65c0efcc4775e0baec4ec03abdde22473cd3cf33c0419ca290e0751b225")
input.On("GetValue").Return(uint64(1000))
input.On("GetScript").Return("a914ea9f486e82efb3dd83a69fd96e3f0113757da03c87")
input.On("GetTxid").Return(txid)
input.On("GetIndex").Return(uint32(0))
return []ports.TxInput{input}
}
func randomHex(len int) string {
buf := make([]byte, len)
// nolint
rand.Read(buf)
return hex.EncodeToString(buf)
}
type poolTxFixtures struct {
Valid []struct {
Payments []domain.Payment
ExpectedNumOfNodes int
ExpectedNumOfLeaves int
}
Invalid []struct {
Payments []domain.Payment
ExpectedErr string
}
}
func parsePoolTxFixtures() (*poolTxFixtures, error) {
file, err := os.ReadFile("testdata/fixtures.json")
if err != nil {
return nil, err
}
v := map[string]interface{}{}
if err := json.Unmarshal(file, &v); err != nil {
return nil, err
}
vv := v["buildPoolTx"].(map[string]interface{})
file, _ = json.Marshal(vv)
var fixtures poolTxFixtures
if err := json.Unmarshal(file, &fixtures); err != nil {
return nil, err
}
return &fixtures, nil
}
type forfeitTxsFixtures struct {
Valid []struct {
Payments []domain.Payment
ExpectedNumOfConnectors int
ExpectedNumOfForfeitTxs int
PoolTx string
PoolTxid string
}
Invalid []struct {
Payments []domain.Payment
ExpectedErr string
PoolTx string
}
}
func parseForfeitTxsFixtures() (*forfeitTxsFixtures, error) {
file, err := os.ReadFile("testdata/fixtures.json")
if err != nil {
return nil, err
}
v := map[string]interface{}{}
if err := json.Unmarshal(file, &v); err != nil {
return nil, err
}
vv := v["buildForfeitTxs"].(map[string]interface{})
file, _ = json.Marshal(vv)
var fixtures forfeitTxsFixtures
if err := json.Unmarshal(file, &fixtures); err != nil {
return nil, err
}
return &fixtures, nil
}
func getTxid(tx *psetv2.Pset) string {
utx, _ := tx.UnsignedTx()
return utx.TxHash().String()
}

View File

@@ -0,0 +1,48 @@
package txbuilder
import (
"github.com/vulpemventures/go-elements/psetv2"
"github.com/vulpemventures/go-elements/transaction"
)
func craftConnectorTx(
input psetv2.InputArgs, outputs []psetv2.OutputArgs,
) (*psetv2.Pset, error) {
ptx, _ := psetv2.New(nil, nil, nil)
updater, _ := psetv2.NewUpdater(ptx)
if err := updater.AddInputs(
[]psetv2.InputArgs{input},
); err != nil {
return nil, err
}
// TODO: add prevout.
if err := updater.AddOutputs(outputs); err != nil {
return nil, err
}
return ptx, nil
}
func getConnectorInputs(pset *psetv2.Pset) ([]psetv2.InputArgs, []*transaction.TxOutput) {
txID, _ := getPsetId(pset)
inputs := make([]psetv2.InputArgs, 0, len(pset.Outputs))
witnessUtxos := make([]*transaction.TxOutput, 0, len(pset.Outputs))
for i, output := range pset.Outputs {
utx, _ := pset.UnsignedTx()
if output.Value == connectorAmount && len(output.Script) > 0 {
inputs = append(inputs, psetv2.InputArgs{
Txid: txID,
TxIndex: uint32(i),
})
witnessUtxos = append(witnessUtxos, utx.Outputs[i])
}
}
return inputs, witnessUtxos
}

View File

@@ -0,0 +1,104 @@
package txbuilder
import (
"github.com/ark-network/ark/common/tree"
"github.com/ark-network/ark/internal/core/domain"
"github.com/btcsuite/btcd/txscript"
"github.com/vulpemventures/go-elements/elementsutil"
"github.com/vulpemventures/go-elements/psetv2"
"github.com/vulpemventures/go-elements/taproot"
"github.com/vulpemventures/go-elements/transaction"
)
func craftForfeitTxs(
connectorTx *psetv2.Pset,
vtxo domain.Vtxo,
vtxoTaprootTree *taproot.IndexedElementsTapScriptTree,
vtxoScript, aspScript []byte,
) (forfeitTxs []string, err error) {
connectors, prevouts := getConnectorInputs(connectorTx)
for i, connectorInput := range connectors {
connectorPrevout := prevouts[i]
asset := elementsutil.AssetHashFromBytes(connectorPrevout.Asset)
pset, err := psetv2.New(nil, nil, nil)
if err != nil {
return nil, err
}
updater, err := psetv2.NewUpdater(pset)
if err != nil {
return nil, err
}
vtxoInput := psetv2.InputArgs{
Txid: vtxo.Txid,
TxIndex: vtxo.VOut,
}
vtxoAmount, _ := elementsutil.ValueToBytes(vtxo.Amount)
vtxoPrevout := &transaction.TxOutput{
Asset: connectorPrevout.Asset,
Value: vtxoAmount,
Script: vtxoScript,
}
if err := updater.AddInputs([]psetv2.InputArgs{connectorInput, vtxoInput}); err != nil {
return nil, err
}
if err = updater.AddInWitnessUtxo(0, connectorPrevout); err != nil {
return nil, err
}
if err := updater.AddInSighashType(0, txscript.SigHashAll); err != nil {
return nil, err
}
if err = updater.AddInWitnessUtxo(1, vtxoPrevout); err != nil {
return nil, err
}
if err := updater.AddInSighashType(1, txscript.SigHashDefault); err != nil {
return nil, err
}
unspendableKey := tree.UnspendableKey()
for _, proof := range vtxoTaprootTree.LeafMerkleProofs {
tapScript := psetv2.NewTapLeafScript(proof, unspendableKey)
if err := updater.AddInTapLeafScript(1, tapScript); err != nil {
return nil, err
}
}
connectorAmount, err := elementsutil.ValueFromBytes(connectorPrevout.Value)
if err != nil {
return nil, err
}
err = updater.AddOutputs([]psetv2.OutputArgs{
{
Asset: asset,
Amount: vtxo.Amount + connectorAmount - 30,
Script: aspScript,
},
{
Asset: asset,
Amount: 30,
},
})
if err != nil {
return nil, err
}
tx, err := pset.ToBase64()
if err != nil {
return nil, err
}
forfeitTxs = append(forfeitTxs, tx)
}
return forfeitTxs, nil
}

View File

@@ -0,0 +1,203 @@
package txbuilder_test
import (
"context"
"github.com/ark-network/ark/internal/core/domain"
"github.com/ark-network/ark/internal/core/ports"
"github.com/decred/dcrd/dcrec/secp256k1/v4"
"github.com/stretchr/testify/mock"
)
type mockedWallet struct {
mock.Mock
}
// BroadcastTransaction implements ports.WalletService.
func (m *mockedWallet) BroadcastTransaction(ctx context.Context, txHex string) (string, error) {
args := m.Called(ctx, txHex)
var res string
if a := args.Get(0); a != nil {
res = a.(string)
}
return res, args.Error(1)
}
// Close implements ports.WalletService.
func (m *mockedWallet) Close() {
m.Called()
}
// DeriveAddresses implements ports.WalletService.
func (m *mockedWallet) DeriveAddresses(ctx context.Context, num int) ([]string, error) {
args := m.Called(ctx, num)
var res []string
if a := args.Get(0); a != nil {
res = a.([]string)
}
return res, args.Error(1)
}
// GetPubkey implements ports.WalletService.
func (m *mockedWallet) GetPubkey(ctx context.Context) (*secp256k1.PublicKey, error) {
args := m.Called(ctx)
var res *secp256k1.PublicKey
if a := args.Get(0); a != nil {
res = a.(*secp256k1.PublicKey)
}
return res, args.Error(1)
}
// SignPset implements ports.WalletService.
func (m *mockedWallet) SignPset(ctx context.Context, pset string, extractRawTx bool) (string, error) {
args := m.Called(ctx, pset, extractRawTx)
var res string
if a := args.Get(0); a != nil {
res = a.(string)
}
return res, args.Error(1)
}
// Status implements ports.WalletService.
func (m *mockedWallet) Status(ctx context.Context) (ports.WalletStatus, error) {
args := m.Called(ctx)
var res ports.WalletStatus
if a := args.Get(0); a != nil {
res = a.(ports.WalletStatus)
}
return res, args.Error(1)
}
func (m *mockedWallet) SelectUtxos(ctx context.Context, asset string, amount uint64) ([]ports.TxInput, uint64, error) {
args := m.Called(ctx, asset, amount)
var res0 func() []ports.TxInput
if a := args.Get(0); a != nil {
res0 = a.(func() []ports.TxInput)
}
var res1 uint64
if a := args.Get(1); a != nil {
res1 = a.(uint64)
}
return res0(), res1, args.Error(2)
}
func (m *mockedWallet) EstimateFees(ctx context.Context, pset string) (uint64, error) {
args := m.Called(ctx, pset)
var res uint64
if a := args.Get(0); a != nil {
res = a.(uint64)
}
return res, args.Error(1)
}
func (m *mockedWallet) IsTransactionPublished(ctx context.Context, txid string) (bool, int64, error) {
args := m.Called(ctx, txid)
var res bool
if a := args.Get(0); a != nil {
res = a.(bool)
}
var blocktime int64
if b := args.Get(1); b != nil {
blocktime = b.(int64)
}
return res, blocktime, args.Error(2)
}
func (m *mockedWallet) SignPsetWithKey(ctx context.Context, pset string, inputIndexes []int) (string, error) {
args := m.Called(ctx, pset, inputIndexes)
var res string
if a := args.Get(0); a != nil {
res = a.(string)
}
return res, args.Error(1)
}
func (m *mockedWallet) WatchScripts(
ctx context.Context, scripts []string,
) error {
args := m.Called(ctx, scripts)
return args.Error(0)
}
func (m *mockedWallet) UnwatchScripts(
ctx context.Context, scripts []string,
) error {
args := m.Called(ctx, scripts)
return args.Error(0)
}
func (m *mockedWallet) GetNotificationChannel(ctx context.Context) chan []domain.VtxoKey {
args := m.Called(ctx)
var res chan []domain.VtxoKey
if a := args.Get(0); a != nil {
res = a.(chan []domain.VtxoKey)
}
return res
}
type mockedInput struct {
mock.Mock
}
func (m *mockedInput) GetTxid() string {
args := m.Called()
var res string
if a := args.Get(0); a != nil {
res = a.(string)
}
return res
}
func (m *mockedInput) GetIndex() uint32 {
args := m.Called()
var res uint32
if a := args.Get(0); a != nil {
res = a.(uint32)
}
return res
}
func (m *mockedInput) GetScript() string {
args := m.Called()
var res string
if a := args.Get(0); a != nil {
res = a.(string)
}
return res
}
func (m *mockedInput) GetAsset() string {
args := m.Called()
var res string
if a := args.Get(0); a != nil {
res = a.(string)
}
return res
}
func (m *mockedInput) GetValue() uint64 {
args := m.Called()
var res uint64
if a := args.Get(0); a != nil {
res = a.(uint64)
}
return res
}

View File

@@ -0,0 +1,135 @@
package txbuilder
import (
"context"
"fmt"
"github.com/ark-network/ark/common"
"github.com/ark-network/ark/common/tree"
"github.com/ark-network/ark/internal/core/ports"
"github.com/vulpemventures/go-elements/address"
"github.com/vulpemventures/go-elements/elementsutil"
"github.com/vulpemventures/go-elements/psetv2"
"github.com/vulpemventures/go-elements/taproot"
"github.com/vulpemventures/go-elements/transaction"
)
func sweepTransaction(
wallet ports.WalletService,
sweepInputs []ports.SweepInput,
lbtc string,
) (*psetv2.Pset, error) {
sweepPset, err := psetv2.New(nil, nil, nil)
if err != nil {
return nil, err
}
updater, err := psetv2.NewUpdater(sweepPset)
if err != nil {
return nil, err
}
amount := uint64(0)
for i, input := range sweepInputs {
leaf := input.SweepLeaf
isSweep, _, lifetime, err := tree.DecodeSweepScript(leaf.Script)
if err != nil {
return nil, err
}
if isSweep {
amount += input.Amount
if err := updater.AddInputs([]psetv2.InputArgs{input.InputArgs}); err != nil {
return nil, err
}
if err := updater.AddInTapLeafScript(i, leaf); err != nil {
return nil, err
}
assetHash, err := elementsutil.AssetHashToBytes(lbtc)
if err != nil {
return nil, err
}
value, err := elementsutil.ValueToBytes(input.Amount)
if err != nil {
return nil, err
}
root := leaf.ControlBlock.RootHash(leaf.Script)
taprootKey := taproot.ComputeTaprootOutputKey(leaf.ControlBlock.InternalKey, root)
script, err := taprootOutputScript(taprootKey)
if err != nil {
return nil, err
}
witnessUtxo := transaction.NewTxOutput(assetHash, value, script)
if err := updater.AddInWitnessUtxo(i, witnessUtxo); err != nil {
return nil, err
}
sequence, err := common.BIP68EncodeAsNumber(lifetime)
if err != nil {
return nil, err
}
updater.Pset.Inputs[i].Sequence = sequence
continue
}
return nil, fmt.Errorf("invalid sweep script")
}
ctx := context.Background()
sweepAddress, err := wallet.DeriveAddresses(ctx, 1)
if err != nil {
return nil, err
}
script, err := address.ToOutputScript(sweepAddress[0])
if err != nil {
return nil, err
}
if err := updater.AddOutputs([]psetv2.OutputArgs{
{
Asset: lbtc,
Amount: amount,
Script: script,
},
}); err != nil {
return nil, err
}
b64, err := sweepPset.ToBase64()
if err != nil {
return nil, err
}
fees, err := wallet.EstimateFees(ctx, b64)
if err != nil {
return nil, err
}
if amount < fees {
return nil, fmt.Errorf("insufficient funds (%d) to cover fees (%d) for sweep transaction", amount, fees)
}
updater.Pset.Outputs[0].Value = amount - fees
if err := updater.AddOutputs([]psetv2.OutputArgs{
{
Asset: lbtc,
Amount: fees,
},
}); err != nil {
return nil, err
}
return sweepPset, nil
}

View File

@@ -0,0 +1,229 @@
{
"buildPoolTx": {
"valid": [
{
"payments": [
{
"id": "0",
"inputs": [
{
"txid": "fd68e3c5796cc7db0a8036d486d5f625b6b2f2c014810ac020e1ac23e82c59d6",
"vout": 0,
"pubkey": "020000000000000000000000000000000000000000000000000000000000000002",
"amount": 1100
}
],
"receivers": [
{
"pubkey": "020000000000000000000000000000000000000000000000000000000000000002",
"amount": 1100
}
]
}
],
"expectedNumOfNodes": 1,
"expectedNumOfLeaves": 1
},
{
"payments": [
{
"id": "0",
"inputs": [
{
"txid": "fd68e3c5796cc7db0a8036d486d5f625b6b2f2c014810ac020e1ac23e82c59d6",
"vout": 0,
"pubkey": "020000000000000000000000000000000000000000000000000000000000000002",
"amount": 1100
}
],
"receivers": [
{
"pubkey": "020000000000000000000000000000000000000000000000000000000000000002",
"amount": 600
},
{
"pubkey": "020000000000000000000000000000000000000000000000000000000000000002",
"amount": 500
}
]
}
],
"expectedNumOfNodes": 3,
"expectedNumOfLeaves": 2
},
{
"payments": [
{
"id": "0",
"inputs": [
{
"txid": "fd68e3c5796cc7db0a8036d486d5f625b6b2f2c014810ac020e1ac23e82c59d6",
"vout": 0,
"pubkey": "020000000000000000000000000000000000000000000000000000000000000002",
"amount": 1100
}
],
"receivers": [
{
"pubkey": "020000000000000000000000000000000000000000000000000000000000000002",
"amount": 600
},
{
"pubkey": "020000000000000000000000000000000000000000000000000000000000000002",
"amount": 500
}
]
},
{
"id": "0",
"inputs": [
{
"txid": "fd68e3c5796cc7db0a8036d486d5f625b6b2f2c014810ac020e1ac23e82c59d6",
"vout": 0,
"pubkey": "020000000000000000000000000000000000000000000000000000000000000002",
"amount": 1100
}
],
"receivers": [
{
"pubkey": "020000000000000000000000000000000000000000000000000000000000000002",
"amount": 600
},
{
"pubkey": "020000000000000000000000000000000000000000000000000000000000000002",
"amount": 500
}
]
},
{
"id": "0",
"inputs": [
{
"txid": "fd68e3c5796cc7db0a8036d486d5f625b6b2f2c014810ac020e1ac23e82c59d6",
"vout": 0,
"pubkey": "020000000000000000000000000000000000000000000000000000000000000002",
"amount": 1100
}
],
"receivers": [
{
"pubkey": "020000000000000000000000000000000000000000000000000000000000000002",
"amount": 600
},
{
"pubkey": "020000000000000000000000000000000000000000000000000000000000000002",
"amount": 500
}
]
}
],
"expectedNumOfNodes": 11,
"expectedNumOfLeaves": 6
},
{
"payments": [
{
"id": "a242cdd8-f3d5-46c0-ae98-94135a2bee3f",
"inputs": [
{
"txid": "755c820771284d85ea4bbcc246565b4eddadc44237a7e57a0f9cb78a840d1d41",
"vout": 0,
"pubkey": "02c87e5c1758df5ad42a918ec507b6e8dfcdcebf22f64f58eb4ad5804257d658a5",
"amount": 1000
},
{
"txid": "66a0df86fcdeb84b8877adfe0b2c556dba30305d72ddbd4c49355f6930355357",
"vout": 0,
"pubkey": "02c87e5c1758df5ad42a918ec507b6e8dfcdcebf22f64f58eb4ad5804257d658a5",
"amount": 1000
},
{
"txid": "9913159bc7aa493ca53cbb9cbc88f97ba01137c814009dc7ef520c3fafc67909",
"vout": 1,
"pubkey": "02c87e5c1758df5ad42a918ec507b6e8dfcdcebf22f64f58eb4ad5804257d658a5",
"amount": 500
},
{
"txid": "5e10e77a7cdedc153be5193a4b6055a7802706ded4f2a9efefe86ed2f9a6ae60",
"vout": 0,
"pubkey": "02c87e5c1758df5ad42a918ec507b6e8dfcdcebf22f64f58eb4ad5804257d658a5",
"amount": 1000
},
{
"txid": "5e10e77a7cdedc153be5193a4b6055a7802706ded4f2a9efefe86ed2f9a6ae60",
"vout": 1,
"pubkey": "02c87e5c1758df5ad42a918ec507b6e8dfcdcebf22f64f58eb4ad5804257d658a5",
"amount": 1000
}
],
"receivers": [
{
"pubkey": "02c87e5c1758df5ad42a918ec507b6e8dfcdcebf22f64f58eb4ad5804257d658a5",
"amount": 1000
},
{
"pubkey": "02c87e5c1758df5ad42a918ec507b6e8dfcdcebf22f64f58eb4ad5804257d658a5",
"amount": 1000
},
{
"pubkey": "02c87e5c1758df5ad42a918ec507b6e8dfcdcebf22f64f58eb4ad5804257d658a5",
"amount": 1000
},
{
"pubkey": "02c87e5c1758df5ad42a918ec507b6e8dfcdcebf22f64f58eb4ad5804257d658a5",
"amount": 1000
},
{
"pubkey": "02c87e5c1758df5ad42a918ec507b6e8dfcdcebf22f64f58eb4ad5804257d658a5",
"amount": 500
}
]
}
],
"expectedNumOfNodes": 9,
"expectedNumOfLeaves": 5
}
],
"invalid": []
},
"buildForfeitTxs": {
"valid": [
{
"payments": [
{
"id": "0",
"inputs": [
{
"txid": "fd68e3c5796cc7db0a8036d486d5f625b6b2f2c014810ac020e1ac23e82c59d6",
"vout": 0,
"pubkey": "020000000000000000000000000000000000000000000000000000000000000002",
"amount": 600
},
{
"txid": "fd68e3c5796cc7db0a8036d486d5f625b6b2f2c014810ac020e1ac23e82c59d6",
"vout": 1,
"pubkey": "020000000000000000000000000000000000000000000000000000000000000002",
"amount": 500
}
],
"receivers": [
{
"pubkey": "020000000000000000000000000000000000000000000000000000000000000002",
"amount": 600
},
{
"pubkey": "020000000000000000000000000000000000000000000000000000000000000002",
"amount": 500
}
]
}
],
"poolTx": "cHNldP8BAgQCAAAAAQQBAQEFAQMBBgEDAfsEAgAAAAABDiDk7dXxh4KQzgLO8i1ABtaLCe4aPL12GVhN1E9zM1ePLwEPBAAAAAABEAT/////AAEDCOgDAAAAAAAAAQQWABSNnpy01UJqd99eTg2M1IpdKId11gf8BHBzZXQCICWyUQcOKcoZBDzzPM1zJOLdqwPsxK4LXnfE/A5c9slaB/wEcHNldAgEAAAAAAABAwh4BQAAAAAAAAEEFgAUjZ6ctNVCanffXk4NjNSKXSiHddYH/ARwc2V0AiAlslEHDinKGQQ88zzNcyTi3asD7MSuC153xPwOXPbJWgf8BHBzZXQIBAAAAAAAAQMI9AEAAAAAAAABBAAH/ARwc2V0AiAlslEHDinKGQQ88zzNcyTi3asD7MSuC153xPwOXPbJWgf8BHBzZXQIBAAAAAAA",
"poolTxid": "7981fce656f266472cc742444527cb32a8bed8c90fed6d47adbfc4c8780d4d9a",
"expectedNumOfForfeitTxs": 4,
"expectedNumOfConnectors": 1
}
],
"invalid": []
}
}

View File

@@ -0,0 +1,448 @@
package txbuilder
import (
"encoding/hex"
"fmt"
"github.com/ark-network/ark/common/tree"
"github.com/ark-network/ark/internal/core/domain"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/decred/dcrd/dcrec/secp256k1/v4"
"github.com/vulpemventures/go-elements/psetv2"
"github.com/vulpemventures/go-elements/taproot"
)
type treeFactory func(outpoint psetv2.InputArgs) (tree.CongestionTree, error)
type node struct {
sweepKey *secp256k1.PublicKey
receivers []domain.Receiver
left *node
right *node
asset string
feeSats uint64
roundLifetime int64
_inputTaprootKey *secp256k1.PublicKey
_inputTaprootTree *taproot.IndexedElementsTapScriptTree
}
func (n *node) isLeaf() bool {
return len(n.receivers) == 1
}
func (n *node) getAmount() uint64 {
var amount uint64
for _, r := range n.receivers {
amount += r.Amount
}
if n.isLeaf() {
return amount
}
return amount + n.feeSats*uint64(n.countChildren())
}
func (n *node) countChildren() int {
result := 0
if n.left != nil {
result++
result += n.left.countChildren()
}
if n.right != nil {
result++
result += n.right.countChildren()
}
return result
}
func (n *node) getChildren() []*node {
if n.isLeaf() {
return nil
}
children := make([]*node, 0, 2)
if n.left != nil {
children = append(children, n.left)
}
if n.right != nil {
children = append(children, n.right)
}
return children
}
func (n *node) getOutputs() ([]psetv2.OutputArgs, error) {
if n.isLeaf() {
taprootKey, _, err := n.getVtxoWitnessData()
if err != nil {
return nil, err
}
script, err := taprootOutputScript(taprootKey)
if err != nil {
return nil, err
}
output := &psetv2.OutputArgs{
Asset: n.asset,
Amount: uint64(n.getAmount()),
Script: script,
}
return []psetv2.OutputArgs{*output}, nil
}
outputs := make([]psetv2.OutputArgs, 0, 2)
children := n.getChildren()
for _, child := range children {
childWitnessProgram, _, err := child.getWitnessData()
if err != nil {
return nil, err
}
script, err := taprootOutputScript(childWitnessProgram)
if err != nil {
return nil, err
}
outputs = append(outputs, psetv2.OutputArgs{
Asset: n.asset,
Amount: child.getAmount() + child.feeSats,
Script: script,
})
}
return outputs, nil
}
func (n *node) getWitnessData() (
*secp256k1.PublicKey, *taproot.IndexedElementsTapScriptTree, error,
) {
if n._inputTaprootKey != nil && n._inputTaprootTree != nil {
return n._inputTaprootKey, n._inputTaprootTree, nil
}
sweepClosure, err := tree.SweepScript(n.sweepKey, uint(n.roundLifetime))
if err != nil {
return nil, nil, err
}
if n.isLeaf() {
taprootKey, _, err := n.getVtxoWitnessData()
if err != nil {
return nil, nil, err
}
branchTaprootScript := tree.BranchScript(
taprootKey, nil, n.getAmount(), 0,
)
branchTaprootTree := taproot.AssembleTaprootScriptTree(
branchTaprootScript, *sweepClosure,
)
root := branchTaprootTree.RootNode.TapHash()
inputTapkey := taproot.ComputeTaprootOutputKey(
tree.UnspendableKey(),
root[:],
)
n._inputTaprootKey = inputTapkey
n._inputTaprootTree = branchTaprootTree
return inputTapkey, branchTaprootTree, nil
}
leftKey, _, err := n.left.getWitnessData()
if err != nil {
return nil, nil, err
}
rightKey, _, err := n.right.getWitnessData()
if err != nil {
return nil, nil, err
}
leftAmount := n.left.getAmount() + n.feeSats
rightAmount := n.right.getAmount() + n.feeSats
branchTaprootLeaf := tree.BranchScript(
leftKey, rightKey, leftAmount, rightAmount,
)
branchTaprootTree := taproot.AssembleTaprootScriptTree(
branchTaprootLeaf, *sweepClosure,
)
root := branchTaprootTree.RootNode.TapHash()
taprootKey := taproot.ComputeTaprootOutputKey(
tree.UnspendableKey(),
root[:],
)
n._inputTaprootKey = taprootKey
n._inputTaprootTree = branchTaprootTree
return taprootKey, branchTaprootTree, nil
}
func (n *node) getVtxoWitnessData() (
*secp256k1.PublicKey, *taproot.IndexedElementsTapScriptTree, error,
) {
if !n.isLeaf() {
return nil, nil, fmt.Errorf("cannot call vtxoWitness on a non-leaf node")
}
sweepClosure, err := tree.SweepScript(n.sweepKey, uint(n.roundLifetime))
if err != nil {
return nil, nil, err
}
key, err := hex.DecodeString(n.receivers[0].Pubkey)
if err != nil {
return nil, nil, err
}
pubkey, err := secp256k1.ParsePubKey(key)
if err != nil {
return nil, nil, err
}
vtxoLeaf, err := tree.VtxoScript(pubkey)
if err != nil {
return nil, nil, err
}
// TODO: add forfeit path
leafTaprootTree := taproot.AssembleTaprootScriptTree(
*vtxoLeaf, *sweepClosure,
)
root := leafTaprootTree.RootNode.TapHash()
taprootKey := taproot.ComputeTaprootOutputKey(
tree.UnspendableKey(),
root[:],
)
return taprootKey, leafTaprootTree, nil
}
func (n *node) getTreeNode(
input psetv2.InputArgs, tapTree *taproot.IndexedElementsTapScriptTree,
) (tree.Node, error) {
pset, err := n.getTx(input, tapTree)
if err != nil {
return tree.Node{}, err
}
txid, err := getPsetId(pset)
if err != nil {
return tree.Node{}, err
}
tx, err := pset.ToBase64()
if err != nil {
return tree.Node{}, err
}
parentTxid := chainhash.Hash(pset.Inputs[0].PreviousTxid).String()
return tree.Node{
Txid: txid,
Tx: tx,
ParentTxid: parentTxid,
Leaf: n.isLeaf(),
}, nil
}
func (n *node) getTx(
input psetv2.InputArgs, inputTapTree *taproot.IndexedElementsTapScriptTree,
) (*psetv2.Pset, error) {
pset, err := psetv2.New(nil, nil, nil)
if err != nil {
return nil, err
}
updater, err := psetv2.NewUpdater(pset)
if err != nil {
return nil, err
}
if err := addTaprootInput(
updater, input, tree.UnspendableKey(), inputTapTree,
); err != nil {
return nil, err
}
feeOutput := psetv2.OutputArgs{
Amount: uint64(n.feeSats),
Asset: n.asset,
}
outputs, err := n.getOutputs()
if err != nil {
return nil, err
}
if err := updater.AddOutputs(append(outputs, feeOutput)); err != nil {
return nil, err
}
return pset, nil
}
func (n *node) createFinalCongestionTree() treeFactory {
return func(poolTxInput psetv2.InputArgs) (tree.CongestionTree, error) {
congestionTree := make(tree.CongestionTree, 0)
_, taprootTree, err := n.getWitnessData()
if err != nil {
return nil, err
}
ins := []psetv2.InputArgs{poolTxInput}
inTrees := []*taproot.IndexedElementsTapScriptTree{taprootTree}
nodes := []*node{n}
for len(nodes) > 0 {
nextNodes := make([]*node, 0)
nextInputsArgs := make([]psetv2.InputArgs, 0)
nextTaprootTrees := make([]*taproot.IndexedElementsTapScriptTree, 0)
treeLevel := make([]tree.Node, 0)
for i, node := range nodes {
treeNode, err := node.getTreeNode(ins[i], inTrees[i])
if err != nil {
return nil, err
}
treeLevel = append(treeLevel, treeNode)
children := node.getChildren()
for i, child := range children {
_, taprootTree, err := child.getWitnessData()
if err != nil {
return nil, err
}
nextNodes = append(nextNodes, child)
nextInputsArgs = append(nextInputsArgs, psetv2.InputArgs{
Txid: treeNode.Txid,
TxIndex: uint32(i),
})
nextTaprootTrees = append(nextTaprootTrees, taprootTree)
}
}
congestionTree = append(congestionTree, treeLevel)
nodes = append([]*node{}, nextNodes...)
ins = append([]psetv2.InputArgs{}, nextInputsArgs...)
inTrees = append(
[]*taproot.IndexedElementsTapScriptTree{}, nextTaprootTrees...,
)
}
return congestionTree, nil
}
}
func craftCongestionTree(
asset string, aspPublicKey *secp256k1.PublicKey,
payments []domain.Payment, feeSatsPerNode uint64, roundLifetime int64,
) (
buildCongestionTree treeFactory,
sharedOutputScript []byte, sharedOutputAmount uint64, err error,
) {
receivers := getOffchainReceivers(payments)
root, err := createPartialCongestionTree(
receivers, aspPublicKey, asset, feeSatsPerNode, roundLifetime,
)
if err != nil {
return
}
taprootKey, _, err := root.getWitnessData()
if err != nil {
return
}
sharedOutputScript, err = taprootOutputScript(taprootKey)
if err != nil {
return
}
sharedOutputAmount = root.getAmount() + root.feeSats
buildCongestionTree = root.createFinalCongestionTree()
return
}
func createPartialCongestionTree(
receivers []domain.Receiver,
aspPublicKey *secp256k1.PublicKey,
asset string,
feeSatsPerNode uint64,
roundLifetime int64,
) (root *node, err error) {
if len(receivers) == 0 {
return nil, fmt.Errorf("no receivers provided")
}
nodes := make([]*node, 0, len(receivers))
for _, r := range receivers {
leafNode := &node{
sweepKey: aspPublicKey,
receivers: []domain.Receiver{r},
asset: asset,
feeSats: feeSatsPerNode,
roundLifetime: roundLifetime,
}
nodes = append(nodes, leafNode)
}
for len(nodes) > 1 {
nodes, err = createUpperLevel(nodes)
if err != nil {
return
}
}
return nodes[0], nil
}
func createUpperLevel(nodes []*node) ([]*node, error) {
if len(nodes)%2 != 0 {
last := nodes[len(nodes)-1]
pairs, err := createUpperLevel(nodes[:len(nodes)-1])
if err != nil {
return nil, err
}
return append(pairs, last), nil
}
pairs := make([]*node, 0, len(nodes)/2)
for i := 0; i < len(nodes); i += 2 {
left := nodes[i]
right := nodes[i+1]
branchNode := &node{
sweepKey: left.sweepKey,
receivers: append(left.receivers, right.receivers...),
left: left,
right: right,
asset: left.asset,
feeSats: left.feeSats,
roundLifetime: left.roundLifetime,
}
pairs = append(pairs, branchNode)
}
return pairs, nil
}

View File

@@ -0,0 +1,167 @@
package txbuilder
import (
"encoding/hex"
"fmt"
"github.com/ark-network/ark/internal/core/domain"
"github.com/ark-network/ark/internal/core/ports"
"github.com/btcsuite/btcd/btcec/v2/schnorr"
"github.com/btcsuite/btcd/txscript"
"github.com/decred/dcrd/dcrec/secp256k1/v4"
"github.com/vulpemventures/go-elements/address"
"github.com/vulpemventures/go-elements/elementsutil"
"github.com/vulpemventures/go-elements/network"
"github.com/vulpemventures/go-elements/payment"
"github.com/vulpemventures/go-elements/psetv2"
"github.com/vulpemventures/go-elements/taproot"
"github.com/vulpemventures/go-elements/transaction"
)
func p2wpkhScript(publicKey *secp256k1.PublicKey, net *network.Network) ([]byte, error) {
payment := payment.FromPublicKey(publicKey, net, nil)
addr, err := payment.WitnessPubKeyHash()
if err != nil {
return nil, err
}
return address.ToOutputScript(addr)
}
func getTxid(txStr string) (string, error) {
pset, err := psetv2.NewPsetFromBase64(txStr)
if err != nil {
return "", err
}
return getPsetId(pset)
}
func getPsetId(pset *psetv2.Pset) (string, error) {
utx, err := pset.UnsignedTx()
if err != nil {
return "", err
}
return utx.TxHash().String(), nil
}
func getOnchainReceivers(
payments []domain.Payment,
) []domain.Receiver {
receivers := make([]domain.Receiver, 0)
for _, payment := range payments {
for _, receiver := range payment.Receivers {
if receiver.IsOnchain() {
receivers = append(receivers, receiver)
}
}
}
return receivers
}
func getOffchainReceivers(
payments []domain.Payment,
) []domain.Receiver {
receivers := make([]domain.Receiver, 0)
for _, payment := range payments {
for _, receiver := range payment.Receivers {
if !receiver.IsOnchain() {
receivers = append(receivers, receiver)
}
}
}
return receivers
}
func toWitnessUtxo(in ports.TxInput) (*transaction.TxOutput, error) {
valueBytes, err := elementsutil.ValueToBytes(in.GetValue())
if err != nil {
return nil, fmt.Errorf("failed to convert value to bytes: %s", err)
}
assetBytes, err := elementsutil.AssetHashToBytes(in.GetAsset())
if err != nil {
return nil, fmt.Errorf("failed to convert asset to bytes: %s", err)
}
scriptBytes, err := hex.DecodeString(in.GetScript())
if err != nil {
return nil, fmt.Errorf("failed to decode script: %s", err)
}
return transaction.NewTxOutput(assetBytes, valueBytes, scriptBytes), nil
}
func countSpentVtxos(payments []domain.Payment) uint64 {
var sum uint64
for _, payment := range payments {
sum += uint64(len(payment.Inputs))
}
return sum
}
func addInputs(
updater *psetv2.Updater,
inputs []ports.TxInput,
) error {
for _, in := range inputs {
inputArg := psetv2.InputArgs{
Txid: in.GetTxid(),
TxIndex: in.GetIndex(),
}
witnessUtxo, err := toWitnessUtxo(in)
if err != nil {
return err
}
if err := updater.AddInputs([]psetv2.InputArgs{inputArg}); err != nil {
return err
}
index := int(updater.Pset.Global.InputCount) - 1
if err := updater.AddInWitnessUtxo(index, witnessUtxo); err != nil {
return err
}
if err := updater.AddInSighashType(index, txscript.SigHashAll); err != nil {
return err
}
}
return nil
}
// wrapper of updater methods adding a taproot input to the pset with all the necessary data to spend it via any taproot script
func addTaprootInput(
updater *psetv2.Updater,
input psetv2.InputArgs,
internalTaprootKey *secp256k1.PublicKey,
taprootTree *taproot.IndexedElementsTapScriptTree,
) error {
if err := updater.AddInputs([]psetv2.InputArgs{input}); err != nil {
return err
}
if err := updater.AddInTapInternalKey(0, schnorr.SerializePubKey(internalTaprootKey)); err != nil {
return err
}
for _, proof := range taprootTree.LeafMerkleProofs {
controlBlock := proof.ToControlBlock(internalTaprootKey)
if err := updater.AddInTapLeafScript(0, psetv2.TapLeafScript{
TapElementsLeaf: taproot.NewBaseTapElementsLeaf(proof.Script),
ControlBlock: controlBlock,
}); err != nil {
return err
}
}
return nil
}
func taprootOutputScript(taprootKey *secp256k1.PublicKey) ([]byte, error) {
return txscript.NewScriptBuilder().AddOp(txscript.OP_1).AddData(schnorr.SerializePubKey(taprootKey)).Script()
}

View File

@@ -0,0 +1,285 @@
package txbuilder
import (
"context"
"github.com/ark-network/ark/common/tree"
"github.com/ark-network/ark/internal/core/domain"
"github.com/ark-network/ark/internal/core/ports"
"github.com/decred/dcrd/dcrec/secp256k1/v4"
"github.com/vulpemventures/go-elements/address"
"github.com/vulpemventures/go-elements/network"
"github.com/vulpemventures/go-elements/payment"
"github.com/vulpemventures/go-elements/psetv2"
"github.com/vulpemventures/go-elements/transaction"
)
const (
connectorAmount = 450
sevenDays = 7 * 24 * 60 * 60
)
type txBuilder struct {
wallet ports.WalletService
net network.Network
}
func NewTxBuilder(
wallet ports.WalletService, net network.Network,
) ports.TxBuilder {
return &txBuilder{wallet, net}
}
// BuildSweepTx implements ports.TxBuilder.
func (*txBuilder) BuildSweepTx(wallet ports.WalletService, inputs []ports.SweepInput) (signedSweepTx string, err error) {
panic("unimplemented")
}
// BuildForfeitTxs implements ports.TxBuilder.
func (b *txBuilder) BuildForfeitTxs(
aspPubkey *secp256k1.PublicKey, poolTx string, payments []domain.Payment,
) (connectors []string, forfeitTxs []string, err error) {
poolTxID, err := getTxid(poolTx)
if err != nil {
return nil, nil, err
}
aspScript, err := p2wpkhScript(aspPubkey, b.net)
if err != nil {
return nil, nil, err
}
numberOfConnectors := countSpentVtxos(payments)
connectors, err = createConnectors(
poolTxID,
1,
psetv2.OutputArgs{
Asset: b.net.AssetID,
Amount: connectorAmount,
Script: aspScript,
},
aspScript,
numberOfConnectors,
)
if err != nil {
return nil, nil, err
}
connectorsAsInputs, err := connectorsToInputArgs(connectors)
if err != nil {
return nil, nil, err
}
forfeitTxs = make([]string, 0)
for _, payment := range payments {
for _, vtxo := range payment.Inputs {
for _, connector := range connectorsAsInputs {
forfeitTx, err := createForfeitTx(
connector,
psetv2.InputArgs{
Txid: vtxo.Txid,
TxIndex: vtxo.VOut,
},
vtxo.Amount,
aspScript,
b.net,
)
if err != nil {
return nil, nil, err
}
forfeitTxs = append(forfeitTxs, forfeitTx)
}
}
}
return connectors, forfeitTxs, nil
}
// BuildPoolTx implements ports.TxBuilder.
func (b *txBuilder) BuildPoolTx(
aspPubkey *secp256k1.PublicKey, payments []domain.Payment, minRelayFee uint64,
) (poolTx string, congestionTree tree.CongestionTree, err error) {
aspScriptBytes, err := p2wpkhScript(aspPubkey, b.net)
if err != nil {
return "", nil, err
}
offchainReceivers, onchainReceivers := receiversFromPayments(payments)
sharedOutputAmount := sumReceivers(offchainReceivers)
numberOfConnectors := countSpentVtxos(payments)
connectorOutputAmount := connectorAmount * numberOfConnectors
ctx := context.Background()
outputs := []psetv2.OutputArgs{
{
Asset: b.net.AssetID,
Amount: sharedOutputAmount,
Script: aspScriptBytes,
},
{
Asset: b.net.AssetID,
Amount: connectorOutputAmount,
Script: aspScriptBytes,
},
}
amountToSelect := sharedOutputAmount + connectorOutputAmount
for _, receiver := range onchainReceivers {
amountToSelect += receiver.Amount
receiverScript, err := address.ToOutputScript(receiver.OnchainAddress)
if err != nil {
return "", nil, err
}
outputs = append(outputs, psetv2.OutputArgs{
Asset: b.net.AssetID,
Amount: receiver.Amount,
Script: receiverScript,
})
}
utxos, change, err := b.wallet.SelectUtxos(ctx, b.net.AssetID, amountToSelect)
if err != nil {
return
}
if change > 0 {
outputs = append(outputs, psetv2.OutputArgs{
Asset: b.net.AssetID,
Amount: change,
Script: aspScriptBytes,
})
}
ptx, err := psetv2.New(toInputArgs(utxos), outputs, nil)
if err != nil {
return
}
utx, err := ptx.UnsignedTx()
if err != nil {
return
}
congestionTree, err = buildCongestionTree(
newOutputScriptFactory(aspPubkey, b.net),
b.net,
utx.TxHash().String(),
offchainReceivers,
)
if err != nil {
return
}
poolTx, err = ptx.ToBase64()
if err != nil {
return
}
return poolTx, congestionTree, err
}
func (b *txBuilder) GetVtxoScript(userPubkey, _ *secp256k1.PublicKey) ([]byte, error) {
p2wpkh := payment.FromPublicKey(userPubkey, &b.net, nil)
addr, _ := p2wpkh.WitnessPubKeyHash()
return address.ToOutputScript(addr)
}
func (b *txBuilder) GetLeafSweepClosure(
node tree.Node, userPubKey *secp256k1.PublicKey,
) (*psetv2.TapLeafScript, int64, error) {
panic("unimplemented")
}
func connectorsToInputArgs(connectors []string) ([]psetv2.InputArgs, error) {
inputs := make([]psetv2.InputArgs, 0, len(connectors)+1)
for i, psetb64 := range connectors {
tx, err := psetv2.NewPsetFromBase64(psetb64)
if err != nil {
return nil, err
}
utx, err := tx.UnsignedTx()
if err != nil {
return nil, err
}
txid := utx.TxHash().String()
for j := range tx.Outputs {
inputs = append(inputs, psetv2.InputArgs{
Txid: txid,
TxIndex: uint32(j),
})
if i != len(connectors)-1 {
break
}
}
}
return inputs, nil
}
func getTxid(txStr string) (string, error) {
pset, err := psetv2.NewPsetFromBase64(txStr)
if err != nil {
tx, err := transaction.NewTxFromHex(txStr)
if err != nil {
return "", err
}
return tx.TxHash().String(), nil
}
utx, err := pset.UnsignedTx()
if err != nil {
return "", err
}
return utx.TxHash().String(), nil
}
func countSpentVtxos(payments []domain.Payment) uint64 {
var sum uint64
for _, payment := range payments {
sum += uint64(len(payment.Inputs))
}
return sum
}
func receiversFromPayments(
payments []domain.Payment,
) (offchainReceivers, onchainReceivers []domain.Receiver) {
for _, payment := range payments {
for _, receiver := range payment.Receivers {
if receiver.IsOnchain() {
onchainReceivers = append(onchainReceivers, receiver)
} else {
offchainReceivers = append(offchainReceivers, receiver)
}
}
}
return
}
func sumReceivers(receivers []domain.Receiver) uint64 {
var sum uint64
for _, r := range receivers {
sum += r.Amount
}
return sum
}
func toInputArgs(
ins []ports.TxInput,
) []psetv2.InputArgs {
inputs := make([]psetv2.InputArgs, 0, len(ins))
for _, in := range ins {
inputs = append(inputs, psetv2.InputArgs{
Txid: in.GetTxid(),
TxIndex: in.GetIndex(),
})
}
return inputs
}

View File

@@ -0,0 +1,407 @@
package txbuilder_test
import (
"context"
"crypto/rand"
"encoding/hex"
"testing"
"github.com/ark-network/ark/common"
"github.com/ark-network/ark/internal/core/domain"
"github.com/ark-network/ark/internal/core/ports"
txbuilder "github.com/ark-network/ark/internal/infrastructure/tx-builder/dummy"
"github.com/btcsuite/btcd/chaincfg/chainhash"
secp256k1 "github.com/decred/dcrd/dcrec/secp256k1/v4"
"github.com/stretchr/testify/require"
"github.com/vulpemventures/go-elements/network"
"github.com/vulpemventures/go-elements/psetv2"
)
const (
testingKey = "apub1qgvdtj5ttpuhkldavhq8thtm5auyk0ec4dcmrfdgu0u5hgp9we22v3hrs4x"
fakePoolTx = "cHNldP8BAgQCAAAAAQQBAQEFAQMBBgEDAfsEAgAAAAABDiDk7dXxh4KQzgLO8i1ABtaLCe4aPL12GVhN1E9zM1ePLwEPBAAAAAABEAT/////AAEDCOgDAAAAAAAAAQQWABSNnpy01UJqd99eTg2M1IpdKId11gf8BHBzZXQCICWyUQcOKcoZBDzzPM1zJOLdqwPsxK4LXnfE/A5c9slaB/wEcHNldAgEAAAAAAABAwh4BQAAAAAAAAEEFgAUjZ6ctNVCanffXk4NjNSKXSiHddYH/ARwc2V0AiAlslEHDinKGQQ88zzNcyTi3asD7MSuC153xPwOXPbJWgf8BHBzZXQIBAAAAAAAAQMI9AEAAAAAAAABBAAH/ARwc2V0AiAlslEHDinKGQQ88zzNcyTi3asD7MSuC153xPwOXPbJWgf8BHBzZXQIBAAAAAAA"
)
type input struct {
txid string
vout uint32
}
func (i *input) GetTxid() string {
return i.txid
}
func (i *input) GetIndex() uint32 {
return i.vout
}
func (i *input) GetScript() string {
return "a914ea9f486e82efb3dd83a69fd96e3f0113757da03c87"
}
func (i *input) GetAsset() string {
return "5ac9f65c0efcc4775e0baec4ec03abdde22473cd3cf33c0419ca290e0751b225"
}
func (i *input) GetValue() uint64 {
return 1000
}
type mockedWalletService struct{}
// BroadcastTransaction implements ports.WalletService.
func (*mockedWalletService) BroadcastTransaction(ctx context.Context, txHex string) (string, error) {
panic("unimplemented")
}
// Close implements ports.WalletService.
func (*mockedWalletService) Close() {
panic("unimplemented")
}
// DeriveAddresses implements ports.WalletService.
func (*mockedWalletService) DeriveAddresses(ctx context.Context, num int) ([]string, error) {
panic("unimplemented")
}
// GetPubkey implements ports.WalletService.
func (*mockedWalletService) GetPubkey(ctx context.Context) (*secp256k1.PublicKey, error) {
panic("unimplemented")
}
// SignPset implements ports.WalletService.
func (*mockedWalletService) SignPset(ctx context.Context, pset string, extractRawTx bool) (string, error) {
panic("unimplemented")
}
// Status implements ports.WalletService.
func (*mockedWalletService) Status(ctx context.Context) (ports.WalletStatus, error) {
panic("unimplemented")
}
func (*mockedWalletService) WatchScripts(ctx context.Context, scripts []string) error {
panic("unimplemented")
}
func (*mockedWalletService) UnwatchScripts(ctx context.Context, scripts []string) error {
panic("unimplemented")
}
func (*mockedWalletService) GetNotificationChannel(ctx context.Context) chan []domain.VtxoKey {
panic("unimplemented")
}
func (*mockedWalletService) SelectUtxos(ctx context.Context, asset string, amount uint64) ([]ports.TxInput, uint64, error) {
// random txid
bytes := make([]byte, 32)
if _, err := rand.Read(bytes); err != nil {
return nil, 0, err
}
fakeInput := input{
txid: hex.EncodeToString(bytes),
vout: 0,
}
return []ports.TxInput{&fakeInput}, 0, nil
}
func (*mockedWalletService) EstimateFees(ctx context.Context, pset string) (uint64, error) {
return 100, nil
}
func (*mockedWalletService) SignPsetWithKey(ctx context.Context, pset string, inputIndex []int) (string, error) {
panic("unimplemented")
}
func (*mockedWalletService) IsTransactionPublished(ctx context.Context, txid string) (bool, int64, error) {
panic("unimplemented")
}
func TestBuildCongestionTree(t *testing.T) {
builder := txbuilder.NewTxBuilder(&mockedWalletService{}, network.Liquid)
fixtures := []struct {
payments []domain.Payment
expectedNodesNum int // 2*len(receivers)-1
expectedLeavesNum int
}{
{
payments: []domain.Payment{
{
Id: "0",
Inputs: []domain.Vtxo{
{
VtxoKey: domain.VtxoKey{
Txid: "fd68e3c5796cc7db0a8036d486d5f625b6b2f2c014810ac020e1ac23e82c59d6",
VOut: 0,
},
Receiver: domain.Receiver{
Pubkey: "020000000000000000000000000000000000000000000000000000000000000002",
Amount: 600,
},
},
},
Receivers: []domain.Receiver{
{
Pubkey: "020000000000000000000000000000000000000000000000000000000000000002",
Amount: 600,
},
{
Pubkey: "020000000000000000000000000000000000000000000000000000000000000002",
Amount: 400,
},
},
},
},
expectedNodesNum: 3,
expectedLeavesNum: 2,
},
{
payments: []domain.Payment{
{
Id: "0",
Inputs: []domain.Vtxo{
{
VtxoKey: domain.VtxoKey{
Txid: "fd68e3c5796cc7db0a8036d486d5f625b6b2f2c014810ac020e1ac23e82c59d6",
VOut: 0,
},
Receiver: domain.Receiver{
Pubkey: "020000000000000000000000000000000000000000000000000000000000000002",
Amount: 600,
},
},
},
Receivers: []domain.Receiver{
{
Pubkey: "020000000000000000000000000000000000000000000000000000000000000002",
Amount: 600,
},
{
Pubkey: "020000000000000000000000000000000000000000000000000000000000000002",
Amount: 400,
},
},
},
{
Id: "0",
Inputs: []domain.Vtxo{
{
VtxoKey: domain.VtxoKey{
Txid: "fd68e3c5796cc7db0a8036d486d5f625b6b2f2c014810ac020e1ac23e82c59d6",
VOut: 0,
},
Receiver: domain.Receiver{
Pubkey: "020000000000000000000000000000000000000000000000000000000000000002",
Amount: 600,
},
},
},
Receivers: []domain.Receiver{
{
Pubkey: "020000000000000000000000000000000000000000000000000000000000000002",
Amount: 600,
},
{
Pubkey: "020000000000000000000000000000000000000000000000000000000000000002",
Amount: 400,
},
},
},
{
Id: "0",
Inputs: []domain.Vtxo{
{
VtxoKey: domain.VtxoKey{
Txid: "fd68e3c5796cc7db0a8036d486d5f625b6b2f2c014810ac020e1ac23e82c59d6",
VOut: 0,
},
Receiver: domain.Receiver{
Pubkey: "020000000000000000000000000000000000000000000000000000000000000002",
Amount: 600,
},
},
},
Receivers: []domain.Receiver{
{
Pubkey: "020000000000000000000000000000000000000000000000000000000000000002",
Amount: 600,
},
{
Pubkey: "020000000000000000000000000000000000000000000000000000000000000002",
Amount: 400,
},
},
},
},
expectedNodesNum: 11,
expectedLeavesNum: 6,
},
}
_, key, err := common.DecodePubKey(testingKey)
require.NoError(t, err)
require.NotNil(t, key)
for _, f := range fixtures {
poolTx, tree, err := builder.BuildPoolTx(key, f.payments, 30)
require.NoError(t, err)
require.Equal(t, f.expectedNodesNum, tree.NumberOfNodes())
require.Len(t, tree.Leaves(), f.expectedLeavesNum)
poolPset, err := psetv2.NewPsetFromBase64(poolTx)
require.NoError(t, err)
poolTxUnsigned, err := poolPset.UnsignedTx()
require.NoError(t, err)
poolTxID := poolTxUnsigned.TxHash().String()
// check the root
require.Len(t, tree[0], 1)
require.Equal(t, poolTxID, tree[0][0].ParentTxid)
// check the leaves
for _, leaf := range tree.Leaves() {
pset, err := psetv2.NewPsetFromBase64(leaf.Tx)
require.NoError(t, err)
require.Len(t, pset.Inputs, 1)
require.Len(t, pset.Outputs, 1)
inputTxID := chainhash.Hash(pset.Inputs[0].PreviousTxid).String()
require.Equal(t, leaf.ParentTxid, inputTxID)
}
// check the nodes
for _, level := range tree[:len(tree)-2] {
for _, node := range level {
pset, err := psetv2.NewPsetFromBase64(node.Tx)
require.NoError(t, err)
require.Len(t, pset.Inputs, 1)
require.Len(t, pset.Outputs, 2)
inputTxID := chainhash.Hash(pset.Inputs[0].PreviousTxid).String()
require.Equal(t, node.ParentTxid, inputTxID)
children := tree.Children(node.Txid)
require.Len(t, children, 2)
}
}
}
}
func TestBuildForfeitTxs(t *testing.T) {
builder := txbuilder.NewTxBuilder(&mockedWalletService{}, network.Liquid)
poolPset, err := psetv2.NewPsetFromBase64(fakePoolTx)
require.NoError(t, err)
poolTxUnsigned, err := poolPset.UnsignedTx()
require.NoError(t, err)
poolTxID := poolTxUnsigned.TxHash().String()
fixtures := []struct {
payments []domain.Payment
expectedNumOfForfeitTxs int
expectedNumOfConnectors int
}{
{
payments: []domain.Payment{
{
Id: "0",
Inputs: []domain.Vtxo{
{
VtxoKey: domain.VtxoKey{
Txid: "fd68e3c5796cc7db0a8036d486d5f625b6b2f2c014810ac020e1ac23e82c59d6",
VOut: 0,
},
Receiver: domain.Receiver{
Pubkey: "020000000000000000000000000000000000000000000000000000000000000002",
Amount: 600,
},
},
{
VtxoKey: domain.VtxoKey{
Txid: "fd68e3c5796cc7db0a8036d486d5f625b6b2f2c014810ac020e1ac23e82c59d6",
VOut: 1,
},
Receiver: domain.Receiver{
Pubkey: "020000000000000000000000000000000000000000000000000000000000000002",
Amount: 400,
},
},
},
Receivers: []domain.Receiver{
{
Pubkey: "020000000000000000000000000000000000000000000000000000000000000002",
Amount: 600,
},
{
Pubkey: "020000000000000000000000000000000000000000000000000000000000000002",
Amount: 400,
},
},
},
},
expectedNumOfForfeitTxs: 4,
expectedNumOfConnectors: 1,
},
}
_, key, err := common.DecodePubKey(testingKey)
require.NoError(t, err)
require.NotNil(t, key)
for _, f := range fixtures {
connectors, forfeitTxs, err := builder.BuildForfeitTxs(
key, fakePoolTx, f.payments,
)
require.NoError(t, err)
require.Len(t, connectors, f.expectedNumOfConnectors)
require.Len(t, forfeitTxs, f.expectedNumOfForfeitTxs)
// decode and check connectors
connectorsPsets := make([]*psetv2.Pset, 0, f.expectedNumOfConnectors)
for _, pset := range connectors {
p, err := psetv2.NewPsetFromBase64(pset)
require.NoError(t, err)
connectorsPsets = append(connectorsPsets, p)
}
for i, pset := range connectorsPsets {
require.Len(t, pset.Inputs, 1)
require.Len(t, pset.Outputs, 2)
expectedInputTxid := poolTxID
expectedInputVout := uint32(1)
if i > 0 {
tx, err := connectorsPsets[i-1].UnsignedTx()
require.NoError(t, err)
require.NotNil(t, tx)
expectedInputTxid = tx.TxHash().String()
}
inputTxid := chainhash.Hash(pset.Inputs[0].PreviousTxid).String()
require.Equal(t, expectedInputTxid, inputTxid)
require.Equal(t, expectedInputVout, pset.Inputs[0].PreviousTxIndex)
}
// decode and check forfeit txs
forfeitTxsPsets := make([]*psetv2.Pset, 0, f.expectedNumOfForfeitTxs)
for _, pset := range forfeitTxs {
p, err := psetv2.NewPsetFromBase64(pset)
require.NoError(t, err)
forfeitTxsPsets = append(forfeitTxsPsets, p)
}
// each forfeit tx should have 2 inputs and 2 outputs
for _, pset := range forfeitTxsPsets {
require.Len(t, pset.Inputs, 2)
require.Len(t, pset.Outputs, 1)
}
}
}

View File

@@ -0,0 +1,103 @@
package txbuilder
import (
"github.com/vulpemventures/go-elements/psetv2"
)
func createConnectors(
poolTxID string,
connectorOutputIndex uint32,
connectorOutput psetv2.OutputArgs,
changeScript []byte,
numberOfConnectors uint64,
) (connectorsPsets []string, err error) {
previousInput := psetv2.InputArgs{
Txid: poolTxID,
TxIndex: connectorOutputIndex,
}
if numberOfConnectors == 1 {
pset, err := psetv2.New(nil, nil, nil)
if err != nil {
return nil, err
}
updater, err := psetv2.NewUpdater(pset)
if err != nil {
return nil, err
}
err = updater.AddInputs([]psetv2.InputArgs{previousInput})
if err != nil {
return nil, err
}
err = updater.AddOutputs([]psetv2.OutputArgs{connectorOutput})
if err != nil {
return nil, err
}
base64, err := pset.ToBase64()
if err != nil {
return nil, err
}
return []string{base64}, nil
}
// compute the initial amount of the connectors output in pool transaction
remainingAmount := connectorAmount * numberOfConnectors
connectorsPset := make([]string, 0, numberOfConnectors-1)
for i := uint64(0); i < numberOfConnectors-1; i++ {
// create a new pset
pset, err := psetv2.New(nil, nil, nil)
if err != nil {
return nil, err
}
updater, err := psetv2.NewUpdater(pset)
if err != nil {
return nil, err
}
err = updater.AddInputs([]psetv2.InputArgs{previousInput})
if err != nil {
return nil, err
}
err = updater.AddOutputs([]psetv2.OutputArgs{connectorOutput})
if err != nil {
return nil, err
}
changeAmount := remainingAmount - connectorOutput.Amount
if changeAmount > 0 {
changeOutput := psetv2.OutputArgs{
Asset: connectorOutput.Asset,
Amount: changeAmount,
Script: changeScript,
}
err = updater.AddOutputs([]psetv2.OutputArgs{changeOutput})
if err != nil {
return nil, err
}
tx, _ := pset.UnsignedTx()
txid := tx.TxHash().String()
// make the change the next previousInput
previousInput = psetv2.InputArgs{
Txid: txid,
TxIndex: 1,
}
}
base64, err := pset.ToBase64()
if err != nil {
return nil, err
}
connectorsPset = append(connectorsPset, base64)
}
return connectorsPset, nil
}

View File

@@ -0,0 +1,42 @@
package txbuilder
import (
"github.com/vulpemventures/go-elements/network"
"github.com/vulpemventures/go-elements/psetv2"
)
func createForfeitTx(
connectorInput psetv2.InputArgs,
vtxoInput psetv2.InputArgs,
vtxoAmount uint64,
aspScript []byte,
net network.Network,
) (forfeitTx string, err error) {
pset, err := psetv2.New(nil, nil, nil)
if err != nil {
return "", err
}
updater, err := psetv2.NewUpdater(pset)
if err != nil {
return "", err
}
err = updater.AddInputs([]psetv2.InputArgs{connectorInput, vtxoInput})
if err != nil {
return "", err
}
err = updater.AddOutputs([]psetv2.OutputArgs{
{
Asset: net.AssetID,
Amount: vtxoAmount,
Script: aspScript,
},
})
if err != nil {
return "", err
}
return pset.ToBase64()
}

View File

@@ -0,0 +1,309 @@
package txbuilder
import (
"encoding/hex"
"github.com/ark-network/ark/common/tree"
"github.com/ark-network/ark/internal/core/domain"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/decred/dcrd/dcrec/secp256k1/v4"
"github.com/vulpemventures/go-elements/address"
"github.com/vulpemventures/go-elements/network"
"github.com/vulpemventures/go-elements/payment"
"github.com/vulpemventures/go-elements/psetv2"
)
const (
sharedOutputIndex = 0
)
type outputScriptFactory func(leaves []domain.Receiver) ([]byte, error)
func p2wpkhScript(publicKey *secp256k1.PublicKey, net network.Network) ([]byte, error) {
payment := payment.FromPublicKey(publicKey, &net, nil)
addr, err := payment.WitnessPubKeyHash()
if err != nil {
return nil, err
}
return address.ToOutputScript(addr)
}
// newOtputScriptFactory returns an output script factory func that lock funds using the ASP public key only on all branches psbt. The leaves are instead locked by the leaf public key.
func newOutputScriptFactory(aspPublicKey *secp256k1.PublicKey, net network.Network) outputScriptFactory {
return func(leaves []domain.Receiver) ([]byte, error) {
aspScript, err := p2wpkhScript(aspPublicKey, net)
if err != nil {
return nil, err
}
switch len(leaves) {
case 0:
return nil, nil
case 1: // it's a leaf
buf, err := hex.DecodeString(leaves[0].Pubkey)
if err != nil {
return nil, err
}
key, err := secp256k1.ParsePubKey(buf)
if err != nil {
return nil, err
}
return p2wpkhScript(key, net)
default: // it's a branch, lock funds with ASP public key
return aspScript, nil
}
}
}
// congestionTree builder iteratively creates a binary tree of Pset from a set of receivers
// it also expect createOutputScript func managing the output script creation and the network to use (mainly for L-BTC asset id)
func buildCongestionTree(
createOutputScript outputScriptFactory,
net network.Network,
poolTxID string,
receivers []domain.Receiver,
) (congestionTree tree.CongestionTree, err error) {
var nodes []*node
for _, r := range receivers {
nodes = append(nodes, newLeaf(createOutputScript, net, r))
}
for len(nodes) > 1 {
nodes, err = createTreeLevel(nodes)
if err != nil {
return nil, err
}
}
psets, err := nodes[0].psets(psetv2.InputArgs{
Txid: poolTxID,
TxIndex: sharedOutputIndex,
}, 0)
if err != nil {
return nil, err
}
maxLevel := 0
for _, psetWithLevel := range psets {
if psetWithLevel.level > maxLevel {
maxLevel = psetWithLevel.level
}
}
congestionTree = make(tree.CongestionTree, maxLevel+1)
for _, psetWithLevel := range psets {
utx, err := psetWithLevel.pset.UnsignedTx()
if err != nil {
return nil, err
}
txid := utx.TxHash().String()
psetB64, err := psetWithLevel.pset.ToBase64()
if err != nil {
return nil, err
}
parentTxid := chainhash.Hash(psetWithLevel.pset.Inputs[0].PreviousTxid).String()
congestionTree[psetWithLevel.level] = append(congestionTree[psetWithLevel.level], tree.Node{
Txid: txid,
Tx: psetB64,
ParentTxid: parentTxid,
Leaf: psetWithLevel.leaf,
})
}
return congestionTree, nil
}
func createTreeLevel(nodes []*node) ([]*node, error) {
if len(nodes)%2 != 0 {
last := nodes[len(nodes)-1]
pairs, err := createTreeLevel(nodes[:len(nodes)-1])
if err != nil {
return nil, err
}
return append(pairs, last), nil
}
pairs := make([]*node, 0, len(nodes)/2)
for i := 0; i < len(nodes); i += 2 {
pairs = append(pairs, newBranch(nodes[i], nodes[i+1]))
}
return pairs, nil
}
// internal struct to build a binary tree of Pset
type node struct {
receivers []domain.Receiver
left *node
right *node
createOutputScript outputScriptFactory
network network.Network
}
// create a node from a single receiver
func newLeaf(
createOutputScript outputScriptFactory,
network network.Network,
receiver domain.Receiver,
) *node {
return &node{
receivers: []domain.Receiver{receiver},
createOutputScript: createOutputScript,
network: network,
left: nil,
right: nil,
}
}
// aggregate two nodes into a branch node
func newBranch(
left *node,
right *node,
) *node {
return &node{
receivers: append(left.receivers, right.receivers...),
createOutputScript: left.createOutputScript,
network: left.network,
left: left,
right: right,
}
}
// is it the final node of the tree
func (n *node) isLeaf() bool {
return n.left == nil && n.right == nil
}
// compute the output amount of a node
func (n *node) amount() uint64 {
var amount uint64
for _, r := range n.receivers {
amount += r.Amount
}
return amount
}
// compute the output script of a node
func (n *node) script() ([]byte, error) {
return n.createOutputScript(n.receivers)
}
// use script & amount to create OutputArgs
func (n *node) output() (*psetv2.OutputArgs, error) {
script, err := n.script()
if err != nil {
return nil, err
}
return &psetv2.OutputArgs{
Asset: n.network.AssetID,
Amount: n.amount(),
Script: script,
}, nil
}
// create the node Pset from the previous node Pset represented by input arg
// if node is a branch, it adds two outputs to the Pset, one for the left branch and one for the right branch
// if node is a leaf, it only adds one output to the Pset (the node output)
func (n *node) pset(input psetv2.InputArgs) (*psetv2.Pset, error) {
pset, err := psetv2.New(nil, nil, nil)
if err != nil {
return nil, err
}
updater, err := psetv2.NewUpdater(pset)
if err != nil {
return nil, err
}
err = updater.AddInputs([]psetv2.InputArgs{input})
if err != nil {
return nil, err
}
if n.isLeaf() {
output, err := n.output()
if err != nil {
return nil, err
}
err = updater.AddOutputs([]psetv2.OutputArgs{*output})
if err != nil {
return nil, err
}
return pset, nil
}
outputLeft, err := n.left.output()
if err != nil {
return nil, err
}
outputRight, err := n.right.output()
if err != nil {
return nil, err
}
err = updater.AddOutputs([]psetv2.OutputArgs{*outputLeft, *outputRight})
if err != nil {
return nil, err
}
return pset, nil
}
type psetWithLevel struct {
pset *psetv2.Pset
level int
leaf bool
}
// create the node pset and all the psets of its children recursively, updating the input arg at each step
// the function stops when it reaches a leaf node
func (n *node) psets(input psetv2.InputArgs, level int) ([]psetWithLevel, error) {
pset, err := n.pset(input)
if err != nil {
return nil, err
}
nodeResult := []psetWithLevel{
{pset, level, n.isLeaf()},
}
if n.isLeaf() {
return nodeResult, nil
}
unsignedTx, err := pset.UnsignedTx()
if err != nil {
return nil, err
}
txID := unsignedTx.TxHash().String()
psetsLeft, err := n.left.psets(psetv2.InputArgs{
Txid: txID,
TxIndex: 0,
}, level+1)
if err != nil {
return nil, err
}
psetsRight, err := n.right.psets(psetv2.InputArgs{
Txid: txID,
TxIndex: 1,
}, level+1)
if err != nil {
return nil, err
}
return append(nodeResult, append(psetsLeft, psetsRight...)...), nil
}