diff --git a/lsps2/intercept_handler.go b/lsps2/intercept_handler.go new file mode 100644 index 0000000..6b106b2 --- /dev/null +++ b/lsps2/intercept_handler.go @@ -0,0 +1,674 @@ +package lsps2 + +import ( + "context" + "encoding/hex" + "fmt" + "log" + "math" + "strings" + "time" + + "github.com/breez/lspd/chain" + "github.com/breez/lspd/lightning" + "github.com/breez/lspd/lsps0" + "github.com/breez/lspd/shared" + "github.com/btcsuite/btcd/wire" +) + +type InterceptorConfig struct { + AdditionalChannelCapacitySat uint64 + MinConfs *uint32 + TargetConf uint32 + FeeStrategy chain.FeeStrategy + MinPaymentSizeMsat uint64 + MaxPaymentSizeMsat uint64 + TimeLockDelta uint32 + HtlcMinimumMsat uint64 + MppTimeout time.Duration +} + +type Interceptor struct { + store Lsps2Store + openingService shared.OpeningService + client lightning.Client + feeEstimator chain.FeeEstimator + config *InterceptorConfig + newPart chan *partState + registrationFetched chan *registrationFetchedEvent + paymentReady chan string + paymentFailure chan *paymentFailureEvent + paymentChanOpened chan *paymentChanOpenedEvent + inflightPayments map[string]*paymentState +} + +func NewInterceptHandler( + store Lsps2Store, + openingService shared.OpeningService, + client lightning.Client, + feeEstimator chain.FeeEstimator, + config *InterceptorConfig, +) *Interceptor { + if config.MppTimeout.Nanoseconds() == 0 { + config.MppTimeout = time.Duration(90 * time.Second) + } + + return &Interceptor{ + store: store, + openingService: openingService, + client: client, + feeEstimator: feeEstimator, + config: config, + // TODO: make sure the chan sizes do not lead to deadlocks. + newPart: make(chan *partState, 1000), + registrationFetched: make(chan *registrationFetchedEvent, 1000), + paymentReady: make(chan string, 1000), + paymentFailure: make(chan *paymentFailureEvent, 1000), + paymentChanOpened: make(chan *paymentChanOpenedEvent, 1000), + inflightPayments: make(map[string]*paymentState), + } +} + +type paymentState struct { + id string + fakeScid lightning.ShortChannelID + outgoingSumMsat uint64 + paymentSizeMsat uint64 + feeMsat uint64 + registration *BuyRegistration + parts map[string]*partState + isFinal bool + timoutChanClosed bool + timeoutChan chan struct{} +} + +func (p *paymentState) closeTimeoutChan() { + if p.timoutChanClosed { + return + } + + close(p.timeoutChan) + p.timoutChanClosed = true +} + +type partState struct { + req *shared.InterceptRequest + resolution chan *shared.InterceptResult +} + +type registrationFetchedEvent struct { + paymentId string + isRegistered bool + registration *BuyRegistration +} + +type paymentChanOpenedEvent struct { + paymentId string + scid lightning.ShortChannelID + channelPoint *wire.OutPoint + htlcMinimumMsat uint64 +} + +type paymentFailureEvent struct { + paymentId string + code shared.InterceptFailureCode +} + +func (i *Interceptor) Start(ctx context.Context) { + // Main event loop for stages of htlcs to be handled. Note that the event + // loop has to execute quickly, so any code running in the 'handle' methods + // must execute quickly. If there is i/o involved, or any waiting, run that + // code in a goroutine, and place an event onto the event loop to continue + // processing after the slow operation is done. + // The nice thing about an event loop is that it runs on a single thread. + // So there's no locking needed, as long as everything that needs + // synchronization goes through the event loop. + for { + select { + case part := <-i.newPart: + i.handleNewPart(part) + case ev := <-i.registrationFetched: + i.handleRegistrationFetched(ev) + case paymentId := <-i.paymentReady: + i.handlePaymentReady(paymentId) + case ev := <-i.paymentFailure: + i.handlePaymentFailure(ev.paymentId, ev.code) + case ev := <-i.paymentChanOpened: + i.handlePaymentChanOpened(ev) + } + } +} + +func (i *Interceptor) handleNewPart(part *partState) { + // Get the associated payment for this part, or create a new payment if it + // doesn't exist for this part yet. + paymentId := part.req.PaymentId() + payment, paymentExisted := i.inflightPayments[paymentId] + if !paymentExisted { + payment = &paymentState{ + id: paymentId, + fakeScid: part.req.Scid, + parts: make(map[string]*partState), + timeoutChan: make(chan struct{}), + } + i.inflightPayments[paymentId] = payment + } + + // Check whether we already have this part, because it may have been + // replayed. + existingPart, partExisted := payment.parts[part.req.HtlcId()] + // Adds the part to the in-progress parts. Or replaces it, if it already + // exists, to make sure we always reply to the correct identifier. If a htlc + // was replayed, assume the latest event is the truth to respond to. + payment.parts[part.req.HtlcId()] = part + + if partExisted { + // If the part already existed, that means it has been replayed. In this + // case the first occurence can be safely ignored, because we won't be + // able to reply to that htlc anyway. Keep the last replayed version for + // further processing. This result below tells the caller to ignore the + // htlc. + existingPart.resolution <- &shared.InterceptResult{ + Action: shared.INTERCEPT_IGNORE, + } + + return + } + + // If this is the first part for this payment, setup the timeout, and fetch + // the registration. + if !paymentExisted { + go func() { + select { + case <-time.After(i.config.MppTimeout): + // Handle timeout inside the event loop, to make sure there are + // no race conditions, since this timeout watcher is running in + // a goroutine. + i.paymentFailure <- &paymentFailureEvent{ + paymentId: paymentId, + code: shared.FAILURE_TEMPORARY_CHANNEL_FAILURE, + } + case <-payment.timeoutChan: + // Stop listening for timeouts when the payment is ready. + } + }() + + // Fetch the buy registration in a goroutine, to avoid blocking the + // event loop. + go i.fetchRegistration(part.req.PaymentId(), part.req.Scid) + } + + // If the registration was already fetched, this part might complete the + // payment. Process the part. Otherwise, the part will be processed after + // the registration was fetched, as a result of 'registrationFetched'. + if payment.registration != nil { + i.processPart(payment, part) + } +} + +func (i *Interceptor) processPart(payment *paymentState, part *partState) { + if payment.registration.IsComplete { + i.failPart(payment, part, shared.FAILURE_UNKNOWN_NEXT_PEER) + return + } + + // Fail parts that come in after the payment is already final. To avoid + // inconsistencies in the payment state. + if payment.isFinal { + i.failPart(payment, part, shared.FAILURE_UNKNOWN_NEXT_PEER) + return + } + + var err error + if payment.registration.Mode == OpeningMode_NoMppVarInvoice { + // Mode == no-MPP+var-invoice + if payment.paymentSizeMsat != 0 { + // Another part is already processed for this payment, and with + // no-MPP+var-invoice there can be only a single part, so this + // part will be failed back. + i.failPart(payment, part, shared.FAILURE_UNKNOWN_NEXT_PEER) + return + } + + // If the mode is no-MPP+var-invoice, the payment size comes from + // the actual forwarded amount. + payment.paymentSizeMsat = part.req.OutgoingAmountMsat + + // Make sure the minimum and maximum are not exceeded. + if payment.paymentSizeMsat > i.config.MaxPaymentSizeMsat || + payment.paymentSizeMsat < i.config.MinPaymentSizeMsat { + i.failPart(payment, part, shared.FAILURE_UNKNOWN_NEXT_PEER) + return + } + + // Make sure there is enough fee to deduct. + payment.feeMsat, err = computeOpeningFee( + payment.paymentSizeMsat, + payment.registration.OpeningFeeParams.Proportional, + payment.registration.OpeningFeeParams.MinFeeMsat, + ) + if err != nil { + i.failPart(payment, part, shared.FAILURE_UNKNOWN_NEXT_PEER) + return + } + + // Make sure the part fits the htlc and fee constraints. + if payment.feeMsat+i.config.HtlcMinimumMsat > + payment.paymentSizeMsat { + i.failPart(payment, part, shared.FAILURE_UNKNOWN_NEXT_PEER) + return + } + } else { + // Mode == MPP+fixed-invoice + payment.paymentSizeMsat = *payment.registration.PaymentSizeMsat + payment.feeMsat, err = computeOpeningFee( + payment.paymentSizeMsat, + payment.registration.OpeningFeeParams.Proportional, + payment.registration.OpeningFeeParams.MinFeeMsat, + ) + if err != nil { + log.Printf( + "Opening fee calculation error while trying to open channel "+ + "for scid %s: %v", + payment.registration.Scid.ToString(), + err, + ) + i.failPart(payment, part, shared.FAILURE_UNKNOWN_NEXT_PEER) + return + } + } + + // Make sure the cltv delta is enough (actual cltv delta + 2). + if int64(part.req.IncomingExpiry)-int64(part.req.OutgoingExpiry) < + int64(i.config.TimeLockDelta)+2 { + i.failPart(payment, part, shared.FAILURE_INCORRECT_CLTV_EXPIRY) + return + } + + // Make sure htlc minimum is enough + if part.req.OutgoingAmountMsat < i.config.HtlcMinimumMsat { + i.failPart(payment, part, shared.FAILURE_AMOUNT_BELOW_MINIMUM) + return + } + + // Make sure we're not getting tricked + if part.req.IncomingAmountMsat < part.req.OutgoingAmountMsat { + i.failPart(payment, part, shared.FAILURE_AMOUNT_BELOW_MINIMUM) + return + } + + // Update the sum of htlcs currently in-flight. + payment.outgoingSumMsat += part.req.OutgoingAmountMsat + + // If payment_size_msat is reached, the payment is ready to forward. (this + // is always true in no-MPP+var-invoice mode) + if payment.outgoingSumMsat >= payment.paymentSizeMsat { + payment.isFinal = true + i.paymentReady <- part.req.PaymentId() + } +} + +// Fetches the registration from the store. If the registration exists, posts +// a registrationReady event. If it doesn't, posts a 'notRegistered' event. +func (i *Interceptor) fetchRegistration( + paymentId string, + scid lightning.ShortChannelID, +) { + registration, err := i.store.GetBuyRegistration( + context.TODO(), + scid, + ) + + if err != nil && err != ErrNotFound { + log.Printf( + "Failed to get buy registration for %v: %v", + uint64(scid), + err, + ) + } + + i.registrationFetched <- ®istrationFetchedEvent{ + paymentId: paymentId, + isRegistered: err == nil, + registration: registration, + } +} + +func (i *Interceptor) handleRegistrationFetched(ev *registrationFetchedEvent) { + if !ev.isRegistered { + i.finalizeAllParts(ev.paymentId, &shared.InterceptResult{ + Action: shared.INTERCEPT_RESUME, + }) + return + } + + payment, ok := i.inflightPayments[ev.paymentId] + if !ok { + // Apparently the payment is already finished. + return + } + + payment.registration = ev.registration + for _, part := range payment.parts { + i.processPart(payment, part) + } +} + +func (i *Interceptor) handlePaymentReady(paymentId string) { + payment, ok := i.inflightPayments[paymentId] + if !ok { + // Apparently this payment is already finalized. + return + } + + // TODO: Handle notifications. + // Stops the timeout listeners + payment.closeTimeoutChan() + + go i.ensureChannelOpen(payment) +} + +// Opens a channel to the destination and waits for the channel to become +// active. When the channel is active, sends an openChanEvent. Should be run in +// a goroutine. +func (i *Interceptor) ensureChannelOpen(payment *paymentState) { + destination, _ := hex.DecodeString(payment.registration.PeerId) + + if payment.registration.ChannelPoint == nil { + + validUntil, err := time.Parse( + lsps0.TIME_FORMAT, + payment.registration.OpeningFeeParams.ValidUntil, + ) + if err != nil { + log.Printf( + "Failed parse validUntil '%s' for paymentId %s: %v", + payment.registration.OpeningFeeParams.ValidUntil, + payment.id, + err, + ) + i.paymentFailure <- &paymentFailureEvent{ + paymentId: payment.id, + code: shared.FAILURE_UNKNOWN_NEXT_PEER, + } + return + } + + // With expired fee params, the current chainfees are checked. If + // they're not cheaper now, fail the payment. + if time.Now().After(validUntil) && + !i.openingService.IsCurrentChainFeeCheaper( + payment.registration.Token, + &payment.registration.OpeningFeeParams, + ) { + log.Printf("LSPS2: Intercepted expired payment registration. "+ + "Failing payment. scid: %s, valid until: %s, destination: %s", + payment.fakeScid.ToString(), + payment.registration.OpeningFeeParams.ValidUntil, + payment.registration.PeerId, + ) + i.paymentFailure <- &paymentFailureEvent{ + paymentId: payment.id, + code: shared.FAILURE_UNKNOWN_NEXT_PEER, + } + return + } + + var targetConf *uint32 + confStr := "" + var feeEstimation *float64 + feeStr := "" + if i.feeEstimator != nil { + fee, err := i.feeEstimator.EstimateFeeRate( + context.Background(), + i.config.FeeStrategy, + ) + if err == nil { + feeEstimation = &fee.SatPerVByte + feeStr = fmt.Sprintf("%.5f", *feeEstimation) + } else { + log.Printf("Error estimating chain fee, fallback to target "+ + "conf: %v", err) + targetConf = &i.config.TargetConf + confStr = fmt.Sprintf("%v", *targetConf) + } + } + + capacity := ((payment.paymentSizeMsat - payment.feeMsat + + 999) / 1000) + i.config.AdditionalChannelCapacitySat + + log.Printf( + "LSPS2: Opening zero conf channel. Destination: %x, capacity: %v, "+ + "fee: %s, targetConf: %s", + destination, + capacity, + feeStr, + confStr, + ) + + channelPoint, err := i.client.OpenChannel(&lightning.OpenChannelRequest{ + Destination: destination, + CapacitySat: uint64(capacity), + MinConfs: i.config.MinConfs, + IsPrivate: true, + IsZeroConf: true, + FeeSatPerVByte: feeEstimation, + TargetConf: targetConf, + }) + if err != nil { + log.Printf( + "LSPS2 openChannel: client.OpenChannel(%x, %v) error: %v", + destination, + capacity, + err, + ) + + code := shared.FAILURE_UNKNOWN_NEXT_PEER + if strings.Contains(err.Error(), "not enough funds") { + code = shared.FAILURE_TEMPORARY_CHANNEL_FAILURE + } + + // TODO: Verify that a client disconnect before receiving + // funding_signed doesn't cause the OpenChannel call to error. + // unknown_next_peer should only be returned if the client rejects + // the channel, or the channel cannot be opened at all. If the + // client disconnects before receiving funding_signed, + // temporary_channel_failure should be returned. + i.paymentFailure <- &paymentFailureEvent{ + paymentId: payment.id, + code: code, + } + return + } + + err = i.store.SetChannelOpened( + context.TODO(), + &ChannelOpened{ + RegistrationId: payment.registration.Id, + Outpoint: channelPoint, + FeeMsat: payment.feeMsat, + PaymentSizeMsat: payment.paymentSizeMsat, + }, + ) + if err != nil { + log.Printf( + "LSPS2 openChannel: store.SetOpenedChannel(%d, %s) error: %v", + payment.registration.Id, + channelPoint.String(), + err, + ) + i.paymentFailure <- &paymentFailureEvent{ + paymentId: payment.id, + code: shared.FAILURE_TEMPORARY_CHANNEL_FAILURE, + } + return + } + + payment.registration.ChannelPoint = channelPoint + // TODO: Send open channel email notification. + } + deadline := time.Now().Add(time.Minute) + // Wait for the channel to open. + for { + chanResult, _ := i.client.GetChannel( + destination, + *payment.registration.ChannelPoint, + ) + if chanResult == nil { + select { + case <-time.After(time.Second): + continue + case <-time.After(time.Until(deadline)): + i.paymentFailure <- &paymentFailureEvent{ + paymentId: payment.id, + code: shared.FAILURE_TEMPORARY_CHANNEL_FAILURE, + } + return + } + } + log.Printf( + "Got new channel for forward successfully. scid alias: %v, "+ + "confirmed scid: %v", + chanResult.InitialChannelID.ToString(), + chanResult.ConfirmedChannelID.ToString(), + ) + + scid := chanResult.ConfirmedChannelID + if uint64(scid) == 0 { + scid = chanResult.InitialChannelID + } + + i.paymentChanOpened <- &paymentChanOpenedEvent{ + paymentId: payment.id, + scid: scid, + channelPoint: payment.registration.ChannelPoint, + htlcMinimumMsat: chanResult.HtlcMinimumMsat, + } + break + } +} + +func (i *Interceptor) handlePaymentChanOpened(event *paymentChanOpenedEvent) { + payment, ok := i.inflightPayments[event.paymentId] + if !ok { + // Apparently this payment is already finalized. + return + } + feeRemainingMsat := payment.feeMsat + + destination, _ := hex.DecodeString(payment.registration.PeerId) + // Deduct the lsp fee from the parts to forward. + resolutions := []*struct { + part *partState + resolution *shared.InterceptResult + }{} + for _, part := range payment.parts { + deductMsat := uint64(math.Min( + float64(feeRemainingMsat), + float64(part.req.OutgoingAmountMsat-event.htlcMinimumMsat), + )) + feeRemainingMsat -= deductMsat + amountMsat := part.req.OutgoingAmountMsat - deductMsat + var feeMsat *uint64 + if deductMsat > 0 { + feeMsat = &deductMsat + } + resolutions = append(resolutions, &struct { + part *partState + resolution *shared.InterceptResult + }{ + part: part, + resolution: &shared.InterceptResult{ + Action: shared.INTERCEPT_RESUME_WITH_ONION, + Destination: destination, + ChannelPoint: event.channelPoint, + AmountMsat: amountMsat, + FeeMsat: feeMsat, + Scid: event.scid, + }, + }) + } + + if feeRemainingMsat > 0 { + // It is possible this case happens if the htlc_minimum_msat is larger + // than 1. We might not be able to deduct the opening fees from the + // payment entirely. This is an edge case, and we'll fail the payment. + log.Printf( + "After deducting fees from payment parts, there was still fee "+ + "remaining. payment id: %s, fee remaining msat: %d. Failing "+ + "payment.", + event.paymentId, + feeRemainingMsat, + ) + // TODO: Verify temporary_channel_failure is the way to go here, maybe + // unknown_next_peer is more appropriate. + i.paymentFailure <- &paymentFailureEvent{ + paymentId: event.paymentId, + code: shared.FAILURE_TEMPORARY_CHANNEL_FAILURE, + } + return + } + + for _, resolution := range resolutions { + resolution.part.resolution <- resolution.resolution + } + + payment.registration.IsComplete = true + go i.store.SetCompleted(context.TODO(), payment.registration.Id) + delete(i.inflightPayments, event.paymentId) +} + +func (i *Interceptor) handlePaymentFailure( + paymentId string, + code shared.InterceptFailureCode, +) { + i.finalizeAllParts(paymentId, &shared.InterceptResult{ + Action: shared.INTERCEPT_FAIL_HTLC_WITH_CODE, + FailureCode: code, + }) +} + +func (i *Interceptor) finalizeAllParts( + paymentId string, + result *shared.InterceptResult, +) { + payment, ok := i.inflightPayments[paymentId] + if !ok { + // Apparently this payment is already finalized. + return + } + + // Stops the timeout listeners + payment.closeTimeoutChan() + + for _, part := range payment.parts { + part.resolution <- result + } + delete(i.inflightPayments, paymentId) +} + +func (i *Interceptor) Intercept(req shared.InterceptRequest) shared.InterceptResult { + resolution := make(chan *shared.InterceptResult, 1) + i.newPart <- &partState{ + req: &req, + resolution: resolution, + } + res := <-resolution + return *res +} + +func (i *Interceptor) failPart( + payment *paymentState, + part *partState, + code shared.InterceptFailureCode, +) { + part.resolution <- &shared.InterceptResult{ + Action: shared.INTERCEPT_FAIL_HTLC_WITH_CODE, + FailureCode: code, + } + delete(payment.parts, part.req.HtlcId()) + if len(payment.parts) == 0 { + payment.closeTimeoutChan() + delete(i.inflightPayments, part.req.PaymentId()) + } +} diff --git a/lsps2/intercept_test.go b/lsps2/intercept_test.go new file mode 100644 index 0000000..25eaaa8 --- /dev/null +++ b/lsps2/intercept_test.go @@ -0,0 +1,758 @@ +package lsps2 + +import ( + "context" + "crypto/sha256" + "encoding/binary" + "fmt" + "strconv" + "sync" + "testing" + "time" + + "github.com/breez/lspd/chain" + "github.com/breez/lspd/lightning" + "github.com/breez/lspd/lsps0" + "github.com/breez/lspd/shared" + "github.com/btcsuite/btcd/chaincfg/chainhash" + "github.com/btcsuite/btcd/wire" + "github.com/stretchr/testify/assert" +) + +var defaultScid uint64 = 123 +var defaultPaymentSizeMsat uint64 = 1_000_000 +var defaultMinViableAmount uint64 = defaultOpeningFeeParams().MinFeeMsat + defaultConfig().HtlcMinimumMsat +var defaultFee, _ = computeOpeningFee( + defaultPaymentSizeMsat, + defaultOpeningFeeParams().Proportional, + defaultOpeningFeeParams().MinFeeMsat, +) +var defaultChainHash = chainhash.Hash([32]byte{}) +var defaultOutPoint = wire.NewOutPoint(&defaultChainHash, 0) +var defaultChannelScid uint64 = 456 +var defaultChanResult = &lightning.GetChannelResult{ + HtlcMinimumMsat: defaultConfig().HtlcMinimumMsat, + InitialChannelID: lightning.ShortChannelID(defaultChannelScid), + ConfirmedChannelID: lightning.ShortChannelID(defaultChannelScid), +} + +func defaultOpeningFeeParams() shared.OpeningFeeParams { + return shared.OpeningFeeParams{ + MinFeeMsat: 1000, + Proportional: 1000, + ValidUntil: time.Now().UTC().Add(5 * time.Hour).Format(lsps0.TIME_FORMAT), + MinLifetime: 1000, + MaxClientToSelfDelay: 2016, + Promise: "fake", + } +} +func defaultStore() *mockLsps2Store { + return &mockLsps2Store{ + registrations: map[uint64]*BuyRegistration{ + defaultScid: { + PeerId: "peer", + Scid: lightning.ShortChannelID(defaultScid), + Mode: OpeningMode_NoMppVarInvoice, + OpeningFeeParams: defaultOpeningFeeParams(), + }, + }, + } +} + +func mppStore() *mockLsps2Store { + s := defaultStore() + for _, r := range s.registrations { + r.Mode = OpeningMode_MppFixedInvoice + r.PaymentSizeMsat = &defaultPaymentSizeMsat + } + return s +} + +func defaultClient() *mockLightningClient { + return &mockLightningClient{ + openResponses: []*wire.OutPoint{ + defaultOutPoint, + }, + getChanResponses: []*lightning.GetChannelResult{ + defaultChanResult, + }, + } +} + +func defaultFeeEstimator() *mockFeeEstimator { + return nil +} + +func defaultopeningService() *mockOpeningService { + return &mockOpeningService{ + isCurrentChainFeeCheaper: false, + } +} + +func defaultConfig() *InterceptorConfig { + var minConfs uint32 = 1 + return &InterceptorConfig{ + AdditionalChannelCapacitySat: 100_000, + MinConfs: &minConfs, + TargetConf: 6, + FeeStrategy: chain.FeeStrategyEconomy, + MinPaymentSizeMsat: 1_000, + MaxPaymentSizeMsat: 4_000_000_000, + TimeLockDelta: 144, + HtlcMinimumMsat: 100, + } +} + +type interceptP struct { + store *mockLsps2Store + openingService *mockOpeningService + client *mockLightningClient + feeEstimator *mockFeeEstimator + config *InterceptorConfig +} + +func setupInterceptor( + ctx context.Context, + p *interceptP, +) *Interceptor { + var store *mockLsps2Store + if p != nil && p.store != nil { + store = p.store + } else { + store = defaultStore() + } + + var client *mockLightningClient + if p != nil && p.client != nil { + client = p.client + } else { + client = defaultClient() + } + + var f *mockFeeEstimator + if p != nil && p.feeEstimator != nil { + f = p.feeEstimator + } else { + f = defaultFeeEstimator() + } + + var config *InterceptorConfig + if p != nil && p.config != nil { + config = p.config + } else { + config = defaultConfig() + } + + var openingService *mockOpeningService + if p != nil && p.openingService != nil { + openingService = p.openingService + } else { + openingService = defaultopeningService() + } + + i := NewInterceptHandler(store, openingService, client, f, config) + go i.Start(ctx) + return i +} + +type part struct { + id string + scid uint64 + ph []byte + amt uint64 + cltvDelta uint32 +} + +func createPart(p *part) shared.InterceptRequest { + id := "first" + if p != nil && p.id != "" { + id = p.id + } + + scid := lightning.ShortChannelID(defaultScid) + if p != nil && p.scid != 0 { + scid = lightning.ShortChannelID(p.scid) + } + + ph := []byte("fake payment hash") + if p != nil && p.ph != nil && len(p.ph) > 0 { + ph = p.ph + } + + var amt uint64 = 1_000_000 + if p != nil && p.amt != 0 { + amt = p.amt + } + + var cltv uint32 = 146 + if p != nil && p.cltvDelta != 0 { + cltv = p.cltvDelta + } + + return shared.InterceptRequest{ + Identifier: id, + Scid: scid, + PaymentHash: ph, + IncomingAmountMsat: amt, + OutgoingAmountMsat: amt, + IncomingExpiry: 100 + cltv, + OutgoingExpiry: 100, + } +} + +func runIntercept(i *Interceptor, req shared.InterceptRequest, res *shared.InterceptResult, wg *sync.WaitGroup) { + go func() { + *res = i.Intercept(req) + wg.Done() + }() +} + +func assertEmpty(t *testing.T, i *Interceptor) { + assert.Empty(t, i.inflightPayments) + assert.Empty(t, i.newPart) + assert.Empty(t, i.registrationFetched) + assert.Empty(t, i.paymentChanOpened) + assert.Empty(t, i.paymentFailure) + assert.Empty(t, i.paymentReady) +} + +// Asserts that a part that is not associated with a bought channel is not +// handled by the interceptor. This allows the legacy interceptor to pick up +// from there. +func Test_NotBought_SinglePart(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + i := setupInterceptor(ctx, nil) + res := i.Intercept(createPart(&part{scid: 999})) + assert.Equal(t, shared.INTERCEPT_RESUME, res.Action) + assertEmpty(t, i) +} + +func Test_NotBought_TwoParts(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + i := setupInterceptor(ctx, nil) + + var wg sync.WaitGroup + wg.Add(2) + var res1 shared.InterceptResult + runIntercept(i, createPart(&part{id: "first", scid: 999}), &res1, &wg) + + var res2 shared.InterceptResult + runIntercept(i, createPart(&part{id: "second", scid: 999}), &res2, &wg) + wg.Wait() + assert.Equal(t, shared.INTERCEPT_RESUME, res1.Action) + assert.Equal(t, shared.INTERCEPT_RESUME, res2.Action) + assertEmpty(t, i) +} + +// Asserts that a no-MPP+var-invoice mode payment works in the happy flow. +func Test_NoMpp_Happyflow(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + i := setupInterceptor(ctx, nil) + + res := i.Intercept(createPart(nil)) + assert.Equal(t, shared.INTERCEPT_RESUME_WITH_ONION, res.Action) + assert.Equal(t, defaultPaymentSizeMsat-defaultFee, res.AmountMsat) + assert.Equal(t, defaultFee, *res.FeeMsat) + assert.Equal(t, defaultChannelScid, uint64(res.Scid)) + assertEmpty(t, i) +} + +// Asserts that a no-MPP+var-invoice mode payment works with the exact minimum +// amount. +func Test_NoMpp_AmountMinFeePlusHtlcMinPlusOne(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + i := setupInterceptor(ctx, nil) + + res := i.Intercept(createPart(&part{amt: defaultMinViableAmount})) + assert.Equal(t, shared.INTERCEPT_RESUME_WITH_ONION, res.Action) + assert.Equal(t, defaultConfig().HtlcMinimumMsat, res.AmountMsat) + assert.Equal(t, defaultOpeningFeeParams().MinFeeMsat, *res.FeeMsat) + assert.Equal(t, defaultChannelScid, uint64(res.Scid)) + assertEmpty(t, i) +} + +// Asserts that a no-MPP+var-invoice mode payment fails with the exact minimum +// amount minus one. +func Test_NoMpp_AmtBelowMinimum(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + i := setupInterceptor(ctx, nil) + + res := i.Intercept(createPart(&part{amt: defaultMinViableAmount - 1})) + assert.Equal(t, shared.INTERCEPT_FAIL_HTLC_WITH_CODE, res.Action) + assert.Equal(t, shared.FAILURE_UNKNOWN_NEXT_PEER, res.FailureCode) + assertEmpty(t, i) +} + +// Asserts that a no-MPP+var-invoice mode payment succeeds with the exact +// maximum amount. +func Test_NoMpp_AmtAtMaximum(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + i := setupInterceptor(ctx, nil) + + res := i.Intercept(createPart(&part{amt: defaultConfig().MaxPaymentSizeMsat})) + assert.Equal(t, shared.INTERCEPT_RESUME_WITH_ONION, res.Action) + assertEmpty(t, i) +} + +// Asserts that a no-MPP+var-invoice mode payment fails with the exact +// maximum amount plus one. +func Test_NoMpp_AmtAboveMaximum(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + i := setupInterceptor(ctx, nil) + + res := i.Intercept(createPart(&part{amt: defaultConfig().MaxPaymentSizeMsat + 1})) + assert.Equal(t, shared.INTERCEPT_FAIL_HTLC_WITH_CODE, res.Action) + assert.Equal(t, shared.FAILURE_UNKNOWN_NEXT_PEER, res.FailureCode) + assertEmpty(t, i) +} + +// Asserts that a no-MPP+var-invoice mode payment fails when the cltv delta is +// less than cltv delta + 2. +func Test_NoMpp_CltvDeltaBelowMinimum(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + i := setupInterceptor(ctx, nil) + + res := i.Intercept(createPart(&part{cltvDelta: 145})) + assert.Equal(t, shared.INTERCEPT_FAIL_HTLC_WITH_CODE, res.Action) + assert.Equal(t, shared.FAILURE_INCORRECT_CLTV_EXPIRY, res.FailureCode) + assertEmpty(t, i) +} + +// Asserts that a no-MPP+var-invoice mode payment succeeds when the cltv delta +// is higher than expected. +func Test_NoMpp_HigherCltvDelta(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + i := setupInterceptor(ctx, nil) + + res := i.Intercept(createPart(&part{cltvDelta: 1000})) + assert.Equal(t, shared.INTERCEPT_RESUME_WITH_ONION, res.Action) + assertEmpty(t, i) +} + +// Asserts that a no-MPP+var-invoice mode payment fails if the opening params +// have expired. +func Test_NoMpp_ParamsExpired(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + store := defaultStore() + store.registrations[defaultScid].OpeningFeeParams.ValidUntil = time.Now(). + UTC().Add(-time.Nanosecond).Format(lsps0.TIME_FORMAT) + i := setupInterceptor(ctx, &interceptP{store: store}) + + res := i.Intercept(createPart(nil)) + assert.Equal(t, shared.INTERCEPT_FAIL_HTLC_WITH_CODE, res.Action) + assert.Equal(t, shared.FAILURE_UNKNOWN_NEXT_PEER, res.FailureCode) + assertEmpty(t, i) +} + +func Test_NoMpp_ChannelAlreadyOpened_NotComplete_Forwards(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + store := defaultStore() + store.registrations[defaultScid].ChannelPoint = defaultOutPoint + i := setupInterceptor(ctx, &interceptP{store: store}) + + res := i.Intercept(createPart(nil)) + assert.Equal(t, shared.INTERCEPT_RESUME_WITH_ONION, res.Action) + assertEmpty(t, i) +} + +func Test_NoMpp_ChannelAlreadyOpened_Complete_Fails(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + store := defaultStore() + store.registrations[defaultScid].ChannelPoint = defaultOutPoint + store.registrations[defaultScid].IsComplete = true + i := setupInterceptor(ctx, &interceptP{store: store}) + + res := i.Intercept(createPart(nil)) + assert.Equal(t, shared.INTERCEPT_FAIL_HTLC_WITH_CODE, res.Action) + assert.Equal(t, shared.FAILURE_UNKNOWN_NEXT_PEER, res.FailureCode) + assertEmpty(t, i) +} + +// Asserts that a MPP+fixed-invoice mode payment succeeds in the happy flow +// case. +func Test_Mpp_SinglePart_Happyflow(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + i := setupInterceptor(ctx, &interceptP{store: mppStore()}) + + res := i.Intercept(createPart(&part{amt: defaultPaymentSizeMsat})) + assert.Equal(t, shared.INTERCEPT_RESUME_WITH_ONION, res.Action) + assert.Equal(t, defaultPaymentSizeMsat-defaultFee, res.AmountMsat) + assert.Equal(t, defaultFee, *res.FeeMsat) + assert.Equal(t, defaultChannelScid, uint64(res.Scid)) + assertEmpty(t, i) +} + +// Asserts that a MPP+fixed-invoice mode payment times out when it receives only +// a single part below payment_size_msat. +func Test_Mpp_SinglePart_AmtTooSmall(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + config := defaultConfig() + config.MppTimeout = time.Millisecond * 500 + i := setupInterceptor(ctx, &interceptP{store: mppStore(), config: config}) + + start := time.Now() + res := i.Intercept(createPart(&part{amt: defaultPaymentSizeMsat - 1})) + end := time.Now() + assert.Equal(t, shared.INTERCEPT_FAIL_HTLC_WITH_CODE, res.Action) + assert.Equal(t, shared.FAILURE_TEMPORARY_CHANNEL_FAILURE, res.FailureCode) + assert.GreaterOrEqual(t, end.Sub(start).Milliseconds(), config.MppTimeout.Milliseconds()) + assertEmpty(t, i) +} + +// Asserts that a MPP+fixed-invoice mode payment finalizes after it receives the +// second part that finalizes the payment. +func Test_Mpp_TwoParts_FinalizedOnSecond(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + config := defaultConfig() + config.MppTimeout = time.Millisecond * 500 + i := setupInterceptor(ctx, &interceptP{store: mppStore(), config: config}) + + var wg sync.WaitGroup + wg.Add(2) + var res1 shared.InterceptResult + var res2 shared.InterceptResult + var t1 time.Time + var t2 time.Time + start := time.Now() + go func() { + res1 = i.Intercept(createPart(&part{ + id: "first", + amt: defaultPaymentSizeMsat - defaultConfig().HtlcMinimumMsat, + })) + t1 = time.Now() + wg.Done() + }() + + <-time.After(time.Millisecond * 250) + + go func() { + res2 = i.Intercept(createPart(&part{ + id: "second", + amt: defaultConfig().HtlcMinimumMsat, + })) + t2 = time.Now() + wg.Done() + }() + + wg.Wait() + assert.Equal(t, shared.INTERCEPT_RESUME_WITH_ONION, res1.Action) + assert.Equal(t, defaultPaymentSizeMsat-defaultConfig().HtlcMinimumMsat-defaultFee, res1.AmountMsat) + assert.Equal(t, defaultFee, *res1.FeeMsat) + + assert.Equal(t, shared.INTERCEPT_RESUME_WITH_ONION, res2.Action) + assert.Equal(t, defaultConfig().HtlcMinimumMsat, res2.AmountMsat) + assert.Nil(t, res2.FeeMsat) + + assert.LessOrEqual(t, int64(250), t1.Sub(start).Milliseconds()) + assert.LessOrEqual(t, int64(250), t2.Sub(start).Milliseconds()) + assertEmpty(t, i) +} + +// Asserts that a MPP+fixed-invoice mode payment with the following parts +// 1) payment size - htlc minimum +// 2) htlc minimum - 1 +// 3) htlc minimum +// still succeeds. The second part is dropped, but the third part completes the +// payment. +func Test_Mpp_BadSecondPart_ThirdPartCompletes(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + config := defaultConfig() + config.MppTimeout = time.Millisecond * 500 + i := setupInterceptor(ctx, &interceptP{store: mppStore(), config: config}) + + var wg sync.WaitGroup + wg.Add(2) + var res1 shared.InterceptResult + var res2 shared.InterceptResult + var res3 shared.InterceptResult + var t1 time.Time + var t2 time.Time + var t3 time.Time + start := time.Now() + go func() { + res1 = i.Intercept(createPart(&part{ + id: "first", + amt: defaultPaymentSizeMsat - defaultConfig().HtlcMinimumMsat, + })) + t1 = time.Now() + wg.Done() + }() + + <-time.After(time.Millisecond * 100) + res2 = i.Intercept(createPart(&part{ + id: "second", + amt: defaultConfig().HtlcMinimumMsat - 1, + })) + t2 = time.Now() + + <-time.After(time.Millisecond * 100) + go func() { + res3 = i.Intercept(createPart(&part{ + id: "third", + amt: defaultConfig().HtlcMinimumMsat, + })) + t3 = time.Now() + wg.Done() + }() + + wg.Wait() + assert.Equal(t, shared.INTERCEPT_RESUME_WITH_ONION, res1.Action) + assert.Equal(t, defaultPaymentSizeMsat-defaultConfig().HtlcMinimumMsat-defaultFee, res1.AmountMsat) + assert.Equal(t, defaultFee, *res1.FeeMsat) + + assert.Equal(t, shared.INTERCEPT_FAIL_HTLC_WITH_CODE, res2.Action) + assert.Equal(t, shared.FAILURE_AMOUNT_BELOW_MINIMUM, res2.FailureCode) + + assert.Equal(t, shared.INTERCEPT_RESUME_WITH_ONION, res3.Action) + assert.Equal(t, defaultConfig().HtlcMinimumMsat, res3.AmountMsat) + assert.Nil(t, res3.FeeMsat) + + assert.LessOrEqual(t, int64(200), t1.Sub(start).Milliseconds()) + assert.Greater(t, int64(200), t2.Sub(start).Milliseconds()) + assert.LessOrEqual(t, int64(200), t3.Sub(start).Milliseconds()) + + assertEmpty(t, i) +} + +// Asserts that a MPP+fixed-invoice mode payment fails when the cltv delta is +// less than cltv delta + 2. +func Test_Mpp_CltvDeltaBelowMinimum(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + i := setupInterceptor(ctx, &interceptP{store: mppStore()}) + + res := i.Intercept(createPart(&part{cltvDelta: 145})) + assert.Equal(t, shared.INTERCEPT_FAIL_HTLC_WITH_CODE, res.Action) + assert.Equal(t, shared.FAILURE_INCORRECT_CLTV_EXPIRY, res.FailureCode) + assertEmpty(t, i) +} + +// Asserts that a MPP+fixed-invoice mode payment succeeds when the cltv delta +// is higher than expected. +func Test_Mpp_HigherCltvDelta(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + i := setupInterceptor(ctx, &interceptP{store: mppStore()}) + + res := i.Intercept(createPart(&part{cltvDelta: 1000})) + assert.Equal(t, shared.INTERCEPT_RESUME_WITH_ONION, res.Action) + assertEmpty(t, i) +} + +// Asserts that a MPP+fixed-invoice mode payment fails if the opening params +// have expired. +func Test_Mpp_ParamsExpired(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + store := mppStore() + store.registrations[defaultScid].OpeningFeeParams.ValidUntil = time.Now(). + UTC().Add(-time.Nanosecond).Format(lsps0.TIME_FORMAT) + i := setupInterceptor(ctx, &interceptP{store: store}) + + res := i.Intercept(createPart(nil)) + assert.Equal(t, shared.INTERCEPT_FAIL_HTLC_WITH_CODE, res.Action) + assert.Equal(t, shared.FAILURE_UNKNOWN_NEXT_PEER, res.FailureCode) + assertEmpty(t, i) +} + +// Asserts that a MPP+fixed-invoice mode payment fails if the opening params +// expire while the part is in-flight. +func Test_Mpp_ParamsExpireInFlight(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + store := mppStore() + i := setupInterceptor(ctx, &interceptP{store: store}) + + start := time.Now() + store.registrations[defaultScid].OpeningFeeParams.ValidUntil = start. + UTC().Add(time.Millisecond * 250).Format(lsps0.TIME_FORMAT) + + var res1 shared.InterceptResult + var res2 shared.InterceptResult + var wg sync.WaitGroup + wg.Add(1) + go func() { + res1 = i.Intercept(createPart(&part{ + id: "first", + amt: defaultPaymentSizeMsat - defaultConfig().HtlcMinimumMsat, + })) + wg.Done() + }() + + <-time.After(time.Millisecond * 300) + res2 = i.Intercept(createPart(&part{ + id: "second", + amt: defaultConfig().HtlcMinimumMsat, + })) + + wg.Wait() + assert.Equal(t, shared.INTERCEPT_FAIL_HTLC_WITH_CODE, res1.Action) + assert.Equal(t, shared.FAILURE_UNKNOWN_NEXT_PEER, res1.FailureCode) + assert.Equal(t, shared.INTERCEPT_FAIL_HTLC_WITH_CODE, res2.Action) + assert.Equal(t, shared.FAILURE_UNKNOWN_NEXT_PEER, res2.FailureCode) + + assertEmpty(t, i) +} + +// Asserts that a MPP+fixed-invoice mode replacement of a part ignores that +// part, and the replacement is used for completing the payment +func Test_Mpp_PartReplacement(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + i := setupInterceptor(ctx, &interceptP{store: mppStore()}) + + var wg sync.WaitGroup + wg.Add(3) + var res1 shared.InterceptResult + var res2 shared.InterceptResult + var res3 shared.InterceptResult + var t1 time.Time + var t2 time.Time + var t3 time.Time + start := time.Now() + go func() { + res1 = i.Intercept(createPart(&part{ + id: "first", + amt: defaultPaymentSizeMsat - defaultConfig().HtlcMinimumMsat, + })) + t1 = time.Now() + wg.Done() + }() + + <-time.After(time.Millisecond * 100) + go func() { + res2 = i.Intercept(createPart(&part{ + id: "first", + amt: defaultPaymentSizeMsat - defaultConfig().HtlcMinimumMsat, + })) + t2 = time.Now() + wg.Done() + }() + + <-time.After(time.Millisecond * 100) + go func() { + res3 = i.Intercept(createPart(&part{ + id: "second", + amt: defaultConfig().HtlcMinimumMsat, + })) + t3 = time.Now() + wg.Done() + }() + + wg.Wait() + assert.Equal(t, shared.INTERCEPT_IGNORE, res1.Action) + + assert.Equal(t, shared.INTERCEPT_RESUME_WITH_ONION, res2.Action) + assert.Equal(t, defaultPaymentSizeMsat-defaultConfig().HtlcMinimumMsat-defaultFee, res2.AmountMsat) + assert.Equal(t, defaultFee, *res2.FeeMsat) + + assert.Equal(t, shared.INTERCEPT_RESUME_WITH_ONION, res3.Action) + assert.Equal(t, defaultConfig().HtlcMinimumMsat, res3.AmountMsat) + assert.Nil(t, res3.FeeMsat) + + assert.LessOrEqual(t, int64(100), t1.Sub(start).Milliseconds()) + assert.LessOrEqual(t, int64(200), t2.Sub(start).Milliseconds()) + assert.LessOrEqual(t, int64(200), t3.Sub(start).Milliseconds()) + + assertEmpty(t, i) +} + +func Test_Mpp_ChannelAlreadyOpened_NotComplete_Forwards(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + store := mppStore() + store.registrations[defaultScid].ChannelPoint = defaultOutPoint + i := setupInterceptor(ctx, &interceptP{store: store}) + + res := i.Intercept(createPart(nil)) + assert.Equal(t, shared.INTERCEPT_RESUME_WITH_ONION, res.Action) + assertEmpty(t, i) +} + +func Test_Mpp_ChannelAlreadyOpened_Complete_Fails(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + store := mppStore() + store.registrations[defaultScid].ChannelPoint = defaultOutPoint + store.registrations[defaultScid].IsComplete = true + i := setupInterceptor(ctx, &interceptP{store: store}) + + res := i.Intercept(createPart(nil)) + assert.Equal(t, shared.INTERCEPT_FAIL_HTLC_WITH_CODE, res.Action) + assert.Equal(t, shared.FAILURE_UNKNOWN_NEXT_PEER, res.FailureCode) + assertEmpty(t, i) +} + +func Test_Mpp_Performance(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + paymentCount := 100 + partCount := 10 + store := &mockLsps2Store{ + delay: time.Millisecond * 500, + registrations: make(map[uint64]*BuyRegistration), + } + + client := &mockLightningClient{} + for paymentNo := 0; paymentNo < paymentCount; paymentNo++ { + scid := uint64(paymentNo + 1_000_000) + client.getChanResponses = append(client.getChanResponses, defaultChanResult) + client.openResponses = append(client.openResponses, defaultOutPoint) + store.registrations[scid] = &BuyRegistration{ + PeerId: strconv.FormatUint(scid, 10), + Scid: lightning.ShortChannelID(scid), + Mode: OpeningMode_MppFixedInvoice, + OpeningFeeParams: defaultOpeningFeeParams(), + PaymentSizeMsat: &defaultPaymentSizeMsat, + } + } + i := setupInterceptor(ctx, &interceptP{store: store, client: client}) + var wg sync.WaitGroup + wg.Add(partCount * paymentCount) + start := time.Now() + for paymentNo := 0; paymentNo < paymentCount; paymentNo++ { + for partNo := 0; partNo < partCount; partNo++ { + scid := paymentNo + 1_000_000 + id := fmt.Sprintf("%d|%d", paymentNo, partNo) + var a [8]byte + binary.BigEndian.PutUint64(a[:], uint64(scid)) + ph := sha256.Sum256(a[:]) + + go func() { + res := i.Intercept(createPart(&part{ + scid: uint64(scid), + id: id, + ph: ph[:], + amt: defaultPaymentSizeMsat / uint64(partCount), + })) + + assert.Equal(t, shared.INTERCEPT_RESUME_WITH_ONION, res.Action) + wg.Done() + }() + } + } + wg.Wait() + end := time.Now() + + assert.LessOrEqual(t, end.Sub(start).Milliseconds(), int64(1000)) + assertEmpty(t, i) +}