lnrpc+routing: add edges and nodes restrictions to query routes

This commit allows the execution of QueryRoutes to be controlled using
lists of black-listed edges and nodes. Any path returned will not pass
through the edges and/or nodes on the list.
This commit is contained in:
Joost Jager
2019-03-05 12:42:29 +01:00
parent 4376f3e1bd
commit b09adc3219
8 changed files with 757 additions and 595 deletions

File diff suppressed because it is too large Load Diff

View File

@@ -1480,7 +1480,31 @@ message QueryRoutesRequest {
send the payment. send the payment.
*/ */
FeeLimit fee_limit = 5; FeeLimit fee_limit = 5;
/**
A list of nodes to ignore during path finding.
*/
repeated bytes ignored_nodes = 6;
/**
A list of edges to ignore during path finding.
*/
repeated EdgeLocator ignored_edges = 7;
} }
message EdgeLocator {
/// The short channel id of this edge.
uint64 channel_id = 1;
/**
The direction of this edge. If direction_reverse is false, the direction
of this edge is from the channel endpoint with the lexicographically smaller
pub key to the endpoint with the larger pub key. If direction_reverse is
is true, the edge goes the other way.
*/
bool direction_reverse = 2;
}
message QueryRoutesResponse { message QueryRoutesResponse {
repeated Route routes = 1 [json_name = "routes"]; repeated Route routes = 1 [json_name = "routes"];
} }

View File

@@ -593,6 +593,17 @@
"required": false, "required": false,
"type": "string", "type": "string",
"format": "int64" "format": "int64"
},
{
"name": "ignored_nodes",
"description": "*\nA list of nodes to ignore during path finding.",
"in": "query",
"required": false,
"type": "array",
"items": {
"type": "string",
"format": "byte"
}
} }
], ],
"tags": [ "tags": [

View File

@@ -790,7 +790,7 @@ func findPath(g *graphParams, r *RestrictParams,
// algorithm in a block box manner. // algorithm in a block box manner.
func findPaths(tx *bbolt.Tx, graph *channeldb.ChannelGraph, func findPaths(tx *bbolt.Tx, graph *channeldb.ChannelGraph,
source *channeldb.LightningNode, target *btcec.PublicKey, source *channeldb.LightningNode, target *btcec.PublicKey,
amt lnwire.MilliSatoshi, feeLimit lnwire.MilliSatoshi, numPaths uint32, amt lnwire.MilliSatoshi, restrictions *RestrictParams, numPaths uint32,
bandwidthHints map[uint64]lnwire.MilliSatoshi) ([][]*channeldb.ChannelEdgePolicy, error) { bandwidthHints map[uint64]lnwire.MilliSatoshi) ([][]*channeldb.ChannelEdgePolicy, error) {
// TODO(roasbeef): modifying ordering within heap to eliminate final // TODO(roasbeef): modifying ordering within heap to eliminate final
@@ -809,10 +809,7 @@ func findPaths(tx *bbolt.Tx, graph *channeldb.ChannelGraph,
graph: graph, graph: graph,
bandwidthHints: bandwidthHints, bandwidthHints: bandwidthHints,
}, },
&RestrictParams{ restrictions, source, target, amt,
FeeLimit: feeLimit,
},
source, target, amt,
) )
if err != nil { if err != nil {
log.Errorf("Unable to find path: %v", err) log.Errorf("Unable to find path: %v", err)
@@ -846,6 +843,13 @@ func findPaths(tx *bbolt.Tx, graph *channeldb.ChannelGraph,
ignoredEdges := make(map[EdgeLocator]struct{}) ignoredEdges := make(map[EdgeLocator]struct{})
ignoredVertexes := make(map[Vertex]struct{}) ignoredVertexes := make(map[Vertex]struct{})
for e := range restrictions.IgnoredEdges {
ignoredEdges[e] = struct{}{}
}
for n := range restrictions.IgnoredNodes {
ignoredVertexes[n] = struct{}{}
}
// Our spur node is the i-th node in the prior shortest // Our spur node is the i-th node in the prior shortest
// path, and our root path will be all nodes in the // path, and our root path will be all nodes in the
// path leading up to our spurNode. // path leading up to our spurNode.
@@ -889,17 +893,22 @@ func findPaths(tx *bbolt.Tx, graph *channeldb.ChannelGraph,
// TODO: Fee limit passed to spur path finding isn't // TODO: Fee limit passed to spur path finding isn't
// correct, because it doesn't take into account the // correct, because it doesn't take into account the
// fees already paid on the root path. // fees already paid on the root path.
//
// TODO: Outgoing channel restriction isn't obeyed for
// spur paths.
spurRestrictions := &RestrictParams{
IgnoredEdges: ignoredEdges,
IgnoredNodes: ignoredVertexes,
FeeLimit: restrictions.FeeLimit,
}
spurPath, err := findPath( spurPath, err := findPath(
&graphParams{ &graphParams{
tx: tx, tx: tx,
graph: graph, graph: graph,
bandwidthHints: bandwidthHints, bandwidthHints: bandwidthHints,
}, },
&RestrictParams{ spurRestrictions, spurNode, target, amt,
IgnoredNodes: ignoredVertexes,
IgnoredEdges: ignoredEdges,
FeeLimit: feeLimit,
}, spurNode, target, amt,
) )
// If we weren't able to find a path, we'll continue to // If we weren't able to find a path, we'll continue to

View File

@@ -47,6 +47,12 @@ const (
noFeeLimit = lnwire.MilliSatoshi(math.MaxUint32) noFeeLimit = lnwire.MilliSatoshi(math.MaxUint32)
) )
var (
noRestrictions = &RestrictParams{
FeeLimit: noFeeLimit,
}
)
var ( var (
testSig = &btcec.Signature{ testSig = &btcec.Signature{
R: new(big.Int), R: new(big.Int),
@@ -953,9 +959,12 @@ func TestKShortestPathFinding(t *testing.T) {
paymentAmt := lnwire.NewMSatFromSatoshis(100) paymentAmt := lnwire.NewMSatFromSatoshis(100)
target := graph.aliasMap["luoji"] target := graph.aliasMap["luoji"]
restrictions := &RestrictParams{
FeeLimit: noFeeLimit,
}
paths, err := findPaths( paths, err := findPaths(
nil, graph.graph, sourceNode, target, paymentAmt, noFeeLimit, 100, nil, graph.graph, sourceNode, target, paymentAmt, restrictions,
nil, 100, nil,
) )
if err != nil { if err != nil {
t.Fatalf("unable to find paths between roasbeef and "+ t.Fatalf("unable to find paths between roasbeef and "+
@@ -1706,7 +1715,7 @@ func TestPathFindSpecExample(t *testing.T) {
// Query for a route of 4,999,999 mSAT to carol. // Query for a route of 4,999,999 mSAT to carol.
carol := ctx.aliases["C"] carol := ctx.aliases["C"]
const amt lnwire.MilliSatoshi = 4999999 const amt lnwire.MilliSatoshi = 4999999
routes, err := ctx.router.FindRoutes(carol, amt, noFeeLimit, 100) routes, err := ctx.router.FindRoutes(carol, amt, noRestrictions, 100)
if err != nil { if err != nil {
t.Fatalf("unable to find route: %v", err) t.Fatalf("unable to find route: %v", err)
} }
@@ -1767,7 +1776,7 @@ func TestPathFindSpecExample(t *testing.T) {
// We'll now request a route from A -> B -> C. // We'll now request a route from A -> B -> C.
ctx.router.routeCache = make(map[routeTuple][]*Route) ctx.router.routeCache = make(map[routeTuple][]*Route)
routes, err = ctx.router.FindRoutes(carol, amt, noFeeLimit, 100) routes, err = ctx.router.FindRoutes(carol, amt, noRestrictions, 100)
if err != nil { if err != nil {
t.Fatalf("unable to find routes: %v", err) t.Fatalf("unable to find routes: %v", err)
} }

View File

@@ -1326,7 +1326,7 @@ func pathsToFeeSortedRoutes(source Vertex, paths [][]*channeldb.ChannelEdgePolic
// route that will be ranked the highest is the one with the lowest cumulative // route that will be ranked the highest is the one with the lowest cumulative
// fee along the route. // fee along the route.
func (r *ChannelRouter) FindRoutes(target *btcec.PublicKey, func (r *ChannelRouter) FindRoutes(target *btcec.PublicKey,
amt, feeLimit lnwire.MilliSatoshi, numPaths uint32, amt lnwire.MilliSatoshi, restrictions *RestrictParams, numPaths uint32,
finalExpiry ...uint16) ([]*Route, error) { finalExpiry ...uint16) ([]*Route, error) {
var finalCLTVDelta uint16 var finalCLTVDelta uint16
@@ -1396,8 +1396,8 @@ func (r *ChannelRouter) FindRoutes(target *btcec.PublicKey,
// we'll execute our KSP algorithm to find the k-shortest paths from // we'll execute our KSP algorithm to find the k-shortest paths from
// our source to the destination. // our source to the destination.
shortestPaths, err := findPaths( shortestPaths, err := findPaths(
tx, r.cfg.Graph, r.selfNode, target, amt, feeLimit, numPaths, tx, r.cfg.Graph, r.selfNode, target, amt, restrictions,
bandwidthHints, numPaths, bandwidthHints,
) )
if err != nil { if err != nil {
tx.Rollback() tx.Rollback()

View File

@@ -184,7 +184,7 @@ func TestFindRoutesFeeSorting(t *testing.T) {
paymentAmt := lnwire.NewMSatFromSatoshis(100) paymentAmt := lnwire.NewMSatFromSatoshis(100)
target := ctx.aliases["luoji"] target := ctx.aliases["luoji"]
routes, err := ctx.router.FindRoutes( routes, err := ctx.router.FindRoutes(
target, paymentAmt, noFeeLimit, defaultNumRoutes, target, paymentAmt, noRestrictions, defaultNumRoutes,
DefaultFinalCLTVDelta, DefaultFinalCLTVDelta,
) )
if err != nil { if err != nil {
@@ -240,10 +240,12 @@ func TestFindRoutesWithFeeLimit(t *testing.T) {
// see the first route. // see the first route.
target := ctx.aliases["sophon"] target := ctx.aliases["sophon"]
paymentAmt := lnwire.NewMSatFromSatoshis(100) paymentAmt := lnwire.NewMSatFromSatoshis(100)
feeLimit := lnwire.NewMSatFromSatoshis(10) restrictions := &RestrictParams{
FeeLimit: lnwire.NewMSatFromSatoshis(10),
}
routes, err := ctx.router.FindRoutes( routes, err := ctx.router.FindRoutes(
target, paymentAmt, feeLimit, defaultNumRoutes, target, paymentAmt, restrictions, defaultNumRoutes,
DefaultFinalCLTVDelta, DefaultFinalCLTVDelta,
) )
if err != nil { if err != nil {
@@ -254,7 +256,7 @@ func TestFindRoutesWithFeeLimit(t *testing.T) {
t.Fatalf("expected 1 route, got %d", len(routes)) t.Fatalf("expected 1 route, got %d", len(routes))
} }
if routes[0].TotalFees > feeLimit { if routes[0].TotalFees > restrictions.FeeLimit {
t.Fatalf("route exceeded fee limit: %v", spew.Sdump(routes[0])) t.Fatalf("route exceeded fee limit: %v", spew.Sdump(routes[0]))
} }
@@ -1307,7 +1309,7 @@ func TestAddEdgeUnknownVertexes(t *testing.T) {
paymentAmt := lnwire.NewMSatFromSatoshis(100) paymentAmt := lnwire.NewMSatFromSatoshis(100)
targetNode := priv2.PubKey() targetNode := priv2.PubKey()
routes, err := ctx.router.FindRoutes( routes, err := ctx.router.FindRoutes(
targetNode, paymentAmt, noFeeLimit, defaultNumRoutes, targetNode, paymentAmt, noRestrictions, defaultNumRoutes,
DefaultFinalCLTVDelta, DefaultFinalCLTVDelta,
) )
if err != nil { if err != nil {
@@ -1352,7 +1354,7 @@ func TestAddEdgeUnknownVertexes(t *testing.T) {
// Should still be able to find the routes, and the info should be // Should still be able to find the routes, and the info should be
// updated. // updated.
routes, err = ctx.router.FindRoutes( routes, err = ctx.router.FindRoutes(
targetNode, paymentAmt, noFeeLimit, defaultNumRoutes, targetNode, paymentAmt, noRestrictions, defaultNumRoutes,
DefaultFinalCLTVDelta, DefaultFinalCLTVDelta,
) )
if err != nil { if err != nil {

View File

@@ -4022,8 +4022,36 @@ func (r *rpcServer) QueryRoutes(ctx context.Context,
"allowed is %v", amt, maxPaymentMSat.ToSatoshis()) "allowed is %v", amt, maxPaymentMSat.ToSatoshis())
} }
// Unmarshall restrictions from request.
feeLimit := calculateFeeLimit(in.FeeLimit, amtMSat) feeLimit := calculateFeeLimit(in.FeeLimit, amtMSat)
ignoredNodes := make(map[routing.Vertex]struct{})
for _, ignorePubKey := range in.IgnoredNodes {
if len(ignorePubKey) != 33 {
return nil, fmt.Errorf("invalid ignore node pubkey")
}
var ignoreVertex routing.Vertex
copy(ignoreVertex[:], ignorePubKey)
ignoredNodes[ignoreVertex] = struct{}{}
}
ignoredEdges := make(map[routing.EdgeLocator]struct{})
for _, ignoredEdge := range in.IgnoredEdges {
locator := routing.EdgeLocator{
ChannelID: ignoredEdge.ChannelId,
}
if ignoredEdge.DirectionReverse {
locator.Direction = 1
}
ignoredEdges[locator] = struct{}{}
}
restrictions := &routing.RestrictParams{
FeeLimit: feeLimit,
IgnoredNodes: ignoredNodes,
IgnoredEdges: ignoredEdges,
}
// numRoutes will default to 10 if not specified explicitly. // numRoutes will default to 10 if not specified explicitly.
numRoutesIn := uint32(in.NumRoutes) numRoutesIn := uint32(in.NumRoutes)
if numRoutesIn == 0 { if numRoutesIn == 0 {
@@ -4037,13 +4065,14 @@ func (r *rpcServer) QueryRoutes(ctx context.Context,
routes []*routing.Route routes []*routing.Route
findErr error findErr error
) )
if in.FinalCltvDelta == 0 { if in.FinalCltvDelta == 0 {
routes, findErr = r.server.chanRouter.FindRoutes( routes, findErr = r.server.chanRouter.FindRoutes(
pubKey, amtMSat, feeLimit, numRoutesIn, pubKey, amtMSat, restrictions, numRoutesIn,
) )
} else { } else {
routes, findErr = r.server.chanRouter.FindRoutes( routes, findErr = r.server.chanRouter.FindRoutes(
pubKey, amtMSat, feeLimit, numRoutesIn, pubKey, amtMSat, restrictions, numRoutesIn,
uint16(in.FinalCltvDelta), uint16(in.FinalCltvDelta),
) )
} }