mirror of
https://github.com/lightninglabs/aperture.git
synced 2026-01-07 11:24:24 +01:00
Merge pull request #46 from guggero/lnd-connection-restart
challenger: shutdown if connection to lnd is lost
This commit is contained in:
@@ -110,7 +110,10 @@ func run() error {
|
||||
Value: price,
|
||||
}, nil
|
||||
}
|
||||
challenger, err := NewLndChallenger(cfg.Authenticator, genInvoiceReq)
|
||||
errChan := make(chan error)
|
||||
challenger, err := NewLndChallenger(
|
||||
cfg.Authenticator, genInvoiceReq, errChan,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -164,7 +167,6 @@ func run() error {
|
||||
)
|
||||
log.Infof("Starting the server, listening on %s.", cfg.ListenAddr)
|
||||
|
||||
errChan := make(chan error)
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
@@ -49,6 +49,8 @@ type LndChallenger struct {
|
||||
invoicesCancel func()
|
||||
invoicesCond *sync.Cond
|
||||
|
||||
errChan chan<- error
|
||||
|
||||
quit chan struct{}
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
@@ -66,8 +68,8 @@ const (
|
||||
|
||||
// NewLndChallenger creates a new challenger that uses the given connection
|
||||
// details to connect to an lnd backend to create payment challenges.
|
||||
func NewLndChallenger(cfg *authConfig, genInvoiceReq InvoiceRequestGenerator) (
|
||||
*LndChallenger, error) {
|
||||
func NewLndChallenger(cfg *authConfig, genInvoiceReq InvoiceRequestGenerator,
|
||||
errChan chan<- error) (*LndChallenger, error) {
|
||||
|
||||
if genInvoiceReq == nil {
|
||||
return nil, fmt.Errorf("genInvoiceReq cannot be nil")
|
||||
@@ -89,6 +91,7 @@ func NewLndChallenger(cfg *authConfig, genInvoiceReq InvoiceRequestGenerator) (
|
||||
invoicesMtx: invoicesMtx,
|
||||
invoicesCond: sync.NewCond(invoicesMtx),
|
||||
quit: make(chan struct{}),
|
||||
errChan: errChan,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -188,17 +191,39 @@ func (l *LndChallenger) readInvoiceStream(
|
||||
switch {
|
||||
|
||||
case err == io.EOF:
|
||||
// The connection is shutting down, we can't continue
|
||||
// to function properly. Signal the error to the main
|
||||
// goroutine to force a shutdown/restart.
|
||||
select {
|
||||
case l.errChan <- err:
|
||||
case <-l.quit:
|
||||
default:
|
||||
}
|
||||
|
||||
return
|
||||
|
||||
case err != nil && strings.Contains(
|
||||
err.Error(), context.Canceled.Error(),
|
||||
):
|
||||
|
||||
// The context has been canceled, we are shutting down.
|
||||
// So no need to forward the error to the main
|
||||
// goroutine.
|
||||
return
|
||||
|
||||
case err != nil:
|
||||
log.Errorf("Received error from invoice subscription: "+
|
||||
"%v", err)
|
||||
|
||||
// The connection is faulty, we can't continue to
|
||||
// function properly. Signal the error to the main
|
||||
// goroutine to force a shutdown/restart.
|
||||
select {
|
||||
case l.errChan <- err:
|
||||
case <-l.quit:
|
||||
default:
|
||||
}
|
||||
|
||||
return
|
||||
|
||||
default:
|
||||
|
||||
@@ -2,6 +2,7 @@ package aperture
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -13,13 +14,14 @@ import (
|
||||
)
|
||||
|
||||
var (
|
||||
defaultTimeout = 20 * time.Millisecond
|
||||
defaultTimeout = 100 * time.Millisecond
|
||||
)
|
||||
|
||||
type invoiceStreamMock struct {
|
||||
lnrpc.Lightning_SubscribeInvoicesClient
|
||||
|
||||
updateChan chan *lnrpc.Invoice
|
||||
errChan chan error
|
||||
quit chan struct{}
|
||||
}
|
||||
|
||||
@@ -28,6 +30,9 @@ func (i *invoiceStreamMock) Recv() (*lnrpc.Invoice, error) {
|
||||
case msg := <-i.updateChan:
|
||||
return msg, nil
|
||||
|
||||
case err := <-i.errChan:
|
||||
return nil, err
|
||||
|
||||
case <-i.quit:
|
||||
return nil, context.Canceled
|
||||
}
|
||||
@@ -36,6 +41,7 @@ func (i *invoiceStreamMock) Recv() (*lnrpc.Invoice, error) {
|
||||
type mockInvoiceClient struct {
|
||||
invoices []*lnrpc.Invoice
|
||||
updateChan chan *lnrpc.Invoice
|
||||
errChan chan error
|
||||
quit chan struct{}
|
||||
|
||||
lastAddIndex uint64
|
||||
@@ -60,6 +66,7 @@ func (m *mockInvoiceClient) SubscribeInvoices(_ context.Context,
|
||||
|
||||
return &invoiceStreamMock{
|
||||
updateChan: m.updateChan,
|
||||
errChan: m.errChan,
|
||||
quit: m.quit,
|
||||
}, nil
|
||||
}
|
||||
@@ -81,9 +88,10 @@ func (m *mockInvoiceClient) stop() {
|
||||
close(m.quit)
|
||||
}
|
||||
|
||||
func newChallenger() (*LndChallenger, *mockInvoiceClient) {
|
||||
func newChallenger() (*LndChallenger, *mockInvoiceClient, chan error) {
|
||||
mockClient := &mockInvoiceClient{
|
||||
updateChan: make(chan *lnrpc.Invoice),
|
||||
errChan: make(chan error, 1),
|
||||
quit: make(chan struct{}),
|
||||
}
|
||||
genInvoiceReq := func(price int64) (*lnrpc.Invoice, error) {
|
||||
@@ -91,6 +99,7 @@ func newChallenger() (*LndChallenger, *mockInvoiceClient) {
|
||||
nil
|
||||
}
|
||||
invoicesMtx := &sync.Mutex{}
|
||||
mainErrChan := make(chan error)
|
||||
return &LndChallenger{
|
||||
client: mockClient,
|
||||
genInvoiceReq: genInvoiceReq,
|
||||
@@ -98,7 +107,8 @@ func newChallenger() (*LndChallenger, *mockInvoiceClient) {
|
||||
quit: make(chan struct{}),
|
||||
invoicesMtx: invoicesMtx,
|
||||
invoicesCond: sync.NewCond(invoicesMtx),
|
||||
}, mockClient
|
||||
errChan: mainErrChan,
|
||||
}, mockClient, mainErrChan
|
||||
}
|
||||
|
||||
func newInvoice(hash lntypes.Hash, addIndex uint64,
|
||||
@@ -119,12 +129,13 @@ func TestLndChallenger(t *testing.T) {
|
||||
|
||||
// First of all, test that the NewLndChallenger doesn't allow a nil
|
||||
// invoice generator function.
|
||||
_, err := NewLndChallenger(nil, nil)
|
||||
errChan := make(chan error)
|
||||
_, err := NewLndChallenger(nil, nil, errChan)
|
||||
require.Error(t, err)
|
||||
|
||||
// Now mock the lnd backend and create a challenger instance that we can
|
||||
// test.
|
||||
c, invoiceMock := newChallenger()
|
||||
c, invoiceMock, mainErrChan := newChallenger()
|
||||
|
||||
// Creating a new challenge should add an invoice to the lnd backend.
|
||||
req, hash, err := c.NewChallenge(1337)
|
||||
@@ -210,6 +221,33 @@ func TestLndChallenger(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// Finally test that if an error occurs in the invoice subscription the
|
||||
// challenger reports it on the main error channel to cause a shutdown
|
||||
// of aperture. The mock's error channel is buffered so we can send
|
||||
// directly.
|
||||
invoiceMock.errChan <- fmt.Errorf("an expected error")
|
||||
select {
|
||||
case err := <-mainErrChan:
|
||||
require.Error(t, err)
|
||||
|
||||
// Make sure that the goroutine exited.
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
c.wg.Wait()
|
||||
done <- struct{}{}
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
|
||||
case <-time.After(defaultTimeout):
|
||||
t.Fatalf("wait group didn't finish before timeout")
|
||||
}
|
||||
|
||||
case <-time.After(defaultTimeout):
|
||||
t.Fatalf("error not received on main chan before the timeout")
|
||||
}
|
||||
|
||||
invoiceMock.stop()
|
||||
c.Stop()
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user