Unit tests (#32)

* unit tests

* Fix makefile

* Fix race conditions

* Renaming
This commit is contained in:
Pietralberto Mazza
2023-12-01 17:50:42 +01:00
committed by GitHub
parent c8d9db89c5
commit 46d54a227d
7 changed files with 672 additions and 51 deletions

View File

@@ -27,7 +27,7 @@ help:
## intergrationtest: runs integration tests
integrationtest:
@echo "Running integration tests..."
@go test -v -count=1 -race ./... $(go list ./... | grep internal/test)
@find . -name go.mod -execdir go test -v -count=1 -race $(go list ./... | grep internal/test) \;
## lint: lint codebase
lint:

View File

@@ -19,7 +19,10 @@ import (
"github.com/vulpemventures/go-elements/psetv2"
)
const paymentsThreshold = 128
const (
paymentsThreshold = 128
dustAmount = 450
)
type Service interface {
SpendVtxos(ctx context.Context, inputs []domain.VtxoKey) (string, error)
@@ -144,7 +147,7 @@ func (s *service) start() error {
}
func (s *service) startRound() {
round := domain.NewRound()
round := domain.NewRound(dustAmount)
changes, _ := round.StartRegistration()
if err := s.repoManager.Events().Save(
context.Background(), round.Id, changes...,

View File

@@ -26,21 +26,29 @@ func NewPayment(inputs []Vtxo) (*Payment, error) {
return p, nil
}
func (p *Payment) AddReceivers(recievers []Receiver) (err error) {
func (p *Payment) AddReceivers(receivers []Receiver) (err error) {
if p.Receivers == nil {
p.Receivers = make([]Receiver, 0)
}
p.Receivers = append(p.Receivers, recievers...)
p.Receivers = append(p.Receivers, receivers...)
defer func() {
if err != nil {
p.Receivers = p.Receivers[:len(p.Receivers)-len(recievers)]
p.Receivers = p.Receivers[:len(p.Receivers)-len(receivers)]
}
}()
err = p.validate(false)
return
}
func (p Payment) TotOutputAmount() uint64 {
func (p Payment) TotalInputAmount() uint64 {
tot := uint64(0)
for _, in := range p.Inputs {
tot += in.Amount
}
return tot
}
func (p Payment) TotalOutputAmount() uint64 {
tot := uint64(0)
for _, r := range p.Receivers {
tot += r.Amount
@@ -62,14 +70,8 @@ func (p Payment) validate(ignoreOuts bool) error {
return fmt.Errorf("missing outputs")
}
// Check that input and output and output amounts match.
inAmount := uint64(0)
for _, in := range p.Inputs {
inAmount += in.Amount
}
outAmount := uint64(0)
for _, v := range p.Receivers {
outAmount += v.Amount
}
inAmount := p.TotalInputAmount()
outAmount := p.TotalOutputAmount()
if inAmount != outAmount {
return fmt.Errorf("input and output amounts mismatch")
}

View File

@@ -0,0 +1,102 @@
package domain_test
import (
"testing"
"github.com/ark-network/ark/internal/core/domain"
"github.com/stretchr/testify/require"
)
var inputs = []domain.Vtxo{
{
VtxoKey: domain.VtxoKey{
Txid: "0000000000000000000000000000000000000000000000000000000000000000",
VOut: 0,
},
Receiver: domain.Receiver{
Pubkey: "030000000000000000000000000000000000000000000000000000000000000001",
Amount: 500,
},
},
}
func TestPayment(t *testing.T) {
t.Run("new_payment", func(t *testing.T) {
t.Run("vaild", func(t *testing.T) {
payment, err := domain.NewPayment(inputs)
require.NoError(t, err)
require.NotNil(t, payment)
require.NotEmpty(t, payment.Id)
require.Exactly(t, inputs, payment.Inputs)
require.Empty(t, payment.Receivers)
})
t.Run("invaild", func(t *testing.T) {
fixtures := []struct {
inputs []domain.Vtxo
expectedErr string
}{
{
inputs: nil,
expectedErr: "missing inputs",
},
}
for _, f := range fixtures {
payment, err := domain.NewPayment(f.inputs)
require.EqualError(t, err, f.expectedErr)
require.Nil(t, payment)
}
})
})
t.Run("add_receivers", func(t *testing.T) {
t.Run("valid", func(t *testing.T) {
payment, err := domain.NewPayment(inputs)
require.NoError(t, err)
require.NotNil(t, payment)
err = payment.AddReceivers([]domain.Receiver{
{
Pubkey: "030000000000000000000000000000000000000000000000000000000000000001",
Amount: 200,
},
{
Pubkey: "020000000000000000000000000000000000000000000000000000000000000002",
Amount: 300,
},
})
require.NoError(t, err)
})
t.Run("invalid", func(t *testing.T) {
fixtures := []struct {
receivers []domain.Receiver
expectedErr string
}{
{
receivers: nil,
expectedErr: "missing outputs",
},
{
receivers: []domain.Receiver{
{
Pubkey: "030000000000000000000000000000000000000000000000000000000000000001",
Amount: 100,
},
},
expectedErr: "input and output amounts mismatch",
},
}
payment, err := domain.NewPayment(inputs)
require.NoError(t, err)
require.NotNil(t, payment)
for _, f := range fixtures {
err := payment.AddReceivers(f.receivers)
require.EqualError(t, err, f.expectedErr)
}
})
})
}

View File

@@ -11,8 +11,6 @@ const (
UndefinedStage RoundStage = iota
RegistrationStage
FinalizationStage
dustAmount = 450
)
type RoundStage int
@@ -45,15 +43,17 @@ type Round struct {
ForfeitTxs []string
CongestionTree []string
Connectors []string
DustAmount uint64
Version uint
changes []RoundEvent
}
func NewRound() *Round {
func NewRound(dustAmount uint64) *Round {
return &Round{
Id: uuid.New().String(),
Payments: make(map[string]Payment),
changes: make([]RoundEvent, 0),
Id: uuid.New().String(),
DustAmount: dustAmount,
Payments: make(map[string]Payment),
changes: make([]RoundEvent, 0),
}
}
@@ -121,10 +121,44 @@ func (r *Round) StartRegistration() ([]RoundEvent, error) {
return []RoundEvent{event}, nil
}
func (r *Round) RegisterPayments(payments []Payment) ([]RoundEvent, error) {
if r.Stage.Code != RegistrationStage || r.IsFailed() {
return nil, fmt.Errorf("not in a valid stage to register payments")
}
if len(payments) <= 0 {
return nil, fmt.Errorf("missing payments to register")
}
for _, p := range payments {
if err := p.validate(false); err != nil {
return nil, err
}
}
event := PaymentsRegistered{
Id: r.Id,
Payments: payments,
}
r.raise(event)
return []RoundEvent{event}, nil
}
func (r *Round) StartFinalization(connectors, tree []string, poolTx string) ([]RoundEvent, error) {
if len(connectors) <= 0 {
return nil, fmt.Errorf("missing list of connectors")
}
if len(tree) <= 0 {
return nil, fmt.Errorf("missing congestion tree")
}
if len(poolTx) <= 0 {
return nil, fmt.Errorf("missing unsigned pool tx")
}
if r.Stage.Code != RegistrationStage || r.IsFailed() {
return nil, fmt.Errorf("not in a valid stage to start payment finalization")
}
if len(r.Payments) <= 0 {
return nil, fmt.Errorf("no payments registered")
}
event := RoundFinalizationStarted{
Id: r.Id,
@@ -138,11 +172,17 @@ func (r *Round) StartFinalization(connectors, tree []string, poolTx string) ([]R
}
func (r *Round) EndFinalization(forfeitTxs []string, txid string) ([]RoundEvent, error) {
if len(forfeitTxs) <= 0 {
return nil, fmt.Errorf("missing list of signed forfeit txs")
}
if len(txid) <= 0 {
return nil, fmt.Errorf("missing pool txid")
}
if r.Stage.Code != FinalizationStage || r.IsFailed() {
return nil, fmt.Errorf("not in a valid stage to end payment finalization")
}
if r.Stage.Ended {
return nil, fmt.Errorf("payment finalization already ended")
return nil, fmt.Errorf("round already finalized")
}
event := RoundFinalized{
Id: r.Id,
@@ -169,30 +209,9 @@ func (r *Round) Fail(err error) []RoundEvent {
return []RoundEvent{event}
}
func (r *Round) RegisterPayments(payments []Payment) ([]RoundEvent, error) {
if !r.IsStarted() {
return nil, fmt.Errorf("not in a valid stage to register payments")
}
if len(payments) <= 0 {
return nil, fmt.Errorf("missing payments to register")
}
for _, p := range payments {
if err := p.validate(false); err != nil {
return nil, err
}
}
event := PaymentsRegistered{
Id: r.Id,
Payments: payments,
}
r.raise(event)
return []RoundEvent{event}, nil
}
func (r *Round) IsStarted() bool {
return !r.IsFailed() && r.Stage.Code == RegistrationStage
empty := Stage{}
return !r.IsFailed() && !r.IsEnded() && r.Stage != empty
}
func (r *Round) IsEnded() bool {
@@ -203,14 +222,18 @@ func (r *Round) IsFailed() bool {
return r.Stage.Failed
}
func (r *Round) TotInputAmount() uint64 {
return uint64(len(r.Payments) * dustAmount)
func (r *Round) TotalInputAmount() uint64 {
totInputs := 0
for _, p := range r.Payments {
totInputs += len(p.Inputs)
}
return uint64(totInputs * int(r.DustAmount))
}
func (r *Round) TotOutputAmount() uint64 {
func (r *Round) TotalOutputAmount() uint64 {
tot := uint64(0)
for _, p := range r.Payments {
tot += p.TotOutputAmount()
tot += p.TotalOutputAmount()
}
return tot
}

View File

@@ -0,0 +1,486 @@
package domain_test
import (
"fmt"
"testing"
"github.com/ark-network/ark/internal/core/domain"
"github.com/stretchr/testify/require"
)
var (
dustAmount = uint64(450)
payments = []domain.Payment{
{
Id: "0",
Inputs: []domain.Vtxo{{}},
Receivers: []domain.Receiver{{}, {}, {}},
},
{
Id: "1",
Inputs: []domain.Vtxo{{}, {}},
Receivers: []domain.Receiver{{}},
},
}
emptyPtx = "cHNldP8BAgQCAAAAAQQBAAEFAQABBgEDAfsEAgAAAAA="
emptyTx = "0200000000000000000000"
txid = "0000000000000000000000000000000000000000000000000000000000000000"
congestionTree = []string{emptyPtx, emptyPtx, emptyPtx, emptyPtx, emptyPtx, emptyPtx, emptyPtx}
connectors = []string{emptyPtx, emptyPtx, emptyPtx}
forfeitTxs = []string{emptyPtx, emptyPtx, emptyPtx, emptyPtx, emptyPtx, emptyPtx, emptyPtx, emptyPtx, emptyPtx}
poolTx = emptyTx
)
func TestRound(t *testing.T) {
testStartRegistration(t)
testRegisterPayments(t)
testStartFinalization(t)
testEndFinalization(t)
testFail(t)
}
func testStartRegistration(t *testing.T) {
t.Run("start_registration", func(t *testing.T) {
t.Run("valid", func(t *testing.T) {
round := domain.NewRound(dustAmount)
require.NotNil(t, round)
require.NotEmpty(t, round.Id)
require.Empty(t, round.Events())
require.False(t, round.IsStarted())
require.False(t, round.IsEnded())
require.False(t, round.IsFailed())
events, err := round.StartRegistration()
require.NoError(t, err)
require.Len(t, events, 1)
require.True(t, round.IsStarted())
require.False(t, round.IsEnded())
require.False(t, round.IsFailed())
event, ok := events[0].(domain.RoundStarted)
require.True(t, ok)
require.Equal(t, round.Id, event.Id)
require.Equal(t, round.StartingTimestamp, event.Timestamp)
})
t.Run("invalid", func(t *testing.T) {
fixtures := []struct {
round *domain.Round
expectedErr string
}{
{
round: &domain.Round{
Id: "id",
Stage: domain.Stage{
Code: domain.UndefinedStage,
Failed: true,
},
},
expectedErr: "not in a valid stage to start payment registration",
},
{
round: &domain.Round{
Id: "id",
Stage: domain.Stage{
Code: domain.RegistrationStage,
},
},
expectedErr: "not in a valid stage to start payment registration",
},
{
round: &domain.Round{
Id: "id",
Stage: domain.Stage{
Code: domain.FinalizationStage,
},
},
expectedErr: "not in a valid stage to start payment registration",
},
}
for _, f := range fixtures {
events, err := f.round.StartRegistration()
require.EqualError(t, err, f.expectedErr)
require.Empty(t, events)
}
})
})
}
func testRegisterPayments(t *testing.T) {
t.Run("register_payments", func(t *testing.T) {
t.Run("valid", func(t *testing.T) {
round := domain.NewRound(dustAmount)
events, err := round.StartRegistration()
require.NoError(t, err)
require.NotEmpty(t, events)
events, err = round.RegisterPayments(payments)
require.NoError(t, err)
require.Len(t, events, 1)
require.Condition(t, func() bool {
for _, payment := range payments {
_, ok := round.Payments[payment.Id]
if !ok {
return false
}
}
return true
})
event, ok := events[0].(domain.PaymentsRegistered)
require.True(t, ok)
require.Equal(t, round.Id, event.Id)
require.Equal(t, payments, event.Payments)
})
t.Run("invalid", func(t *testing.T) {
fixtures := []struct {
round *domain.Round
payments []domain.Payment
expectedErr string
}{
{
round: &domain.Round{
Id: "id",
Stage: domain.Stage{},
},
payments: payments,
expectedErr: "not in a valid stage to register payments",
},
{
round: &domain.Round{
Id: "id",
Stage: domain.Stage{
Code: domain.RegistrationStage,
Failed: true,
},
},
payments: payments,
expectedErr: "not in a valid stage to register payments",
},
{
round: &domain.Round{
Id: "id",
Stage: domain.Stage{
Code: domain.FinalizationStage,
},
},
payments: payments,
expectedErr: "not in a valid stage to register payments",
},
{
round: &domain.Round{
Id: "id",
Stage: domain.Stage{
Code: domain.RegistrationStage,
},
},
payments: nil,
expectedErr: "missing payments to register",
},
}
for _, f := range fixtures {
events, err := f.round.RegisterPayments(f.payments)
require.EqualError(t, err, f.expectedErr)
require.Empty(t, events)
}
})
})
}
func testStartFinalization(t *testing.T) {
t.Run("start_finalization", func(t *testing.T) {
t.Run("valid", func(t *testing.T) {
round := domain.NewRound(dustAmount)
events, err := round.StartRegistration()
require.NoError(t, err)
require.NotEmpty(t, events)
events, err = round.RegisterPayments(payments)
require.NoError(t, err)
require.NotEmpty(t, events)
events, err = round.StartFinalization(connectors, congestionTree, poolTx)
require.NoError(t, err)
require.Len(t, events, 1)
require.True(t, round.IsStarted())
require.False(t, round.IsEnded())
require.False(t, round.IsFailed())
event, ok := events[0].(domain.RoundFinalizationStarted)
require.True(t, ok)
require.Equal(t, round.Id, event.Id)
require.Exactly(t, connectors, event.Connectors)
require.Exactly(t, congestionTree, event.CongestionTree)
require.Exactly(t, poolTx, event.PoolTx)
})
t.Run("invalid", func(t *testing.T) {
paymentsById := map[string]domain.Payment{}
for _, p := range payments {
paymentsById[p.Id] = p
}
fixtures := []struct {
round *domain.Round
connectors []string
tree []string
poolTx string
expectedErr string
}{
{
round: &domain.Round{
Id: "0",
Stage: domain.Stage{
Code: domain.RegistrationStage,
},
Payments: paymentsById,
},
connectors: nil,
tree: congestionTree,
poolTx: poolTx,
expectedErr: "missing list of connectors",
},
{
round: &domain.Round{
Id: "0",
Stage: domain.Stage{
Code: domain.RegistrationStage,
},
Payments: paymentsById,
},
connectors: connectors,
tree: nil,
poolTx: poolTx,
expectedErr: "missing congestion tree",
},
{
round: &domain.Round{
Id: "0",
Stage: domain.Stage{
Code: domain.RegistrationStage,
},
Payments: paymentsById,
},
connectors: connectors,
tree: congestionTree,
poolTx: "",
expectedErr: "missing unsigned pool tx",
},
{
round: &domain.Round{
Id: "0",
Stage: domain.Stage{
Code: domain.RegistrationStage,
},
Payments: nil,
},
connectors: connectors,
tree: congestionTree,
poolTx: poolTx,
expectedErr: "no payments registered",
},
{
round: &domain.Round{
Id: "0",
Stage: domain.Stage{
Code: domain.UndefinedStage,
},
Payments: paymentsById,
},
connectors: connectors,
tree: congestionTree,
poolTx: poolTx,
expectedErr: "not in a valid stage to start payment finalization",
},
{
round: &domain.Round{
Id: "0",
Stage: domain.Stage{
Code: domain.RegistrationStage,
Failed: true,
},
Payments: paymentsById,
},
connectors: connectors,
tree: congestionTree,
poolTx: poolTx,
expectedErr: "not in a valid stage to start payment finalization",
},
{
round: &domain.Round{
Id: "0",
Stage: domain.Stage{
Code: domain.FinalizationStage,
},
Payments: paymentsById,
},
connectors: connectors,
tree: congestionTree,
poolTx: poolTx,
expectedErr: "not in a valid stage to start payment finalization",
},
}
for _, f := range fixtures {
events, err := f.round.StartFinalization(f.connectors, f.tree, f.poolTx)
require.EqualError(t, err, f.expectedErr)
require.Empty(t, events)
}
})
})
}
func testEndFinalization(t *testing.T) {
t.Run("end_registration", func(t *testing.T) {
t.Run("valid", func(t *testing.T) {
round := domain.NewRound(dustAmount)
events, err := round.StartRegistration()
require.NoError(t, err)
require.NotEmpty(t, events)
events, err = round.RegisterPayments(payments)
require.NoError(t, err)
require.NotEmpty(t, events)
events, err = round.StartFinalization(connectors, congestionTree, poolTx)
require.NoError(t, err)
require.NotEmpty(t, events)
events, err = round.EndFinalization(forfeitTxs, txid)
require.NoError(t, err)
require.Len(t, events, 1)
require.False(t, round.IsStarted())
require.True(t, round.IsEnded())
require.False(t, round.IsFailed())
event, ok := events[0].(domain.RoundFinalized)
require.True(t, ok)
require.Equal(t, round.Id, event.Id)
require.Exactly(t, txid, event.Txid)
require.Exactly(t, forfeitTxs, event.ForfeitTxs)
require.Exactly(t, round.EndingTimestamp, event.Timestamp)
})
t.Run("invalid", func(t *testing.T) {
paymentsById := map[string]domain.Payment{}
for _, p := range payments {
paymentsById[p.Id] = p
}
fixtures := []struct {
round *domain.Round
forfeitTxs []string
txid string
expectedErr string
}{
{
round: &domain.Round{
Id: "0",
Stage: domain.Stage{
Code: domain.FinalizationStage,
},
},
forfeitTxs: nil,
txid: txid,
expectedErr: "missing list of signed forfeit txs",
},
{
round: &domain.Round{
Id: "0",
Stage: domain.Stage{
Code: domain.FinalizationStage,
},
},
forfeitTxs: forfeitTxs,
txid: "",
expectedErr: "missing pool txid",
},
{
round: &domain.Round{
Id: "0",
},
forfeitTxs: forfeitTxs,
txid: txid,
expectedErr: "not in a valid stage to end payment finalization",
},
{
round: &domain.Round{
Id: "0",
Stage: domain.Stage{
Code: domain.RegistrationStage,
},
},
forfeitTxs: forfeitTxs,
txid: txid,
expectedErr: "not in a valid stage to end payment finalization",
},
{
round: &domain.Round{
Id: "0",
Stage: domain.Stage{
Code: domain.FinalizationStage,
Failed: true,
},
},
forfeitTxs: []string{emptyPtx, emptyPtx, emptyPtx, emptyPtx},
txid: txid,
expectedErr: "not in a valid stage to end payment finalization",
},
{
round: &domain.Round{
Id: "0",
Stage: domain.Stage{
Code: domain.FinalizationStage,
Ended: true,
},
},
forfeitTxs: []string{emptyPtx, emptyPtx, emptyPtx, emptyPtx},
txid: txid,
expectedErr: "round already finalized",
},
}
for _, f := range fixtures {
events, err := f.round.EndFinalization(f.forfeitTxs, f.txid)
require.EqualError(t, err, f.expectedErr)
require.Empty(t, events)
}
})
})
}
func testFail(t *testing.T) {
t.Run("fail", func(t *testing.T) {
t.Run("valid", func(t *testing.T) {
round := domain.NewRound(dustAmount)
events, err := round.StartRegistration()
require.NoError(t, err)
require.NotEmpty(t, events)
events, err = round.RegisterPayments(payments)
require.NoError(t, err)
require.NotEmpty(t, events)
reason := "some valid reason"
events = round.Fail(fmt.Errorf(reason))
require.Len(t, events, 1)
require.False(t, round.IsStarted())
require.False(t, round.IsEnded())
require.True(t, round.IsFailed())
event, ok := events[0].(domain.RoundFailed)
require.True(t, ok)
require.Exactly(t, round.Id, event.Id)
require.Exactly(t, round.EndingTimestamp, event.Timestamp)
require.EqualError(t, event.Err, reason)
events = round.Fail(fmt.Errorf(reason))
require.Empty(t, events)
})
})
}

View File

@@ -20,7 +20,7 @@ type eventsDTO struct {
type eventRepository struct {
store *badgerhold.Store
lock *sync.Mutex
lock *sync.RWMutex
chUpdates chan *domain.Round
handler func(round *domain.Round)
}
@@ -51,7 +51,7 @@ func NewRoundEventRepository(config ...interface{}) (dbtypes.EventStore, error)
return nil, fmt.Errorf("failed to open round events store: %s", err)
}
chEvents := make(chan *domain.Round)
lock := &sync.Mutex{}
lock := &sync.RWMutex{}
repo := &eventRepository{store, lock, chEvents, nil}
go repo.listen()
return repo, nil
@@ -85,6 +85,9 @@ func (r *eventRepository) Load(
func (r *eventRepository) RegisterEventsHandler(
handler func(round *domain.Round),
) {
r.lock.Lock()
defer r.lock.Unlock()
r.handler = handler
}
@@ -135,9 +138,11 @@ func (r *eventRepository) upsert(
func (r *eventRepository) listen() {
for updatedRound := range r.chUpdates {
r.lock.RLock()
if r.handler != nil {
r.handler(updatedRound)
}
r.lock.RUnlock()
}
}