diff --git a/intercept.go b/intercept.go index f8ed8f4..dd5b98f 100644 --- a/intercept.go +++ b/intercept.go @@ -17,16 +17,17 @@ import ( type interceptAction int const ( - INTERCEPT_RESUME interceptAction = 0 - INTERCEPT_RESUME_OR_CANCEL interceptAction = 1 - INTERCEPT_FAIL_HTLC interceptAction = 2 + INTERCEPT_RESUME interceptAction = 0 + INTERCEPT_RESUME_OR_CANCEL interceptAction = 1 + INTERCEPT_FAIL_HTLC interceptAction = 2 + INTERCEPT_FAIL_HTLC_WITH_CODE interceptAction = 3 ) -type interceptFailureCode int +type interceptFailureCode uint16 -const ( - FAILURE_TEMPORARY_CHANNEL_FAILURE interceptFailureCode = 0 - FAILURE_INCORRECT_OR_UNKNOWN_PAYMENT_DETAILS interceptFailureCode = 1 +var ( + FAILURE_TEMPORARY_CHANNEL_FAILURE interceptFailureCode = 0x1007 + FAILURE_INCORRECT_OR_UNKNOWN_PAYMENT_DETAILS interceptFailureCode = 0x4015 ) type interceptResult struct { @@ -67,7 +68,7 @@ func intercept(reqPaymentHash []byte, reqOutgoingAmountMsat uint64, reqOutgoingE } return interceptResult{ - action: INTERCEPT_FAIL_HTLC, + action: INTERCEPT_FAIL_HTLC_WITH_CODE, failureCode: failureCode, } } diff --git a/lnd_interceptor.go b/lnd_interceptor.go index 6c9f319..a5e5434 100644 --- a/lnd_interceptor.go +++ b/lnd_interceptor.go @@ -7,7 +7,7 @@ import ( "os" "time" - "github.com/btcsuite/btcd/wire" + "github.com/lightningnetwork/lnd/lnrpc" "github.com/lightningnetwork/lnd/lnrpc/routerrpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/metadata" @@ -84,10 +84,15 @@ func (i *LndHtlcInterceptor) intercept() error { switch interceptResult.action { case INTERCEPT_RESUME_OR_CANCEL: go resumeOrCancel(clientCtx, interceptorClient, request.IncomingCircuitKey, - interceptResult.destination, *interceptResult.channelPoint, - interceptResult.amountMsat, interceptResult.onionBlob) + interceptResult) case INTERCEPT_FAIL_HTLC: failForwardSend(interceptorClient, request.IncomingCircuitKey) + case INTERCEPT_FAIL_HTLC_WITH_CODE: + interceptorClient.Send(&routerrpc.ForwardHtlcInterceptResponse{ + IncomingCircuitKey: request.IncomingCircuitKey, + Action: routerrpc.ResolveHoldForwardAction_FAIL, + FailureCode: mapFailureCode(interceptResult.failureCode), + }) case INTERCEPT_RESUME: fallthrough default: @@ -105,6 +110,17 @@ func (i *LndHtlcInterceptor) intercept() error { } } +func mapFailureCode(original interceptFailureCode) lnrpc.Failure_FailureCode { + switch original { + case FAILURE_TEMPORARY_CHANNEL_FAILURE: + return lnrpc.Failure_TEMPORARY_CHANNEL_FAILURE + case FAILURE_INCORRECT_OR_UNKNOWN_PAYMENT_DETAILS: + return lnrpc.Failure_INCORRECT_OR_UNKNOWN_PAYMENT_DETAILS + default: + return lnrpc.Failure_TEMPORARY_CHANNEL_FAILURE + } +} + func failForwardSend(interceptorClient routerrpc.Router_HtlcInterceptorClient, incomingCircuitKey *routerrpc.CircuitKey) { interceptorClient.Send(&routerrpc.ForwardHtlcInterceptResponse{ IncomingCircuitKey: incomingCircuitKey, @@ -116,14 +132,11 @@ func resumeOrCancel( ctx context.Context, interceptorClient routerrpc.Router_HtlcInterceptorClient, incomingCircuitKey *routerrpc.CircuitKey, - destination []byte, - channelPoint wire.OutPoint, - outgoingAmountMsat uint64, - onionBlob []byte, + interceptResult interceptResult, ) { deadline := time.Now().Add(10 * time.Second) for { - ch, err := client.GetChannel(destination, channelPoint) + ch, err := client.GetChannel(interceptResult.destination, *interceptResult.channelPoint) if err != nil { failForwardSend(interceptorClient, incomingCircuitKey) return @@ -133,19 +146,19 @@ func resumeOrCancel( interceptorClient.Send(&routerrpc.ForwardHtlcInterceptResponse{ IncomingCircuitKey: incomingCircuitKey, Action: routerrpc.ResolveHoldForwardAction_RESUME, - OutgoingAmountMsat: outgoingAmountMsat, + OutgoingAmountMsat: interceptResult.amountMsat, OutgoingRequestedChanId: uint64(ch.InitialChannelID), - OnionBlob: onionBlob, + OnionBlob: interceptResult.onionBlob, }) - err := insertChannel(uint64(ch.InitialChannelID), uint64(ch.ConfirmedChannelID), channelPoint.String(), destination, time.Now()) + err := insertChannel(uint64(ch.InitialChannelID), uint64(ch.ConfirmedChannelID), interceptResult.channelPoint.String(), interceptResult.destination, time.Now()) if err != nil { log.Printf("insertChannel error: %v", err) } return } - log.Printf("getChannel(%x, %v) returns 0", destination, channelPoint) + log.Printf("getChannel(%x, %v) returns 0", interceptResult.destination, interceptResult.channelPoint.String()) if time.Now().After(deadline) { - log.Printf("Stop retrying getChannel(%x, %v)", destination, channelPoint) + log.Printf("Stop retrying getChannel(%x, %v)", interceptResult.destination, interceptResult.channelPoint.String()) break } time.Sleep(1 * time.Second)