diff --git a/htlcswitch/link_test.go b/htlcswitch/link_test.go index ac3a1b12..6ae141ab 100644 --- a/htlcswitch/link_test.go +++ b/htlcswitch/link_test.go @@ -3909,8 +3909,8 @@ func TestChannelLinkAcceptDuplicatePayment(t *testing.T) { err = n.aliceServer.htlcSwitch.SendHTLC( n.firstBobChannelLink.ShortChanID(), pid, htlc, ) - if err != ErrPaymentIDAlreadyExists { - t.Fatalf("ErrPaymentIDAlreadyExists should have been "+ + if err != ErrDuplicateAdd { + t.Fatalf("ErrDuplicateAdd should have been "+ "received got: %v", err) } diff --git a/htlcswitch/switch.go b/htlcswitch/switch.go index efbcf608..ee2c090a 100644 --- a/htlcswitch/switch.go +++ b/htlcswitch/switch.go @@ -64,16 +64,6 @@ var ( zeroPreimage [sha256.Size]byte ) -// pendingPayment represents the payment which made by user and waits for -// updates to be received whether the payment has been rejected or proceed -// successfully. -type pendingPayment struct { - paymentHash lntypes.Hash - amount lnwire.MilliSatoshi - - resultChan chan *networkResult -} - // plexPacket encapsulates switch packet and adds error channel to receive // error from request handler. type plexPacket struct { @@ -201,12 +191,12 @@ type Switch struct { // service was initialized with. cfg *Config - // pendingPayments stores payments initiated by the user that are not yet - // settled. The map is used to later look up the payments and notify the - // user of the result when they are complete. Each payment is given a unique - // integer ID when it is created. - pendingPayments map[uint64]*pendingPayment - pendingMutex sync.RWMutex + // networkResults stores the results of payments initiated by the user. + // results. The store is used to later look up the payments and notify + // the user of the result when they are complete. Each payment attempt + // should be given a unique integer ID when it is created, otherwise + // results might be overwritten. + networkResults *networkResultStore // circuits is storage for payment circuits which are used to // forward the settle/fail htlc updates back to the add htlc initiator. @@ -292,7 +282,7 @@ func New(cfg Config, currentHeight uint32) (*Switch, error) { forwardingIndex: make(map[lnwire.ShortChannelID]ChannelLink), interfaceIndex: make(map[[33]byte]map[lnwire.ChannelID]ChannelLink), pendingLinkIndex: make(map[lnwire.ChannelID]ChannelLink), - pendingPayments: make(map[uint64]*pendingPayment), + networkResults: newNetworkResultStore(cfg.DB), htlcPlex: make(chan *plexPacket), chanCloseRequests: make(chan *ChanClose), resolutionMsgs: make(chan *resolutionMsg), @@ -345,12 +335,33 @@ func (s *Switch) ProcessContractResolution(msg contractcourt.ResolutionMsg) erro func (s *Switch) GetPaymentResult(paymentID uint64, paymentHash lntypes.Hash, deobfuscator ErrorDecrypter) (<-chan *PaymentResult, error) { - s.pendingMutex.Lock() - payment, ok := s.pendingPayments[paymentID] - s.pendingMutex.Unlock() + var ( + nChan <-chan *networkResult + err error + outKey = CircuitKey{ + ChanID: sourceHop, + HtlcID: paymentID, + } + ) - if !ok { - return nil, ErrPaymentIDNotFound + // If the payment is not found in the circuit map, check whether a + // result is already available. + // Assumption: no one will add this payment ID other than the caller. + if s.circuits.LookupCircuit(outKey) == nil { + res, err := s.networkResults.getResult(paymentID) + if err != nil { + return nil, err + } + c := make(chan *networkResult, 1) + c <- res + nChan = c + } else { + // The payment was committed to the circuits, subscribe for a + // result. + nChan, err = s.networkResults.subscribeResult(paymentID) + if err != nil { + return nil, err + } } resultChan := make(chan *PaymentResult, 1) @@ -364,7 +375,7 @@ func (s *Switch) GetPaymentResult(paymentID uint64, paymentHash lntypes.Hash, var n *networkResult select { - case n = <-payment.resultChan: + case n = <-nChan: case <-s.quit: // We close the result channel to signal a shutdown. We // don't send any result in this case since the HTLC is @@ -398,24 +409,6 @@ func (s *Switch) GetPaymentResult(paymentID uint64, paymentHash lntypes.Hash, func (s *Switch) SendHTLC(firstHop lnwire.ShortChannelID, paymentID uint64, htlc *lnwire.UpdateAddHTLC) error { - // Create payment and add to the map of payment in order later to be - // able to retrieve it and return response to the user. - payment := &pendingPayment{ - resultChan: make(chan *networkResult, 1), - paymentHash: htlc.PaymentHash, - amount: htlc.Amount, - } - - s.pendingMutex.Lock() - if _, ok := s.pendingPayments[paymentID]; ok { - s.pendingMutex.Unlock() - - return ErrPaymentIDAlreadyExists - } - - s.pendingPayments[paymentID] = payment - s.pendingMutex.Unlock() - // Generate and send new update packet, if error will be received on // this stage it means that packet haven't left boundaries of our // system and something wrong happened. @@ -426,12 +419,7 @@ func (s *Switch) SendHTLC(firstHop lnwire.ShortChannelID, paymentID uint64, htlc: htlc, } - if err := s.forward(packet); err != nil { - s.removePendingPayment(paymentID) - return err - } - - return nil + return s.forward(packet) } // UpdateForwardingPolicies sends a message to the switch to update the @@ -856,15 +844,34 @@ func (s *Switch) handleLocalDispatch(pkt *htlcPacket) error { // multiple db transactions. The guarantees of the circuit map are stringent // enough such that we are able to tolerate reordering of these operations // without side effects. The primary operations handled are: -// 1. Ack settle/fail references, to avoid resending this response internally -// 2. Teardown the closing circuit in the circuit map -// 3. Transition the payment status to grounded or completed. -// 4. Respond to an in-mem pending payment, if it is found. +// 1. Save the payment result to the pending payment store. +// 2. Notify subscribers about the payment result. +// 3. Ack settle/fail references, to avoid resending this response internally +// 4. Teardown the closing circuit in the circuit map // // NOTE: This method MUST be spawned as a goroutine. func (s *Switch) handleLocalResponse(pkt *htlcPacket) { defer s.wg.Done() + paymentID := pkt.incomingHTLCID + + // The error reason will be unencypted in case this a local + // failure or a converted error. + unencrypted := pkt.localFailure || pkt.convertedError + n := &networkResult{ + msg: pkt.htlc, + unencrypted: unencrypted, + isResolution: pkt.isResolution, + } + + // Store the result to the db. This will also notify subscribers about + // the result. + if err := s.networkResults.storeResult(paymentID, n); err != nil { + log.Errorf("Unable to complete payment for pid=%v: %v", + paymentID, err) + return + } + // First, we'll clean up any fwdpkg references, circuit entries, and // mark in our db that the payment for this payment hash has either // succeeded or failed. @@ -892,26 +899,6 @@ func (s *Switch) handleLocalResponse(pkt *htlcPacket) { pkt.inKey(), err) return } - - // Locate the pending payment to notify the application that this - // payment has failed. If one is not found, it likely means the daemon - // has been restarted since sending the payment. - payment := s.findPayment(pkt.incomingHTLCID) - - // The error reason will be unencypted in case this a local - // failure or a converted error. - unencrypted := pkt.localFailure || pkt.convertedError - n := &networkResult{ - msg: pkt.htlc, - unencrypted: unencrypted, - isResolution: pkt.isResolution, - } - - // Deliver the payment error and preimage to the application, if it is - // waiting for a response. - if payment != nil { - payment.resultChan <- n - } } // extractResult uses the given deobfuscator to extract the payment result from @@ -2173,30 +2160,6 @@ func (s *Switch) getLinks(destination [33]byte) ([]ChannelLink, error) { return channelLinks, nil } -// removePendingPayment is the helper function which removes the pending user -// payment. -func (s *Switch) removePendingPayment(paymentID uint64) { - s.pendingMutex.Lock() - defer s.pendingMutex.Unlock() - - delete(s.pendingPayments, paymentID) -} - -// findPayment is the helper function which find the payment. -func (s *Switch) findPayment(paymentID uint64) *pendingPayment { - s.pendingMutex.RLock() - defer s.pendingMutex.RUnlock() - - payment, ok := s.pendingPayments[paymentID] - if !ok { - log.Errorf("Cannot find pending payment with ID %d", - paymentID) - return nil - } - - return payment -} - // CircuitModifier returns a reference to subset of the interfaces provided by // the circuit map, to allow links to open and close circuits. func (s *Switch) CircuitModifier() CircuitModifier {