diff --git a/.github/workflows/integration_tests.yaml b/.github/workflows/integration_tests.yaml index 6650e79..bdb8acb 100644 --- a/.github/workflows/integration_tests.yaml +++ b/.github/workflows/integration_tests.yaml @@ -88,6 +88,7 @@ jobs: max-parallel: 4 matrix: test: [ + testRestartLspNode, testOpenZeroConfChannelOnReceive, testOpenZeroConfSingleHtlc, testZeroReserve, diff --git a/cln/cln_client.go b/cln/cln_client.go index 4cb08c9..49e1da8 100644 --- a/cln/cln_client.go +++ b/cln/cln_client.go @@ -6,6 +6,7 @@ import ( "log" "path/filepath" "strings" + "sync" "time" "github.com/breez/lspd/lightning" @@ -17,7 +18,9 @@ import ( ) type ClnClient struct { - client *glightning.Lightning + socketPath string + client *glightning.Lightning + mtx sync.Mutex } var ( @@ -28,6 +31,17 @@ var ( ) func NewClnClient(socketPath string) (*ClnClient, error) { + client, err := newGlightningClient(socketPath) + if err != nil { + return nil, err + } + return &ClnClient{ + socketPath: socketPath, + client: client, + }, nil +} + +func newGlightningClient(socketPath string) (*glightning.Lightning, error) { rpcFile := filepath.Base(socketPath) if rpcFile == "" || rpcFile == "." { return nil, fmt.Errorf("invalid socketPath '%s'", socketPath) @@ -39,14 +53,35 @@ func NewClnClient(socketPath string) (*ClnClient, error) { client := glightning.NewLightning() client.SetTimeout(60) - client.StartUp(rpcFile, lightningDir) - return &ClnClient{ - client: client, - }, nil + err := client.StartUp(rpcFile, lightningDir) + return client, err +} + +func (c *ClnClient) getClient() (*glightning.Lightning, error) { + c.mtx.Lock() + defer c.mtx.Unlock() + if c.client.IsUp() { + return c.client, nil + } + + var err error + c.client, err = newGlightningClient(c.socketPath) + if err != nil { + return nil, err + } + if c.client.IsUp() { + return c.client, nil + } + + return nil, fmt.Errorf("cln is not accessible") } func (c *ClnClient) GetInfo() (*lightning.GetInfoResult, error) { - info, err := c.client.GetInfo() + client, err := c.getClient() + if err != nil { + return nil, err + } + info, err := client.GetInfo() if err != nil { log.Printf("CLN: client.GetInfo() error: %v", err) return nil, err @@ -59,8 +94,13 @@ func (c *ClnClient) GetInfo() (*lightning.GetInfoResult, error) { } func (c *ClnClient) IsConnected(destination []byte) (bool, error) { + client, err := c.getClient() + if err != nil { + return false, err + } + pubKey := hex.EncodeToString(destination) - peer, err := c.client.GetPeer(pubKey) + peer, err := client.GetPeer(pubKey) if err != nil { if strings.Contains(err.Error(), "not found") { return false, nil @@ -80,6 +120,11 @@ func (c *ClnClient) IsConnected(destination []byte) (bool, error) { } func (c *ClnClient) OpenChannel(req *lightning.OpenChannelRequest) (*wire.OutPoint, error) { + client, err := c.getClient() + if err != nil { + return nil, err + } + pubkey := hex.EncodeToString(req.Destination) var minConfs *uint16 if req.MinConfs != nil { @@ -114,7 +159,7 @@ func (c *ClnClient) OpenChannel(req *lightning.OpenChannelRequest) (*wire.OutPoi } } - fundResult, err := c.client.FundChannelExt( + fundResult, err := client.FundChannelExt( pubkey, glightning.NewSat(int(req.CapacitySat)), rate, @@ -150,8 +195,13 @@ func (c *ClnClient) OpenChannel(req *lightning.OpenChannelRequest) (*wire.OutPoi } func (c *ClnClient) GetChannel(peerID []byte, channelPoint wire.OutPoint) (*lightning.GetChannelResult, error) { + client, err := c.getClient() + if err != nil { + return nil, err + } + pubkey := hex.EncodeToString(peerID) - channels, err := c.client.GetPeerChannels(pubkey) + channels, err := client.GetPeerChannels(pubkey) if err != nil { log.Printf("CLN: client.GetPeer(%s) error: %v", pubkey, err) return nil, err @@ -184,8 +234,13 @@ func (c *ClnClient) GetChannel(peerID []byte, channelPoint wire.OutPoint) (*ligh } func (c *ClnClient) GetNodeChannelCount(nodeID []byte) (int, error) { + client, err := c.getClient() + if err != nil { + return 0, err + } + pubkey := hex.EncodeToString(nodeID) - channels, err := c.client.GetPeerChannels(pubkey) + channels, err := client.GetPeerChannels(pubkey) if err != nil { log.Printf("CLN: client.GetPeer(%s) error: %v", pubkey, err) return 0, err @@ -203,12 +258,17 @@ func (c *ClnClient) GetNodeChannelCount(nodeID []byte) (int, error) { } func (c *ClnClient) GetClosedChannels(nodeID string, channelPoints map[string]uint64) (map[string]uint64, error) { + client, err := c.getClient() + if err != nil { + return nil, err + } + r := make(map[string]uint64) if len(channelPoints) == 0 { return r, nil } - channels, err := c.client.GetPeerChannels(nodeID) + channels, err := client.GetPeerChannels(nodeID) if err != nil { log.Printf("CLN: client.GetPeer(%s) error: %v", nodeID, err) return nil, err @@ -239,8 +299,13 @@ func (c *ClnClient) GetClosedChannels(nodeID string, channelPoints map[string]ui } func (c *ClnClient) GetPeerId(scid *lightning.ShortChannelID) ([]byte, error) { + client, err := c.getClient() + if err != nil { + return nil, err + } + scidStr := scid.ToString() - channels, err := c.client.ListPeerChannels() + channels, err := client.ListPeerChannels() if err != nil { return nil, err } @@ -265,9 +330,14 @@ func (c *ClnClient) GetPeerId(scid *lightning.ShortChannelID) ([]byte, error) { var pollingInterval = 400 * time.Millisecond func (c *ClnClient) WaitOnline(peerID []byte, deadline time.Time) error { + client, err := c.getClient() + if err != nil { + return err + } + peerIDStr := hex.EncodeToString(peerID) for { - peer, err := c.client.GetPeer(peerIDStr) + peer, err := client.GetPeer(peerIDStr) if err == nil && peer.Connected { return nil } diff --git a/cln/custom_msg_client.go b/cln/custom_msg_client.go index 79121f4..9f3bb6c 100644 --- a/cln/custom_msg_client.go +++ b/cln/custom_msg_client.go @@ -161,7 +161,11 @@ func (c *CustomMsgClient) Send(msg *lightning.CustomMessage) error { binary.BigEndian.PutUint16(t[:], uint16(msg.Type)) m := hex.EncodeToString(t[:]) + hex.EncodeToString(msg.Data) - _, err := c.client.client.SendCustomMessage(msg.PeerId, m) + client, err := c.client.getClient() + if err != nil { + return err + } + _, err = client.SendCustomMessage(msg.PeerId, m) return err } diff --git a/itest/lspd_test.go b/itest/lspd_test.go index 0b00e1a..87bf004 100644 --- a/itest/lspd_test.go +++ b/itest/lspd_test.go @@ -226,4 +226,8 @@ var allTestCases = []*testCase{ isLsps2: true, skipCreateLsp: true, }, + { + name: "testRestartLspNode", + test: testRestartLspNode, + }, } diff --git a/itest/restart_lsp_node_test.go b/itest/restart_lsp_node_test.go new file mode 100644 index 0000000..b608cf8 --- /dev/null +++ b/itest/restart_lsp_node_test.go @@ -0,0 +1,75 @@ +package itest + +import ( + "log" + "time" + + "github.com/breez/lntest" + lspd "github.com/breez/lspd/rpc" + "github.com/stretchr/testify/assert" +) + +func testRestartLspNode(p *testParams) { + alice := lntest.NewClnNode(p.h, p.m, "Alice") + alice.Start() + alice.Fund(10000000) + p.lsp.LightningNode().Fund(10000000) + + log.Print("Opening channel between Alice and the lsp") + channel := alice.OpenChannel(p.lsp.LightningNode(), &lntest.OpenChannelOptions{ + AmountSat: publicChanAmount, + }) + alice.WaitForChannelReady(channel) + + log.Printf("Adding bob's invoices") + outerAmountMsat := uint64(2100000) + innerAmountMsat := calculateInnerAmountMsat(p.lsp, outerAmountMsat, nil) + description := "Please pay me" + innerInvoice, outerInvoice := GenerateInvoices(p.BreezClient(), + generateInvoicesRequest{ + innerAmountMsat: innerAmountMsat, + outerAmountMsat: outerAmountMsat, + description: description, + lsp: p.lsp, + }) + p.BreezClient().SetHtlcAcceptor(innerAmountMsat) + + log.Print("Connecting bob to lspd") + p.BreezClient().Node().ConnectPeer(p.lsp.LightningNode()) + + log.Printf("Registering payment with lsp") + RegisterPayment(p.lsp, &lspd.PaymentInformation{ + PaymentHash: innerInvoice.paymentHash, + PaymentSecret: innerInvoice.paymentSecret, + Destination: p.BreezClient().Node().NodeId(), + IncomingAmountMsat: int64(outerAmountMsat), + OutgoingAmountMsat: int64(innerAmountMsat), + }, false) + + log.Printf("stopping lsp lightning node") + p.lsp.LightningNode().Stop() + log.Printf("waiting %v to allow lsp lightning node to stop completely", htlcInterceptorDelay) + <-time.After(htlcInterceptorDelay) + log.Printf("starting lsp lightning node again") + p.lsp.LightningNode().Start() + + // TODO: Fix race waiting for htlc interceptor. + log.Printf("Waiting %v to allow htlc interceptor to activate.", htlcInterceptorDelay) + <-time.After(htlcInterceptorDelay) + + log.Printf("Connect Bob to LSP again") + p.BreezClient().Node().ConnectPeer(p.lsp.LightningNode()) + + log.Printf("Alice paying") + payResp := alice.Pay(outerInvoice.bolt11) + bobInvoice := p.BreezClient().Node().GetInvoice(payResp.PaymentHash) + + assert.Equal(p.t, payResp.PaymentPreimage, bobInvoice.PaymentPreimage) + assert.Equal(p.t, innerAmountMsat, bobInvoice.AmountReceivedMsat) + + // Make sure capacity is correct + chans := p.BreezClient().Node().GetChannels() + assert.Equal(p.t, 1, len(chans)) + c := chans[0] + AssertChannelCapacity(p.t, outerAmountMsat, c.CapacityMsat) +}