mirror of
https://github.com/aljazceru/breez-lnd.git
synced 2025-12-18 14:44:22 +01:00
htlcswitch.test: add message interceptor handler
Add message interceptor which checks the order and may skip the messages which were denoted to be skipeed.
This commit is contained in:
committed by
Olaoluwa Osuntokun
parent
f29b4f60e4
commit
1eb906bcfb
@@ -5,13 +5,14 @@ import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"io"
|
||||
"sync/atomic"
|
||||
|
||||
"bytes"
|
||||
|
||||
"testing"
|
||||
|
||||
"github.com/btcsuite/fastsha256"
|
||||
"github.com/go-errors/errors"
|
||||
"github.com/lightningnetwork/lnd/chainntnfs"
|
||||
@@ -39,8 +40,8 @@ type mockServer struct {
|
||||
id [33]byte
|
||||
htlcSwitch *Switch
|
||||
|
||||
registry *mockInvoiceRegistry
|
||||
recordFuncs []func(lnwire.Message)
|
||||
registry *mockInvoiceRegistry
|
||||
interceptorFuncs []messageInterceptor
|
||||
}
|
||||
|
||||
var _ Peer = (*mockServer)(nil)
|
||||
@@ -51,14 +52,14 @@ func newMockServer(t *testing.T, name string) *mockServer {
|
||||
copy(id[:], h[:])
|
||||
|
||||
return &mockServer{
|
||||
t: t,
|
||||
id: id,
|
||||
name: name,
|
||||
messages: make(chan lnwire.Message, 3000),
|
||||
quit: make(chan bool),
|
||||
registry: newMockRegistry(),
|
||||
htlcSwitch: New(Config{}),
|
||||
recordFuncs: make([]func(lnwire.Message), 0),
|
||||
t: t,
|
||||
id: id,
|
||||
name: name,
|
||||
messages: make(chan lnwire.Message, 3000),
|
||||
quit: make(chan bool),
|
||||
registry: newMockRegistry(),
|
||||
htlcSwitch: New(Config{}),
|
||||
interceptorFuncs: make([]messageInterceptor, 0),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -76,8 +77,20 @@ func (s *mockServer) Start() error {
|
||||
for {
|
||||
select {
|
||||
case msg := <-s.messages:
|
||||
for _, f := range s.recordFuncs {
|
||||
f(msg)
|
||||
var shouldSkip bool
|
||||
|
||||
for _, interceptor := range s.interceptorFuncs {
|
||||
skip, err := interceptor(msg)
|
||||
if err != nil {
|
||||
s.errChan <- errors.Errorf("%v: error in the "+
|
||||
"interceptor: %v", s.name, err)
|
||||
return
|
||||
}
|
||||
shouldSkip = shouldSkip || skip
|
||||
}
|
||||
|
||||
if shouldSkip {
|
||||
continue
|
||||
}
|
||||
|
||||
if err := s.readHandler(msg); err != nil {
|
||||
@@ -245,13 +258,13 @@ func (f *ForwardingInfo) decode(r io.Reader) error {
|
||||
}
|
||||
|
||||
// messageInterceptor is function that handles the incoming peer messages and
|
||||
// may decide should we handle it or not.
|
||||
type messageInterceptor func(m lnwire.Message)
|
||||
// may decide should the peer skip the message or not.
|
||||
type messageInterceptor func(m lnwire.Message) (bool, error)
|
||||
|
||||
// Record is used to set the function which will be triggered when new
|
||||
// lnwire message was received.
|
||||
func (s *mockServer) record(f messageInterceptor) {
|
||||
s.recordFuncs = append(s.recordFuncs, f)
|
||||
func (s *mockServer) intersect(f messageInterceptor) {
|
||||
s.interceptorFuncs = append(s.interceptorFuncs, f)
|
||||
}
|
||||
|
||||
func (s *mockServer) SendMessage(message lnwire.Message) error {
|
||||
@@ -297,11 +310,8 @@ func (s *mockServer) readHandler(message lnwire.Message) error {
|
||||
// the server when handler stacked (server unavailable)
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer func() {
|
||||
done <- struct{}{}
|
||||
}()
|
||||
|
||||
link.HandleChannelUpdate(message)
|
||||
done <- struct{}{}
|
||||
}()
|
||||
select {
|
||||
case <-done:
|
||||
|
||||
Reference in New Issue
Block a user