From 4be6d8c6a4d96c54c07697370051a5a24764945e Mon Sep 17 00:00:00 2001 From: Jesse de Wit Date: Mon, 6 Nov 2023 15:00:34 +0100 Subject: [PATCH] split node context in its own module --- channel_opener_server.go | 26 ++++---------------------- grpc_server.go | 10 +--------- rpc/node_context.go | 36 ++++++++++++++++++++++++++++++++++++ 3 files changed, 41 insertions(+), 31 deletions(-) create mode 100644 rpc/node_context.go diff --git a/channel_opener_server.go b/channel_opener_server.go index 48e1d17..0b065f7 100644 --- a/channel_opener_server.go +++ b/channel_opener_server.go @@ -16,8 +16,6 @@ import ( lspdrpc "github.com/breez/lspd/rpc" ecies "github.com/ecies/go/v2" "github.com/golang/protobuf/proto" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" "github.com/btcsuite/btcd/btcec/v2" "github.com/btcsuite/btcd/chaincfg/chainhash" @@ -41,10 +39,8 @@ func NewChannelOpenerServer( } } -type contextKey string - func (s *channelOpenerServer) ChannelInformation(ctx context.Context, in *lspdrpc.ChannelInformationRequest) (*lspdrpc.ChannelInformationReply, error) { - node, token, err := s.getNode(ctx) + node, token, err := lspdrpc.GetNode(ctx) if err != nil { return nil, err } @@ -88,7 +84,7 @@ func (s *channelOpenerServer) RegisterPayment( ctx context.Context, in *lspdrpc.RegisterPaymentRequest, ) (*lspdrpc.RegisterPaymentReply, error) { - node, token, err := s.getNode(ctx) + node, token, err := lspdrpc.GetNode(ctx) if err != nil { return nil, err } @@ -174,7 +170,7 @@ func (s *channelOpenerServer) RegisterPayment( } func (s *channelOpenerServer) OpenChannel(ctx context.Context, in *lspdrpc.OpenChannelRequest) (*lspdrpc.OpenChannelReply, error) { - node, _, err := s.getNode(ctx) + node, _, err := lspdrpc.GetNode(ctx) if err != nil { return nil, err } @@ -258,7 +254,7 @@ func getSignedEncryptedData(n *common.Node, in *lspdrpc.Encrypted) (string, []by } func (s *channelOpenerServer) CheckChannels(ctx context.Context, in *lspdrpc.Encrypted) (*lspdrpc.Encrypted, error) { - node, _, err := s.getNode(ctx) + node, _, err := lspdrpc.GetNode(ctx) if err != nil { return nil, err } @@ -312,20 +308,6 @@ func (s *channelOpenerServer) CheckChannels(ctx context.Context, in *lspdrpc.Enc return &lspdrpc.Encrypted{Data: encrypted}, nil } -func (s *channelOpenerServer) getNode(ctx context.Context) (*common.Node, string, error) { - nd := ctx.Value(contextKey("node")) - if nd == nil { - return nil, "", status.Errorf(codes.PermissionDenied, "Not authorized") - } - - nodeContext, ok := nd.(*nodeContext) - if !ok { - return nil, "", status.Errorf(codes.PermissionDenied, "Not authorized") - } - - return nodeContext.node, nodeContext.token, nil -} - func checkPayment(params *lspdrpc.OpeningFeeParams, incomingAmountMsat, outgoingAmountMsat int64) error { fees := incomingAmountMsat * int64(params.Proportional) / 1_000_000 / 1_000 * 1_000 if fees < int64(params.MinMsat) { diff --git a/grpc_server.go b/grpc_server.go index 3d47ad2..7ed190c 100644 --- a/grpc_server.go +++ b/grpc_server.go @@ -29,11 +29,6 @@ type grpcServer struct { n notifications.NotificationsServer } -type nodeContext struct { - token string - node *common.Node -} - func NewGrpcServer( nodesService common.NodesService, address string, @@ -84,10 +79,7 @@ func (s *grpcServer) Start() error { continue } - return handler(context.WithValue(ctx, contextKey("node"), &nodeContext{ - token: token, - node: node, - }), req) + return handler(lspdrpc.WithNode(ctx, node, token), req) } } return nil, status.Errorf(codes.PermissionDenied, "Not authorized") diff --git a/rpc/node_context.go b/rpc/node_context.go new file mode 100644 index 0000000..436fb40 --- /dev/null +++ b/rpc/node_context.go @@ -0,0 +1,36 @@ +package lspd + +import ( + context "context" + + "github.com/breez/lspd/common" + codes "google.golang.org/grpc/codes" + status "google.golang.org/grpc/status" +) + +type contextKey string +type nodeContext struct { + token string + node *common.Node +} + +func GetNode(ctx context.Context) (*common.Node, string, error) { + nd := ctx.Value(contextKey("node")) + if nd == nil { + return nil, "", status.Errorf(codes.PermissionDenied, "Not authorized") + } + + nodeContext, ok := nd.(*nodeContext) + if !ok { + return nil, "", status.Errorf(codes.PermissionDenied, "Not authorized") + } + + return nodeContext.node, nodeContext.token, nil +} + +func WithNode(ctx context.Context, node *common.Node, token string) context.Context { + return context.WithValue(ctx, contextKey("node"), &nodeContext{ + token: token, + node: node, + }) +}