mirror of
https://github.com/aljazceru/ark.git
synced 2025-12-18 20:54:20 +01:00
Unit tests (#32)
* unit tests * Fix makefile * Fix race conditions * Renaming
This commit is contained in:
committed by
GitHub
parent
c8d9db89c5
commit
46d54a227d
@@ -27,7 +27,7 @@ help:
|
||||
## intergrationtest: runs integration tests
|
||||
integrationtest:
|
||||
@echo "Running integration tests..."
|
||||
@go test -v -count=1 -race ./... $(go list ./... | grep internal/test)
|
||||
@find . -name go.mod -execdir go test -v -count=1 -race $(go list ./... | grep internal/test) \;
|
||||
|
||||
## lint: lint codebase
|
||||
lint:
|
||||
|
||||
@@ -19,7 +19,10 @@ import (
|
||||
"github.com/vulpemventures/go-elements/psetv2"
|
||||
)
|
||||
|
||||
const paymentsThreshold = 128
|
||||
const (
|
||||
paymentsThreshold = 128
|
||||
dustAmount = 450
|
||||
)
|
||||
|
||||
type Service interface {
|
||||
SpendVtxos(ctx context.Context, inputs []domain.VtxoKey) (string, error)
|
||||
@@ -144,7 +147,7 @@ func (s *service) start() error {
|
||||
}
|
||||
|
||||
func (s *service) startRound() {
|
||||
round := domain.NewRound()
|
||||
round := domain.NewRound(dustAmount)
|
||||
changes, _ := round.StartRegistration()
|
||||
if err := s.repoManager.Events().Save(
|
||||
context.Background(), round.Id, changes...,
|
||||
|
||||
@@ -26,21 +26,29 @@ func NewPayment(inputs []Vtxo) (*Payment, error) {
|
||||
return p, nil
|
||||
}
|
||||
|
||||
func (p *Payment) AddReceivers(recievers []Receiver) (err error) {
|
||||
func (p *Payment) AddReceivers(receivers []Receiver) (err error) {
|
||||
if p.Receivers == nil {
|
||||
p.Receivers = make([]Receiver, 0)
|
||||
}
|
||||
p.Receivers = append(p.Receivers, recievers...)
|
||||
p.Receivers = append(p.Receivers, receivers...)
|
||||
defer func() {
|
||||
if err != nil {
|
||||
p.Receivers = p.Receivers[:len(p.Receivers)-len(recievers)]
|
||||
p.Receivers = p.Receivers[:len(p.Receivers)-len(receivers)]
|
||||
}
|
||||
}()
|
||||
err = p.validate(false)
|
||||
return
|
||||
}
|
||||
|
||||
func (p Payment) TotOutputAmount() uint64 {
|
||||
func (p Payment) TotalInputAmount() uint64 {
|
||||
tot := uint64(0)
|
||||
for _, in := range p.Inputs {
|
||||
tot += in.Amount
|
||||
}
|
||||
return tot
|
||||
}
|
||||
|
||||
func (p Payment) TotalOutputAmount() uint64 {
|
||||
tot := uint64(0)
|
||||
for _, r := range p.Receivers {
|
||||
tot += r.Amount
|
||||
@@ -62,14 +70,8 @@ func (p Payment) validate(ignoreOuts bool) error {
|
||||
return fmt.Errorf("missing outputs")
|
||||
}
|
||||
// Check that input and output and output amounts match.
|
||||
inAmount := uint64(0)
|
||||
for _, in := range p.Inputs {
|
||||
inAmount += in.Amount
|
||||
}
|
||||
outAmount := uint64(0)
|
||||
for _, v := range p.Receivers {
|
||||
outAmount += v.Amount
|
||||
}
|
||||
inAmount := p.TotalInputAmount()
|
||||
outAmount := p.TotalOutputAmount()
|
||||
if inAmount != outAmount {
|
||||
return fmt.Errorf("input and output amounts mismatch")
|
||||
}
|
||||
|
||||
102
asp/internal/core/domain/payment_test.go
Normal file
102
asp/internal/core/domain/payment_test.go
Normal file
@@ -0,0 +1,102 @@
|
||||
package domain_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/ark-network/ark/internal/core/domain"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
var inputs = []domain.Vtxo{
|
||||
{
|
||||
VtxoKey: domain.VtxoKey{
|
||||
Txid: "0000000000000000000000000000000000000000000000000000000000000000",
|
||||
VOut: 0,
|
||||
},
|
||||
Receiver: domain.Receiver{
|
||||
Pubkey: "030000000000000000000000000000000000000000000000000000000000000001",
|
||||
Amount: 500,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
func TestPayment(t *testing.T) {
|
||||
t.Run("new_payment", func(t *testing.T) {
|
||||
t.Run("vaild", func(t *testing.T) {
|
||||
payment, err := domain.NewPayment(inputs)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, payment)
|
||||
require.NotEmpty(t, payment.Id)
|
||||
require.Exactly(t, inputs, payment.Inputs)
|
||||
require.Empty(t, payment.Receivers)
|
||||
})
|
||||
|
||||
t.Run("invaild", func(t *testing.T) {
|
||||
fixtures := []struct {
|
||||
inputs []domain.Vtxo
|
||||
expectedErr string
|
||||
}{
|
||||
{
|
||||
inputs: nil,
|
||||
expectedErr: "missing inputs",
|
||||
},
|
||||
}
|
||||
|
||||
for _, f := range fixtures {
|
||||
payment, err := domain.NewPayment(f.inputs)
|
||||
require.EqualError(t, err, f.expectedErr)
|
||||
require.Nil(t, payment)
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("add_receivers", func(t *testing.T) {
|
||||
t.Run("valid", func(t *testing.T) {
|
||||
payment, err := domain.NewPayment(inputs)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, payment)
|
||||
|
||||
err = payment.AddReceivers([]domain.Receiver{
|
||||
{
|
||||
Pubkey: "030000000000000000000000000000000000000000000000000000000000000001",
|
||||
Amount: 200,
|
||||
},
|
||||
{
|
||||
Pubkey: "020000000000000000000000000000000000000000000000000000000000000002",
|
||||
Amount: 300,
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("invalid", func(t *testing.T) {
|
||||
fixtures := []struct {
|
||||
receivers []domain.Receiver
|
||||
expectedErr string
|
||||
}{
|
||||
{
|
||||
receivers: nil,
|
||||
expectedErr: "missing outputs",
|
||||
},
|
||||
{
|
||||
receivers: []domain.Receiver{
|
||||
{
|
||||
Pubkey: "030000000000000000000000000000000000000000000000000000000000000001",
|
||||
Amount: 100,
|
||||
},
|
||||
},
|
||||
expectedErr: "input and output amounts mismatch",
|
||||
},
|
||||
}
|
||||
|
||||
payment, err := domain.NewPayment(inputs)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, payment)
|
||||
|
||||
for _, f := range fixtures {
|
||||
err := payment.AddReceivers(f.receivers)
|
||||
require.EqualError(t, err, f.expectedErr)
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
@@ -11,8 +11,6 @@ const (
|
||||
UndefinedStage RoundStage = iota
|
||||
RegistrationStage
|
||||
FinalizationStage
|
||||
|
||||
dustAmount = 450
|
||||
)
|
||||
|
||||
type RoundStage int
|
||||
@@ -45,15 +43,17 @@ type Round struct {
|
||||
ForfeitTxs []string
|
||||
CongestionTree []string
|
||||
Connectors []string
|
||||
DustAmount uint64
|
||||
Version uint
|
||||
changes []RoundEvent
|
||||
}
|
||||
|
||||
func NewRound() *Round {
|
||||
func NewRound(dustAmount uint64) *Round {
|
||||
return &Round{
|
||||
Id: uuid.New().String(),
|
||||
Payments: make(map[string]Payment),
|
||||
changes: make([]RoundEvent, 0),
|
||||
Id: uuid.New().String(),
|
||||
DustAmount: dustAmount,
|
||||
Payments: make(map[string]Payment),
|
||||
changes: make([]RoundEvent, 0),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -121,10 +121,44 @@ func (r *Round) StartRegistration() ([]RoundEvent, error) {
|
||||
return []RoundEvent{event}, nil
|
||||
}
|
||||
|
||||
func (r *Round) RegisterPayments(payments []Payment) ([]RoundEvent, error) {
|
||||
if r.Stage.Code != RegistrationStage || r.IsFailed() {
|
||||
return nil, fmt.Errorf("not in a valid stage to register payments")
|
||||
}
|
||||
if len(payments) <= 0 {
|
||||
return nil, fmt.Errorf("missing payments to register")
|
||||
}
|
||||
for _, p := range payments {
|
||||
if err := p.validate(false); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
event := PaymentsRegistered{
|
||||
Id: r.Id,
|
||||
Payments: payments,
|
||||
}
|
||||
r.raise(event)
|
||||
|
||||
return []RoundEvent{event}, nil
|
||||
}
|
||||
|
||||
func (r *Round) StartFinalization(connectors, tree []string, poolTx string) ([]RoundEvent, error) {
|
||||
if len(connectors) <= 0 {
|
||||
return nil, fmt.Errorf("missing list of connectors")
|
||||
}
|
||||
if len(tree) <= 0 {
|
||||
return nil, fmt.Errorf("missing congestion tree")
|
||||
}
|
||||
if len(poolTx) <= 0 {
|
||||
return nil, fmt.Errorf("missing unsigned pool tx")
|
||||
}
|
||||
if r.Stage.Code != RegistrationStage || r.IsFailed() {
|
||||
return nil, fmt.Errorf("not in a valid stage to start payment finalization")
|
||||
}
|
||||
if len(r.Payments) <= 0 {
|
||||
return nil, fmt.Errorf("no payments registered")
|
||||
}
|
||||
|
||||
event := RoundFinalizationStarted{
|
||||
Id: r.Id,
|
||||
@@ -138,11 +172,17 @@ func (r *Round) StartFinalization(connectors, tree []string, poolTx string) ([]R
|
||||
}
|
||||
|
||||
func (r *Round) EndFinalization(forfeitTxs []string, txid string) ([]RoundEvent, error) {
|
||||
if len(forfeitTxs) <= 0 {
|
||||
return nil, fmt.Errorf("missing list of signed forfeit txs")
|
||||
}
|
||||
if len(txid) <= 0 {
|
||||
return nil, fmt.Errorf("missing pool txid")
|
||||
}
|
||||
if r.Stage.Code != FinalizationStage || r.IsFailed() {
|
||||
return nil, fmt.Errorf("not in a valid stage to end payment finalization")
|
||||
}
|
||||
if r.Stage.Ended {
|
||||
return nil, fmt.Errorf("payment finalization already ended")
|
||||
return nil, fmt.Errorf("round already finalized")
|
||||
}
|
||||
event := RoundFinalized{
|
||||
Id: r.Id,
|
||||
@@ -169,30 +209,9 @@ func (r *Round) Fail(err error) []RoundEvent {
|
||||
return []RoundEvent{event}
|
||||
}
|
||||
|
||||
func (r *Round) RegisterPayments(payments []Payment) ([]RoundEvent, error) {
|
||||
if !r.IsStarted() {
|
||||
return nil, fmt.Errorf("not in a valid stage to register payments")
|
||||
}
|
||||
if len(payments) <= 0 {
|
||||
return nil, fmt.Errorf("missing payments to register")
|
||||
}
|
||||
for _, p := range payments {
|
||||
if err := p.validate(false); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
event := PaymentsRegistered{
|
||||
Id: r.Id,
|
||||
Payments: payments,
|
||||
}
|
||||
r.raise(event)
|
||||
|
||||
return []RoundEvent{event}, nil
|
||||
}
|
||||
|
||||
func (r *Round) IsStarted() bool {
|
||||
return !r.IsFailed() && r.Stage.Code == RegistrationStage
|
||||
empty := Stage{}
|
||||
return !r.IsFailed() && !r.IsEnded() && r.Stage != empty
|
||||
}
|
||||
|
||||
func (r *Round) IsEnded() bool {
|
||||
@@ -203,14 +222,18 @@ func (r *Round) IsFailed() bool {
|
||||
return r.Stage.Failed
|
||||
}
|
||||
|
||||
func (r *Round) TotInputAmount() uint64 {
|
||||
return uint64(len(r.Payments) * dustAmount)
|
||||
func (r *Round) TotalInputAmount() uint64 {
|
||||
totInputs := 0
|
||||
for _, p := range r.Payments {
|
||||
totInputs += len(p.Inputs)
|
||||
}
|
||||
return uint64(totInputs * int(r.DustAmount))
|
||||
}
|
||||
|
||||
func (r *Round) TotOutputAmount() uint64 {
|
||||
func (r *Round) TotalOutputAmount() uint64 {
|
||||
tot := uint64(0)
|
||||
for _, p := range r.Payments {
|
||||
tot += p.TotOutputAmount()
|
||||
tot += p.TotalOutputAmount()
|
||||
}
|
||||
return tot
|
||||
}
|
||||
|
||||
486
asp/internal/core/domain/round_test.go
Normal file
486
asp/internal/core/domain/round_test.go
Normal file
@@ -0,0 +1,486 @@
|
||||
package domain_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/ark-network/ark/internal/core/domain"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
var (
|
||||
dustAmount = uint64(450)
|
||||
payments = []domain.Payment{
|
||||
{
|
||||
Id: "0",
|
||||
Inputs: []domain.Vtxo{{}},
|
||||
Receivers: []domain.Receiver{{}, {}, {}},
|
||||
},
|
||||
{
|
||||
Id: "1",
|
||||
Inputs: []domain.Vtxo{{}, {}},
|
||||
Receivers: []domain.Receiver{{}},
|
||||
},
|
||||
}
|
||||
emptyPtx = "cHNldP8BAgQCAAAAAQQBAAEFAQABBgEDAfsEAgAAAAA="
|
||||
emptyTx = "0200000000000000000000"
|
||||
txid = "0000000000000000000000000000000000000000000000000000000000000000"
|
||||
congestionTree = []string{emptyPtx, emptyPtx, emptyPtx, emptyPtx, emptyPtx, emptyPtx, emptyPtx}
|
||||
connectors = []string{emptyPtx, emptyPtx, emptyPtx}
|
||||
forfeitTxs = []string{emptyPtx, emptyPtx, emptyPtx, emptyPtx, emptyPtx, emptyPtx, emptyPtx, emptyPtx, emptyPtx}
|
||||
poolTx = emptyTx
|
||||
)
|
||||
|
||||
func TestRound(t *testing.T) {
|
||||
testStartRegistration(t)
|
||||
|
||||
testRegisterPayments(t)
|
||||
|
||||
testStartFinalization(t)
|
||||
|
||||
testEndFinalization(t)
|
||||
|
||||
testFail(t)
|
||||
}
|
||||
|
||||
func testStartRegistration(t *testing.T) {
|
||||
t.Run("start_registration", func(t *testing.T) {
|
||||
t.Run("valid", func(t *testing.T) {
|
||||
round := domain.NewRound(dustAmount)
|
||||
require.NotNil(t, round)
|
||||
require.NotEmpty(t, round.Id)
|
||||
require.Empty(t, round.Events())
|
||||
require.False(t, round.IsStarted())
|
||||
require.False(t, round.IsEnded())
|
||||
require.False(t, round.IsFailed())
|
||||
|
||||
events, err := round.StartRegistration()
|
||||
require.NoError(t, err)
|
||||
require.Len(t, events, 1)
|
||||
require.True(t, round.IsStarted())
|
||||
require.False(t, round.IsEnded())
|
||||
require.False(t, round.IsFailed())
|
||||
|
||||
event, ok := events[0].(domain.RoundStarted)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, round.Id, event.Id)
|
||||
require.Equal(t, round.StartingTimestamp, event.Timestamp)
|
||||
})
|
||||
|
||||
t.Run("invalid", func(t *testing.T) {
|
||||
fixtures := []struct {
|
||||
round *domain.Round
|
||||
expectedErr string
|
||||
}{
|
||||
{
|
||||
round: &domain.Round{
|
||||
Id: "id",
|
||||
Stage: domain.Stage{
|
||||
Code: domain.UndefinedStage,
|
||||
Failed: true,
|
||||
},
|
||||
},
|
||||
expectedErr: "not in a valid stage to start payment registration",
|
||||
},
|
||||
{
|
||||
round: &domain.Round{
|
||||
Id: "id",
|
||||
Stage: domain.Stage{
|
||||
Code: domain.RegistrationStage,
|
||||
},
|
||||
},
|
||||
expectedErr: "not in a valid stage to start payment registration",
|
||||
},
|
||||
{
|
||||
round: &domain.Round{
|
||||
Id: "id",
|
||||
Stage: domain.Stage{
|
||||
Code: domain.FinalizationStage,
|
||||
},
|
||||
},
|
||||
expectedErr: "not in a valid stage to start payment registration",
|
||||
},
|
||||
}
|
||||
|
||||
for _, f := range fixtures {
|
||||
events, err := f.round.StartRegistration()
|
||||
require.EqualError(t, err, f.expectedErr)
|
||||
require.Empty(t, events)
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func testRegisterPayments(t *testing.T) {
|
||||
t.Run("register_payments", func(t *testing.T) {
|
||||
t.Run("valid", func(t *testing.T) {
|
||||
round := domain.NewRound(dustAmount)
|
||||
events, err := round.StartRegistration()
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, events)
|
||||
|
||||
events, err = round.RegisterPayments(payments)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, events, 1)
|
||||
require.Condition(t, func() bool {
|
||||
for _, payment := range payments {
|
||||
_, ok := round.Payments[payment.Id]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
event, ok := events[0].(domain.PaymentsRegistered)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, round.Id, event.Id)
|
||||
require.Equal(t, payments, event.Payments)
|
||||
})
|
||||
|
||||
t.Run("invalid", func(t *testing.T) {
|
||||
fixtures := []struct {
|
||||
round *domain.Round
|
||||
payments []domain.Payment
|
||||
expectedErr string
|
||||
}{
|
||||
{
|
||||
round: &domain.Round{
|
||||
Id: "id",
|
||||
Stage: domain.Stage{},
|
||||
},
|
||||
payments: payments,
|
||||
expectedErr: "not in a valid stage to register payments",
|
||||
},
|
||||
{
|
||||
round: &domain.Round{
|
||||
Id: "id",
|
||||
Stage: domain.Stage{
|
||||
Code: domain.RegistrationStage,
|
||||
Failed: true,
|
||||
},
|
||||
},
|
||||
payments: payments,
|
||||
expectedErr: "not in a valid stage to register payments",
|
||||
},
|
||||
{
|
||||
round: &domain.Round{
|
||||
Id: "id",
|
||||
Stage: domain.Stage{
|
||||
Code: domain.FinalizationStage,
|
||||
},
|
||||
},
|
||||
payments: payments,
|
||||
expectedErr: "not in a valid stage to register payments",
|
||||
},
|
||||
{
|
||||
round: &domain.Round{
|
||||
Id: "id",
|
||||
Stage: domain.Stage{
|
||||
Code: domain.RegistrationStage,
|
||||
},
|
||||
},
|
||||
payments: nil,
|
||||
expectedErr: "missing payments to register",
|
||||
},
|
||||
}
|
||||
|
||||
for _, f := range fixtures {
|
||||
events, err := f.round.RegisterPayments(f.payments)
|
||||
require.EqualError(t, err, f.expectedErr)
|
||||
require.Empty(t, events)
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func testStartFinalization(t *testing.T) {
|
||||
t.Run("start_finalization", func(t *testing.T) {
|
||||
t.Run("valid", func(t *testing.T) {
|
||||
round := domain.NewRound(dustAmount)
|
||||
events, err := round.StartRegistration()
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, events)
|
||||
|
||||
events, err = round.RegisterPayments(payments)
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, events)
|
||||
|
||||
events, err = round.StartFinalization(connectors, congestionTree, poolTx)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, events, 1)
|
||||
require.True(t, round.IsStarted())
|
||||
require.False(t, round.IsEnded())
|
||||
require.False(t, round.IsFailed())
|
||||
|
||||
event, ok := events[0].(domain.RoundFinalizationStarted)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, round.Id, event.Id)
|
||||
require.Exactly(t, connectors, event.Connectors)
|
||||
require.Exactly(t, congestionTree, event.CongestionTree)
|
||||
require.Exactly(t, poolTx, event.PoolTx)
|
||||
})
|
||||
|
||||
t.Run("invalid", func(t *testing.T) {
|
||||
paymentsById := map[string]domain.Payment{}
|
||||
for _, p := range payments {
|
||||
paymentsById[p.Id] = p
|
||||
}
|
||||
fixtures := []struct {
|
||||
round *domain.Round
|
||||
connectors []string
|
||||
tree []string
|
||||
poolTx string
|
||||
expectedErr string
|
||||
}{
|
||||
{
|
||||
round: &domain.Round{
|
||||
Id: "0",
|
||||
Stage: domain.Stage{
|
||||
Code: domain.RegistrationStage,
|
||||
},
|
||||
Payments: paymentsById,
|
||||
},
|
||||
connectors: nil,
|
||||
tree: congestionTree,
|
||||
poolTx: poolTx,
|
||||
expectedErr: "missing list of connectors",
|
||||
},
|
||||
{
|
||||
round: &domain.Round{
|
||||
Id: "0",
|
||||
Stage: domain.Stage{
|
||||
Code: domain.RegistrationStage,
|
||||
},
|
||||
Payments: paymentsById,
|
||||
},
|
||||
connectors: connectors,
|
||||
tree: nil,
|
||||
poolTx: poolTx,
|
||||
expectedErr: "missing congestion tree",
|
||||
},
|
||||
{
|
||||
round: &domain.Round{
|
||||
Id: "0",
|
||||
Stage: domain.Stage{
|
||||
Code: domain.RegistrationStage,
|
||||
},
|
||||
Payments: paymentsById,
|
||||
},
|
||||
connectors: connectors,
|
||||
tree: congestionTree,
|
||||
poolTx: "",
|
||||
expectedErr: "missing unsigned pool tx",
|
||||
},
|
||||
{
|
||||
round: &domain.Round{
|
||||
Id: "0",
|
||||
Stage: domain.Stage{
|
||||
Code: domain.RegistrationStage,
|
||||
},
|
||||
Payments: nil,
|
||||
},
|
||||
connectors: connectors,
|
||||
tree: congestionTree,
|
||||
poolTx: poolTx,
|
||||
expectedErr: "no payments registered",
|
||||
},
|
||||
{
|
||||
round: &domain.Round{
|
||||
Id: "0",
|
||||
Stage: domain.Stage{
|
||||
Code: domain.UndefinedStage,
|
||||
},
|
||||
Payments: paymentsById,
|
||||
},
|
||||
connectors: connectors,
|
||||
tree: congestionTree,
|
||||
poolTx: poolTx,
|
||||
expectedErr: "not in a valid stage to start payment finalization",
|
||||
},
|
||||
{
|
||||
round: &domain.Round{
|
||||
Id: "0",
|
||||
Stage: domain.Stage{
|
||||
Code: domain.RegistrationStage,
|
||||
Failed: true,
|
||||
},
|
||||
Payments: paymentsById,
|
||||
},
|
||||
connectors: connectors,
|
||||
tree: congestionTree,
|
||||
poolTx: poolTx,
|
||||
expectedErr: "not in a valid stage to start payment finalization",
|
||||
},
|
||||
{
|
||||
round: &domain.Round{
|
||||
Id: "0",
|
||||
Stage: domain.Stage{
|
||||
Code: domain.FinalizationStage,
|
||||
},
|
||||
Payments: paymentsById,
|
||||
},
|
||||
connectors: connectors,
|
||||
tree: congestionTree,
|
||||
poolTx: poolTx,
|
||||
expectedErr: "not in a valid stage to start payment finalization",
|
||||
},
|
||||
}
|
||||
|
||||
for _, f := range fixtures {
|
||||
events, err := f.round.StartFinalization(f.connectors, f.tree, f.poolTx)
|
||||
require.EqualError(t, err, f.expectedErr)
|
||||
require.Empty(t, events)
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func testEndFinalization(t *testing.T) {
|
||||
t.Run("end_registration", func(t *testing.T) {
|
||||
t.Run("valid", func(t *testing.T) {
|
||||
round := domain.NewRound(dustAmount)
|
||||
events, err := round.StartRegistration()
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, events)
|
||||
|
||||
events, err = round.RegisterPayments(payments)
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, events)
|
||||
|
||||
events, err = round.StartFinalization(connectors, congestionTree, poolTx)
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, events)
|
||||
|
||||
events, err = round.EndFinalization(forfeitTxs, txid)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, events, 1)
|
||||
require.False(t, round.IsStarted())
|
||||
require.True(t, round.IsEnded())
|
||||
require.False(t, round.IsFailed())
|
||||
|
||||
event, ok := events[0].(domain.RoundFinalized)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, round.Id, event.Id)
|
||||
require.Exactly(t, txid, event.Txid)
|
||||
require.Exactly(t, forfeitTxs, event.ForfeitTxs)
|
||||
require.Exactly(t, round.EndingTimestamp, event.Timestamp)
|
||||
})
|
||||
|
||||
t.Run("invalid", func(t *testing.T) {
|
||||
paymentsById := map[string]domain.Payment{}
|
||||
for _, p := range payments {
|
||||
paymentsById[p.Id] = p
|
||||
}
|
||||
fixtures := []struct {
|
||||
round *domain.Round
|
||||
forfeitTxs []string
|
||||
txid string
|
||||
expectedErr string
|
||||
}{
|
||||
{
|
||||
round: &domain.Round{
|
||||
Id: "0",
|
||||
Stage: domain.Stage{
|
||||
Code: domain.FinalizationStage,
|
||||
},
|
||||
},
|
||||
forfeitTxs: nil,
|
||||
txid: txid,
|
||||
expectedErr: "missing list of signed forfeit txs",
|
||||
},
|
||||
{
|
||||
round: &domain.Round{
|
||||
Id: "0",
|
||||
Stage: domain.Stage{
|
||||
Code: domain.FinalizationStage,
|
||||
},
|
||||
},
|
||||
forfeitTxs: forfeitTxs,
|
||||
txid: "",
|
||||
expectedErr: "missing pool txid",
|
||||
},
|
||||
{
|
||||
round: &domain.Round{
|
||||
Id: "0",
|
||||
},
|
||||
forfeitTxs: forfeitTxs,
|
||||
txid: txid,
|
||||
expectedErr: "not in a valid stage to end payment finalization",
|
||||
},
|
||||
{
|
||||
round: &domain.Round{
|
||||
Id: "0",
|
||||
Stage: domain.Stage{
|
||||
Code: domain.RegistrationStage,
|
||||
},
|
||||
},
|
||||
forfeitTxs: forfeitTxs,
|
||||
txid: txid,
|
||||
expectedErr: "not in a valid stage to end payment finalization",
|
||||
},
|
||||
{
|
||||
round: &domain.Round{
|
||||
Id: "0",
|
||||
Stage: domain.Stage{
|
||||
Code: domain.FinalizationStage,
|
||||
Failed: true,
|
||||
},
|
||||
},
|
||||
forfeitTxs: []string{emptyPtx, emptyPtx, emptyPtx, emptyPtx},
|
||||
txid: txid,
|
||||
expectedErr: "not in a valid stage to end payment finalization",
|
||||
},
|
||||
{
|
||||
round: &domain.Round{
|
||||
Id: "0",
|
||||
Stage: domain.Stage{
|
||||
Code: domain.FinalizationStage,
|
||||
Ended: true,
|
||||
},
|
||||
},
|
||||
forfeitTxs: []string{emptyPtx, emptyPtx, emptyPtx, emptyPtx},
|
||||
txid: txid,
|
||||
expectedErr: "round already finalized",
|
||||
},
|
||||
}
|
||||
|
||||
for _, f := range fixtures {
|
||||
events, err := f.round.EndFinalization(f.forfeitTxs, f.txid)
|
||||
require.EqualError(t, err, f.expectedErr)
|
||||
require.Empty(t, events)
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func testFail(t *testing.T) {
|
||||
t.Run("fail", func(t *testing.T) {
|
||||
t.Run("valid", func(t *testing.T) {
|
||||
round := domain.NewRound(dustAmount)
|
||||
events, err := round.StartRegistration()
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, events)
|
||||
|
||||
events, err = round.RegisterPayments(payments)
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, events)
|
||||
|
||||
reason := "some valid reason"
|
||||
events = round.Fail(fmt.Errorf(reason))
|
||||
require.Len(t, events, 1)
|
||||
require.False(t, round.IsStarted())
|
||||
require.False(t, round.IsEnded())
|
||||
require.True(t, round.IsFailed())
|
||||
|
||||
event, ok := events[0].(domain.RoundFailed)
|
||||
require.True(t, ok)
|
||||
require.Exactly(t, round.Id, event.Id)
|
||||
require.Exactly(t, round.EndingTimestamp, event.Timestamp)
|
||||
require.EqualError(t, event.Err, reason)
|
||||
|
||||
events = round.Fail(fmt.Errorf(reason))
|
||||
require.Empty(t, events)
|
||||
})
|
||||
})
|
||||
}
|
||||
@@ -20,7 +20,7 @@ type eventsDTO struct {
|
||||
|
||||
type eventRepository struct {
|
||||
store *badgerhold.Store
|
||||
lock *sync.Mutex
|
||||
lock *sync.RWMutex
|
||||
chUpdates chan *domain.Round
|
||||
handler func(round *domain.Round)
|
||||
}
|
||||
@@ -51,7 +51,7 @@ func NewRoundEventRepository(config ...interface{}) (dbtypes.EventStore, error)
|
||||
return nil, fmt.Errorf("failed to open round events store: %s", err)
|
||||
}
|
||||
chEvents := make(chan *domain.Round)
|
||||
lock := &sync.Mutex{}
|
||||
lock := &sync.RWMutex{}
|
||||
repo := &eventRepository{store, lock, chEvents, nil}
|
||||
go repo.listen()
|
||||
return repo, nil
|
||||
@@ -85,6 +85,9 @@ func (r *eventRepository) Load(
|
||||
func (r *eventRepository) RegisterEventsHandler(
|
||||
handler func(round *domain.Round),
|
||||
) {
|
||||
r.lock.Lock()
|
||||
defer r.lock.Unlock()
|
||||
|
||||
r.handler = handler
|
||||
}
|
||||
|
||||
@@ -135,9 +138,11 @@ func (r *eventRepository) upsert(
|
||||
|
||||
func (r *eventRepository) listen() {
|
||||
for updatedRound := range r.chUpdates {
|
||||
r.lock.RLock()
|
||||
if r.handler != nil {
|
||||
r.handler(updatedRound)
|
||||
}
|
||||
r.lock.RUnlock()
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user