diff --git a/routing/pathfind.go b/routing/pathfind.go index 50c9e107..7788dc83 100644 --- a/routing/pathfind.go +++ b/routing/pathfind.go @@ -150,7 +150,22 @@ func newRoute(sourceVertex route.Vertex, amtToForward lnwire.MilliSatoshi fee lnwire.MilliSatoshi outgoingTimeLock uint32 + tlvPayload bool ) + + // Define a helper function that checks this edge's feature + // vector for support for a given feature. We assume at this + // point that the feature vectors transitive dependencies have + // been validated. + supports := edge.Node.Features.HasFeature + + // We start by assuming the node doesn't support TLV. We'll now + // inspect the node's feature vector to see if we can promote + // the hop. We assume already that the feature vector's + // transitive dependencies have already been validated by path + // finding or some other means. + tlvPayload = supports(lnwire.TLVOnionPayloadOptional) + if i == len(pathEdges)-1 { // If this is the last hop, then the hop payload will // contain the exact amount. In BOLT #4: Onion Routing @@ -194,19 +209,7 @@ func newRoute(sourceVertex route.Vertex, ChannelID: edge.ChannelID, AmtToForward: amtToForward, OutgoingTimeLock: outgoingTimeLock, - LegacyPayload: true, - } - - // We start out above by assuming that this node needs the - // legacy payload, as if we don't have the full - // NodeAnnouncement information for this node, then we can't - // assume it knows the latest features. If we do have a feature - // vector for this node, then we'll update the info now. - if edge.Node.Features != nil { - features := edge.Node.Features - currentHop.LegacyPayload = !features.HasFeature( - lnwire.TLVOnionPayloadOptional, - ) + LegacyPayload: !tlvPayload, } // If this is the last hop, then we'll populate any TLV records diff --git a/routing/pathfind_test.go b/routing/pathfind_test.go index 104d47b5..f7607166 100644 --- a/routing/pathfind_test.go +++ b/routing/pathfind_test.go @@ -1095,6 +1095,10 @@ func TestNewRoute(t *testing.T) { // indicated by hops. paymentAmount lnwire.MilliSatoshi + // destFeatures is a feature vector, that if non-nil, will + // overwrite the final hop's feature vector in the graph. + destFeatures *lnwire.FeatureVector + // expectedFees is a list of fees that every hop is expected // to charge for forwarding. expectedFees []lnwire.MilliSatoshi @@ -1123,6 +1127,8 @@ func TestNewRoute(t *testing.T) { // expectedErrorCode indicates the expected error code when // expectError is true. expectedErrorCode errorCode + + expectedTLVPayload bool }{ { // For a single hop payment, no fees are expected to be paid. @@ -1149,6 +1155,22 @@ func TestNewRoute(t *testing.T) { expectedTimeLocks: []uint32{1, 1}, expectedTotalAmount: 100130, expectedTotalTimeLock: 6, + }, { + // For a two hop payment, only the fee for the first hop + // needs to be paid. The destination hop does not require + // a fee to receive the payment. + name: "two hop tlv onion feature", + destFeatures: tlvFeatures, + paymentAmount: 100000, + hops: []*channeldb.ChannelEdgePolicy{ + createHop(0, 1000, 1000000, 10), + createHop(30, 1000, 1000000, 5), + }, + expectedFees: []lnwire.MilliSatoshi{130, 0}, + expectedTimeLocks: []uint32{1, 1}, + expectedTotalAmount: 100130, + expectedTotalTimeLock: 6, + expectedTLVPayload: true, }, { // A three hop payment where the first and second hop // will both charge 1 msat. The fee for the first hop @@ -1205,6 +1227,15 @@ func TestNewRoute(t *testing.T) { }} for _, testCase := range testCases { + testCase := testCase + + // Overwrite the final hop's features if the test requires a + // custom feature vector. + if testCase.destFeatures != nil { + finalHop := testCase.hops[len(testCase.hops)-1] + finalHop.Node.Features = testCase.destFeatures + } + assertRoute := func(t *testing.T, route *route.Route) { if route.TotalAmount != testCase.expectedTotalAmount { t.Errorf("Expected total amount is be %v"+ @@ -1248,6 +1279,16 @@ func TestNewRoute(t *testing.T) { route.Hops[i].OutgoingTimeLock) } } + + finalHop := route.Hops[len(route.Hops)-1] + if !finalHop.LegacyPayload != + testCase.expectedTLVPayload { + + t.Errorf("Expected tlv payload: %t, "+ + "but got: %t instead", + testCase.expectedTLVPayload, + !finalHop.LegacyPayload) + } } t.Run(testCase.name, func(t *testing.T) {