From 870a61765778d6a01ec08c2583948eb6fcd133da Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Mon, 26 Sep 2022 10:28:39 +0200 Subject: [PATCH] multi: gc stale mailboxes In this commit, we start a timer if a mailbox stream is completely un-occupied (neither read or write stream is occupied). The timer stopped if either of the streams are occupied and is reset if both streams are unoccupied. --- aperture.go | 1 + config.go | 1 + hashmail_server.go | 138 ++++++++++++++++++++- hashmail_server_test.go | 263 ++++++++++++++++++++++++++++++++++++++++ 4 files changed, 400 insertions(+), 3 deletions(-) diff --git a/aperture.go b/aperture.go index 3da60ba..86e6ed8 100644 --- a/aperture.go +++ b/aperture.go @@ -729,6 +729,7 @@ func createHashMailServer(cfg *Config) ([]proxy.LocalService, func(), error) { hashMailServer := newHashMailServer(hashMailServerConfig{ msgRate: cfg.HashMail.MessageRate, msgBurstAllowance: cfg.HashMail.MessageBurstAllowance, + staleTimeout: cfg.HashMail.StaleTimeout, }) hashMailGRPC := grpc.NewServer(serverOpts...) hashmailrpc.RegisterHashMailServer(hashMailGRPC, hashMailServer) diff --git a/config.go b/config.go index e95b6e6..10f4596 100644 --- a/config.go +++ b/config.go @@ -64,6 +64,7 @@ type HashMailConfig struct { Enabled bool `long:"enabled"` MessageRate time.Duration `long:"messagerate" description:"The average minimum time that should pass between each message."` MessageBurstAllowance int `long:"messageburstallowance" description:"The burst rate we allow for messages."` + StaleTimeout time.Duration `long:"staletimeout" description:"The time after the last activity that a mailbox should be removed. Set to -1s to disable. "` } type TorConfig struct { diff --git a/hashmail_server.go b/hashmail_server.go index cb432d4..cf41353 100644 --- a/hashmail_server.go +++ b/hashmail_server.go @@ -28,6 +28,10 @@ const ( // then we'll allow it up to this burst allowance. DefaultMsgBurstAllowance = 10 + // DefaultStaleTimeout is the time after which a mailbox will be torn + // down if neither of its streams are occupied. + DefaultStaleTimeout = time.Hour + // DefaultBufSize is the default number of bytes that are read in a // single operation. DefaultBufSize = 4096 @@ -185,11 +189,14 @@ type stream struct { wg sync.WaitGroup limiter *rate.Limiter + + status *streamStatus } // newStream creates a new stream independent of any given stream ID. func newStream(id streamID, limiter *rate.Limiter, - equivAuth func(auth *hashmailrpc.CipherBoxAuth) error) *stream { + equivAuth func(auth *hashmailrpc.CipherBoxAuth) error, + onStale func() error, staleTimeout time.Duration) *stream { // Our stream is actually just a plain io.Pipe. This allows us to avoid // having to do things like rate limiting, etc as we can limit the @@ -204,6 +211,7 @@ func newStream(id streamID, limiter *rate.Limiter, id: id, equivAuth: equivAuth, limiter: limiter, + status: newStreamStatus(onStale, staleTimeout), readBytesChan: make(chan []byte), readErrChan: make(chan error, 1), quit: make(chan struct{}), @@ -213,6 +221,7 @@ func newStream(id streamID, limiter *rate.Limiter, // will cause the goroutine below to get an EOF error when reading, // which will cause it to close the other ends of the pipe. s.tearDown = func() error { + s.status.stop() err := writeWritePipe.Close() if err != nil { return err @@ -284,12 +293,14 @@ func newStream(id streamID, limiter *rate.Limiter, // ReturnReadStream returns the target read stream back to its holding channel. func (s *stream) ReturnReadStream(r *readStream) { s.readStreamChan <- r + s.status.streamReturned(true) } // ReturnWriteStream returns the target write stream back to its holding // channel. func (s *stream) ReturnWriteStream(w *writeStream) { s.writeStreamChan <- w + s.status.streamReturned(false) } // RequestReadStream attempts to request the read stream from the main backing @@ -300,6 +311,7 @@ func (s *stream) RequestReadStream() (*readStream, error) { select { case r := <-s.readStreamChan: + s.status.streamTaken(true) return r, nil default: return nil, fmt.Errorf("read stream occupied") @@ -314,6 +326,7 @@ func (s *stream) RequestWriteStream() (*writeStream, error) { select { case w := <-s.writeStreamChan: + s.status.streamTaken(false) return w, nil default: return nil, fmt.Errorf("write stream occupied") @@ -324,6 +337,7 @@ func (s *stream) RequestWriteStream() (*writeStream, error) { type hashMailServerConfig struct { msgRate time.Duration msgBurstAllowance int + staleTimeout time.Duration } // hashMailServer is an implementation of the HashMailServer gRPC service that @@ -350,6 +364,9 @@ func newHashMailServer(cfg hashMailServerConfig) *hashMailServer { if cfg.msgBurstAllowance == 0 { cfg.msgBurstAllowance = DefaultMsgBurstAllowance } + if cfg.staleTimeout == 0 { + cfg.staleTimeout = DefaultStaleTimeout + } return &hashMailServer{ streams: make(map[streamID]*stream), @@ -372,6 +389,29 @@ func (h *hashMailServer) Stop() { } +// tearDownStaleStream can be used to tear down a stale mailbox stream. +func (h *hashMailServer) tearDownStaleStream(id streamID) error { + log.Debugf("Tearing down stale HashMail stream: id=%x", id) + + h.Lock() + defer h.Unlock() + + stream, ok := h.streams[id] + if !ok { + return fmt.Errorf("stream not found") + } + + if err := stream.tearDown(); err != nil { + return err + } + + delete(h.streams, id) + + mailboxCount.Set(float64(len(h.streams))) + + return nil +} + // ValidateStreamAuth attempts to validate the authentication mechanism that is // being used to claim or revoke a stream within the mail server. func (h *hashMailServer) ValidateStreamAuth(ctx context.Context, @@ -415,7 +455,9 @@ func (h *hashMailServer) InitStream( freshStream := newStream( streamID, limiter, func(auth *hashmailrpc.CipherBoxAuth) error { return nil - }, + }, func() error { + return h.tearDownStaleStream(streamID) + }, h.cfg.staleTimeout, ) h.streams[streamID] = freshStream @@ -430,7 +472,6 @@ func (h *hashMailServer) InitStream( // LookUpReadStream attempts to loop up a new stream. If the stream is found, then // the stream is marked as being active. Otherwise, an error is returned. func (h *hashMailServer) LookUpReadStream(streamID []byte) (*readStream, error) { - h.RLock() defer h.RUnlock() @@ -710,3 +751,94 @@ func (h *hashMailServer) RecvStream(desc *hashmailrpc.CipherBoxDesc, } var _ hashmailrpc.HashMailServer = (*hashMailServer)(nil) + +// streamStatus keeps track of the occupancy status of a stream's read and +// write sub-streams. It is initialised with callback functions to call on the +// event of the streams being occupied (either or both of the streams are +// occupied) or fully idle (both streams are unoccupied). +type streamStatus struct { + disabled bool + + staleTimeout time.Duration + staleTimer *time.Timer + + readStreamOccupied bool + writeStreamOccupied bool + sync.Mutex +} + +// newStreamStatus constructs a new streamStatus instance. +func newStreamStatus(onStale func() error, + staleTimeout time.Duration) *streamStatus { + + if staleTimeout < 0 { + return &streamStatus{ + disabled: true, + } + } + + staleTimer := time.AfterFunc(staleTimeout, func() { + if err := onStale(); err != nil { + log.Errorf("error in onStale callback: %v", err) + } + }) + + return &streamStatus{ + staleTimer: staleTimer, + staleTimeout: staleTimeout, + } +} + +// stop cleans up any resources held by streamStatus. +func (s *streamStatus) stop() { + if s.disabled { + return + } + + s.Lock() + defer s.Unlock() + + _ = s.staleTimer.Stop() +} + +// streamTaken should be called when one of the sub-streams (read or write) +// become occupied. This will stop the staleTimer. The read parameter should be +// true if the stream being returned is the read stream. +func (s *streamStatus) streamTaken(read bool) { + if s.disabled { + return + } + + s.Lock() + defer s.Unlock() + + if read { + s.readStreamOccupied = true + } else { + s.writeStreamOccupied = true + } + _ = s.staleTimer.Stop() +} + +// streamReturned should be called when one of the sub-streams are released. +// If the occupancy count after this call is zero, then the staleTimer is reset. +// The read parameter should be true if the stream being returned is the read +// stream. +func (s *streamStatus) streamReturned(read bool) { + if s.disabled { + return + } + + s.Lock() + defer s.Unlock() + + if read { + s.readStreamOccupied = false + } else { + s.writeStreamOccupied = false + } + + if !s.readStreamOccupied && !s.writeStreamOccupied { + _ = s.staleTimer.Reset(s.staleTimeout) + } +} diff --git a/hashmail_server_test.go b/hashmail_server_test.go index 0c99b0a..c9db570 100644 --- a/hashmail_server_test.go +++ b/hashmail_server_test.go @@ -5,6 +5,7 @@ import ( "crypto/rand" "fmt" "math" + "net" "net/http" "testing" "time" @@ -13,8 +14,10 @@ import ( "github.com/lightningnetwork/lnd/build" "github.com/lightningnetwork/lnd/lntest/wait" "github.com/lightningnetwork/lnd/signal" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "google.golang.org/grpc" + "google.golang.org/grpc/test/bufconn" ) var ( @@ -242,3 +245,263 @@ func readMsgFromStream(t *testing.T, return box, nil } } + +type statusState struct { + readOccupied bool + writeOccupied bool +} + +// TestStaleMailboxCleanup tests that the streamStatus behaves as expected and +// that it correctly tears down a mailbox if it becomes stale. +func TestStaleMailboxCleanup(t *testing.T) { + tests := []struct { + name string + staleTimeout time.Duration + senderConnected statusState + readerConnected statusState + senderDisconnected statusState + expectStaleMailboxRemoval bool + }{ + { + name: "tear down stale mailbox", + staleTimeout: 500 * time.Millisecond, + senderConnected: statusState{ + writeOccupied: true, + }, + readerConnected: statusState{ + writeOccupied: true, + readOccupied: true, + }, + senderDisconnected: statusState{ + writeOccupied: false, + readOccupied: true, + }, + expectStaleMailboxRemoval: true, + }, + { + name: "dont tear down stale mailbox", + staleTimeout: -1, + senderConnected: statusState{ + writeOccupied: false, + readOccupied: false, + }, + readerConnected: statusState{ + writeOccupied: false, + readOccupied: false, + }, + senderDisconnected: statusState{ + writeOccupied: false, + readOccupied: false, + }, + }, + } + + for _, test := range tests { + test := test + t.Run(test.name, func(t *testing.T) { + ctx := context.Background() + + // Set up a new hashmail server. + hm := newHashMailHarness(t, hashMailServerConfig{ + staleTimeout: test.staleTimeout, + }) + + // Create two clients of the hashmail server. + conn1 := hm.newClientConn() + conn2 := hm.newClientConn() + + client1 := hashmailrpc.NewHashMailClient(conn1) + client2 := hashmailrpc.NewHashMailClient(conn2) + + // Let client 1 create a mailbox on the server. + resp, err := client1.NewCipherBox( + ctx, &hashmailrpc.CipherBoxAuth{ + Auth: &hashmailrpc.CipherBoxAuth_LndAuth{}, + Desc: testStreamDesc, + }, + ) + require.NoError(t, err) + require.NotNil(t, resp.GetSuccess()) + + // Assert that neither of the mailbox streams are + // occupied to start with. + hm.assertStreamsOccupied(statusState{ + readOccupied: false, + writeOccupied: false, + }) + + // Let client 1 take the send-stream and write to it. + err = sendToStream(client1) + require.NoError(t, err) + + hm.assertStreamsOccupied(test.senderConnected) + + // Let client 2 take the read stream and receive from + // it. + err = recvFromStream(client2) + require.NoError(t, err) + + hm.assertStreamsOccupied(test.readerConnected) + + // Ensure that attempting to take the read stream and + // receive from it while it is currently occupied will + // result in an error. + err = recvFromStream(client2) + require.Error(t, err) + assert.Contains(t, err.Error(), "read stream occupied") + + hm.assertStreamsOccupied(test.readerConnected) + + // Disconnect client 1. This should release the + // send-stream. + require.NoError(t, conn1.Close()) + hm.assertStreamsOccupied(test.senderDisconnected) + + // Disconnect client 1. This should release the + // read-stream. + require.NoError(t, conn2.Close()) + + // Assert that neither of the streams are occupied. + hm.assertStreamsOccupied(statusState{ + readOccupied: false, + writeOccupied: false, + }) + + // Assert that the stream is torn down. + hm.assertStreamExists(!test.expectStaleMailboxRemoval) + }) + } +} + +// hashMailHarness is a test harness that spins up a hashmail server for +// testing purposes. +type hashMailHarness struct { + t *testing.T + server *hashMailServer + lis *bufconn.Listener +} + +// newHashMailHarness spins up a new hashmail server and serves it on a bufconn +// listener. +func newHashMailHarness(t *testing.T, + cfg hashMailServerConfig) *hashMailHarness { + + hm := newHashMailServer(cfg) + + lis := bufconn.Listen(1024 * 1024) + hashMailGRPC := grpc.NewServer() + t.Cleanup(hashMailGRPC.Stop) + + hashmailrpc.RegisterHashMailServer(hashMailGRPC, hm) + go func() { + require.NoError(t, hashMailGRPC.Serve(lis)) + }() + + return &hashMailHarness{ + t: t, + server: hm, + lis: lis, + } +} + +// newClientConn creates a new client of the hashMailHarness server. +func (h *hashMailHarness) newClientConn() *grpc.ClientConn { + conn, err := grpc.Dial("bufnet", grpc.WithContextDialer( + func(ctx context.Context, s string) (net.Conn, error) { + return h.lis.Dial() + }), grpc.WithInsecure(), + ) + require.NoError(h.t, err) + h.t.Cleanup(func() { + _ = conn.Close() + }) + + return conn +} + +// assertStreamOccupied checks that the current state of the stream's read and +// writes streams are the same as the expected state. +func (h *hashMailHarness) assertStreamsOccupied(state statusState) { + err := wait.Predicate(func() bool { + h.server.Lock() + defer h.server.Unlock() + + stream, ok := h.server.streams[testSID] + if !ok { + return false + } + + stream.status.Lock() + defer stream.status.Unlock() + + if stream.status.readStreamOccupied != state.readOccupied { + return false + } + + return stream.status.writeStreamOccupied == state.writeOccupied + + }, time.Second) + require.NoError(h.t, err) +} + +// assertStreamExists ensures that the test stream does or does not exist +// depending on the value of the boolean passed in. +func (h *hashMailHarness) assertStreamExists(exists bool) { + err := wait.Predicate(func() bool { + h.server.Lock() + defer h.server.Unlock() + + _, ok := h.server.streams[testSID] + return ok == exists + + }, time.Second) + require.NoError(h.t, err) +} + +// sendToStream is a helper function that attempts to send dummy data to the +// test stream using the given client. +func sendToStream(client hashmailrpc.HashMailClient) error { + writeStream, err := client.SendStream(context.Background()) + if err != nil { + return err + } + + return writeStream.Send(&hashmailrpc.CipherBox{ + Desc: testStreamDesc, + Msg: testMessage, + }) +} + +// recvFromStream is a helper function that attempts to receive dummy data from +// the test stream using the given client. +func recvFromStream(client hashmailrpc.HashMailClient) error { + readStream, err := client.RecvStream( + context.Background(), testStreamDesc, + ) + if err != nil { + return err + } + + recvChan := make(chan *hashmailrpc.CipherBox) + errChan := make(chan error) + go func() { + box, err := readStream.Recv() + if err != nil { + errChan <- err + } + recvChan <- box + }() + + select { + case <-time.After(time.Second): + return fmt.Errorf("timed out waiting to receive from receive " + + "stream") + + case err := <-errChan: + return err + + case <-recvChan: + } + + return nil +}