diff --git a/config/config.go b/config/config.go index 4587f85..0f18c94 100644 --- a/config/config.go +++ b/config/config.go @@ -11,9 +11,9 @@ type NodeConfig struct { // clients. LspdPrivateKey string `json:"lspdPrivateKey"` - // Token used to authenticate to lspd. This token must be unique for each + // Tokens used to authenticate to lspd. These tokens must be unique for each // configured node, so it's obvious which node an rpc call is meant for. - Token string `json:"token"` + Tokens []string `json:"tokens"` // The network location of the lightning node, e.g. `12.34.56.78:9012` or // `localhost:10011` diff --git a/interceptor/intercept.go b/interceptor/intercept.go index 69b8178..6cc751b 100644 --- a/interceptor/intercept.go +++ b/interceptor/intercept.go @@ -76,7 +76,7 @@ func NewInterceptor( func (i *Interceptor) Intercept(nextHop string, reqPaymentHash []byte, reqOutgoingAmountMsat uint64, reqOutgoingExpiry uint32, reqIncomingExpiry uint32) InterceptResult { reqPaymentHashStr := hex.EncodeToString(reqPaymentHash) resp, _, _ := i.payHashGroup.Do(reqPaymentHashStr, func() (interface{}, error) { - params, paymentHash, paymentSecret, destination, incomingAmountMsat, outgoingAmountMsat, channelPoint, err := i.store.PaymentInfo(reqPaymentHash) + token, params, paymentHash, paymentSecret, destination, incomingAmountMsat, outgoingAmountMsat, channelPoint, err := i.store.PaymentInfo(reqPaymentHash) if err != nil { log.Printf("paymentInfo(%x) error: %v", reqPaymentHash, err) return InterceptResult{ @@ -123,7 +123,7 @@ func (i *Interceptor) Intercept(nextHop string, reqPaymentHash []byte, reqOutgoi } if time.Now().UTC().After(validUntil) { - if !i.isCurrentChainFeeCheaper(params) { + if !i.isCurrentChainFeeCheaper(token, params) { log.Printf("Intercepted expired payment registration. Failing payment. payment hash: %x, valid until: %s", paymentHash, params.ValidUntil) return InterceptResult{ Action: INTERCEPT_FAIL_HTLC_WITH_CODE, @@ -281,8 +281,8 @@ func (i *Interceptor) Intercept(nextHop string, reqPaymentHash []byte, reqOutgoi return resp.(InterceptResult) } -func (i *Interceptor) isCurrentChainFeeCheaper(params *OpeningFeeParams) bool { - settings, err := i.store.GetFeeParamsSettings() +func (i *Interceptor) isCurrentChainFeeCheaper(token string, params *OpeningFeeParams) bool { + settings, err := i.store.GetFeeParamsSettings(token) if err != nil { log.Printf("Failed to get fee params settings: %v", err) return false diff --git a/interceptor/store.go b/interceptor/store.go index 605f9cd..6ecd33f 100644 --- a/interceptor/store.go +++ b/interceptor/store.go @@ -20,9 +20,9 @@ type OpeningFeeParams struct { } type InterceptStore interface { - PaymentInfo(htlcPaymentHash []byte) (*OpeningFeeParams, []byte, []byte, []byte, int64, int64, *wire.OutPoint, error) + PaymentInfo(htlcPaymentHash []byte) (string, *OpeningFeeParams, []byte, []byte, []byte, int64, int64, *wire.OutPoint, error) SetFundingTx(paymentHash []byte, channelPoint *wire.OutPoint) error - RegisterPayment(params *OpeningFeeParams, destination, paymentHash, paymentSecret []byte, incomingAmountMsat, outgoingAmountMsat int64, tag string) error + RegisterPayment(token string, params *OpeningFeeParams, destination, paymentHash, paymentSecret []byte, incomingAmountMsat, outgoingAmountMsat int64, tag string) error InsertChannel(initialChanID, confirmedChanId uint64, channelPoint string, nodeID []byte, lastUpdate time.Time) error - GetFeeParamsSettings() ([]*OpeningFeeParamsSetting, error) + GetFeeParamsSettings(token string) ([]*OpeningFeeParamsSetting, error) } diff --git a/postgresql/intercept_store.go b/postgresql/intercept_store.go index 05fe591..941451a 100644 --- a/postgresql/intercept_store.go +++ b/postgresql/intercept_store.go @@ -15,6 +15,11 @@ import ( "github.com/jackc/pgx/v4/pgxpool" ) +type extendedParams struct { + Token string `json:"token"` + Params interceptor.OpeningFeeParams `json:"fees_params"` +} + type PostgresInterceptStore struct { pool *pgxpool.Pool } @@ -23,7 +28,7 @@ func NewPostgresInterceptStore(pool *pgxpool.Pool) *PostgresInterceptStore { return &PostgresInterceptStore{pool: pool} } -func (s *PostgresInterceptStore) PaymentInfo(htlcPaymentHash []byte) (*interceptor.OpeningFeeParams, []byte, []byte, []byte, int64, int64, *wire.OutPoint, error) { +func (s *PostgresInterceptStore) PaymentInfo(htlcPaymentHash []byte) (string, *interceptor.OpeningFeeParams, []byte, []byte, []byte, int64, int64, *wire.OutPoint, error) { var ( p *string paymentHash, paymentSecret, destination []byte @@ -40,7 +45,7 @@ func (s *PostgresInterceptStore) PaymentInfo(htlcPaymentHash []byte) (*intercept if err == pgx.ErrNoRows { err = nil } - return nil, nil, nil, nil, 0, 0, nil, err + return "", nil, nil, nil, nil, 0, 0, nil, err } var cp *wire.OutPoint @@ -51,15 +56,15 @@ func (s *PostgresInterceptStore) PaymentInfo(htlcPaymentHash []byte) (*intercept } } - var params *interceptor.OpeningFeeParams + var extParams *extendedParams if p != nil { - err = json.Unmarshal([]byte(*p), ¶ms) + err = json.Unmarshal([]byte(*p), &extParams) if err != nil { log.Printf("Failed to unmarshal OpeningFeeParams '%s': %v", *p, err) - return nil, nil, nil, nil, 0, 0, nil, err + return "", nil, nil, nil, nil, 0, 0, nil, err } } - return params, paymentHash, paymentSecret, destination, incomingAmountMsat, outgoingAmountMsat, cp, nil + return extParams.Token, &extParams.Params, paymentHash, paymentSecret, destination, incomingAmountMsat, outgoingAmountMsat, cp, nil } func (s *PostgresInterceptStore) SetFundingTx(paymentHash []byte, channelPoint *wire.OutPoint) error { @@ -72,16 +77,20 @@ func (s *PostgresInterceptStore) SetFundingTx(paymentHash []byte, channelPoint * return err } -func (s *PostgresInterceptStore) RegisterPayment(params *interceptor.OpeningFeeParams, destination, paymentHash, paymentSecret []byte, incomingAmountMsat, outgoingAmountMsat int64, tag string) error { +func (s *PostgresInterceptStore) RegisterPayment(token string, params *interceptor.OpeningFeeParams, destination, paymentHash, paymentSecret []byte, incomingAmountMsat, outgoingAmountMsat int64, tag string) error { var t *string if tag != "" { t = &tag } - p, err := json.Marshal(params) - if err != nil { - log.Printf("Failed to marshal OpeningFeeParams: %v", err) - return err + p := []byte{} + if params != nil { + var err error + p, err = json.Marshal(extendedParams{Token: token, Params: *params}) + if err != nil { + log.Printf("Failed to marshal OpeningFeeParams: %v", err) + return err + } } commandTag, err := s.pool.Exec(context.Background(), @@ -119,10 +128,10 @@ func (s *PostgresInterceptStore) InsertChannel(initialChanID, confirmedChanId ui return nil } -func (s *PostgresInterceptStore) GetFeeParamsSettings() ([]*interceptor.OpeningFeeParamsSetting, error) { - rows, err := s.pool.Query(context.Background(), `SELECT validity, params FROM new_channel_params`) +func (s *PostgresInterceptStore) GetFeeParamsSettings(token string) ([]*interceptor.OpeningFeeParamsSetting, error) { + rows, err := s.pool.Query(context.Background(), `SELECT validity, params FROM new_channel_params WHERE token=$1`, token) if err != nil { - log.Printf("GetFeeParamsSettings() error: %v", err) + log.Printf("GetFeeParamsSettings(%v) error: %v", token, err) return nil, err } diff --git a/postgresql/migrations/000012_new_channel_params_token.down.sql b/postgresql/migrations/000012_new_channel_params_token.down.sql new file mode 100644 index 0000000..0286c25 --- /dev/null +++ b/postgresql/migrations/000012_new_channel_params_token.down.sql @@ -0,0 +1,2 @@ +ALTER TABLE public.new_channel_params +DROP COLUMN token; \ No newline at end of file diff --git a/postgresql/migrations/000012_new_channel_params_token.up.sql b/postgresql/migrations/000012_new_channel_params_token.up.sql new file mode 100644 index 0000000..7bbab49 --- /dev/null +++ b/postgresql/migrations/000012_new_channel_params_token.up.sql @@ -0,0 +1 @@ +ALTER TABLE public.new_channel_params ADD token varchar; \ No newline at end of file diff --git a/server.go b/server.go index de10bd2..db35c4a 100644 --- a/server.go +++ b/server.go @@ -58,13 +58,15 @@ type node struct { openChannelReqGroup singleflight.Group } +type contextKey string + func (s *server) ChannelInformation(ctx context.Context, in *lspdrpc.ChannelInformationRequest) (*lspdrpc.ChannelInformationReply, error) { - node, err := getNode(ctx) + node, token, err := s.getNode(ctx) if err != nil { return nil, err } - params, err := s.createOpeningParamsMenu(ctx, node) + params, err := s.createOpeningParamsMenu(ctx, node, token) if err != nil { return nil, err } @@ -90,10 +92,11 @@ func (s *server) ChannelInformation(ctx context.Context, in *lspdrpc.ChannelInfo func (s *server) createOpeningParamsMenu( ctx context.Context, node *node, + token string, ) ([]*lspdrpc.OpeningFeeParams, error) { var menu []*lspdrpc.OpeningFeeParams - settings, err := s.store.GetFeeParamsSettings() + settings, err := s.store.GetFeeParamsSettings(token) if err != nil { log.Printf("Failed to fetch fee params settings: %v", err) return nil, fmt.Errorf("failed to get opening_fee_params") @@ -208,7 +211,7 @@ func (s *server) RegisterPayment( ctx context.Context, in *lspdrpc.RegisterPaymentRequest, ) (*lspdrpc.RegisterPaymentReply, error) { - node, err := getNode(ctx) + node, token, err := s.getNode(ctx) if err != nil { return nil, err } @@ -275,7 +278,7 @@ func (s *server) RegisterPayment( MaxClientToSelfDelay: pi.OpeningFeeParams.MaxClientToSelfDelay, Promise: pi.OpeningFeeParams.Promise, } - err = s.store.RegisterPayment(params, pi.Destination, pi.PaymentHash, pi.PaymentSecret, pi.IncomingAmountMsat, pi.OutgoingAmountMsat, pi.Tag) + err = s.store.RegisterPayment(token, params, pi.Destination, pi.PaymentHash, pi.PaymentSecret, pi.IncomingAmountMsat, pi.OutgoingAmountMsat, pi.Tag) if err != nil { log.Printf("RegisterPayment() error: %v", err) return nil, fmt.Errorf("RegisterPayment() error: %w", err) @@ -284,7 +287,7 @@ func (s *server) RegisterPayment( } func (s *server) OpenChannel(ctx context.Context, in *lspdrpc.OpenChannelRequest) (*lspdrpc.OpenChannelReply, error) { - node, err := getNode(ctx) + node, _, err := s.getNode(ctx) if err != nil { return nil, err } @@ -368,7 +371,7 @@ func (n *node) getSignedEncryptedData(in *lspdrpc.Encrypted) (string, []byte, bo } func (s *server) CheckChannels(ctx context.Context, in *lspdrpc.Encrypted) (*lspdrpc.Encrypted, error) { - node, err := getNode(ctx) + node, _, err := s.getNode(ctx) if err != nil { return nil, err } @@ -473,12 +476,14 @@ func NewGrpcServer( } } - _, exists := nodes[config.Token] - if exists { - return nil, fmt.Errorf("cannot have multiple nodes with the same token") - } + for _, token := range config.Tokens { + _, exists := nodes[token] + if exists { + return nil, fmt.Errorf("cannot have multiple nodes with the same token") + } - nodes[config.Token] = node + nodes[token] = node + } } return &server{ @@ -534,12 +539,12 @@ func (s *server) Start() error { } token := strings.Replace(auth, "Bearer ", "", 1) - node, ok := s.nodes[token] + _, ok := s.nodes[token] if !ok { continue } - return handler(context.WithValue(ctx, "node", node), req) + return handler(context.WithValue(ctx, contextKey("token"), token), req) } } return nil, status.Errorf(codes.PermissionDenied, "Not authorized") @@ -563,18 +568,21 @@ func (s *server) Stop() { } } -func getNode(ctx context.Context) (*node, error) { - n := ctx.Value("node") - if n == nil { - return nil, status.Errorf(codes.PermissionDenied, "Not authorized") +func (s *server) getNode(ctx context.Context) (*node, string, error) { + tok := ctx.Value(contextKey("token")) + if tok == nil { + return nil, "", status.Errorf(codes.PermissionDenied, "Not authorized") } - node, ok := n.(*node) - if !ok || node == nil { - return nil, status.Errorf(codes.PermissionDenied, "Not authorized") + token, ok := tok.(string) + if !ok { + return nil, "", status.Errorf(codes.PermissionDenied, "Not authorized") } - - return node, nil + node, ok := s.nodes[token] + if !ok { + return nil, "", status.Errorf(codes.PermissionDenied, "Not authorized") + } + return node, token, nil } func checkPayment(params *lspdrpc.OpeningFeeParams, incomingAmountMsat, outgoingAmountMsat int64) error {