Support more than one token per node and params per token

This commit is contained in:
Yaacov Akiba Slama
2023-06-02 18:51:52 +03:00
parent 59f01bd30c
commit 81f24accad
7 changed files with 66 additions and 46 deletions

View File

@@ -11,9 +11,9 @@ type NodeConfig struct {
// clients.
LspdPrivateKey string `json:"lspdPrivateKey"`
// Token used to authenticate to lspd. This token must be unique for each
// Tokens used to authenticate to lspd. These tokens must be unique for each
// configured node, so it's obvious which node an rpc call is meant for.
Token string `json:"token"`
Tokens []string `json:"tokens"`
// The network location of the lightning node, e.g. `12.34.56.78:9012` or
// `localhost:10011`

View File

@@ -76,7 +76,7 @@ func NewInterceptor(
func (i *Interceptor) Intercept(nextHop string, reqPaymentHash []byte, reqOutgoingAmountMsat uint64, reqOutgoingExpiry uint32, reqIncomingExpiry uint32) InterceptResult {
reqPaymentHashStr := hex.EncodeToString(reqPaymentHash)
resp, _, _ := i.payHashGroup.Do(reqPaymentHashStr, func() (interface{}, error) {
params, paymentHash, paymentSecret, destination, incomingAmountMsat, outgoingAmountMsat, channelPoint, err := i.store.PaymentInfo(reqPaymentHash)
token, params, paymentHash, paymentSecret, destination, incomingAmountMsat, outgoingAmountMsat, channelPoint, err := i.store.PaymentInfo(reqPaymentHash)
if err != nil {
log.Printf("paymentInfo(%x) error: %v", reqPaymentHash, err)
return InterceptResult{
@@ -123,7 +123,7 @@ func (i *Interceptor) Intercept(nextHop string, reqPaymentHash []byte, reqOutgoi
}
if time.Now().UTC().After(validUntil) {
if !i.isCurrentChainFeeCheaper(params) {
if !i.isCurrentChainFeeCheaper(token, params) {
log.Printf("Intercepted expired payment registration. Failing payment. payment hash: %x, valid until: %s", paymentHash, params.ValidUntil)
return InterceptResult{
Action: INTERCEPT_FAIL_HTLC_WITH_CODE,
@@ -281,8 +281,8 @@ func (i *Interceptor) Intercept(nextHop string, reqPaymentHash []byte, reqOutgoi
return resp.(InterceptResult)
}
func (i *Interceptor) isCurrentChainFeeCheaper(params *OpeningFeeParams) bool {
settings, err := i.store.GetFeeParamsSettings()
func (i *Interceptor) isCurrentChainFeeCheaper(token string, params *OpeningFeeParams) bool {
settings, err := i.store.GetFeeParamsSettings(token)
if err != nil {
log.Printf("Failed to get fee params settings: %v", err)
return false

View File

@@ -20,9 +20,9 @@ type OpeningFeeParams struct {
}
type InterceptStore interface {
PaymentInfo(htlcPaymentHash []byte) (*OpeningFeeParams, []byte, []byte, []byte, int64, int64, *wire.OutPoint, error)
PaymentInfo(htlcPaymentHash []byte) (string, *OpeningFeeParams, []byte, []byte, []byte, int64, int64, *wire.OutPoint, error)
SetFundingTx(paymentHash []byte, channelPoint *wire.OutPoint) error
RegisterPayment(params *OpeningFeeParams, destination, paymentHash, paymentSecret []byte, incomingAmountMsat, outgoingAmountMsat int64, tag string) error
RegisterPayment(token string, params *OpeningFeeParams, destination, paymentHash, paymentSecret []byte, incomingAmountMsat, outgoingAmountMsat int64, tag string) error
InsertChannel(initialChanID, confirmedChanId uint64, channelPoint string, nodeID []byte, lastUpdate time.Time) error
GetFeeParamsSettings() ([]*OpeningFeeParamsSetting, error)
GetFeeParamsSettings(token string) ([]*OpeningFeeParamsSetting, error)
}

View File

@@ -15,6 +15,11 @@ import (
"github.com/jackc/pgx/v4/pgxpool"
)
type extendedParams struct {
Token string `json:"token"`
Params interceptor.OpeningFeeParams `json:"fees_params"`
}
type PostgresInterceptStore struct {
pool *pgxpool.Pool
}
@@ -23,7 +28,7 @@ func NewPostgresInterceptStore(pool *pgxpool.Pool) *PostgresInterceptStore {
return &PostgresInterceptStore{pool: pool}
}
func (s *PostgresInterceptStore) PaymentInfo(htlcPaymentHash []byte) (*interceptor.OpeningFeeParams, []byte, []byte, []byte, int64, int64, *wire.OutPoint, error) {
func (s *PostgresInterceptStore) PaymentInfo(htlcPaymentHash []byte) (string, *interceptor.OpeningFeeParams, []byte, []byte, []byte, int64, int64, *wire.OutPoint, error) {
var (
p *string
paymentHash, paymentSecret, destination []byte
@@ -40,7 +45,7 @@ func (s *PostgresInterceptStore) PaymentInfo(htlcPaymentHash []byte) (*intercept
if err == pgx.ErrNoRows {
err = nil
}
return nil, nil, nil, nil, 0, 0, nil, err
return "", nil, nil, nil, nil, 0, 0, nil, err
}
var cp *wire.OutPoint
@@ -51,15 +56,15 @@ func (s *PostgresInterceptStore) PaymentInfo(htlcPaymentHash []byte) (*intercept
}
}
var params *interceptor.OpeningFeeParams
var extParams *extendedParams
if p != nil {
err = json.Unmarshal([]byte(*p), &params)
err = json.Unmarshal([]byte(*p), &extParams)
if err != nil {
log.Printf("Failed to unmarshal OpeningFeeParams '%s': %v", *p, err)
return nil, nil, nil, nil, 0, 0, nil, err
return "", nil, nil, nil, nil, 0, 0, nil, err
}
}
return params, paymentHash, paymentSecret, destination, incomingAmountMsat, outgoingAmountMsat, cp, nil
return extParams.Token, &extParams.Params, paymentHash, paymentSecret, destination, incomingAmountMsat, outgoingAmountMsat, cp, nil
}
func (s *PostgresInterceptStore) SetFundingTx(paymentHash []byte, channelPoint *wire.OutPoint) error {
@@ -72,16 +77,20 @@ func (s *PostgresInterceptStore) SetFundingTx(paymentHash []byte, channelPoint *
return err
}
func (s *PostgresInterceptStore) RegisterPayment(params *interceptor.OpeningFeeParams, destination, paymentHash, paymentSecret []byte, incomingAmountMsat, outgoingAmountMsat int64, tag string) error {
func (s *PostgresInterceptStore) RegisterPayment(token string, params *interceptor.OpeningFeeParams, destination, paymentHash, paymentSecret []byte, incomingAmountMsat, outgoingAmountMsat int64, tag string) error {
var t *string
if tag != "" {
t = &tag
}
p, err := json.Marshal(params)
if err != nil {
log.Printf("Failed to marshal OpeningFeeParams: %v", err)
return err
p := []byte{}
if params != nil {
var err error
p, err = json.Marshal(extendedParams{Token: token, Params: *params})
if err != nil {
log.Printf("Failed to marshal OpeningFeeParams: %v", err)
return err
}
}
commandTag, err := s.pool.Exec(context.Background(),
@@ -119,10 +128,10 @@ func (s *PostgresInterceptStore) InsertChannel(initialChanID, confirmedChanId ui
return nil
}
func (s *PostgresInterceptStore) GetFeeParamsSettings() ([]*interceptor.OpeningFeeParamsSetting, error) {
rows, err := s.pool.Query(context.Background(), `SELECT validity, params FROM new_channel_params`)
func (s *PostgresInterceptStore) GetFeeParamsSettings(token string) ([]*interceptor.OpeningFeeParamsSetting, error) {
rows, err := s.pool.Query(context.Background(), `SELECT validity, params FROM new_channel_params WHERE token=$1`, token)
if err != nil {
log.Printf("GetFeeParamsSettings() error: %v", err)
log.Printf("GetFeeParamsSettings(%v) error: %v", token, err)
return nil, err
}

View File

@@ -0,0 +1,2 @@
ALTER TABLE public.new_channel_params
DROP COLUMN token;

View File

@@ -0,0 +1 @@
ALTER TABLE public.new_channel_params ADD token varchar;

View File

@@ -58,13 +58,15 @@ type node struct {
openChannelReqGroup singleflight.Group
}
type contextKey string
func (s *server) ChannelInformation(ctx context.Context, in *lspdrpc.ChannelInformationRequest) (*lspdrpc.ChannelInformationReply, error) {
node, err := getNode(ctx)
node, token, err := s.getNode(ctx)
if err != nil {
return nil, err
}
params, err := s.createOpeningParamsMenu(ctx, node)
params, err := s.createOpeningParamsMenu(ctx, node, token)
if err != nil {
return nil, err
}
@@ -90,10 +92,11 @@ func (s *server) ChannelInformation(ctx context.Context, in *lspdrpc.ChannelInfo
func (s *server) createOpeningParamsMenu(
ctx context.Context,
node *node,
token string,
) ([]*lspdrpc.OpeningFeeParams, error) {
var menu []*lspdrpc.OpeningFeeParams
settings, err := s.store.GetFeeParamsSettings()
settings, err := s.store.GetFeeParamsSettings(token)
if err != nil {
log.Printf("Failed to fetch fee params settings: %v", err)
return nil, fmt.Errorf("failed to get opening_fee_params")
@@ -208,7 +211,7 @@ func (s *server) RegisterPayment(
ctx context.Context,
in *lspdrpc.RegisterPaymentRequest,
) (*lspdrpc.RegisterPaymentReply, error) {
node, err := getNode(ctx)
node, token, err := s.getNode(ctx)
if err != nil {
return nil, err
}
@@ -275,7 +278,7 @@ func (s *server) RegisterPayment(
MaxClientToSelfDelay: pi.OpeningFeeParams.MaxClientToSelfDelay,
Promise: pi.OpeningFeeParams.Promise,
}
err = s.store.RegisterPayment(params, pi.Destination, pi.PaymentHash, pi.PaymentSecret, pi.IncomingAmountMsat, pi.OutgoingAmountMsat, pi.Tag)
err = s.store.RegisterPayment(token, params, pi.Destination, pi.PaymentHash, pi.PaymentSecret, pi.IncomingAmountMsat, pi.OutgoingAmountMsat, pi.Tag)
if err != nil {
log.Printf("RegisterPayment() error: %v", err)
return nil, fmt.Errorf("RegisterPayment() error: %w", err)
@@ -284,7 +287,7 @@ func (s *server) RegisterPayment(
}
func (s *server) OpenChannel(ctx context.Context, in *lspdrpc.OpenChannelRequest) (*lspdrpc.OpenChannelReply, error) {
node, err := getNode(ctx)
node, _, err := s.getNode(ctx)
if err != nil {
return nil, err
}
@@ -368,7 +371,7 @@ func (n *node) getSignedEncryptedData(in *lspdrpc.Encrypted) (string, []byte, bo
}
func (s *server) CheckChannels(ctx context.Context, in *lspdrpc.Encrypted) (*lspdrpc.Encrypted, error) {
node, err := getNode(ctx)
node, _, err := s.getNode(ctx)
if err != nil {
return nil, err
}
@@ -473,12 +476,14 @@ func NewGrpcServer(
}
}
_, exists := nodes[config.Token]
if exists {
return nil, fmt.Errorf("cannot have multiple nodes with the same token")
}
for _, token := range config.Tokens {
_, exists := nodes[token]
if exists {
return nil, fmt.Errorf("cannot have multiple nodes with the same token")
}
nodes[config.Token] = node
nodes[token] = node
}
}
return &server{
@@ -534,12 +539,12 @@ func (s *server) Start() error {
}
token := strings.Replace(auth, "Bearer ", "", 1)
node, ok := s.nodes[token]
_, ok := s.nodes[token]
if !ok {
continue
}
return handler(context.WithValue(ctx, "node", node), req)
return handler(context.WithValue(ctx, contextKey("token"), token), req)
}
}
return nil, status.Errorf(codes.PermissionDenied, "Not authorized")
@@ -563,18 +568,21 @@ func (s *server) Stop() {
}
}
func getNode(ctx context.Context) (*node, error) {
n := ctx.Value("node")
if n == nil {
return nil, status.Errorf(codes.PermissionDenied, "Not authorized")
func (s *server) getNode(ctx context.Context) (*node, string, error) {
tok := ctx.Value(contextKey("token"))
if tok == nil {
return nil, "", status.Errorf(codes.PermissionDenied, "Not authorized")
}
node, ok := n.(*node)
if !ok || node == nil {
return nil, status.Errorf(codes.PermissionDenied, "Not authorized")
token, ok := tok.(string)
if !ok {
return nil, "", status.Errorf(codes.PermissionDenied, "Not authorized")
}
return node, nil
node, ok := s.nodes[token]
if !ok {
return nil, "", status.Errorf(codes.PermissionDenied, "Not authorized")
}
return node, token, nil
}
func checkPayment(params *lspdrpc.OpeningFeeParams, incomingAmountMsat, outgoingAmountMsat int64) error {