Files
aperture/hashmail_server.go
Elle Mouton e60b09eb51 multi: add mailbox read counter
In this commit, we add a mailbox-read-count metric. This will be
incremented each time a mailbox with an _odd_ stream ID is read from.
We do this because we assume that a full duplex connection is being used
meaning that there will be 2 streams that have a matching ID except for
the last byte. And so to avoid duplicating the data, we only record the
odd streams. We also assume that for every read, there will be a write
and so we only record the reads.
2022-04-05 09:46:28 +02:00

713 lines
19 KiB
Go

package aperture
import (
"bytes"
"context"
"fmt"
"io"
"io/ioutil"
"sync"
"time"
"github.com/lightninglabs/lightning-node-connect/hashmailrpc"
"github.com/lightningnetwork/lnd/tlv"
"github.com/prometheus/client_golang/prometheus"
"golang.org/x/time/rate"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)
const (
// DefaultMsgRate is the default message rate for a given mailbox that
// we'll allow. We'll allow one message every 500 milliseconds, or 2
// messages per second.
DefaultMsgRate = time.Millisecond * 500
// DefaultMsgBurstAllowance is the default burst rate that we'll allow
// for messages. If a new message is about to exceed the burst rate,
// then we'll allow it up to this burst allowance.
DefaultMsgBurstAllowance = 10
// DefaultBufSize is the default number of bytes that are read in a
// single operation.
DefaultBufSize = 4096
)
// streamID is the identifier of a stream.
type streamID [64]byte
// newStreamID creates a new stream given an ID as a byte slice.
func newStreamID(id []byte) streamID {
var s streamID
copy(s[:], id)
return s
}
// baseID returns the first 16 bytes of the streamID. This part of the ID will
// overlap for the two streams in a bidirectional pair.
func (s *streamID) baseID() [16]byte {
var id [16]byte
copy(id[:], s[:16])
return id
}
// isOdd returns true if the streamID is an odd number.
func (s *streamID) isOdd() bool {
return s[63]&0x01 == 0x01
}
// readStream is the read side of the read pipe, which is implemented a
// buffered wrapper around the core reader.
type readStream struct {
// parentStream is a pointer to the parent stream. We keep this around
// so we can return the stream after we're done using it.
parentStream *stream
// scratchBuf is a scratch buffer we'll use for decoding message from
// the stream.
scratchBuf [8]byte
}
// ReadNextMsg attempts to read the next message in the stream.
//
// NOTE: This will *block* until a new message is available.
func (r *readStream) ReadNextMsg(ctx context.Context) ([]byte, error) {
var reader io.Reader
select {
case b := <-r.parentStream.readBytesChan:
reader = bytes.NewReader(b)
case <-ctx.Done():
return nil, ctx.Err()
case err := <-r.parentStream.readErrChan:
return nil, err
}
// First, we'll decode the length of the next message from the stream
// so we know how many bytes we need to read.
msgLen, err := tlv.ReadVarInt(reader, &r.scratchBuf)
if err != nil {
return nil, err
}
// Now that we know the length of the message, we'll make a limit
// reader, then read all the encoded bytes until the EOF is emitted by
// the reader.
msgReader := io.LimitReader(reader, int64(msgLen))
return ioutil.ReadAll(msgReader)
}
// ReturnStream gives up the read stream by passing it back up through the
// payment stream.
func (r *readStream) ReturnStream() {
log.Debugf("Returning read stream %x", r.parentStream.id[:])
r.parentStream.ReturnReadStream(r)
}
// writeStream is the write side of the read pipe. The stream itself is a
// buffered I/O wrapper around the write end of the io.Writer pipe.
type writeStream struct {
io.Writer
// parentStream is a pointer to the parent stream. We keep this around
// so we can return the stream after we're done using it.
parentStream *stream
// scratchBuf is a scratch buffer we'll use for decoding message from
// the stream.
scratchBuf [8]byte
}
// WriteMsg attempts to write a message to the stream so it can be read using
// the read end of the stream.
//
// NOTE: If the buffer is full, then this call will block until the reader
// consumes bytes from the other end.
func (w *writeStream) WriteMsg(ctx context.Context, msg []byte) error {
// Wait until until we have enough available event slots to write to
// the stream. This'll return an error if the referneded context has
// been cancelled.
if err := w.parentStream.limiter.Wait(ctx); err != nil {
return err
}
// As we're writing to a stream, we need to delimit each message with a
// length prefix so the reader knows how many bytes to consume for each
// message.
var buf bytes.Buffer
msgSize := uint64(len(msg))
if err := tlv.WriteVarInt(&buf, msgSize, &w.scratchBuf); err != nil {
return err
}
// Next, we'll write the message directly to the stream.
if _, err := buf.Write(msg); err != nil {
return err
}
if _, err := w.Write(buf.Bytes()); err != nil {
return err
}
return nil
}
// ReturnStream returns the write stream back to the parent stream.
func (w *writeStream) ReturnStream() {
w.parentStream.ReturnWriteStream(w)
}
// stream is a unique pipe implemented using a subscription server, and expose
// over gRPC. Only a single writer and reader can exist within the stream at
// any given time.
type stream struct {
sync.Mutex
id streamID
readStreamChan chan *readStream
writeStreamChan chan *writeStream
readBytesChan chan []byte
readErrChan chan error
quit chan struct{}
// equivAuth is a method used to determine if an authentication
// mechanism to tear down a stream is equivalent to the one used to
// create it in the first place. WE use this to ensure that only the
// original creator of a stream can tear it down.
equivAuth func(auth *hashmailrpc.CipherBoxAuth) error
tearDown func() error
wg sync.WaitGroup
limiter *rate.Limiter
}
// newStream creates a new stream independent of any given stream ID.
func newStream(id streamID, limiter *rate.Limiter,
equivAuth func(auth *hashmailrpc.CipherBoxAuth) error) *stream {
// Our stream is actually just a plain io.Pipe. This allows us to avoid
// having to do things like rate limiting, etc as we can limit the
// buffer size. In order to allow non-blocking writes (up to the buffer
// size), but blocking reads, we'll utilize a series of two pipes.
writeReadPipe, writeWritePipe := io.Pipe()
readReadPipe, readWritePipe := io.Pipe()
s := &stream{
readStreamChan: make(chan *readStream, 1),
writeStreamChan: make(chan *writeStream, 1),
id: id,
equivAuth: equivAuth,
limiter: limiter,
readBytesChan: make(chan []byte),
readErrChan: make(chan error, 1),
quit: make(chan struct{}),
}
// Our tear down function will close the write side of the pipe, which
// will cause the goroutine below to get an EOF error when reading,
// which will cause it to close the other ends of the pipe.
s.tearDown = func() error {
err := writeWritePipe.Close()
if err != nil {
return err
}
close(s.quit)
s.wg.Wait()
return nil
}
s.wg.Add(1)
go func() {
defer s.wg.Done()
// Next, we'll launch a goroutine to copy over the bytes from
// the pipe the writer will write to into the pipe the reader
// will read from.
_, err := io.Copy(
readWritePipe,
writeReadPipe,
)
_ = readWritePipe.CloseWithError(err)
_ = writeReadPipe.CloseWithError(err)
}()
s.wg.Add(1)
go func() {
defer s.wg.Done()
var buf [DefaultBufSize]byte
for {
numBytes, err := readReadPipe.Read(buf[:])
if err != nil {
s.readErrChan <- err
return
}
c := make([]byte, numBytes)
copy(c, buf[0:numBytes])
for numBytes == DefaultBufSize {
numBytes, err = readReadPipe.Read(buf[:])
if err != nil {
s.readErrChan <- err
return
}
c = append(c, buf[0:numBytes]...)
}
select {
case s.readBytesChan <- c:
case <-s.quit:
}
}
}()
// We'll now initialize our stream by sending the read and write ends
// to their respective holding channels.
s.readStreamChan <- &readStream{
parentStream: s,
}
s.writeStreamChan <- &writeStream{
Writer: writeWritePipe,
parentStream: s,
}
return s
}
// ReturnReadStream returns the target read stream back to its holding channel.
func (s *stream) ReturnReadStream(r *readStream) {
s.readStreamChan <- r
}
// ReturnWriteStream returns the target write stream back to its holding
// channel.
func (s *stream) ReturnWriteStream(w *writeStream) {
s.writeStreamChan <- w
}
// RequestReadStream attempts to request the read stream from the main backing
// stream. If we're unable to obtain it before the timeout, then an error is
// returned.
func (s *stream) RequestReadStream() (*readStream, error) {
log.Tracef("HashMailStream(%x): requesting read stream", s.id[:])
select {
case r := <-s.readStreamChan:
return r, nil
default:
return nil, fmt.Errorf("read stream occupied")
}
}
// RequestWriteStream attempts to request the read stream from the main backing
// stream. If we're unable to obtain it before the timeout, then an error is
// returned.
func (s *stream) RequestWriteStream() (*writeStream, error) {
log.Tracef("HashMailStream(%x): requesting write stream", s.id[:])
select {
case w := <-s.writeStreamChan:
return w, nil
default:
return nil, fmt.Errorf("write stream occupied")
}
}
// hashMailServerConfig is the main config of the mail server.
type hashMailServerConfig struct {
msgRate time.Duration
msgBurstAllowance int
}
// hashMailServer is an implementation of the HashMailServer gRPC service that
// implements a simple encrypted mailbox implemented as a series of read and
// write pipes.
type hashMailServer struct {
hashmailrpc.UnimplementedHashMailServer
sync.RWMutex
streams map[streamID]*stream
// TODO(roasbeef): index to keep track of total stream tallies
quit chan struct{}
cfg hashMailServerConfig
}
// newHashMailServer returns a new mail server instance given a valid config.
func newHashMailServer(cfg hashMailServerConfig) *hashMailServer {
if cfg.msgRate == 0 {
cfg.msgRate = DefaultMsgRate
}
if cfg.msgBurstAllowance == 0 {
cfg.msgBurstAllowance = DefaultMsgBurstAllowance
}
return &hashMailServer{
streams: make(map[streamID]*stream),
quit: make(chan struct{}),
cfg: cfg,
}
}
// Stop attempts to gracefully stop the server by cancelling all pending user
// streams and any goroutines active feeding off them.
func (h *hashMailServer) Stop() {
h.Lock()
defer h.Unlock()
for _, stream := range h.streams {
if err := stream.tearDown(); err != nil {
log.Warnf("unable to tear down stream: %v", err)
}
}
}
// ValidateStreamAuth attempts to validate the authentication mechanism that is
// being used to claim or revoke a stream within the mail server.
func (h *hashMailServer) ValidateStreamAuth(ctx context.Context,
init *hashmailrpc.CipherBoxAuth) error {
// TODO(guggero): Implement auth.
if true {
return nil
}
// TODO(roasbeef): throttle the number of streams a given
// ticket/account can have
return nil
}
// InitStream attempts to initialize a new stream given a valid descriptor.
func (h *hashMailServer) InitStream(
init *hashmailrpc.CipherBoxAuth) (*hashmailrpc.CipherInitResp, error) {
h.Lock()
defer h.Unlock()
streamID := newStreamID(init.Desc.StreamId)
log.Debugf("Creating new HashMail Stream: %x", streamID)
// The stream is already active, and we only allow a single session for
// a given stream to exist.
if _, ok := h.streams[streamID]; ok {
return nil, status.Error(codes.AlreadyExists, "stream "+
"already active")
}
// TODO(roasbeef): validate that ticket or node doesn't already have
// the same stream going
limiter := rate.NewLimiter(
rate.Every(h.cfg.msgRate), h.cfg.msgBurstAllowance,
)
freshStream := newStream(
streamID, limiter, func(auth *hashmailrpc.CipherBoxAuth) error {
return nil
},
)
h.streams[streamID] = freshStream
mailboxCount.Set(float64(len(h.streams)))
return &hashmailrpc.CipherInitResp{
Resp: &hashmailrpc.CipherInitResp_Success{},
}, nil
}
// LookUpReadStream attempts to loop up a new stream. If the stream is found, then
// the stream is marked as being active. Otherwise, an error is returned.
func (h *hashMailServer) LookUpReadStream(streamID []byte) (*readStream, error) {
h.RLock()
defer h.RUnlock()
stream, ok := h.streams[newStreamID(streamID)]
if !ok {
return nil, fmt.Errorf("stream not found")
}
return stream.RequestReadStream()
}
// LookUpWriteStream attempts to loop up a new stream. If the stream is found,
// then the stream is marked as being active. Otherwise, an error is returned.
func (h *hashMailServer) LookUpWriteStream(streamID []byte) (*writeStream, error) {
h.RLock()
defer h.RUnlock()
stream, ok := h.streams[newStreamID(streamID)]
if !ok {
return nil, fmt.Errorf("stream not found")
}
return stream.RequestWriteStream()
}
// TearDownStream attempts to tear down a stream which renders both sides of
// the stream unusable and also reclaims resources.
func (h *hashMailServer) TearDownStream(ctx context.Context, streamID []byte,
auth *hashmailrpc.CipherBoxAuth) error {
h.Lock()
defer h.Unlock()
sid := newStreamID(streamID)
stream, ok := h.streams[sid]
if !ok {
return fmt.Errorf("stream not found")
}
// We'll ensure that the same authentication type is used, to ensure
// only the creator can tear down a stream they created.
if err := stream.equivAuth(auth); err != nil {
return fmt.Errorf("invalid auth: %v", err)
}
// Now that we know the auth type has matched up, we'll validate the
// authentication mechanism as normal.
if err := h.ValidateStreamAuth(ctx, auth); err != nil {
return err
}
log.Debugf("Tearing down HashMail stream: id=%x, auth=%v",
auth.Desc.StreamId, auth.Auth)
// At this point we know the auth was valid, so we'll tear down the
// stream.
if err := stream.tearDown(); err != nil {
return err
}
delete(h.streams, sid)
mailboxCount.Set(float64(len(h.streams)))
return nil
}
// validateAuthReq does some basic sanity checks on incoming auth methods.
func validateAuthReq(req *hashmailrpc.CipherBoxAuth) error {
switch {
case req.Desc == nil:
return fmt.Errorf("cipher box descriptor required")
case req.Desc.StreamId == nil:
return fmt.Errorf("stream_id required")
case req.Auth == nil:
return fmt.Errorf("auth type required")
default:
return nil
}
}
// NewCipherBox attempts to create a new cipher box stream given a valid
// authentication mechanism. This call may fail if the stream is already
// active, or the authentication mechanism invalid.
func (h *hashMailServer) NewCipherBox(ctx context.Context,
init *hashmailrpc.CipherBoxAuth) (*hashmailrpc.CipherInitResp, error) {
// Before we try to process the request, we'll do some basic user input
// validation.
if err := validateAuthReq(init); err != nil {
return nil, err
}
log.Debugf("New HashMail stream init: id=%x, auth=%v",
init.Desc.StreamId, init.Auth)
if err := h.ValidateStreamAuth(ctx, init); err != nil {
log.Debugf("Stream creation validation failed (id=%x): %v",
init.Desc.StreamId, err)
return nil, err
}
resp, err := h.InitStream(init)
if err != nil {
return nil, err
}
return resp, nil
}
// DelCipherBox attempts to tear down an existing cipher box pipe. The same
// authentication mechanism used to initially create the stream MUST be
// specified.
func (h *hashMailServer) DelCipherBox(ctx context.Context,
auth *hashmailrpc.CipherBoxAuth) (*hashmailrpc.DelCipherBoxResp, error) {
// Before we try to process the request, we'll do some basic user input
// validation.
if err := validateAuthReq(auth); err != nil {
return nil, err
}
log.Debugf("New HashMail stream deletion: id=%x, auth=%v",
auth.Desc.StreamId, auth.Auth)
if err := h.TearDownStream(ctx, auth.Desc.StreamId, auth); err != nil {
return nil, err
}
return &hashmailrpc.DelCipherBoxResp{}, nil
}
// SendStream implements the client streaming call to utilize the write end of
// a stream to send a message to the read end.
func (h *hashMailServer) SendStream(readStream hashmailrpc.HashMail_SendStreamServer) error {
log.Debugf("New HashMail write stream pending...")
// We'll need to receive the first message in order to determine if
// this stream exists or not
//
// TODO(roasbeef): better way to control?
cipherBox, err := readStream.Recv()
if err != nil {
return err
}
switch {
case cipherBox.Desc == nil:
return fmt.Errorf("cipher box descriptor required")
case cipherBox.Desc.StreamId == nil:
return fmt.Errorf("stream_id required")
}
log.Debugf("New HashMail write stream: id=%x",
cipherBox.Desc.StreamId)
// Now that we have the first message, we can attempt to look up the
// given stream.
writeStream, err := h.LookUpWriteStream(cipherBox.Desc.StreamId)
if err != nil {
return err
}
// Now that we know the stream is found, we'll make sure to mark the
// write inactive if the client hangs up on their end.
defer writeStream.ReturnStream()
log.Tracef("Sending msg_len=%v to stream_id=%x", len(cipherBox.Msg),
cipherBox.Desc.StreamId)
// We'll send the first message into the stream, then enter our loop
// below to continue to read from the stream and send it to the read
// end.
ctx := readStream.Context()
if err := writeStream.WriteMsg(ctx, cipherBox.Msg); err != nil {
return err
}
for {
// Check to see if the stream has been closed or if we need to
// exit before shutting down.
select {
case <-ctx.Done():
log.Debugf("SendStream: Context done, exiting")
return nil
case <-h.quit:
return fmt.Errorf("server shutting down")
default:
}
cipherBox, err := readStream.Recv()
if err != nil {
log.Debugf("SendStream: Exiting write stream RPC "+
"stream read: %v", err)
return err
}
log.Tracef("Sending msg_len=%v to stream_id=%x",
len(cipherBox.Msg), cipherBox.Desc.StreamId)
if err := writeStream.WriteMsg(ctx, cipherBox.Msg); err != nil {
return err
}
}
}
// RecvStream implements the read end of the stream. A single client will have
// all messages written to the opposite side of the stream written to it for
// consumption.
func (h *hashMailServer) RecvStream(desc *hashmailrpc.CipherBoxDesc,
reader hashmailrpc.HashMail_RecvStreamServer) error {
// First, we'll attempt to locate the stream. We allow any single
// entity that knows of the full stream ID to access the read end.
readStream, err := h.LookUpReadStream(desc.StreamId)
if err != nil {
return err
}
log.Debugf("New HashMail read stream: id=%x", desc.StreamId)
// If the reader hangs up, then we'll mark the stream as inactive so
// another can take its place.
defer readStream.ReturnStream()
for {
// Check to see if the stream has been closed or if we need to
// exit before shutting down.
select {
case <-reader.Context().Done():
log.Debugf("Read stream context done.")
return nil
case <-h.quit:
return fmt.Errorf("server shutting down")
default:
}
nextMsg, err := readStream.ReadNextMsg(reader.Context())
if err != nil {
log.Debugf("Got error an read stream read: %v", err)
return err
}
log.Tracef("Read %v bytes for HashMail stream_id=%x",
len(nextMsg), desc.StreamId)
// In order not to duplicate metric data, we only record this
// read if its streamID is odd. We use the base stream ID as the
// label. For this to work, it is expected that the read and
// write streams of bidirectional pair have the same IDs with
// the last bit flipped for one of them.
streamID := newStreamID(desc.StreamId)
if streamID.isOdd() {
baseID := streamID.baseID()
mailboxReadCount.With(prometheus.Labels{
streamIDLabel: fmt.Sprintf("%x", baseID),
}).Inc()
}
err = reader.Send(&hashmailrpc.CipherBox{
Desc: desc,
Msg: nextMsg,
})
if err != nil {
log.Debugf("Got error when sending on read stream: %v",
err)
return err
}
}
}
var _ hashmailrpc.HashMailServer = (*hashMailServer)(nil)