mirror of
https://github.com/lightninglabs/aperture.git
synced 2026-01-31 15:14:26 +01:00
hashmail_server: fix blocking reads
With this commit we fix a bug in the hashmail server that didn't return a read stream properly if it was closed from the client side. Co-authored-by: Elle Mouton <elle.mouton@gmail.com>
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
package aperture
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -26,6 +26,10 @@ const (
|
||||
// for messages. If a new message is about to exceed the burst rate,
|
||||
// then we'll allow it up to this burst allowance.
|
||||
DefaultMsgBurstAllowance = 10
|
||||
|
||||
// DefaultBufSize is the default number of bytes that are read in a
|
||||
// single operation.
|
||||
DefaultBufSize = 4096
|
||||
)
|
||||
|
||||
// streamID is the identifier of a stream.
|
||||
@@ -42,8 +46,6 @@ func newStreamID(id []byte) streamID {
|
||||
// readStream is the read side of the read pipe, which is implemented a
|
||||
// buffered wrapper around the core reader.
|
||||
type readStream struct {
|
||||
io.Reader
|
||||
|
||||
// parentStream is a pointer to the parent stream. We keep this around
|
||||
// so we can return the stream after we're done using it.
|
||||
parentStream *stream
|
||||
@@ -56,10 +58,22 @@ type readStream struct {
|
||||
// ReadNextMsg attempts to read the next message in the stream.
|
||||
//
|
||||
// NOTE: This will *block* until a new message is available.
|
||||
func (r *readStream) ReadNextMsg() ([]byte, error) {
|
||||
func (r *readStream) ReadNextMsg(ctx context.Context) ([]byte, error) {
|
||||
var reader io.Reader
|
||||
select {
|
||||
case b := <-r.parentStream.readBytesChan:
|
||||
reader = bytes.NewReader(b)
|
||||
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
|
||||
case err := <-r.parentStream.readErrChan:
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// First, we'll decode the length of the next message from the stream
|
||||
// so we know how many bytes we need to read.
|
||||
msgLen, err := tlv.ReadVarInt(r, &r.scratchBuf)
|
||||
msgLen, err := tlv.ReadVarInt(reader, &r.scratchBuf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -67,7 +81,7 @@ func (r *readStream) ReadNextMsg() ([]byte, error) {
|
||||
// Now that we know the length of the message, we'll make a limit
|
||||
// reader, then read all the encoded bytes until the EOF is emitted by
|
||||
// the reader.
|
||||
msgReader := io.LimitReader(r, int64(msgLen))
|
||||
msgReader := io.LimitReader(reader, int64(msgLen))
|
||||
return ioutil.ReadAll(msgReader)
|
||||
}
|
||||
|
||||
@@ -108,19 +122,18 @@ func (w *writeStream) WriteMsg(ctx context.Context, msg []byte) error {
|
||||
// As we're writing to a stream, we need to delimit each message with a
|
||||
// length prefix so the reader knows how many bytes to consume for each
|
||||
// message.
|
||||
//
|
||||
// TODO(roasbeef): actually needs to be single write?
|
||||
var buf bytes.Buffer
|
||||
msgSize := uint64(len(msg))
|
||||
err := tlv.WriteVarInt(
|
||||
w, msgSize, &w.scratchBuf,
|
||||
)
|
||||
if err != nil {
|
||||
if err := tlv.WriteVarInt(&buf, msgSize, &w.scratchBuf); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Next, we'll write the message directly to the stream.
|
||||
_, err = w.Write(msg)
|
||||
if err != nil {
|
||||
if _, err := buf.Write(msg); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if _, err := w.Write(buf.Bytes()); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -143,6 +156,10 @@ type stream struct {
|
||||
readStreamChan chan *readStream
|
||||
writeStreamChan chan *writeStream
|
||||
|
||||
readBytesChan chan []byte
|
||||
readErrChan chan error
|
||||
quit chan struct{}
|
||||
|
||||
// equivAuth is a method used to determine if an authentication
|
||||
// mechanism to tear down a stream is equivalent to the one used to
|
||||
// create it in the first place. WE use this to ensure that only the
|
||||
@@ -173,6 +190,9 @@ func newStream(id streamID, limiter *rate.Limiter,
|
||||
id: id,
|
||||
equivAuth: equivAuth,
|
||||
limiter: limiter,
|
||||
readBytesChan: make(chan []byte),
|
||||
readErrChan: make(chan error, 1),
|
||||
quit: make(chan struct{}),
|
||||
}
|
||||
|
||||
// Our tear down function will close the write side of the pipe, which
|
||||
@@ -183,6 +203,7 @@ func newStream(id streamID, limiter *rate.Limiter,
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
close(s.quit)
|
||||
s.wg.Wait()
|
||||
return nil
|
||||
}
|
||||
@@ -196,21 +217,37 @@ func newStream(id streamID, limiter *rate.Limiter,
|
||||
// will read from.
|
||||
_, err := io.Copy(
|
||||
readWritePipe,
|
||||
// This is where the buffering will happen, as the
|
||||
// writer writes to the write end of the pipe, this
|
||||
// goroutine will copy the bytes into the buffer until
|
||||
// its full, then attempt to write it to the write end
|
||||
// of the read pipe.
|
||||
bufio.NewReader(writeReadPipe),
|
||||
writeReadPipe,
|
||||
)
|
||||
_ = readWritePipe.CloseWithError(err)
|
||||
_ = writeReadPipe.CloseWithError(err)
|
||||
}()
|
||||
|
||||
s.wg.Add(1)
|
||||
go func() {
|
||||
defer s.wg.Done()
|
||||
|
||||
var buf [DefaultBufSize]byte
|
||||
for {
|
||||
numBytes, err := readReadPipe.Read(buf[:])
|
||||
if err != nil {
|
||||
s.readErrChan <- err
|
||||
return
|
||||
}
|
||||
|
||||
c := make([]byte, numBytes)
|
||||
copy(c, buf[0:numBytes])
|
||||
|
||||
select {
|
||||
case s.readBytesChan <- c:
|
||||
case <-s.quit:
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// We'll now initialize our stream by sending the read and write ends
|
||||
// to their respective holding channels.
|
||||
s.readStreamChan <- &readStream{
|
||||
Reader: readReadPipe,
|
||||
parentStream: s,
|
||||
}
|
||||
s.writeStreamChan <- &writeStream{
|
||||
@@ -555,6 +592,7 @@ func (h *hashMailServer) SendStream(readStream hashmailrpc.HashMail_SendStreamSe
|
||||
// exit before shutting down.
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
log.Debugf("SendStream: Context done, exiting")
|
||||
return nil
|
||||
case <-h.quit:
|
||||
return fmt.Errorf("server shutting down")
|
||||
@@ -564,6 +602,8 @@ func (h *hashMailServer) SendStream(readStream hashmailrpc.HashMail_SendStreamSe
|
||||
|
||||
cipherBox, err := readStream.Recv()
|
||||
if err != nil {
|
||||
log.Debugf("SendStream: Exiting write stream RPC "+
|
||||
"stream read: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -608,7 +648,7 @@ func (h *hashMailServer) RecvStream(desc *hashmailrpc.CipherBoxDesc,
|
||||
default:
|
||||
}
|
||||
|
||||
nextMsg, err := readStream.ReadNextMsg()
|
||||
nextMsg, err := readStream.ReadNextMsg(reader.Context())
|
||||
if err != nil {
|
||||
log.Debugf("Got error an read stream read: %v", err)
|
||||
return err
|
||||
|
||||
192
hashmail_server_test.go
Normal file
192
hashmail_server_test.go
Normal file
@@ -0,0 +1,192 @@
|
||||
package aperture
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/lightninglabs/lightning-node-connect/hashmailrpc"
|
||||
"github.com/lightningnetwork/lnd/build"
|
||||
"github.com/lightningnetwork/lnd/lntest/wait"
|
||||
"github.com/lightningnetwork/lnd/signal"
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/grpc"
|
||||
)
|
||||
|
||||
var (
|
||||
testApertureAddress = "localhost:8082"
|
||||
testSID = streamID{1, 2, 3}
|
||||
testStreamDesc = &hashmailrpc.CipherBoxDesc{
|
||||
StreamId: testSID[:],
|
||||
}
|
||||
testMessage = []byte("I'm a message!")
|
||||
apertureStartTimeout = 3 * time.Second
|
||||
)
|
||||
|
||||
func TestHashMailServerReturnStream(t *testing.T) {
|
||||
ctxb := context.Background()
|
||||
|
||||
setupAperture(t)
|
||||
|
||||
// Create a client and connect it to the server.
|
||||
conn, err := grpc.Dial(testApertureAddress, grpc.WithInsecure())
|
||||
require.NoError(t, err)
|
||||
client := hashmailrpc.NewHashMailClient(conn)
|
||||
|
||||
// We'll create a new cipher box that we're going to subscribe to
|
||||
// multiple times to check disconnecting returns the read stream.
|
||||
resp, err := client.NewCipherBox(ctxb, &hashmailrpc.CipherBoxAuth{
|
||||
Auth: &hashmailrpc.CipherBoxAuth_LndAuth{},
|
||||
Desc: testStreamDesc,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp.GetSuccess())
|
||||
|
||||
// First we make sure there is something to read on the other end of
|
||||
// that stream by writing something to it.
|
||||
sendCtx, sendCancel := context.WithCancel(context.Background())
|
||||
defer sendCancel()
|
||||
|
||||
writeStream, err := client.SendStream(sendCtx)
|
||||
require.NoError(t, err)
|
||||
err = writeStream.Send(&hashmailrpc.CipherBox{
|
||||
Desc: testStreamDesc,
|
||||
Msg: testMessage,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// We need to wait a bit to make sure the message is really sent.
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Connect, wait for the stream to be ready, read something, then
|
||||
// disconnect immediately.
|
||||
msg, err := readMsgFromStream(t, client)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, testMessage, msg.Msg)
|
||||
|
||||
// Make sure we can connect again immediately and try to read something.
|
||||
// There is no message to read before we cancel the request so we expect
|
||||
// an EOF error to be returned upon connection close/context cancel.
|
||||
_, err = readMsgFromStream(t, client)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "context canceled")
|
||||
|
||||
// Send then receive yet another message to make sure the stream is
|
||||
// still operational.
|
||||
testMessage2 := append(testMessage, []byte("test")...)
|
||||
err = writeStream.Send(&hashmailrpc.CipherBox{
|
||||
Desc: testStreamDesc,
|
||||
Msg: testMessage2,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// We need to wait a bit to make sure the message is really sent.
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
msg, err = readMsgFromStream(t, client)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, testMessage2, msg.Msg)
|
||||
}
|
||||
|
||||
func setupAperture(t *testing.T) {
|
||||
logWriter := build.NewRotatingLogWriter()
|
||||
interceptor, err := signal.Intercept()
|
||||
require.NoError(t, err)
|
||||
|
||||
SetupLoggers(logWriter, interceptor)
|
||||
|
||||
err = build.ParseAndSetDebugLevels("trace,PRXY=warn", logWriter)
|
||||
require.NoError(t, err)
|
||||
|
||||
apertureCfg := &Config{
|
||||
Insecure: true,
|
||||
ListenAddr: testApertureAddress,
|
||||
Authenticator: &AuthConfig{
|
||||
Disable: true,
|
||||
},
|
||||
Etcd: &EtcdConfig{},
|
||||
HashMail: &HashMailConfig{
|
||||
Enabled: true,
|
||||
MessageRate: time.Millisecond,
|
||||
MessageBurstAllowance: math.MaxUint32,
|
||||
},
|
||||
}
|
||||
aperture := NewAperture(apertureCfg)
|
||||
errChan := make(chan error)
|
||||
require.NoError(t, aperture.Start(errChan))
|
||||
|
||||
// Any error while starting?
|
||||
select {
|
||||
case err := <-errChan:
|
||||
t.Fatalf("error starting aperture: %v", err)
|
||||
default:
|
||||
}
|
||||
|
||||
err = wait.NoError(func() error {
|
||||
apertureAddr := fmt.Sprintf("http://%s/dummy",
|
||||
testApertureAddress)
|
||||
|
||||
resp, err := http.Get(apertureAddr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusNotFound {
|
||||
return fmt.Errorf("invalid status: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
return nil
|
||||
}, apertureStartTimeout)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func readMsgFromStream(t *testing.T,
|
||||
client hashmailrpc.HashMailClient) (*hashmailrpc.CipherBox, error) {
|
||||
|
||||
ctxc, cancel := context.WithCancel(context.Background())
|
||||
readStream, err := client.RecvStream(ctxc, testStreamDesc)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Wait a bit again to make sure the request is actually sent before our
|
||||
// context is canceled already again.
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// We'll start a read on the stream in the background.
|
||||
var (
|
||||
goroutineStarted = make(chan struct{})
|
||||
resultChan = make(chan *hashmailrpc.CipherBox)
|
||||
errChan = make(chan error)
|
||||
)
|
||||
go func() {
|
||||
close(goroutineStarted)
|
||||
box, err := readStream.Recv()
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
resultChan <- box
|
||||
}()
|
||||
|
||||
// Give the goroutine a chance to actually run, so block the main thread
|
||||
// until it did.
|
||||
<-goroutineStarted
|
||||
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
|
||||
// Now close and cancel the stream to make sure the server can clean it
|
||||
// up and release it.
|
||||
require.NoError(t, readStream.CloseSend())
|
||||
cancel()
|
||||
|
||||
// Interpret the result.
|
||||
select {
|
||||
case err := <-errChan:
|
||||
return nil, err
|
||||
|
||||
case box := <-resultChan:
|
||||
return box, nil
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user