mirror of
https://github.com/aljazceru/breez-lnd.git
synced 2026-02-23 07:24:21 +01:00
Merge pull request #3442 from cfromknecht/router-registry
single-shot, sender-side mpp via sendtoroute
This commit is contained in:
@@ -14,6 +14,7 @@ import (
|
||||
"github.com/coreos/bbolt"
|
||||
"github.com/lightningnetwork/lnd/lntypes"
|
||||
"github.com/lightningnetwork/lnd/lnwire"
|
||||
"github.com/lightningnetwork/lnd/record"
|
||||
"github.com/lightningnetwork/lnd/routing/route"
|
||||
"github.com/lightningnetwork/lnd/tlv"
|
||||
)
|
||||
@@ -508,7 +509,9 @@ func deserializePaymentAttemptInfo(r io.Reader) (*PaymentAttemptInfo, error) {
|
||||
|
||||
func serializeHop(w io.Writer, h *route.Hop) error {
|
||||
if err := WriteElements(w,
|
||||
h.PubKeyBytes[:], h.ChannelID, h.OutgoingTimeLock,
|
||||
h.PubKeyBytes[:],
|
||||
h.ChannelID,
|
||||
h.OutgoingTimeLock,
|
||||
h.AmtToForward,
|
||||
); err != nil {
|
||||
return err
|
||||
@@ -525,10 +528,23 @@ func serializeHop(w io.Writer, h *route.Hop) error {
|
||||
return WriteElements(w, uint32(0))
|
||||
}
|
||||
|
||||
// Gather all non-primitive TLV records so that they can be serialized
|
||||
// as a single blob.
|
||||
//
|
||||
// TODO(conner): add migration to unify all fields in a single TLV
|
||||
// blobs. The split approach will cause headaches down the road as more
|
||||
// fields are added, which we can avoid by having a single TLV stream
|
||||
// for all payload fields.
|
||||
var records []tlv.Record
|
||||
if h.MPP != nil {
|
||||
records = append(records, h.MPP.Record())
|
||||
}
|
||||
records = append(records, h.TLVRecords...)
|
||||
|
||||
// Otherwise, we'll transform our slice of records into a map of the
|
||||
// raw bytes, then serialize them in-line with a length (number of
|
||||
// elements) prefix.
|
||||
mapRecords, err := tlv.RecordsToMap(h.TLVRecords)
|
||||
mapRecords, err := tlv.RecordsToMap(records)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -604,6 +620,29 @@ func deserializeHop(r io.Reader) (*route.Hop, error) {
|
||||
tlvMap[tlvType] = rawRecordBytes
|
||||
}
|
||||
|
||||
// If the MPP type is present, remove it from the generic TLV map and
|
||||
// parse it back into a proper MPP struct.
|
||||
//
|
||||
// TODO(conner): add migration to unify all fields in a single TLV
|
||||
// blobs. The split approach will cause headaches down the road as more
|
||||
// fields are added, which we can avoid by having a single TLV stream
|
||||
// for all payload fields.
|
||||
mppType := uint64(record.MPPOnionType)
|
||||
if mppBytes, ok := tlvMap[mppType]; ok {
|
||||
delete(tlvMap, mppType)
|
||||
|
||||
var (
|
||||
mpp = &record.MPP{}
|
||||
mppRec = mpp.Record()
|
||||
r = bytes.NewReader(mppBytes)
|
||||
)
|
||||
err := mppRec.Decode(r, uint64(len(mppBytes)))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
h.MPP = mpp
|
||||
}
|
||||
|
||||
tlvRecords, err := tlv.MapToRecords(tlvMap)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"github.com/btcsuite/btcd/btcec"
|
||||
"github.com/davecgh/go-spew/spew"
|
||||
"github.com/lightningnetwork/lnd/lntypes"
|
||||
"github.com/lightningnetwork/lnd/record"
|
||||
"github.com/lightningnetwork/lnd/routing/route"
|
||||
"github.com/lightningnetwork/lnd/tlv"
|
||||
)
|
||||
@@ -31,6 +32,7 @@ var (
|
||||
tlv.MakeStaticRecord(1, nil, 3, tlvEncoder, nil),
|
||||
tlv.MakeStaticRecord(2, nil, 3, tlvEncoder, nil),
|
||||
},
|
||||
MPP: record.NewMPP(32, [32]byte{0x42}),
|
||||
}
|
||||
|
||||
testHop2 = &route.Hop{
|
||||
@@ -46,8 +48,8 @@ var (
|
||||
TotalAmount: 1234567,
|
||||
SourcePubKey: route.NewVertex(pub),
|
||||
Hops: []*route.Hop{
|
||||
testHop1,
|
||||
testHop2,
|
||||
testHop1,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
@@ -21,7 +21,7 @@ type Registry interface {
|
||||
NotifyExitHopHtlc(payHash lntypes.Hash, paidAmount lnwire.MilliSatoshi,
|
||||
expiry uint32, currentHeight int32,
|
||||
circuitKey channeldb.CircuitKey, hodlChan chan<- interface{},
|
||||
eob []byte) (*invoices.HodlEvent, error)
|
||||
payload invoices.Payload) (*invoices.HodlEvent, error)
|
||||
|
||||
// HodlUnsubscribeAll unsubscribes from all hodl events.
|
||||
HodlUnsubscribeAll(subscriber chan<- interface{})
|
||||
|
||||
@@ -24,7 +24,7 @@ type mockRegistry struct {
|
||||
func (r *mockRegistry) NotifyExitHopHtlc(payHash lntypes.Hash,
|
||||
paidAmount lnwire.MilliSatoshi, expiry uint32, currentHeight int32,
|
||||
circuitKey channeldb.CircuitKey, hodlChan chan<- interface{},
|
||||
eob []byte) (*invoices.HodlEvent, error) {
|
||||
payload invoices.Payload) (*invoices.HodlEvent, error) {
|
||||
|
||||
r.notifyChan <- notifyExitHopData{
|
||||
hodlChan: hodlChan,
|
||||
|
||||
@@ -16,15 +16,13 @@ import (
|
||||
// interpret the forwarding information encoded within the HTLC packet, and hop
|
||||
// to encode the forwarding information for the _next_ hop.
|
||||
type Iterator interface {
|
||||
// ForwardingInstructions returns the set of fields that detail exactly
|
||||
// _how_ this hop should forward the HTLC to the next hop.
|
||||
// Additionally, the information encoded within the returned
|
||||
// ForwardingInfo is to be used by each hop to authenticate the
|
||||
// information given to it by the prior hop.
|
||||
ForwardingInstructions() (ForwardingInfo, error)
|
||||
|
||||
// ExtraOnionBlob returns the additional EOB data (if available).
|
||||
ExtraOnionBlob() []byte
|
||||
// HopPayload returns the set of fields that detail exactly _how_ this
|
||||
// hop should forward the HTLC to the next hop. Additionally, the
|
||||
// information encoded within the returned ForwardingInfo is to be used
|
||||
// by each hop to authenticate the information given to it by the prior
|
||||
// hop. The payload will also contain any additional TLV fields provided
|
||||
// by the sender.
|
||||
HopPayload() (*Payload, error)
|
||||
|
||||
// EncodeNextHop encodes the onion packet destined for the next hop
|
||||
// into the passed io.Writer.
|
||||
@@ -72,50 +70,35 @@ func (r *sphinxHopIterator) EncodeNextHop(w io.Writer) error {
|
||||
return r.processedPacket.NextPacket.Encode(w)
|
||||
}
|
||||
|
||||
// ForwardingInstructions returns the set of fields that detail exactly _how_
|
||||
// this hop should forward the HTLC to the next hop. Additionally, the
|
||||
// information encoded within the returned ForwardingInfo is to be used by each
|
||||
// hop to authenticate the information given to it by the prior hop.
|
||||
// HopPayload returns the set of fields that detail exactly _how_ this hop
|
||||
// should forward the HTLC to the next hop. Additionally, the information
|
||||
// encoded within the returned ForwardingInfo is to be used by each hop to
|
||||
// authenticate the information given to it by the prior hop. The payload will
|
||||
// also contain any additional TLV fields provided by the sender.
|
||||
//
|
||||
// NOTE: Part of the HopIterator interface.
|
||||
func (r *sphinxHopIterator) ForwardingInstructions() (ForwardingInfo, error) {
|
||||
func (r *sphinxHopIterator) HopPayload() (*Payload, error) {
|
||||
switch r.processedPacket.Payload.Type {
|
||||
|
||||
// If this is the legacy payload, then we'll extract the information
|
||||
// directly from the pre-populated ForwardingInstructions field.
|
||||
case sphinx.PayloadLegacy:
|
||||
fwdInst := r.processedPacket.ForwardingInstructions
|
||||
p := NewLegacyPayload(fwdInst)
|
||||
|
||||
return p.ForwardingInfo(), nil
|
||||
return NewLegacyPayload(fwdInst), nil
|
||||
|
||||
// Otherwise, if this is the TLV payload, then we'll make a new stream
|
||||
// to decode only what we need to make routing decisions.
|
||||
case sphinx.PayloadTLV:
|
||||
p, err := NewPayloadFromReader(bytes.NewReader(
|
||||
return NewPayloadFromReader(bytes.NewReader(
|
||||
r.processedPacket.Payload.Payload,
|
||||
))
|
||||
if err != nil {
|
||||
return ForwardingInfo{}, err
|
||||
}
|
||||
|
||||
return p.ForwardingInfo(), nil
|
||||
|
||||
default:
|
||||
return ForwardingInfo{}, fmt.Errorf("unknown "+
|
||||
"sphinx payload type: %v",
|
||||
return nil, fmt.Errorf("unknown sphinx payload type: %v",
|
||||
r.processedPacket.Payload.Type)
|
||||
}
|
||||
}
|
||||
|
||||
// ExtraOnionBlob returns the additional EOB data (if available).
|
||||
func (r *sphinxHopIterator) ExtraOnionBlob() []byte {
|
||||
if r.processedPacket.Payload.Type == sphinx.PayloadLegacy {
|
||||
return nil
|
||||
}
|
||||
|
||||
return r.processedPacket.Payload.Payload
|
||||
}
|
||||
|
||||
// ExtractErrorEncrypter decodes and returns the ErrorEncrypter for this hop,
|
||||
// along with a failure code to signal if the decoding was successful. The
|
||||
// ErrorEncrypter is used to encrypt errors back to the sender in the event that
|
||||
|
||||
@@ -85,12 +85,13 @@ func TestSphinxHopIteratorForwardingInstructions(t *testing.T) {
|
||||
for i, testCase := range testCases {
|
||||
iterator.processedPacket = testCase.sphinxPacket
|
||||
|
||||
fwdInfo, err := iterator.ForwardingInstructions()
|
||||
pld, err := iterator.HopPayload()
|
||||
if err != nil {
|
||||
t.Fatalf("#%v: unable to extract forwarding "+
|
||||
"instructions: %v", i, err)
|
||||
}
|
||||
|
||||
fwdInfo := pld.ForwardingInfo()
|
||||
if fwdInfo != testCase.expectedFwdInfo {
|
||||
t.Fatalf("#%v: wrong fwding info: expected %v, got %v",
|
||||
i, spew.Sdump(testCase.expectedFwdInfo),
|
||||
|
||||
@@ -81,6 +81,10 @@ type Payload struct {
|
||||
// FwdInfo holds the basic parameters required for HTLC forwarding, e.g.
|
||||
// amount, cltv, and next hop.
|
||||
FwdInfo ForwardingInfo
|
||||
|
||||
// MPP holds the info provided in an option_mpp record when parsed from
|
||||
// a TLV onion payload.
|
||||
MPP *record.MPP
|
||||
}
|
||||
|
||||
// NewLegacyPayload builds a Payload from the amount, cltv, and next hop
|
||||
@@ -105,12 +109,14 @@ func NewPayloadFromReader(r io.Reader) (*Payload, error) {
|
||||
cid uint64
|
||||
amt uint64
|
||||
cltv uint32
|
||||
mpp = &record.MPP{}
|
||||
)
|
||||
|
||||
tlvStream, err := tlv.NewStream(
|
||||
record.NewAmtToFwdRecord(&amt),
|
||||
record.NewLockTimeRecord(&cltv),
|
||||
record.NewNextHopIDRecord(&cid),
|
||||
mpp.Record(),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -151,6 +157,12 @@ func NewPayloadFromReader(r io.Reader) (*Payload, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// If no MPP field was parsed, set the MPP field on the resulting
|
||||
// payload to nil.
|
||||
if _, ok := parsedTypes[record.MPPOnionType]; !ok {
|
||||
mpp = nil
|
||||
}
|
||||
|
||||
return &Payload{
|
||||
FwdInfo: ForwardingInfo{
|
||||
Network: BitcoinNetwork,
|
||||
@@ -158,6 +170,7 @@ func NewPayloadFromReader(r io.Reader) (*Payload, error) {
|
||||
AmountToForward: lnwire.MilliSatoshi(amt),
|
||||
OutgoingCTLV: cltv,
|
||||
},
|
||||
MPP: mpp,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -179,6 +192,7 @@ func ValidateParsedPayloadTypes(parsedTypes tlv.TypeSet,
|
||||
_, hasAmt := parsedTypes[record.AmtOnionType]
|
||||
_, hasLockTime := parsedTypes[record.LockTimeOnionType]
|
||||
_, hasNextHop := parsedTypes[record.NextHopOnionType]
|
||||
_, hasMPP := parsedTypes[record.MPPOnionType]
|
||||
|
||||
switch {
|
||||
|
||||
@@ -207,7 +221,21 @@ func ValidateParsedPayloadTypes(parsedTypes tlv.TypeSet,
|
||||
Violation: IncludedViolation,
|
||||
FinalHop: true,
|
||||
}
|
||||
|
||||
// Intermediate nodes should never receive MPP fields.
|
||||
case !isFinalHop && hasMPP:
|
||||
return ErrInvalidPayload{
|
||||
Type: record.MPPOnionType,
|
||||
Violation: IncludedViolation,
|
||||
FinalHop: isFinalHop,
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// MultiPath returns the record corresponding the option_mpp parsed from the
|
||||
// onion payload.
|
||||
func (h *Payload) MultiPath() *record.MPP {
|
||||
return h.MPP
|
||||
}
|
||||
|
||||
@@ -6,13 +6,15 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/lightningnetwork/lnd/htlcswitch/hop"
|
||||
"github.com/lightningnetwork/lnd/lnwire"
|
||||
"github.com/lightningnetwork/lnd/record"
|
||||
)
|
||||
|
||||
type decodePayloadTest struct {
|
||||
name string
|
||||
payload []byte
|
||||
expErr error
|
||||
name string
|
||||
payload []byte
|
||||
expErr error
|
||||
shouldHaveMPP bool
|
||||
}
|
||||
|
||||
var decodePayloadTests = []decodePayloadTest{
|
||||
@@ -79,9 +81,9 @@ var decodePayloadTests = []decodePayloadTest{
|
||||
},
|
||||
{
|
||||
name: "required type after omitted hop id",
|
||||
payload: []byte{0x02, 0x00, 0x04, 0x00, 0x08, 0x00},
|
||||
payload: []byte{0x02, 0x00, 0x04, 0x00, 0x0a, 0x00},
|
||||
expErr: hop.ErrInvalidPayload{
|
||||
Type: 8,
|
||||
Type: 10,
|
||||
Violation: hop.RequiredViolation,
|
||||
FinalHop: true,
|
||||
},
|
||||
@@ -89,10 +91,10 @@ var decodePayloadTests = []decodePayloadTest{
|
||||
{
|
||||
name: "required type after included hop id",
|
||||
payload: []byte{0x02, 0x00, 0x04, 0x00, 0x06, 0x08, 0x01, 0x00,
|
||||
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x08, 0x00,
|
||||
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0a, 0x00,
|
||||
},
|
||||
expErr: hop.ErrInvalidPayload{
|
||||
Type: 8,
|
||||
Type: 10,
|
||||
Violation: hop.RequiredViolation,
|
||||
FinalHop: false,
|
||||
},
|
||||
@@ -112,7 +114,7 @@ var decodePayloadTests = []decodePayloadTest{
|
||||
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||
},
|
||||
expErr: hop.ErrInvalidPayload{
|
||||
Type: 6,
|
||||
Type: record.NextHopOnionType,
|
||||
Violation: hop.IncludedViolation,
|
||||
FinalHop: true,
|
||||
},
|
||||
@@ -128,6 +130,60 @@ var decodePayloadTests = []decodePayloadTest{
|
||||
FinalHop: false,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "valid intermediate hop",
|
||||
payload: []byte{0x02, 0x00, 0x04, 0x00, 0x06, 0x08, 0x01, 0x00,
|
||||
0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||
},
|
||||
expErr: nil,
|
||||
},
|
||||
{
|
||||
name: "valid final hop",
|
||||
payload: []byte{0x02, 0x00, 0x04, 0x00},
|
||||
expErr: nil,
|
||||
},
|
||||
{
|
||||
name: "intermediate hop with mpp",
|
||||
payload: []byte{
|
||||
// amount
|
||||
0x02, 0x00,
|
||||
// cltv
|
||||
0x04, 0x00,
|
||||
// next hop id
|
||||
0x06, 0x08,
|
||||
0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||
// mpp
|
||||
0x08, 0x21,
|
||||
0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11,
|
||||
0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11,
|
||||
0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11,
|
||||
0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11,
|
||||
0x08,
|
||||
},
|
||||
expErr: hop.ErrInvalidPayload{
|
||||
Type: record.MPPOnionType,
|
||||
Violation: hop.IncludedViolation,
|
||||
FinalHop: false,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "final hop with mpp",
|
||||
payload: []byte{
|
||||
// amount
|
||||
0x02, 0x00,
|
||||
// cltv
|
||||
0x04, 0x00,
|
||||
// mpp
|
||||
0x08, 0x21,
|
||||
0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11,
|
||||
0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11,
|
||||
0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11,
|
||||
0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11,
|
||||
0x08,
|
||||
},
|
||||
expErr: nil,
|
||||
shouldHaveMPP: true,
|
||||
},
|
||||
}
|
||||
|
||||
// TestDecodeHopPayloadRecordValidation asserts that parsing the payloads in the
|
||||
@@ -142,9 +198,37 @@ func TestDecodeHopPayloadRecordValidation(t *testing.T) {
|
||||
}
|
||||
|
||||
func testDecodeHopPayloadValidation(t *testing.T, test decodePayloadTest) {
|
||||
_, err := hop.NewPayloadFromReader(bytes.NewReader(test.payload))
|
||||
var (
|
||||
testTotalMsat = lnwire.MilliSatoshi(8)
|
||||
testAddr = [32]byte{
|
||||
0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11,
|
||||
0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11,
|
||||
0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11,
|
||||
0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11,
|
||||
}
|
||||
)
|
||||
|
||||
p, err := hop.NewPayloadFromReader(bytes.NewReader(test.payload))
|
||||
if !reflect.DeepEqual(test.expErr, err) {
|
||||
t.Fatalf("expected error mismatch, want: %v, got: %v",
|
||||
test.expErr, err)
|
||||
}
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Assert MPP fields if we expect them.
|
||||
if test.shouldHaveMPP {
|
||||
if p.MPP == nil {
|
||||
t.Fatalf("payload should have MPP record")
|
||||
}
|
||||
if p.MPP.TotalMsat() != testTotalMsat {
|
||||
t.Fatalf("invalid total msat")
|
||||
}
|
||||
if p.MPP.PaymentAddr() != testAddr {
|
||||
t.Fatalf("invalid payment addr")
|
||||
}
|
||||
} else if p.MPP != nil {
|
||||
t.Fatalf("unexpected MPP payload")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -27,7 +27,7 @@ type InvoiceDatabase interface {
|
||||
NotifyExitHopHtlc(payHash lntypes.Hash, paidAmount lnwire.MilliSatoshi,
|
||||
expiry uint32, currentHeight int32,
|
||||
circuitKey channeldb.CircuitKey, hodlChan chan<- interface{},
|
||||
eob []byte) (*invoices.HodlEvent, error)
|
||||
payload invoices.Payload) (*invoices.HodlEvent, error)
|
||||
|
||||
// CancelInvoice attempts to cancel the invoice corresponding to the
|
||||
// passed payment hash.
|
||||
|
||||
@@ -2642,7 +2642,7 @@ func (l *channelLink) processRemoteAdds(fwdPkg *channeldb.FwdPkg,
|
||||
|
||||
heightNow := l.cfg.Switch.BestHeight()
|
||||
|
||||
fwdInfo, err := chanIterator.ForwardingInstructions()
|
||||
pld, err := chanIterator.HopPayload()
|
||||
if err != nil {
|
||||
// If we're unable to process the onion payload, or we
|
||||
// received invalid onion payload failure, then we
|
||||
@@ -2671,11 +2671,12 @@ func (l *channelLink) processRemoteAdds(fwdPkg *channeldb.FwdPkg,
|
||||
continue
|
||||
}
|
||||
|
||||
fwdInfo := pld.ForwardingInfo()
|
||||
|
||||
switch fwdInfo.NextHop {
|
||||
case hop.Exit:
|
||||
updated, err := l.processExitHop(
|
||||
pd, obfuscator, fwdInfo, heightNow,
|
||||
chanIterator.ExtraOnionBlob(),
|
||||
pd, obfuscator, fwdInfo, heightNow, pld,
|
||||
)
|
||||
if err != nil {
|
||||
l.fail(LinkFailureError{code: ErrInternalError},
|
||||
@@ -2844,7 +2845,7 @@ func (l *channelLink) processRemoteAdds(fwdPkg *channeldb.FwdPkg,
|
||||
// returns a boolean indicating whether the commitment tx needs an update.
|
||||
func (l *channelLink) processExitHop(pd *lnwallet.PaymentDescriptor,
|
||||
obfuscator hop.ErrorEncrypter, fwdInfo hop.ForwardingInfo,
|
||||
heightNow uint32, eob []byte) (bool, error) {
|
||||
heightNow uint32, payload invoices.Payload) (bool, error) {
|
||||
|
||||
// If hodl.ExitSettle is requested, we will not validate the final hop's
|
||||
// ADD, nor will we settle the corresponding invoice or respond with the
|
||||
@@ -2895,7 +2896,7 @@ func (l *channelLink) processExitHop(pd *lnwallet.PaymentDescriptor,
|
||||
|
||||
event, err := l.cfg.Registry.NotifyExitHopHtlc(
|
||||
invoiceHash, pd.Amount, pd.Timeout, int32(heightNow),
|
||||
circuitKey, l.hodlQueue.ChanIn(), eob,
|
||||
circuitKey, l.hodlQueue.ChanIn(), payload,
|
||||
)
|
||||
|
||||
switch err {
|
||||
|
||||
@@ -22,6 +22,7 @@ import (
|
||||
"github.com/coreos/bbolt"
|
||||
"github.com/davecgh/go-spew/spew"
|
||||
"github.com/go-errors/errors"
|
||||
sphinx "github.com/lightningnetwork/lightning-onion"
|
||||
"github.com/lightningnetwork/lnd/build"
|
||||
"github.com/lightningnetwork/lnd/channeldb"
|
||||
"github.com/lightningnetwork/lnd/contractcourt"
|
||||
@@ -563,7 +564,7 @@ func TestExitNodeTimelockPayloadMismatch(t *testing.T) {
|
||||
// per-hop payload for outgoing time lock to be the incorrect value.
|
||||
// The proper value of the outgoing CLTV should be the policy set by
|
||||
// the receiving node, instead we set it to be a random value.
|
||||
hops[0].OutgoingCTLV = 500
|
||||
hops[0].FwdInfo.OutgoingCTLV = 500
|
||||
firstHop := n.firstBobChannelLink.ShortChanID()
|
||||
_, err = makePayment(
|
||||
n.aliceServer, n.bobServer, firstHop, hops, amount, htlcAmt,
|
||||
@@ -616,7 +617,7 @@ func TestExitNodeAmountPayloadMismatch(t *testing.T) {
|
||||
// per-hop payload for amount to be the incorrect value. The proper
|
||||
// value of the amount to forward should be the amount that the
|
||||
// receiving node expects to receive.
|
||||
hops[0].AmountToForward = 1
|
||||
hops[0].FwdInfo.AmountToForward = 1
|
||||
firstHop := n.firstBobChannelLink.ShortChanID()
|
||||
_, err = makePayment(
|
||||
n.aliceServer, n.bobServer, firstHop, hops, amount, htlcAmt,
|
||||
@@ -4354,13 +4355,13 @@ func generateHtlcAndInvoice(t *testing.T,
|
||||
|
||||
htlcAmt := lnwire.NewMSatFromSatoshis(10000)
|
||||
htlcExpiry := testStartingHeight + testInvoiceCltvExpiry
|
||||
hops := []hop.ForwardingInfo{
|
||||
{
|
||||
Network: hop.BitcoinNetwork,
|
||||
NextHop: hop.Exit,
|
||||
AmountToForward: htlcAmt,
|
||||
OutgoingCTLV: uint32(htlcExpiry),
|
||||
},
|
||||
hops := []*hop.Payload{
|
||||
hop.NewLegacyPayload(&sphinx.HopData{
|
||||
Realm: [1]byte{}, // hop.BitcoinNetwork
|
||||
NextAddress: [8]byte{}, // hop.Exit,
|
||||
ForwardAmount: uint64(htlcAmt),
|
||||
OutgoingCltv: uint32(htlcExpiry),
|
||||
}),
|
||||
}
|
||||
blob, err := generateRoute(hops...)
|
||||
if err != nil {
|
||||
|
||||
@@ -265,16 +265,14 @@ func (s *mockServer) QuitSignal() <-chan struct{} {
|
||||
// mockHopIterator represents the test version of hop iterator which instead
|
||||
// of encrypting the path in onion blob just stores the path as a list of hops.
|
||||
type mockHopIterator struct {
|
||||
hops []hop.ForwardingInfo
|
||||
hops []*hop.Payload
|
||||
}
|
||||
|
||||
func newMockHopIterator(hops ...hop.ForwardingInfo) hop.Iterator {
|
||||
func newMockHopIterator(hops ...*hop.Payload) hop.Iterator {
|
||||
return &mockHopIterator{hops: hops}
|
||||
}
|
||||
|
||||
func (r *mockHopIterator) ForwardingInstructions() (
|
||||
hop.ForwardingInfo, error) {
|
||||
|
||||
func (r *mockHopIterator) HopPayload() (*hop.Payload, error) {
|
||||
h := r.hops[0]
|
||||
r.hops = r.hops[1:]
|
||||
return h, nil
|
||||
@@ -300,7 +298,8 @@ func (r *mockHopIterator) EncodeNextHop(w io.Writer) error {
|
||||
}
|
||||
|
||||
for _, hop := range r.hops {
|
||||
if err := encodeFwdInfo(w, &hop); err != nil {
|
||||
fwdInfo := hop.ForwardingInfo()
|
||||
if err := encodeFwdInfo(w, &fwdInfo); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
@@ -434,14 +433,22 @@ func (p *mockIteratorDecoder) DecodeHopIterator(r io.Reader, rHash []byte,
|
||||
}
|
||||
hopLength := binary.BigEndian.Uint32(b[:])
|
||||
|
||||
hops := make([]hop.ForwardingInfo, hopLength)
|
||||
hops := make([]*hop.Payload, hopLength)
|
||||
for i := uint32(0); i < hopLength; i++ {
|
||||
f := &hop.ForwardingInfo{}
|
||||
if err := decodeFwdInfo(r, f); err != nil {
|
||||
var f hop.ForwardingInfo
|
||||
if err := decodeFwdInfo(r, &f); err != nil {
|
||||
return nil, lnwire.CodeTemporaryChannelFailure
|
||||
}
|
||||
|
||||
hops[i] = *f
|
||||
var nextHopBytes [8]byte
|
||||
binary.BigEndian.PutUint64(nextHopBytes[:], f.NextHop.ToUint64())
|
||||
|
||||
hops[i] = hop.NewLegacyPayload(&sphinx.HopData{
|
||||
Realm: [1]byte{}, // hop.BitcoinNetwork
|
||||
NextAddress: nextHopBytes,
|
||||
ForwardAmount: uint64(f.AmountToForward),
|
||||
OutgoingCltv: f.OutgoingCTLV,
|
||||
})
|
||||
}
|
||||
|
||||
return newMockHopIterator(hops...), lnwire.CodeNone
|
||||
@@ -807,10 +814,11 @@ func (i *mockInvoiceRegistry) SettleHodlInvoice(preimage lntypes.Preimage) error
|
||||
func (i *mockInvoiceRegistry) NotifyExitHopHtlc(rhash lntypes.Hash,
|
||||
amt lnwire.MilliSatoshi, expiry uint32, currentHeight int32,
|
||||
circuitKey channeldb.CircuitKey, hodlChan chan<- interface{},
|
||||
eob []byte) (*invoices.HodlEvent, error) {
|
||||
payload invoices.Payload) (*invoices.HodlEvent, error) {
|
||||
|
||||
event, err := i.registry.NotifyExitHopHtlc(
|
||||
rhash, amt, expiry, currentHeight, circuitKey, hodlChan, eob,
|
||||
rhash, amt, expiry, currentHeight, circuitKey, hodlChan,
|
||||
payload,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
||||
@@ -23,6 +23,7 @@ import (
|
||||
"github.com/btcsuite/fastsha256"
|
||||
"github.com/coreos/bbolt"
|
||||
"github.com/go-errors/errors"
|
||||
sphinx "github.com/lightningnetwork/lightning-onion"
|
||||
"github.com/lightningnetwork/lnd/channeldb"
|
||||
"github.com/lightningnetwork/lnd/contractcourt"
|
||||
"github.com/lightningnetwork/lnd/htlcswitch/hop"
|
||||
@@ -601,7 +602,7 @@ func generatePayment(invoiceAmt, htlcAmt lnwire.MilliSatoshi, timelock uint32,
|
||||
}
|
||||
|
||||
// generateRoute generates the path blob by given array of peers.
|
||||
func generateRoute(hops ...hop.ForwardingInfo) (
|
||||
func generateRoute(hops ...*hop.Payload) (
|
||||
[lnwire.OnionPacketSize]byte, error) {
|
||||
|
||||
var blob [lnwire.OnionPacketSize]byte
|
||||
@@ -642,13 +643,12 @@ type threeHopNetwork struct {
|
||||
// also the time lock value needed to route an HTLC with the target amount over
|
||||
// the specified path.
|
||||
func generateHops(payAmt lnwire.MilliSatoshi, startingHeight uint32,
|
||||
path ...*channelLink) (lnwire.MilliSatoshi, uint32,
|
||||
[]hop.ForwardingInfo) {
|
||||
path ...*channelLink) (lnwire.MilliSatoshi, uint32, []*hop.Payload) {
|
||||
|
||||
totalTimelock := startingHeight
|
||||
runningAmt := payAmt
|
||||
|
||||
hops := make([]hop.ForwardingInfo, len(path))
|
||||
hops := make([]*hop.Payload, len(path))
|
||||
for i := len(path) - 1; i >= 0; i-- {
|
||||
// If this is the last hop, then the next hop is the special
|
||||
// "exit node". Otherwise, we look to the "prior" hop.
|
||||
@@ -676,7 +676,7 @@ func generateHops(payAmt lnwire.MilliSatoshi, startingHeight uint32,
|
||||
amount := payAmt
|
||||
if i != len(path)-1 {
|
||||
prevHop := hops[i+1]
|
||||
prevAmount := prevHop.AmountToForward
|
||||
prevAmount := prevHop.ForwardingInfo().AmountToForward
|
||||
|
||||
fee := ExpectedFee(path[i].cfg.FwrdingPolicy, prevAmount)
|
||||
runningAmt += fee
|
||||
@@ -687,12 +687,15 @@ func generateHops(payAmt lnwire.MilliSatoshi, startingHeight uint32,
|
||||
amount = runningAmt - fee
|
||||
}
|
||||
|
||||
hops[i] = hop.ForwardingInfo{
|
||||
Network: hop.BitcoinNetwork,
|
||||
NextHop: nextHop,
|
||||
AmountToForward: amount,
|
||||
OutgoingCTLV: timeLock,
|
||||
}
|
||||
var nextHopBytes [8]byte
|
||||
binary.BigEndian.PutUint64(nextHopBytes[:], nextHop.ToUint64())
|
||||
|
||||
hops[i] = hop.NewLegacyPayload(&sphinx.HopData{
|
||||
Realm: [1]byte{}, // hop.BitcoinNetwork
|
||||
NextAddress: nextHopBytes,
|
||||
ForwardAmount: uint64(amount),
|
||||
OutgoingCltv: timeLock,
|
||||
})
|
||||
}
|
||||
|
||||
return runningAmt, totalTimelock, hops
|
||||
@@ -739,7 +742,7 @@ func waitForPayFuncResult(payFunc func() error, d time.Duration) error {
|
||||
// * from Alice to Carol through the Bob
|
||||
// * from Alice to some another peer through the Bob
|
||||
func makePayment(sendingPeer, receivingPeer lnpeer.Peer,
|
||||
firstHop lnwire.ShortChannelID, hops []hop.ForwardingInfo,
|
||||
firstHop lnwire.ShortChannelID, hops []*hop.Payload,
|
||||
invoiceAmt, htlcAmt lnwire.MilliSatoshi,
|
||||
timelock uint32) *paymentResponse {
|
||||
|
||||
@@ -773,7 +776,7 @@ func makePayment(sendingPeer, receivingPeer lnpeer.Peer,
|
||||
// preparePayment creates an invoice at the receivingPeer and returns a function
|
||||
// that, when called, launches the payment from the sendingPeer.
|
||||
func preparePayment(sendingPeer, receivingPeer lnpeer.Peer,
|
||||
firstHop lnwire.ShortChannelID, hops []hop.ForwardingInfo,
|
||||
firstHop lnwire.ShortChannelID, hops []*hop.Payload,
|
||||
invoiceAmt, htlcAmt lnwire.MilliSatoshi,
|
||||
timelock uint32) (*channeldb.Invoice, func() error, error) {
|
||||
|
||||
@@ -1265,7 +1268,7 @@ func (n *twoHopNetwork) stop() {
|
||||
}
|
||||
|
||||
func (n *twoHopNetwork) makeHoldPayment(sendingPeer, receivingPeer lnpeer.Peer,
|
||||
firstHop lnwire.ShortChannelID, hops []hop.ForwardingInfo,
|
||||
firstHop lnwire.ShortChannelID, hops []*hop.Payload,
|
||||
invoiceAmt, htlcAmt lnwire.MilliSatoshi,
|
||||
timelock uint32, preimage lntypes.Preimage) chan error {
|
||||
|
||||
|
||||
11
invoices/interface.go
Normal file
11
invoices/interface.go
Normal file
@@ -0,0 +1,11 @@
|
||||
package invoices
|
||||
|
||||
import "github.com/lightningnetwork/lnd/record"
|
||||
|
||||
// Payload abstracts access to any additional fields provided in the final hop's
|
||||
// TLV onion payload.
|
||||
type Payload interface {
|
||||
// MultiPath returns the record corresponding the option_mpp parsed from
|
||||
// the onion payload.
|
||||
MultiPath() *record.MPP
|
||||
}
|
||||
@@ -429,7 +429,7 @@ func (i *InvoiceRegistry) LookupInvoice(rHash lntypes.Hash) (channeldb.Invoice,
|
||||
func (i *InvoiceRegistry) NotifyExitHopHtlc(rHash lntypes.Hash,
|
||||
amtPaid lnwire.MilliSatoshi, expiry uint32, currentHeight int32,
|
||||
circuitKey channeldb.CircuitKey, hodlChan chan<- interface{},
|
||||
eob []byte) (*HodlEvent, error) {
|
||||
payload Payload) (*HodlEvent, error) {
|
||||
|
||||
i.Lock()
|
||||
defer i.Unlock()
|
||||
|
||||
@@ -13,7 +13,9 @@ import (
|
||||
"github.com/btcsuite/btcd/chaincfg"
|
||||
"github.com/btcsuite/btcutil"
|
||||
"github.com/lightningnetwork/lnd/lnrpc"
|
||||
"github.com/lightningnetwork/lnd/lntypes"
|
||||
"github.com/lightningnetwork/lnd/lnwire"
|
||||
"github.com/lightningnetwork/lnd/record"
|
||||
"github.com/lightningnetwork/lnd/routing"
|
||||
"github.com/lightningnetwork/lnd/routing/route"
|
||||
"github.com/lightningnetwork/lnd/tlv"
|
||||
@@ -352,6 +354,17 @@ func (r *RouterBackend) MarshallRoute(route *route.Route) (*lnrpc.Route, error)
|
||||
chanCapacity = incomingAmt.ToSatoshis()
|
||||
}
|
||||
|
||||
// Extract the MPP fields if present on this hop.
|
||||
var mpp *lnrpc.MPPRecord
|
||||
if hop.MPP != nil {
|
||||
addr := hop.MPP.PaymentAddr()
|
||||
|
||||
mpp = &lnrpc.MPPRecord{
|
||||
PaymentAddr: addr[:],
|
||||
TotalAmtMsat: int64(hop.MPP.TotalMsat()),
|
||||
}
|
||||
}
|
||||
|
||||
resp.Hops[i] = &lnrpc.Hop{
|
||||
ChanId: hop.ChannelID,
|
||||
ChanCapacity: int64(chanCapacity),
|
||||
@@ -364,6 +377,7 @@ func (r *RouterBackend) MarshallRoute(route *route.Route) (*lnrpc.Route, error)
|
||||
hop.PubKeyBytes[:],
|
||||
),
|
||||
TlvPayload: !hop.LegacyPayload,
|
||||
MppRecord: mpp,
|
||||
}
|
||||
incomingAmt = hop.AmtToForward
|
||||
}
|
||||
@@ -396,6 +410,11 @@ func (r *RouterBackend) UnmarshallHopByChannelLookup(hop *lnrpc.Hop,
|
||||
|
||||
var tlvRecords []tlv.Record
|
||||
|
||||
mpp, err := UnmarshalMPP(hop.MppRecord)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &route.Hop{
|
||||
OutgoingTimeLock: hop.Expiry,
|
||||
AmtToForward: lnwire.MilliSatoshi(hop.AmtToForwardMsat),
|
||||
@@ -403,6 +422,7 @@ func (r *RouterBackend) UnmarshallHopByChannelLookup(hop *lnrpc.Hop,
|
||||
ChannelID: hop.ChanId,
|
||||
TLVRecords: tlvRecords,
|
||||
LegacyPayload: !hop.TlvPayload,
|
||||
MPP: mpp,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -420,6 +440,11 @@ func UnmarshallKnownPubkeyHop(hop *lnrpc.Hop) (*route.Hop, error) {
|
||||
|
||||
var tlvRecords []tlv.Record
|
||||
|
||||
mpp, err := UnmarshalMPP(hop.MppRecord)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &route.Hop{
|
||||
OutgoingTimeLock: hop.Expiry,
|
||||
AmtToForward: lnwire.MilliSatoshi(hop.AmtToForwardMsat),
|
||||
@@ -427,6 +452,7 @@ func UnmarshallKnownPubkeyHop(hop *lnrpc.Hop) (*route.Hop, error) {
|
||||
ChannelID: hop.ChanId,
|
||||
TLVRecords: tlvRecords,
|
||||
LegacyPayload: !hop.TlvPayload,
|
||||
MPP: mpp,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -712,3 +738,45 @@ func ValidateCLTVLimit(val, max uint32) (uint32, error) {
|
||||
return val, nil
|
||||
}
|
||||
}
|
||||
|
||||
// UnmarshalMPP accepts the mpp_total_amt_msat and mpp_payment_addr fields from
|
||||
// an RPC request and converts into an record.MPP object. An error is returned
|
||||
// if the payment address is not 0 or 32 bytes. If the total amount and payment
|
||||
// address are zero-value, the return value will be nil signaling there is no
|
||||
// MPP record to attach to this hop. Otherwise, a non-nil reocrd will be
|
||||
// contained combining the provided values.
|
||||
func UnmarshalMPP(reqMPP *lnrpc.MPPRecord) (*record.MPP, error) {
|
||||
// If no MPP record was submitted, assume the user wants to send a
|
||||
// regular payment.
|
||||
if reqMPP == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
reqTotal := reqMPP.TotalAmtMsat
|
||||
reqAddr := reqMPP.PaymentAddr
|
||||
|
||||
switch {
|
||||
|
||||
// No MPP fields were provided.
|
||||
case reqTotal == 0 && len(reqAddr) == 0:
|
||||
return nil, fmt.Errorf("missing total_msat and payment_addr")
|
||||
|
||||
// Total is present, but payment address is missing.
|
||||
case reqTotal > 0 && len(reqAddr) == 0:
|
||||
return nil, fmt.Errorf("missing payment_addr")
|
||||
|
||||
// Payment address is present, but total is missing.
|
||||
case reqTotal == 0 && len(reqAddr) > 0:
|
||||
return nil, fmt.Errorf("missing total_msat")
|
||||
}
|
||||
|
||||
addr, err := lntypes.MakeHash(reqAddr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to parse "+
|
||||
"payment_addr: %v", err)
|
||||
}
|
||||
|
||||
total := lnwire.MilliSatoshi(reqTotal)
|
||||
|
||||
return record.NewMPP(total, addr), nil
|
||||
}
|
||||
|
||||
@@ -180,3 +180,123 @@ func (m *mockMissionControl) GetPairHistorySnapshot(fromNode,
|
||||
|
||||
return routing.TimedPairResult{}
|
||||
}
|
||||
|
||||
type mppOutcome byte
|
||||
|
||||
const (
|
||||
valid mppOutcome = iota
|
||||
invalid
|
||||
nompp
|
||||
)
|
||||
|
||||
type unmarshalMPPTest struct {
|
||||
name string
|
||||
mpp *lnrpc.MPPRecord
|
||||
outcome mppOutcome
|
||||
}
|
||||
|
||||
// TestUnmarshalMPP checks both positive and negative cases of UnmarshalMPP to
|
||||
// assert that an MPP record is only returned when both fields are properly
|
||||
// specified. It also asserts that zero-values for both inputs is also valid,
|
||||
// but returns a nil record.
|
||||
func TestUnmarshalMPP(t *testing.T) {
|
||||
tests := []unmarshalMPPTest{
|
||||
{
|
||||
name: "nil record",
|
||||
mpp: nil,
|
||||
outcome: nompp,
|
||||
},
|
||||
{
|
||||
name: "invalid total or addr",
|
||||
mpp: &lnrpc.MPPRecord{
|
||||
PaymentAddr: nil,
|
||||
TotalAmtMsat: 0,
|
||||
},
|
||||
outcome: invalid,
|
||||
},
|
||||
{
|
||||
name: "valid total only",
|
||||
mpp: &lnrpc.MPPRecord{
|
||||
PaymentAddr: nil,
|
||||
TotalAmtMsat: 8,
|
||||
},
|
||||
outcome: invalid,
|
||||
},
|
||||
{
|
||||
name: "valid addr only",
|
||||
mpp: &lnrpc.MPPRecord{
|
||||
PaymentAddr: bytes.Repeat([]byte{0x02}, 32),
|
||||
TotalAmtMsat: 0,
|
||||
},
|
||||
outcome: invalid,
|
||||
},
|
||||
{
|
||||
name: "valid total and invalid addr",
|
||||
mpp: &lnrpc.MPPRecord{
|
||||
PaymentAddr: []byte{0x02},
|
||||
TotalAmtMsat: 8,
|
||||
},
|
||||
outcome: invalid,
|
||||
},
|
||||
{
|
||||
name: "valid total and valid addr",
|
||||
mpp: &lnrpc.MPPRecord{
|
||||
PaymentAddr: bytes.Repeat([]byte{0x02}, 32),
|
||||
TotalAmtMsat: 8,
|
||||
},
|
||||
outcome: valid,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
test := test
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
testUnmarshalMPP(t, test)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func testUnmarshalMPP(t *testing.T, test unmarshalMPPTest) {
|
||||
mpp, err := UnmarshalMPP(test.mpp)
|
||||
switch test.outcome {
|
||||
|
||||
// Valid arguments should result in no error, a non-nil MPP record, and
|
||||
// the fields should be set correctly.
|
||||
case valid:
|
||||
if err != nil {
|
||||
t.Fatalf("unable to parse mpp record: %v", err)
|
||||
}
|
||||
if mpp == nil {
|
||||
t.Fatalf("mpp payload should be non-nil")
|
||||
}
|
||||
if int64(mpp.TotalMsat()) != test.mpp.TotalAmtMsat {
|
||||
t.Fatalf("incorrect total msat")
|
||||
}
|
||||
addr := mpp.PaymentAddr()
|
||||
if !bytes.Equal(addr[:], test.mpp.PaymentAddr) {
|
||||
t.Fatalf("incorrect payment addr")
|
||||
}
|
||||
|
||||
// Invalid arguments should produce a failure and nil MPP record.
|
||||
case invalid:
|
||||
if err == nil {
|
||||
t.Fatalf("expected failure for invalid mpp")
|
||||
}
|
||||
if mpp != nil {
|
||||
t.Fatalf("mpp payload should be nil for failure")
|
||||
}
|
||||
|
||||
// Arguments that produce no MPP field should return no error and no MPP
|
||||
// record.
|
||||
case nompp:
|
||||
if err != nil {
|
||||
t.Fatalf("failure for args resulting for no-mpp")
|
||||
}
|
||||
if mpp != nil {
|
||||
t.Fatalf("mpp payload should be nil for no-mpp")
|
||||
}
|
||||
|
||||
default:
|
||||
t.Fatalf("test case has non-standard outcome")
|
||||
}
|
||||
}
|
||||
|
||||
1288
lnrpc/rpc.pb.go
1288
lnrpc/rpc.pb.go
File diff suppressed because it is too large
Load Diff
@@ -1836,6 +1836,32 @@ message Hop {
|
||||
TLV format.
|
||||
*/
|
||||
bool tlv_payload = 9 [json_name = "tlv_payload"];
|
||||
|
||||
/**
|
||||
An optional TLV record tha singals the use of an MPP payment. If present,
|
||||
the receiver will enforce that that the same mpp_record is included in the
|
||||
final hop payload of all non-zero payments in the HTLC set. If empty, a
|
||||
regular single-shot payment is or was attempted.
|
||||
*/
|
||||
MPPRecord mpp_record = 10 [json_name = "mpp_record"];
|
||||
}
|
||||
|
||||
message MPPRecord {
|
||||
/**
|
||||
A unique, random identifier used to authenticate the sender as the intended
|
||||
payer of a multi-path payment. The payment_addr must be the same for all
|
||||
subpayments, and match the payment_addr provided in the receiver's invoice.
|
||||
The same payment_addr must be used on all subpayments.
|
||||
*/
|
||||
bytes payment_addr = 11 [json_name = "payment_addr"];
|
||||
|
||||
/**
|
||||
The total amount in milli-satoshis being sent as part of a larger multi-path
|
||||
payment. The caller is responsible for ensuring subpayments to the same node
|
||||
and payment_hash sum exactly to total_amt_msat. The same
|
||||
total_amt_msat must be used on all subpayments.
|
||||
*/
|
||||
int64 total_amt_msat = 10 [json_name = "total_amt_msat"];
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@@ -2530,6 +2530,10 @@
|
||||
"type": "boolean",
|
||||
"format": "boolean",
|
||||
"description": "* \nIf set to true, then this hop will be encoded using the new variable length\nTLV format."
|
||||
},
|
||||
"mpp_record": {
|
||||
"$ref": "#/definitions/lnrpcMPPRecord",
|
||||
"description": "*\nAn optional TLV record tha singals the use of an MPP payment. If present,\nthe receiver will enforce that that the same mpp_record is included in the\nfinal hop payload of all non-zero payments in the HTLC set. If empty, a\nregular single-shot payment is or was attempted."
|
||||
}
|
||||
}
|
||||
},
|
||||
@@ -2873,6 +2877,21 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"lnrpcMPPRecord": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"payment_addr": {
|
||||
"type": "string",
|
||||
"format": "byte",
|
||||
"description": "*\nA unique, random identifier used to authenticate the sender as the intended\npayer of a multi-path payment. The payment_addr must be the same for all\nsubpayments, and match the payment_addr provided in the receiver's invoice.\nThe same payment_addr must be used on all subpayments."
|
||||
},
|
||||
"total_amt_msat": {
|
||||
"type": "string",
|
||||
"format": "int64",
|
||||
"description": "*\nThe total amount in milli-satoshis being sent as part of a larger multi-path\npayment. The caller is responsible for ensuring subpayments to the same node\nand payment_hash sum exactly to total_amt_msat. The same\ntotal_amt_msat must be used on all subpayments."
|
||||
}
|
||||
}
|
||||
},
|
||||
"lnrpcMacaroonPermission": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
|
||||
@@ -106,7 +106,6 @@ func (h *harnessTest) Fatalf(format string, a ...interface{}) {
|
||||
// RunTestCase executes a harness test case. Any errors or panics will be
|
||||
// represented as fatal.
|
||||
func (h *harnessTest) RunTestCase(testCase *testCase) {
|
||||
|
||||
h.testCase = testCase
|
||||
defer func() {
|
||||
h.testCase = nil
|
||||
@@ -4406,43 +4405,135 @@ func testMultiHopPayments(net *lntest.NetworkHarness, t *harnessTest) {
|
||||
closeChannelAndAssert(ctxt, t, net, carol, chanPointCarol, false)
|
||||
}
|
||||
|
||||
// testSingleHopSendToRoute tests that payments are properly processed
|
||||
// through a provided route with a single hop. We'll create the
|
||||
// following network topology:
|
||||
// Alice --100k--> Bob
|
||||
// We'll query the daemon for routes from Alice to Bob and then
|
||||
// send payments through the route.
|
||||
type singleHopSendToRouteCase struct {
|
||||
name string
|
||||
|
||||
// streaming tests streaming SendToRoute if true, otherwise tests
|
||||
// synchronous SenToRoute.
|
||||
streaming bool
|
||||
|
||||
// routerrpc submits the request to the routerrpc subserver if true,
|
||||
// otherwise submits to the main rpc server.
|
||||
routerrpc bool
|
||||
|
||||
// mpp sets the MPP fields on the request if true, otherwise submits a
|
||||
// regular payment.
|
||||
mpp bool
|
||||
}
|
||||
|
||||
var singleHopSendToRouteCases = []singleHopSendToRouteCase{
|
||||
{
|
||||
name: "regular main sync",
|
||||
},
|
||||
{
|
||||
name: "regular main stream",
|
||||
streaming: true,
|
||||
},
|
||||
{
|
||||
name: "regular routerrpc sync",
|
||||
routerrpc: true,
|
||||
},
|
||||
{
|
||||
name: "mpp main sync",
|
||||
mpp: true,
|
||||
},
|
||||
{
|
||||
name: "mpp main stream",
|
||||
streaming: true,
|
||||
mpp: true,
|
||||
},
|
||||
{
|
||||
name: "mpp routerrpc sync",
|
||||
routerrpc: true,
|
||||
mpp: true,
|
||||
},
|
||||
}
|
||||
|
||||
// testSingleHopSendToRoute tests that payments are properly processed through a
|
||||
// provided route with a single hop. We'll create the following network
|
||||
// topology:
|
||||
// Carol --100k--> Dave
|
||||
// We'll query the daemon for routes from Carol to Dave and then send payments
|
||||
// by feeding the route back into the various SendToRoute RPC methods. Here we
|
||||
// test all three SendToRoute endpoints, forcing each to perform both a regular
|
||||
// payment and an MPP payment.
|
||||
func testSingleHopSendToRoute(net *lntest.NetworkHarness, t *harnessTest) {
|
||||
ctxb := context.Background()
|
||||
for _, test := range singleHopSendToRouteCases {
|
||||
test := test
|
||||
|
||||
t.t.Run(test.name, func(t1 *testing.T) {
|
||||
ht := newHarnessTest(t1, t.lndHarness)
|
||||
ht.RunTestCase(&testCase{
|
||||
name: test.name,
|
||||
test: func(_ *lntest.NetworkHarness, tt *harnessTest) {
|
||||
testSingleHopSendToRouteCase(net, tt, test)
|
||||
},
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func testSingleHopSendToRouteCase(net *lntest.NetworkHarness, t *harnessTest,
|
||||
test singleHopSendToRouteCase) {
|
||||
|
||||
const chanAmt = btcutil.Amount(100000)
|
||||
const paymentAmtSat = 1000
|
||||
const numPayments = 5
|
||||
const amountPaid = int64(numPayments * paymentAmtSat)
|
||||
|
||||
ctxb := context.Background()
|
||||
var networkChans []*lnrpc.ChannelPoint
|
||||
|
||||
// Open a channel with 100k satoshis between Alice and Bob with Alice
|
||||
// Create Carol and Dave, then establish a channel between them. Carol
|
||||
// is the sole funder of the channel with 100k satoshis. The network
|
||||
// topology should look like:
|
||||
// Carol -> 100k -> Dave
|
||||
carol, err := net.NewNode("Carol", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create new nodes: %v", err)
|
||||
}
|
||||
defer shutdownAndAssert(net, t, carol)
|
||||
|
||||
dave, err := net.NewNode("Dave", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create new nodes: %v", err)
|
||||
}
|
||||
defer shutdownAndAssert(net, t, dave)
|
||||
|
||||
ctxt, _ := context.WithTimeout(ctxb, defaultTimeout)
|
||||
if err := net.ConnectNodes(ctxt, carol, dave); err != nil {
|
||||
t.Fatalf("unable to connect carol to dave: %v", err)
|
||||
}
|
||||
ctxt, _ = context.WithTimeout(ctxb, defaultTimeout)
|
||||
err = net.SendCoins(ctxt, btcutil.SatoshiPerBitcoin, carol)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to send coins to carol: %v", err)
|
||||
}
|
||||
|
||||
// Open a channel with 100k satoshis between Carol and Dave with Carol
|
||||
// being the sole funder of the channel.
|
||||
ctxt, _ := context.WithTimeout(ctxb, channelOpenTimeout)
|
||||
chanPointAlice := openChannelAndAssert(
|
||||
ctxt, t, net, net.Alice, net.Bob,
|
||||
ctxt, _ = context.WithTimeout(ctxb, channelOpenTimeout)
|
||||
chanPointCarol := openChannelAndAssert(
|
||||
ctxt, t, net, carol, dave,
|
||||
lntest.OpenChannelParams{
|
||||
Amt: chanAmt,
|
||||
},
|
||||
)
|
||||
networkChans = append(networkChans, chanPointAlice)
|
||||
networkChans = append(networkChans, chanPointCarol)
|
||||
|
||||
aliceChanTXID, err := lnd.GetChanPointFundingTxid(chanPointAlice)
|
||||
carolChanTXID, err := lnd.GetChanPointFundingTxid(chanPointCarol)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to get txid: %v", err)
|
||||
}
|
||||
aliceFundPoint := wire.OutPoint{
|
||||
Hash: *aliceChanTXID,
|
||||
Index: chanPointAlice.OutputIndex,
|
||||
carolFundPoint := wire.OutPoint{
|
||||
Hash: *carolChanTXID,
|
||||
Index: chanPointCarol.OutputIndex,
|
||||
}
|
||||
|
||||
// Wait for all nodes to have seen all channels.
|
||||
nodes := []*lntest.HarnessNode{net.Alice, net.Bob}
|
||||
nodeNames := []string{"Alice", "Bob"}
|
||||
nodes := []*lntest.HarnessNode{carol, dave}
|
||||
for _, chanPoint := range networkChans {
|
||||
for i, node := range nodes {
|
||||
for _, node := range nodes {
|
||||
txid, err := lnd.GetChanPointFundingTxid(chanPoint)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to get txid: %v", err)
|
||||
@@ -4456,111 +4547,241 @@ func testSingleHopSendToRoute(net *lntest.NetworkHarness, t *harnessTest) {
|
||||
err = node.WaitForNetworkChannelOpen(ctxt, chanPoint)
|
||||
if err != nil {
|
||||
t.Fatalf("%s(%d): timeout waiting for "+
|
||||
"channel(%s) open: %v", nodeNames[i],
|
||||
"channel(%s) open: %v", node.Name(),
|
||||
node.NodeID, point, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Query for routes to pay from Alice to Bob.
|
||||
// We set FinalCltvDelta to 40 since by default QueryRoutes returns
|
||||
// the last hop with a final cltv delta of 9 where as the default in
|
||||
// htlcswitch is 40.
|
||||
const paymentAmt = 1000
|
||||
routesReq := &lnrpc.QueryRoutesRequest{
|
||||
PubKey: net.Bob.PubKeyStr,
|
||||
Amt: paymentAmt,
|
||||
FinalCltvDelta: lnd.DefaultBitcoinTimeLockDelta,
|
||||
}
|
||||
ctxt, _ = context.WithTimeout(ctxb, defaultTimeout)
|
||||
routes, err := net.Alice.QueryRoutes(ctxt, routesReq)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to get route: %v", err)
|
||||
}
|
||||
|
||||
// Create 5 invoices for Bob, which expect a payment from Alice for 1k
|
||||
// satoshis with a different preimage each time.
|
||||
const numPayments = 5
|
||||
// Create invoices for Dave, which expect a payment from Carol.
|
||||
_, rHashes, _, err := createPayReqs(
|
||||
net.Bob, paymentAmt, numPayments,
|
||||
dave, paymentAmtSat, numPayments,
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create pay reqs: %v", err)
|
||||
}
|
||||
|
||||
// We'll wait for all parties to recognize the new channels within the
|
||||
// network.
|
||||
// Query for routes to pay from Carol to Dave.
|
||||
// We set FinalCltvDelta to 40 since by default QueryRoutes returns
|
||||
// the last hop with a final cltv delta of 9 where as the default in
|
||||
// htlcswitch is 40.
|
||||
routesReq := &lnrpc.QueryRoutesRequest{
|
||||
PubKey: dave.PubKeyStr,
|
||||
Amt: paymentAmtSat,
|
||||
FinalCltvDelta: lnd.DefaultBitcoinTimeLockDelta,
|
||||
}
|
||||
ctxt, _ = context.WithTimeout(ctxb, defaultTimeout)
|
||||
err = net.Bob.WaitForNetworkChannelOpen(ctxt, chanPointAlice)
|
||||
routes, err := carol.QueryRoutes(ctxt, routesReq)
|
||||
if err != nil {
|
||||
t.Fatalf("alice didn't advertise her channel in time: %v", err)
|
||||
t.Fatalf("unable to get route from %s: %v",
|
||||
carol.Name(), err)
|
||||
}
|
||||
|
||||
time.Sleep(time.Millisecond * 50)
|
||||
// There should only be one route to try, so take the first item.
|
||||
r := routes.Routes[0]
|
||||
|
||||
// Using Alice as the source, pay to the 5 invoices from Carol created
|
||||
// above.
|
||||
ctxt, _ = context.WithTimeout(ctxb, defaultTimeout)
|
||||
alicePayStream, err := net.Alice.SendToRoute(ctxt)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create payment stream for alice: %v", err)
|
||||
}
|
||||
// Construct a closure that will set MPP fields on the route, which
|
||||
// allows us to test MPP payments.
|
||||
setMPPFields := func(i int) {
|
||||
addr := [32]byte{byte(i)}
|
||||
|
||||
for _, rHash := range rHashes {
|
||||
sendReq := &lnrpc.SendToRouteRequest{
|
||||
PaymentHash: rHash,
|
||||
Route: routes.Routes[0],
|
||||
hop := r.Hops[len(r.Hops)-1]
|
||||
hop.TlvPayload = true
|
||||
hop.MppRecord = &lnrpc.MPPRecord{
|
||||
PaymentAddr: addr[:],
|
||||
TotalAmtMsat: paymentAmtSat * 1000,
|
||||
}
|
||||
err := alicePayStream.Send(sendReq)
|
||||
}
|
||||
|
||||
// Construct closures for each of the payment types covered:
|
||||
// - main rpc server sync
|
||||
// - main rpc server streaming
|
||||
// - routerrpc server sync
|
||||
sendToRouteSync := func() {
|
||||
for i, rHash := range rHashes {
|
||||
// Populate the MPP fields for the final hop if we are
|
||||
// testing MPP payments.
|
||||
if test.mpp {
|
||||
setMPPFields(i)
|
||||
}
|
||||
|
||||
sendReq := &lnrpc.SendToRouteRequest{
|
||||
PaymentHash: rHash,
|
||||
Route: r,
|
||||
}
|
||||
ctxt, _ = context.WithTimeout(ctxb, defaultTimeout)
|
||||
resp, err := carol.SendToRouteSync(
|
||||
ctxt, sendReq,
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to send to route for "+
|
||||
"%s: %v", carol.Name(), err)
|
||||
}
|
||||
if resp.PaymentError != "" {
|
||||
t.Fatalf("received payment error from %s: %v",
|
||||
carol.Name(), resp.PaymentError)
|
||||
}
|
||||
}
|
||||
}
|
||||
sendToRouteStream := func() {
|
||||
ctxt, _ = context.WithTimeout(ctxb, defaultTimeout)
|
||||
alicePayStream, err := carol.SendToRoute(ctxt)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to send payment: %v", err)
|
||||
t.Fatalf("unable to create payment stream for "+
|
||||
"carol: %v", err)
|
||||
}
|
||||
|
||||
for i, rHash := range rHashes {
|
||||
// Populate the MPP fields for the final hop if we are
|
||||
// testing MPP payments.
|
||||
if test.mpp {
|
||||
setMPPFields(i)
|
||||
}
|
||||
|
||||
sendReq := &lnrpc.SendToRouteRequest{
|
||||
PaymentHash: rHash,
|
||||
Route: routes.Routes[0],
|
||||
}
|
||||
err := alicePayStream.Send(sendReq)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("unable to send payment: %v", err)
|
||||
}
|
||||
|
||||
resp, err := alicePayStream.Recv()
|
||||
if err != nil {
|
||||
t.Fatalf("unable to send payment: %v", err)
|
||||
}
|
||||
if resp.PaymentError != "" {
|
||||
t.Fatalf("received payment error: %v",
|
||||
resp.PaymentError)
|
||||
}
|
||||
}
|
||||
}
|
||||
sendToRouteRouterRPC := func() {
|
||||
for i, rHash := range rHashes {
|
||||
// Populate the MPP fields for the final hop if we are
|
||||
// testing MPP payments.
|
||||
if test.mpp {
|
||||
setMPPFields(i)
|
||||
}
|
||||
|
||||
sendReq := &routerrpc.SendToRouteRequest{
|
||||
PaymentHash: rHash,
|
||||
Route: r,
|
||||
}
|
||||
ctxt, _ = context.WithTimeout(ctxb, defaultTimeout)
|
||||
resp, err := carol.RouterClient.SendToRoute(
|
||||
ctxt, sendReq,
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to send to route for "+
|
||||
"%s: %v", carol.Name(), err)
|
||||
}
|
||||
if resp.Failure != nil {
|
||||
t.Fatalf("received payment error from %s: %v",
|
||||
carol.Name(), resp.Failure)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for range rHashes {
|
||||
resp, err := alicePayStream.Recv()
|
||||
if err != nil {
|
||||
t.Fatalf("unable to send payment: %v", err)
|
||||
}
|
||||
if resp.PaymentError != "" {
|
||||
t.Fatalf("received payment error: %v", resp.PaymentError)
|
||||
}
|
||||
// Using Carol as the node as the source, send the payments
|
||||
// synchronously via the the routerrpc's SendToRoute, or via the main RPC
|
||||
// server's SendToRoute streaming or sync calls.
|
||||
switch {
|
||||
case !test.routerrpc && test.streaming:
|
||||
sendToRouteStream()
|
||||
case !test.routerrpc && !test.streaming:
|
||||
sendToRouteSync()
|
||||
case test.routerrpc && !test.streaming:
|
||||
sendToRouteRouterRPC()
|
||||
default:
|
||||
t.Fatalf("routerrpc does not support streaming send_to_route")
|
||||
}
|
||||
|
||||
req := &lnrpc.ListPaymentsRequest{}
|
||||
// Verify that the payment's from Carol's PoV have the correct payment
|
||||
// hash and amount.
|
||||
ctxt, _ = context.WithTimeout(ctxt, defaultTimeout)
|
||||
paymentsResp, err := net.Alice.ListPayments(ctxt, req)
|
||||
paymentsResp, err := carol.ListPayments(
|
||||
ctxt, &lnrpc.ListPaymentsRequest{},
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("error when obtaining Alice payments: %v", err)
|
||||
t.Fatalf("error when obtaining %s payments: %v",
|
||||
carol.Name(), err)
|
||||
}
|
||||
if len(paymentsResp.Payments) != 5 {
|
||||
if len(paymentsResp.Payments) != numPayments {
|
||||
t.Fatalf("incorrect number of payments, got %v, want %v",
|
||||
len(paymentsResp.Payments), 5)
|
||||
len(paymentsResp.Payments), numPayments)
|
||||
}
|
||||
|
||||
// Verify that the ListPayments displays the payment without an invoice
|
||||
// since the payment was completed with SendToRoute.
|
||||
for _, p := range paymentsResp.Payments {
|
||||
for i, p := range paymentsResp.Payments {
|
||||
// Assert that the payment hashes for each payment match up.
|
||||
rHashHex := hex.EncodeToString(rHashes[i])
|
||||
if p.PaymentHash != rHashHex {
|
||||
t.Fatalf("incorrect payment hash for payment %d, "+
|
||||
"want: %s got: %s",
|
||||
i, rHashHex, p.PaymentHash)
|
||||
}
|
||||
|
||||
// Assert that each payment has no invoice since the payment was
|
||||
// completed using SendToRoute.
|
||||
if p.PaymentRequest != "" {
|
||||
t.Fatalf("incorrect payreq, want: \"\", got: %v",
|
||||
p.PaymentRequest)
|
||||
t.Fatalf("incorrect payment request for payment: %d, "+
|
||||
"want: \"\", got: %s",
|
||||
i, p.PaymentRequest)
|
||||
}
|
||||
|
||||
// Assert the payment ammount is correct.
|
||||
if p.ValueSat != paymentAmtSat {
|
||||
t.Fatalf("incorrect payment amt for payment %d, "+
|
||||
"want: %d, got: %d",
|
||||
i, paymentAmtSat, p.ValueSat)
|
||||
}
|
||||
}
|
||||
|
||||
// Verify that the invoices's from Dave's PoV have the correct payment
|
||||
// hash and amount.
|
||||
ctxt, _ = context.WithTimeout(ctxt, defaultTimeout)
|
||||
invoicesResp, err := dave.ListInvoices(
|
||||
ctxt, &lnrpc.ListInvoiceRequest{},
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("error when obtaining %s payments: %v",
|
||||
dave.Name(), err)
|
||||
}
|
||||
if len(invoicesResp.Invoices) != numPayments {
|
||||
t.Fatalf("incorrect number of invoices, got %v, want %v",
|
||||
len(invoicesResp.Invoices), numPayments)
|
||||
}
|
||||
|
||||
for i, inv := range invoicesResp.Invoices {
|
||||
// Assert that the payment hashes match up.
|
||||
if !bytes.Equal(inv.RHash, rHashes[i]) {
|
||||
t.Fatalf("incorrect payment hash for invoice %d, "+
|
||||
"want: %x got: %x",
|
||||
i, rHashes[i], inv.RHash)
|
||||
}
|
||||
|
||||
// Assert that the amount paid to the invoice is correct.
|
||||
if inv.AmtPaidSat != paymentAmtSat {
|
||||
t.Fatalf("incorrect payment amt for invoice %d, "+
|
||||
"want: %d, got %d",
|
||||
i, paymentAmtSat, inv.AmtPaidSat)
|
||||
}
|
||||
}
|
||||
|
||||
// At this point all the channels within our proto network should be
|
||||
// shifted by 5k satoshis in the direction of Bob, the sink within the
|
||||
// shifted by 5k satoshis in the direction of Dave, the sink within the
|
||||
// payment flow generated above. The order of asserts corresponds to
|
||||
// increasing of time is needed to embed the HTLC in commitment
|
||||
// transaction, in channel Alice->Bob, order is Bob and then Alice.
|
||||
const amountPaid = int64(5000)
|
||||
assertAmountPaid(t, "Alice(local) => Bob(remote)", net.Bob,
|
||||
aliceFundPoint, int64(0), amountPaid)
|
||||
assertAmountPaid(t, "Alice(local) => Bob(remote)", net.Alice,
|
||||
aliceFundPoint, amountPaid, int64(0))
|
||||
// transaction, in channel Carol->Dave, order is Dave and then Carol.
|
||||
assertAmountPaid(t, "Carol(local) => Dave(remote)", dave,
|
||||
carolFundPoint, int64(0), amountPaid)
|
||||
assertAmountPaid(t, "Carol(local) => Dave(remote)", carol,
|
||||
carolFundPoint, amountPaid, int64(0))
|
||||
|
||||
ctxt, _ = context.WithTimeout(ctxb, channelCloseTimeout)
|
||||
closeChannelAndAssert(ctxt, t, net, net.Alice, chanPointAlice, false)
|
||||
closeChannelAndAssert(ctxt, t, net, carol, chanPointCarol, false)
|
||||
}
|
||||
|
||||
// testMultiHopSendToRoute tests that payments are properly processed
|
||||
|
||||
98
record/mpp.go
Normal file
98
record/mpp.go
Normal file
@@ -0,0 +1,98 @@
|
||||
package record
|
||||
|
||||
import (
|
||||
"io"
|
||||
|
||||
"github.com/lightningnetwork/lnd/lnwire"
|
||||
"github.com/lightningnetwork/lnd/tlv"
|
||||
)
|
||||
|
||||
// MPPOnionType is the type used in the onion to reference the MPP fields:
|
||||
// total_amt and payment_addr.
|
||||
const MPPOnionType tlv.Type = 8
|
||||
|
||||
// MPP is a record that encodes the fields necessary for multi-path payments.
|
||||
type MPP struct {
|
||||
// paymentAddr is a random, receiver-generated value used to avoid
|
||||
// collisions with concurrent payers.
|
||||
paymentAddr [32]byte
|
||||
|
||||
// totalMsat is the total value of the payment, potentially spread
|
||||
// across more than one HTLC.
|
||||
totalMsat lnwire.MilliSatoshi
|
||||
}
|
||||
|
||||
// NewMPP generates a new MPP record with the given total and payment address.
|
||||
func NewMPP(total lnwire.MilliSatoshi, addr [32]byte) *MPP {
|
||||
return &MPP{
|
||||
paymentAddr: addr,
|
||||
totalMsat: total,
|
||||
}
|
||||
}
|
||||
|
||||
// PaymentAddr returns the payment address contained in the MPP record.
|
||||
func (r *MPP) PaymentAddr() [32]byte {
|
||||
return r.paymentAddr
|
||||
}
|
||||
|
||||
// TotalMsat returns the total value of an MPP payment in msats.
|
||||
func (r *MPP) TotalMsat() lnwire.MilliSatoshi {
|
||||
return r.totalMsat
|
||||
}
|
||||
|
||||
// MPPEncoder writes the MPP record to the provided io.Writer.
|
||||
func MPPEncoder(w io.Writer, val interface{}, buf *[8]byte) error {
|
||||
if v, ok := val.(*MPP); ok {
|
||||
err := tlv.EBytes32(w, &v.paymentAddr, buf)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return tlv.ETUint64T(w, uint64(v.totalMsat), buf)
|
||||
}
|
||||
return tlv.NewTypeForEncodingErr(val, "MPP")
|
||||
}
|
||||
|
||||
const (
|
||||
// minMPPLength is the minimum length of a serialized MPP TLV record,
|
||||
// which occurs when the truncated encoding of total_amt_msat takes 0
|
||||
// bytes, leaving only the payment_addr.
|
||||
minMPPLength = 32
|
||||
|
||||
// maxMPPLength is the maximum length of a serialized MPP TLV record,
|
||||
// which occurs when the truncated encoding of total_amt_msat takes 8
|
||||
// bytes.
|
||||
maxMPPLength = 40
|
||||
)
|
||||
|
||||
// MPPDecoder reads the MPP record to the provided io.Reader.
|
||||
func MPPDecoder(r io.Reader, val interface{}, buf *[8]byte, l uint64) error {
|
||||
if v, ok := val.(*MPP); ok && minMPPLength <= l && l <= maxMPPLength {
|
||||
if err := tlv.DBytes32(r, &v.paymentAddr, buf, 32); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var total uint64
|
||||
if err := tlv.DTUint64(r, &total, buf, l-32); err != nil {
|
||||
return err
|
||||
}
|
||||
v.totalMsat = lnwire.MilliSatoshi(total)
|
||||
|
||||
return nil
|
||||
|
||||
}
|
||||
return tlv.NewTypeForDecodingErr(val, "MPP", l, maxMPPLength)
|
||||
}
|
||||
|
||||
// Record returns a tlv.Record that can be used to encode or decode this record.
|
||||
func (r *MPP) Record() tlv.Record {
|
||||
// Fixed-size, 32 byte payment address followed by truncated 64-bit
|
||||
// total msat.
|
||||
size := func() uint64 {
|
||||
return 32 + tlv.SizeTUint64(uint64(r.totalMsat))
|
||||
}
|
||||
|
||||
return tlv.MakeDynamicRecord(
|
||||
MPPOnionType, r, size, MPPEncoder, MPPDecoder,
|
||||
)
|
||||
}
|
||||
73
record/record_test.go
Normal file
73
record/record_test.go
Normal file
@@ -0,0 +1,73 @@
|
||||
package record_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"testing"
|
||||
|
||||
"github.com/lightningnetwork/lnd/lnwire"
|
||||
"github.com/lightningnetwork/lnd/record"
|
||||
"github.com/lightningnetwork/lnd/tlv"
|
||||
)
|
||||
|
||||
type recordEncDecTest struct {
|
||||
name string
|
||||
encRecord func() tlv.RecordProducer
|
||||
decRecord func() tlv.RecordProducer
|
||||
assert func(*testing.T, interface{})
|
||||
}
|
||||
|
||||
var (
|
||||
testTotal = lnwire.MilliSatoshi(45)
|
||||
testAddr = [32]byte{0x01, 0x02}
|
||||
)
|
||||
|
||||
var recordEncDecTests = []recordEncDecTest{
|
||||
{
|
||||
name: "mpp",
|
||||
encRecord: func() tlv.RecordProducer {
|
||||
return record.NewMPP(testTotal, testAddr)
|
||||
},
|
||||
decRecord: func() tlv.RecordProducer {
|
||||
return new(record.MPP)
|
||||
},
|
||||
assert: func(t *testing.T, r interface{}) {
|
||||
mpp := r.(*record.MPP)
|
||||
if mpp.TotalMsat() != testTotal {
|
||||
t.Fatal("incorrect total msat")
|
||||
}
|
||||
if mpp.PaymentAddr() != testAddr {
|
||||
t.Fatal("incorrect payment addr")
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// TestRecordEncodeDecode is a generic test framework for custom TLV records. It
|
||||
// asserts that records can encode and decode themselves, and that the value of
|
||||
// the original record matches the decoded record.
|
||||
func TestRecordEncodeDecode(t *testing.T) {
|
||||
for _, test := range recordEncDecTests {
|
||||
test := test
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
r := test.encRecord()
|
||||
r2 := test.decRecord()
|
||||
encStream := tlv.MustNewStream(r.Record())
|
||||
decStream := tlv.MustNewStream(r2.Record())
|
||||
|
||||
test.assert(t, r)
|
||||
|
||||
var b bytes.Buffer
|
||||
err := encStream.Encode(&b)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to encode record: %v", err)
|
||||
}
|
||||
|
||||
err = decStream.Decode(bytes.NewReader(b.Bytes()))
|
||||
if err != nil {
|
||||
t.Fatalf("unable to decode record: %v", err)
|
||||
}
|
||||
|
||||
test.assert(t, r2)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"strconv"
|
||||
@@ -19,9 +20,17 @@ import (
|
||||
// VertexSize is the size of the array to store a vertex.
|
||||
const VertexSize = 33
|
||||
|
||||
// ErrNoRouteHopsProvided is returned when a caller attempts to construct a new
|
||||
// sphinx packet, but provides an empty set of hops for each route.
|
||||
var ErrNoRouteHopsProvided = fmt.Errorf("empty route hops provided")
|
||||
var (
|
||||
// ErrNoRouteHopsProvided is returned when a caller attempts to
|
||||
// construct a new sphinx packet, but provides an empty set of hops for
|
||||
// each route.
|
||||
ErrNoRouteHopsProvided = fmt.Errorf("empty route hops provided")
|
||||
|
||||
// ErrIntermediateMPPHop is returned when a hop tries to deliver an MPP
|
||||
// record to an intermediate hop, only final hops can receive MPP
|
||||
// records.
|
||||
ErrIntermediateMPPHop = errors.New("cannot send MPP to intermediate")
|
||||
)
|
||||
|
||||
// Vertex is a simple alias for the serialization of a compressed Bitcoin
|
||||
// public key.
|
||||
@@ -94,6 +103,10 @@ type Hop struct {
|
||||
// carries as a fee will be subtracted by the hop.
|
||||
AmtToForward lnwire.MilliSatoshi
|
||||
|
||||
// MPP encapsulates the data required for option_mpp. This field should
|
||||
// only be set for the final hop.
|
||||
MPP *record.MPP
|
||||
|
||||
// TLVRecords if non-nil are a set of additional TLV records that
|
||||
// should be included in the forwarding instructions for this node.
|
||||
TLVRecords []tlv.Record
|
||||
@@ -140,6 +153,17 @@ func (h *Hop) PackHopPayload(w io.Writer, nextChanID uint64) error {
|
||||
)
|
||||
}
|
||||
|
||||
// If an MPP record is destined for this hop, ensure that we only ever
|
||||
// attach it to the final hop. Otherwise the route was constructed
|
||||
// incorrectly.
|
||||
if h.MPP != nil {
|
||||
if nextChanID == 0 {
|
||||
records = append(records, h.MPP.Record())
|
||||
} else {
|
||||
return ErrIntermediateMPPHop
|
||||
}
|
||||
}
|
||||
|
||||
// Append any custom types destined for this hop.
|
||||
records = append(records, h.TLVRecords...)
|
||||
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
package route
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"testing"
|
||||
|
||||
"github.com/lightningnetwork/lnd/lnwire"
|
||||
"github.com/lightningnetwork/lnd/record"
|
||||
)
|
||||
|
||||
// TestRouteTotalFees checks that a route reports the expected total fee.
|
||||
@@ -56,3 +58,38 @@ func TestRouteTotalFees(t *testing.T) {
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
var (
|
||||
testAmt = lnwire.MilliSatoshi(1000)
|
||||
testAddr = [32]byte{0x01, 0x02}
|
||||
)
|
||||
|
||||
// TestMPPHop asserts that a Hop will encode a non-nil to final nodes, and fail
|
||||
// when trying to send to intermediaries.
|
||||
func TestMPPHop(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
hop := Hop{
|
||||
ChannelID: 1,
|
||||
OutgoingTimeLock: 44,
|
||||
AmtToForward: testAmt,
|
||||
LegacyPayload: false,
|
||||
MPP: record.NewMPP(testAmt, testAddr),
|
||||
}
|
||||
|
||||
// Encoding an MPP record to an intermediate hop should result in a
|
||||
// failure.
|
||||
var b bytes.Buffer
|
||||
err := hop.PackHopPayload(&b, 2)
|
||||
if err != ErrIntermediateMPPHop {
|
||||
t.Fatalf("expected err: %v, got: %v",
|
||||
ErrIntermediateMPPHop, err)
|
||||
}
|
||||
|
||||
// Encoding an MPP record to a final hop should be successful.
|
||||
b.Reset()
|
||||
err = hop.PackHopPayload(&b, 0)
|
||||
if err != nil {
|
||||
t.Fatalf("expected err: %v, got: %v", nil, err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -43,6 +43,14 @@ func SizeVarBytes(e *[]byte) SizeFunc {
|
||||
}
|
||||
}
|
||||
|
||||
// RecorderProducer is an interface for objects that can produce a Record object
|
||||
// capable of encoding and/or decoding the RecordProducer as a Record.
|
||||
type RecordProducer interface {
|
||||
// Record returns a Record that can be used to encode or decode the
|
||||
// backing object.
|
||||
Record() Record
|
||||
}
|
||||
|
||||
// Record holds the required information to encode or decode a TLV record.
|
||||
type Record struct {
|
||||
value interface{}
|
||||
@@ -77,6 +85,14 @@ func (f *Record) Encode(w io.Writer) error {
|
||||
return f.encoder(w, f.value, &b)
|
||||
}
|
||||
|
||||
// Decode read in the TLV record from the passed reader. This is useful when a
|
||||
// caller wants decode a *single* TLV record, outside the context of the Stream
|
||||
// struct.
|
||||
func (f *Record) Decode(r io.Reader, l uint64) error {
|
||||
var b [8]byte
|
||||
return f.decoder(r, f.value, &b, l)
|
||||
}
|
||||
|
||||
// MakePrimitiveRecord creates a record for common types.
|
||||
func MakePrimitiveRecord(typ Type, val interface{}) Record {
|
||||
var (
|
||||
|
||||
@@ -40,6 +40,15 @@ func ETUint16(w io.Writer, val interface{}, buf *[8]byte) error {
|
||||
return NewTypeForEncodingErr(val, "uint16")
|
||||
}
|
||||
|
||||
// ETUint16T is an Encoder for truncated uint16 values, where leading zeros will
|
||||
// be omitted. An error is returned if val is not a *uint16.
|
||||
func ETUint16T(w io.Writer, val uint16, buf *[8]byte) error {
|
||||
binary.BigEndian.PutUint16(buf[:2], val)
|
||||
numZeros := numLeadingZeroBytes16(val)
|
||||
_, err := w.Write(buf[numZeros:2])
|
||||
return err
|
||||
}
|
||||
|
||||
// DTUint16 is an Decoder for truncated uint16 values, where leading zeros will
|
||||
// be resurrected. An error is returned if val is not a *uint16.
|
||||
func DTUint16(r io.Reader, val interface{}, buf *[8]byte, l uint64) error {
|
||||
@@ -92,6 +101,15 @@ func ETUint32(w io.Writer, val interface{}, buf *[8]byte) error {
|
||||
return NewTypeForEncodingErr(val, "uint32")
|
||||
}
|
||||
|
||||
// ETUint32T is an Encoder for truncated uint32 values, where leading zeros will
|
||||
// be omitted. An error is returned if val is not a *uint32.
|
||||
func ETUint32T(w io.Writer, val uint32, buf *[8]byte) error {
|
||||
binary.BigEndian.PutUint32(buf[:4], val)
|
||||
numZeros := numLeadingZeroBytes32(val)
|
||||
_, err := w.Write(buf[numZeros:4])
|
||||
return err
|
||||
}
|
||||
|
||||
// DTUint32 is an Decoder for truncated uint32 values, where leading zeros will
|
||||
// be resurrected. An error is returned if val is not a *uint32.
|
||||
func DTUint32(r io.Reader, val interface{}, buf *[8]byte, l uint64) error {
|
||||
@@ -154,6 +172,15 @@ func ETUint64(w io.Writer, val interface{}, buf *[8]byte) error {
|
||||
return NewTypeForEncodingErr(val, "uint64")
|
||||
}
|
||||
|
||||
// ETUint64T is an Encoder for truncated uint64 values, where leading zeros will
|
||||
// be omitted. An error is returned if val is not a *uint64.
|
||||
func ETUint64T(w io.Writer, val uint64, buf *[8]byte) error {
|
||||
binary.BigEndian.PutUint64(buf[:], val)
|
||||
numZeros := numLeadingZeroBytes64(val)
|
||||
_, err := w.Write(buf[numZeros:])
|
||||
return err
|
||||
}
|
||||
|
||||
// DTUint64 is an Decoder for truncated uint64 values, where leading zeros will
|
||||
// be resurrected. An error is returned if val is not a *uint64.
|
||||
func DTUint64(r io.Reader, val interface{}, buf *[8]byte, l uint64) error {
|
||||
|
||||
@@ -60,6 +60,8 @@ func TestSizeTUint16(t *testing.T) {
|
||||
func TestTUint16(t *testing.T) {
|
||||
var buf [8]byte
|
||||
for _, test := range tuint16Tests {
|
||||
test := test
|
||||
|
||||
if len(test.bytes) != int(test.size) {
|
||||
t.Fatalf("invalid test case, "+
|
||||
"len(bytes)[%d] != size[%d]",
|
||||
@@ -68,6 +70,7 @@ func TestTUint16(t *testing.T) {
|
||||
|
||||
name := fmt.Sprintf("0x%x", test.value)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
// Test generic encoder.
|
||||
var b bytes.Buffer
|
||||
err := tlv.ETUint16(&b, &test.value, &buf)
|
||||
if err != nil {
|
||||
@@ -80,6 +83,19 @@ func TestTUint16(t *testing.T) {
|
||||
test.bytes, b.Bytes())
|
||||
}
|
||||
|
||||
// Test non-generic encoder.
|
||||
var b2 bytes.Buffer
|
||||
err = tlv.ETUint16T(&b2, test.value, &buf)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to encode tuint16: %v", err)
|
||||
}
|
||||
|
||||
if !bytes.Equal(b2.Bytes(), test.bytes) {
|
||||
t.Fatalf("encoding mismatch, "+
|
||||
"expected: %x, got: %x",
|
||||
test.bytes, b2.Bytes())
|
||||
}
|
||||
|
||||
var value uint16
|
||||
r := bytes.NewReader(b.Bytes())
|
||||
err = tlv.DTUint16(r, &value, &buf, test.size)
|
||||
@@ -168,6 +184,8 @@ func TestSizeTUint32(t *testing.T) {
|
||||
func TestTUint32(t *testing.T) {
|
||||
var buf [8]byte
|
||||
for _, test := range tuint32Tests {
|
||||
test := test
|
||||
|
||||
if len(test.bytes) != int(test.size) {
|
||||
t.Fatalf("invalid test case, "+
|
||||
"len(bytes)[%d] != size[%d]",
|
||||
@@ -176,6 +194,7 @@ func TestTUint32(t *testing.T) {
|
||||
|
||||
name := fmt.Sprintf("0x%x", test.value)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
// Test generic encoder.
|
||||
var b bytes.Buffer
|
||||
err := tlv.ETUint32(&b, &test.value, &buf)
|
||||
if err != nil {
|
||||
@@ -188,6 +207,19 @@ func TestTUint32(t *testing.T) {
|
||||
test.bytes, b.Bytes())
|
||||
}
|
||||
|
||||
// Test non-generic encoder.
|
||||
var b2 bytes.Buffer
|
||||
err = tlv.ETUint32T(&b2, test.value, &buf)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to encode tuint32: %v", err)
|
||||
}
|
||||
|
||||
if !bytes.Equal(b2.Bytes(), test.bytes) {
|
||||
t.Fatalf("encoding mismatch, "+
|
||||
"expected: %x, got: %x",
|
||||
test.bytes, b2.Bytes())
|
||||
}
|
||||
|
||||
var value uint32
|
||||
r := bytes.NewReader(b.Bytes())
|
||||
err = tlv.DTUint32(r, &value, &buf, test.size)
|
||||
@@ -322,6 +354,8 @@ func TestSizeTUint64(t *testing.T) {
|
||||
func TestTUint64(t *testing.T) {
|
||||
var buf [8]byte
|
||||
for _, test := range tuint64Tests {
|
||||
test := test
|
||||
|
||||
if len(test.bytes) != int(test.size) {
|
||||
t.Fatalf("invalid test case, "+
|
||||
"len(bytes)[%d] != size[%d]",
|
||||
@@ -330,6 +364,7 @@ func TestTUint64(t *testing.T) {
|
||||
|
||||
name := fmt.Sprintf("0x%x", test.value)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
// Test generic encoder.
|
||||
var b bytes.Buffer
|
||||
err := tlv.ETUint64(&b, &test.value, &buf)
|
||||
if err != nil {
|
||||
@@ -342,6 +377,19 @@ func TestTUint64(t *testing.T) {
|
||||
test.bytes, b.Bytes())
|
||||
}
|
||||
|
||||
// Test non-generic encoder.
|
||||
var b2 bytes.Buffer
|
||||
err = tlv.ETUint64T(&b2, test.value, &buf)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to encode tuint64: %v", err)
|
||||
}
|
||||
|
||||
if !bytes.Equal(b2.Bytes(), test.bytes) {
|
||||
t.Fatalf("encoding mismatch, "+
|
||||
"expected: %x, got: %x",
|
||||
test.bytes, b2.Bytes())
|
||||
}
|
||||
|
||||
var value uint64
|
||||
r := bytes.NewReader(b.Bytes())
|
||||
err = tlv.DTUint64(r, &value, &buf, test.size)
|
||||
|
||||
Reference in New Issue
Block a user