diff --git a/go.mod b/go.mod index ef373b1..77041de 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,7 @@ go 1.19 require ( github.com/aws/aws-sdk-go v1.30.20 - github.com/breez/lntest v0.0.10 + github.com/breez/lntest v0.0.11 github.com/btcsuite/btcd v0.23.3 github.com/btcsuite/btcd/btcec/v2 v2.2.1 github.com/btcsuite/btcd/chaincfg/chainhash v1.0.1 diff --git a/itest/breez_client.go b/itest/breez_client.go index 7465c33..8334acd 100644 --- a/itest/breez_client.go +++ b/itest/breez_client.go @@ -1,15 +1,9 @@ package itest import ( - "bufio" "crypto/sha256" - "flag" - "fmt" - "os" - "path/filepath" "github.com/breez/lntest" - "github.com/breez/lntest/lnd" "github.com/btcsuite/btcd/btcec/v2" "github.com/btcsuite/btcd/btcec/v2/ecdsa" "github.com/btcsuite/btcd/chaincfg" @@ -17,133 +11,12 @@ import ( "github.com/lightningnetwork/lnd/zpay32" ) -type breezClient struct { - name string - harness *lntest.TestHarness - lightningNode lntest.LightningNode - scriptDir string -} - -var pluginContent string = `#!/usr/bin/env python3 -"""Use the openchannel hook to selectively opt-into zeroconf -""" - -from pyln.client import Plugin - -plugin = Plugin() - - -@plugin.hook('openchannel') -def on_openchannel(openchannel, plugin, **kwargs): - plugin.log(repr(openchannel)) - mindepth = int(0) - - plugin.log(f"This peer is in the zeroconf allowlist, setting mindepth={mindepth}") - return {'result': 'continue', 'mindepth': mindepth} - -plugin.run() -` - -var pluginStartupContent string = `python3 -m venv %s > /dev/null 2>&1 -source %s > /dev/null 2>&1 -pip install pyln-client > /dev/null 2>&1 -python %s -` - -func newClnBreezClient(h *lntest.TestHarness, m *lntest.Miner, name string) *breezClient { - scriptDir, err := os.MkdirTemp(h.Dir, name) - lntest.CheckError(h.T, err) - pythonFilePath := filepath.Join(scriptDir, "zero_conf_plugin.py") - pythonFile, err := os.OpenFile(pythonFilePath, os.O_CREATE|os.O_WRONLY, 0755) - lntest.CheckError(h.T, err) - - pythonWriter := bufio.NewWriter(pythonFile) - _, err = pythonWriter.WriteString(pluginContent) - lntest.CheckError(h.T, err) - - err = pythonWriter.Flush() - lntest.CheckError(h.T, err) - pythonFile.Close() - - pluginFilePath := filepath.Join(scriptDir, "start_zero_conf_plugin.sh") - pluginFile, err := os.OpenFile(pluginFilePath, os.O_CREATE|os.O_WRONLY, 0755) - lntest.CheckError(h.T, err) - - pluginWriter := bufio.NewWriter(pluginFile) - venvDir := filepath.Join(scriptDir, "venv") - activatePath := filepath.Join(venvDir, "bin", "activate") - _, err = pluginWriter.WriteString(fmt.Sprintf(pluginStartupContent, venvDir, activatePath, pythonFilePath)) - lntest.CheckError(h.T, err) - - err = pluginWriter.Flush() - lntest.CheckError(h.T, err) - pluginFile.Close() - - node := lntest.NewClnNode( - h, - m, - name, - fmt.Sprintf("--plugin=%s", pluginFilePath), - // NOTE: max-concurrent-htlcs is 30 on mainnet by default. In cln V22.11 - // there is a check for 'all dust' commitment transactions. The max - // concurrent HTLCs of both sides of the channel * dust limit must be - // lower than the channel capacity in order to open a zero conf zero - // reserve channel. Relevant code: - // https://github.com/ElementsProject/lightning/blob/774d16a72e125e4ae4e312b9e3307261983bec0e/openingd/openingd.c#L481-L520 - "--max-concurrent-htlcs=30", - ) - - return &breezClient{ - name: name, - harness: h, - lightningNode: node, - scriptDir: scriptDir, - } -} - -var lndMobileExecutable = flag.String( - "lndmobileexec", "", "full path to lnd mobile binary", -) - -func newLndBreezClient(h *lntest.TestHarness, m *lntest.Miner, name string) *breezClient { - lnd := lntest.NewLndNodeFromBinary(h, m, name, *lndMobileExecutable, - "--protocol.zero-conf", - "--protocol.option-scid-alias", - "--bitcoin.defaultchanconfs=0", - ) - - go startChannelAcceptor(h, lnd) - - return &breezClient{ - name: name, - harness: h, - lightningNode: lnd, - } -} - -func startChannelAcceptor(h *lntest.TestHarness, n *lntest.LndNode) error { - client, err := n.LightningClient().ChannelAcceptor(h.Ctx) - lntest.CheckError(h.T, err) - - for { - request, err := client.Recv() - if err != nil { - return err - } - - private := request.ChannelFlags&uint32(lnwire.FFAnnounceChannel) == 0 - resp := &lnd.ChannelAcceptResponse{ - PendingChanId: request.PendingChanId, - Accept: private, - } - if request.WantsZeroConf { - resp.MinAcceptDepth = 0 - resp.ZeroConf = true - } - - err = client.Send(resp) - lntest.CheckError(h.T, err) - } +type BreezClient interface { + Name() string + Harness() *lntest.TestHarness + Node() lntest.LightningNode + Start() + Stop() error } type generateInvoicesRequest struct { @@ -160,20 +33,20 @@ type invoice struct { paymentPreimage []byte } -func (n *breezClient) GenerateInvoices(req generateInvoicesRequest) (invoice, invoice) { +func GenerateInvoices(n BreezClient, req generateInvoicesRequest) (invoice, invoice) { preimage, err := GenerateRandomBytes(32) - lntest.CheckError(n.harness.T, err) + lntest.CheckError(n.Harness().T, err) lspNodeId, err := btcec.ParsePubKey(req.lsp.NodeId()) - lntest.CheckError(n.harness.T, err) + lntest.CheckError(n.Harness().T, err) - innerInvoice := n.lightningNode.CreateBolt11Invoice(&lntest.CreateInvoiceOptions{ + innerInvoice := n.Node().CreateBolt11Invoice(&lntest.CreateInvoiceOptions{ AmountMsat: req.innerAmountMsat, Description: &req.description, Preimage: &preimage, }) outerInvoiceRaw, err := zpay32.Decode(innerInvoice.Bolt11, &chaincfg.RegressionNetParams) - lntest.CheckError(n.harness.T, err) + lntest.CheckError(n.Harness().T, err) milliSat := lnwire.MilliSatoshi(req.outerAmountMsat) outerInvoiceRaw.MilliSat = &milliSat @@ -191,10 +64,10 @@ func (n *breezClient) GenerateInvoices(req generateInvoicesRequest) (invoice, in outerInvoice, err := outerInvoiceRaw.Encode(zpay32.MessageSigner{ SignCompact: func(msg []byte) ([]byte, error) { hash := sha256.Sum256(msg) - return ecdsa.SignCompact(n.lightningNode.PrivateKey(), hash[:], true) + return ecdsa.SignCompact(n.Node().PrivateKey(), hash[:], true) }, }) - lntest.CheckError(n.harness.T, err) + lntest.CheckError(n.Harness().T, err) inner := invoice{ bolt11: innerInvoice.Bolt11, diff --git a/itest/cln_breez_client.go b/itest/cln_breez_client.go new file mode 100644 index 0000000..ebd3afe --- /dev/null +++ b/itest/cln_breez_client.go @@ -0,0 +1,155 @@ +package itest + +import ( + "bufio" + "fmt" + "os" + "path/filepath" + "sync" + + "github.com/breez/lntest" +) + +var pluginContent string = `#!/usr/bin/env python3 +"""Use the openchannel hook to selectively opt-into zeroconf +""" + +from pyln.client import Plugin + +plugin = Plugin() + + +@plugin.hook('openchannel') +def on_openchannel(openchannel, plugin, **kwargs): + plugin.log(repr(openchannel)) + mindepth = int(0) + + plugin.log(f"This peer is in the zeroconf allowlist, setting mindepth={mindepth}") + return {'result': 'continue', 'mindepth': mindepth} + +plugin.run() +` + +var pluginStartupContent string = `python3 -m venv %s > /dev/null 2>&1 +source %s > /dev/null 2>&1 +pip install pyln-client > /dev/null 2>&1 +python %s +` + +type clnBreezClient struct { + name string + scriptDir string + pluginFilePath string + harness *lntest.TestHarness + isInitialized bool + node *lntest.ClnNode + mtx sync.Mutex +} + +func newClnBreezClient(h *lntest.TestHarness, m *lntest.Miner, name string) BreezClient { + scriptDir := h.GetDirectory(name) + pluginFilePath := filepath.Join(scriptDir, "start_zero_conf_plugin.sh") + node := lntest.NewClnNode( + h, + m, + name, + fmt.Sprintf("--plugin=%s", pluginFilePath), + // NOTE: max-concurrent-htlcs is 30 on mainnet by default. In cln V22.11 + // there is a check for 'all dust' commitment transactions. The max + // concurrent HTLCs of both sides of the channel * dust limit must be + // lower than the channel capacity in order to open a zero conf zero + // reserve channel. Relevant code: + // https://github.com/ElementsProject/lightning/blob/774d16a72e125e4ae4e312b9e3307261983bec0e/openingd/openingd.c#L481-L520 + "--max-concurrent-htlcs=30", + ) + + return &clnBreezClient{ + name: name, + harness: h, + node: node, + scriptDir: scriptDir, + pluginFilePath: pluginFilePath, + } +} + +func (c *clnBreezClient) Name() string { + return c.name +} + +func (c *clnBreezClient) Harness() *lntest.TestHarness { + return c.harness +} + +func (c *clnBreezClient) Node() lntest.LightningNode { + return c.node +} + +func (c *clnBreezClient) Start() { + c.mtx.Lock() + defer c.mtx.Unlock() + + if !c.isInitialized { + c.initialize() + c.isInitialized = true + } + + c.node.Start() +} + +func (c *clnBreezClient) initialize() error { + var cleanups []*lntest.Cleanup + + pythonFilePath := filepath.Join(c.scriptDir, "zero_conf_plugin.py") + pythonFile, err := os.OpenFile(pythonFilePath, os.O_CREATE|os.O_WRONLY, 0755) + if err != nil { + return fmt.Errorf("failed to create python file '%s': %v", pythonFilePath, err) + } + cleanups = append(cleanups, &lntest.Cleanup{ + Name: fmt.Sprintf("%s: python file", c.name), + Fn: pythonFile.Close, + }) + + pythonWriter := bufio.NewWriter(pythonFile) + _, err = pythonWriter.WriteString(pluginContent) + if err != nil { + lntest.PerformCleanup(cleanups) + return fmt.Errorf("failed to write content to python file '%s': %v", pythonFilePath, err) + } + + err = pythonWriter.Flush() + if err != nil { + lntest.PerformCleanup(cleanups) + return fmt.Errorf("failed to flush python file '%s': %v", pythonFilePath, err) + } + + pluginFile, err := os.OpenFile(c.pluginFilePath, os.O_CREATE|os.O_WRONLY, 0755) + if err != nil { + lntest.PerformCleanup(cleanups) + return fmt.Errorf("failed to create plugin file '%s': %v", c.pluginFilePath, err) + } + cleanups = append(cleanups, &lntest.Cleanup{ + Name: fmt.Sprintf("%s: python file", c.name), + Fn: pluginFile.Close, + }) + + pluginWriter := bufio.NewWriter(pluginFile) + venvDir := filepath.Join(c.scriptDir, "venv") + activatePath := filepath.Join(venvDir, "bin", "activate") + _, err = pluginWriter.WriteString(fmt.Sprintf(pluginStartupContent, venvDir, activatePath, pythonFilePath)) + if err != nil { + lntest.PerformCleanup(cleanups) + return fmt.Errorf("failed to write content to plugin file '%s': %v", c.pluginFilePath, err) + } + + err = pluginWriter.Flush() + if err != nil { + lntest.PerformCleanup(cleanups) + return fmt.Errorf("failed to flush plugin file '%s': %v", c.pluginFilePath, err) + } + lntest.PerformCleanup(cleanups) + return nil +} + +func (c *clnBreezClient) Stop() error { + return c.node.Stop() +} diff --git a/itest/cln_lspd_node.go b/itest/cln_lspd_node.go new file mode 100644 index 0000000..8290fdb --- /dev/null +++ b/itest/cln_lspd_node.go @@ -0,0 +1,133 @@ +package itest + +import ( + "fmt" + "sync" + + "github.com/breez/lntest" + lspd "github.com/breez/lspd/rpc" + "github.com/btcsuite/btcd/btcec/v2" + ecies "github.com/ecies/go/v2" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" +) + +type ClnLspNode struct { + harness *lntest.TestHarness + lightningNode *lntest.ClnNode + lspBase *lspBase + runtime *clnLspNodeRuntime + isInitialized bool + mtx sync.Mutex +} + +type clnLspNodeRuntime struct { + rpc lspd.ChannelOpenerClient + cleanups []*lntest.Cleanup +} + +func NewClnLspdNode(h *lntest.TestHarness, m *lntest.Miner, name string) LspNode { + lspbase, err := newLspd(h, name, "RUN_CLN=true") + if err != nil { + h.T.Fatalf("failed to initialize lspd") + } + + args := []string{ + fmt.Sprintf("--plugin=%s", lspbase.scriptFilePath), + fmt.Sprintf("--fee-base=%d", lspBaseFeeMsat), + fmt.Sprintf("--fee-per-satoshi=%d", lspFeeRatePpm), + fmt.Sprintf("--cltv-delta=%d", lspCltvDelta), + "--max-concurrent-htlcs=30", + "--dev-allowdustreserve=true", + } + lightningNode := lntest.NewClnNode(h, m, name, args...) + + lspNode := &ClnLspNode{ + harness: h, + lightningNode: lightningNode, + lspBase: lspbase, + } + + h.AddStoppable(lspNode) + return lspNode +} + +func (c *ClnLspNode) Start() { + c.mtx.Lock() + defer c.mtx.Unlock() + + var cleanups []*lntest.Cleanup + if !c.isInitialized { + err := c.lspBase.Initialize() + if err != nil { + c.harness.T.Fatalf("failed to initialize lsp: %v", err) + } + c.isInitialized = true + cleanups = append(cleanups, &lntest.Cleanup{ + Name: fmt.Sprintf("%s: lsp base", c.lspBase.name), + Fn: c.lspBase.Stop, + }) + } + + c.lightningNode.Start() + cleanups = append(cleanups, &lntest.Cleanup{ + Name: fmt.Sprintf("%s: lightning node", c.lspBase.name), + Fn: c.lightningNode.Stop, + }) + conn, err := grpc.Dial( + c.lspBase.grpcAddress, + grpc.WithTransportCredentials(insecure.NewCredentials()), + grpc.WithPerRPCCredentials(&token{token: "hello"}), + ) + if err != nil { + lntest.PerformCleanup(cleanups) + c.harness.T.Fatalf("%s: failed to create grpc connection: %v", c.lspBase.name, err) + } + + client := lspd.NewChannelOpenerClient(conn) + c.runtime = &clnLspNodeRuntime{ + rpc: client, + cleanups: cleanups, + } +} + +func (c *ClnLspNode) Stop() error { + c.mtx.Lock() + defer c.mtx.Unlock() + + if c.runtime == nil { + return nil + } + + lntest.PerformCleanup(c.runtime.cleanups) + c.runtime = nil + return nil +} + +func (c *ClnLspNode) Harness() *lntest.TestHarness { + return c.harness +} + +func (c *ClnLspNode) PublicKey() *btcec.PublicKey { + return c.lspBase.pubkey +} + +func (c *ClnLspNode) EciesPublicKey() *ecies.PublicKey { + return c.lspBase.eciesPubkey +} + +func (c *ClnLspNode) Rpc() lspd.ChannelOpenerClient { + return c.runtime.rpc +} + +func (l *ClnLspNode) NodeId() []byte { + return l.lightningNode.NodeId() +} + +func (l *ClnLspNode) LightningNode() lntest.LightningNode { + return l.lightningNode +} + +func (l *ClnLspNode) SupportsChargingFees() bool { + return false +} diff --git a/itest/intercept_zero_conf_test.go b/itest/intercept_zero_conf_test.go index f5d5ab5..9413b6d 100644 --- a/itest/intercept_zero_conf_test.go +++ b/itest/intercept_zero_conf_test.go @@ -13,6 +13,7 @@ var htlcInterceptorDelay = time.Second * 7 func testOpenZeroConfChannelOnReceive(p *testParams) { alice := lntest.NewClnNode(p.h, p.m, "Alice") + alice.Start() alice.Fund(10000000) p.lsp.LightningNode().Fund(10000000) @@ -26,15 +27,16 @@ func testOpenZeroConfChannelOnReceive(p *testParams) { outerAmountMsat := uint64(2100000) innerAmountMsat := calculateInnerAmountMsat(p.lsp, outerAmountMsat) description := "Please pay me" - innerInvoice, outerInvoice := p.BreezClient().GenerateInvoices(generateInvoicesRequest{ - innerAmountMsat: innerAmountMsat, - outerAmountMsat: outerAmountMsat, - description: description, - lsp: p.lsp, - }) + innerInvoice, outerInvoice := GenerateInvoices(p.BreezClient(), + generateInvoicesRequest{ + innerAmountMsat: innerAmountMsat, + outerAmountMsat: outerAmountMsat, + description: description, + lsp: p.lsp, + }) log.Print("Connecting bob to lspd") - p.BreezClient().lightningNode.ConnectPeer(p.lsp.LightningNode()) + p.BreezClient().Node().ConnectPeer(p.lsp.LightningNode()) // NOTE: We pretend to be paying fees to the lsp, but actually we won't. log.Printf("Registering payment with lsp") @@ -42,7 +44,7 @@ func testOpenZeroConfChannelOnReceive(p *testParams) { RegisterPayment(p.lsp, &lspd.PaymentInformation{ PaymentHash: innerInvoice.paymentHash, PaymentSecret: innerInvoice.paymentSecret, - Destination: p.BreezClient().lightningNode.NodeId(), + Destination: p.BreezClient().Node().NodeId(), IncomingAmountMsat: int64(outerAmountMsat), OutgoingAmountMsat: int64(pretendAmount), }) @@ -52,13 +54,13 @@ func testOpenZeroConfChannelOnReceive(p *testParams) { <-time.After(htlcInterceptorDelay) log.Printf("Alice paying") payResp := alice.Pay(outerInvoice.bolt11) - bobInvoice := p.BreezClient().lightningNode.GetInvoice(payResp.PaymentHash) + bobInvoice := p.BreezClient().Node().GetInvoice(payResp.PaymentHash) assert.Equal(p.t, payResp.PaymentPreimage, bobInvoice.PaymentPreimage) assert.Equal(p.t, innerAmountMsat, bobInvoice.AmountReceivedMsat) // Make sure capacity is correct - chans := p.BreezClient().lightningNode.GetChannels() + chans := p.BreezClient().Node().GetChannels() assert.Len(p.t, chans, 1) c := chans[0] AssertChannelCapacity(p.t, outerAmountMsat, c.CapacityMsat) @@ -66,7 +68,7 @@ func testOpenZeroConfChannelOnReceive(p *testParams) { func testOpenZeroConfSingleHtlc(p *testParams) { alice := lntest.NewClnNode(p.h, p.m, "Alice") - + alice.Start() alice.Fund(10000000) p.lsp.LightningNode().Fund(10000000) @@ -80,15 +82,16 @@ func testOpenZeroConfSingleHtlc(p *testParams) { outerAmountMsat := uint64(2100000) innerAmountMsat := calculateInnerAmountMsat(p.lsp, outerAmountMsat) description := "Please pay me" - innerInvoice, outerInvoice := p.BreezClient().GenerateInvoices(generateInvoicesRequest{ - innerAmountMsat: innerAmountMsat, - outerAmountMsat: outerAmountMsat, - description: description, - lsp: p.lsp, - }) + innerInvoice, outerInvoice := GenerateInvoices(p.BreezClient(), + generateInvoicesRequest{ + innerAmountMsat: innerAmountMsat, + outerAmountMsat: outerAmountMsat, + description: description, + lsp: p.lsp, + }) log.Print("Connecting bob to lspd") - p.BreezClient().lightningNode.ConnectPeer(p.lsp.LightningNode()) + p.BreezClient().Node().ConnectPeer(p.lsp.LightningNode()) // NOTE: We pretend to be paying fees to the lsp, but actually we won't. log.Printf("Registering payment with lsp") @@ -96,7 +99,7 @@ func testOpenZeroConfSingleHtlc(p *testParams) { RegisterPayment(p.lsp, &lspd.PaymentInformation{ PaymentHash: innerInvoice.paymentHash, PaymentSecret: innerInvoice.paymentSecret, - Destination: p.BreezClient().lightningNode.NodeId(), + Destination: p.BreezClient().Node().NodeId(), IncomingAmountMsat: int64(outerAmountMsat), OutgoingAmountMsat: int64(pretendAmount), }) @@ -105,15 +108,15 @@ func testOpenZeroConfSingleHtlc(p *testParams) { log.Printf("Waiting %v to allow htlc interceptor to activate.", htlcInterceptorDelay) <-time.After(htlcInterceptorDelay) log.Printf("Alice paying") - route := constructRoute(p.lsp.LightningNode(), p.BreezClient().lightningNode, channelId, lntest.NewShortChanIDFromString("1x0x0"), outerAmountMsat) + route := constructRoute(p.lsp.LightningNode(), p.BreezClient().Node(), channelId, lntest.NewShortChanIDFromString("1x0x0"), outerAmountMsat) payResp := alice.PayViaRoute(outerAmountMsat, outerInvoice.paymentHash, outerInvoice.paymentSecret, route) - bobInvoice := p.BreezClient().lightningNode.GetInvoice(payResp.PaymentHash) + bobInvoice := p.BreezClient().Node().GetInvoice(payResp.PaymentHash) assert.Equal(p.t, payResp.PaymentPreimage, bobInvoice.PaymentPreimage) assert.Equal(p.t, innerAmountMsat, bobInvoice.AmountReceivedMsat) // Make sure capacity is correct - chans := p.BreezClient().lightningNode.GetChannels() + chans := p.BreezClient().Node().GetChannels() assert.Len(p.t, chans, 1) c := chans[0] AssertChannelCapacity(p.t, outerAmountMsat, c.CapacityMsat) diff --git a/itest/lnd_breez_client.go b/itest/lnd_breez_client.go new file mode 100644 index 0000000..6da835a --- /dev/null +++ b/itest/lnd_breez_client.go @@ -0,0 +1,108 @@ +package itest + +import ( + "context" + "flag" + "sync" + + "github.com/breez/lntest" + "github.com/breez/lntest/lnd" + "github.com/lightningnetwork/lnd/lnwire" +) + +var lndMobileExecutable = flag.String( + "lndmobileexec", "", "full path to lnd mobile binary", +) + +type lndBreezClient struct { + name string + harness *lntest.TestHarness + node *lntest.LndNode + cancel context.CancelFunc + mtx sync.Mutex +} + +func newLndBreezClient(h *lntest.TestHarness, m *lntest.Miner, name string) BreezClient { + lnd := lntest.NewLndNodeFromBinary(h, m, name, *lndMobileExecutable, + "--protocol.zero-conf", + "--protocol.option-scid-alias", + "--bitcoin.defaultchanconfs=0", + ) + + c := &lndBreezClient{ + name: name, + harness: h, + node: lnd, + } + h.AddStoppable(c) + return c +} + +func (c *lndBreezClient) Name() string { + return c.name +} + +func (c *lndBreezClient) Harness() *lntest.TestHarness { + return c.harness +} + +func (c *lndBreezClient) Node() lntest.LightningNode { + return c.node +} + +func (c *lndBreezClient) Start() { + c.mtx.Lock() + defer c.mtx.Unlock() + + if c.node.IsStarted() { + return + } + + c.node.Start() + + ctx, cancel := context.WithCancel(c.harness.Ctx) + c.cancel = cancel + go c.startChannelAcceptor(ctx) +} + +func (c *lndBreezClient) Stop() error { + c.mtx.Lock() + defer c.mtx.Unlock() + + // Stop the channel acceptor + if c.cancel != nil { + c.cancel() + c.cancel = nil + } + + return c.node.Stop() +} + +func (c *lndBreezClient) startChannelAcceptor(ctx context.Context) error { + client, err := c.node.LightningClient().ChannelAcceptor(ctx) + if err != nil { + c.harness.T.Fatalf("%s: failed to create channel acceptor: %v", c.name, err) + } + + for { + request, err := client.Recv() + if err != nil { + return err + } + + private := request.ChannelFlags&uint32(lnwire.FFAnnounceChannel) == 0 + resp := &lnd.ChannelAcceptResponse{ + PendingChanId: request.PendingChanId, + Accept: private, + } + if request.WantsZeroConf { + resp.MinAcceptDepth = 0 + resp.ZeroConf = true + } + + err = client.Send(resp) + if err != nil { + c.harness.T.Fatalf("%s: failed to send acceptor response: %v", c.name, err) + } + } +} diff --git a/itest/lnd_lspd_node.go b/itest/lnd_lspd_node.go new file mode 100644 index 0000000..7fa3385 --- /dev/null +++ b/itest/lnd_lspd_node.go @@ -0,0 +1,230 @@ +package itest + +import ( + "fmt" + "log" + "os" + "os/exec" + "path/filepath" + "runtime" + "strings" + "sync" + + "github.com/breez/lntest" + lspd "github.com/breez/lspd/rpc" + "github.com/btcsuite/btcd/btcec/v2" + ecies "github.com/ecies/go/v2" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" +) + +type LndLspNode struct { + harness *lntest.TestHarness + lightningNode *lntest.LndNode + logFilePath string + isInitialized bool + lspBase *lspBase + runtime *lndLspNodeRuntime + mtx sync.Mutex +} + +type lndLspNodeRuntime struct { + logFile *os.File + cmd *exec.Cmd + rpc lspd.ChannelOpenerClient + cleanups []*lntest.Cleanup +} + +func NewLndLspdNode(h *lntest.TestHarness, m *lntest.Miner, name string) LspNode { + args := []string{ + "--protocol.zero-conf", + "--protocol.option-scid-alias", + "--requireinterceptor", + "--bitcoin.defaultchanconfs=0", + fmt.Sprintf("--bitcoin.chanreservescript=\"0 if (chanAmt != %d) else chanAmt/100\"", publicChanAmount), + fmt.Sprintf("--bitcoin.basefee=%d", lspBaseFeeMsat), + fmt.Sprintf("--bitcoin.feerate=%d", lspFeeRatePpm), + fmt.Sprintf("--bitcoin.timelockdelta=%d", lspCltvDelta), + } + + lightningNode := lntest.NewLndNode(h, m, name, args...) + tlsCert := strings.Replace(string(lightningNode.TlsCert()), "\n", "\\n", -1) + lspBase, err := newLspd(h, name, + "RUN_LND=true", + fmt.Sprintf("LND_CERT=\"%s\"", tlsCert), + fmt.Sprintf("LND_ADDRESS=%s", lightningNode.GrpcHost()), + fmt.Sprintf("LND_MACAROON_HEX=%x", lightningNode.Macaroon()), + ) + if err != nil { + h.T.Fatalf("failed to initialize lspd") + } + scriptDir := filepath.Dir(lspBase.scriptFilePath) + logFilePath := filepath.Join(scriptDir, "lspd.log") + h.RegisterLogfile(logFilePath, fmt.Sprintf("lspd-%s", name)) + + lspNode := &LndLspNode{ + harness: h, + lightningNode: lightningNode, + logFilePath: logFilePath, + lspBase: lspBase, + } + + h.AddStoppable(lspNode) + return lspNode +} + +func (c *LndLspNode) Start() { + c.mtx.Lock() + defer c.mtx.Unlock() + + var cleanups []*lntest.Cleanup + wasInitialized := c.isInitialized + if !c.isInitialized { + err := c.lspBase.Initialize() + if err != nil { + c.harness.T.Fatalf("failed to initialize lsp: %v", err) + } + c.isInitialized = true + cleanups = append(cleanups, &lntest.Cleanup{ + Name: fmt.Sprintf("%s: lsp base", c.lspBase.name), + Fn: c.lspBase.Stop, + }) + } + + c.lightningNode.Start() + cleanups = append(cleanups, &lntest.Cleanup{ + Name: fmt.Sprintf("%s: lsp lightning node", c.lspBase.name), + Fn: c.lightningNode.Stop, + }) + + if !wasInitialized { + scriptFile, err := os.ReadFile(c.lspBase.scriptFilePath) + if err != nil { + lntest.PerformCleanup(cleanups) + c.harness.T.Fatalf("failed to open scriptfile '%s': %v", c.lspBase.scriptFilePath, err) + } + + err = os.Remove(c.lspBase.scriptFilePath) + if err != nil { + lntest.PerformCleanup(cleanups) + c.harness.T.Fatalf("failed to remove scriptfile '%s': %v", c.lspBase.scriptFilePath, err) + } + + split := strings.Split(string(scriptFile), "\n") + for i, s := range split { + if strings.HasPrefix(s, "export LND_CERT") { + tlsCert := strings.Replace(string(c.lightningNode.TlsCert()), "\n", "\\n", -1) + split[i] = fmt.Sprintf("export LND_CERT=\"%s\"", tlsCert) + } + + if strings.HasPrefix(s, "export LND_MACAROON_HEX") { + split[i] = fmt.Sprintf("export LND_MACAROON_HEX=%x", c.lightningNode.Macaroon()) + } + } + newContent := strings.Join(split, "\n") + err = os.WriteFile(c.lspBase.scriptFilePath, []byte(newContent), 0755) + if err != nil { + lntest.PerformCleanup(cleanups) + c.harness.T.Fatalf("failed to rewrite scriptfile '%s': %v", c.lspBase.scriptFilePath, err) + } + } + + cmd := exec.CommandContext(c.harness.Ctx, c.lspBase.scriptFilePath) + logFile, err := os.Create(c.logFilePath) + if err != nil { + lntest.PerformCleanup(cleanups) + c.harness.T.Fatalf("failed create lsp logfile: %v", err) + } + cleanups = append(cleanups, &lntest.Cleanup{ + Name: fmt.Sprintf("%s: logfile", c.lspBase.name), + Fn: logFile.Close, + }) + + cmd.Stdout = logFile + cmd.Stderr = logFile + + log.Printf("%s: starting lspd %s", c.lspBase.name, c.lspBase.scriptFilePath) + err = cmd.Start() + if err != nil { + lntest.PerformCleanup(cleanups) + c.harness.T.Fatalf("failed to start lspd: %v", err) + } + cleanups = append(cleanups, &lntest.Cleanup{ + Name: fmt.Sprintf("%s: cmd", c.lspBase.name), + Fn: func() error { + proc := cmd.Process + if proc != nil { + if runtime.GOOS == "windows" { + return proc.Signal(os.Kill) + } + + return proc.Signal(os.Interrupt) + } + + return nil + }, + }) + + conn, err := grpc.Dial( + c.lspBase.grpcAddress, + grpc.WithTransportCredentials(insecure.NewCredentials()), + grpc.WithPerRPCCredentials(&token{token: "hello"}), + ) + if err != nil { + lntest.PerformCleanup(cleanups) + c.harness.T.Fatalf("failed to create grpc connection: %v", err) + } + cleanups = append(cleanups, &lntest.Cleanup{ + Name: fmt.Sprintf("%s: grpc conn", c.lspBase.name), + Fn: conn.Close, + }) + + client := lspd.NewChannelOpenerClient(conn) + c.runtime = &lndLspNodeRuntime{ + logFile: logFile, + cmd: cmd, + rpc: client, + cleanups: cleanups, + } +} + +func (c *LndLspNode) Stop() error { + c.mtx.Lock() + defer c.mtx.Unlock() + + if c.runtime == nil { + return nil + } + + lntest.PerformCleanup(c.runtime.cleanups) + c.runtime = nil + return nil +} + +func (c *LndLspNode) Harness() *lntest.TestHarness { + return c.harness +} + +func (c *LndLspNode) PublicKey() *btcec.PublicKey { + return c.lspBase.pubkey +} + +func (c *LndLspNode) EciesPublicKey() *ecies.PublicKey { + return c.lspBase.eciesPubkey +} + +func (c *LndLspNode) Rpc() lspd.ChannelOpenerClient { + return c.runtime.rpc +} + +func (l *LndLspNode) SupportsChargingFees() bool { + return true +} + +func (l *LndLspNode) NodeId() []byte { + return l.lightningNode.NodeId() +} + +func (l *LndLspNode) LightningNode() lntest.LightningNode { + return l.lightningNode +} diff --git a/itest/lspd_node.go b/itest/lspd_node.go index e845d5b..c4601bc 100644 --- a/itest/lspd_node.go +++ b/itest/lspd_node.go @@ -9,7 +9,6 @@ import ( "os" "os/exec" "path/filepath" - "strings" "github.com/breez/lntest" lspd "github.com/breez/lspd/rpc" @@ -17,8 +16,6 @@ import ( "github.com/decred/dcrd/dcrec/secp256k1/v4" ecies "github.com/ecies/go/v2" "github.com/golang/protobuf/proto" - "google.golang.org/grpc" - "google.golang.org/grpc/credentials/insecure" ) var ( @@ -37,6 +34,8 @@ var ( ) type LspNode interface { + Start() + Stop() error Harness() *lntest.TestHarness PublicKey() *btcec.PublicKey EciesPublicKey() *ecies.PublicKey @@ -46,228 +45,43 @@ type LspNode interface { SupportsChargingFees() bool } -type ClnLspNode struct { +type lspBase struct { harness *lntest.TestHarness - lightningNode *lntest.ClnNode - rpc lspd.ChannelOpenerClient - publicKey btcec.PublicKey - eciesPublicKey ecies.PublicKey + name string + binary string + env []string + scriptFilePath string + grpcAddress string + pubkey *secp256k1.PublicKey + eciesPubkey *ecies.PublicKey postgresBackend *PostgresContainer } -func (c *ClnLspNode) Harness() *lntest.TestHarness { - return c.harness -} - -func (c *ClnLspNode) PublicKey() *btcec.PublicKey { - return &c.publicKey -} - -func (c *ClnLspNode) EciesPublicKey() *ecies.PublicKey { - return &c.eciesPublicKey -} - -func (c *ClnLspNode) Rpc() lspd.ChannelOpenerClient { - return c.rpc -} - -func (l *ClnLspNode) TearDown() error { - // NOTE: The lightningnode will be torn down on its own. - return l.postgresBackend.Shutdown(context.Background()) -} - -func (l *ClnLspNode) Cleanup() error { - return l.postgresBackend.Cleanup(context.Background()) -} - -func (l *ClnLspNode) NodeId() []byte { - return l.lightningNode.NodeId() -} - -func (l *ClnLspNode) LightningNode() lntest.LightningNode { - return l.lightningNode -} - -func (l *ClnLspNode) SupportsChargingFees() bool { - return false -} - -type LndLspNode struct { - harness *lntest.TestHarness - lightningNode *lntest.LndNode - rpc lspd.ChannelOpenerClient - publicKey btcec.PublicKey - eciesPublicKey ecies.PublicKey - postgresBackend *PostgresContainer - logFile *os.File - lspdCmd *exec.Cmd -} - -func (c *LndLspNode) Harness() *lntest.TestHarness { - return c.harness -} - -func (c *LndLspNode) PublicKey() *btcec.PublicKey { - return &c.publicKey -} - -func (c *LndLspNode) EciesPublicKey() *ecies.PublicKey { - return &c.eciesPublicKey -} - -func (c *LndLspNode) Rpc() lspd.ChannelOpenerClient { - return c.rpc -} -func (l *LndLspNode) SupportsChargingFees() bool { - return true -} - -func (l *LndLspNode) TearDown() error { - // NOTE: The lightningnode will be torn down on its own. - if l.lspdCmd != nil && l.lspdCmd.Process != nil { - err := l.lspdCmd.Process.Kill() - if err != nil { - log.Printf("error stopping lspd process: %v", err) - } - } - - if l.logFile != nil { - err := l.logFile.Close() - if err != nil { - log.Printf("error closing logfile: %v", err) - } - } - - return l.postgresBackend.Shutdown(context.Background()) -} - -func (l *LndLspNode) Cleanup() error { - return l.postgresBackend.Cleanup(context.Background()) -} - -func (l *LndLspNode) NodeId() []byte { - return l.lightningNode.NodeId() -} - -func (l *LndLspNode) LightningNode() lntest.LightningNode { - return l.lightningNode -} - -func NewClnLspdNode(h *lntest.TestHarness, m *lntest.Miner, name string) LspNode { - scriptFilePath, grpcAddress, publ, eciesPubl, postgresBackend := setupLspd(h, name, "RUN_CLN=true") - args := []string{ - fmt.Sprintf("--plugin=%s", scriptFilePath), - fmt.Sprintf("--fee-base=%d", lspBaseFeeMsat), - fmt.Sprintf("--fee-per-satoshi=%d", lspFeeRatePpm), - fmt.Sprintf("--cltv-delta=%d", lspCltvDelta), - "--max-concurrent-htlcs=30", - "--dev-allowdustreserve=true", - } - - lightningNode := lntest.NewClnNode(h, m, name, args...) - - conn, err := grpc.Dial( - grpcAddress, - grpc.WithTransportCredentials(insecure.NewCredentials()), - grpc.WithPerRPCCredentials(&token{token: "hello"}), - ) - lntest.CheckError(h.T, err) - - client := lspd.NewChannelOpenerClient(conn) - - lspNode := &ClnLspNode{ - harness: h, - lightningNode: lightningNode, - rpc: client, - publicKey: *publ, - eciesPublicKey: *eciesPubl, - postgresBackend: postgresBackend, - } - - h.AddStoppable(lspNode) - h.AddCleanable(lspNode) - return lspNode -} - -func NewLndLspdNode(h *lntest.TestHarness, m *lntest.Miner, name string) LspNode { - args := []string{ - "--protocol.zero-conf", - "--protocol.option-scid-alias", - "--requireinterceptor", - "--bitcoin.defaultchanconfs=0", - fmt.Sprintf("--bitcoin.chanreservescript=\"0 if (chanAmt != %d) else chanAmt/100\"", publicChanAmount), - fmt.Sprintf("--bitcoin.basefee=%d", lspBaseFeeMsat), - fmt.Sprintf("--bitcoin.feerate=%d", lspFeeRatePpm), - fmt.Sprintf("--bitcoin.timelockdelta=%d", lspCltvDelta), - } - - lightningNode := lntest.NewLndNode(h, m, name, args...) - tlsCert := strings.Replace(string(lightningNode.TlsCert()), "\n", "\\n", -1) - scriptFilePath, grpcAddress, publ, eciesPubl, postgresBackend := setupLspd(h, name, - "RUN_LND=true", - fmt.Sprintf("LND_CERT=\"%s\"", tlsCert), - fmt.Sprintf("LND_ADDRESS=%s", lightningNode.GrpcHost()), - fmt.Sprintf("LND_MACAROON_HEX=%x", lightningNode.Macaroon()), - ) - scriptDir := filepath.Dir(scriptFilePath) - logFilePath := filepath.Join(scriptDir, "lspd.log") - h.RegisterLogfile(logFilePath, fmt.Sprintf("lspd-%s", name)) - - lspdCmd := exec.CommandContext(h.Ctx, scriptFilePath) - logFile, err := os.Create(logFilePath) - lntest.CheckError(h.T, err) - - lspdCmd.Stdout = logFile - lspdCmd.Stderr = logFile - - log.Printf("%s: starting lspd %s", name, scriptFilePath) - err = lspdCmd.Start() - lntest.CheckError(h.T, err) - - conn, err := grpc.Dial( - grpcAddress, - grpc.WithTransportCredentials(insecure.NewCredentials()), - grpc.WithPerRPCCredentials(&token{token: "hello"}), - ) - lntest.CheckError(h.T, err) - - client := lspd.NewChannelOpenerClient(conn) - - lspNode := &LndLspNode{ - harness: h, - lightningNode: lightningNode, - rpc: client, - publicKey: *publ, - eciesPublicKey: *eciesPubl, - postgresBackend: postgresBackend, - logFile: logFile, - lspdCmd: lspdCmd, - } - - h.AddStoppable(lspNode) - h.AddCleanable(lspNode) - return lspNode -} - -func setupLspd(h *lntest.TestHarness, name string, envExt ...string) (string, string, *secp256k1.PublicKey, *ecies.PublicKey, *PostgresContainer) { +func newLspd(h *lntest.TestHarness, name string, envExt ...string) (*lspBase, error) { scriptDir := h.GetDirectory(fmt.Sprintf("lspd-%s", name)) log.Printf("%s: Creating LSPD in dir %s", name, scriptDir) - migrationsDir, err := getMigrationsDir() - lntest.CheckError(h.T, err) pgLogfile := filepath.Join(scriptDir, "postgres.log") h.RegisterLogfile(pgLogfile, fmt.Sprintf("%s-postgres", name)) - postgresBackend := StartPostgresContainer(h.T, h.Ctx, pgLogfile) - postgresBackend.RunMigrations(h.T, h.Ctx, migrationsDir) + postgresBackend, err := NewPostgresContainer(pgLogfile) + if err != nil { + return nil, err + } lspdBinary, err := getLspdBinary() - lntest.CheckError(h.T, err) + if err != nil { + return nil, err + } lspdPort, err := lntest.GetPort() - lntest.CheckError(h.T, err) + if err != nil { + return nil, err + } lspdPrivateKeyBytes, err := GenerateRandomBytes(32) - lntest.CheckError(h.T, err) + if err != nil { + return nil, err + } _, publ := btcec.PrivKeyFromBytes(lspdPrivateKeyBytes) eciesPubl := ecies.NewPrivateKeyFromBytes(lspdPrivateKeyBytes).PublicKey @@ -286,27 +100,91 @@ func setupLspd(h *lntest.TestHarness, name string, envExt ...string) (string, st env = append(env, envExt...) scriptFilePath := filepath.Join(scriptDir, "start-lspd.sh") - log.Printf("%s: Creating lspd startup script at %s", name, scriptFilePath) - scriptFile, err := os.OpenFile(scriptFilePath, os.O_CREATE|os.O_WRONLY, 0755) - lntest.CheckError(h.T, err) - writer := bufio.NewWriter(scriptFile) - _, err = writer.WriteString("#!/bin/bash\n") - lntest.CheckError(h.T, err) + l := &lspBase{ + harness: h, + name: name, + env: env, + binary: lspdBinary, + scriptFilePath: scriptFilePath, + grpcAddress: grpcAddress, + pubkey: publ, + eciesPubkey: eciesPubl, + postgresBackend: postgresBackend, + } + h.AddStoppable(l) + h.AddCleanable(l) + return l, nil +} - for _, str := range env { - _, err = writer.WriteString("export " + str + "\n") - lntest.CheckError(h.T, err) +func (l *lspBase) Stop() error { + return l.postgresBackend.Stop(context.Background()) +} + +func (l *lspBase) Cleanup() error { + return l.postgresBackend.Cleanup(context.Background()) +} + +func (l *lspBase) Initialize() error { + var cleanups []*lntest.Cleanup + migrationsDir, err := getMigrationsDir() + if err != nil { + return err } - _, err = writer.WriteString(lspdBinary + "\n") - lntest.CheckError(h.T, err) + err = l.postgresBackend.Start(l.harness.Ctx) + if err != nil { + return err + } + + cleanups = append(cleanups, &lntest.Cleanup{ + Name: fmt.Sprintf("%s: postgres container", l.name), + Fn: func() error { + return l.postgresBackend.Stop(context.Background()) + }, + }) + err = l.postgresBackend.RunMigrations(l.harness.Ctx, migrationsDir) + if err != nil { + lntest.PerformCleanup(cleanups) + return err + } + + log.Printf("%s: Creating lspd startup script at %s", l.name, l.scriptFilePath) + scriptFile, err := os.OpenFile(l.scriptFilePath, os.O_CREATE|os.O_WRONLY, 0755) + if err != nil { + lntest.PerformCleanup(cleanups) + return err + } + + defer scriptFile.Close() + writer := bufio.NewWriter(scriptFile) + _, err = writer.WriteString("#!/bin/bash\n") + if err != nil { + lntest.PerformCleanup(cleanups) + return err + } + + for _, str := range l.env { + _, err = writer.WriteString("export " + str + "\n") + if err != nil { + lntest.PerformCleanup(cleanups) + return err + } + } + + _, err = writer.WriteString(l.binary + "\n") + if err != nil { + lntest.PerformCleanup(cleanups) + return err + } err = writer.Flush() - lntest.CheckError(h.T, err) - scriptFile.Close() + if err != nil { + lntest.PerformCleanup(cleanups) + return err + } - return scriptFilePath, grpcAddress, publ, eciesPubl, postgresBackend + return nil } func RegisterPayment(l LspNode, paymentInfo *lspd.PaymentInformation) { diff --git a/itest/lspd_test.go b/itest/lspd_test.go index 6ae6c37..14b4ff5 100644 --- a/itest/lspd_test.go +++ b/itest/lspd_test.go @@ -3,6 +3,7 @@ package itest import ( "fmt" "log" + "sync" "testing" "time" @@ -13,16 +14,16 @@ var defaultTimeout time.Duration = time.Second * 120 func TestLspd(t *testing.T) { testCases := allTestCases - runTests(t, testCases, "LND-lspd", func(h *lntest.TestHarness, m *lntest.Miner) (LspNode, *breezClient) { + runTests(t, testCases, "LND-lspd", func(h *lntest.TestHarness, m *lntest.Miner) (LspNode, BreezClient) { return NewLndLspdNode(h, m, "lsp"), newLndBreezClient(h, m, "breez-client") }) - runTests(t, testCases, "CLN-lspd", func(h *lntest.TestHarness, m *lntest.Miner) (LspNode, *breezClient) { + runTests(t, testCases, "CLN-lspd", func(h *lntest.TestHarness, m *lntest.Miner) (LspNode, BreezClient) { return NewClnLspdNode(h, m, "lsp"), newClnBreezClient(h, m, "breez-client") }) } -func runTests(t *testing.T, testCases []*testCase, prefix string, nodesFunc func(h *lntest.TestHarness, m *lntest.Miner) (LspNode, *breezClient)) { +func runTests(t *testing.T, testCases []*testCase, prefix string, nodesFunc func(h *lntest.TestHarness, m *lntest.Miner) (LspNode, BreezClient)) { for _, testCase := range testCases { testCase := testCase t.Run(fmt.Sprintf("%s: %s", prefix, testCase.name), func(t *testing.T) { @@ -31,7 +32,7 @@ func runTests(t *testing.T, testCases []*testCase, prefix string, nodesFunc func } } -func runTest(t *testing.T, testCase *testCase, prefix string, nodesFunc func(h *lntest.TestHarness, m *lntest.Miner) (LspNode, *breezClient)) { +func runTest(t *testing.T, testCase *testCase, prefix string, nodesFunc func(h *lntest.TestHarness, m *lntest.Miner) (LspNode, BreezClient)) { log.Printf("%s: Running test case '%s'", prefix, testCase.name) var dd time.Duration to := testCase.timeout @@ -47,8 +48,21 @@ func runTest(t *testing.T, testCase *testCase, prefix string, nodesFunc func(h * log.Printf("Creating miner") miner := lntest.NewMiner(h) + miner.Start() log.Printf("Creating lsp") lsp, c := nodesFunc(h, miner) + var wg sync.WaitGroup + wg.Add(2) + go func() { + lsp.Start() + wg.Done() + }() + + go func() { + c.Start() + wg.Done() + }() + wg.Wait() log.Printf("Run testcase") testCase.test(&testParams{ t: t, diff --git a/itest/postgres.go b/itest/postgres.go index bc3a89d..a8c4f54 100644 --- a/itest/postgres.go +++ b/itest/postgres.go @@ -3,7 +3,6 @@ package itest import ( "context" "encoding/binary" - "errors" "fmt" "io" "log" @@ -11,7 +10,7 @@ import ( "path/filepath" "sort" "strconv" - "testing" + "sync" "time" "github.com/breez/lntest" @@ -23,33 +22,112 @@ import ( ) type PostgresContainer struct { - id string - password string - port uint32 - cli *client.Client + id string + password string + port uint32 + cli *client.Client + logfile string + isInitialized bool + isStarted bool + mtx sync.Mutex } -func StartPostgresContainer(t *testing.T, ctx context.Context, logfile string) *PostgresContainer { - cli, err := client.NewClientWithOpts(client.FromEnv) - lntest.CheckError(t, err) - - image := "postgres:15" - _, _, err = cli.ImageInspectWithRaw(ctx, image) +func NewPostgresContainer(logfile string) (*PostgresContainer, error) { + port, err := lntest.GetPort() if err != nil { - if !client.IsErrNotFound(err) { - lntest.CheckError(t, err) - } - - pullReader, err := cli.ImagePull(ctx, image, types.ImagePullOptions{}) - lntest.CheckError(t, err) - _, err = io.Copy(io.Discard, pullReader) - pullReader.Close() - lntest.CheckError(t, err) + return nil, fmt.Errorf("could not get port: %w", err) } - port, err := lntest.GetPort() - lntest.CheckError(t, err) - createResp, err := cli.ContainerCreate(ctx, &container.Config{ + return &PostgresContainer{ + password: "pgpassword", + port: port, + }, nil +} + +func (c *PostgresContainer) Start(ctx context.Context) error { + c.mtx.Lock() + defer c.mtx.Unlock() + + var err error + if c.isStarted { + return nil + } + + c.cli, err = client.NewClientWithOpts(client.FromEnv) + if err != nil { + return fmt.Errorf("could not create docker client: %w", err) + } + + if !c.isInitialized { + err := c.initialize(ctx) + if err != nil { + c.cli.Close() + return err + } + } + + err = c.cli.ContainerStart(ctx, c.id, types.ContainerStartOptions{}) + if err != nil { + c.cli.Close() + return fmt.Errorf("failed to start docker container '%s': %w", c.id, err) + } + c.isStarted = true + +HealthCheck: + for { + inspect, err := c.cli.ContainerInspect(ctx, c.id) + if err != nil { + c.cli.ContainerStop(ctx, c.id, nil) + c.cli.Close() + return fmt.Errorf("failed to inspect container '%s' during healthcheck: %w", c.id, err) + } + + status := inspect.State.Health.Status + switch status { + case "unhealthy": + c.cli.ContainerStop(ctx, c.id, nil) + c.cli.Close() + return fmt.Errorf("container '%s' unhealthy", c.id) + case "healthy": + for { + pgxPool, err := pgxpool.Connect(ctx, c.ConnectionString()) + if err == nil { + pgxPool.Close() + break HealthCheck + } + + <-time.After(50 * time.Millisecond) + } + default: + <-time.After(200 * time.Millisecond) + } + } + + go c.monitorLogs(ctx) + return nil +} + +func (c *PostgresContainer) initialize(ctx context.Context) error { + image := "postgres:15" + _, _, err := c.cli.ImageInspectWithRaw(ctx, image) + if err != nil { + if !client.IsErrNotFound(err) { + return fmt.Errorf("could not find docker image '%s': %w", image, err) + } + + pullReader, err := c.cli.ImagePull(ctx, image, types.ImagePullOptions{}) + if err != nil { + return fmt.Errorf("failed to pull docker image '%s': %w", image, err) + } + defer pullReader.Close() + + _, err = io.Copy(io.Discard, pullReader) + if err != nil { + return fmt.Errorf("failed to download docker image '%s': %w", image, err) + } + } + + createResp, err := c.cli.ContainerCreate(ctx, &container.Config{ Image: image, Cmd: []string{ "postgres", @@ -70,7 +148,7 @@ func StartPostgresContainer(t *testing.T, ctx context.Context, logfile string) * }, &container.HostConfig{ PortBindings: nat.PortMap{ "5432/tcp": []nat.PortBinding{ - {HostPort: strconv.FormatUint(uint64(port), 10)}, + {HostPort: strconv.FormatUint(uint64(c.port), 10)}, }, }, }, @@ -78,48 +156,45 @@ func StartPostgresContainer(t *testing.T, ctx context.Context, logfile string) * nil, "", ) - lntest.CheckError(t, err) - err = cli.ContainerStart(ctx, createResp.ID, types.ContainerStartOptions{}) - lntest.CheckError(t, err) - - ct := &PostgresContainer{ - id: createResp.ID, - password: "pgpassword", - port: port, - cli: cli, + if err != nil { + return fmt.Errorf("failed to create docker container: %w", err) } -HealthCheck: - for { - inspect, err := cli.ContainerInspect(ctx, createResp.ID) - lntest.CheckError(t, err) - - status := inspect.State.Health.Status - switch status { - case "unhealthy": - lntest.CheckError(t, errors.New("container unhealthy")) - case "healthy": - for { - pgxPool, err := pgxpool.Connect(context.Background(), ct.ConnectionString()) - if err == nil { - pgxPool.Close() - break HealthCheck - } - - <-time.After(50 * time.Millisecond) - } - default: - <-time.After(200 * time.Millisecond) - } - } - - go ct.monitorLogs(logfile) - return ct + c.id = createResp.ID + c.isInitialized = true + return nil } -func (c *PostgresContainer) monitorLogs(logfile string) { - i, err := c.cli.ContainerLogs(context.Background(), c.id, types.ContainerLogsOptions{ +func (c *PostgresContainer) Stop(ctx context.Context) error { + c.mtx.Lock() + defer c.mtx.Unlock() + + if !c.isStarted { + return nil + } + + defer c.cli.Close() + err := c.cli.ContainerStop(ctx, c.id, nil) + c.isStarted = false + return err +} + +func (c *PostgresContainer) Cleanup(ctx context.Context) error { + c.mtx.Lock() + defer c.mtx.Unlock() + cli, err := client.NewClientWithOpts(client.FromEnv) + if err != nil { + return err + } + defer cli.Close() + return cli.ContainerRemove(ctx, c.id, types.ContainerRemoveOptions{ + Force: true, + }) +} + +func (c *PostgresContainer) monitorLogs(ctx context.Context) { + i, err := c.cli.ContainerLogs(ctx, c.id, types.ContainerLogsOptions{ ShowStderr: true, ShowStdout: true, Timestamps: false, @@ -132,7 +207,7 @@ func (c *PostgresContainer) monitorLogs(logfile string) { } defer i.Close() - file, err := os.OpenFile(logfile, os.O_WRONLY|os.O_CREATE|os.O_APPEND, 0600) + file, err := os.OpenFile(c.logfile, os.O_WRONLY|os.O_CREATE|os.O_APPEND, 0600) if err != nil { log.Printf("Could not create container log file: %v", err) return @@ -162,39 +237,31 @@ func (c *PostgresContainer) ConnectionString() string { return fmt.Sprintf("postgres://postgres:%s@127.0.0.1:%d/postgres", c.password, c.port) } -func (c *PostgresContainer) Shutdown(ctx context.Context) error { - defer c.cli.Close() - timeout := time.Second - err := c.cli.ContainerStop(ctx, c.id, &timeout) - return err -} - -func (c *PostgresContainer) Cleanup(ctx context.Context) error { - cli, err := client.NewClientWithOpts(client.FromEnv) - if err != nil { - return err - } - defer cli.Close() - return cli.ContainerRemove(ctx, c.id, types.ContainerRemoveOptions{ - Force: true, - }) -} - -func (c *PostgresContainer) RunMigrations(t *testing.T, ctx context.Context, migrationDir string) { +func (c *PostgresContainer) RunMigrations(ctx context.Context, migrationDir string) error { filenames, err := filepath.Glob(filepath.Join(migrationDir, "*.up.sql")) - lntest.CheckError(t, err) + if err != nil { + return fmt.Errorf("failed to glob migration files: %w", err) + } sort.Strings(filenames) - pgxPool, err := pgxpool.Connect(context.Background(), c.ConnectionString()) - lntest.CheckError(t, err) + pgxPool, err := pgxpool.Connect(ctx, c.ConnectionString()) + if err != nil { + return fmt.Errorf("failed to connect to postgres: %w", err) + } defer pgxPool.Close() for _, filename := range filenames { data, err := os.ReadFile(filename) - lntest.CheckError(t, err) + if err != nil { + return fmt.Errorf("failed to read migration file '%s': %w", filename, err) + } _, err = pgxPool.Exec(ctx, string(data)) - lntest.CheckError(t, err) + if err != nil { + return fmt.Errorf("failed to execute migration file '%s': %w", filename, err) + } } + + return nil } diff --git a/itest/test_params.go b/itest/test_params.go index 5343940..d67f708 100644 --- a/itest/test_params.go +++ b/itest/test_params.go @@ -10,7 +10,7 @@ type testParams struct { t *testing.T h *lntest.TestHarness m *lntest.Miner - c *breezClient + c BreezClient lsp LspNode } @@ -30,6 +30,6 @@ func (h *testParams) Harness() *lntest.TestHarness { return h.h } -func (h *testParams) BreezClient() *breezClient { +func (h *testParams) BreezClient() BreezClient { return h.c } diff --git a/itest/zero_reserve_test.go b/itest/zero_reserve_test.go index ea6e8d6..f9f7486 100644 --- a/itest/zero_reserve_test.go +++ b/itest/zero_reserve_test.go @@ -11,6 +11,7 @@ import ( func testZeroReserve(p *testParams) { alice := lntest.NewClnNode(p.h, p.m, "Alice") + alice.Start() alice.Fund(10000000) p.lsp.LightningNode().Fund(10000000) @@ -24,15 +25,16 @@ func testZeroReserve(p *testParams) { outerAmountMsat := uint64(2100000) innerAmountMsat := calculateInnerAmountMsat(p.lsp, outerAmountMsat) description := "Please pay me" - innerInvoice, outerInvoice := p.BreezClient().GenerateInvoices(generateInvoicesRequest{ - innerAmountMsat: innerAmountMsat, - outerAmountMsat: outerAmountMsat, - description: description, - lsp: p.lsp, - }) + innerInvoice, outerInvoice := GenerateInvoices(p.BreezClient(), + generateInvoicesRequest{ + innerAmountMsat: innerAmountMsat, + outerAmountMsat: outerAmountMsat, + description: description, + lsp: p.lsp, + }) log.Print("Connecting bob to lspd") - p.BreezClient().lightningNode.ConnectPeer(p.lsp.LightningNode()) + p.BreezClient().Node().ConnectPeer(p.lsp.LightningNode()) // NOTE: We pretend to be paying fees to the lsp, but actually we won't. log.Printf("Registering payment with lsp") @@ -40,7 +42,7 @@ func testZeroReserve(p *testParams) { RegisterPayment(p.lsp, &lspd.PaymentInformation{ PaymentHash: innerInvoice.paymentHash, PaymentSecret: innerInvoice.paymentSecret, - Destination: p.BreezClient().lightningNode.NodeId(), + Destination: p.BreezClient().Node().NodeId(), IncomingAmountMsat: int64(outerAmountMsat), OutgoingAmountMsat: int64(pretendAmount), }) @@ -49,11 +51,11 @@ func testZeroReserve(p *testParams) { log.Printf("Waiting %v to allow htlc interceptor to activate.", htlcInterceptorDelay) <-time.After(htlcInterceptorDelay) log.Printf("Alice paying") - route := constructRoute(p.lsp.LightningNode(), p.BreezClient().lightningNode, channelId, lntest.NewShortChanIDFromString("1x0x0"), outerAmountMsat) + route := constructRoute(p.lsp.LightningNode(), p.BreezClient().Node(), channelId, lntest.NewShortChanIDFromString("1x0x0"), outerAmountMsat) alice.PayViaRoute(outerAmountMsat, outerInvoice.paymentHash, outerInvoice.paymentSecret, route) // Make sure balance is correct - chans := p.BreezClient().lightningNode.GetChannels() + chans := p.BreezClient().Node().GetChannels() assert.Len(p.t, chans, 1) c := chans[0]