diff --git a/cln_client.go b/cln_client.go index cc9f088..c59b50e 100644 --- a/cln_client.go +++ b/cln_client.go @@ -4,6 +4,7 @@ import ( "encoding/hex" "fmt" "log" + "os" "github.com/breez/lspd/basetypes" "github.com/btcsuite/btcd/chaincfg/chainhash" @@ -23,7 +24,9 @@ var ( CLOSED_STATUSES = []string{"CLOSED"} ) -func NewClnClient(rpcFile string, lightningDir string) *ClnClient { +func NewClnClient() *ClnClient { + rpcFile := os.Getenv("CLN_SOCKET_NAME") + lightningDir := os.Getenv("CLN_SOCKET_DIR") client := glightning.NewLightning() client.SetTimeout(60) client.StartUp(rpcFile, lightningDir) diff --git a/cln_interceptor.go b/cln_interceptor.go index 02f6903..ab082fd 100644 --- a/cln_interceptor.go +++ b/cln_interceptor.go @@ -2,56 +2,159 @@ package main import ( "bytes" - "encoding/hex" + "context" "fmt" "io" "log" "os" "sync" + "time" + "github.com/breez/lspd/cln_plugin" sphinx "github.com/lightningnetwork/lightning-onion" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/record" "github.com/lightningnetwork/lnd/tlv" - "github.com/niftynei/glightning/glightning" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/status" ) type ClnHtlcInterceptor struct { - client *ClnClient - plugin *glightning.Plugin - initWg sync.WaitGroup + pluginAddress string + client *ClnClient + pluginClient cln_plugin.ClnPluginClient + initWg sync.WaitGroup + doneWg sync.WaitGroup + ctx context.Context + cancel context.CancelFunc } func NewClnHtlcInterceptor() *ClnHtlcInterceptor { - i := &ClnHtlcInterceptor{} + i := &ClnHtlcInterceptor{ + pluginAddress: os.Getenv("CLN_PLUGIN_ADDRESS"), + client: NewClnClient(), + } i.initWg.Add(1) return i } func (i *ClnHtlcInterceptor) Start() error { - //c-lightning plugin initiate - plugin := glightning.NewPlugin(i.onInit) - i.plugin = plugin - plugin.RegisterHooks(&glightning.Hooks{ - HtlcAccepted: i.OnHtlcAccepted, - }) - - err := plugin.Start(os.Stdin, os.Stdout) + ctx, cancel := context.WithCancel(context.Background()) + log.Printf("Dialing cln plugin on '%s'", i.pluginAddress) + conn, err := grpc.DialContext(ctx, i.pluginAddress, grpc.WithTransportCredentials(insecure.NewCredentials())) if err != nil { - log.Printf("Plugin error: %v", err) + log.Printf("grpc.Dial error: %v", err) + cancel() return err } - return nil + i.pluginClient = cln_plugin.NewClnPluginClient(conn) + i.ctx = ctx + i.cancel = cancel + return i.intercept() +} + +func (i *ClnHtlcInterceptor) intercept() error { + inited := false + + defer func() { + if !inited { + i.initWg.Done() + } + log.Printf("CLN intercept(): stopping. Waiting for in-progress interceptions to complete.") + i.doneWg.Wait() + }() + + for { + if i.ctx.Err() != nil { + return i.ctx.Err() + } + + log.Printf("Connecting CLN HTLC interceptor.") + interceptorClient, err := i.pluginClient.HtlcStream(i.ctx) + if err != nil { + log.Printf("pluginClient.HtlcStream(): %v", err) + <-time.After(time.Second) + continue + } + + for { + if i.ctx.Err() != nil { + return i.ctx.Err() + } + + if !inited { + inited = true + i.initWg.Done() + } + + request, err := interceptorClient.Recv() + if err != nil { + // If it is just the error result of the context cancellation + // the we exit silently. + status, ok := status.FromError(err) + if ok && status.Code() == codes.Canceled { + log.Printf("Got code canceled. Break.") + break + } + + // Otherwise it an unexpected error, we fail the test. + log.Printf("unexpected error in interceptor.Recv() %v", err) + break + } + + log.Printf("correlationid: %v\nhtlc: %v\nchanID: %v\nincoming amount: %v\noutgoing amount: %v\nincoming expiry: %v\noutgoing expiry: %v\npaymentHash: %v\nonionBlob: %v\n\n", + request.Correlationid, + request.Htlc, + request.Onion.ShortChannelId, + request.Htlc.AmountMsat, //with fees + request.Onion.ForwardAmountMsat, + request.Htlc.CltvExpiryRelative, + request.Htlc.CltvExpiry, + request.Htlc.PaymentHash, + request, + ) + + i.doneWg.Add(1) + go func() { + interceptResult := intercept(request.Htlc.PaymentHash, request.Onion.ForwardAmountMsat, request.Htlc.CltvExpiry) + switch interceptResult.action { + case INTERCEPT_RESUME_WITH_ONION: + interceptorClient.Send(i.resumeWithOnion(request, interceptResult)) + case INTERCEPT_FAIL_HTLC_WITH_CODE: + interceptorClient.Send(&cln_plugin.HtlcResolution{ + Correlationid: request.Correlationid, + Outcome: &cln_plugin.HtlcResolution_Fail{ + Fail: &cln_plugin.HtlcFail{ + FailureMessage: i.mapFailureCode(interceptResult.failureCode), + }, + }, + }) + case INTERCEPT_RESUME: + fallthrough + default: + interceptorClient.Send(&cln_plugin.HtlcResolution{ + Correlationid: request.Correlationid, + Outcome: &cln_plugin.HtlcResolution_Continue{ + Continue: &cln_plugin.HtlcContinue{}, + }, + }) + } + + i.doneWg.Done() + }() + } + + <-time.After(time.Second) + } } func (i *ClnHtlcInterceptor) Stop() error { - plugin := i.plugin - if plugin != nil { - plugin.Stop() - } - + i.cancel() + i.doneWg.Wait() return nil } @@ -60,98 +163,57 @@ func (i *ClnHtlcInterceptor) WaitStarted() LightningClient { return i.client } -func (i *ClnHtlcInterceptor) onInit(plugin *glightning.Plugin, options map[string]glightning.Option, config *glightning.Config) { - log.Printf("successfully init'd! %v\n", config.RpcFile) - - //lightning server - clientcln := glightning.NewLightning() - clientcln.SetTimeout(60) - clientcln.StartUp(config.RpcFile, config.LightningDir) - - i.client = &ClnClient{ - client: clientcln, - } - - log.Printf("successfull clientcln.StartUp") - i.initWg.Done() -} - -func (i *ClnHtlcInterceptor) OnHtlcAccepted(event *glightning.HtlcAcceptedEvent) (*glightning.HtlcAcceptedResponse, error) { - log.Printf("htlc_accepted called\n") - onion := event.Onion - - log.Printf("htlc: %v\nchanID: %v\nincoming amount: %v\noutgoing amount: %v\nincoming expiry: %v\noutgoing expiry: %v\npaymentHash: %v\nonionBlob: %v\n\n", - event.Htlc, - onion.ShortChannelId, - event.Htlc.AmountMilliSatoshi, //with fees - onion.ForwardAmount, - event.Htlc.CltvExpiryRelative, - event.Htlc.CltvExpiry, - event.Htlc.PaymentHash, - onion, - ) - - // fail htlc in case payment hash is not valid. - paymentHashBytes, err := hex.DecodeString(event.Htlc.PaymentHash) - if err != nil { - log.Printf("hex.DecodeString(%v) error: %v", event.Htlc.PaymentHash, err) - return event.Fail(i.mapFailureCode(FAILURE_TEMPORARY_CHANNEL_FAILURE)), nil - } - - interceptResult := intercept(paymentHashBytes, onion.ForwardAmount, uint32(event.Htlc.CltvExpiry)) - switch interceptResult.action { - case INTERCEPT_RESUME_WITH_ONION: - return i.resumeWithOnion(event, interceptResult), nil - case INTERCEPT_FAIL_HTLC_WITH_CODE: - return event.Fail(i.mapFailureCode(interceptResult.failureCode)), nil - case INTERCEPT_RESUME: - fallthrough - default: - return event.Continue(), nil - } -} - -func (i *ClnHtlcInterceptor) resumeWithOnion(event *glightning.HtlcAcceptedEvent, interceptResult interceptResult) *glightning.HtlcAcceptedResponse { +func (i *ClnHtlcInterceptor) resumeWithOnion(request *cln_plugin.HtlcAccepted, interceptResult interceptResult) *cln_plugin.HtlcResolution { //decoding and encoding onion with alias in type 6 record. - newPayload, err := encodePayloadWithNextHop(event.Onion.Payload, interceptResult.channelId) + newPayload, err := encodePayloadWithNextHop(request.Onion.Payload, interceptResult.channelId) if err != nil { log.Printf("encodePayloadWithNextHop error: %v", err) - return event.Fail(i.mapFailureCode(FAILURE_TEMPORARY_CHANNEL_FAILURE)) + return &cln_plugin.HtlcResolution{ + Correlationid: request.Correlationid, + Outcome: &cln_plugin.HtlcResolution_Fail{ + Fail: &cln_plugin.HtlcFail{ + FailureMessage: i.mapFailureCode(FAILURE_TEMPORARY_CHANNEL_FAILURE), + }, + }, + } } chanId := lnwire.NewChanIDFromOutPoint(interceptResult.channelPoint) log.Printf("forwarding htlc to the destination node and a new private channel was opened") - return event.ContinueWith(chanId.String(), newPayload) + return &cln_plugin.HtlcResolution{ + Correlationid: request.Correlationid, + Outcome: &cln_plugin.HtlcResolution_ContinueWith{ + ContinueWith: &cln_plugin.HtlcContinueWith{ + ChannelId: chanId[:], + Payload: newPayload, + }, + }, + } } -func encodePayloadWithNextHop(payloadHex string, channelId uint64) (string, error) { - payload, err := hex.DecodeString(payloadHex) - if err != nil { - log.Printf("failed to decode types. error: %v", err) - return "", err - } +func encodePayloadWithNextHop(payload []byte, channelId uint64) ([]byte, error) { bufReader := bytes.NewBuffer(payload) var b [8]byte varInt, err := sphinx.ReadVarInt(bufReader, &b) if err != nil { - return "", fmt.Errorf("failed to read payload length %v: %v", payloadHex, err) + return nil, fmt.Errorf("failed to read payload length %x: %v", payload, err) } innerPayload := make([]byte, varInt) if _, err := io.ReadFull(bufReader, innerPayload[:]); err != nil { - return "", fmt.Errorf("failed to decode payload %x: %v", innerPayload[:], err) + return nil, fmt.Errorf("failed to decode payload %x: %v", innerPayload[:], err) } s, _ := tlv.NewStream() tlvMap, err := s.DecodeWithParsedTypes(bytes.NewReader(innerPayload)) if err != nil { - return "", fmt.Errorf("DecodeWithParsedTypes failed for %x: %v", innerPayload[:], err) + return nil, fmt.Errorf("DecodeWithParsedTypes failed for %x: %v", innerPayload[:], err) } tt := record.NewNextHopIDRecord(&channelId) buf := bytes.NewBuffer([]byte{}) if err := tt.Encode(buf); err != nil { - return "", fmt.Errorf("failed to encode nexthop %x: %v", innerPayload[:], err) + return nil, fmt.Errorf("failed to encode nexthop %x: %v", innerPayload[:], err) } uTlvMap := make(map[uint64][]byte) @@ -165,26 +227,26 @@ func encodePayloadWithNextHop(payloadHex string, channelId uint64) (string, erro tlvRecords := tlv.MapToRecords(uTlvMap) s, err = tlv.NewStream(tlvRecords...) if err != nil { - return "", fmt.Errorf("tlv.NewStream(%x) error: %v", tlvRecords, err) + return nil, fmt.Errorf("tlv.NewStream(%x) error: %v", tlvRecords, err) } var newPayloadBuf bytes.Buffer err = s.Encode(&newPayloadBuf) if err != nil { - return "", fmt.Errorf("encode error: %v", err) + return nil, fmt.Errorf("encode error: %v", err) } - return hex.EncodeToString(newPayloadBuf.Bytes()), nil + return newPayloadBuf.Bytes(), nil } -func (i *ClnHtlcInterceptor) mapFailureCode(original interceptFailureCode) string { +func (i *ClnHtlcInterceptor) mapFailureCode(original interceptFailureCode) []byte { switch original { case FAILURE_TEMPORARY_CHANNEL_FAILURE: - return "1007" + return []byte{0x10, 0x07} case FAILURE_TEMPORARY_NODE_FAILURE: - return "2002" + return []byte{0x20, 0x02} case FAILURE_INCORRECT_OR_UNKNOWN_PAYMENT_DETAILS: - return "400F" + return []byte{0x40, 0x0F} default: log.Printf("Unknown failure code %v, default to temporary channel failure.", original) - return "1007" // temporary channel failure + return []byte{0x10, 0x07} // temporary channel failure } } diff --git a/go.mod b/go.mod index d667517..65f21e4 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,7 @@ go 1.19 require ( github.com/aws/aws-sdk-go v1.30.20 - github.com/breez/lntest v0.0.16 + github.com/breez/lntest v0.0.17 github.com/btcsuite/btcd v0.23.3 github.com/btcsuite/btcd/btcec/v2 v2.2.1 github.com/btcsuite/btcd/chaincfg/chainhash v1.0.1 @@ -163,7 +163,7 @@ require ( golang.org/x/text v0.4.0 // indirect golang.org/x/time v0.0.0-20210220033141-f8bda1e9f3ba // indirect google.golang.org/genproto v0.0.0-20210617175327-b9e0b3197ced // indirect - google.golang.org/protobuf v1.27.1 // indirect + google.golang.org/protobuf v1.27.1 gopkg.in/errgo.v1 v1.0.1 // indirect gopkg.in/macaroon-bakery.v2 v2.0.1 // indirect gopkg.in/macaroon.v2 v2.0.0 // indirect diff --git a/itest/cln_lspd_node.go b/itest/cln_lspd_node.go index 8290fdb..cdbfdc8 100644 --- a/itest/cln_lspd_node.go +++ b/itest/cln_lspd_node.go @@ -1,7 +1,12 @@ package itest import ( + "flag" "fmt" + "log" + "os" + "os/exec" + "path/filepath" "sync" "github.com/breez/lntest" @@ -12,28 +17,42 @@ import ( "google.golang.org/grpc/credentials/insecure" ) +var clnPluginExec = flag.String( + "clnpluginexec", "", "full path to cln plugin wrapper binary", +) + type ClnLspNode struct { harness *lntest.TestHarness lightningNode *lntest.ClnNode lspBase *lspBase + logFilePath string runtime *clnLspNodeRuntime isInitialized bool mtx sync.Mutex + pluginBinary string + pluginFile string + pluginAddress string } type clnLspNodeRuntime struct { + logFile *os.File + cmd *exec.Cmd rpc lspd.ChannelOpenerClient cleanups []*lntest.Cleanup } func NewClnLspdNode(h *lntest.TestHarness, m *lntest.Miner, name string) LspNode { - lspbase, err := newLspd(h, name, "RUN_CLN=true") + scriptDir := h.GetDirectory("lspd") + pluginFile := filepath.Join(scriptDir, "htlc.sh") + pluginBinary := *clnPluginExec + pluginPort, err := lntest.GetPort() if err != nil { - h.T.Fatalf("failed to initialize lspd") + h.T.Fatalf("failed to get port for the htlc interceptor plugin.") } + pluginAddress := fmt.Sprintf("127.0.0.1:%d", pluginPort) args := []string{ - fmt.Sprintf("--plugin=%s", lspbase.scriptFilePath), + fmt.Sprintf("--plugin=%s", pluginFile), fmt.Sprintf("--fee-base=%d", lspBaseFeeMsat), fmt.Sprintf("--fee-per-satoshi=%d", lspFeeRatePpm), fmt.Sprintf("--cltv-delta=%d", lspCltvDelta), @@ -41,11 +60,27 @@ func NewClnLspdNode(h *lntest.TestHarness, m *lntest.Miner, name string) LspNode "--dev-allowdustreserve=true", } lightningNode := lntest.NewClnNode(h, m, name, args...) + lspbase, err := newLspd(h, name, + "RUN_CLN=true", + fmt.Sprintf("CLN_PLUGIN_ADDRESS=%s", pluginAddress), + fmt.Sprintf("CLN_SOCKET_DIR=%s", lightningNode.SocketDir()), + fmt.Sprintf("CLN_SOCKET_NAME=%s", lightningNode.SocketFile()), + ) + if err != nil { + h.T.Fatalf("failed to initialize lspd") + } + + logFilePath := filepath.Join(scriptDir, "lspd.log") + h.RegisterLogfile(logFilePath, fmt.Sprintf("lspd-%s", name)) lspNode := &ClnLspNode{ harness: h, lightningNode: lightningNode, + logFilePath: logFilePath, lspBase: lspbase, + pluginBinary: pluginBinary, + pluginFile: pluginFile, + pluginAddress: pluginAddress, } h.AddStoppable(lspNode) @@ -67,6 +102,16 @@ func (c *ClnLspNode) Start() { Name: fmt.Sprintf("%s: lsp base", c.lspBase.name), Fn: c.lspBase.Stop, }) + + pluginContent := fmt.Sprintf(`#!/bin/bash +export LISTEN_ADDRESS=%s +%s`, c.pluginAddress, c.pluginBinary) + + err = os.WriteFile(c.pluginFile, []byte(pluginContent), 0755) + if err != nil { + lntest.PerformCleanup(cleanups) + c.harness.T.Fatalf("failed create lsp cln plugin file: %v", err) + } } c.lightningNode.Start() @@ -74,6 +119,51 @@ func (c *ClnLspNode) Start() { Name: fmt.Sprintf("%s: lightning node", c.lspBase.name), Fn: c.lightningNode.Stop, }) + + cmd := exec.CommandContext(c.harness.Ctx, c.lspBase.scriptFilePath) + logFile, err := os.Create(c.logFilePath) + if err != nil { + lntest.PerformCleanup(cleanups) + c.harness.T.Fatalf("failed create lsp logfile: %v", err) + } + cleanups = append(cleanups, &lntest.Cleanup{ + Name: fmt.Sprintf("%s: logfile", c.lspBase.name), + Fn: logFile.Close, + }) + + cmd.Stdout = logFile + cmd.Stderr = logFile + + log.Printf("%s: starting lspd %s", c.lspBase.name, c.lspBase.scriptFilePath) + err = cmd.Start() + if err != nil { + lntest.PerformCleanup(cleanups) + c.harness.T.Fatalf("failed to start lspd: %v", err) + } + cleanups = append(cleanups, &lntest.Cleanup{ + Name: fmt.Sprintf("%s: cmd", c.lspBase.name), + Fn: func() error { + proc := cmd.Process + if proc == nil { + return nil + } + + proc.Kill() + + log.Printf("About to wait for lspd to exit") + status, err := proc.Wait() + if err != nil { + log.Printf("waiting for lspd process error: %v, status: %v", err, status) + } + err = cmd.Wait() + if err != nil { + log.Printf("waiting for lspd cmd error: %v", err) + } + + return nil + }, + }) + conn, err := grpc.Dial( c.lspBase.grpcAddress, grpc.WithTransportCredentials(insecure.NewCredentials()), @@ -81,11 +171,17 @@ func (c *ClnLspNode) Start() { ) if err != nil { lntest.PerformCleanup(cleanups) - c.harness.T.Fatalf("%s: failed to create grpc connection: %v", c.lspBase.name, err) + c.harness.T.Fatalf("failed to create grpc connection: %v", err) } + cleanups = append(cleanups, &lntest.Cleanup{ + Name: fmt.Sprintf("%s: grpc conn", c.lspBase.name), + Fn: conn.Close, + }) client := lspd.NewChannelOpenerClient(conn) c.runtime = &clnLspNodeRuntime{ + logFile: logFile, + cmd: cmd, rpc: client, cleanups: cleanups, } diff --git a/sample.env b/sample.env index 901d9e0..3dfb15f 100644 --- a/sample.env +++ b/sample.env @@ -30,4 +30,7 @@ LND_MACAROON_HEX= RUN_LND=true # CLN specific environment variables +CLN_PLUGIN_ADDRESS=
+CLN_SOCKET_DIR= +CLN_SOCKET_NAME= RUN_CLN=true