diff --git a/cln/cln_client.go b/cln/cln_client.go index 8475ce6..6125dde 100644 --- a/cln/cln_client.go +++ b/cln/cln_client.go @@ -65,7 +65,7 @@ func (c *ClnClient) IsConnected(destination []byte) (bool, error) { } for _, peer := range peers { - if pubKey == peer.Id { + if pubKey == peer.Id && peer.Connected { log.Printf("destination online: %x", destination) return true, nil } @@ -228,3 +228,29 @@ func (c *ClnClient) GetClosedChannels(nodeID string, channelPoints map[string]ui return r, nil } + +func (c *ClnClient) GetPeerId(scid *basetypes.ShortChannelID) ([]byte, error) { + scidStr := scid.ToString() + peers, err := c.client.ListPeers() + if err != nil { + return nil, err + } + + var dest *string + for _, p := range peers { + for _, ch := range p.Channels { + if ch.Alias.Local == scidStr || + ch.Alias.Remote == scidStr || + ch.ShortChannelId == scidStr { + dest = &p.Id + break + } + } + } + + if dest == nil { + return nil, nil + } + + return hex.DecodeString(*dest) +} diff --git a/cln/cln_interceptor.go b/cln/cln_interceptor.go index a2c6162..6f3e4b1 100644 --- a/cln/cln_interceptor.go +++ b/cln/cln_interceptor.go @@ -10,6 +10,7 @@ import ( "sync" "time" + "github.com/breez/lspd/basetypes" "github.com/breez/lspd/cln_plugin/proto" "github.com/breez/lspd/config" "github.com/breez/lspd/interceptor" @@ -129,20 +130,6 @@ func (i *ClnHtlcInterceptor) intercept() error { log.Printf("unexpected error in interceptor.Recv() %v", err) break } - nextHop := "" - channels, err := i.client.client.GetChannel(request.Onion.ShortChannelId) - if err != nil { - for _, c := range channels { - if c.Source == i.config.NodePubkey { - nextHop = c.Destination - break - } - if c.Destination == i.config.NodePubkey { - nextHop = c.Source - break - } - } - } i.doneWg.Add(1) go func() { @@ -150,8 +137,17 @@ func (i *ClnHtlcInterceptor) intercept() error { if err != nil { interceptorClient.Send(i.defaultResolution(request)) i.doneWg.Done() + return } - interceptResult := i.interceptor.Intercept(nextHop, paymentHash, request.Onion.ForwardMsat, request.Onion.OutgoingCltvValue, request.Htlc.CltvExpiry) + + scid, err := basetypes.NewShortChannelIDFromString(request.Onion.ShortChannelId) + if err != nil { + interceptorClient.Send(i.defaultResolution(request)) + i.doneWg.Done() + return + } + + interceptResult := i.interceptor.Intercept(scid, paymentHash, request.Onion.ForwardMsat, request.Onion.OutgoingCltvValue, request.Htlc.CltvExpiry) switch interceptResult.Action { case interceptor.INTERCEPT_RESUME_WITH_ONION: interceptorClient.Send(i.resumeWithOnion(request, interceptResult)) diff --git a/interceptor/intercept.go b/interceptor/intercept.go index ae7545d..35b6324 100644 --- a/interceptor/intercept.go +++ b/interceptor/intercept.go @@ -73,7 +73,7 @@ func NewInterceptor( } } -func (i *Interceptor) Intercept(nextHop string, reqPaymentHash []byte, reqOutgoingAmountMsat uint64, reqOutgoingExpiry uint32, reqIncomingExpiry uint32) InterceptResult { +func (i *Interceptor) Intercept(scid *basetypes.ShortChannelID, reqPaymentHash []byte, reqOutgoingAmountMsat uint64, reqOutgoingExpiry uint32, reqIncomingExpiry uint32) InterceptResult { reqPaymentHashStr := hex.EncodeToString(reqPaymentHash) resp, _, _ := i.payHashGroup.Do(reqPaymentHashStr, func() (interface{}, error) { token, params, paymentHash, paymentSecret, destination, incomingAmountMsat, outgoingAmountMsat, channelPoint, tag, err := i.store.PaymentInfo(reqPaymentHash) @@ -85,7 +85,24 @@ func (i *Interceptor) Intercept(nextHop string, reqPaymentHash []byte, reqOutgoi }, nil } - if paymentSecret == nil || (nextHop != "" && nextHop != hex.EncodeToString(destination)) { + nextHop, err := i.client.GetPeerId(scid) + if err != nil { + log.Printf("GetPeerId(%s) error: %v", scid.ToString(), err) + return InterceptResult{ + Action: INTERCEPT_FAIL_HTLC_WITH_CODE, + FailureCode: FAILURE_TEMPORARY_NODE_FAILURE, + }, nil + } + + // If the payment was registered, but the next hop is not the destination + // that means we are not the last hop of the payment, so we'll just forward. + if destination != nil && nextHop != nil && !bytes.Equal(nextHop, destination) { + return InterceptResult{ + Action: INTERCEPT_RESUME, + }, nil + } + + if paymentSecret == nil { return InterceptResult{ Action: INTERCEPT_RESUME, }, nil diff --git a/lightning/client.go b/lightning/client.go index d414deb..4109e41 100644 --- a/lightning/client.go +++ b/lightning/client.go @@ -31,6 +31,7 @@ type Client interface { IsConnected(destination []byte) (bool, error) OpenChannel(req *OpenChannelRequest) (*wire.OutPoint, error) GetChannel(peerID []byte, channelPoint wire.OutPoint) (*GetChannelResult, error) + GetPeerId(scid *basetypes.ShortChannelID) ([]byte, error) GetNodeChannelCount(nodeID []byte) (int, error) GetClosedChannels(nodeID string, channelPoints map[string]uint64) (map[string]uint64, error) } diff --git a/lnd/client.go b/lnd/client.go index 0105569..1194b37 100644 --- a/lnd/client.go +++ b/lnd/client.go @@ -79,18 +79,18 @@ func (c *LndClient) GetInfo() (*lightning.GetInfoResult, error) { } func (c *LndClient) IsConnected(destination []byte) (bool, error) { - pubKey := hex.EncodeToString(destination) + pubkey := hex.EncodeToString(destination) - r, err := c.client.ListPeers(context.Background(), &lnrpc.ListPeersRequest{LatestError: true}) + r, err := c.client.GetPeerConnected(context.Background(), &lnrpc.GetPeerConnectedRequest{ + Pubkey: pubkey, + }) if err != nil { - log.Printf("LND: client.ListPeers() error: %v", err) - return false, fmt.Errorf("LND: client.ListPeers() error: %w", err) + log.Printf("LND: client.GetPeerConnected() error: %v", err) + return false, fmt.Errorf("LND: client.GetPeerConnected() error: %w", err) } - for _, peer := range r.Peers { - if pubKey == peer.PubKey { - log.Printf("destination online: %x", destination) - return true, nil - } + if r.Connected { + log.Printf("LND: destination online: %x", destination) + return true, nil } log.Printf("LND: destination offline: %x", destination) @@ -230,3 +230,20 @@ func (c *LndClient) getWaitingCloseChannels(nodeID string) ([]*lnrpc.PendingChan } return waitingCloseChannels, nil } + +func (c *LndClient) GetPeerId(scid *basetypes.ShortChannelID) ([]byte, error) { + scidu64 := uint64(*scid) + peer, err := c.client.GetPeerIdByScid(context.Background(), &lnrpc.GetPeerIdByScidRequest{ + Scid: scidu64, + }) + if err != nil { + return nil, err + } + + if peer.PeerId == "" { + return nil, nil + } + + peerid, _ := hex.DecodeString(peer.PeerId) + return peerid, nil +} diff --git a/lnd/interceptor.go b/lnd/interceptor.go index 925a6eb..fa296d0 100644 --- a/lnd/interceptor.go +++ b/lnd/interceptor.go @@ -6,6 +6,7 @@ import ( "sync" "time" + "github.com/breez/lspd/basetypes" "github.com/breez/lspd/config" "github.com/breez/lspd/interceptor" "github.com/lightningnetwork/lnd/lnrpc" @@ -126,20 +127,10 @@ func (i *LndHtlcInterceptor) intercept() error { break } - nextHop := "" - chanInfo, err := i.client.client.GetChanInfo(context.Background(), &lnrpc.ChanInfoRequest{ChanId: request.OutgoingRequestedChanId}) - if err == nil && chanInfo != nil { - if chanInfo.Node1Pub == i.config.NodePubkey { - nextHop = chanInfo.Node2Pub - } - if chanInfo.Node2Pub == i.config.NodePubkey { - nextHop = chanInfo.Node1Pub - } - } - i.doneWg.Add(1) go func() { - interceptResult := i.interceptor.Intercept(nextHop, request.PaymentHash, request.OutgoingAmountMsat, request.OutgoingExpiry, request.IncomingExpiry) + scid := basetypes.ShortChannelID(request.OutgoingRequestedChanId) + interceptResult := i.interceptor.Intercept(&scid, request.PaymentHash, request.OutgoingAmountMsat, request.OutgoingExpiry, request.IncomingExpiry) switch interceptResult.Action { case interceptor.INTERCEPT_RESUME_WITH_ONION: interceptorClient.Send(&routerrpc.ForwardHtlcInterceptResponse{