Merge pull request #46 from guggero/lnd-connection-restart

challenger: shutdown if connection to lnd is lost
This commit is contained in:
Olaoluwa Osuntokun
2020-09-30 18:02:40 -07:00
committed by GitHub
3 changed files with 74 additions and 9 deletions

View File

@@ -110,7 +110,10 @@ func run() error {
Value: price,
}, nil
}
challenger, err := NewLndChallenger(cfg.Authenticator, genInvoiceReq)
errChan := make(chan error)
challenger, err := NewLndChallenger(
cfg.Authenticator, genInvoiceReq, errChan,
)
if err != nil {
return err
}
@@ -164,7 +167,6 @@ func run() error {
)
log.Infof("Starting the server, listening on %s.", cfg.ListenAddr)
errChan := make(chan error)
wg.Add(1)
go func() {
defer wg.Done()

View File

@@ -49,6 +49,8 @@ type LndChallenger struct {
invoicesCancel func()
invoicesCond *sync.Cond
errChan chan<- error
quit chan struct{}
wg sync.WaitGroup
}
@@ -66,8 +68,8 @@ const (
// NewLndChallenger creates a new challenger that uses the given connection
// details to connect to an lnd backend to create payment challenges.
func NewLndChallenger(cfg *authConfig, genInvoiceReq InvoiceRequestGenerator) (
*LndChallenger, error) {
func NewLndChallenger(cfg *authConfig, genInvoiceReq InvoiceRequestGenerator,
errChan chan<- error) (*LndChallenger, error) {
if genInvoiceReq == nil {
return nil, fmt.Errorf("genInvoiceReq cannot be nil")
@@ -89,6 +91,7 @@ func NewLndChallenger(cfg *authConfig, genInvoiceReq InvoiceRequestGenerator) (
invoicesMtx: invoicesMtx,
invoicesCond: sync.NewCond(invoicesMtx),
quit: make(chan struct{}),
errChan: errChan,
}, nil
}
@@ -188,17 +191,39 @@ func (l *LndChallenger) readInvoiceStream(
switch {
case err == io.EOF:
// The connection is shutting down, we can't continue
// to function properly. Signal the error to the main
// goroutine to force a shutdown/restart.
select {
case l.errChan <- err:
case <-l.quit:
default:
}
return
case err != nil && strings.Contains(
err.Error(), context.Canceled.Error(),
):
// The context has been canceled, we are shutting down.
// So no need to forward the error to the main
// goroutine.
return
case err != nil:
log.Errorf("Received error from invoice subscription: "+
"%v", err)
// The connection is faulty, we can't continue to
// function properly. Signal the error to the main
// goroutine to force a shutdown/restart.
select {
case l.errChan <- err:
case <-l.quit:
default:
}
return
default:

View File

@@ -2,6 +2,7 @@ package aperture
import (
"context"
"fmt"
"sync"
"testing"
"time"
@@ -13,13 +14,14 @@ import (
)
var (
defaultTimeout = 20 * time.Millisecond
defaultTimeout = 100 * time.Millisecond
)
type invoiceStreamMock struct {
lnrpc.Lightning_SubscribeInvoicesClient
updateChan chan *lnrpc.Invoice
errChan chan error
quit chan struct{}
}
@@ -28,6 +30,9 @@ func (i *invoiceStreamMock) Recv() (*lnrpc.Invoice, error) {
case msg := <-i.updateChan:
return msg, nil
case err := <-i.errChan:
return nil, err
case <-i.quit:
return nil, context.Canceled
}
@@ -36,6 +41,7 @@ func (i *invoiceStreamMock) Recv() (*lnrpc.Invoice, error) {
type mockInvoiceClient struct {
invoices []*lnrpc.Invoice
updateChan chan *lnrpc.Invoice
errChan chan error
quit chan struct{}
lastAddIndex uint64
@@ -60,6 +66,7 @@ func (m *mockInvoiceClient) SubscribeInvoices(_ context.Context,
return &invoiceStreamMock{
updateChan: m.updateChan,
errChan: m.errChan,
quit: m.quit,
}, nil
}
@@ -81,9 +88,10 @@ func (m *mockInvoiceClient) stop() {
close(m.quit)
}
func newChallenger() (*LndChallenger, *mockInvoiceClient) {
func newChallenger() (*LndChallenger, *mockInvoiceClient, chan error) {
mockClient := &mockInvoiceClient{
updateChan: make(chan *lnrpc.Invoice),
errChan: make(chan error, 1),
quit: make(chan struct{}),
}
genInvoiceReq := func(price int64) (*lnrpc.Invoice, error) {
@@ -91,6 +99,7 @@ func newChallenger() (*LndChallenger, *mockInvoiceClient) {
nil
}
invoicesMtx := &sync.Mutex{}
mainErrChan := make(chan error)
return &LndChallenger{
client: mockClient,
genInvoiceReq: genInvoiceReq,
@@ -98,7 +107,8 @@ func newChallenger() (*LndChallenger, *mockInvoiceClient) {
quit: make(chan struct{}),
invoicesMtx: invoicesMtx,
invoicesCond: sync.NewCond(invoicesMtx),
}, mockClient
errChan: mainErrChan,
}, mockClient, mainErrChan
}
func newInvoice(hash lntypes.Hash, addIndex uint64,
@@ -119,12 +129,13 @@ func TestLndChallenger(t *testing.T) {
// First of all, test that the NewLndChallenger doesn't allow a nil
// invoice generator function.
_, err := NewLndChallenger(nil, nil)
errChan := make(chan error)
_, err := NewLndChallenger(nil, nil, errChan)
require.Error(t, err)
// Now mock the lnd backend and create a challenger instance that we can
// test.
c, invoiceMock := newChallenger()
c, invoiceMock, mainErrChan := newChallenger()
// Creating a new challenge should add an invoice to the lnd backend.
req, hash, err := c.NewChallenge(1337)
@@ -210,6 +221,33 @@ func TestLndChallenger(t *testing.T) {
}
}
// Finally test that if an error occurs in the invoice subscription the
// challenger reports it on the main error channel to cause a shutdown
// of aperture. The mock's error channel is buffered so we can send
// directly.
invoiceMock.errChan <- fmt.Errorf("an expected error")
select {
case err := <-mainErrChan:
require.Error(t, err)
// Make sure that the goroutine exited.
done := make(chan struct{})
go func() {
c.wg.Wait()
done <- struct{}{}
}()
select {
case <-done:
case <-time.After(defaultTimeout):
t.Fatalf("wait group didn't finish before timeout")
}
case <-time.After(defaultTimeout):
t.Fatalf("error not received on main chan before the timeout")
}
invoiceMock.stop()
c.Stop()
}