move nexthop logic inside the interceptor

This commit is contained in:
Jesse de Wit
2023-06-15 10:50:20 +02:00
parent 898b69f9a7
commit 205d39d715
6 changed files with 87 additions and 39 deletions

View File

@@ -65,7 +65,7 @@ func (c *ClnClient) IsConnected(destination []byte) (bool, error) {
} }
for _, peer := range peers { for _, peer := range peers {
if pubKey == peer.Id { if pubKey == peer.Id && peer.Connected {
log.Printf("destination online: %x", destination) log.Printf("destination online: %x", destination)
return true, nil return true, nil
} }
@@ -228,3 +228,29 @@ func (c *ClnClient) GetClosedChannels(nodeID string, channelPoints map[string]ui
return r, nil return r, nil
} }
func (c *ClnClient) GetPeerId(scid *basetypes.ShortChannelID) ([]byte, error) {
scidStr := scid.ToString()
peers, err := c.client.ListPeers()
if err != nil {
return nil, err
}
var dest *string
for _, p := range peers {
for _, ch := range p.Channels {
if ch.Alias.Local == scidStr ||
ch.Alias.Remote == scidStr ||
ch.ShortChannelId == scidStr {
dest = &p.Id
break
}
}
}
if dest == nil {
return nil, nil
}
return hex.DecodeString(*dest)
}

View File

@@ -10,6 +10,7 @@ import (
"sync" "sync"
"time" "time"
"github.com/breez/lspd/basetypes"
"github.com/breez/lspd/cln_plugin/proto" "github.com/breez/lspd/cln_plugin/proto"
"github.com/breez/lspd/config" "github.com/breez/lspd/config"
"github.com/breez/lspd/interceptor" "github.com/breez/lspd/interceptor"
@@ -129,20 +130,6 @@ func (i *ClnHtlcInterceptor) intercept() error {
log.Printf("unexpected error in interceptor.Recv() %v", err) log.Printf("unexpected error in interceptor.Recv() %v", err)
break break
} }
nextHop := "<unknown>"
channels, err := i.client.client.GetChannel(request.Onion.ShortChannelId)
if err != nil {
for _, c := range channels {
if c.Source == i.config.NodePubkey {
nextHop = c.Destination
break
}
if c.Destination == i.config.NodePubkey {
nextHop = c.Source
break
}
}
}
i.doneWg.Add(1) i.doneWg.Add(1)
go func() { go func() {
@@ -150,8 +137,17 @@ func (i *ClnHtlcInterceptor) intercept() error {
if err != nil { if err != nil {
interceptorClient.Send(i.defaultResolution(request)) interceptorClient.Send(i.defaultResolution(request))
i.doneWg.Done() i.doneWg.Done()
return
} }
interceptResult := i.interceptor.Intercept(nextHop, paymentHash, request.Onion.ForwardMsat, request.Onion.OutgoingCltvValue, request.Htlc.CltvExpiry)
scid, err := basetypes.NewShortChannelIDFromString(request.Onion.ShortChannelId)
if err != nil {
interceptorClient.Send(i.defaultResolution(request))
i.doneWg.Done()
return
}
interceptResult := i.interceptor.Intercept(scid, paymentHash, request.Onion.ForwardMsat, request.Onion.OutgoingCltvValue, request.Htlc.CltvExpiry)
switch interceptResult.Action { switch interceptResult.Action {
case interceptor.INTERCEPT_RESUME_WITH_ONION: case interceptor.INTERCEPT_RESUME_WITH_ONION:
interceptorClient.Send(i.resumeWithOnion(request, interceptResult)) interceptorClient.Send(i.resumeWithOnion(request, interceptResult))

View File

@@ -73,7 +73,7 @@ func NewInterceptor(
} }
} }
func (i *Interceptor) Intercept(nextHop string, reqPaymentHash []byte, reqOutgoingAmountMsat uint64, reqOutgoingExpiry uint32, reqIncomingExpiry uint32) InterceptResult { func (i *Interceptor) Intercept(scid *basetypes.ShortChannelID, reqPaymentHash []byte, reqOutgoingAmountMsat uint64, reqOutgoingExpiry uint32, reqIncomingExpiry uint32) InterceptResult {
reqPaymentHashStr := hex.EncodeToString(reqPaymentHash) reqPaymentHashStr := hex.EncodeToString(reqPaymentHash)
resp, _, _ := i.payHashGroup.Do(reqPaymentHashStr, func() (interface{}, error) { resp, _, _ := i.payHashGroup.Do(reqPaymentHashStr, func() (interface{}, error) {
token, params, paymentHash, paymentSecret, destination, incomingAmountMsat, outgoingAmountMsat, channelPoint, tag, err := i.store.PaymentInfo(reqPaymentHash) token, params, paymentHash, paymentSecret, destination, incomingAmountMsat, outgoingAmountMsat, channelPoint, tag, err := i.store.PaymentInfo(reqPaymentHash)
@@ -85,7 +85,24 @@ func (i *Interceptor) Intercept(nextHop string, reqPaymentHash []byte, reqOutgoi
}, nil }, nil
} }
if paymentSecret == nil || (nextHop != "<unknown>" && nextHop != hex.EncodeToString(destination)) { nextHop, err := i.client.GetPeerId(scid)
if err != nil {
log.Printf("GetPeerId(%s) error: %v", scid.ToString(), err)
return InterceptResult{
Action: INTERCEPT_FAIL_HTLC_WITH_CODE,
FailureCode: FAILURE_TEMPORARY_NODE_FAILURE,
}, nil
}
// If the payment was registered, but the next hop is not the destination
// that means we are not the last hop of the payment, so we'll just forward.
if destination != nil && nextHop != nil && !bytes.Equal(nextHop, destination) {
return InterceptResult{
Action: INTERCEPT_RESUME,
}, nil
}
if paymentSecret == nil {
return InterceptResult{ return InterceptResult{
Action: INTERCEPT_RESUME, Action: INTERCEPT_RESUME,
}, nil }, nil

View File

@@ -31,6 +31,7 @@ type Client interface {
IsConnected(destination []byte) (bool, error) IsConnected(destination []byte) (bool, error)
OpenChannel(req *OpenChannelRequest) (*wire.OutPoint, error) OpenChannel(req *OpenChannelRequest) (*wire.OutPoint, error)
GetChannel(peerID []byte, channelPoint wire.OutPoint) (*GetChannelResult, error) GetChannel(peerID []byte, channelPoint wire.OutPoint) (*GetChannelResult, error)
GetPeerId(scid *basetypes.ShortChannelID) ([]byte, error)
GetNodeChannelCount(nodeID []byte) (int, error) GetNodeChannelCount(nodeID []byte) (int, error)
GetClosedChannels(nodeID string, channelPoints map[string]uint64) (map[string]uint64, error) GetClosedChannels(nodeID string, channelPoints map[string]uint64) (map[string]uint64, error)
} }

View File

@@ -79,18 +79,18 @@ func (c *LndClient) GetInfo() (*lightning.GetInfoResult, error) {
} }
func (c *LndClient) IsConnected(destination []byte) (bool, error) { func (c *LndClient) IsConnected(destination []byte) (bool, error) {
pubKey := hex.EncodeToString(destination) pubkey := hex.EncodeToString(destination)
r, err := c.client.ListPeers(context.Background(), &lnrpc.ListPeersRequest{LatestError: true}) r, err := c.client.GetPeerConnected(context.Background(), &lnrpc.GetPeerConnectedRequest{
Pubkey: pubkey,
})
if err != nil { if err != nil {
log.Printf("LND: client.ListPeers() error: %v", err) log.Printf("LND: client.GetPeerConnected() error: %v", err)
return false, fmt.Errorf("LND: client.ListPeers() error: %w", err) return false, fmt.Errorf("LND: client.GetPeerConnected() error: %w", err)
} }
for _, peer := range r.Peers { if r.Connected {
if pubKey == peer.PubKey { log.Printf("LND: destination online: %x", destination)
log.Printf("destination online: %x", destination) return true, nil
return true, nil
}
} }
log.Printf("LND: destination offline: %x", destination) log.Printf("LND: destination offline: %x", destination)
@@ -230,3 +230,20 @@ func (c *LndClient) getWaitingCloseChannels(nodeID string) ([]*lnrpc.PendingChan
} }
return waitingCloseChannels, nil return waitingCloseChannels, nil
} }
func (c *LndClient) GetPeerId(scid *basetypes.ShortChannelID) ([]byte, error) {
scidu64 := uint64(*scid)
peer, err := c.client.GetPeerIdByScid(context.Background(), &lnrpc.GetPeerIdByScidRequest{
Scid: scidu64,
})
if err != nil {
return nil, err
}
if peer.PeerId == "" {
return nil, nil
}
peerid, _ := hex.DecodeString(peer.PeerId)
return peerid, nil
}

View File

@@ -6,6 +6,7 @@ import (
"sync" "sync"
"time" "time"
"github.com/breez/lspd/basetypes"
"github.com/breez/lspd/config" "github.com/breez/lspd/config"
"github.com/breez/lspd/interceptor" "github.com/breez/lspd/interceptor"
"github.com/lightningnetwork/lnd/lnrpc" "github.com/lightningnetwork/lnd/lnrpc"
@@ -126,20 +127,10 @@ func (i *LndHtlcInterceptor) intercept() error {
break break
} }
nextHop := "<unknown>"
chanInfo, err := i.client.client.GetChanInfo(context.Background(), &lnrpc.ChanInfoRequest{ChanId: request.OutgoingRequestedChanId})
if err == nil && chanInfo != nil {
if chanInfo.Node1Pub == i.config.NodePubkey {
nextHop = chanInfo.Node2Pub
}
if chanInfo.Node2Pub == i.config.NodePubkey {
nextHop = chanInfo.Node1Pub
}
}
i.doneWg.Add(1) i.doneWg.Add(1)
go func() { go func() {
interceptResult := i.interceptor.Intercept(nextHop, request.PaymentHash, request.OutgoingAmountMsat, request.OutgoingExpiry, request.IncomingExpiry) scid := basetypes.ShortChannelID(request.OutgoingRequestedChanId)
interceptResult := i.interceptor.Intercept(&scid, request.PaymentHash, request.OutgoingAmountMsat, request.OutgoingExpiry, request.IncomingExpiry)
switch interceptResult.Action { switch interceptResult.Action {
case interceptor.INTERCEPT_RESUME_WITH_ONION: case interceptor.INTERCEPT_RESUME_WITH_ONION:
interceptorClient.Send(&routerrpc.ForwardHtlcInterceptResponse{ interceptorClient.Send(&routerrpc.ForwardHtlcInterceptResponse{