Fix errors on round finalization (#199)

* Fix ListConnectorUtxos

* Fix

* Fix

* Add log

* Store current round in memoory and drop GetCurrentRound repo api

* Skip lint
This commit is contained in:
Pietralberto Mazza
2024-07-08 14:22:35 +02:00
committed by GitHub
parent c274d670fe
commit d10c724ced
6 changed files with 35 additions and 89 deletions

View File

@@ -79,6 +79,7 @@ type service struct {
trustedOnboardingScriptLock *sync.Mutex trustedOnboardingScriptLock *sync.Mutex
trustedOnboardingScripts map[string]*secp256k1.PublicKey trustedOnboardingScripts map[string]*secp256k1.PublicKey
currentRound *domain.Round
} }
func NewService( func NewService(
@@ -106,7 +107,7 @@ func NewService(
roundLifetime, roundInterval, unilateralExitDelay, minRelayFee, roundLifetime, roundInterval, unilateralExitDelay, minRelayFee,
walletSvc, repoManager, builder, scanner, sweeper, walletSvc, repoManager, builder, scanner, sweeper,
paymentRequests, forfeitTxs, eventsCh, onboardingCh, paymentRequests, forfeitTxs, eventsCh, onboardingCh,
&sync.Mutex{}, make(map[string]*secp256k1.PublicKey), &sync.Mutex{}, make(map[string]*secp256k1.PublicKey), nil,
} }
repoManager.RegisterEventsHandler( repoManager.RegisterEventsHandler(
func(round *domain.Round) { func(round *domain.Round) {
@@ -219,7 +220,7 @@ func (s *service) GetRoundByTxid(ctx context.Context, poolTxid string) (*domain.
} }
func (s *service) GetCurrentRound(ctx context.Context) (*domain.Round, error) { func (s *service) GetCurrentRound(ctx context.Context) (*domain.Round, error) {
return s.repoManager.Rounds().GetCurrentRound(ctx) return domain.NewRoundFromEvents(s.currentRound.Events()), nil
} }
func (s *service) GetInfo(ctx context.Context) (*ServiceInfo, error) { func (s *service) GetInfo(ctx context.Context) (*ServiceInfo, error) {
@@ -324,13 +325,9 @@ func (s *service) start() {
func (s *service) startRound() { func (s *service) startRound() {
round := domain.NewRound(dustAmount) round := domain.NewRound(dustAmount)
changes, _ := round.StartRegistration() //nolint:all
if err := s.saveEvents( round.StartRegistration()
context.Background(), round.Id, changes, s.currentRound = round
); err != nil {
log.WithError(err).Warn("failed to store new round events")
return
}
defer func() { defer func() {
time.Sleep(time.Duration(s.roundInterval/2) * time.Second) time.Sleep(time.Duration(s.roundInterval/2) * time.Second)
@@ -342,15 +339,16 @@ func (s *service) startRound() {
func (s *service) startFinalization() { func (s *service) startFinalization() {
ctx := context.Background() ctx := context.Background()
round, err := s.repoManager.Rounds().GetCurrentRound(ctx) round := s.currentRound
if err != nil {
log.WithError(err).Warn("failed to retrieve current round")
return
}
var changes []domain.RoundEvent var roundAborted bool
defer func() { defer func() {
if err := s.saveEvents(ctx, round.Id, changes); err != nil { if roundAborted {
s.startRound()
return
}
if err := s.saveEvents(ctx, round.Id, round.Events()); err != nil {
log.WithError(err).Warn("failed to store new round events") log.WithError(err).Warn("failed to store new round events")
} }
@@ -369,8 +367,9 @@ func (s *service) startFinalization() {
// TODO: understand how many payments must be popped from the queue and actually registered for the round // TODO: understand how many payments must be popped from the queue and actually registered for the round
num := s.paymentRequests.len() num := s.paymentRequests.len()
if num == 0 { if num == 0 {
roundAborted = true
err := fmt.Errorf("no payments registered") err := fmt.Errorf("no payments registered")
changes = round.Fail(fmt.Errorf("round aborted: %s", err)) round.Fail(fmt.Errorf("round aborted: %s", err))
log.WithError(err).Debugf("round %s aborted", round.Id) log.WithError(err).Debugf("round %s aborted", round.Id)
return return
} }
@@ -378,23 +377,22 @@ func (s *service) startFinalization() {
num = paymentsThreshold num = paymentsThreshold
} }
payments := s.paymentRequests.pop(num) payments := s.paymentRequests.pop(num)
changes, err = round.RegisterPayments(payments) if _, err := round.RegisterPayments(payments); err != nil {
if err != nil { round.Fail(fmt.Errorf("failed to register payments: %s", err))
changes = round.Fail(fmt.Errorf("failed to register payments: %s", err))
log.WithError(err).Warn("failed to register payments") log.WithError(err).Warn("failed to register payments")
return return
} }
sweptRounds, err := s.repoManager.Rounds().GetSweptRounds(ctx) sweptRounds, err := s.repoManager.Rounds().GetSweptRounds(ctx)
if err != nil { if err != nil {
changes = round.Fail(fmt.Errorf("failed to retrieve swept rounds: %s", err)) round.Fail(fmt.Errorf("failed to retrieve swept rounds: %s", err))
log.WithError(err).Warn("failed to retrieve swept rounds") log.WithError(err).Warn("failed to retrieve swept rounds")
return return
} }
unsignedPoolTx, tree, connectorAddress, err := s.builder.BuildPoolTx(s.pubkey, payments, s.minRelayFee, sweptRounds) unsignedPoolTx, tree, connectorAddress, err := s.builder.BuildPoolTx(s.pubkey, payments, s.minRelayFee, sweptRounds)
if err != nil { if err != nil {
changes = round.Fail(fmt.Errorf("failed to create pool tx: %s", err)) round.Fail(fmt.Errorf("failed to create pool tx: %s", err))
log.WithError(err).Warn("failed to create pool tx") log.WithError(err).Warn("failed to create pool tx")
return return
} }
@@ -404,20 +402,20 @@ func (s *service) startFinalization() {
connectors, forfeitTxs, err := s.builder.BuildForfeitTxs(s.pubkey, unsignedPoolTx, payments, s.minRelayFee) connectors, forfeitTxs, err := s.builder.BuildForfeitTxs(s.pubkey, unsignedPoolTx, payments, s.minRelayFee)
if err != nil { if err != nil {
changes = round.Fail(fmt.Errorf("failed to create connectors and forfeit txs: %s", err)) round.Fail(fmt.Errorf("failed to create connectors and forfeit txs: %s", err))
log.WithError(err).Warn("failed to create connectors and forfeit txs") log.WithError(err).Warn("failed to create connectors and forfeit txs")
return return
} }
log.Debugf("forfeit transactions created for round %s", round.Id) log.Debugf("forfeit transactions created for round %s", round.Id)
events, err := round.StartFinalization(connectorAddress, connectors, tree, unsignedPoolTx) if _, err := round.StartFinalization(
if err != nil { connectorAddress, connectors, tree, unsignedPoolTx,
changes = round.Fail(fmt.Errorf("failed to start finalization: %s", err)) ); err != nil {
round.Fail(fmt.Errorf("failed to start finalization: %s", err))
log.WithError(err).Warn("failed to start finalization") log.WithError(err).Warn("failed to start finalization")
return return
} }
changes = append(changes, events...)
s.forfeitTxs.push(forfeitTxs) s.forfeitTxs.push(forfeitTxs)
@@ -428,15 +426,13 @@ func (s *service) finalizeRound() {
defer s.startRound() defer s.startRound()
ctx := context.Background() ctx := context.Background()
round, err := s.repoManager.Rounds().GetCurrentRound(ctx) round := s.currentRound
if err != nil {
log.WithError(err).Warn("failed to retrieve current round")
return
}
if round.IsFailed() { if round.IsFailed() {
return return
} }
fmt.Printf("%+v\n", *round)
var changes []domain.RoundEvent var changes []domain.RoundEvent
defer func() { defer func() {
if err := s.saveEvents(ctx, round.Id, changes); err != nil { if err := s.saveEvents(ctx, round.Id, changes); err != nil {

View File

@@ -13,7 +13,6 @@ type RoundEventRepository interface {
type RoundRepository interface { type RoundRepository interface {
AddOrUpdateRound(ctx context.Context, round Round) error AddOrUpdateRound(ctx context.Context, round Round) error
GetCurrentRound(ctx context.Context) (*Round, error)
GetRoundWithId(ctx context.Context, id string) (*Round, error) GetRoundWithId(ctx context.Context, id string) (*Round, error)
GetRoundWithTxid(ctx context.Context, txid string) (*Round, error) GetRoundWithTxid(ctx context.Context, txid string) (*Round, error)
GetSweepableRounds(ctx context.Context) ([]Round, error) GetSweepableRounds(ctx context.Context) ([]Round, error)

View File

@@ -50,20 +50,6 @@ func (r *roundRepository) AddOrUpdateRound(
return r.addOrUpdateRound(ctx, round) return r.addOrUpdateRound(ctx, round)
} }
func (r *roundRepository) GetCurrentRound(
ctx context.Context,
) (*domain.Round, error) {
query := badgerhold.Where("Stage.Ended").Eq(false).And("Stage.Failed").Eq(false)
rounds, err := r.findRound(ctx, query)
if err != nil {
return nil, err
}
if len(rounds) <= 0 {
return nil, fmt.Errorf("ongoing round not found")
}
return &rounds[0], nil
}
func (r *roundRepository) GetRoundWithId( func (r *roundRepository) GetRoundWithId(
ctx context.Context, id string, ctx context.Context, id string,
) (*domain.Round, error) { ) (*domain.Round, error) {

View File

@@ -230,11 +230,6 @@ func testRoundRepository(t *testing.T, svc ports.RepoManager) {
err = svc.Rounds().AddOrUpdateRound(ctx, *round) err = svc.Rounds().AddOrUpdateRound(ctx, *round)
require.NoError(t, err) require.NoError(t, err)
currentRound, err := svc.Rounds().GetCurrentRound(ctx)
require.NoError(t, err)
require.NotNil(t, currentRound)
require.Condition(t, roundsMatch(*round, *currentRound))
roundById, err := svc.Rounds().GetRoundWithId(ctx, roundId) roundById, err := svc.Rounds().GetRoundWithId(ctx, roundId)
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, roundById) require.NotNil(t, roundById)
@@ -311,14 +306,9 @@ func testRoundRepository(t *testing.T, svc ports.RepoManager) {
err = svc.Rounds().AddOrUpdateRound(ctx, *updatedRound) err = svc.Rounds().AddOrUpdateRound(ctx, *updatedRound)
require.NoError(t, err) require.NoError(t, err)
currentRound, err = svc.Rounds().GetCurrentRound(ctx)
require.NoError(t, err)
require.NotNil(t, currentRound)
require.Condition(t, roundsMatch(*updatedRound, *currentRound))
roundById, err = svc.Rounds().GetRoundWithId(ctx, updatedRound.Id) roundById, err = svc.Rounds().GetRoundWithId(ctx, updatedRound.Id)
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, currentRound) require.NotNil(t, roundById)
require.Condition(t, roundsMatch(*updatedRound, *roundById)) require.Condition(t, roundsMatch(*updatedRound, *roundById))
txid := randomString(32) txid := randomString(32)
@@ -336,10 +326,6 @@ func testRoundRepository(t *testing.T, svc ports.RepoManager) {
err = svc.Rounds().AddOrUpdateRound(ctx, *finalizedRound) err = svc.Rounds().AddOrUpdateRound(ctx, *finalizedRound)
require.NoError(t, err) require.NoError(t, err)
currentRound, err = svc.Rounds().GetCurrentRound(ctx)
require.Error(t, err)
require.Nil(t, currentRound)
roundById, err = svc.Rounds().GetRoundWithId(ctx, roundId) roundById, err = svc.Rounds().GetRoundWithId(ctx, roundId)
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, roundById) require.NotNil(t, roundById)

View File

@@ -135,11 +135,10 @@ LEFT OUTER JOIN receiver ON payment.id=receiver.payment_id
LEFT OUTER JOIN vtxo ON payment.id=vtxo.payment_id LEFT OUTER JOIN vtxo ON payment.id=vtxo.payment_id
` `
selectCurrentRound = selectRound + " WHERE round.ended = false AND round.failed = false;"
selectRoundWithId = selectRound + " WHERE round.id = ?;" selectRoundWithId = selectRound + " WHERE round.id = ?;"
selectRoundWithTxId = selectRound + " WHERE round.txid = ?;" selectRoundWithTxId = selectRound + " WHERE round.txid = ?;"
selectSweepableRounds = selectRound + " WHERE round.swept = false AND round.ended = true AND round.failed = false;" selectSweepableRounds = selectRound + " WHERE round.swept = false AND round.ended = true AND round.failed = false;"
selectSweptRounds = selectRound + " WHERE round.swept = true AND round.failed = false AND round.ended = true;" selectSweptRounds = selectRound + " WHERE round.swept = true AND round.failed = false AND round.ended = true AND round.connector_address <> '';"
selectRoundIdsInRange = ` selectRoundIdsInRange = `
SELECT id FROM round WHERE starting_timestamp > ? AND starting_timestamp < ?; SELECT id FROM round WHERE starting_timestamp > ? AND starting_timestamp < ?;
@@ -379,30 +378,6 @@ func (r *roundRepository) AddOrUpdateRound(ctx context.Context, round domain.Rou
return tx.Commit() return tx.Commit()
} }
func (r *roundRepository) GetCurrentRound(ctx context.Context) (*domain.Round, error) {
stmt, err := r.db.Prepare(selectCurrentRound)
if err != nil {
return nil, err
}
defer stmt.Close()
rows, err := stmt.Query()
if err != nil {
return nil, err
}
rounds, err := readRoundRows(rows)
if err != nil {
return nil, err
}
if len(rounds) == 0 {
return nil, errors.New("no current round")
}
return rounds[0], nil
}
func (r *roundRepository) GetRoundWithId(ctx context.Context, id string) (*domain.Round, error) { func (r *roundRepository) GetRoundWithId(ctx context.Context, id string) (*domain.Round, error) {
stmt, err := r.db.Prepare(selectRoundWithId) stmt, err := r.db.Prepare(selectRoundWithId)
if err != nil { if err != nil {

View File

@@ -27,9 +27,13 @@ func (s *service) DeriveConnectorAddress(ctx context.Context) (string, error) {
func (s *service) ListConnectorUtxos( func (s *service) ListConnectorUtxos(
ctx context.Context, connectorAddress string, ctx context.Context, connectorAddress string,
) ([]ports.TxInput, error) { ) ([]ports.TxInput, error) {
addresses := make([]string, 0)
if len(connectorAddress) > 0 {
addresses = append(addresses, connectorAddress)
}
res, err := s.accountClient.ListUtxos(ctx, &pb.ListUtxosRequest{ res, err := s.accountClient.ListUtxos(ctx, &pb.ListUtxosRequest{
AccountName: connectorAccount, AccountName: connectorAccount,
Addresses: []string{connectorAddress}, Addresses: addresses,
}) })
if err != nil { if err != nil {
return nil, err return nil, err