diff --git a/aperture.go b/aperture.go index 669a1a7..2463eb5 100644 --- a/aperture.go +++ b/aperture.go @@ -110,7 +110,10 @@ func run() error { Value: price, }, nil } - challenger, err := NewLndChallenger(cfg.Authenticator, genInvoiceReq) + errChan := make(chan error) + challenger, err := NewLndChallenger( + cfg.Authenticator, genInvoiceReq, errChan, + ) if err != nil { return err } @@ -164,7 +167,6 @@ func run() error { ) log.Infof("Starting the server, listening on %s.", cfg.ListenAddr) - errChan := make(chan error) wg.Add(1) go func() { defer wg.Done() diff --git a/challenger.go b/challenger.go index ca01c62..f35a6ad 100644 --- a/challenger.go +++ b/challenger.go @@ -49,6 +49,8 @@ type LndChallenger struct { invoicesCancel func() invoicesCond *sync.Cond + errChan chan<- error + quit chan struct{} wg sync.WaitGroup } @@ -66,8 +68,8 @@ const ( // NewLndChallenger creates a new challenger that uses the given connection // details to connect to an lnd backend to create payment challenges. -func NewLndChallenger(cfg *authConfig, genInvoiceReq InvoiceRequestGenerator) ( - *LndChallenger, error) { +func NewLndChallenger(cfg *authConfig, genInvoiceReq InvoiceRequestGenerator, + errChan chan<- error) (*LndChallenger, error) { if genInvoiceReq == nil { return nil, fmt.Errorf("genInvoiceReq cannot be nil") @@ -89,6 +91,7 @@ func NewLndChallenger(cfg *authConfig, genInvoiceReq InvoiceRequestGenerator) ( invoicesMtx: invoicesMtx, invoicesCond: sync.NewCond(invoicesMtx), quit: make(chan struct{}), + errChan: errChan, }, nil } @@ -188,17 +191,39 @@ func (l *LndChallenger) readInvoiceStream( switch { case err == io.EOF: + // The connection is shutting down, we can't continue + // to function properly. Signal the error to the main + // goroutine to force a shutdown/restart. + select { + case l.errChan <- err: + case <-l.quit: + default: + } + return case err != nil && strings.Contains( err.Error(), context.Canceled.Error(), ): + // The context has been canceled, we are shutting down. + // So no need to forward the error to the main + // goroutine. return case err != nil: log.Errorf("Received error from invoice subscription: "+ "%v", err) + + // The connection is faulty, we can't continue to + // function properly. Signal the error to the main + // goroutine to force a shutdown/restart. + select { + case l.errChan <- err: + case <-l.quit: + default: + } + return default: diff --git a/challenger_test.go b/challenger_test.go index fa43db3..f0bf7e0 100644 --- a/challenger_test.go +++ b/challenger_test.go @@ -2,6 +2,7 @@ package aperture import ( "context" + "fmt" "sync" "testing" "time" @@ -13,13 +14,14 @@ import ( ) var ( - defaultTimeout = 20 * time.Millisecond + defaultTimeout = 100 * time.Millisecond ) type invoiceStreamMock struct { lnrpc.Lightning_SubscribeInvoicesClient updateChan chan *lnrpc.Invoice + errChan chan error quit chan struct{} } @@ -28,6 +30,9 @@ func (i *invoiceStreamMock) Recv() (*lnrpc.Invoice, error) { case msg := <-i.updateChan: return msg, nil + case err := <-i.errChan: + return nil, err + case <-i.quit: return nil, context.Canceled } @@ -36,6 +41,7 @@ func (i *invoiceStreamMock) Recv() (*lnrpc.Invoice, error) { type mockInvoiceClient struct { invoices []*lnrpc.Invoice updateChan chan *lnrpc.Invoice + errChan chan error quit chan struct{} lastAddIndex uint64 @@ -60,6 +66,7 @@ func (m *mockInvoiceClient) SubscribeInvoices(_ context.Context, return &invoiceStreamMock{ updateChan: m.updateChan, + errChan: m.errChan, quit: m.quit, }, nil } @@ -81,9 +88,10 @@ func (m *mockInvoiceClient) stop() { close(m.quit) } -func newChallenger() (*LndChallenger, *mockInvoiceClient) { +func newChallenger() (*LndChallenger, *mockInvoiceClient, chan error) { mockClient := &mockInvoiceClient{ updateChan: make(chan *lnrpc.Invoice), + errChan: make(chan error, 1), quit: make(chan struct{}), } genInvoiceReq := func(price int64) (*lnrpc.Invoice, error) { @@ -91,6 +99,7 @@ func newChallenger() (*LndChallenger, *mockInvoiceClient) { nil } invoicesMtx := &sync.Mutex{} + mainErrChan := make(chan error) return &LndChallenger{ client: mockClient, genInvoiceReq: genInvoiceReq, @@ -98,7 +107,8 @@ func newChallenger() (*LndChallenger, *mockInvoiceClient) { quit: make(chan struct{}), invoicesMtx: invoicesMtx, invoicesCond: sync.NewCond(invoicesMtx), - }, mockClient + errChan: mainErrChan, + }, mockClient, mainErrChan } func newInvoice(hash lntypes.Hash, addIndex uint64, @@ -119,12 +129,13 @@ func TestLndChallenger(t *testing.T) { // First of all, test that the NewLndChallenger doesn't allow a nil // invoice generator function. - _, err := NewLndChallenger(nil, nil) + errChan := make(chan error) + _, err := NewLndChallenger(nil, nil, errChan) require.Error(t, err) // Now mock the lnd backend and create a challenger instance that we can // test. - c, invoiceMock := newChallenger() + c, invoiceMock, mainErrChan := newChallenger() // Creating a new challenge should add an invoice to the lnd backend. req, hash, err := c.NewChallenge(1337) @@ -210,6 +221,33 @@ func TestLndChallenger(t *testing.T) { } } + // Finally test that if an error occurs in the invoice subscription the + // challenger reports it on the main error channel to cause a shutdown + // of aperture. The mock's error channel is buffered so we can send + // directly. + invoiceMock.errChan <- fmt.Errorf("an expected error") + select { + case err := <-mainErrChan: + require.Error(t, err) + + // Make sure that the goroutine exited. + done := make(chan struct{}) + go func() { + c.wg.Wait() + done <- struct{}{} + }() + + select { + case <-done: + + case <-time.After(defaultTimeout): + t.Fatalf("wait group didn't finish before timeout") + } + + case <-time.After(defaultTimeout): + t.Fatalf("error not received on main chan before the timeout") + } + invoiceMock.stop() c.Stop() }