From c08377d303d7f95b45dc51933ea9c0579eb01b27 Mon Sep 17 00:00:00 2001 From: carla Date: Thu, 30 Jan 2020 10:01:10 +0200 Subject: [PATCH 1/3] htlcswitch/test: replace mock server delta with constant --- htlcswitch/link_test.go | 1 + htlcswitch/switch_test.go | 96 +++++++++++++++++++++++++++++---------- 2 files changed, 73 insertions(+), 24 deletions(-) diff --git a/htlcswitch/link_test.go b/htlcswitch/link_test.go index 16d64d70..427322ee 100644 --- a/htlcswitch/link_test.go +++ b/htlcswitch/link_test.go @@ -39,6 +39,7 @@ import ( const ( testStartingHeight = 100 + testDefaultDelta = 6 ) // concurrentTester is a thread-safe wrapper around the Fatalf method of a diff --git a/htlcswitch/switch_test.go b/htlcswitch/switch_test.go index 69425dc2..1497f009 100644 --- a/htlcswitch/switch_test.go +++ b/htlcswitch/switch_test.go @@ -32,7 +32,9 @@ func genPreimage() ([32]byte, error) { func TestSwitchAddDuplicateLink(t *testing.T) { t.Parallel() - alicePeer, err := newMockServer(t, "alice", testStartingHeight, nil, 6) + alicePeer, err := newMockServer( + t, "alice", testStartingHeight, nil, testDefaultDelta, + ) if err != nil { t.Fatalf("unable to create alice server: %v", err) } @@ -90,7 +92,9 @@ func TestSwitchAddDuplicateLink(t *testing.T) { func TestSwitchHasActiveLink(t *testing.T) { t.Parallel() - alicePeer, err := newMockServer(t, "alice", testStartingHeight, nil, 6) + alicePeer, err := newMockServer( + t, "alice", testStartingHeight, nil, testDefaultDelta, + ) if err != nil { t.Fatalf("unable to create alice server: %v", err) } @@ -158,7 +162,9 @@ func TestSwitchHasActiveLink(t *testing.T) { func TestSwitchSendPending(t *testing.T) { t.Parallel() - alicePeer, err := newMockServer(t, "alice", testStartingHeight, nil, 6) + alicePeer, err := newMockServer( + t, "alice", testStartingHeight, nil, testDefaultDelta, + ) if err != nil { t.Fatalf("unable to create alice server: %v", err) } @@ -253,11 +259,15 @@ func TestSwitchSendPending(t *testing.T) { func TestSwitchForward(t *testing.T) { t.Parallel() - alicePeer, err := newMockServer(t, "alice", testStartingHeight, nil, 6) + alicePeer, err := newMockServer( + t, "alice", testStartingHeight, nil, testDefaultDelta, + ) if err != nil { t.Fatalf("unable to create alice server: %v", err) } - bobPeer, err := newMockServer(t, "bob", testStartingHeight, nil, 6) + bobPeer, err := newMockServer( + t, "bob", testStartingHeight, nil, testDefaultDelta, + ) if err != nil { t.Fatalf("unable to create bob server: %v", err) } @@ -358,11 +368,15 @@ func TestSwitchForwardFailAfterFullAdd(t *testing.T) { chanID1, chanID2, aliceChanID, bobChanID := genIDs() - alicePeer, err := newMockServer(t, "alice", testStartingHeight, nil, 6) + alicePeer, err := newMockServer( + t, "alice", testStartingHeight, nil, testDefaultDelta, + ) if err != nil { t.Fatalf("unable to create alice server: %v", err) } - bobPeer, err := newMockServer(t, "bob", testStartingHeight, nil, 6) + bobPeer, err := newMockServer( + t, "bob", testStartingHeight, nil, testDefaultDelta, + ) if err != nil { t.Fatalf("unable to create bob server: %v", err) } @@ -549,11 +563,15 @@ func TestSwitchForwardSettleAfterFullAdd(t *testing.T) { chanID1, chanID2, aliceChanID, bobChanID := genIDs() - alicePeer, err := newMockServer(t, "alice", testStartingHeight, nil, 6) + alicePeer, err := newMockServer( + t, "alice", testStartingHeight, nil, testDefaultDelta, + ) if err != nil { t.Fatalf("unable to create alice server: %v", err) } - bobPeer, err := newMockServer(t, "bob", testStartingHeight, nil, 6) + bobPeer, err := newMockServer( + t, "bob", testStartingHeight, nil, testDefaultDelta, + ) if err != nil { t.Fatalf("unable to create bob server: %v", err) } @@ -743,11 +761,15 @@ func TestSwitchForwardDropAfterFullAdd(t *testing.T) { chanID1, chanID2, aliceChanID, bobChanID := genIDs() - alicePeer, err := newMockServer(t, "alice", testStartingHeight, nil, 6) + alicePeer, err := newMockServer( + t, "alice", testStartingHeight, nil, testDefaultDelta, + ) if err != nil { t.Fatalf("unable to create alice server: %v", err) } - bobPeer, err := newMockServer(t, "bob", testStartingHeight, nil, 6) + bobPeer, err := newMockServer( + t, "bob", testStartingHeight, nil, testDefaultDelta, + ) if err != nil { t.Fatalf("unable to create bob server: %v", err) } @@ -906,11 +928,15 @@ func TestSwitchForwardFailAfterHalfAdd(t *testing.T) { chanID1, chanID2, aliceChanID, bobChanID := genIDs() - alicePeer, err := newMockServer(t, "alice", testStartingHeight, nil, 6) + alicePeer, err := newMockServer( + t, "alice", testStartingHeight, nil, testDefaultDelta, + ) if err != nil { t.Fatalf("unable to create alice server: %v", err) } - bobPeer, err := newMockServer(t, "bob", testStartingHeight, nil, 6) + bobPeer, err := newMockServer( + t, "bob", testStartingHeight, nil, testDefaultDelta, + ) if err != nil { t.Fatalf("unable to create bob server: %v", err) } @@ -1064,11 +1090,15 @@ func TestSwitchForwardCircuitPersistence(t *testing.T) { chanID1, chanID2, aliceChanID, bobChanID := genIDs() - alicePeer, err := newMockServer(t, "alice", testStartingHeight, nil, 6) + alicePeer, err := newMockServer( + t, "alice", testStartingHeight, nil, testDefaultDelta, + ) if err != nil { t.Fatalf("unable to create alice server: %v", err) } - bobPeer, err := newMockServer(t, "bob", testStartingHeight, nil, 6) + bobPeer, err := newMockServer( + t, "bob", testStartingHeight, nil, testDefaultDelta, + ) if err != nil { t.Fatalf("unable to create bob server: %v", err) } @@ -1359,11 +1389,15 @@ func testSkipIneligibleLinksMultiHopForward(t *testing.T, var packet *htlcPacket - alicePeer, err := newMockServer(t, "alice", testStartingHeight, nil, 6) + alicePeer, err := newMockServer( + t, "alice", testStartingHeight, nil, testDefaultDelta, + ) if err != nil { t.Fatalf("unable to create alice server: %v", err) } - bobPeer, err := newMockServer(t, "bob", testStartingHeight, nil, 6) + bobPeer, err := newMockServer( + t, "bob", testStartingHeight, nil, testDefaultDelta, + ) if err != nil { t.Fatalf("unable to create bob server: %v", err) } @@ -1470,7 +1504,9 @@ func testSkipLinkLocalForward(t *testing.T, eligible bool, // We'll create a single link for this test, marking it as being unable // to forward form the get go. - alicePeer, err := newMockServer(t, "alice", testStartingHeight, nil, 6) + alicePeer, err := newMockServer( + t, "alice", testStartingHeight, nil, testDefaultDelta, + ) if err != nil { t.Fatalf("unable to create alice server: %v", err) } @@ -1524,11 +1560,15 @@ func testSkipLinkLocalForward(t *testing.T, eligible bool, func TestSwitchCancel(t *testing.T) { t.Parallel() - alicePeer, err := newMockServer(t, "alice", testStartingHeight, nil, 6) + alicePeer, err := newMockServer( + t, "alice", testStartingHeight, nil, testDefaultDelta, + ) if err != nil { t.Fatalf("unable to create alice server: %v", err) } - bobPeer, err := newMockServer(t, "bob", testStartingHeight, nil, 6) + bobPeer, err := newMockServer( + t, "bob", testStartingHeight, nil, testDefaultDelta, + ) if err != nil { t.Fatalf("unable to create bob server: %v", err) } @@ -1637,11 +1677,15 @@ func TestSwitchAddSamePayment(t *testing.T) { chanID1, chanID2, aliceChanID, bobChanID := genIDs() - alicePeer, err := newMockServer(t, "alice", testStartingHeight, nil, 6) + alicePeer, err := newMockServer( + t, "alice", testStartingHeight, nil, testDefaultDelta, + ) if err != nil { t.Fatalf("unable to create alice server: %v", err) } - bobPeer, err := newMockServer(t, "bob", testStartingHeight, nil, 6) + bobPeer, err := newMockServer( + t, "bob", testStartingHeight, nil, testDefaultDelta, + ) if err != nil { t.Fatalf("unable to create bob server: %v", err) } @@ -1796,7 +1840,9 @@ func TestSwitchAddSamePayment(t *testing.T) { func TestSwitchSendPayment(t *testing.T) { t.Parallel() - alicePeer, err := newMockServer(t, "alice", testStartingHeight, nil, 6) + alicePeer, err := newMockServer( + t, "alice", testStartingHeight, nil, testDefaultDelta, + ) if err != nil { t.Fatalf("unable to create alice server: %v", err) } @@ -2334,7 +2380,9 @@ func TestSwitchGetPaymentResult(t *testing.T) { func TestInvalidFailure(t *testing.T) { t.Parallel() - alicePeer, err := newMockServer(t, "alice", testStartingHeight, nil, 6) + alicePeer, err := newMockServer( + t, "alice", testStartingHeight, nil, testDefaultDelta, + ) if err != nil { t.Fatalf("unable to create alice server: %v", err) } From a3478f1d998c6d19efda3d7e0fba2f765a3d2888 Mon Sep 17 00:00:00 2001 From: carla Date: Thu, 30 Jan 2020 10:01:17 +0200 Subject: [PATCH 2/3] htlcswitch: add CheckCircularForward to handlePacketForward Add a CheckCircularForward function which detects packets which are forwards over the same incoming and outgoing link, and errors if the node is configured to disallow forwards of this nature. This check is added to increase the cost of a liquidity lockup attack, because it increases the length of the route required to lock up an individual node's bandwidth. Since nodes are currently limited to 20 hops, increasing the length of the route needed to lock up capital increases the number of malicious payments an attacker will have to route, which increases the capital requirement of the attack overall. --- htlcswitch/failure_detail.go | 8 ++ htlcswitch/switch.go | 51 +++++++++++ htlcswitch/switch_test.go | 166 +++++++++++++++++++++++++++++++++++ 3 files changed, 225 insertions(+) diff --git a/htlcswitch/failure_detail.go b/htlcswitch/failure_detail.go index 92b25510..976015b8 100644 --- a/htlcswitch/failure_detail.go +++ b/htlcswitch/failure_detail.go @@ -29,6 +29,11 @@ const ( // FailureDetailInsufficientBalance is returned when we cannot route a // htlc due to insufficient outgoing capacity. FailureDetailInsufficientBalance + + // FailureDetailCircularRoute is returned when an attempt is made + // to forward a htlc through our node which arrives and leaves on the + // same channel. + FailureDetailCircularRoute ) // String returns the string representation of a failure detail. @@ -52,6 +57,9 @@ func (fd FailureDetail) String() string { case FailureDetailInsufficientBalance: return "insufficient bandwidth to route htlc" + case FailureDetailCircularRoute: + return "same incoming and outgoing channel" + default: return "unknown failure detail" } diff --git a/htlcswitch/switch.go b/htlcswitch/switch.go index c08bdc0c..760aadf0 100644 --- a/htlcswitch/switch.go +++ b/htlcswitch/switch.go @@ -167,6 +167,10 @@ type Config struct { // fails in forwarding packages. AckEventTicker ticker.Ticker + // AllowCircularRoute is true if the user has configured their node to + // allow forwards that arrive and depart our node over the same channel. + AllowCircularRoute bool + // RejectHTLC is a flag that instructs the htlcswitch to reject any // HTLCs that are not from the source hop. RejectHTLC bool @@ -986,6 +990,22 @@ func (s *Switch) handlePacketForward(packet *htlcPacket) error { return s.handleLocalDispatch(packet) } + // Before we attempt to find a non-strict forwarding path for + // this htlc, check whether the htlc is being routed over the + // same incoming and outgoing channel. If our node does not + // allow forwards of this nature, we fail the htlc early. This + // check is in place to disallow inefficiently routed htlcs from + // locking up our balance. + linkErr := checkCircularForward( + packet.incomingChanID, packet.outgoingChanID, + s.cfg.AllowCircularRoute, htlc.PaymentHash, + ) + if linkErr != nil { + return s.failAddPacket( + packet, linkErr.WireMessage(), linkErr, + ) + } + s.indexMtx.RLock() targetLink, err := s.getLinkByShortID(packet.outgoingChanID) if err != nil { @@ -1170,6 +1190,37 @@ func (s *Switch) handlePacketForward(packet *htlcPacket) error { } } +// checkCircularForward checks whether a forward is circular (arrives and +// departs on the same link) and returns a link error if the switch is +// configured to disallow this behaviour. +func checkCircularForward(incoming, outgoing lnwire.ShortChannelID, + allowCircular bool, paymentHash lntypes.Hash) *LinkError { + + // If the route is not circular we do not need to perform any further + // checks. + if incoming != outgoing { + return nil + } + + // If the incoming and outgoing link are equal, the htlc is part of a + // circular route which may be used to lock up our liquidity. If the + // switch is configured to allow circular routes, log that we are + // allowing the route then return nil. + if allowCircular { + log.Debugf("allowing circular route over link: %v "+ + "(payment hash: %x)", incoming, paymentHash) + return nil + } + + // If our node disallows circular routes, return a temporary channel + // failure. There is nothing wrong with the policy used by the remote + // node, so we do not include a channel update. + return NewDetailedLinkError( + lnwire.NewTemporaryChannelFailure(nil), + FailureDetailCircularRoute, + ) +} + // failAddPacket encrypts a fail packet back to an add packet's source. // The ciphertext will be derived from the failure message proivded by context. // This method returns the failErr if all other steps complete successfully. diff --git a/htlcswitch/switch_test.go b/htlcswitch/switch_test.go index 1497f009..828ef232 100644 --- a/htlcswitch/switch_test.go +++ b/htlcswitch/switch_test.go @@ -6,6 +6,7 @@ import ( "fmt" "io" "io/ioutil" + "reflect" "testing" "time" @@ -1324,6 +1325,171 @@ type multiHopFwdTest struct { expectedReply lnwire.FailCode } +// TestCircularForwards tests the allowing/disallowing of circular payments +// through the same channel in the case where the switch is configured to allow +// and disallow same channel circular forwards. +func TestCircularForwards(t *testing.T) { + chanID1, aliceChanID := genID() + preimage := [sha256.Size]byte{1} + hash := fastsha256.Sum256(preimage[:]) + + tests := []struct { + name string + allowCircularPayment bool + expectedErr error + }{ + { + name: "circular payment allowed", + allowCircularPayment: true, + expectedErr: nil, + }, + { + name: "circular payment disallowed", + allowCircularPayment: false, + expectedErr: NewDetailedLinkError( + lnwire.NewTemporaryChannelFailure(nil), + FailureDetailCircularRoute, + ), + }, + } + + for _, test := range tests { + test := test + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + alicePeer, err := newMockServer( + t, "alice", testStartingHeight, nil, + testDefaultDelta, + ) + if err != nil { + t.Fatalf("unable to create alice server: %v", + err) + } + + s, err := initSwitchWithDB(testStartingHeight, nil) + if err != nil { + t.Fatalf("unable to init switch: %v", err) + } + if err := s.Start(); err != nil { + t.Fatalf("unable to start switch: %v", err) + } + defer func() { _ = s.Stop() }() + + // Set the switch to allow or disallow circular routes + // according to the test's requirements. + s.cfg.AllowCircularRoute = test.allowCircularPayment + + aliceChannelLink := newMockChannelLink( + s, chanID1, aliceChanID, alicePeer, true, + ) + + if err := s.AddLink(aliceChannelLink); err != nil { + t.Fatalf("unable to add alice link: %v", err) + } + + // Create a new packet that loops through alice's link + // in a circle. + obfuscator := NewMockObfuscator() + packet := &htlcPacket{ + incomingChanID: aliceChannelLink.ShortChanID(), + outgoingChanID: aliceChannelLink.ShortChanID(), + htlc: &lnwire.UpdateAddHTLC{ + PaymentHash: hash, + Amount: 1, + }, + obfuscator: obfuscator, + } + + // Attempt to forward the packet and check for the expected + // error. + err = s.forward(packet) + if !reflect.DeepEqual(err, test.expectedErr) { + t.Fatalf("expected: %v, got: %v", + test.expectedErr, err) + } + + // Ensure that no circuits were opened. + if s.circuits.NumOpen() > 0 { + t.Fatal("do not expect any open circuits") + } + }) + } +} + +// TestCheckCircularForward tests the error returned by checkCircularForward +// in cases where we allow and disallow same channel circular forwards. +func TestCheckCircularForward(t *testing.T) { + tests := []struct { + name string + + // allowCircular determines whether we should allow circular + // forwards. + allowCircular bool + + // incomingLink is the link that the htlc arrived on. + incomingLink lnwire.ShortChannelID + + // outgoingLink is the link that the htlc forward + // is destined to leave on. + outgoingLink lnwire.ShortChannelID + + // expectedErr is the error we expect to be returned. + expectedErr *LinkError + }{ + { + name: "not circular, allowed in config", + allowCircular: true, + incomingLink: lnwire.NewShortChanIDFromInt(123), + outgoingLink: lnwire.NewShortChanIDFromInt(321), + expectedErr: nil, + }, + { + name: "not circular, not allowed in config", + allowCircular: false, + incomingLink: lnwire.NewShortChanIDFromInt(123), + outgoingLink: lnwire.NewShortChanIDFromInt(321), + expectedErr: nil, + }, + { + name: "circular, allowed in config", + allowCircular: true, + incomingLink: lnwire.NewShortChanIDFromInt(123), + outgoingLink: lnwire.NewShortChanIDFromInt(123), + expectedErr: nil, + }, + { + name: "circular, not allowed in config", + allowCircular: false, + incomingLink: lnwire.NewShortChanIDFromInt(123), + outgoingLink: lnwire.NewShortChanIDFromInt(123), + expectedErr: NewDetailedLinkError( + lnwire.NewTemporaryChannelFailure(nil), + FailureDetailCircularRoute, + ), + }, + } + + for _, test := range tests { + test := test + + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + // Check for a circular forward, the hash passed can + // be nil because it is only used for logging. + err := checkCircularForward( + test.incomingLink, test.outgoingLink, + test.allowCircular, lntypes.Hash{}, + ) + if !reflect.DeepEqual(err, test.expectedErr) { + t.Fatalf("expected: %v, got: %v", + test.expectedErr, err) + } + }) + } +} + // TestSkipIneligibleLinksMultiHopForward tests that if a multi-hop HTLC comes // along, then we won't attempt to froward it down al ink that isn't yet able // to forward any HTLC's. From afc7cc7f84f546d917409234d18ac322ebfdd741 Mon Sep 17 00:00:00 2001 From: carla Date: Thu, 30 Jan 2020 10:01:18 +0200 Subject: [PATCH 3/3] htlcswitch+config: make circular forwarding defence configurable Add a bool to the switch's config which can be used to disable same channel circular route checks. --- config.go | 2 ++ server.go | 1 + 2 files changed, 3 insertions(+) diff --git a/config.go b/config.go index 10f79ba5..89b934a8 100644 --- a/config.go +++ b/config.go @@ -341,6 +341,8 @@ type config struct { Watchtower *lncfg.Watchtower `group:"watchtower" namespace:"watchtower"` LegacyProtocol *lncfg.LegacyProtocol `group:"legacyprotocol" namespace:"legacyprotocol"` + + AllowCircularRoute bool `long:"allow-circular-route" description:"If true, our node will allow htlc forwards that arrive and depart on the same channel."` } // loadConfig initializes and parses the config using a config file and command diff --git a/server.go b/server.go index 72c66614..09a4dbc0 100644 --- a/server.go +++ b/server.go @@ -470,6 +470,7 @@ func newServer(listenAddrs []net.Addr, chanDB *channeldb.DB, FwdEventTicker: ticker.New(htlcswitch.DefaultFwdEventInterval), LogEventTicker: ticker.New(htlcswitch.DefaultLogInterval), AckEventTicker: ticker.New(htlcswitch.DefaultAckInterval), + AllowCircularRoute: cfg.AllowCircularRoute, RejectHTLC: cfg.RejectHTLC, }, uint32(currentHeight)) if err != nil {