diff --git a/contractcourt/chain_arbitrator.go b/contractcourt/chain_arbitrator.go index 6d7c674a..1c18253e 100644 --- a/contractcourt/chain_arbitrator.go +++ b/contractcourt/chain_arbitrator.go @@ -296,7 +296,7 @@ func newActiveChannelArbitrator(channel *channeldb.OpenChannel, } arbCfg.MarkChannelResolved = func() error { - return c.resolveContract(chanPoint, chanLog) + return c.ResolveContract(chanPoint) } // Finally, we'll need to construct a series of htlc Sets based on all @@ -321,11 +321,10 @@ func newActiveChannelArbitrator(channel *channeldb.OpenChannel, ), nil } -// resolveContract marks a contract as fully resolved within the database. +// ResolveContract marks a contract as fully resolved within the database. // This is only to be done once all contracts which were live on the channel // before hitting the chain have been resolved. -func (c *ChainArbitrator) resolveContract(chanPoint wire.OutPoint, - arbLog ArbitratorLog) error { +func (c *ChainArbitrator) ResolveContract(chanPoint wire.OutPoint) error { log.Infof("Marking ChannelPoint(%v) fully resolved", chanPoint) @@ -338,27 +337,44 @@ func (c *ChainArbitrator) resolveContract(chanPoint wire.OutPoint, return err } + // Now that the channel has been marked as fully closed, we'll stop + // both the channel arbitrator and chain watcher for this channel if + // they're still active. + var arbLog ArbitratorLog + c.Lock() + chainArb := c.activeChannels[chanPoint] + delete(c.activeChannels, chanPoint) + + chainWatcher := c.activeWatchers[chanPoint] + delete(c.activeWatchers, chanPoint) + c.Unlock() + + if chainArb != nil { + arbLog = chainArb.log + + if err := chainArb.Stop(); err != nil { + log.Warnf("unable to stop ChannelArbitrator(%v): %v", + chanPoint, err) + } + } + if chainWatcher != nil { + if err := chainWatcher.Stop(); err != nil { + log.Warnf("unable to stop ChainWatcher(%v): %v", + chanPoint, err) + } + } + + // Once this has been marked as resolved, we'll wipe the log that the + // channel arbitrator was using to store its persistent state. We do + // this after marking the channel resolved, as otherwise, the + // arbitrator would be re-created, and think it was starting from the + // default state. if arbLog != nil { - // Once this has been marked as resolved, we'll wipe the log - // that the channel arbitrator was using to store its - // persistent state. We do this after marking the channel - // resolved, as otherwise, the arbitrator would be re-created, - // and think it was starting from the default state. if err := arbLog.WipeHistory(); err != nil { return err } } - c.Lock() - delete(c.activeChannels, chanPoint) - - chainWatcher, ok := c.activeWatchers[chanPoint] - if ok { - chainWatcher.Stop() - } - delete(c.activeWatchers, chanPoint) - c.Unlock() - return nil } @@ -491,7 +507,7 @@ func (c *ChainArbitrator) Start() error { return err } arbCfg.MarkChannelResolved = func() error { - return c.resolveContract(chanPoint, chanLog) + return c.ResolveContract(chanPoint) } // We can also leave off the set of HTLC's here as since the diff --git a/contractcourt/chain_arbitrator_test.go b/contractcourt/chain_arbitrator_test.go index 38ea2a35..28682c92 100644 --- a/contractcourt/chain_arbitrator_test.go +++ b/contractcourt/chain_arbitrator_test.go @@ -115,3 +115,105 @@ func TestChainArbitratorRepublishCommitment(t *testing.T) { t.Fatalf("unexpected tx published") } } + +// TestResolveContract tests that if we have an active channel being watched by +// the chain arb, then a call to ResolveContract will mark the channel as fully +// closed in the database, and also clean up all arbitrator state. +func TestResolveContract(t *testing.T) { + t.Parallel() + + // To start with, we'll create a new temp DB for the duration of this + // test. + tempPath, err := ioutil.TempDir("", "testdb") + if err != nil { + t.Fatalf("unable to make temp dir: %v", err) + } + defer os.RemoveAll(tempPath) + db, err := channeldb.Open(tempPath) + if err != nil { + t.Fatalf("unable to open db: %v", err) + } + defer db.Close() + + // With the DB created, we'll make a new channel, and mark it as + // pending open within the database. + newChannel, _, cleanup, err := lnwallet.CreateTestChannels(true) + if err != nil { + t.Fatalf("unable to make new test channel: %v", err) + } + defer cleanup() + channel := newChannel.State() + channel.Db = db + addr := &net.TCPAddr{ + IP: net.ParseIP("127.0.0.1"), + Port: 18556, + } + if err := channel.SyncPending(addr, 101); err != nil { + t.Fatalf("unable to write channel to db: %v", err) + } + + // With the channel inserted into the database, we'll now create a new + // chain arbitrator that should pick up these new channels and launch + // resolver for them. + chainArbCfg := ChainArbitratorConfig{ + ChainIO: &mockChainIO{}, + Notifier: &mockNotifier{}, + PublishTx: func(tx *wire.MsgTx) error { + return nil + }, + } + chainArb := NewChainArbitrator( + chainArbCfg, db, + ) + if err := chainArb.Start(); err != nil { + t.Fatal(err) + } + defer func() { + if err := chainArb.Stop(); err != nil { + t.Fatal(err) + } + }() + + channelArb := chainArb.activeChannels[channel.FundingOutpoint] + + // While the resolver are active, we'll now remove the channel from the + // database (mark is as closed). + err = db.AbandonChannel(&channel.FundingOutpoint, 4) + if err != nil { + t.Fatalf("unable to remove channel: %v", err) + } + + // With the channel removed, we'll now manually call ResolveContract. + // This stimulates needing to remove a channel from the chain arb due + // to any possible external consistency issues. + err = chainArb.ResolveContract(channel.FundingOutpoint) + if err != nil { + t.Fatalf("unable to resolve contract: %v", err) + } + + // The shouldn't be an active chain watcher or channel arb for this + // channel. + if len(chainArb.activeChannels) != 0 { + t.Fatalf("expected zero active channels, instead have %v", + len(chainArb.activeChannels)) + } + if len(chainArb.activeWatchers) != 0 { + t.Fatalf("expected zero active watchers, instead have %v", + len(chainArb.activeWatchers)) + } + + // At this point, the channel's arbitrator log should also be empty as + // well. + _, err = channelArb.log.FetchContractResolutions() + if err != errScopeBucketNoExist { + t.Fatalf("channel arb log state should have been "+ + "removed: %v", err) + } + + // If we attempt to call this method again, then we should get a nil + // error, as there is no more state to be cleaned up. + err = chainArb.ResolveContract(channel.FundingOutpoint) + if err != nil { + t.Fatalf("second resolve call shouldn't fail: %v", err) + } +}