From 5b864f9cce70fd101b6d8e2f340a8a1b64f442db Mon Sep 17 00:00:00 2001 From: Jesse de Wit Date: Fri, 18 Nov 2022 16:33:53 +0100 Subject: [PATCH] implement cln client and interceptor --- cln_client.go | 198 ++++++++++++++++++++++++++++++++++++++++++++ cln_interceptor.go | 191 ++++++++++++++++++++++++++++++++++++++++++ go.mod | 8 +- short_channel_id.go | 39 +++++---- 4 files changed, 416 insertions(+), 20 deletions(-) create mode 100644 cln_client.go create mode 100644 cln_interceptor.go diff --git a/cln_client.go b/cln_client.go new file mode 100644 index 0000000..610c0d5 --- /dev/null +++ b/cln_client.go @@ -0,0 +1,198 @@ +package main + +import ( + "encoding/hex" + "fmt" + "log" + + "github.com/btcsuite/btcd/wire" + "github.com/niftynei/glightning/glightning" + "golang.org/x/exp/slices" +) + +type ClnClient struct { + client *glightning.Lightning +} + +var ( + OPEN_STATUSES = []string{"CHANNELD_NORMAL"} + PENDING_STATUSES = []string{"OPENINGD", "CHANNELD_AWAITING_LOCKIN"} + CLOSING_STATUSES = []string{"CHANNELD_SHUTTING_DOWN", "CLOSINGD_SIGEXCHANGE", "CLOSINGD_COMPLETE", "AWAITING_UNILATERAL", "FUNDING_SPEND_SEEN", "ONCHAIN"} + CLOSED_STATUSES = []string{"CLOSED"} +) + +func NewClnClient(rpcFile string, lightningDir string) *ClnClient { + client := glightning.NewLightning() + client.SetTimeout(60) + client.StartUp(rpcFile, lightningDir) + return &ClnClient{ + client: client, + } +} + +func (c *ClnClient) GetInfo() (*GetInfoResult, error) { + info, err := c.client.GetInfo() + if err != nil { + log.Printf("CLN: client.GetInfo() error: %v", err) + return nil, err + } + + return &GetInfoResult{ + Alias: info.Alias, + Pubkey: info.Id, + }, nil +} + +func (c *ClnClient) IsConnected(destination []byte) (*bool, error) { + pubKey := hex.EncodeToString(destination) + peers, err := c.client.ListPeers() + if err != nil { + log.Printf("CLN: client.ListPeers() error: %v", err) + return nil, fmt.Errorf("CLN: client.ListPeers() error: %w", err) + } + + for _, peer := range peers { + if pubKey == peer.Id { + log.Printf("destination online: %x", destination) + result := true + return &result, nil + } + } + + log.Printf("CLN: destination offline: %x", destination) + result := false + return &result, nil +} + +func (c *ClnClient) OpenChannel(req *OpenChannelRequest) (*wire.OutPoint, error) { + pubkey := hex.EncodeToString(req.Destination) + minConf := uint16(req.TargetConf) + if req.IsZeroConf { + minConf = 0 + } + + var minDepth *uint16 + if req.IsZeroConf { + var d uint16 = 0 + minDepth = &d + } + + fundResult, err := c.client.FundChannelExt( + pubkey, + glightning.NewSat(int(req.CapacitySat)), + &glightning.FeeRate{ + Directive: glightning.Slow, + }, + !req.IsPrivate, + &minConf, + glightning.NewMsat(0), + minDepth, + ) + + if err != nil { + log.Printf("CLN: client.FundChannelExt(%v, %v) error: %v", pubkey, req.CapacitySat, err) + return nil, err + } + + fundingTxId, err := hex.DecodeString(fundResult.FundingTxId) + if err != nil { + log.Printf("CLN: hex.DecodeString(%s) error: %v", fundResult.FundingTxId, err) + return nil, err + } + + channelPoint, err := NewOutPoint(fundingTxId, uint32(fundResult.FundingTxOutputNum)) + if err != nil { + log.Printf("CLN: NewOutPoint(%x, %d) error: %v", fundingTxId, fundResult.FundingTxOutputNum, err) + return nil, err + } + + return channelPoint, nil +} + +func (c *ClnClient) GetChannel(peerID []byte, channelPoint wire.OutPoint) (*GetChannelResult, error) { + pubkey := hex.EncodeToString(peerID) + peer, err := c.client.GetPeer(pubkey) + fundingTxID := channelPoint.Hash.String() + if err != nil { + log.Printf("CLN: client.GetPeer(%s) error: %v", pubkey, err) + return nil, err + } + + for _, c := range peer.Channels { + log.Printf("getChannel destination: %s, Short channel id: %v, local alias: %v , FundingTxID:%v, State:%v ", pubkey, c.ShortChannelId, c.Alias.Local, c.FundingTxId, c.State) + if slices.Contains(OPEN_STATUSES, c.State) && c.FundingTxId == fundingTxID { + confirmedChanID, err := NewShortChannelIDFromString(c.ShortChannelId) + if err != nil { + fmt.Printf("NewShortChannelIDFromString %v error: %v", c.ShortChannelId, err) + return nil, err + } + initialChanID, err := NewShortChannelIDFromString(c.Alias.Local) + if err != nil { + fmt.Printf("NewShortChannelIDFromString %v error: %v", c.Alias.Local, err) + return nil, err + } + return &GetChannelResult{ + InitialChannelID: *initialChanID, + ConfirmedChannelID: *confirmedChanID, + }, nil + } + } + + log.Printf("No channel found: getChannel(%v)", pubkey) + return nil, fmt.Errorf("no channel found") +} + +func (c *ClnClient) GetNodeChannelCount(nodeID []byte) (int, error) { + pubkey := hex.EncodeToString(nodeID) + peer, err := c.client.GetPeer(pubkey) + if err != nil { + log.Printf("CLN: client.GetPeer(%s) error: %v", pubkey, err) + return 0, err + } + + count := 0 + openPendingStatuses := append(OPEN_STATUSES, PENDING_STATUSES...) + for _, c := range peer.Channels { + if slices.Contains(openPendingStatuses, c.State) { + count++ + } + } + + return count, nil +} + +func (c *ClnClient) GetClosedChannels(nodeID string, channelPoints map[string]uint64) (map[string]uint64, error) { + r := make(map[string]uint64) + if len(channelPoints) == 0 { + return r, nil + } + + peer, err := c.client.GetPeer(nodeID) + if err != nil { + log.Printf("CLN: client.GetPeer(%s) error: %v", nodeID, err) + return nil, err + } + + lookup := make(map[string]uint64) + for _, c := range peer.Channels { + if slices.Contains(CLOSING_STATUSES, c.State) { + cid, err := NewShortChannelIDFromString(c.ShortChannelId) + if err != nil { + log.Printf("CLN: GetClosedChannels NewShortChannelIDFromString(%v) error: %v", c.ShortChannelId, err) + continue + } + + outnum := uint64(*cid) & 0xFFFFFF + cp := fmt.Sprintf("%s:%d", c.FundingTxId, outnum) + lookup[cp] = uint64(*cid) + } + } + + for c, h := range channelPoints { + if _, ok := lookup[c]; !ok { + r[c] = h + } + } + + return r, nil +} diff --git a/cln_interceptor.go b/cln_interceptor.go new file mode 100644 index 0000000..6414d56 --- /dev/null +++ b/cln_interceptor.go @@ -0,0 +1,191 @@ +package main + +import ( + "bytes" + "encoding/hex" + "fmt" + "io" + "log" + "os" + "time" + + sphinx "github.com/lightningnetwork/lightning-onion" + "github.com/lightningnetwork/lnd/record" + "github.com/lightningnetwork/lnd/tlv" + "github.com/niftynei/glightning/glightning" +) + +type ClnHtlcInterceptor struct { + client *ClnClient + plugin *glightning.Plugin +} + +func NewClnHtlcInterceptor(client *ClnClient) *ClnHtlcInterceptor { + return &ClnHtlcInterceptor{ + client: client, + } +} + +func (i *ClnHtlcInterceptor) Start() error { + //c-lightning plugin initiate + plugin := glightning.NewPlugin(onInit) + i.plugin = plugin + plugin.RegisterHooks(&glightning.Hooks{ + HtlcAccepted: i.OnHtlcAccepted, + }) + + err := plugin.Start(os.Stdin, os.Stdout) + if err != nil { + log.Printf("Plugin error: %v", err) + return err + } + + return nil +} + +func (i *ClnHtlcInterceptor) Stop() error { + plugin := i.plugin + if plugin != nil { + plugin.Stop() + } + + return nil +} + +func onInit(plugin *glightning.Plugin, options map[string]glightning.Option, config *glightning.Config) { + log.Printf("successfully init'd! %v\n", config.RpcFile) +} + +func (i *ClnHtlcInterceptor) OnHtlcAccepted(event *glightning.HtlcAcceptedEvent) (*glightning.HtlcAcceptedResponse, error) { + log.Printf("htlc_accepted called\n") + onion := event.Onion + + log.Printf("htlc: %v\nchanID: %v\nincoming amount: %v\noutgoing amount: %v\nincoming expiry: %v\noutgoing expiry: %v\npaymentHash: %v\nonionBlob: %v\n\n", + event.Htlc, + onion.ShortChannelId, + event.Htlc.AmountMilliSatoshi, //with fees + onion.ForwardAmount, + event.Htlc.CltvExpiryRelative, + event.Htlc.CltvExpiry, + event.Htlc.PaymentHash, + onion, + ) + + // fail htlc in case payment hash is not valid. + paymentHashBytes, err := hex.DecodeString(event.Htlc.PaymentHash) + if err != nil { + log.Printf("hex.DecodeString(%v) error: %v", event.Htlc.PaymentHash, err) + return event.Fail(uint16(FAILURE_TEMPORARY_CHANNEL_FAILURE)), nil + } + + interceptResult := intercept(paymentHashBytes, onion.ForwardAmount, uint32(event.Htlc.CltvExpiry)) + switch interceptResult.action { + case INTERCEPT_RESUME_OR_CANCEL: + return i.resumeOrCancel(event, interceptResult), nil + case INTERCEPT_FAIL_HTLC: + return event.Fail(uint16(FAILURE_TEMPORARY_CHANNEL_FAILURE)), nil + case INTERCEPT_FAIL_HTLC_WITH_CODE: + return event.Fail(uint16(interceptResult.failureCode)), nil + case INTERCEPT_RESUME: + fallthrough + default: + return event.Continue(), nil + } +} + +func (i *ClnHtlcInterceptor) resumeOrCancel(event *glightning.HtlcAcceptedEvent, interceptResult interceptResult) *glightning.HtlcAcceptedResponse { + deadline := time.Now().Add(60 * time.Second) + + for { + chanResult, _ := i.client.GetChannel(interceptResult.destination, *interceptResult.channelPoint) + if chanResult != nil { + log.Printf("channel opended successfully alias: %v, confirmed: %v", chanResult.InitialChannelID.ToString(), chanResult.ConfirmedChannelID.ToString()) + + err := insertChannel( + uint64(chanResult.InitialChannelID), + uint64(chanResult.ConfirmedChannelID), + interceptResult.channelPoint.String(), + interceptResult.destination, + time.Now(), + ) + + if err != nil { + log.Printf("insertChannel error: %v", err) + return event.Fail(uint16(FAILURE_TEMPORARY_CHANNEL_FAILURE)) + } + + channelID := uint64(chanResult.ConfirmedChannelID) + if channelID == 0 { + channelID = uint64(chanResult.InitialChannelID) + } + //decoding and encoding onion with alias in type 6 record. + newPayload, err := encodePayloadWithNextHop(event.Onion.Payload, channelID) + if err != nil { + log.Printf("encodePayloadWithNextHop error: %v", err) + return event.Fail(uint16(FAILURE_TEMPORARY_CHANNEL_FAILURE)) + } + + log.Printf("forwarding htlc to the destination node and a new private channel was opened") + return event.ContinueWithPayload(newPayload) + } + + log.Printf("waiting for channel to get opened.... %v\n", interceptResult.destination) + if time.Now().After(deadline) { + log.Printf("Stop retrying getChannel(%v, %v)", interceptResult.destination, interceptResult.channelPoint.String()) + break + } + time.Sleep(1 * time.Second) + } + log.Printf("Error: Channel failed to opened... timed out. ") + return event.Fail(uint16(FAILURE_TEMPORARY_CHANNEL_FAILURE)) +} + +func encodePayloadWithNextHop(payloadHex string, channelId uint64) (string, error) { + payload, err := hex.DecodeString(payloadHex) + if err != nil { + log.Fatalf("failed to decode types %v", err) + } + bufReader := bytes.NewBuffer(payload) + var b [8]byte + varInt, err := sphinx.ReadVarInt(bufReader, &b) + if err != nil { + return "", fmt.Errorf("failed to read payload length %v: %v", payloadHex, err) + } + + innerPayload := make([]byte, varInt) + if _, err := io.ReadFull(bufReader, innerPayload[:]); err != nil { + return "", fmt.Errorf("failed to decode payload %x: %v", innerPayload[:], err) + } + + s, _ := tlv.NewStream() + tlvMap, err := s.DecodeWithParsedTypes(bytes.NewReader(innerPayload)) + if err != nil { + return "", fmt.Errorf("DecodeWithParsedTypes failed for %x: %v", innerPayload[:], err) + } + + tt := record.NewNextHopIDRecord(&channelId) + buf := bytes.NewBuffer([]byte{}) + if err := tt.Encode(buf); err != nil { + return "", fmt.Errorf("failed to encode nexthop %x: %v", innerPayload[:], err) + } + + uTlvMap := make(map[uint64][]byte) + for t, b := range tlvMap { + if t == record.NextHopOnionType { + uTlvMap[uint64(t)] = buf.Bytes() + continue + } + uTlvMap[uint64(t)] = b + } + tlvRecords := tlv.MapToRecords(uTlvMap) + s, err = tlv.NewStream(tlvRecords...) + if err != nil { + return "", fmt.Errorf("tlv.NewStream(%x) error: %v", tlvRecords, err) + } + var newPayloadBuf bytes.Buffer + err = s.Encode(&newPayloadBuf) + if err != nil { + return "", fmt.Errorf("encode error: %v", err) + } + return hex.EncodeToString(newPayloadBuf.Bytes()), nil +} diff --git a/go.mod b/go.mod index 9860553..6566285 100644 --- a/go.mod +++ b/go.mod @@ -13,8 +13,9 @@ require ( github.com/grpc-ecosystem/go-grpc-middleware v1.3.0 github.com/jackc/pgtype v1.8.1 github.com/jackc/pgx/v4 v4.13.0 - github.com/lightningnetwork/lightning-onion v1.0.2-0.20220211021909-bb84a1ccb0c5 + github.com/lightningnetwork/lightning-onion v1.2.0 github.com/lightningnetwork/lnd v0.15.1-beta + github.com/niftynei/glightning v0.8.2 golang.org/x/sync v0.0.0-20210220032951-036812b2e83c google.golang.org/grpc v1.38.0 ) @@ -136,8 +137,9 @@ require ( go.uber.org/multierr v1.6.0 // indirect go.uber.org/zap v1.17.0 // indirect golang.org/x/crypto v0.0.0-20211215153901-e495a2d5b3d3 // indirect + golang.org/x/exp v0.0.0-20221114191408-850992195362 golang.org/x/net v0.0.0-20211216030914-fe4d6282115f // indirect - golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a // indirect + golang.org/x/sys v0.1.0 // indirect golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1 // indirect golang.org/x/text v0.3.7 // indirect golang.org/x/time v0.0.0-20210220033141-f8bda1e9f3ba // indirect @@ -154,3 +156,5 @@ require ( ) replace github.com/lightningnetwork/lnd v0.15.1-beta => github.com/breez/lnd v0.15.0-beta.rc6.0.20220831104847-00b86a81e57a + +replace github.com/niftynei/glightning v0.8.2 => github.com/breez/glightning v0.0.0-20220822151439-7bb360481467 diff --git a/short_channel_id.go b/short_channel_id.go index c8207a2..ef87f7b 100644 --- a/short_channel_id.go +++ b/short_channel_id.go @@ -4,37 +4,40 @@ import ( "fmt" "strconv" "strings" + + "github.com/lightningnetwork/lnd/lnwire" ) type ShortChannelID uint64 func NewShortChannelIDFromString(channelID string) (*ShortChannelID, error) { - parts := strings.Split(channelID, "x") - if len(parts) != 3 { - return nil, fmt.Errorf("expected 3 parts, got %d", len(parts)) + if channelID == "" { + return nil, nil } - blockHeight, err := strconv.Atoi(parts[0]) - if err != nil { - return nil, err + fields := strings.Split(channelID, "x") + if len(fields) != 3 { + return nil, fmt.Errorf("invalid short channel id %v", channelID) } - - txIndex, err := strconv.Atoi(parts[1]) - if err != nil { - return nil, err + var blockHeight, txIndex, txPos int64 + var err error + if blockHeight, err = strconv.ParseInt(fields[0], 10, 64); err != nil { + return nil, fmt.Errorf("failed to parse block height %v", fields[0]) } - - outputIndex, err := strconv.Atoi(parts[2]) - if err != nil { - return nil, err + if txIndex, err = strconv.ParseInt(fields[1], 10, 64); err != nil { + return nil, fmt.Errorf("failed to parse block height %v", fields[1]) + } + if txPos, err = strconv.ParseInt(fields[2], 10, 64); err != nil { + return nil, fmt.Errorf("failed to parse block height %v", fields[2]) } result := ShortChannelID( - (uint64(outputIndex) & 0xFFFF) + - ((uint64(txIndex) << 16) & 0xFFFFFF0000) + - ((uint64(blockHeight) << 40) & 0xFFFFFF0000000000), + lnwire.ShortChannelID{ + BlockHeight: uint32(blockHeight), + TxIndex: uint32(txIndex), + TxPosition: uint16(txPos), + }.ToUint64(), ) - return &result, nil }