diff --git a/contractcourt/chain_arbitrator.go b/contractcourt/chain_arbitrator.go index 696ddeb2..6100758a 100644 --- a/contractcourt/chain_arbitrator.go +++ b/contractcourt/chain_arbitrator.go @@ -375,6 +375,7 @@ func (c *ChainArbitrator) Start() error { contractBreach: func(retInfo *lnwallet.BreachRetribution) error { return c.cfg.ContractBreach(chanPoint, retInfo) }, + extractStateNumHint: lnwallet.GetStateNumHint, }, ) if err != nil { @@ -710,6 +711,7 @@ func (c *ChainArbitrator) WatchNewChannel(newChan *channeldb.OpenChannel) error contractBreach: func(retInfo *lnwallet.BreachRetribution) error { return c.cfg.ContractBreach(chanPoint, retInfo) }, + extractStateNumHint: lnwallet.GetStateNumHint, }, ) if err != nil { diff --git a/contractcourt/chain_watcher.go b/contractcourt/chain_watcher.go index 719b2068..ec60c847 100644 --- a/contractcourt/chain_watcher.go +++ b/contractcourt/chain_watcher.go @@ -105,6 +105,11 @@ type chainWatcherConfig struct { // isOurAddr is a function that returns true if the passed address is // known to us. isOurAddr func(btcutil.Address) bool + + // extractStateNumHint extracts the encoded state hint using the passed + // obfuscater. This is used by the chain watcher to identify which + // state was broadcast and confirmed on-chain. + extractStateNumHint func(*wire.MsgTx, [lnwallet.StateHintSize]byte) uint64 } // chainWatcher is a system that's assigned to every active channel. The duty @@ -350,10 +355,9 @@ func (c *chainWatcher) closeObserver(spendNtfn *chainntnfs.SpendEvent) { "ChannelPoint(%v) ", c.cfg.chanState.FundingOutpoint) // Decode the state hint encoded within the commitment - // transaction to determine if this is a revoked state - // or not. + // transaction to determine if this is a revoked state or not. obfuscator := c.stateHintObfuscator - broadcastStateNum := lnwallet.GetStateNumHint( + broadcastStateNum := c.cfg.extractStateNumHint( commitTxBroadcast, obfuscator, ) remoteStateNum := remoteCommit.CommitHeight @@ -402,11 +406,12 @@ func (c *chainWatcher) closeObserver(spendNtfn *chainntnfs.SpendEvent) { c.cfg.chanState.FundingOutpoint, err) } - // This is the case that somehow the commitment broadcast is - // actually greater than even one beyond our best known state - // number. This should ONLY happen in case we experienced some - // sort of data loss. - case broadcastStateNum > remoteStateNum+1: + // If the remote party has broadcasted a state beyond our best + // known state for them, and they don't have a pending + // commitment (we write them to disk before sending out), then + // this means that we've lost data. In this case, we'll enter + // the DLP protocol. + case broadcastStateNum > remoteStateNum: log.Warnf("Remote node broadcast state #%v, "+ "which is more than 1 beyond best known "+ "state #%v!!! Attempting recovery...", @@ -418,6 +423,7 @@ func (c *chainWatcher) closeObserver(spendNtfn *chainntnfs.SpendEvent) { // point, there's not much we can do other than wait // for us to retrieve it. We will attempt to retrieve // it from the peer each time we connect to it. + // // TODO(halseth): actively initiate re-connection to // the peer? var commitPoint *btcec.PublicKey @@ -458,6 +464,7 @@ func (c *chainWatcher) closeObserver(spendNtfn *chainntnfs.SpendEvent) { // state, we'll just pass an empty commitment. Note // that this means we won't be able to recover any HTLC // funds. + // // TODO(halseth): can we try to recover some HTLCs? err = c.dispatchRemoteForceClose( commitSpend, channeldb.ChannelCommitment{}, diff --git a/contractcourt/chain_watcher_test.go b/contractcourt/chain_watcher_test.go index 8abece49..341c7706 100644 --- a/contractcourt/chain_watcher_test.go +++ b/contractcourt/chain_watcher_test.go @@ -3,12 +3,18 @@ package contractcourt import ( "bytes" "crypto/sha256" + "math" + "math/rand" + "reflect" "testing" + "testing/quick" "time" "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/chainntnfs" + "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwire" ) @@ -71,9 +77,10 @@ func TestChainWatcherRemoteUnilateralClose(t *testing.T) { spendChan: make(chan *chainntnfs.SpendDetail), } aliceChainWatcher, err := newChainWatcher(chainWatcherConfig{ - chanState: aliceChannel.State(), - notifier: aliceNotifier, - signer: aliceChannel.Signer, + chanState: aliceChannel.State(), + notifier: aliceNotifier, + signer: aliceChannel.Signer, + extractStateNumHint: lnwallet.GetStateNumHint, }) if err != nil { t.Fatalf("unable to create chain watcher: %v", err) @@ -114,6 +121,28 @@ func TestChainWatcherRemoteUnilateralClose(t *testing.T) { } } +func addFakeHTLC(t *testing.T, htlcAmount lnwire.MilliSatoshi, id uint64, + aliceChannel, bobChannel *lnwallet.LightningChannel) { + + preimage := bytes.Repeat([]byte{byte(id)}, 32) + paymentHash := sha256.Sum256(preimage) + var returnPreimage [32]byte + copy(returnPreimage[:], preimage) + htlc := &lnwire.UpdateAddHTLC{ + ID: uint64(id), + PaymentHash: paymentHash, + Amount: htlcAmount, + Expiry: uint32(5), + } + + if _, err := aliceChannel.AddHTLC(htlc, nil); err != nil { + t.Fatalf("alice unable to add htlc: %v", err) + } + if _, err := bobChannel.ReceiveHTLC(htlc); err != nil { + t.Fatalf("bob unable to recv add htlc: %v", err) + } +} + // TestChainWatcherRemoteUnilateralClosePendingCommit tests that the chain // watcher is able to properly detect a unilateral close wherein the remote // node broadcasts their newly received commitment, without first revoking the @@ -135,9 +164,10 @@ func TestChainWatcherRemoteUnilateralClosePendingCommit(t *testing.T) { spendChan: make(chan *chainntnfs.SpendDetail), } aliceChainWatcher, err := newChainWatcher(chainWatcherConfig{ - chanState: aliceChannel.State(), - notifier: aliceNotifier, - signer: aliceChannel.Signer, + chanState: aliceChannel.State(), + notifier: aliceNotifier, + signer: aliceChannel.Signer, + extractStateNumHint: lnwallet.GetStateNumHint, }) if err != nil { t.Fatalf("unable to create chain watcher: %v", err) @@ -155,23 +185,7 @@ func TestChainWatcherRemoteUnilateralClosePendingCommit(t *testing.T) { // channel state to a new pending commitment on her remote commit chain // for Bob. htlcAmount := lnwire.NewMSatFromSatoshis(20000) - preimage := bytes.Repeat([]byte{byte(1)}, 32) - paymentHash := sha256.Sum256(preimage) - var returnPreimage [32]byte - copy(returnPreimage[:], preimage) - htlc := &lnwire.UpdateAddHTLC{ - ID: uint64(0), - PaymentHash: paymentHash, - Amount: htlcAmount, - Expiry: uint32(5), - } - - if _, err := aliceChannel.AddHTLC(htlc, nil); err != nil { - t.Fatalf("alice unable to add htlc: %v", err) - } - if _, err := bobChannel.ReceiveHTLC(htlc); err != nil { - t.Fatalf("bob unable to recv add htlc: %v", err) - } + addFakeHTLC(t, htlcAmount, 0, aliceChannel, bobChannel) // With the HTLC added, we'll now manually initiate a state transition // from Alice to Bob. @@ -213,3 +227,185 @@ func TestChainWatcherRemoteUnilateralClosePendingCommit(t *testing.T) { t.Fatalf("unable to find alice's commit resolution") } } + +// dlpTestCase is a speical struct that we'll use to generate randomized test +// cases for the main TestChainWatcherDataLossProtect test. This struct has a +// special Generate method that will generate a random state number, and a +// broadcast state number which is greater than that state number. +type dlpTestCase struct { + BroadcastStateNum uint8 + NumUpdates uint8 +} + +// TestChainWatcherDataLossProtect tests that if we've lost data (and are +// behind the remote node), then we'll properly detect this case and dispatch a +// remote force close using the obtained data loss commitment point. +func TestChainWatcherDataLossProtect(t *testing.T) { + t.Parallel() + + // dlpScenario is our primary quick check testing function for this + // test as whole. It ensures that if the remote party broadcasts a + // commitment that is beyond our best known commitment for them, and + // they don't have a pending commitment (one we sent but which hasn't + // been revoked), then we'll properly detect this case, and execute the + // DLP protocol on our end. + // + // broadcastStateNum is the number that we'll trick Alice into thinking + // was broadcast, while numUpdates is the actual number of updates + // we'll execute. Both of these will be random 8-bit values generated + // by testing/quick. + dlpScenario := func(testCase dlpTestCase) bool { + // First, we'll create two channels which already have + // established a commitment contract between themselves. + aliceChannel, bobChannel, cleanUp, err := lnwallet.CreateTestChannels() + if err != nil { + t.Fatalf("unable to create test channels: %v", err) + } + defer cleanUp() + + // With the channels created, we'll now create a chain watcher + // instance which will be watching for any closes of Alice's + // channel. + aliceNotifier := &mockNotifier{ + spendChan: make(chan *chainntnfs.SpendDetail), + } + aliceChainWatcher, err := newChainWatcher(chainWatcherConfig{ + chanState: aliceChannel.State(), + notifier: aliceNotifier, + signer: aliceChannel.Signer, + extractStateNumHint: func(*wire.MsgTx, + [lnwallet.StateHintSize]byte) uint64 { + + // We'll return the "fake" broadcast commitment + // number so we can simulate broadcast of an + // arbitrary state. + return uint64(testCase.BroadcastStateNum) + }, + }) + if err != nil { + t.Fatalf("unable to create chain watcher: %v", err) + } + if err := aliceChainWatcher.Start(); err != nil { + t.Fatalf("unable to start chain watcher: %v", err) + } + defer aliceChainWatcher.Stop() + + // Based on the number of random updates for this state, make a + // new HTLC to add to the commitment, and then lock in a state + // transition. + const htlcAmt = 1000 + for i := 0; i < int(testCase.NumUpdates); i++ { + addFakeHTLC( + t, 1000, uint64(i), aliceChannel, bobChannel, + ) + + err := lnwallet.ForceStateTransition( + aliceChannel, bobChannel, + ) + if err != nil { + t.Errorf("unable to trigger state "+ + "transition: %v", err) + return false + } + } + + // We'll request a new channel event subscription from Alice's + // chain watcher so we can be notified of our fake close below. + chanEvents := aliceChainWatcher.SubscribeChannelEvents() + + // Otherwise, we'll feed in this new state number as a response + // to the query, and insert the expected DLP commit point. + dlpPoint := aliceChannel.State().RemoteCurrentRevocation + err = aliceChannel.State().MarkDataLoss(dlpPoint) + if err != nil { + t.Errorf("unable to insert dlp point: %v", err) + return false + } + + // Now we'll trigger the channel close event to trigger the + // scenario. + bobCommit := bobChannel.State().LocalCommitment.CommitTx + bobTxHash := bobCommit.TxHash() + bobSpend := &chainntnfs.SpendDetail{ + SpenderTxHash: &bobTxHash, + SpendingTx: bobCommit, + } + aliceNotifier.spendChan <- bobSpend + + // We should get a new uni close resolution that indicates we + // processed the DLP scenario. + var uniClose *lnwallet.UnilateralCloseSummary + select { + case uniClose = <-chanEvents.RemoteUnilateralClosure: + // If we processed this as a DLP case, then the remote + // party's commitment should be blank, as we don't have + // this up to date state. + blankCommit := channeldb.ChannelCommitment{} + if uniClose.RemoteCommit.FeePerKw != blankCommit.FeePerKw { + t.Errorf("DLP path not executed") + return false + } + + // The resolution should have also read the DLP point + // we stored above, and used that to derive their sweep + // key for this output. + sweepTweak := input.SingleTweakBytes( + dlpPoint, + aliceChannel.State().LocalChanCfg.PaymentBasePoint.PubKey, + ) + commitResolution := uniClose.CommitResolution + resolutionTweak := commitResolution.SelfOutputSignDesc.SingleTweak + if !bytes.Equal(sweepTweak, resolutionTweak) { + t.Errorf("sweep key mismatch: expected %x got %x", + sweepTweak, resolutionTweak) + return false + } + + return true + + case <-time.After(time.Second * 5): + t.Errorf("didn't receive unilateral close event") + return false + } + } + + // For our first scenario, we'll ensure that if we're on state 1, and + // the remote party broadcasts state 2 and we don't have a pending + // commit for them, then we'll properly detect this as a DLP scenario. + if !dlpScenario(dlpTestCase{ + BroadcastStateNum: 2, + NumUpdates: 1, + }) { + t.Fatalf("DLP test case failed at state 1!") + } + + // For the remainder of the tests, we'll perform 10 iterations with + // random values. We limit this number as set up of each test can take + // time, and also it doing up to 255 state transitions may cause the + // test to hang for a long time. + // + // TODO(roasbeef): speed up execution + err := quick.Check(dlpScenario, &quick.Config{ + MaxCount: 10, + Values: func(v []reflect.Value, rand *rand.Rand) { + // stateNum will be the random number of state updates + // we execute during the scenario. + stateNum := uint8(rand.Int31()) + + // From this state number, we'll draw a random number + // between the state and 255, ensuring that it' at + // least one state beyond the target stateNum. + broadcastRange := rand.Int31n(int32(math.MaxUint8 - stateNum)) + broadcastNum := uint8(stateNum + 1 + uint8(broadcastRange)) + + testCase := dlpTestCase{ + BroadcastStateNum: broadcastNum, + NumUpdates: stateNum, + } + v[0] = reflect.ValueOf(testCase) + }, + }) + if err != nil { + t.Fatalf("DLP test case failed: %v", err) + } +} diff --git a/lnwallet/channel_test.go b/lnwallet/channel_test.go index 5616b0fc..1093ff04 100644 --- a/lnwallet/channel_test.go +++ b/lnwallet/channel_test.go @@ -23,45 +23,6 @@ import ( "github.com/lightningnetwork/lnd/lnwire" ) -// forceStateTransition executes the necessary interaction between the two -// commitment state machines to transition to a new state locking in any -// pending updates. -func forceStateTransition(chanA, chanB *LightningChannel) error { - aliceSig, aliceHtlcSigs, err := chanA.SignNextCommitment() - if err != nil { - return err - } - if err = chanB.ReceiveNewCommitment(aliceSig, aliceHtlcSigs); err != nil { - return err - } - - bobRevocation, _, err := chanB.RevokeCurrentCommitment() - if err != nil { - return err - } - bobSig, bobHtlcSigs, err := chanB.SignNextCommitment() - if err != nil { - return err - } - - if _, _, _, err := chanA.ReceiveRevocation(bobRevocation); err != nil { - return err - } - if err := chanA.ReceiveNewCommitment(bobSig, bobHtlcSigs); err != nil { - return err - } - - aliceRevocation, _, err := chanA.RevokeCurrentCommitment() - if err != nil { - return err - } - if _, _, _, err := chanB.ReceiveRevocation(aliceRevocation); err != nil { - return err - } - - return nil -} - // createHTLC is a utility function for generating an HTLC with a given // preimage and a given amount. func createHTLC(id int, amount lnwire.MilliSatoshi) (*lnwire.UpdateAddHTLC, [32]byte) { @@ -440,7 +401,7 @@ func TestCheckCommitTxSize(t *testing.T) { t.Fatalf("bob unable to receive htlc: %v", err) } - if err := forceStateTransition(aliceChannel, bobChannel); err != nil { + if err := ForceStateTransition(aliceChannel, bobChannel); err != nil { t.Fatalf("unable to complete state update: %v", err) } checkSize(aliceChannel, i+1) @@ -462,7 +423,7 @@ func TestCheckCommitTxSize(t *testing.T) { t.Fatalf("alice unable to accept settle of outbound htlc: %v", err) } - if err := forceStateTransition(bobChannel, aliceChannel); err != nil { + if err := ForceStateTransition(bobChannel, aliceChannel); err != nil { t.Fatalf("unable to complete state update: %v", err) } checkSize(aliceChannel, i) @@ -576,10 +537,10 @@ func TestForceClose(t *testing.T) { // Next, we'll perform two state transitions to ensure that both HTLC's // get fully locked-in. - if err := forceStateTransition(aliceChannel, bobChannel); err != nil { + if err := ForceStateTransition(aliceChannel, bobChannel); err != nil { t.Fatalf("Can't update the channel state: %v", err) } - if err := forceStateTransition(bobChannel, aliceChannel); err != nil { + if err := ForceStateTransition(bobChannel, aliceChannel); err != nil { t.Fatalf("Can't update the channel state: %v", err) } @@ -862,7 +823,7 @@ func TestForceCloseDustOutput(t *testing.T) { if err != nil { t.Fatalf("bob unable to receive htlc: %v", err) } - if err := forceStateTransition(bobChannel, aliceChannel); err != nil { + if err := ForceStateTransition(bobChannel, aliceChannel); err != nil { t.Fatalf("Can't update the channel state: %v", err) } @@ -875,7 +836,7 @@ func TestForceCloseDustOutput(t *testing.T) { if err != nil { t.Fatalf("alice unable to accept settle of outbound htlc: %v", err) } - if err := forceStateTransition(aliceChannel, bobChannel); err != nil { + if err := ForceStateTransition(aliceChannel, bobChannel); err != nil { t.Fatalf("Can't update the channel state: %v", err) } @@ -967,7 +928,7 @@ func TestDustHTLCFees(t *testing.T) { if _, err := bobChannel.ReceiveHTLC(htlc); err != nil { t.Fatalf("bob unable to receive htlc: %v", err) } - if err := forceStateTransition(aliceChannel, bobChannel); err != nil { + if err := ForceStateTransition(aliceChannel, bobChannel); err != nil { t.Fatalf("Can't update the channel state: %v", err) } @@ -1048,7 +1009,7 @@ func TestHTLCDustLimit(t *testing.T) { if err != nil { t.Fatalf("bob unable to receive htlc: %v", err) } - if err := forceStateTransition(aliceChannel, bobChannel); err != nil { + if err := ForceStateTransition(aliceChannel, bobChannel); err != nil { t.Fatalf("Can't update the channel state: %v", err) } @@ -1082,7 +1043,7 @@ func TestHTLCDustLimit(t *testing.T) { if err != nil { t.Fatalf("alice unable to accept settle of outbound htlc: %v", err) } - if err := forceStateTransition(bobChannel, aliceChannel); err != nil { + if err := ForceStateTransition(bobChannel, aliceChannel); err != nil { t.Fatalf("state transition error: %v", err) } @@ -1314,7 +1275,7 @@ func TestChannelBalanceDustLimit(t *testing.T) { if err != nil { t.Fatalf("bob unable to receive htlc: %v", err) } - if err := forceStateTransition(aliceChannel, bobChannel); err != nil { + if err := ForceStateTransition(aliceChannel, bobChannel); err != nil { t.Fatalf("state transition error: %v", err) } err = bobChannel.SettleHTLC(preimage, bobHtlcIndex, nil, nil, nil) @@ -1325,7 +1286,7 @@ func TestChannelBalanceDustLimit(t *testing.T) { if err != nil { t.Fatalf("alice unable to accept settle of outbound htlc: %v", err) } - if err := forceStateTransition(bobChannel, aliceChannel); err != nil { + if err := ForceStateTransition(bobChannel, aliceChannel); err != nil { t.Fatalf("state transition error: %v", err) } @@ -1439,7 +1400,7 @@ func TestStateUpdatePersistence(t *testing.T) { // Next, Alice initiates a state transition to include the HTLC's she // added above in a new commitment state. - if err := forceStateTransition(aliceChannel, bobChannel); err != nil { + if err := ForceStateTransition(aliceChannel, bobChannel); err != nil { t.Fatalf("unable to complete alice's state transition: %v", err) } @@ -1447,7 +1408,7 @@ func TestStateUpdatePersistence(t *testing.T) { // commitment transaction (but it was in Alice's, as he ACK'd her // changes before creating a new state), Bob needs to trigger another // state update in order to re-sync their states. - if err := forceStateTransition(bobChannel, aliceChannel); err != nil { + if err := ForceStateTransition(bobChannel, aliceChannel); err != nil { t.Fatalf("unable to complete bob's state transition: %v", err) } @@ -1642,10 +1603,10 @@ func TestStateUpdatePersistence(t *testing.T) { // entries to the update log before a state transition was initiated by // either side, both sides are required to trigger an update in order // to lock in their changes. - if err := forceStateTransition(aliceChannelNew, bobChannelNew); err != nil { + if err := ForceStateTransition(aliceChannelNew, bobChannelNew); err != nil { t.Fatalf("unable to update commitments: %v", err) } - if err := forceStateTransition(bobChannelNew, aliceChannelNew); err != nil { + if err := ForceStateTransition(bobChannelNew, aliceChannelNew); err != nil { t.Fatalf("unable to update commitments: %v", err) } @@ -1720,7 +1681,7 @@ func TestCancelHTLC(t *testing.T) { if err != nil { t.Fatalf("unable to add bob htlc: %v", err) } - if err := forceStateTransition(aliceChannel, bobChannel); err != nil { + if err := ForceStateTransition(aliceChannel, bobChannel); err != nil { t.Fatalf("unable to create new commitment state: %v", err) } @@ -1748,7 +1709,7 @@ func TestCancelHTLC(t *testing.T) { // Now trigger another state transition, the HTLC should now be removed // from both sides, with balances reflected. - if err := forceStateTransition(bobChannel, aliceChannel); err != nil { + if err := ForceStateTransition(bobChannel, aliceChannel); err != nil { t.Fatalf("unable to create new commitment: %v", err) } @@ -1992,7 +1953,7 @@ func TestUpdateFeeAdjustments(t *testing.T) { // With the fee updates applied, we'll now initiate a state transition // to ensure the fee update is locked in. - if err := forceStateTransition(aliceChannel, bobChannel); err != nil { + if err := ForceStateTransition(aliceChannel, bobChannel); err != nil { t.Fatalf("unable to create new commitment: %v", err) } @@ -2014,7 +1975,7 @@ func TestUpdateFeeAdjustments(t *testing.T) { if err := bobChannel.ReceiveUpdateFee(newFee); err != nil { t.Fatalf("unable to bob update fee: %v", err) } - if err := forceStateTransition(aliceChannel, bobChannel); err != nil { + if err := ForceStateTransition(aliceChannel, bobChannel); err != nil { t.Fatalf("unable to create new commitment: %v", err) } } @@ -2642,7 +2603,7 @@ func TestChanSyncFullySynced(t *testing.T) { } // Then we'll initiate a state transition to lock in this new HTLC. - if err := forceStateTransition(aliceChannel, bobChannel); err != nil { + if err := ForceStateTransition(aliceChannel, bobChannel); err != nil { t.Fatalf("unable to complete alice's state transition: %v", err) } @@ -2663,7 +2624,7 @@ func TestChanSyncFullySynced(t *testing.T) { // Next, we'll complete Bob's state transition, and assert again that // they think they're fully synced. - if err := forceStateTransition(bobChannel, aliceChannel); err != nil { + if err := ForceStateTransition(bobChannel, aliceChannel); err != nil { t.Fatalf("unable to complete bob's state transition: %v", err) } assertNoChanSyncNeeded(t, aliceChannel, bobChannel) @@ -2773,7 +2734,7 @@ func TestChanSyncOweCommitment(t *testing.T) { // With the HTLC's applied to both update logs, we'll initiate a state // transition from Bob. - if err := forceStateTransition(bobChannel, aliceChannel); err != nil { + if err := ForceStateTransition(bobChannel, aliceChannel); err != nil { t.Fatalf("unable to complete bob's state transition: %v", err) } @@ -3010,7 +2971,7 @@ func TestChanSyncOweCommitment(t *testing.T) { if err != nil { t.Fatalf("unable to settle htlc: %v", err) } - if err := forceStateTransition(bobChannel, aliceChannel); err != nil { + if err := ForceStateTransition(bobChannel, aliceChannel); err != nil { t.Fatalf("unable to complete bob's state transition: %v", err) } @@ -3075,7 +3036,7 @@ func TestChanSyncOweRevocation(t *testing.T) { if err != nil { t.Fatalf("unable to recv bob's htlc: %v", err) } - if err := forceStateTransition(bobChannel, aliceChannel); err != nil { + if err := ForceStateTransition(bobChannel, aliceChannel); err != nil { t.Fatalf("unable to complete bob's state transition: %v", err) } @@ -3216,7 +3177,7 @@ func TestChanSyncOweRevocation(t *testing.T) { if _, err := bobChannel.ReceiveHTLC(aliceHtlc); err != nil { t.Fatalf("unable to recv alice's htlc: %v", err) } - if err := forceStateTransition(aliceChannel, bobChannel); err != nil { + if err := ForceStateTransition(aliceChannel, bobChannel); err != nil { t.Fatalf("unable to complete alice's state transition: %v", err) } @@ -3260,7 +3221,7 @@ func TestChanSyncOweRevocationAndCommit(t *testing.T) { if err != nil { t.Fatalf("unable to recv bob's htlc: %v", err) } - if err := forceStateTransition(bobChannel, aliceChannel); err != nil { + if err := ForceStateTransition(bobChannel, aliceChannel); err != nil { t.Fatalf("unable to complete bob's state transition: %v", err) } @@ -3429,7 +3390,7 @@ func TestChanSyncOweRevocationAndCommitForceTransition(t *testing.T) { if err != nil { t.Fatalf("unable to recv bob's htlc: %v", err) } - if err := forceStateTransition(bobChannel, aliceChannel); err != nil { + if err := ForceStateTransition(bobChannel, aliceChannel); err != nil { t.Fatalf("unable to complete bob's state transition: %v", err) } @@ -3664,7 +3625,7 @@ func TestChanSyncFailure(t *testing.T) { if err != nil { t.Fatalf("unable to recv bob's htlc: %v", err) } - err = forceStateTransition(bobChannel, aliceChannel) + err = ForceStateTransition(bobChannel, aliceChannel) if err != nil { t.Fatalf("unable to complete bob's state "+ "transition: %v", err) @@ -4077,7 +4038,7 @@ func TestChannelRetransmissionFeeUpdate(t *testing.T) { if _, err := aliceChannel.ReceiveHTLC(bobHtlc); err != nil { t.Fatalf("unable to recv bob's htlc: %v", err) } - if err := forceStateTransition(bobChannel, aliceChannel); err != nil { + if err := ForceStateTransition(bobChannel, aliceChannel); err != nil { t.Fatalf("unable to complete bob's state transition: %v", err) } } @@ -4283,7 +4244,7 @@ func TestFeeUpdateOldDiskFormat(t *testing.T) { if _, err := bobChannel.ReceiveHTLC(htlc); err != nil { t.Fatalf("unable to recv htlc: %v", err) } - if err := forceStateTransition(aliceChannel, bobChannel); err != nil { + if err := ForceStateTransition(aliceChannel, bobChannel); err != nil { t.Fatalf("unable to complete bob's state transition: %v", err) } @@ -4380,7 +4341,7 @@ func TestChanSyncInvalidLastSecret(t *testing.T) { } // Then we'll initiate a state transition to lock in this new HTLC. - if err := forceStateTransition(aliceChannel, bobChannel); err != nil { + if err := ForceStateTransition(aliceChannel, bobChannel); err != nil { t.Fatalf("unable to complete alice's state transition: %v", err) } @@ -4452,13 +4413,13 @@ func TestChanAvailableBandwidth(t *testing.T) { // to actually determine what the current up to date balance // is. if aliceInitiate { - err := forceStateTransition(aliceChannel, bobChannel) + err := ForceStateTransition(aliceChannel, bobChannel) if err != nil { t.Fatalf("unable to complete alice's state "+ "transition: %v", err) } } else { - err := forceStateTransition(bobChannel, aliceChannel) + err := ForceStateTransition(bobChannel, aliceChannel) if err != nil { t.Fatalf("unable to complete alice's state "+ "transition: %v", err) @@ -4538,7 +4499,7 @@ func TestChanAvailableBandwidth(t *testing.T) { // We must do a state transition before the balance is available // for Alice. - if err := forceStateTransition(aliceChannel, bobChannel); err != nil { + if err := ForceStateTransition(aliceChannel, bobChannel); err != nil { t.Fatalf("unable to complete alice's state "+ "transition: %v", err) } @@ -4968,10 +4929,10 @@ func TestChannelUnilateralCloseHtlcResolution(t *testing.T) { if _, err := aliceChannel.ReceiveHTLC(htlcBob); err != nil { t.Fatalf("alice unable to recv add htlc: %v", err) } - if err := forceStateTransition(aliceChannel, bobChannel); err != nil { + if err := ForceStateTransition(aliceChannel, bobChannel); err != nil { t.Fatalf("Can't update the channel state: %v", err) } - if err := forceStateTransition(bobChannel, aliceChannel); err != nil { + if err := ForceStateTransition(bobChannel, aliceChannel); err != nil { t.Fatalf("Can't update the channel state: %v", err) } @@ -5253,7 +5214,7 @@ func TestDesyncHTLCs(t *testing.T) { } // Lock this HTLC in. - if err := forceStateTransition(aliceChannel, bobChannel); err != nil { + if err := ForceStateTransition(aliceChannel, bobChannel); err != nil { t.Fatalf("unable to complete state update: %v", err) } @@ -5283,7 +5244,7 @@ func TestDesyncHTLCs(t *testing.T) { // Now do a state transition, which will ACK the FailHTLC, making Alice // able to add the new HTLC. - if err := forceStateTransition(aliceChannel, bobChannel); err != nil { + if err := ForceStateTransition(aliceChannel, bobChannel); err != nil { t.Fatalf("unable to complete state update: %v", err) } if _, err = aliceChannel.AddHTLC(htlc, nil); err != nil { @@ -5348,7 +5309,7 @@ func TestMaxAcceptedHTLCs(t *testing.T) { if _, err := bobChannel.ReceiveHTLC(htlc); err != nil { t.Fatalf("unable to recv htlc: %v", err) } - err = forceStateTransition(aliceChannel, bobChannel) + err = ForceStateTransition(aliceChannel, bobChannel) if err != ErrMaxHTLCNumber { t.Fatalf("expected ErrMaxHTLCNumber, instead received: %v", err) } @@ -5407,7 +5368,7 @@ func TestMaxPendingAmount(t *testing.T) { if _, err := bobChannel.ReceiveHTLC(htlc); err != nil { t.Fatalf("unable to recv htlc: %v", err) } - err = forceStateTransition(aliceChannel, bobChannel) + err = ForceStateTransition(aliceChannel, bobChannel) if err != ErrMaxPendingAmount { t.Fatalf("expected ErrMaxPendingAmount, instead received: %v", err) } @@ -5505,7 +5466,7 @@ func TestChanReserve(t *testing.T) { // Force a state transition, making sure this HTLC is considered valid // even though the channel reserves are not met. - if err := forceStateTransition(aliceChannel, bobChannel); err != nil { + if err := ForceStateTransition(aliceChannel, bobChannel); err != nil { t.Fatalf("unable to complete state update: %v", err) } @@ -5532,7 +5493,7 @@ func TestChanReserve(t *testing.T) { if _, err := aliceChannel.ReceiveHTLC(htlc); err != nil { t.Fatalf("unable to recv htlc: %v", err) } - err = forceStateTransition(aliceChannel, bobChannel) + err = ForceStateTransition(aliceChannel, bobChannel) if err != ErrBelowChanReserve { t.Fatalf("expected ErrBelowChanReserve, instead received: %v", err) } @@ -5580,7 +5541,7 @@ func TestChanReserve(t *testing.T) { if _, err := bobChannel.ReceiveHTLC(htlc); err != nil { t.Fatalf("unable to recv htlc: %v", err) } - err = forceStateTransition(aliceChannel, bobChannel) + err = ForceStateTransition(aliceChannel, bobChannel) if err != ErrBelowChanReserve { t.Fatalf("expected ErrBelowChanReserve, instead received: %v", err) } @@ -5608,7 +5569,7 @@ func TestChanReserve(t *testing.T) { if err != nil { t.Fatalf("unable to recv htlc: %v", err) } - if err := forceStateTransition(aliceChannel, bobChannel); err != nil { + if err := ForceStateTransition(aliceChannel, bobChannel); err != nil { t.Fatalf("unable to complete state update: %v", err) } @@ -5624,7 +5585,7 @@ func TestChanReserve(t *testing.T) { if err := aliceChannel.ReceiveHTLCSettle(preimage, aliceHtlcIndex); err != nil { t.Fatalf("alice unable to accept settle of outbound htlc: %v", err) } - if err := forceStateTransition(bobChannel, aliceChannel); err != nil { + if err := ForceStateTransition(bobChannel, aliceChannel); err != nil { t.Fatalf("unable to complete state update: %v", err) } @@ -5648,7 +5609,7 @@ func TestChanReserve(t *testing.T) { } // Do a last state transition, which should succeed. - if err := forceStateTransition(bobChannel, aliceChannel); err != nil { + if err := ForceStateTransition(bobChannel, aliceChannel); err != nil { t.Fatalf("unable to complete state update: %v", err) } @@ -5707,7 +5668,7 @@ func TestMinHTLC(t *testing.T) { if err != nil { t.Fatalf("error receiving htlc: %v", err) } - err = forceStateTransition(aliceChannel, bobChannel) + err = ForceStateTransition(aliceChannel, bobChannel) if err != ErrBelowMinHTLC { t.Fatalf("expected ErrBelowMinHTLC, instead received: %v", err) } @@ -5766,7 +5727,7 @@ func TestNewBreachRetributionSkipsDustHtlcs(t *testing.T) { // With the HTLC's applied to both update logs, we'll initiate a state // transition from Alice. - if err := forceStateTransition(bobChannel, aliceChannel); err != nil { + if err := ForceStateTransition(bobChannel, aliceChannel); err != nil { t.Fatalf("unable to complete bob's state transition: %v", err) } @@ -5786,7 +5747,7 @@ func TestNewBreachRetributionSkipsDustHtlcs(t *testing.T) { t.Fatalf("unable to settle htlc: %v", err) } } - if err := forceStateTransition(bobChannel, aliceChannel); err != nil { + if err := ForceStateTransition(bobChannel, aliceChannel); err != nil { t.Fatalf("unable to complete bob's state transition: %v", err) } @@ -6088,7 +6049,7 @@ func TestChannelRestoreUpdateLogsFailedHTLC(t *testing.T) { } // Lock in the Add on both sides. - if err := forceStateTransition(aliceChannel, bobChannel); err != nil { + if err := ForceStateTransition(aliceChannel, bobChannel); err != nil { t.Fatalf("unable to complete state update: %v", err) } @@ -6203,7 +6164,7 @@ func TestDuplicateFailRejection(t *testing.T) { t.Fatalf("unable to recv htlc: %v", err) } - if err := forceStateTransition(aliceChannel, bobChannel); err != nil { + if err := ForceStateTransition(aliceChannel, bobChannel); err != nil { t.Fatalf("unable to complete state update: %v", err) } @@ -6281,7 +6242,7 @@ func TestDuplicateSettleRejection(t *testing.T) { t.Fatalf("unable to recv htlc: %v", err) } - if err := forceStateTransition(aliceChannel, bobChannel); err != nil { + if err := ForceStateTransition(aliceChannel, bobChannel); err != nil { t.Fatalf("unable to complete state update: %v", err) } diff --git a/lnwallet/test_utils.go b/lnwallet/test_utils.go index 11865e16..392f4492 100644 --- a/lnwallet/test_utils.go +++ b/lnwallet/test_utils.go @@ -492,3 +492,43 @@ func calcStaticFee(numHTLCs int) btcutil.Amount { return feePerKw * (commitWeight + btcutil.Amount(htlcWeight*numHTLCs)) / 1000 } + +// ForceStateTransition executes the necessary interaction between the two +// commitment state machines to transition to a new state locking in any +// pending updates. This method is useful when testing interactions between two +// live state machines. +func ForceStateTransition(chanA, chanB *LightningChannel) error { + aliceSig, aliceHtlcSigs, err := chanA.SignNextCommitment() + if err != nil { + return err + } + if err = chanB.ReceiveNewCommitment(aliceSig, aliceHtlcSigs); err != nil { + return err + } + + bobRevocation, _, err := chanB.RevokeCurrentCommitment() + if err != nil { + return err + } + bobSig, bobHtlcSigs, err := chanB.SignNextCommitment() + if err != nil { + return err + } + + if _, _, _, err := chanA.ReceiveRevocation(bobRevocation); err != nil { + return err + } + if err := chanA.ReceiveNewCommitment(bobSig, bobHtlcSigs); err != nil { + return err + } + + aliceRevocation, _, err := chanA.RevokeCurrentCommitment() + if err != nil { + return err + } + if _, _, _, err := chanB.ReceiveRevocation(aliceRevocation); err != nil { + return err + } + + return nil +}