From 6e673d2be45a5cbe3142b8fcc744f886e2be9a49 Mon Sep 17 00:00:00 2001 From: Jesse de Wit Date: Fri, 23 Dec 2022 18:23:38 +0100 Subject: [PATCH] more efficient cleanup, lessons learned --- forwarding_history.go | 32 ++++++++++++++++++++------- itest/lnd_lspd_node.go | 19 ++++++++++------ lnd_interceptor.go | 49 ++++++++++++++++++++++++++---------------- main.go | 11 ++++++++++ 4 files changed, 79 insertions(+), 32 deletions(-) diff --git a/forwarding_history.go b/forwarding_history.go index 56c566d..e45166d 100644 --- a/forwarding_history.go +++ b/forwarding_history.go @@ -36,19 +36,25 @@ func (cfe *copyFromEvents) Err() error { return cfe.err } -func channelsSynchronize(client *LndClient) { +func channelsSynchronize(ctx context.Context, client *LndClient) { lastSync := time.Now().Add(-6 * time.Minute) for { - cancellableCtx, cancel := context.WithCancel(context.Background()) - stream, err := client.chainNotifierClient.RegisterBlockEpochNtfn(cancellableCtx, &chainrpc.BlockEpoch{}) + if ctx.Err() != nil { + return + } + + stream, err := client.chainNotifierClient.RegisterBlockEpochNtfn(ctx, &chainrpc.BlockEpoch{}) if err != nil { log.Printf("chainNotifierClient.RegisterBlockEpochNtfn(): %v", err) - cancel() <-time.After(time.Second) continue } for { + if ctx.Err() != nil { + return + } + _, err := stream.Recv() if err != nil { log.Printf("stream.Recv: %v", err) @@ -56,13 +62,16 @@ func channelsSynchronize(client *LndClient) { break } if lastSync.Add(5 * time.Minute).Before(time.Now()) { - <-time.After(30 * time.Second) + select { + case <-ctx.Done(): + return + case <-time.After(1 * time.Minute): + } err = channelsSynchronizeOnce(client) lastSync = time.Now() log.Printf("channelsSynchronizeOnce() err: %v", err) } } - cancel() } } @@ -99,11 +108,18 @@ func channelsSynchronizeOnce(client *LndClient) error { return nil } -func forwardingHistorySynchronize(client *LndClient) { +func forwardingHistorySynchronize(ctx context.Context, client *LndClient) { for { + if ctx.Err() != nil { + return + } + err := forwardingHistorySynchronizeOnce(client) log.Printf("forwardingHistorySynchronizeOnce() err: %v", err) - <-time.After(1 * time.Minute) + select { + case <-time.After(1 * time.Minute): + case <-ctx.Done(): + } } } diff --git a/itest/lnd_lspd_node.go b/itest/lnd_lspd_node.go index 7fa3385..e096361 100644 --- a/itest/lnd_lspd_node.go +++ b/itest/lnd_lspd_node.go @@ -6,7 +6,6 @@ import ( "os" "os/exec" "path/filepath" - "runtime" "strings" "sync" @@ -153,12 +152,20 @@ func (c *LndLspNode) Start() { Name: fmt.Sprintf("%s: cmd", c.lspBase.name), Fn: func() error { proc := cmd.Process - if proc != nil { - if runtime.GOOS == "windows" { - return proc.Signal(os.Kill) - } + if proc == nil { + return nil + } - return proc.Signal(os.Interrupt) + proc.Kill() + + log.Printf("About to wait for lspd to exit") + status, err := proc.Wait() + if err != nil { + log.Printf("waiting for lspd process error: %v, status: %v", err, status) + } + err = cmd.Wait() + if err != nil { + log.Printf("waiting for lspd cmd error: %v", err) } return nil diff --git a/lnd_interceptor.go b/lnd_interceptor.go index 85f2ee0..a3775c4 100644 --- a/lnd_interceptor.go +++ b/lnd_interceptor.go @@ -14,10 +14,11 @@ import ( ) type LndHtlcInterceptor struct { - client *LndClient - stopRequested bool - initWg sync.WaitGroup - doneWg sync.WaitGroup + client *LndClient + initWg sync.WaitGroup + doneWg sync.WaitGroup + ctx context.Context + cancel context.CancelFunc } func NewLndHtlcInterceptor() *LndHtlcInterceptor { @@ -31,14 +32,18 @@ func NewLndHtlcInterceptor() *LndHtlcInterceptor { } func (i *LndHtlcInterceptor) Start() error { - go forwardingHistorySynchronize(i.client) - go channelsSynchronize(i.client) - i.initWg.Done() + ctx, cancel := context.WithCancel(context.Background()) + i.ctx = ctx + i.cancel = cancel + go forwardingHistorySynchronize(ctx, i.client) + go channelsSynchronize(ctx, i.client) + return i.intercept() } func (i *LndHtlcInterceptor) Stop() error { - i.stopRequested = true + i.cancel() + i.doneWg.Wait() return nil } @@ -48,44 +53,53 @@ func (i *LndHtlcInterceptor) WaitStarted() LightningClient { } func (i *LndHtlcInterceptor) intercept() error { + inited := false defer func() { + if !inited { + i.initWg.Done() + } log.Printf("LND intercept(): stopping. Waiting for in-progress interceptions to complete.") i.doneWg.Wait() }() for { - if i.stopRequested { - return nil + if i.ctx.Err() != nil { + return i.ctx.Err() } log.Printf("Connecting LND HTLC interceptor.") - cancellableCtx, cancel := context.WithCancel(context.Background()) - interceptorClient, err := i.client.routerClient.HtlcInterceptor(cancellableCtx) + interceptorClient, err := i.client.routerClient.HtlcInterceptor(i.ctx) if err != nil { log.Printf("routerClient.HtlcInterceptor(): %v", err) - cancel() <-time.After(time.Second) continue } for { - if i.stopRequested { - cancel() - return nil + if i.ctx.Err() != nil { + return i.ctx.Err() } + + if !inited { + inited = true + i.initWg.Done() + } + request, err := interceptorClient.Recv() if err != nil { // If it is just the error result of the context cancellation // the we exit silently. status, ok := status.FromError(err) if ok && status.Code() == codes.Canceled { + log.Printf("Got code canceled. Break.") break } + // Otherwise it an unexpected error, we fail the test. log.Printf("unexpected error in interceptor.Recv() %v", err) - cancel() break } + fmt.Printf("htlc: %v\nchanID: %v\nincoming amount: %v\noutgoing amount: %v\nincomin expiry: %v\noutgoing expiry: %v\npaymentHash: %x\nonionBlob: %x\n\n", request.IncomingCircuitKey.HtlcId, request.IncomingCircuitKey.ChanId, @@ -131,7 +145,6 @@ func (i *LndHtlcInterceptor) intercept() error { }() } - cancel() <-time.After(time.Second) } } diff --git a/main.go b/main.go index e447194..208a6a2 100644 --- a/main.go +++ b/main.go @@ -4,7 +4,9 @@ import ( "fmt" "log" "os" + "os/signal" "sync" + "syscall" "github.com/btcsuite/btcd/btcec/v2" ) @@ -86,6 +88,15 @@ func main() { wg.Done() }() + c := make(chan os.Signal, 1) + signal.Notify(c, syscall.SIGHUP, syscall.SIGINT, syscall.SIGTERM, syscall.SIGQUIT) + go func() { + sig := <-c + log.Printf("Received stop signal %v. Stopping.", sig) + s.Stop() + interceptor.Stop() + }() + wg.Wait() log.Printf("lspd exited") }