diff --git a/intercept.go b/intercept.go index dd5b98f..9978ada 100644 --- a/intercept.go +++ b/intercept.go @@ -2,6 +2,7 @@ package main import ( "bytes" + "encoding/hex" "fmt" "log" "math/big" @@ -12,6 +13,7 @@ import ( "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/record" "github.com/lightningnetwork/lnd/routing/route" + "golang.org/x/sync/singleflight" ) type interceptAction int @@ -30,6 +32,8 @@ var ( FAILURE_INCORRECT_OR_UNKNOWN_PAYMENT_DETAILS interceptFailureCode = 0x4015 ) +var payHashGroup singleflight.Group + type interceptResult struct { action interceptAction failureCode interceptFailureCode @@ -40,122 +44,128 @@ type interceptResult struct { } func intercept(reqPaymentHash []byte, reqOutgoingAmountMsat uint64, reqOutgoingExpiry uint32) interceptResult { - paymentHash, paymentSecret, destination, incomingAmountMsat, outgoingAmountMsat, channelPoint, err := paymentInfo(reqPaymentHash) - if err != nil { - log.Printf("paymentInfo(%x) error: %v", reqPaymentHash, err) - return interceptResult{ - action: INTERCEPT_FAIL_HTLC, + reqPaymentHashStr := hex.EncodeToString(reqPaymentHash) + resp, _, _ := payHashGroup.Do(reqPaymentHashStr, func() (interface{}, error) { + paymentHash, paymentSecret, destination, incomingAmountMsat, outgoingAmountMsat, channelPoint, err := paymentInfo(reqPaymentHash) + if err != nil { + log.Printf("paymentInfo(%x) error: %v", reqPaymentHash, err) + return interceptResult{ + action: INTERCEPT_FAIL_HTLC, + }, nil } - } - log.Printf("paymentHash:%x\npaymentSecret:%x\ndestination:%x\nincomingAmountMsat:%v\noutgoingAmountMsat:%v\n\n", - paymentHash, paymentSecret, destination, incomingAmountMsat, outgoingAmountMsat) - if paymentSecret != nil { + log.Printf("paymentHash:%x\npaymentSecret:%x\ndestination:%x\nincomingAmountMsat:%v\noutgoingAmountMsat:%v\n\n", + paymentHash, paymentSecret, destination, incomingAmountMsat, outgoingAmountMsat) + if paymentSecret != nil { - if channelPoint == nil { - if bytes.Equal(paymentHash, reqPaymentHash) { - channelPoint, err = openChannel(client, reqPaymentHash, destination, incomingAmountMsat) - log.Printf("openChannel(%x, %v) err: %v", destination, incomingAmountMsat, err) - if err != nil { - return interceptResult{ - action: INTERCEPT_FAIL_HTLC, + if channelPoint == nil { + if bytes.Equal(paymentHash, reqPaymentHash) { + channelPoint, err = openChannel(client, reqPaymentHash, destination, incomingAmountMsat) + log.Printf("openChannel(%x, %v) err: %v", destination, incomingAmountMsat, err) + if err != nil { + return interceptResult{ + action: INTERCEPT_FAIL_HTLC, + }, nil + } + } else { //probing + failureCode := FAILURE_TEMPORARY_CHANNEL_FAILURE + isConnected, _ := client.IsConnected(destination) + if err != nil || !*isConnected { + failureCode = FAILURE_INCORRECT_OR_UNKNOWN_PAYMENT_DETAILS } - } - } else { //probing - failureCode := FAILURE_TEMPORARY_CHANNEL_FAILURE - isConnected, _ := client.IsConnected(destination) - if err != nil || !*isConnected { - failureCode = FAILURE_INCORRECT_OR_UNKNOWN_PAYMENT_DETAILS - } + return interceptResult{ + action: INTERCEPT_FAIL_HTLC_WITH_CODE, + failureCode: failureCode, + }, nil + } + } + + pubKey, err := btcec.ParsePubKey(destination) + if err != nil { + log.Printf("btcec.ParsePubKey(%x): %v", destination, err) return interceptResult{ - action: INTERCEPT_FAIL_HTLC_WITH_CODE, - failureCode: failureCode, - } + action: INTERCEPT_FAIL_HTLC, + }, nil + } + + sessionKey, err := btcec.NewPrivateKey() + if err != nil { + log.Printf("btcec.NewPrivateKey(): %v", err) + return interceptResult{ + action: INTERCEPT_FAIL_HTLC, + }, nil + } + + var bigProd, bigAmt big.Int + amt := (bigAmt.Div(bigProd.Mul(big.NewInt(outgoingAmountMsat), big.NewInt(int64(reqOutgoingAmountMsat))), big.NewInt(incomingAmountMsat))).Int64() + + var addr [32]byte + copy(addr[:], paymentSecret) + hop := route.Hop{ + AmtToForward: lnwire.MilliSatoshi(amt), + OutgoingTimeLock: reqOutgoingExpiry, + MPP: record.NewMPP(lnwire.MilliSatoshi(outgoingAmountMsat), addr), + CustomRecords: make(record.CustomSet), + } + + var b bytes.Buffer + err = hop.PackHopPayload(&b, uint64(0)) + if err != nil { + log.Printf("hop.PackHopPayload(): %v", err) + return interceptResult{ + action: INTERCEPT_FAIL_HTLC, + }, nil + } + + payload, err := sphinx.NewHopPayload(nil, b.Bytes()) + if err != nil { + log.Printf("sphinx.NewHopPayload(): %v", err) + return interceptResult{ + action: INTERCEPT_FAIL_HTLC, + }, nil + } + + var sphinxPath sphinx.PaymentPath + sphinxPath[0] = sphinx.OnionHop{ + NodePub: *pubKey, + HopPayload: payload, + } + sphinxPacket, err := sphinx.NewOnionPacket( + &sphinxPath, sessionKey, reqPaymentHash, + sphinx.DeterministicPacketFiller, + ) + if err != nil { + log.Printf("sphinx.NewOnionPacket(): %v", err) + return interceptResult{ + action: INTERCEPT_FAIL_HTLC, + }, nil + } + var onionBlob bytes.Buffer + err = sphinxPacket.Encode(&onionBlob) + if err != nil { + log.Printf("sphinxPacket.Encode(): %v", err) + return interceptResult{ + action: INTERCEPT_FAIL_HTLC, + }, nil } - } - pubKey, err := btcec.ParsePubKey(destination) - if err != nil { - log.Printf("btcec.ParsePubKey(%x): %v", destination, err) return interceptResult{ - action: INTERCEPT_FAIL_HTLC, - } - } - - sessionKey, err := btcec.NewPrivateKey() - if err != nil { - log.Printf("btcec.NewPrivateKey(): %v", err) + action: INTERCEPT_RESUME_OR_CANCEL, + destination: destination, + channelPoint: channelPoint, + amountMsat: uint64(amt), + onionBlob: onionBlob.Bytes(), + }, nil + } else { return interceptResult{ - action: INTERCEPT_FAIL_HTLC, - } + action: INTERCEPT_RESUME, + }, nil } + }) - var bigProd, bigAmt big.Int - amt := (bigAmt.Div(bigProd.Mul(big.NewInt(outgoingAmountMsat), big.NewInt(int64(reqOutgoingAmountMsat))), big.NewInt(incomingAmountMsat))).Int64() - - var addr [32]byte - copy(addr[:], paymentSecret) - hop := route.Hop{ - AmtToForward: lnwire.MilliSatoshi(amt), - OutgoingTimeLock: reqOutgoingExpiry, - MPP: record.NewMPP(lnwire.MilliSatoshi(outgoingAmountMsat), addr), - CustomRecords: make(record.CustomSet), - } - - var b bytes.Buffer - err = hop.PackHopPayload(&b, uint64(0)) - if err != nil { - log.Printf("hop.PackHopPayload(): %v", err) - return interceptResult{ - action: INTERCEPT_FAIL_HTLC, - } - } - - payload, err := sphinx.NewHopPayload(nil, b.Bytes()) - if err != nil { - log.Printf("sphinx.NewHopPayload(): %v", err) - return interceptResult{ - action: INTERCEPT_FAIL_HTLC, - } - } - - var sphinxPath sphinx.PaymentPath - sphinxPath[0] = sphinx.OnionHop{ - NodePub: *pubKey, - HopPayload: payload, - } - sphinxPacket, err := sphinx.NewOnionPacket( - &sphinxPath, sessionKey, reqPaymentHash, - sphinx.DeterministicPacketFiller, - ) - if err != nil { - log.Printf("sphinx.NewOnionPacket(): %v", err) - return interceptResult{ - action: INTERCEPT_FAIL_HTLC, - } - } - var onionBlob bytes.Buffer - err = sphinxPacket.Encode(&onionBlob) - if err != nil { - log.Printf("sphinxPacket.Encode(): %v", err) - return interceptResult{ - action: INTERCEPT_FAIL_HTLC, - } - } - - return interceptResult{ - action: INTERCEPT_RESUME_OR_CANCEL, - destination: destination, - channelPoint: channelPoint, - amountMsat: uint64(amt), - onionBlob: onionBlob.Bytes(), - } - } else { - return interceptResult{ - action: INTERCEPT_RESUME, - } - } + return resp.(interceptResult) } + func checkPayment(incomingAmountMsat, outgoingAmountMsat int64) error { fees := incomingAmountMsat * channelFeePermyriad / 10_000 / 1_000 * 1_000 if fees < channelMinimumFeeMsat {