Files
lspd/cln/cln_interceptor.go
2023-12-11 22:20:54 +01:00

328 lines
9.2 KiB
Go

package cln
import (
"bytes"
"context"
"encoding/hex"
"fmt"
"io"
"log"
"sync"
"time"
"github.com/breez/lspd/cln_plugin/proto"
"github.com/breez/lspd/common"
"github.com/breez/lspd/config"
"github.com/breez/lspd/lightning"
sphinx "github.com/lightningnetwork/lightning-onion"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/record"
"github.com/lightningnetwork/lnd/tlv"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/keepalive"
"google.golang.org/grpc/status"
)
type ClnHtlcInterceptor struct {
interceptor common.InterceptHandler
config *config.NodeConfig
pluginAddress string
client *ClnClient
pluginClient proto.ClnPluginClient
initWg sync.WaitGroup
doneWg sync.WaitGroup
stopRequested bool
ctx context.Context
cancel context.CancelFunc
}
func NewClnHtlcInterceptor(conf *config.NodeConfig, client *ClnClient, interceptor common.InterceptHandler) (*ClnHtlcInterceptor, error) {
i := &ClnHtlcInterceptor{
config: conf,
pluginAddress: conf.Cln.PluginAddress,
client: client,
interceptor: interceptor,
}
i.initWg.Add(1)
return i, nil
}
func (i *ClnHtlcInterceptor) Start() error {
ctx, cancel := context.WithCancel(context.Background())
log.Printf("Dialing cln plugin on '%s'", i.pluginAddress)
conn, err := grpc.DialContext(
ctx,
i.pluginAddress,
grpc.WithTransportCredentials(insecure.NewCredentials()),
grpc.WithKeepaliveParams(keepalive.ClientParameters{
Time: time.Duration(10) * time.Second,
Timeout: time.Duration(10) * time.Second,
}),
)
if err != nil {
log.Printf("grpc.Dial error: %v", err)
cancel()
return err
}
i.pluginClient = proto.NewClnPluginClient(conn)
i.ctx = ctx
i.cancel = cancel
i.stopRequested = false
return i.intercept()
}
func (i *ClnHtlcInterceptor) intercept() error {
inited := false
defer func() {
if !inited {
i.initWg.Done()
}
log.Printf("CLN intercept(): stopping. Waiting for in-progress interceptions to complete.")
i.doneWg.Wait()
}()
for {
if i.ctx.Err() != nil {
return i.ctx.Err()
}
log.Printf("Connecting CLN HTLC interceptor.")
interceptorClient, err := i.pluginClient.HtlcStream(i.ctx)
if err != nil {
log.Printf("pluginClient.HtlcStream(): %v", err)
<-time.After(time.Second)
continue
}
for {
if i.ctx.Err() != nil {
return i.ctx.Err()
}
if !inited {
inited = true
i.initWg.Done()
}
// Stop receiving if stop if requested. The defer func on top of this
// function will assure all htlcs that are currently being processed
// will complete.
if i.stopRequested {
return nil
}
request, err := interceptorClient.Recv()
if err != nil {
// If it is just the error result of the context cancellation
// the we exit silently.
status, ok := status.FromError(err)
if ok && status.Code() == codes.Canceled {
log.Printf("Got code canceled. Break.")
break
}
// Otherwise it an unexpected error, we fail the test.
log.Printf("unexpected error in interceptor.Recv() %v", err)
break
}
i.doneWg.Add(1)
go func() {
paymentHash, err := hex.DecodeString(request.Htlc.PaymentHash)
if err != nil {
interceptorClient.Send(i.defaultResolution(request))
i.doneWg.Done()
return
}
scid, err := lightning.NewShortChannelIDFromString(request.Onion.ShortChannelId)
if err != nil {
interceptorClient.Send(i.defaultResolution(request))
i.doneWg.Done()
return
}
interceptResult := i.interceptor.Intercept(common.InterceptRequest{
Identifier: request.Onion.SharedSecret,
Scid: *scid,
PaymentHash: paymentHash,
IncomingAmountMsat: request.Htlc.AmountMsat,
OutgoingAmountMsat: request.Onion.ForwardMsat,
IncomingExpiry: request.Htlc.CltvExpiry,
OutgoingExpiry: request.Onion.OutgoingCltvValue,
})
switch interceptResult.Action {
case common.INTERCEPT_RESUME_WITH_ONION:
interceptorClient.Send(i.resumeWithOnion(request, interceptResult))
case common.INTERCEPT_FAIL_HTLC_WITH_CODE:
interceptorClient.Send(
i.failWithCode(request, interceptResult.FailureCode),
)
case common.INTERCEPT_IGNORE:
// Do nothing
case common.INTERCEPT_RESUME:
fallthrough
default:
interceptorClient.Send(
i.defaultResolution(request),
)
}
i.doneWg.Done()
}()
}
<-time.After(time.Second)
}
}
func (i *ClnHtlcInterceptor) Stop() error {
// Setting stopRequested to true will make the interceptor stop receiving.
i.stopRequested = true
// Wait until all already received htlcs are handled, responses sent back.
i.doneWg.Wait()
// Close the grpc connection.
i.cancel()
return nil
}
func (i *ClnHtlcInterceptor) WaitStarted() {
i.initWg.Wait()
}
func (i *ClnHtlcInterceptor) resumeWithOnion(request *proto.HtlcAccepted, interceptResult common.InterceptResult) *proto.HtlcResolution {
//decoding and encoding onion with alias in type 6 record.
payload, err := hex.DecodeString(request.Onion.Payload)
if err != nil {
log.Printf("paymenthash: %s, resumeWithOnion: hex.DecodeString(%v) error: %v", request.Htlc.PaymentHash, request.Onion.Payload, err)
return i.failWithCode(request, common.FAILURE_TEMPORARY_CHANNEL_FAILURE)
}
newPayload, err := encodePayloadWithNextHop(payload, interceptResult.Scid, interceptResult.AmountMsat, interceptResult.FeeMsat)
if err != nil {
log.Printf("paymenthash: %s, encodePayloadWithNextHop error: %v", request.Htlc.PaymentHash, err)
return i.failWithCode(request, common.FAILURE_TEMPORARY_CHANNEL_FAILURE)
}
newPayloadStr := hex.EncodeToString(newPayload)
chanId := lnwire.NewChanIDFromOutPoint(interceptResult.ChannelPoint).String()
log.Printf("paymenthash: %s, forwarding htlc to the destination node and a new private channel was opened", request.Htlc.PaymentHash)
return &proto.HtlcResolution{
Correlationid: request.Correlationid,
Outcome: &proto.HtlcResolution_Continue{
Continue: &proto.HtlcContinue{
ForwardTo: &chanId,
Payload: &newPayloadStr,
},
},
}
}
func (i *ClnHtlcInterceptor) defaultResolution(request *proto.HtlcAccepted) *proto.HtlcResolution {
return &proto.HtlcResolution{
Correlationid: request.Correlationid,
Outcome: &proto.HtlcResolution_Continue{
Continue: &proto.HtlcContinue{},
},
}
}
func (i *ClnHtlcInterceptor) failWithCode(request *proto.HtlcAccepted, code common.InterceptFailureCode) *proto.HtlcResolution {
log.Printf("paymenthash: %s, failing htlc with code: '%x'", request.Htlc.PaymentHash, code)
return &proto.HtlcResolution{
Correlationid: request.Correlationid,
Outcome: &proto.HtlcResolution_Fail{
Fail: &proto.HtlcFail{
Failure: &proto.HtlcFail_FailureMessage{
FailureMessage: i.mapFailureCode(code),
},
},
},
}
}
func encodePayloadWithNextHop(payload []byte, scid lightning.ShortChannelID, amountToForward uint64, feeMsat *uint64) ([]byte, error) {
bufReader := bytes.NewBuffer(payload)
var b [8]byte
varInt, err := sphinx.ReadVarInt(bufReader, &b)
if err != nil {
return nil, fmt.Errorf("failed to read payload length %x: %v", payload, err)
}
innerPayload := make([]byte, varInt)
if _, err := io.ReadFull(bufReader, innerPayload[:]); err != nil {
return nil, fmt.Errorf("failed to decode payload %x: %v", innerPayload[:], err)
}
s, _ := tlv.NewStream()
tlvMap, err := s.DecodeWithParsedTypes(bytes.NewReader(innerPayload))
if err != nil {
return nil, fmt.Errorf("DecodeWithParsedTypes failed for %x: %v", innerPayload[:], err)
}
channelId := uint64(scid)
tt := record.NewNextHopIDRecord(&channelId)
ttbuf := bytes.NewBuffer([]byte{})
if err := tt.Encode(ttbuf); err != nil {
return nil, fmt.Errorf("failed to encode nexthop %x: %v", innerPayload[:], err)
}
amt := record.NewAmtToFwdRecord(&amountToForward)
amtbuf := bytes.NewBuffer([]byte{})
if err := amt.Encode(amtbuf); err != nil {
return nil, fmt.Errorf("failed to encode AmtToFwd %x: %v", innerPayload[:], err)
}
uTlvMap := make(map[uint64][]byte)
for t, b := range tlvMap {
if t == record.NextHopOnionType {
uTlvMap[uint64(t)] = ttbuf.Bytes()
continue
}
if t == record.AmtOnionType {
uTlvMap[uint64(t)] = amtbuf.Bytes()
continue
}
uTlvMap[uint64(t)] = b
}
tlvRecords := tlv.MapToRecords(uTlvMap)
s, err = tlv.NewStream(tlvRecords...)
if err != nil {
return nil, fmt.Errorf("tlv.NewStream(%v) error: %v", tlvRecords, err)
}
var newPayloadBuf bytes.Buffer
err = s.Encode(&newPayloadBuf)
if err != nil {
return nil, fmt.Errorf("encode error: %v", err)
}
return newPayloadBuf.Bytes(), nil
}
func (i *ClnHtlcInterceptor) mapFailureCode(original common.InterceptFailureCode) string {
switch original {
case common.FAILURE_TEMPORARY_CHANNEL_FAILURE:
return "1007"
case common.FAILURE_AMOUNT_BELOW_MINIMUM:
return "100B"
case common.FAILURE_INCORRECT_CLTV_EXPIRY:
return "100D"
case common.FAILURE_TEMPORARY_NODE_FAILURE:
return "2002"
case common.FAILURE_UNKNOWN_NEXT_PEER:
return "400A"
case common.FAILURE_INCORRECT_OR_UNKNOWN_PAYMENT_DETAILS:
return "400F"
default:
log.Printf("Unknown failure code %v, default to temporary channel failure.", original)
return "1007" // temporary channel failure
}
}