diff --git a/cln_client.go b/cln_client.go index c59b50e..04886ae 100644 --- a/cln_client.go +++ b/cln_client.go @@ -4,7 +4,7 @@ import ( "encoding/hex" "fmt" "log" - "os" + "path/filepath" "github.com/breez/lspd/basetypes" "github.com/btcsuite/btcd/chaincfg/chainhash" @@ -24,15 +24,22 @@ var ( CLOSED_STATUSES = []string{"CLOSED"} ) -func NewClnClient() *ClnClient { - rpcFile := os.Getenv("CLN_SOCKET_NAME") - lightningDir := os.Getenv("CLN_SOCKET_DIR") +func NewClnClient(socketPath string) (*ClnClient, error) { + rpcFile := filepath.Base(socketPath) + if rpcFile == "" || rpcFile == "." { + return nil, fmt.Errorf("invalid socketPath '%s'", socketPath) + } + lightningDir := filepath.Dir(socketPath) + if lightningDir == "" || lightningDir == "." { + return nil, fmt.Errorf("invalid socketPath '%s'", socketPath) + } + client := glightning.NewLightning() client.SetTimeout(60) client.StartUp(rpcFile, lightningDir) return &ClnClient{ client: client, - } + }, nil } func (c *ClnClient) GetInfo() (*GetInfoResult, error) { diff --git a/cln_interceptor.go b/cln_interceptor.go index c3a3cc0..083ce7b 100644 --- a/cln_interceptor.go +++ b/cln_interceptor.go @@ -7,7 +7,6 @@ import ( "fmt" "io" "log" - "os" "sync" "time" @@ -24,6 +23,7 @@ import ( ) type ClnHtlcInterceptor struct { + config *NodeConfig pluginAddress string client *ClnClient pluginClient proto.ClnPluginClient @@ -33,14 +33,23 @@ type ClnHtlcInterceptor struct { cancel context.CancelFunc } -func NewClnHtlcInterceptor() *ClnHtlcInterceptor { +func NewClnHtlcInterceptor(conf *NodeConfig) (*ClnHtlcInterceptor, error) { + if conf.Cln == nil { + return nil, fmt.Errorf("missing cln config") + } + + client, err := NewClnClient(conf.Cln.SocketPath) + if err != nil { + return nil, err + } i := &ClnHtlcInterceptor{ - pluginAddress: os.Getenv("CLN_PLUGIN_ADDRESS"), - client: NewClnClient(), + config: conf, + pluginAddress: conf.Cln.PluginAddress, + client: client, } i.initWg.Add(1) - return i + return i, nil } func (i *ClnHtlcInterceptor) Start() error { @@ -135,7 +144,7 @@ func (i *ClnHtlcInterceptor) intercept() error { interceptorClient.Send(i.defaultResolution(request)) i.doneWg.Done() } - interceptResult := intercept(paymentHash, request.Onion.ForwardMsat, request.Htlc.CltvExpiry) + interceptResult := intercept(i.client, i.config, paymentHash, request.Onion.ForwardMsat, request.Htlc.CltvExpiry) switch interceptResult.action { case INTERCEPT_RESUME_WITH_ONION: interceptorClient.Send(i.resumeWithOnion(request, interceptResult)) @@ -165,9 +174,8 @@ func (i *ClnHtlcInterceptor) Stop() error { return nil } -func (i *ClnHtlcInterceptor) WaitStarted() LightningClient { +func (i *ClnHtlcInterceptor) WaitStarted() { i.initWg.Wait() - return i.client } func (i *ClnHtlcInterceptor) resumeWithOnion(request *proto.HtlcAccepted, interceptResult interceptResult) *proto.HtlcResolution { diff --git a/config.go b/config.go new file mode 100644 index 0000000..a367682 --- /dev/null +++ b/config.go @@ -0,0 +1,32 @@ +package main + +type NodeConfig struct { + LspdPrivateKey string `json:"lspdPrivateKey"` + Token string `json:"token"` + Host string `json:"host"` + PublicChannelAmount int64 `json:"publicChannelAmount,string"` + ChannelAmount uint64 `json:"channelAmount,string"` + ChannelPrivate bool `json:"channelPrivate"` + TargetConf uint32 `json:"targetConf,string"` + MinHtlcMsat uint64 `json:"minHtlcMsat,string"` + BaseFeeMsat uint64 `json:"baseFeeMsat,string"` + FeeRate float64 `json:"feeRate,string"` + TimeLockDelta uint32 `json:"timeLockDelta,string"` + ChannelFeePermyriad int64 `json:"channelFeePermyriad,string"` + ChannelMinimumFeeMsat int64 `json:"channelMinimumFeeMsat,string"` + AdditionalChannelCapacity int64 `json:"additionalChannelCapacity,string"` + MaxInactiveDuration uint64 `json:"maxInactiveDuration,string"` + Lnd *LndConfig `json:"lnd,omitempty"` + Cln *ClnConfig `json:"cln,omitempty"` +} + +type LndConfig struct { + Address string `json:"address"` + Cert string `json:"cert"` + Macaroon string `json:"macaroon"` +} + +type ClnConfig struct { + PluginAddress string `json:"pluginAddress"` + SocketPath string `json:"socketPath"` +} diff --git a/db.go b/db.go index b68cbf5..f982d69 100644 --- a/db.go +++ b/db.go @@ -5,7 +5,6 @@ import ( "encoding/hex" "fmt" "log" - "os" "time" "github.com/btcsuite/btcd/wire" @@ -18,11 +17,11 @@ var ( pgxPool *pgxpool.Pool ) -func pgConnect() error { +func pgConnect(databaseUrl string) error { var err error - pgxPool, err = pgxpool.Connect(context.Background(), os.Getenv("DATABASE_URL")) + pgxPool, err = pgxpool.Connect(context.Background(), databaseUrl) if err != nil { - return fmt.Errorf("pgxpool.Connect(%v): %w", os.Getenv("DATABASE_URL"), err) + return fmt.Errorf("pgxpool.Connect(%v): %w", databaseUrl, err) } return nil } diff --git a/htlc_interceptor.go b/htlc_interceptor.go index 356463b..e572f0d 100644 --- a/htlc_interceptor.go +++ b/htlc_interceptor.go @@ -3,5 +3,5 @@ package main type HtlcInterceptor interface { Start() error Stop() error - WaitStarted() LightningClient + WaitStarted() } diff --git a/intercept.go b/intercept.go index 81ffea2..d15c861 100644 --- a/intercept.go +++ b/intercept.go @@ -45,7 +45,7 @@ type interceptResult struct { onionBlob []byte } -func intercept(reqPaymentHash []byte, reqOutgoingAmountMsat uint64, reqOutgoingExpiry uint32) interceptResult { +func intercept(client LightningClient, config *NodeConfig, reqPaymentHash []byte, reqOutgoingAmountMsat uint64, reqOutgoingExpiry uint32) interceptResult { reqPaymentHashStr := hex.EncodeToString(reqPaymentHash) resp, _, _ := payHashGroup.Do(reqPaymentHashStr, func() (interface{}, error) { paymentHash, paymentSecret, destination, incomingAmountMsat, outgoingAmountMsat, channelPoint, err := paymentInfo(reqPaymentHash) @@ -66,7 +66,7 @@ func intercept(reqPaymentHash []byte, reqOutgoingAmountMsat uint64, reqOutgoingE if channelPoint == nil { if bytes.Equal(paymentHash, reqPaymentHash) { - channelPoint, err = openChannel(client, reqPaymentHash, destination, incomingAmountMsat) + channelPoint, err = openChannel(client, config, reqPaymentHash, destination, incomingAmountMsat) if err != nil { log.Printf("openChannel(%x, %v) err: %v", destination, incomingAmountMsat, err) return interceptResult{ @@ -213,10 +213,10 @@ func intercept(reqPaymentHash []byte, reqOutgoingAmountMsat uint64, reqOutgoingE return resp.(interceptResult) } -func checkPayment(incomingAmountMsat, outgoingAmountMsat int64) error { - fees := incomingAmountMsat * channelFeePermyriad / 10_000 / 1_000 * 1_000 - if fees < channelMinimumFeeMsat { - fees = channelMinimumFeeMsat +func checkPayment(config *NodeConfig, incomingAmountMsat, outgoingAmountMsat int64) error { + fees := incomingAmountMsat * config.ChannelFeePermyriad / 10_000 / 1_000 * 1_000 + if fees < config.ChannelMinimumFeeMsat { + fees = config.ChannelMinimumFeeMsat } if incomingAmountMsat-outgoingAmountMsat < fees { return fmt.Errorf("not enough fees") @@ -224,9 +224,9 @@ func checkPayment(incomingAmountMsat, outgoingAmountMsat int64) error { return nil } -func openChannel(client LightningClient, paymentHash, destination []byte, incomingAmountMsat int64) (*wire.OutPoint, error) { - capacity := incomingAmountMsat/1000 + additionalChannelCapacity - if capacity == publicChannelAmount { +func openChannel(client LightningClient, config *NodeConfig, paymentHash, destination []byte, incomingAmountMsat int64) (*wire.OutPoint, error) { + capacity := incomingAmountMsat/1000 + config.AdditionalChannelCapacity + if capacity == config.PublicChannelAmount { capacity++ } channelPoint, err := client.OpenChannel(&OpenChannelRequest{ diff --git a/lnd_client.go b/lnd_client.go index f98054c..8bfe813 100644 --- a/lnd_client.go +++ b/lnd_client.go @@ -3,11 +3,10 @@ package main import ( "context" "crypto/x509" + "encoding/base64" "encoding/hex" "fmt" "log" - "os" - "strings" "github.com/breez/lspd/basetypes" "github.com/btcsuite/btcd/wire" @@ -17,7 +16,6 @@ import ( "github.com/lightningnetwork/lnd/lnrpc/routerrpc" "google.golang.org/grpc" "google.golang.org/grpc/credentials" - "google.golang.org/grpc/metadata" ) type LndClient struct { @@ -27,18 +25,28 @@ type LndClient struct { conn *grpc.ClientConn } -func NewLndClient() *LndClient { +func NewLndClient(conf *LndConfig) (*LndClient, error) { + cert, err := base64.StdEncoding.DecodeString(conf.Cert) + if err != nil { + return nil, fmt.Errorf("failed to decode cert: %w", err) + } + + _, err = hex.DecodeString(conf.Macaroon) + if err != nil { + return nil, fmt.Errorf("failed to decode macaroon: %w", err) + } + // Creds file to connect to LND gRPC cp := x509.NewCertPool() - if !cp.AppendCertsFromPEM([]byte(strings.Replace(os.Getenv("LND_CERT"), "\\n", "\n", -1))) { - log.Fatalf("credentials: failed to append certificates") + if !cp.AppendCertsFromPEM(cert) { + return nil, fmt.Errorf("credentials: failed to append certificates") } creds := credentials.NewClientTLSFromCert(cp, "") - macCred := NewMacaroonCredential(os.Getenv("LND_MACAROON_HEX")) + macCred := NewMacaroonCredential(conf.Macaroon) // Address of an LND instance conn, err := grpc.Dial( - os.Getenv("LND_ADDRESS"), + conf.Address, grpc.WithTransportCredentials(creds), grpc.WithPerRPCCredentials(macCred), ) @@ -54,7 +62,7 @@ func NewLndClient() *LndClient { routerClient: routerClient, chainNotifierClient: chainNotifierClient, conn: conn, - } + }, nil } func (c *LndClient) Close() { @@ -152,13 +160,12 @@ func (c *LndClient) GetChannel(peerID []byte, channelPoint wire.OutPoint) (*GetC func (c *LndClient) GetNodeChannelCount(nodeID []byte) (int, error) { nodeIDStr := hex.EncodeToString(nodeID) - clientCtx := metadata.AppendToOutgoingContext(context.Background(), "macaroon", os.Getenv("LND_MACAROON_HEX")) - listResponse, err := c.client.ListChannels(clientCtx, &lnrpc.ListChannelsRequest{}) + listResponse, err := c.client.ListChannels(context.Background(), &lnrpc.ListChannelsRequest{}) if err != nil { return 0, err } - pendingResponse, err := c.client.PendingChannels(clientCtx, &lnrpc.PendingChannelsRequest{}) + pendingResponse, err := c.client.PendingChannels(context.Background(), &lnrpc.PendingChannelsRequest{}) if err != nil { return 0, err } diff --git a/lnd_interceptor.go b/lnd_interceptor.go index a3775c4..3cc152c 100644 --- a/lnd_interceptor.go +++ b/lnd_interceptor.go @@ -14,6 +14,7 @@ import ( ) type LndHtlcInterceptor struct { + config *NodeConfig client *LndClient initWg sync.WaitGroup doneWg sync.WaitGroup @@ -21,14 +22,22 @@ type LndHtlcInterceptor struct { cancel context.CancelFunc } -func NewLndHtlcInterceptor() *LndHtlcInterceptor { +func NewLndHtlcInterceptor(conf *NodeConfig) (*LndHtlcInterceptor, error) { + if conf.Lnd == nil { + return nil, fmt.Errorf("missing lnd configuration") + } + client, err := NewLndClient(conf.Lnd) + if err != nil { + return nil, err + } i := &LndHtlcInterceptor{ - client: NewLndClient(), + config: conf, + client: client, } i.initWg.Add(1) - return i + return i, nil } func (i *LndHtlcInterceptor) Start() error { @@ -47,9 +56,8 @@ func (i *LndHtlcInterceptor) Stop() error { return nil } -func (i *LndHtlcInterceptor) WaitStarted() LightningClient { +func (i *LndHtlcInterceptor) WaitStarted() { i.initWg.Wait() - return i.client } func (i *LndHtlcInterceptor) intercept() error { @@ -113,7 +121,7 @@ func (i *LndHtlcInterceptor) intercept() error { i.doneWg.Add(1) go func() { - interceptResult := intercept(request.PaymentHash, request.OutgoingAmountMsat, request.OutgoingExpiry) + interceptResult := intercept(i.client, i.config, request.PaymentHash, request.OutgoingAmountMsat, request.OutgoingExpiry) switch interceptResult.action { case INTERCEPT_RESUME_WITH_ONION: interceptorClient.Send(&routerrpc.ForwardHtlcInterceptResponse{ diff --git a/main.go b/main.go index 208a6a2..f28ea75 100644 --- a/main.go +++ b/main.go @@ -1,6 +1,7 @@ package main import ( + "encoding/json" "fmt" "log" "os" @@ -21,59 +22,79 @@ func main() { return } - err := pgConnect() + n := os.Getenv("NODES") + var nodes []*NodeConfig + err := json.Unmarshal([]byte(n), &nodes) + if err != nil { + log.Fatalf("failed to unmarshal NODES env: %v", err) + } + + if len(nodes) == 0 { + log.Fatalf("need at least one node configured in NODES.") + } + + var interceptors []HtlcInterceptor + for _, node := range nodes { + var interceptor HtlcInterceptor + if node.Lnd != nil { + interceptor, err = NewLndHtlcInterceptor(node) + if err != nil { + log.Fatalf("failed to initialize LND interceptor: %v", err) + } + } + + if node.Cln != nil { + interceptor, err = NewClnHtlcInterceptor(node) + if err != nil { + log.Fatalf("failed to initialize CLN interceptor: %v", err) + } + } + + if interceptor == nil { + log.Fatalf("node has to be either cln or lnd") + } + + interceptors = append(interceptors, interceptor) + } + + address := os.Getenv("LISTEN_ADDRESS") + certMagicDomain := os.Getenv("CERTMAGIC_DOMAIN") + s, err := NewGrpcServer(nodes, address, certMagicDomain) + if err != nil { + log.Fatalf("failed to initialize grpc server: %v", err) + } + + databaseUrl := os.Getenv("DATABASE_URL") + err = pgConnect(databaseUrl) if err != nil { log.Fatalf("pgConnect() error: %v", err) } - runCln := os.Getenv("RUN_CLN") == "true" - runLnd := os.Getenv("RUN_LND") == "true" - - if runCln && runLnd { - log.Fatalf("One of RUN_CLN or RUN_LND must be true, not both.") - } - - if !runCln && !runLnd { - log.Fatalf("Either RUN_CLN or RUN_LND must be true.") - } - - var interceptor HtlcInterceptor - if runCln { - interceptor = NewClnHtlcInterceptor() - } - - if runLnd { - interceptor = NewLndHtlcInterceptor() - } - - s := NewGrpcServer() - var wg sync.WaitGroup - wg.Add(2) + wg.Add(len(interceptors) + 1) - go func() { - err := interceptor.Start() - if err == nil { - log.Printf("Interceptor stopped.") - } else { - log.Printf("FATAL. Interceptor stopped with error: %v", err) + stopInterceptors := func() { + for _, interceptor := range interceptors { + interceptor.Stop() } - s.Stop() - wg.Done() - }() - - client = interceptor.WaitStarted() - info, err := client.GetInfo() - if err != nil { - log.Fatalf("client.GetInfo() error: %v", err) } - log.Printf("Connected to node '%s', alias '%s'", info.Pubkey, info.Alias) - if nodeName == "" { - nodeName = info.Alias - } - if nodePubkey == "" { - nodePubkey = info.Pubkey + for _, interceptor := range interceptors { + i := interceptor + go func() { + err := i.Start() + if err == nil { + log.Printf("Interceptor stopped.") + } else { + log.Printf("FATAL. Interceptor stopped with error: %v", err) + } + + wg.Done() + + // If any interceptor stops, stop everything, so we're able to restart using systemd. + s.Stop() + stopInterceptors() + }() } go func() { @@ -84,8 +105,10 @@ func main() { log.Printf("FATAL. GRPC server stopped with error: %v", err) } - interceptor.Stop() wg.Done() + + // If the server stops, stop everything else, so we're able to restart using systemd. + stopInterceptors() }() c := make(chan os.Signal, 1) @@ -93,8 +116,10 @@ func main() { go func() { sig := <-c log.Printf("Received stop signal %v. Stopping.", sig) + + // Stop everything gracefully on stop signal s.Stop() - interceptor.Stop() + stopInterceptors() }() wg.Wait() diff --git a/sample.env b/sample.env index 3dfb15f..407136a 100644 --- a/sample.env +++ b/sample.env @@ -3,12 +3,6 @@ LISTEN_ADDRESS= ### a certificate from Let's Encrypt #CERTMAGIC_DOMAIN= -NODE_NAME= -NODE_PUBKEY= -NODE_HOST= - -TOKEN= -LSPD_PRIVATE_KEY= DATABASE_URL= AWS_REGION= @@ -23,14 +17,4 @@ CHANNELMISMATCH_NOTIFICATION_TO='["Name1 "]' CHANNELMISMATCH_NOTIFICATION_CC='["Name2 ","Name3 "]' CHANNELMISMATCH_NOTIFICATION_FROM="Name4 " -# LND specific environment variables -LND_ADDRESS= -LND_CERT= #replace each eol by \\n -LND_MACAROON_HEX= -RUN_LND=true - -# CLN specific environment variables -CLN_PLUGIN_ADDRESS=
-CLN_SOCKET_DIR= -CLN_SOCKET_NAME= -RUN_CLN=true +NODES='[ { "lspdPrivateKey": "", "token": "", "host": "", "publicChannelAmount": "1000183", "channelAmount": "100000", "channelPrivate": false, "targetConf": "6", "minHtlcMsat": "600", "baseFeeMsat": "1000", "feeRate": "0.000001", "timeLockDelta": "144", "channelFeePermyriad": "40", "channelMinimumFeeMsat": "2000000", "additionalChannelCapacity": "100000", "maxInactiveDuration": "3888000", "lnd": { "address": "", "cert": "", "macaroon": "" } }, { "lspdPrivateKey": "", "token": "", "host": "", "publicChannelAmount": "1000183", "channelAmount": "100000", "channelPrivate": false, "targetConf": "6", "minHtlcMsat": "600", "baseFeeMsat": "1000", "feeRate": "0.000001", "timeLockDelta": "144", "channelFeePermyriad": "40", "channelMinimumFeeMsat": "2000000", "additionalChannelCapacity": "100000", "maxInactiveDuration": "3888000", "cln": { "pluginAddress": "
", "socketPath": "" } } ]' diff --git a/server.go b/server.go index 03aaba5..cfae080 100644 --- a/server.go +++ b/server.go @@ -7,14 +7,14 @@ import ( "fmt" "log" "net" - "os" - "strconv" + "strings" "github.com/breez/lspd/btceclegacy" lspdrpc "github.com/breez/lspd/rpc" ecies "github.com/ecies/go/v2" "github.com/golang/protobuf/proto" grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware" + "golang.org/x/sync/singleflight" "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/metadata" @@ -25,61 +25,61 @@ import ( "github.com/btcsuite/btcd/wire" "github.com/caddyserver/certmagic" "github.com/lightningnetwork/lnd/lnwire" - "golang.org/x/sync/singleflight" -) - -const ( - publicChannelAmount = 1_000_183 - targetConf = 6 - minHtlcMsat = 600 - baseFeeMsat = 1000 - feeRate = 0.000001 - timeLockDelta = 144 - channelFeePermyriad = 40 - channelMinimumFeeMsat = 2_000_000 - additionalChannelCapacity = 100_000 - maxInactiveDuration = 45 * 24 * 3600 ) type server struct { - lis net.Listener - s *grpc.Server + address string + certmagicDomain string + lis net.Listener + s *grpc.Server + nodes map[string]*node } -var ( +type node struct { client LightningClient - openChannelReqGroup singleflight.Group + nodeName string + nodePubkey string + nodeConfig *NodeConfig privateKey *btcec.PrivateKey publicKey *btcec.PublicKey eciesPrivateKey *ecies.PrivateKey eciesPublicKey *ecies.PublicKey - nodeName = os.Getenv("NODE_NAME") - nodePubkey = os.Getenv("NODE_PUBKEY") -) + openChannelReqGroup singleflight.Group +} func (s *server) ChannelInformation(ctx context.Context, in *lspdrpc.ChannelInformationRequest) (*lspdrpc.ChannelInformationReply, error) { + node, err := getNode(ctx) + if err != nil { + return nil, err + } + return &lspdrpc.ChannelInformationReply{ - Name: nodeName, - Pubkey: nodePubkey, - Host: os.Getenv("NODE_HOST"), - ChannelCapacity: publicChannelAmount, - TargetConf: targetConf, - MinHtlcMsat: minHtlcMsat, - BaseFeeMsat: baseFeeMsat, - FeeRate: feeRate, - TimeLockDelta: timeLockDelta, - ChannelFeePermyriad: channelFeePermyriad, - ChannelMinimumFeeMsat: channelMinimumFeeMsat, - LspPubkey: publicKey.SerializeCompressed(), - MaxInactiveDuration: maxInactiveDuration, + Name: node.nodeName, + Pubkey: node.nodePubkey, + Host: node.nodeConfig.Host, + ChannelCapacity: int64(node.nodeConfig.PublicChannelAmount), + TargetConf: int32(node.nodeConfig.TargetConf), + MinHtlcMsat: int64(node.nodeConfig.MinHtlcMsat), + BaseFeeMsat: int64(node.nodeConfig.BaseFeeMsat), + FeeRate: node.nodeConfig.FeeRate, + TimeLockDelta: node.nodeConfig.TimeLockDelta, + ChannelFeePermyriad: int64(node.nodeConfig.ChannelFeePermyriad), + ChannelMinimumFeeMsat: int64(node.nodeConfig.ChannelMinimumFeeMsat), + LspPubkey: node.publicKey.SerializeCompressed(), // TODO: Is the publicKey different from the ecies public key? + MaxInactiveDuration: int64(node.nodeConfig.MaxInactiveDuration), }, nil } func (s *server) RegisterPayment(ctx context.Context, in *lspdrpc.RegisterPaymentRequest) (*lspdrpc.RegisterPaymentReply, error) { - data, err := ecies.Decrypt(eciesPrivateKey, in.Blob) + node, err := getNode(ctx) + if err != nil { + return nil, err + } + + data, err := ecies.Decrypt(node.eciesPrivateKey, in.Blob) if err != nil { log.Printf("ecies.Decrypt(%x) error: %v", in.Blob, err) - data, err = btceclegacy.Decrypt(privateKey, in.Blob) + data, err = btceclegacy.Decrypt(node.privateKey, in.Blob) if err != nil { log.Printf("btcec.Decrypt(%x) error: %v", in.Blob, err) return nil, fmt.Errorf("btcec.Decrypt(%x) error: %w", in.Blob, err) @@ -94,7 +94,7 @@ func (s *server) RegisterPayment(ctx context.Context, in *lspdrpc.RegisterPaymen } log.Printf("RegisterPayment - Destination: %x, pi.PaymentHash: %x, pi.PaymentSecret: %x, pi.IncomingAmountMsat: %v, pi.OutgoingAmountMsat: %v", pi.Destination, pi.PaymentHash, pi.PaymentSecret, pi.IncomingAmountMsat, pi.OutgoingAmountMsat) - err = checkPayment(pi.IncomingAmountMsat, pi.OutgoingAmountMsat) + err = checkPayment(node.nodeConfig, pi.IncomingAmountMsat, pi.OutgoingAmountMsat) if err != nil { log.Printf("checkPayment(%v, %v) error: %v", pi.IncomingAmountMsat, pi.OutgoingAmountMsat, err) return nil, fmt.Errorf("checkPayment(%v, %v) error: %v", pi.IncomingAmountMsat, pi.OutgoingAmountMsat, err) @@ -108,36 +108,30 @@ func (s *server) RegisterPayment(ctx context.Context, in *lspdrpc.RegisterPaymen } func (s *server) OpenChannel(ctx context.Context, in *lspdrpc.OpenChannelRequest) (*lspdrpc.OpenChannelReply, error) { - r, err, _ := openChannelReqGroup.Do(in.Pubkey, func() (interface{}, error) { + node, err := getNode(ctx) + if err != nil { + return nil, err + } + + r, err, _ := node.openChannelReqGroup.Do(in.Pubkey, func() (interface{}, error) { pubkey, err := hex.DecodeString(in.Pubkey) if err != nil { return nil, err } - channelCount, err := client.GetNodeChannelCount(pubkey) + channelCount, err := node.client.GetNodeChannelCount(pubkey) if err != nil { return nil, err } - channelAmount, err := strconv.ParseInt(os.Getenv("CHANNEL_AMOUNT"), 0, 64) - if err != nil || channelAmount <= 0 { - channelAmount = publicChannelAmount - } - log.Printf("os.Getenv(\"CHANNEL_AMOUNT\"): %v, channelAmount: %v, publicChannelAmount: %v, err: %v", - os.Getenv("CHANNEL_AMOUNT"), channelAmount, publicChannelAmount, err) - isPrivate, err := strconv.ParseBool(os.Getenv("CHANNEL_PRIVATE")) - if err != nil { - isPrivate = false - } - log.Printf("os.Getenv(\"CHANNEL_PRIVATE\"): %v, isPrivate: %v, err: %v", - os.Getenv("CHANNEL_PRIVATE"), isPrivate, err) + var outPoint *wire.OutPoint if channelCount == 0 { - outPoint, err = client.OpenChannel(&OpenChannelRequest{ - CapacitySat: uint64(channelAmount), + outPoint, err = node.client.OpenChannel(&OpenChannelRequest{ + CapacitySat: node.nodeConfig.ChannelAmount, Destination: pubkey, - TargetConf: targetConf, - MinHtlcMsat: minHtlcMsat, - IsPrivate: isPrivate, + TargetConf: node.nodeConfig.TargetConf, + MinHtlcMsat: node.nodeConfig.MinHtlcMsat, + IsPrivate: node.nodeConfig.ChannelPrivate, }) if err != nil { @@ -157,13 +151,13 @@ func (s *server) OpenChannel(ctx context.Context, in *lspdrpc.OpenChannelRequest return r.(*lspdrpc.OpenChannelReply), err } -func getSignedEncryptedData(in *lspdrpc.Encrypted) (string, []byte, bool, error) { +func (n *node) getSignedEncryptedData(in *lspdrpc.Encrypted) (string, []byte, bool, error) { usedEcies := true - signedBlob, err := ecies.Decrypt(eciesPrivateKey, in.Data) + signedBlob, err := ecies.Decrypt(n.eciesPrivateKey, in.Data) if err != nil { log.Printf("ecies.Decrypt(%x) error: %v", in.Data, err) usedEcies = false - signedBlob, err = btceclegacy.Decrypt(privateKey, in.Data) + signedBlob, err = btceclegacy.Decrypt(n.privateKey, in.Data) if err != nil { log.Printf("btcec.Decrypt(%x) error: %v", in.Data, err) return "", nil, usedEcies, fmt.Errorf("btcec.Decrypt(%x) error: %w", in.Data, err) @@ -198,7 +192,12 @@ func getSignedEncryptedData(in *lspdrpc.Encrypted) (string, []byte, bool, error) } func (s *server) CheckChannels(ctx context.Context, in *lspdrpc.Encrypted) (*lspdrpc.Encrypted, error) { - nodeID, data, usedEcies, err := getSignedEncryptedData(in) + node, err := getNode(ctx) + if err != nil { + return nil, err + } + + nodeID, data, usedEcies, err := node.getSignedEncryptedData(in) if err != nil { log.Printf("getSignedEncryptedData error: %v", err) return nil, fmt.Errorf("getSignedEncryptedData error: %v", err) @@ -214,7 +213,7 @@ func (s *server) CheckChannels(ctx context.Context, in *lspdrpc.Encrypted) (*lsp log.Printf("getNotFakeChannels(%v) error: %v", checkChannelsRequest.FakeChannels, err) return nil, fmt.Errorf("getNotFakeChannels(%v) error: %w", checkChannelsRequest.FakeChannels, err) } - closedChannels, err := client.GetClosedChannels(nodeID, checkChannelsRequest.WaitingCloseChannels) + closedChannels, err := node.client.GetClosedChannels(nodeID, checkChannelsRequest.WaitingCloseChannels) if err != nil { log.Printf("GetClosedChannels(%v) error: %v", checkChannelsRequest.FakeChannels, err) return nil, fmt.Errorf("GetClosedChannels(%v) error: %w", checkChannelsRequest.FakeChannels, err) @@ -236,7 +235,7 @@ func (s *server) CheckChannels(ctx context.Context, in *lspdrpc.Encrypted) (*lsp var encrypted []byte if usedEcies { - encrypted, err = ecies.Encrypt(eciesPublicKey, dataReply) + encrypted, err = ecies.Encrypt(node.eciesPublicKey, dataReply) if err != nil { log.Printf("ecies.Encrypt() error: %v", err) return nil, fmt.Errorf("ecies.Encrypt() error: %w", err) @@ -269,35 +268,82 @@ func getNotFakeChannels(nodeID string, channelPoints map[string]uint64) (map[str return r, nil } -func NewGrpcServer() *server { - return &server{} +func NewGrpcServer(configs []*NodeConfig, address string, certmagicDomain string) (*server, error) { + if len(configs) == 0 { + return nil, fmt.Errorf("no nodes supplied") + } + + nodes := make(map[string]*node) + for _, config := range configs { + pk, err := hex.DecodeString(config.LspdPrivateKey) + if err != nil { + return nil, fmt.Errorf("hex.DecodeString(config.lspdPrivateKey=%v) error: %v", config.LspdPrivateKey, err) + } + + eciesPrivateKey := ecies.NewPrivateKeyFromBytes(pk) + eciesPublicKey := eciesPrivateKey.PublicKey + privateKey, publicKey := btcec.PrivKeyFromBytes(pk) + + // TODO: Set nodename & nodepubkey + node := &node{ + nodeConfig: config, + privateKey: privateKey, + publicKey: publicKey, + eciesPrivateKey: eciesPrivateKey, + eciesPublicKey: eciesPublicKey, + } + + if config.Lnd == nil && config.Cln == nil { + return nil, fmt.Errorf("node has to be either cln or lnd") + } + + if config.Lnd != nil && config.Cln != nil { + return nil, fmt.Errorf("node cannot be both cln and lnd") + } + + if config.Lnd != nil { + node.client, err = NewLndClient(config.Lnd) + if err != nil { + return nil, err + } + } + + if config.Cln != nil { + node.client, err = NewClnClient(config.Cln.SocketPath) + if err != nil { + return nil, err + } + } + + _, exists := nodes[config.Token] + if exists { + return nil, fmt.Errorf("cannot have multiple nodes with the same token") + } + + nodes[config.Token] = node + } + + return &server{ + address: address, + certmagicDomain: certmagicDomain, + nodes: nodes, + }, nil } func (s *server) Start() error { - pk, err := hex.DecodeString(os.Getenv("LSPD_PRIVATE_KEY")) - if err != nil { - log.Fatalf("hex.DecodeString(os.Getenv(\"LSPD_PRIVATE_KEY\")=%v) error: %v", os.Getenv("LSPD_PRIVATE_KEY"), err) - } - - eciesPrivateKey = ecies.NewPrivateKeyFromBytes(pk) - eciesPublicKey = eciesPrivateKey.PublicKey - privateKey, publicKey = btcec.PrivKeyFromBytes(pk) - - certmagicDomain := os.Getenv("CERTMAGIC_DOMAIN") - address := os.Getenv("LISTEN_ADDRESS") var lis net.Listener - if certmagicDomain == "" { + if s.certmagicDomain == "" { var err error - lis, err = net.Listen("tcp", address) + lis, err = net.Listen("tcp", s.address) if err != nil { log.Fatalf("failed to listen: %v", err) } } else { - tlsConfig, err := certmagic.TLS([]string{certmagicDomain}) + tlsConfig, err := certmagic.TLS([]string{s.certmagicDomain}) if err != nil { log.Fatalf("failed to run certmagic: %v", err) } - lis, err = tls.Listen("tcp", address, tlsConfig) + lis, err = tls.Listen("tcp", s.address, tlsConfig) if err != nil { log.Fatalf("failed to listen: %v", err) } @@ -307,9 +353,17 @@ func (s *server) Start() error { grpc_middleware.WithUnaryServerChain(func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { if md, ok := metadata.FromIncomingContext(ctx); ok { for _, auth := range md.Get("authorization") { - if auth == "Bearer "+os.Getenv("TOKEN") { - return handler(ctx, req) + if !strings.HasPrefix(auth, "Bearer ") { + continue } + + token := strings.Replace(auth, "Bearer ", "", 1) + node, ok := s.nodes[token] + if !ok { + continue + } + + return handler(context.WithValue(ctx, "node", node), req) } } return nil, status.Errorf(codes.PermissionDenied, "Not authorized") @@ -332,3 +386,17 @@ func (s *server) Stop() { srv.GracefulStop() } } + +func getNode(ctx context.Context) (*node, error) { + n := ctx.Value("node") + if n == nil { + return nil, status.Errorf(codes.PermissionDenied, "Not authorized") + } + + node, ok := n.(*node) + if !ok || node == nil { + return nil, status.Errorf(codes.PermissionDenied, "Not authorized") + } + + return node, nil +}