make cln client recoverable after cln connection loss

This commit is contained in:
Jesse de Wit
2023-12-11 22:16:06 +01:00
parent 8585eff280
commit d8bee41243
5 changed files with 168 additions and 14 deletions

View File

@@ -88,6 +88,7 @@ jobs:
max-parallel: 4
matrix:
test: [
testRestartLspNode,
testOpenZeroConfChannelOnReceive,
testOpenZeroConfSingleHtlc,
testZeroReserve,

View File

@@ -6,6 +6,7 @@ import (
"log"
"path/filepath"
"strings"
"sync"
"time"
"github.com/breez/lspd/lightning"
@@ -17,7 +18,9 @@ import (
)
type ClnClient struct {
client *glightning.Lightning
socketPath string
client *glightning.Lightning
mtx sync.Mutex
}
var (
@@ -28,6 +31,17 @@ var (
)
func NewClnClient(socketPath string) (*ClnClient, error) {
client, err := newGlightningClient(socketPath)
if err != nil {
return nil, err
}
return &ClnClient{
socketPath: socketPath,
client: client,
}, nil
}
func newGlightningClient(socketPath string) (*glightning.Lightning, error) {
rpcFile := filepath.Base(socketPath)
if rpcFile == "" || rpcFile == "." {
return nil, fmt.Errorf("invalid socketPath '%s'", socketPath)
@@ -39,14 +53,35 @@ func NewClnClient(socketPath string) (*ClnClient, error) {
client := glightning.NewLightning()
client.SetTimeout(60)
client.StartUp(rpcFile, lightningDir)
return &ClnClient{
client: client,
}, nil
err := client.StartUp(rpcFile, lightningDir)
return client, err
}
func (c *ClnClient) getClient() (*glightning.Lightning, error) {
c.mtx.Lock()
defer c.mtx.Unlock()
if c.client.IsUp() {
return c.client, nil
}
var err error
c.client, err = newGlightningClient(c.socketPath)
if err != nil {
return nil, err
}
if c.client.IsUp() {
return c.client, nil
}
return nil, fmt.Errorf("cln is not accessible")
}
func (c *ClnClient) GetInfo() (*lightning.GetInfoResult, error) {
info, err := c.client.GetInfo()
client, err := c.getClient()
if err != nil {
return nil, err
}
info, err := client.GetInfo()
if err != nil {
log.Printf("CLN: client.GetInfo() error: %v", err)
return nil, err
@@ -59,8 +94,13 @@ func (c *ClnClient) GetInfo() (*lightning.GetInfoResult, error) {
}
func (c *ClnClient) IsConnected(destination []byte) (bool, error) {
client, err := c.getClient()
if err != nil {
return false, err
}
pubKey := hex.EncodeToString(destination)
peer, err := c.client.GetPeer(pubKey)
peer, err := client.GetPeer(pubKey)
if err != nil {
if strings.Contains(err.Error(), "not found") {
return false, nil
@@ -80,6 +120,11 @@ func (c *ClnClient) IsConnected(destination []byte) (bool, error) {
}
func (c *ClnClient) OpenChannel(req *lightning.OpenChannelRequest) (*wire.OutPoint, error) {
client, err := c.getClient()
if err != nil {
return nil, err
}
pubkey := hex.EncodeToString(req.Destination)
var minConfs *uint16
if req.MinConfs != nil {
@@ -114,7 +159,7 @@ func (c *ClnClient) OpenChannel(req *lightning.OpenChannelRequest) (*wire.OutPoi
}
}
fundResult, err := c.client.FundChannelExt(
fundResult, err := client.FundChannelExt(
pubkey,
glightning.NewSat(int(req.CapacitySat)),
rate,
@@ -150,8 +195,13 @@ func (c *ClnClient) OpenChannel(req *lightning.OpenChannelRequest) (*wire.OutPoi
}
func (c *ClnClient) GetChannel(peerID []byte, channelPoint wire.OutPoint) (*lightning.GetChannelResult, error) {
client, err := c.getClient()
if err != nil {
return nil, err
}
pubkey := hex.EncodeToString(peerID)
channels, err := c.client.GetPeerChannels(pubkey)
channels, err := client.GetPeerChannels(pubkey)
if err != nil {
log.Printf("CLN: client.GetPeer(%s) error: %v", pubkey, err)
return nil, err
@@ -184,8 +234,13 @@ func (c *ClnClient) GetChannel(peerID []byte, channelPoint wire.OutPoint) (*ligh
}
func (c *ClnClient) GetNodeChannelCount(nodeID []byte) (int, error) {
client, err := c.getClient()
if err != nil {
return 0, err
}
pubkey := hex.EncodeToString(nodeID)
channels, err := c.client.GetPeerChannels(pubkey)
channels, err := client.GetPeerChannels(pubkey)
if err != nil {
log.Printf("CLN: client.GetPeer(%s) error: %v", pubkey, err)
return 0, err
@@ -203,12 +258,17 @@ func (c *ClnClient) GetNodeChannelCount(nodeID []byte) (int, error) {
}
func (c *ClnClient) GetClosedChannels(nodeID string, channelPoints map[string]uint64) (map[string]uint64, error) {
client, err := c.getClient()
if err != nil {
return nil, err
}
r := make(map[string]uint64)
if len(channelPoints) == 0 {
return r, nil
}
channels, err := c.client.GetPeerChannels(nodeID)
channels, err := client.GetPeerChannels(nodeID)
if err != nil {
log.Printf("CLN: client.GetPeer(%s) error: %v", nodeID, err)
return nil, err
@@ -239,8 +299,13 @@ func (c *ClnClient) GetClosedChannels(nodeID string, channelPoints map[string]ui
}
func (c *ClnClient) GetPeerId(scid *lightning.ShortChannelID) ([]byte, error) {
client, err := c.getClient()
if err != nil {
return nil, err
}
scidStr := scid.ToString()
channels, err := c.client.ListPeerChannels()
channels, err := client.ListPeerChannels()
if err != nil {
return nil, err
}
@@ -265,9 +330,14 @@ func (c *ClnClient) GetPeerId(scid *lightning.ShortChannelID) ([]byte, error) {
var pollingInterval = 400 * time.Millisecond
func (c *ClnClient) WaitOnline(peerID []byte, deadline time.Time) error {
client, err := c.getClient()
if err != nil {
return err
}
peerIDStr := hex.EncodeToString(peerID)
for {
peer, err := c.client.GetPeer(peerIDStr)
peer, err := client.GetPeer(peerIDStr)
if err == nil && peer.Connected {
return nil
}

View File

@@ -161,7 +161,11 @@ func (c *CustomMsgClient) Send(msg *lightning.CustomMessage) error {
binary.BigEndian.PutUint16(t[:], uint16(msg.Type))
m := hex.EncodeToString(t[:]) + hex.EncodeToString(msg.Data)
_, err := c.client.client.SendCustomMessage(msg.PeerId, m)
client, err := c.client.getClient()
if err != nil {
return err
}
_, err = client.SendCustomMessage(msg.PeerId, m)
return err
}

View File

@@ -226,4 +226,8 @@ var allTestCases = []*testCase{
isLsps2: true,
skipCreateLsp: true,
},
{
name: "testRestartLspNode",
test: testRestartLspNode,
},
}

View File

@@ -0,0 +1,75 @@
package itest
import (
"log"
"time"
"github.com/breez/lntest"
lspd "github.com/breez/lspd/rpc"
"github.com/stretchr/testify/assert"
)
func testRestartLspNode(p *testParams) {
alice := lntest.NewClnNode(p.h, p.m, "Alice")
alice.Start()
alice.Fund(10000000)
p.lsp.LightningNode().Fund(10000000)
log.Print("Opening channel between Alice and the lsp")
channel := alice.OpenChannel(p.lsp.LightningNode(), &lntest.OpenChannelOptions{
AmountSat: publicChanAmount,
})
alice.WaitForChannelReady(channel)
log.Printf("Adding bob's invoices")
outerAmountMsat := uint64(2100000)
innerAmountMsat := calculateInnerAmountMsat(p.lsp, outerAmountMsat, nil)
description := "Please pay me"
innerInvoice, outerInvoice := GenerateInvoices(p.BreezClient(),
generateInvoicesRequest{
innerAmountMsat: innerAmountMsat,
outerAmountMsat: outerAmountMsat,
description: description,
lsp: p.lsp,
})
p.BreezClient().SetHtlcAcceptor(innerAmountMsat)
log.Print("Connecting bob to lspd")
p.BreezClient().Node().ConnectPeer(p.lsp.LightningNode())
log.Printf("Registering payment with lsp")
RegisterPayment(p.lsp, &lspd.PaymentInformation{
PaymentHash: innerInvoice.paymentHash,
PaymentSecret: innerInvoice.paymentSecret,
Destination: p.BreezClient().Node().NodeId(),
IncomingAmountMsat: int64(outerAmountMsat),
OutgoingAmountMsat: int64(innerAmountMsat),
}, false)
log.Printf("stopping lsp lightning node")
p.lsp.LightningNode().Stop()
log.Printf("waiting %v to allow lsp lightning node to stop completely", htlcInterceptorDelay)
<-time.After(htlcInterceptorDelay)
log.Printf("starting lsp lightning node again")
p.lsp.LightningNode().Start()
// TODO: Fix race waiting for htlc interceptor.
log.Printf("Waiting %v to allow htlc interceptor to activate.", htlcInterceptorDelay)
<-time.After(htlcInterceptorDelay)
log.Printf("Connect Bob to LSP again")
p.BreezClient().Node().ConnectPeer(p.lsp.LightningNode())
log.Printf("Alice paying")
payResp := alice.Pay(outerInvoice.bolt11)
bobInvoice := p.BreezClient().Node().GetInvoice(payResp.PaymentHash)
assert.Equal(p.t, payResp.PaymentPreimage, bobInvoice.PaymentPreimage)
assert.Equal(p.t, innerAmountMsat, bobInvoice.AmountReceivedMsat)
// Make sure capacity is correct
chans := p.BreezClient().Node().GetChannels()
assert.Equal(p.t, 1, len(chans))
c := chans[0]
AssertChannelCapacity(p.t, outerAmountMsat, c.CapacityMsat)
}