diff --git a/auth/server_interceptor.go b/auth/server_interceptor.go index e06745c..b418248 100644 --- a/auth/server_interceptor.go +++ b/auth/server_interceptor.go @@ -3,6 +3,7 @@ package auth import ( "bytes" "context" + "fmt" "net/http" "github.com/lightninglabs/loop/lsat" @@ -28,9 +29,64 @@ func (i *ServerInterceptor) UnaryInterceptor(ctx context.Context, // the request to the handler if anything fails. Incoming calls with // invalid metadata will therefore just be treated as non-identified or // non-authenticated. + token, err := tokenFromContext(ctx) + if err != nil { + log.Debugf("No token extracted, error was: %v", err) + return handler(ctx, req) + } + + // We got a token, create a new context that wraps its value and + // continue the call chain by invoking the handler. + idCtx := AddToContext(ctx, KeyTokenID, *token) + return handler(idCtx, req) +} + +// wrappedStream is a thin wrapper around the grpc.ServerStream that allows us +// to overwrite the context of the stream. +type wrappedStream struct { + grpc.ServerStream + WrappedContext context.Context +} + +// Context returns the context for this stream. +func (w *wrappedStream) Context() context.Context { + return w.WrappedContext +} + +// StreamInterceptor is an stream gRPC server interceptor that inspects incoming +// streams for authentication tokens. If an LSAT authentication token is found +// in the initial stream establishment request, its token ID is extracted and +// treated as client ID. The extracted ID is then attached to the request +// context in a format that is easy to extract by request handlers. +func (i *ServerInterceptor) StreamInterceptor(srv interface{}, + ss grpc.ServerStream, _ *grpc.StreamServerInfo, + handler grpc.StreamHandler) error { + + // Try getting the authentication header embedded in the context meta + // data and parse it. We ignore all errors that happen and just forward + // the request to the handler if anything fails. Incoming calls with + // invalid metadata will therefore just be treated as non-identified or + // non-authenticated. + ctx := ss.Context() + token, err := tokenFromContext(ctx) + if err != nil { + log.Debugf("No token extracted, error was: %v", err) + return handler(srv, ss) + } + + // We got a token, create a new context that wraps its value and + // continue the call chain by invoking the handler. We can't directly + // modify the server stream so we have to wrap it. + idCtx := AddToContext(ctx, KeyTokenID, *token) + wrappedStream := &wrappedStream{ss, idCtx} + return handler(srv, wrappedStream) +} + +// tokenFromContext tries to extract the LSAT from a context. +func tokenFromContext(ctx context.Context) (*lsat.TokenID, error) { md, ok := metadata.FromIncomingContext(ctx) if !ok { - return handler(ctx, req) + return nil, fmt.Errorf("context contains no metadata") } header := &http.Header{ HeaderAuthorization: md.Get(HeaderAuthorization), @@ -39,20 +95,17 @@ func (i *ServerInterceptor) UnaryInterceptor(ctx context.Context, md.Get(HeaderAuthorization)) macaroon, _, err := FromHeader(header) if err != nil { - return handler(ctx, req) + return nil, fmt.Errorf("auth header extraction failed: %v", err) } // If there is an LSAT, decode and add it to the context. identifier, err := lsat.DecodeIdentifier(bytes.NewBuffer(macaroon.Id())) if err != nil { - return handler(ctx, req) + return nil, fmt.Errorf("token ID decoding failed: %v", err) } var clientID lsat.TokenID copy(clientID[:], identifier.TokenID[:]) - idCtx := AddToContext(ctx, KeyTokenID, clientID) log.Debugf("Decoded client/token ID %s from auth header", clientID.String()) - - // Continue the call chain by invoking the handler. - return handler(idCtx, req) + return &clientID, nil }