split node context in its own module

This commit is contained in:
Jesse de Wit
2023-11-06 15:00:34 +01:00
parent f6af1e5442
commit 4be6d8c6a4
3 changed files with 41 additions and 31 deletions

View File

@@ -16,8 +16,6 @@ import (
lspdrpc "github.com/breez/lspd/rpc" lspdrpc "github.com/breez/lspd/rpc"
ecies "github.com/ecies/go/v2" ecies "github.com/ecies/go/v2"
"github.com/golang/protobuf/proto" "github.com/golang/protobuf/proto"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"github.com/btcsuite/btcd/btcec/v2" "github.com/btcsuite/btcd/btcec/v2"
"github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/chaincfg/chainhash"
@@ -41,10 +39,8 @@ func NewChannelOpenerServer(
} }
} }
type contextKey string
func (s *channelOpenerServer) ChannelInformation(ctx context.Context, in *lspdrpc.ChannelInformationRequest) (*lspdrpc.ChannelInformationReply, error) { func (s *channelOpenerServer) ChannelInformation(ctx context.Context, in *lspdrpc.ChannelInformationRequest) (*lspdrpc.ChannelInformationReply, error) {
node, token, err := s.getNode(ctx) node, token, err := lspdrpc.GetNode(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -88,7 +84,7 @@ func (s *channelOpenerServer) RegisterPayment(
ctx context.Context, ctx context.Context,
in *lspdrpc.RegisterPaymentRequest, in *lspdrpc.RegisterPaymentRequest,
) (*lspdrpc.RegisterPaymentReply, error) { ) (*lspdrpc.RegisterPaymentReply, error) {
node, token, err := s.getNode(ctx) node, token, err := lspdrpc.GetNode(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -174,7 +170,7 @@ func (s *channelOpenerServer) RegisterPayment(
} }
func (s *channelOpenerServer) OpenChannel(ctx context.Context, in *lspdrpc.OpenChannelRequest) (*lspdrpc.OpenChannelReply, error) { func (s *channelOpenerServer) OpenChannel(ctx context.Context, in *lspdrpc.OpenChannelRequest) (*lspdrpc.OpenChannelReply, error) {
node, _, err := s.getNode(ctx) node, _, err := lspdrpc.GetNode(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -258,7 +254,7 @@ func getSignedEncryptedData(n *common.Node, in *lspdrpc.Encrypted) (string, []by
} }
func (s *channelOpenerServer) CheckChannels(ctx context.Context, in *lspdrpc.Encrypted) (*lspdrpc.Encrypted, error) { func (s *channelOpenerServer) CheckChannels(ctx context.Context, in *lspdrpc.Encrypted) (*lspdrpc.Encrypted, error) {
node, _, err := s.getNode(ctx) node, _, err := lspdrpc.GetNode(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -312,20 +308,6 @@ func (s *channelOpenerServer) CheckChannels(ctx context.Context, in *lspdrpc.Enc
return &lspdrpc.Encrypted{Data: encrypted}, nil return &lspdrpc.Encrypted{Data: encrypted}, nil
} }
func (s *channelOpenerServer) getNode(ctx context.Context) (*common.Node, string, error) {
nd := ctx.Value(contextKey("node"))
if nd == nil {
return nil, "", status.Errorf(codes.PermissionDenied, "Not authorized")
}
nodeContext, ok := nd.(*nodeContext)
if !ok {
return nil, "", status.Errorf(codes.PermissionDenied, "Not authorized")
}
return nodeContext.node, nodeContext.token, nil
}
func checkPayment(params *lspdrpc.OpeningFeeParams, incomingAmountMsat, outgoingAmountMsat int64) error { func checkPayment(params *lspdrpc.OpeningFeeParams, incomingAmountMsat, outgoingAmountMsat int64) error {
fees := incomingAmountMsat * int64(params.Proportional) / 1_000_000 / 1_000 * 1_000 fees := incomingAmountMsat * int64(params.Proportional) / 1_000_000 / 1_000 * 1_000
if fees < int64(params.MinMsat) { if fees < int64(params.MinMsat) {

View File

@@ -29,11 +29,6 @@ type grpcServer struct {
n notifications.NotificationsServer n notifications.NotificationsServer
} }
type nodeContext struct {
token string
node *common.Node
}
func NewGrpcServer( func NewGrpcServer(
nodesService common.NodesService, nodesService common.NodesService,
address string, address string,
@@ -84,10 +79,7 @@ func (s *grpcServer) Start() error {
continue continue
} }
return handler(context.WithValue(ctx, contextKey("node"), &nodeContext{ return handler(lspdrpc.WithNode(ctx, node, token), req)
token: token,
node: node,
}), req)
} }
} }
return nil, status.Errorf(codes.PermissionDenied, "Not authorized") return nil, status.Errorf(codes.PermissionDenied, "Not authorized")

36
rpc/node_context.go Normal file
View File

@@ -0,0 +1,36 @@
package lspd
import (
context "context"
"github.com/breez/lspd/common"
codes "google.golang.org/grpc/codes"
status "google.golang.org/grpc/status"
)
type contextKey string
type nodeContext struct {
token string
node *common.Node
}
func GetNode(ctx context.Context) (*common.Node, string, error) {
nd := ctx.Value(contextKey("node"))
if nd == nil {
return nil, "", status.Errorf(codes.PermissionDenied, "Not authorized")
}
nodeContext, ok := nd.(*nodeContext)
if !ok {
return nil, "", status.Errorf(codes.PermissionDenied, "Not authorized")
}
return nodeContext.node, nodeContext.token, nil
}
func WithNode(ctx context.Context, node *common.Node, token string) context.Context {
return context.WithValue(ctx, contextKey("node"), &nodeContext{
token: token,
node: node,
})
}