hashmail_server: fix blocking reads

With this commit we fix a bug in the hashmail server that didn't return
a read stream properly if it was closed from the client side.

Co-authored-by: Elle Mouton <elle.mouton@gmail.com>
This commit is contained in:
Oliver Gugger
2021-11-22 16:30:52 +01:00
parent 7bcc8355d0
commit 8d37d8a3d9
2 changed files with 254 additions and 22 deletions

View File

@@ -1,7 +1,7 @@
package aperture
import (
"bufio"
"bytes"
"context"
"fmt"
"io"
@@ -26,6 +26,10 @@ const (
// for messages. If a new message is about to exceed the burst rate,
// then we'll allow it up to this burst allowance.
DefaultMsgBurstAllowance = 10
// DefaultBufSize is the default number of bytes that are read in a
// single operation.
DefaultBufSize = 4096
)
// streamID is the identifier of a stream.
@@ -42,8 +46,6 @@ func newStreamID(id []byte) streamID {
// readStream is the read side of the read pipe, which is implemented a
// buffered wrapper around the core reader.
type readStream struct {
io.Reader
// parentStream is a pointer to the parent stream. We keep this around
// so we can return the stream after we're done using it.
parentStream *stream
@@ -56,10 +58,22 @@ type readStream struct {
// ReadNextMsg attempts to read the next message in the stream.
//
// NOTE: This will *block* until a new message is available.
func (r *readStream) ReadNextMsg() ([]byte, error) {
func (r *readStream) ReadNextMsg(ctx context.Context) ([]byte, error) {
var reader io.Reader
select {
case b := <-r.parentStream.readBytesChan:
reader = bytes.NewReader(b)
case <-ctx.Done():
return nil, ctx.Err()
case err := <-r.parentStream.readErrChan:
return nil, err
}
// First, we'll decode the length of the next message from the stream
// so we know how many bytes we need to read.
msgLen, err := tlv.ReadVarInt(r, &r.scratchBuf)
msgLen, err := tlv.ReadVarInt(reader, &r.scratchBuf)
if err != nil {
return nil, err
}
@@ -67,7 +81,7 @@ func (r *readStream) ReadNextMsg() ([]byte, error) {
// Now that we know the length of the message, we'll make a limit
// reader, then read all the encoded bytes until the EOF is emitted by
// the reader.
msgReader := io.LimitReader(r, int64(msgLen))
msgReader := io.LimitReader(reader, int64(msgLen))
return ioutil.ReadAll(msgReader)
}
@@ -108,19 +122,18 @@ func (w *writeStream) WriteMsg(ctx context.Context, msg []byte) error {
// As we're writing to a stream, we need to delimit each message with a
// length prefix so the reader knows how many bytes to consume for each
// message.
//
// TODO(roasbeef): actually needs to be single write?
var buf bytes.Buffer
msgSize := uint64(len(msg))
err := tlv.WriteVarInt(
w, msgSize, &w.scratchBuf,
)
if err != nil {
if err := tlv.WriteVarInt(&buf, msgSize, &w.scratchBuf); err != nil {
return err
}
// Next, we'll write the message directly to the stream.
_, err = w.Write(msg)
if err != nil {
if _, err := buf.Write(msg); err != nil {
return err
}
if _, err := w.Write(buf.Bytes()); err != nil {
return err
}
@@ -143,6 +156,10 @@ type stream struct {
readStreamChan chan *readStream
writeStreamChan chan *writeStream
readBytesChan chan []byte
readErrChan chan error
quit chan struct{}
// equivAuth is a method used to determine if an authentication
// mechanism to tear down a stream is equivalent to the one used to
// create it in the first place. WE use this to ensure that only the
@@ -173,6 +190,9 @@ func newStream(id streamID, limiter *rate.Limiter,
id: id,
equivAuth: equivAuth,
limiter: limiter,
readBytesChan: make(chan []byte),
readErrChan: make(chan error, 1),
quit: make(chan struct{}),
}
// Our tear down function will close the write side of the pipe, which
@@ -183,6 +203,7 @@ func newStream(id streamID, limiter *rate.Limiter,
if err != nil {
return err
}
close(s.quit)
s.wg.Wait()
return nil
}
@@ -196,21 +217,37 @@ func newStream(id streamID, limiter *rate.Limiter,
// will read from.
_, err := io.Copy(
readWritePipe,
// This is where the buffering will happen, as the
// writer writes to the write end of the pipe, this
// goroutine will copy the bytes into the buffer until
// its full, then attempt to write it to the write end
// of the read pipe.
bufio.NewReader(writeReadPipe),
writeReadPipe,
)
_ = readWritePipe.CloseWithError(err)
_ = writeReadPipe.CloseWithError(err)
}()
s.wg.Add(1)
go func() {
defer s.wg.Done()
var buf [DefaultBufSize]byte
for {
numBytes, err := readReadPipe.Read(buf[:])
if err != nil {
s.readErrChan <- err
return
}
c := make([]byte, numBytes)
copy(c, buf[0:numBytes])
select {
case s.readBytesChan <- c:
case <-s.quit:
}
}
}()
// We'll now initialize our stream by sending the read and write ends
// to their respective holding channels.
s.readStreamChan <- &readStream{
Reader: readReadPipe,
parentStream: s,
}
s.writeStreamChan <- &writeStream{
@@ -555,6 +592,7 @@ func (h *hashMailServer) SendStream(readStream hashmailrpc.HashMail_SendStreamSe
// exit before shutting down.
select {
case <-ctx.Done():
log.Debugf("SendStream: Context done, exiting")
return nil
case <-h.quit:
return fmt.Errorf("server shutting down")
@@ -564,6 +602,8 @@ func (h *hashMailServer) SendStream(readStream hashmailrpc.HashMail_SendStreamSe
cipherBox, err := readStream.Recv()
if err != nil {
log.Debugf("SendStream: Exiting write stream RPC "+
"stream read: %v", err)
return err
}
@@ -608,7 +648,7 @@ func (h *hashMailServer) RecvStream(desc *hashmailrpc.CipherBoxDesc,
default:
}
nextMsg, err := readStream.ReadNextMsg()
nextMsg, err := readStream.ReadNextMsg(reader.Context())
if err != nil {
log.Debugf("Got error an read stream read: %v", err)
return err

192
hashmail_server_test.go Normal file
View File

@@ -0,0 +1,192 @@
package aperture
import (
"context"
"fmt"
"math"
"net/http"
"testing"
"time"
"github.com/lightninglabs/lightning-node-connect/hashmailrpc"
"github.com/lightningnetwork/lnd/build"
"github.com/lightningnetwork/lnd/lntest/wait"
"github.com/lightningnetwork/lnd/signal"
"github.com/stretchr/testify/require"
"google.golang.org/grpc"
)
var (
testApertureAddress = "localhost:8082"
testSID = streamID{1, 2, 3}
testStreamDesc = &hashmailrpc.CipherBoxDesc{
StreamId: testSID[:],
}
testMessage = []byte("I'm a message!")
apertureStartTimeout = 3 * time.Second
)
func TestHashMailServerReturnStream(t *testing.T) {
ctxb := context.Background()
setupAperture(t)
// Create a client and connect it to the server.
conn, err := grpc.Dial(testApertureAddress, grpc.WithInsecure())
require.NoError(t, err)
client := hashmailrpc.NewHashMailClient(conn)
// We'll create a new cipher box that we're going to subscribe to
// multiple times to check disconnecting returns the read stream.
resp, err := client.NewCipherBox(ctxb, &hashmailrpc.CipherBoxAuth{
Auth: &hashmailrpc.CipherBoxAuth_LndAuth{},
Desc: testStreamDesc,
})
require.NoError(t, err)
require.NotNil(t, resp.GetSuccess())
// First we make sure there is something to read on the other end of
// that stream by writing something to it.
sendCtx, sendCancel := context.WithCancel(context.Background())
defer sendCancel()
writeStream, err := client.SendStream(sendCtx)
require.NoError(t, err)
err = writeStream.Send(&hashmailrpc.CipherBox{
Desc: testStreamDesc,
Msg: testMessage,
})
require.NoError(t, err)
// We need to wait a bit to make sure the message is really sent.
time.Sleep(100 * time.Millisecond)
// Connect, wait for the stream to be ready, read something, then
// disconnect immediately.
msg, err := readMsgFromStream(t, client)
require.NoError(t, err)
require.Equal(t, testMessage, msg.Msg)
// Make sure we can connect again immediately and try to read something.
// There is no message to read before we cancel the request so we expect
// an EOF error to be returned upon connection close/context cancel.
_, err = readMsgFromStream(t, client)
require.Error(t, err)
require.Contains(t, err.Error(), "context canceled")
// Send then receive yet another message to make sure the stream is
// still operational.
testMessage2 := append(testMessage, []byte("test")...)
err = writeStream.Send(&hashmailrpc.CipherBox{
Desc: testStreamDesc,
Msg: testMessage2,
})
require.NoError(t, err)
// We need to wait a bit to make sure the message is really sent.
time.Sleep(100 * time.Millisecond)
msg, err = readMsgFromStream(t, client)
require.NoError(t, err)
require.Equal(t, testMessage2, msg.Msg)
}
func setupAperture(t *testing.T) {
logWriter := build.NewRotatingLogWriter()
interceptor, err := signal.Intercept()
require.NoError(t, err)
SetupLoggers(logWriter, interceptor)
err = build.ParseAndSetDebugLevels("trace,PRXY=warn", logWriter)
require.NoError(t, err)
apertureCfg := &Config{
Insecure: true,
ListenAddr: testApertureAddress,
Authenticator: &AuthConfig{
Disable: true,
},
Etcd: &EtcdConfig{},
HashMail: &HashMailConfig{
Enabled: true,
MessageRate: time.Millisecond,
MessageBurstAllowance: math.MaxUint32,
},
}
aperture := NewAperture(apertureCfg)
errChan := make(chan error)
require.NoError(t, aperture.Start(errChan))
// Any error while starting?
select {
case err := <-errChan:
t.Fatalf("error starting aperture: %v", err)
default:
}
err = wait.NoError(func() error {
apertureAddr := fmt.Sprintf("http://%s/dummy",
testApertureAddress)
resp, err := http.Get(apertureAddr)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusNotFound {
return fmt.Errorf("invalid status: %d", resp.StatusCode)
}
return nil
}, apertureStartTimeout)
require.NoError(t, err)
}
func readMsgFromStream(t *testing.T,
client hashmailrpc.HashMailClient) (*hashmailrpc.CipherBox, error) {
ctxc, cancel := context.WithCancel(context.Background())
readStream, err := client.RecvStream(ctxc, testStreamDesc)
require.NoError(t, err)
// Wait a bit again to make sure the request is actually sent before our
// context is canceled already again.
time.Sleep(100 * time.Millisecond)
// We'll start a read on the stream in the background.
var (
goroutineStarted = make(chan struct{})
resultChan = make(chan *hashmailrpc.CipherBox)
errChan = make(chan error)
)
go func() {
close(goroutineStarted)
box, err := readStream.Recv()
if err != nil {
errChan <- err
return
}
resultChan <- box
}()
// Give the goroutine a chance to actually run, so block the main thread
// until it did.
<-goroutineStarted
time.Sleep(200 * time.Millisecond)
// Now close and cancel the stream to make sure the server can clean it
// up and release it.
require.NoError(t, readStream.CloseSend())
cancel()
// Interpret the result.
select {
case err := <-errChan:
return nil, err
case box := <-resultChan:
return box, nil
}
}