From 8d37d8a3d91dc717987bcfc431af4119619a98bb Mon Sep 17 00:00:00 2001 From: Oliver Gugger Date: Mon, 22 Nov 2021 16:30:52 +0100 Subject: [PATCH] hashmail_server: fix blocking reads With this commit we fix a bug in the hashmail server that didn't return a read stream properly if it was closed from the client side. Co-authored-by: Elle Mouton --- hashmail_server.go | 84 +++++++++++++----- hashmail_server_test.go | 192 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 254 insertions(+), 22 deletions(-) create mode 100644 hashmail_server_test.go diff --git a/hashmail_server.go b/hashmail_server.go index 18ce3a4..7ec0301 100644 --- a/hashmail_server.go +++ b/hashmail_server.go @@ -1,7 +1,7 @@ package aperture import ( - "bufio" + "bytes" "context" "fmt" "io" @@ -26,6 +26,10 @@ const ( // for messages. If a new message is about to exceed the burst rate, // then we'll allow it up to this burst allowance. DefaultMsgBurstAllowance = 10 + + // DefaultBufSize is the default number of bytes that are read in a + // single operation. + DefaultBufSize = 4096 ) // streamID is the identifier of a stream. @@ -42,8 +46,6 @@ func newStreamID(id []byte) streamID { // readStream is the read side of the read pipe, which is implemented a // buffered wrapper around the core reader. type readStream struct { - io.Reader - // parentStream is a pointer to the parent stream. We keep this around // so we can return the stream after we're done using it. parentStream *stream @@ -56,10 +58,22 @@ type readStream struct { // ReadNextMsg attempts to read the next message in the stream. // // NOTE: This will *block* until a new message is available. -func (r *readStream) ReadNextMsg() ([]byte, error) { +func (r *readStream) ReadNextMsg(ctx context.Context) ([]byte, error) { + var reader io.Reader + select { + case b := <-r.parentStream.readBytesChan: + reader = bytes.NewReader(b) + + case <-ctx.Done(): + return nil, ctx.Err() + + case err := <-r.parentStream.readErrChan: + return nil, err + } + // First, we'll decode the length of the next message from the stream // so we know how many bytes we need to read. - msgLen, err := tlv.ReadVarInt(r, &r.scratchBuf) + msgLen, err := tlv.ReadVarInt(reader, &r.scratchBuf) if err != nil { return nil, err } @@ -67,7 +81,7 @@ func (r *readStream) ReadNextMsg() ([]byte, error) { // Now that we know the length of the message, we'll make a limit // reader, then read all the encoded bytes until the EOF is emitted by // the reader. - msgReader := io.LimitReader(r, int64(msgLen)) + msgReader := io.LimitReader(reader, int64(msgLen)) return ioutil.ReadAll(msgReader) } @@ -108,19 +122,18 @@ func (w *writeStream) WriteMsg(ctx context.Context, msg []byte) error { // As we're writing to a stream, we need to delimit each message with a // length prefix so the reader knows how many bytes to consume for each // message. - // - // TODO(roasbeef): actually needs to be single write? + var buf bytes.Buffer msgSize := uint64(len(msg)) - err := tlv.WriteVarInt( - w, msgSize, &w.scratchBuf, - ) - if err != nil { + if err := tlv.WriteVarInt(&buf, msgSize, &w.scratchBuf); err != nil { return err } // Next, we'll write the message directly to the stream. - _, err = w.Write(msg) - if err != nil { + if _, err := buf.Write(msg); err != nil { + return err + } + + if _, err := w.Write(buf.Bytes()); err != nil { return err } @@ -143,6 +156,10 @@ type stream struct { readStreamChan chan *readStream writeStreamChan chan *writeStream + readBytesChan chan []byte + readErrChan chan error + quit chan struct{} + // equivAuth is a method used to determine if an authentication // mechanism to tear down a stream is equivalent to the one used to // create it in the first place. WE use this to ensure that only the @@ -173,6 +190,9 @@ func newStream(id streamID, limiter *rate.Limiter, id: id, equivAuth: equivAuth, limiter: limiter, + readBytesChan: make(chan []byte), + readErrChan: make(chan error, 1), + quit: make(chan struct{}), } // Our tear down function will close the write side of the pipe, which @@ -183,6 +203,7 @@ func newStream(id streamID, limiter *rate.Limiter, if err != nil { return err } + close(s.quit) s.wg.Wait() return nil } @@ -196,21 +217,37 @@ func newStream(id streamID, limiter *rate.Limiter, // will read from. _, err := io.Copy( readWritePipe, - // This is where the buffering will happen, as the - // writer writes to the write end of the pipe, this - // goroutine will copy the bytes into the buffer until - // its full, then attempt to write it to the write end - // of the read pipe. - bufio.NewReader(writeReadPipe), + writeReadPipe, ) _ = readWritePipe.CloseWithError(err) _ = writeReadPipe.CloseWithError(err) }() + s.wg.Add(1) + go func() { + defer s.wg.Done() + + var buf [DefaultBufSize]byte + for { + numBytes, err := readReadPipe.Read(buf[:]) + if err != nil { + s.readErrChan <- err + return + } + + c := make([]byte, numBytes) + copy(c, buf[0:numBytes]) + + select { + case s.readBytesChan <- c: + case <-s.quit: + } + } + }() + // We'll now initialize our stream by sending the read and write ends // to their respective holding channels. s.readStreamChan <- &readStream{ - Reader: readReadPipe, parentStream: s, } s.writeStreamChan <- &writeStream{ @@ -555,6 +592,7 @@ func (h *hashMailServer) SendStream(readStream hashmailrpc.HashMail_SendStreamSe // exit before shutting down. select { case <-ctx.Done(): + log.Debugf("SendStream: Context done, exiting") return nil case <-h.quit: return fmt.Errorf("server shutting down") @@ -564,6 +602,8 @@ func (h *hashMailServer) SendStream(readStream hashmailrpc.HashMail_SendStreamSe cipherBox, err := readStream.Recv() if err != nil { + log.Debugf("SendStream: Exiting write stream RPC "+ + "stream read: %v", err) return err } @@ -608,7 +648,7 @@ func (h *hashMailServer) RecvStream(desc *hashmailrpc.CipherBoxDesc, default: } - nextMsg, err := readStream.ReadNextMsg() + nextMsg, err := readStream.ReadNextMsg(reader.Context()) if err != nil { log.Debugf("Got error an read stream read: %v", err) return err diff --git a/hashmail_server_test.go b/hashmail_server_test.go new file mode 100644 index 0000000..f49dbfd --- /dev/null +++ b/hashmail_server_test.go @@ -0,0 +1,192 @@ +package aperture + +import ( + "context" + "fmt" + "math" + "net/http" + "testing" + "time" + + "github.com/lightninglabs/lightning-node-connect/hashmailrpc" + "github.com/lightningnetwork/lnd/build" + "github.com/lightningnetwork/lnd/lntest/wait" + "github.com/lightningnetwork/lnd/signal" + "github.com/stretchr/testify/require" + "google.golang.org/grpc" +) + +var ( + testApertureAddress = "localhost:8082" + testSID = streamID{1, 2, 3} + testStreamDesc = &hashmailrpc.CipherBoxDesc{ + StreamId: testSID[:], + } + testMessage = []byte("I'm a message!") + apertureStartTimeout = 3 * time.Second +) + +func TestHashMailServerReturnStream(t *testing.T) { + ctxb := context.Background() + + setupAperture(t) + + // Create a client and connect it to the server. + conn, err := grpc.Dial(testApertureAddress, grpc.WithInsecure()) + require.NoError(t, err) + client := hashmailrpc.NewHashMailClient(conn) + + // We'll create a new cipher box that we're going to subscribe to + // multiple times to check disconnecting returns the read stream. + resp, err := client.NewCipherBox(ctxb, &hashmailrpc.CipherBoxAuth{ + Auth: &hashmailrpc.CipherBoxAuth_LndAuth{}, + Desc: testStreamDesc, + }) + require.NoError(t, err) + require.NotNil(t, resp.GetSuccess()) + + // First we make sure there is something to read on the other end of + // that stream by writing something to it. + sendCtx, sendCancel := context.WithCancel(context.Background()) + defer sendCancel() + + writeStream, err := client.SendStream(sendCtx) + require.NoError(t, err) + err = writeStream.Send(&hashmailrpc.CipherBox{ + Desc: testStreamDesc, + Msg: testMessage, + }) + require.NoError(t, err) + + // We need to wait a bit to make sure the message is really sent. + time.Sleep(100 * time.Millisecond) + + // Connect, wait for the stream to be ready, read something, then + // disconnect immediately. + msg, err := readMsgFromStream(t, client) + require.NoError(t, err) + require.Equal(t, testMessage, msg.Msg) + + // Make sure we can connect again immediately and try to read something. + // There is no message to read before we cancel the request so we expect + // an EOF error to be returned upon connection close/context cancel. + _, err = readMsgFromStream(t, client) + require.Error(t, err) + require.Contains(t, err.Error(), "context canceled") + + // Send then receive yet another message to make sure the stream is + // still operational. + testMessage2 := append(testMessage, []byte("test")...) + err = writeStream.Send(&hashmailrpc.CipherBox{ + Desc: testStreamDesc, + Msg: testMessage2, + }) + require.NoError(t, err) + + // We need to wait a bit to make sure the message is really sent. + time.Sleep(100 * time.Millisecond) + + msg, err = readMsgFromStream(t, client) + require.NoError(t, err) + require.Equal(t, testMessage2, msg.Msg) +} + +func setupAperture(t *testing.T) { + logWriter := build.NewRotatingLogWriter() + interceptor, err := signal.Intercept() + require.NoError(t, err) + + SetupLoggers(logWriter, interceptor) + + err = build.ParseAndSetDebugLevels("trace,PRXY=warn", logWriter) + require.NoError(t, err) + + apertureCfg := &Config{ + Insecure: true, + ListenAddr: testApertureAddress, + Authenticator: &AuthConfig{ + Disable: true, + }, + Etcd: &EtcdConfig{}, + HashMail: &HashMailConfig{ + Enabled: true, + MessageRate: time.Millisecond, + MessageBurstAllowance: math.MaxUint32, + }, + } + aperture := NewAperture(apertureCfg) + errChan := make(chan error) + require.NoError(t, aperture.Start(errChan)) + + // Any error while starting? + select { + case err := <-errChan: + t.Fatalf("error starting aperture: %v", err) + default: + } + + err = wait.NoError(func() error { + apertureAddr := fmt.Sprintf("http://%s/dummy", + testApertureAddress) + + resp, err := http.Get(apertureAddr) + if err != nil { + return err + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusNotFound { + return fmt.Errorf("invalid status: %d", resp.StatusCode) + } + + return nil + }, apertureStartTimeout) + require.NoError(t, err) +} + +func readMsgFromStream(t *testing.T, + client hashmailrpc.HashMailClient) (*hashmailrpc.CipherBox, error) { + + ctxc, cancel := context.WithCancel(context.Background()) + readStream, err := client.RecvStream(ctxc, testStreamDesc) + require.NoError(t, err) + + // Wait a bit again to make sure the request is actually sent before our + // context is canceled already again. + time.Sleep(100 * time.Millisecond) + + // We'll start a read on the stream in the background. + var ( + goroutineStarted = make(chan struct{}) + resultChan = make(chan *hashmailrpc.CipherBox) + errChan = make(chan error) + ) + go func() { + close(goroutineStarted) + box, err := readStream.Recv() + if err != nil { + errChan <- err + return + } + resultChan <- box + }() + + // Give the goroutine a chance to actually run, so block the main thread + // until it did. + <-goroutineStarted + + time.Sleep(200 * time.Millisecond) + + // Now close and cancel the stream to make sure the server can clean it + // up and release it. + require.NoError(t, readStream.CloseSend()) + cancel() + + // Interpret the result. + select { + case err := <-errChan: + return nil, err + + case box := <-resultChan: + return box, nil + } +}