Files
lspd/lnd/client.go
2024-02-23 08:54:00 +01:00

571 lines
14 KiB
Go

package lnd
import (
"context"
"crypto/x509"
"encoding/hex"
"fmt"
"log"
"sync"
"time"
"github.com/breez/lspd/config"
"github.com/breez/lspd/lightning"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/wire"
"github.com/lightningnetwork/lnd/lnrpc"
"github.com/lightningnetwork/lnd/lnrpc/chainrpc"
"github.com/lightningnetwork/lnd/lnrpc/routerrpc"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/status"
)
type LndClient struct {
client lnrpc.LightningClient
routerClient routerrpc.RouterClient
chainNotifierClient chainrpc.ChainNotifierClient
conn *grpc.ClientConn
listenerCtx context.Context
listenerCancel context.CancelFunc
peersubs map[string]map[uint64]chan struct{}
chansubs map[string]map[uint64]chan struct{}
submtx sync.RWMutex
index uint64
}
func NewLndClient(conf *config.LndConfig) (*LndClient, error) {
_, err := hex.DecodeString(conf.Macaroon)
if err != nil {
return nil, fmt.Errorf("failed to decode macaroon: %w", err)
}
// Creds file to connect to LND gRPC
cp := x509.NewCertPool()
if !cp.AppendCertsFromPEM([]byte(conf.Cert)) {
return nil, fmt.Errorf("credentials: failed to append certificates")
}
creds := credentials.NewClientTLSFromCert(cp, "")
macCred := NewMacaroonCredential(conf.Macaroon)
// Address of an LND instance
conn, err := grpc.Dial(
conf.Address,
grpc.WithTransportCredentials(creds),
grpc.WithPerRPCCredentials(macCred),
)
if err != nil {
log.Fatalf("Failed to connect to LND gRPC: %v", err)
}
client := lnrpc.NewLightningClient(conn)
routerClient := routerrpc.NewRouterClient(conn)
chainNotifierClient := chainrpc.NewChainNotifierClient(conn)
return &LndClient{
client: client,
routerClient: routerClient,
chainNotifierClient: chainNotifierClient,
conn: conn,
peersubs: make(map[string]map[uint64]chan struct{}),
chansubs: make(map[string]map[uint64]chan struct{}),
}, nil
}
func (c *LndClient) Close() {
cancel := c.listenerCancel
if cancel != nil {
cancel()
}
c.conn.Close()
}
func (c *LndClient) StartListeners() {
c.listenerCtx, c.listenerCancel = context.WithCancel(context.Background())
go c.listenPeerEvents()
go c.listenChannelEvents()
}
func (c *LndClient) listenPeerEvents() {
ctx := c.listenerCtx
for {
if ctx.Err() != nil {
return
}
sub, err := c.client.SubscribePeerEvents(
ctx,
&lnrpc.PeerEventSubscription{},
)
if err != nil {
log.Printf("SubscribePeerEvents: %v", err)
<-time.After(time.Second)
continue
}
for {
if ctx.Err() != nil {
return
}
msg, err := sub.Recv()
if err != nil {
status, ok := status.FromError(err)
if ok && status.Code() == codes.Canceled {
log.Printf("listenPeerEvents: Got code canceled. Break.")
break
}
log.Printf("unexpected error in listenPeerEvents: %v", err)
break
}
if msg.Type != lnrpc.PeerEvent_PEER_ONLINE {
continue
}
c.submtx.RLock()
subs, ok := c.peersubs[msg.PubKey]
if ok {
for _, sub := range subs {
sub <- struct{}{}
}
}
c.submtx.RUnlock()
}
<-time.After(time.Second)
}
}
func (c *LndClient) listenChannelEvents() {
ctx := c.listenerCtx
for {
if ctx.Err() != nil {
return
}
sub, err := c.client.SubscribeChannelEvents(
ctx,
&lnrpc.ChannelEventSubscription{},
)
if err != nil {
log.Printf("listenChannelEvents: SubscribeChannelEvents: %v", err)
<-time.After(time.Second)
continue
}
for {
if ctx.Err() != nil {
return
}
msg, err := sub.Recv()
if err != nil {
status, ok := status.FromError(err)
if ok && status.Code() == codes.Canceled {
log.Printf("listenChannelEvents: Got code canceled. Break.")
break
}
log.Printf("unexpected error in listenChannelEvents: %v", err)
break
}
if msg.Type != lnrpc.ChannelEventUpdate_ACTIVE_CHANNEL {
continue
}
ch := msg.GetActiveChannel()
point, err := extractChannelPoint(ch)
if err != nil {
log.Printf("listenChannelEvents: Failed to extract channel point %+v: %v", ch, err)
continue
}
c.submtx.RLock()
subs, ok := c.chansubs[point]
if ok {
for _, sub := range subs {
sub <- struct{}{}
}
}
c.submtx.RUnlock()
}
<-time.After(time.Second)
}
}
func extractChannelPoint(cp *lnrpc.ChannelPoint) (string, error) {
str := cp.GetFundingTxidStr()
if str == "" {
b := cp.GetFundingTxidBytes()
h, err := chainhash.NewHash(b)
if err != nil {
return "", err
}
str = h.String()
}
return fmt.Sprintf("%s:%d", str, cp.OutputIndex), nil
}
func (c *LndClient) GetInfo() (*lightning.GetInfoResult, error) {
info, err := c.client.GetInfo(context.Background(), &lnrpc.GetInfoRequest{})
if err != nil {
log.Printf("LND: client.GetInfo() error: %v", err)
return nil, err
}
return &lightning.GetInfoResult{
Alias: info.Alias,
Pubkey: info.IdentityPubkey,
}, nil
}
func (c *LndClient) IsConnected(destination []byte) (bool, error) {
pubkey := hex.EncodeToString(destination)
r, err := c.client.GetPeerConnected(context.Background(), &lnrpc.GetPeerConnectedRequest{
Pubkey: pubkey,
})
if err != nil {
log.Printf("LND: client.GetPeerConnected() error: %v", err)
return false, fmt.Errorf("LND: client.GetPeerConnected() error: %w", err)
}
if r.Connected {
log.Printf("LND: destination online: %x", destination)
return true, nil
}
log.Printf("LND: destination offline: %x", destination)
return false, nil
}
func (c *LndClient) OpenChannel(req *lightning.OpenChannelRequest) (*wire.OutPoint, error) {
lnReq := &lnrpc.OpenChannelRequest{
NodePubkey: req.Destination,
LocalFundingAmount: int64(req.CapacitySat),
PushSat: 0,
Private: true,
CommitmentType: lnrpc.CommitmentType_ANCHORS,
ZeroConf: true,
}
if req.MinConfs != nil {
minConfs := *req.MinConfs
lnReq.MinConfs = int32(minConfs)
if minConfs == 0 {
lnReq.SpendUnconfirmed = true
}
}
if req.FeeSatPerVByte != nil {
lnReq.SatPerVbyte = uint64(*req.FeeSatPerVByte)
} else if req.TargetConf != nil {
lnReq.TargetConf = int32(*req.TargetConf)
}
channelPoint, err := c.client.OpenChannelSync(context.Background(), lnReq)
if err != nil {
log.Printf("LND: client.OpenChannelSync(%x, %v) error: %v", req.Destination, req.CapacitySat, err)
return nil, fmt.Errorf("LND: OpenChannel() error: %w", err)
}
result, err := lightning.NewOutPoint(channelPoint.GetFundingTxidBytes(), channelPoint.OutputIndex)
if err != nil {
log.Printf("LND: OpenChannel returned invalid outpoint. error: %v", err)
return nil, err
}
return result, nil
}
func (c *LndClient) GetChannel(peerID []byte, channelPoint wire.OutPoint) (*lightning.GetChannelResult, error) {
r, err := c.client.ListChannels(context.Background(), &lnrpc.ListChannelsRequest{Peer: peerID})
if err != nil {
log.Printf("client.ListChannels(%x) error: %v", peerID, err)
return nil, err
}
channelPointStr := channelPoint.String()
if err != nil {
return nil, err
}
for _, c := range r.Channels {
log.Printf("getChannel(%x): %v", peerID, c.ChanId)
if c.ChannelPoint == channelPointStr && c.Active {
aliasScid, confirmedScid := mapScidsFromChannel(c)
return &lightning.GetChannelResult{
AliasScid: aliasScid,
ConfirmedScid: confirmedScid,
HtlcMinimumMsat: c.LocalConstraints.MinHtlcMsat,
}, nil
}
}
log.Printf("No channel found: getChannel(%x)", peerID)
return nil, fmt.Errorf("no channel found")
}
func (c *LndClient) GetClosedChannels(nodeID string, channelPoints map[string]uint64) (map[string]uint64, error) {
r := make(map[string]uint64)
if len(channelPoints) == 0 {
return r, nil
}
waitingCloseChannels, err := c.getWaitingCloseChannels(nodeID)
if err != nil {
return nil, err
}
wcc := make(map[string]struct{})
for _, c := range waitingCloseChannels {
wcc[c.Channel.ChannelPoint] = struct{}{}
}
for c, h := range channelPoints {
if _, ok := wcc[c]; !ok {
r[c] = h
}
}
return r, nil
}
func (c *LndClient) getWaitingCloseChannels(nodeID string) ([]*lnrpc.PendingChannelsResponse_WaitingCloseChannel, error) {
pendingResponse, err := c.client.PendingChannels(context.Background(), &lnrpc.PendingChannelsRequest{})
if err != nil {
return nil, err
}
var waitingCloseChannels []*lnrpc.PendingChannelsResponse_WaitingCloseChannel
for _, p := range pendingResponse.WaitingCloseChannels {
if p.Channel.RemoteNodePub == nodeID {
waitingCloseChannels = append(waitingCloseChannels, p)
}
}
return waitingCloseChannels, nil
}
func (c *LndClient) GetPeerId(scid *lightning.ShortChannelID) ([]byte, error) {
scidu64 := uint64(*scid)
peer, err := c.client.GetPeerIdByScid(context.Background(), &lnrpc.GetPeerIdByScidRequest{
Scid: scidu64,
})
if err != nil {
return nil, err
}
if peer.PeerId == "" {
return nil, nil
}
peerid, _ := hex.DecodeString(peer.PeerId)
return peerid, nil
}
func (c *LndClient) WaitOnline(peerID []byte, deadline time.Time) error {
pkStr := hex.EncodeToString(peerID)
signal := make(chan struct{}, 10)
defer close(signal)
c.submtx.Lock()
subid := c.index
c.index++
subs, ok := c.peersubs[pkStr]
if !ok {
subs = make(map[uint64]chan struct{})
c.peersubs[pkStr] = subs
}
subs[subid] = signal
c.submtx.Unlock()
defer func() {
c.submtx.Lock()
subs, ok := c.peersubs[pkStr]
if ok {
delete(subs, subid)
if len(subs) == 0 {
delete(c.peersubs, pkStr)
}
}
c.submtx.Unlock()
}()
connected, err := c.IsConnected(peerID)
if err != nil {
return err
}
if connected {
return nil
}
select {
case <-signal:
return nil
case <-time.After(time.Until(deadline)):
return fmt.Errorf("deadline exceeded")
}
}
func (c *LndClient) WaitChannelActive(peerID []byte, deadline time.Time) error {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// Fetch the channels for this peer
chans, err := c.client.ListChannels(ctx, &lnrpc.ListChannelsRequest{
Peer: peerID,
})
if err != nil {
return err
}
if len(chans.Channels) == 0 {
return fmt.Errorf("no channels with peer")
}
// Exit now if a channel is already active.
for _, ch := range chans.Channels {
if ch.Active {
return nil
}
}
signal := make(chan struct{}, 10)
defer close(signal)
// Subscribe to channel active events from this channel
c.submtx.Lock()
for _, ch := range chans.Channels {
chansignal := make(chan struct{}, 10)
defer close(chansignal)
// forward signals from all channels to the signal aggregate
go func(c chan struct{}) {
for msg := range c {
signal <- msg
}
}(chansignal)
outpoint := ch.ChannelPoint
subid := c.index
c.index++
subs, ok := c.chansubs[outpoint]
if !ok {
subs = make(map[uint64]chan struct{})
c.chansubs[outpoint] = subs
}
subs[subid] = chansignal
defer func() {
c.submtx.Lock()
subs, ok := c.chansubs[outpoint]
if ok {
delete(subs, subid)
if len(subs) == 0 {
delete(c.chansubs, outpoint)
}
}
c.submtx.Unlock()
}()
}
c.submtx.Unlock()
// Fetch the channels for this peer again, so there is no gap between the
// subscription and the call.
chans, err = c.client.ListChannels(ctx, &lnrpc.ListChannelsRequest{
Peer: peerID,
})
if err != nil {
return err
}
// Exit now if a channel is already active.
for _, ch := range chans.Channels {
if ch.Active {
return nil
}
}
select {
case <-signal:
return nil
case <-time.After(time.Until(deadline)):
return fmt.Errorf("deadline exceeded")
}
}
func (c *LndClient) ListChannels() ([]*lightning.Channel, error) {
channels, err := c.client.ListChannels(
context.TODO(),
&lnrpc.ListChannelsRequest{},
)
if err != nil {
return nil, err
}
pendingChannels, err := c.client.PendingChannels(
context.TODO(),
&lnrpc.PendingChannelsRequest{},
)
if err != nil {
return nil, err
}
result := make([]*lightning.Channel, len(channels.Channels))
for i, c := range channels.Channels {
peerId, err := hex.DecodeString(c.RemotePubkey)
if err != nil {
log.Printf("hex.DecodeString in LndClient.ListChannels error: %v", err)
continue
}
alias, confirmedScid := mapScidsFromChannel(c)
outpoint, err := lightning.NewOutPointFromString(c.ChannelPoint)
if err != nil {
log.Printf("lightning.NewOutPointFromString(%s) in LndClient.ListChannels error: %v", c.ChannelPoint, err)
}
result[i] = &lightning.Channel{
AliasScid: alias,
ConfirmedScid: confirmedScid,
ChannelPoint: outpoint,
PeerId: peerId,
}
}
for _, c := range pendingChannels.PendingOpenChannels {
peerId, err := hex.DecodeString(c.Channel.RemoteNodePub)
if err != nil {
log.Printf("hex.DecodeString in LndClient.ListChannels error: %v", err)
continue
}
outpoint, err := lightning.NewOutPointFromString(c.Channel.ChannelPoint)
if err != nil {
log.Printf("lightning.NewOutPointFromString(%s) in LndClient.ListChannels error: %v", c.Channel.ChannelPoint, err)
}
result = append(result, &lightning.Channel{
AliasScid: nil,
ConfirmedScid: nil,
ChannelPoint: outpoint,
PeerId: peerId,
})
}
return result, nil
}
func mapScidsFromChannel(c *lnrpc.Channel) (*lightning.ShortChannelID, *lightning.ShortChannelID) {
var alias *lightning.ShortChannelID
var confirmedScid *lightning.ShortChannelID
if c.ZeroConf {
if c.ZeroConfConfirmedScid != 0 {
confirmedScid = (*lightning.ShortChannelID)(&c.ZeroConfConfirmedScid)
}
alias = (*lightning.ShortChannelID)(&c.ChanId)
} else {
confirmedScid = (*lightning.ShortChannelID)(&c.ChanId)
if len(c.AliasScids) > 0 {
alias = (*lightning.ShortChannelID)(&c.AliasScids[0])
}
}
return alias, confirmedScid
}