diff --git a/lnwire/channel_announcement.go b/lnwire/channel_announcement.go index 9e57d17b..2b34c0f9 100644 --- a/lnwire/channel_announcement.go +++ b/lnwire/channel_announcement.go @@ -88,20 +88,51 @@ func (a *ChannelAnnouncement) Decode(r io.Reader, pver uint32) error { // // This is part of the lnwire.Message interface. func (a *ChannelAnnouncement) Encode(w *bytes.Buffer, pver uint32) error { - return WriteElements(w, - a.NodeSig1, - a.NodeSig2, - a.BitcoinSig1, - a.BitcoinSig2, - a.Features, - a.ChainHash[:], - a.ShortChannelID, - a.NodeID1, - a.NodeID2, - a.BitcoinKey1, - a.BitcoinKey2, - a.ExtraOpaqueData, - ) + if err := WriteSig(w, a.NodeSig1); err != nil { + return err + } + + if err := WriteSig(w, a.NodeSig2); err != nil { + return err + } + + if err := WriteSig(w, a.BitcoinSig1); err != nil { + return err + } + + if err := WriteSig(w, a.BitcoinSig2); err != nil { + return err + } + + if err := WriteRawFeatureVector(w, a.Features); err != nil { + return err + } + + if err := WriteBytes(w, a.ChainHash[:]); err != nil { + return err + } + + if err := WriteShortChannelID(w, a.ShortChannelID); err != nil { + return err + } + + if err := WriteBytes(w, a.NodeID1[:]); err != nil { + return err + } + + if err := WriteBytes(w, a.NodeID2[:]); err != nil { + return err + } + + if err := WriteBytes(w, a.BitcoinKey1[:]); err != nil { + return err + } + + if err := WriteBytes(w, a.BitcoinKey2[:]); err != nil { + return err + } + + return WriteBytes(w, a.ExtraOpaqueData) } // MsgType returns the integer uniquely identifying this message type on the @@ -116,20 +147,40 @@ func (a *ChannelAnnouncement) MsgType() MessageType { // be signed. func (a *ChannelAnnouncement) DataToSign() ([]byte, error) { // We should not include the signatures itself. - var w bytes.Buffer - err := WriteElements(&w, - a.Features, - a.ChainHash[:], - a.ShortChannelID, - a.NodeID1, - a.NodeID2, - a.BitcoinKey1, - a.BitcoinKey2, - a.ExtraOpaqueData, - ) - if err != nil { + b := make([]byte, 0, MaxMsgBody) + buf := bytes.NewBuffer(b) + + if err := WriteRawFeatureVector(buf, a.Features); err != nil { return nil, err } - return w.Bytes(), nil + if err := WriteBytes(buf, a.ChainHash[:]); err != nil { + return nil, err + } + + if err := WriteShortChannelID(buf, a.ShortChannelID); err != nil { + return nil, err + } + + if err := WriteBytes(buf, a.NodeID1[:]); err != nil { + return nil, err + } + + if err := WriteBytes(buf, a.NodeID2[:]); err != nil { + return nil, err + } + + if err := WriteBytes(buf, a.BitcoinKey1[:]); err != nil { + return nil, err + } + + if err := WriteBytes(buf, a.BitcoinKey2[:]); err != nil { + return nil, err + } + + if err := WriteBytes(buf, a.ExtraOpaqueData); err != nil { + return nil, err + } + + return buf.Bytes(), nil } diff --git a/lnwire/channel_reestablish.go b/lnwire/channel_reestablish.go index 44cc6f3f..0de16a36 100644 --- a/lnwire/channel_reestablish.go +++ b/lnwire/channel_reestablish.go @@ -77,12 +77,15 @@ var _ Message = (*ChannelReestablish)(nil) // // This is part of the lnwire.Message interface. func (a *ChannelReestablish) Encode(w *bytes.Buffer, pver uint32) error { - err := WriteElements(w, - a.ChanID, - a.NextLocalCommitHeight, - a.RemoteCommitTailHeight, - ) - if err != nil { + if err := WriteChannelID(w, a.ChanID); err != nil { + return err + } + + if err := WriteUint64(w, a.NextLocalCommitHeight); err != nil { + return err + } + + if err := WriteUint64(w, a.RemoteCommitTailHeight); err != nil { return err } @@ -94,15 +97,18 @@ func (a *ChannelReestablish) Encode(w *bytes.Buffer, pver uint32) error { // // NOTE: This is here primarily for the quickcheck tests, in // practice, we'll always populate this field. - return WriteElements(w, a.ExtraData) + return WriteBytes(w, a.ExtraData) } // Otherwise, we'll write out the remaining elements. - return WriteElements(w, - a.LastRemoteCommitSecret[:], - a.LocalUnrevokedCommitPoint, - a.ExtraData, - ) + if err := WriteBytes(w, a.LastRemoteCommitSecret[:]); err != nil { + return err + } + + if err := WritePublicKey(w, a.LocalUnrevokedCommitPoint); err != nil { + return err + } + return WriteBytes(w, a.ExtraData) } // Decode deserializes a serialized ChannelReestablish stored in the passed diff --git a/lnwire/channel_update.go b/lnwire/channel_update.go index e1bac9f9..7881f972 100644 --- a/lnwire/channel_update.go +++ b/lnwire/channel_update.go @@ -160,32 +160,57 @@ func (a *ChannelUpdate) Decode(r io.Reader, pver uint32) error { // // This is part of the lnwire.Message interface. func (a *ChannelUpdate) Encode(w *bytes.Buffer, pver uint32) error { - err := WriteElements(w, - a.Signature, - a.ChainHash[:], - a.ShortChannelID, - a.Timestamp, - a.MessageFlags, - a.ChannelFlags, - a.TimeLockDelta, - a.HtlcMinimumMsat, - a.BaseFee, - a.FeeRate, - ) - if err != nil { + if err := WriteSig(w, a.Signature); err != nil { + return err + } + + if err := WriteBytes(w, a.ChainHash[:]); err != nil { + return err + } + + if err := WriteShortChannelID(w, a.ShortChannelID); err != nil { + return err + } + + if err := WriteUint32(w, a.Timestamp); err != nil { + return err + } + + if err := WriteChanUpdateMsgFlags(w, a.MessageFlags); err != nil { + return err + } + + if err := WriteChanUpdateChanFlags(w, a.ChannelFlags); err != nil { + return err + } + + if err := WriteUint16(w, a.TimeLockDelta); err != nil { + return err + } + + if err := WriteMilliSatoshi(w, a.HtlcMinimumMsat); err != nil { + return err + } + + if err := WriteUint32(w, a.BaseFee); err != nil { + return err + } + + if err := WriteUint32(w, a.FeeRate); err != nil { return err } // Now append optional fields if they are set. Currently, the only // optional field is max HTLC. if a.MessageFlags.HasMaxHtlc() { - if err := WriteElements(w, a.HtlcMaximumMsat); err != nil { + err := WriteMilliSatoshi(w, a.HtlcMaximumMsat) + if err != nil { return err } } // Finally, append any extra opaque data. - return a.ExtraOpaqueData.Encode(w) + return WriteBytes(w, a.ExtraOpaqueData) } // MsgType returns the integer uniquely identifying this message type on the @@ -199,36 +224,58 @@ func (a *ChannelUpdate) MsgType() MessageType { // DataToSign is used to retrieve part of the announcement message which should // be signed. func (a *ChannelUpdate) DataToSign() ([]byte, error) { - // We should not include the signatures itself. - var w bytes.Buffer - err := WriteElements(&w, - a.ChainHash[:], - a.ShortChannelID, - a.Timestamp, - a.MessageFlags, - a.ChannelFlags, - a.TimeLockDelta, - a.HtlcMinimumMsat, - a.BaseFee, - a.FeeRate, - ) - if err != nil { + b := make([]byte, 0, MaxMsgBody) + buf := bytes.NewBuffer(b) + if err := WriteBytes(buf, a.ChainHash[:]); err != nil { + return nil, err + } + + if err := WriteShortChannelID(buf, a.ShortChannelID); err != nil { + return nil, err + } + + if err := WriteUint32(buf, a.Timestamp); err != nil { + return nil, err + } + + if err := WriteChanUpdateMsgFlags(buf, a.MessageFlags); err != nil { + return nil, err + } + + if err := WriteChanUpdateChanFlags(buf, a.ChannelFlags); err != nil { + return nil, err + } + + if err := WriteUint16(buf, a.TimeLockDelta); err != nil { + return nil, err + } + + if err := WriteMilliSatoshi(buf, a.HtlcMinimumMsat); err != nil { + return nil, err + } + + if err := WriteUint32(buf, a.BaseFee); err != nil { + return nil, err + } + + if err := WriteUint32(buf, a.FeeRate); err != nil { return nil, err } // Now append optional fields if they are set. Currently, the only // optional field is max HTLC. if a.MessageFlags.HasMaxHtlc() { - if err := WriteElements(&w, a.HtlcMaximumMsat); err != nil { + err := WriteMilliSatoshi(buf, a.HtlcMaximumMsat) + if err != nil { return nil, err } } // Finally, append any extra opaque data. - if err := a.ExtraOpaqueData.Encode(&w); err != nil { + if err := WriteBytes(buf, a.ExtraOpaqueData); err != nil { return nil, err } - return w.Bytes(), nil + return buf.Bytes(), nil } diff --git a/lnwire/extra_bytes.go b/lnwire/extra_bytes.go index f2dbec45..70554f4f 100644 --- a/lnwire/extra_bytes.go +++ b/lnwire/extra_bytes.go @@ -18,7 +18,7 @@ type ExtraOpaqueData []byte // Encode attempts to encode the raw extra bytes into the passed io.Writer. func (e *ExtraOpaqueData) Encode(w *bytes.Buffer) error { eBytes := []byte((*e)[:]) - if err := WriteElements(w, eBytes); err != nil { + if err := WriteBytes(w, eBytes); err != nil { return err } diff --git a/lnwire/lnwire_test.go b/lnwire/lnwire_test.go index 14082bfe..4475b338 100644 --- a/lnwire/lnwire_test.go +++ b/lnwire/lnwire_test.go @@ -232,7 +232,7 @@ func TestMaxOutPointIndex(t *testing.T) { } var b bytes.Buffer - if err := WriteElement(&b, op); err == nil { + if err := WriteOutPoint(&b, op); err == nil { t.Fatalf("write of outPoint should fail, index exceeds 16-bits") } } diff --git a/lnwire/query_short_chan_ids.go b/lnwire/query_short_chan_ids.go index 930becdb..323a936d 100644 --- a/lnwire/query_short_chan_ids.go +++ b/lnwire/query_short_chan_ids.go @@ -293,19 +293,18 @@ func decodeShortChanIDs(r io.Reader) (ShortChanIDEncoding, []ShortChannelID, err // This is part of the lnwire.Message interface. func (q *QueryShortChanIDs) Encode(w *bytes.Buffer, pver uint32) error { // First, we'll write out the chain hash. - err := WriteElements(w, q.ChainHash[:]) - if err != nil { + if err := WriteBytes(w, q.ChainHash[:]); err != nil { return err } // Base on our encoding type, we'll write out the set of short channel // ID's. - err = encodeShortChanIDs(w, q.EncodingType, q.ShortChanIDs, q.noSort) + err := encodeShortChanIDs(w, q.EncodingType, q.ShortChanIDs, q.noSort) if err != nil { return err } - return q.ExtraData.Encode(w) + return WriteBytes(w, q.ExtraData) } // encodeShortChanIDs encodes the passed short channel ID's into the passed @@ -332,20 +331,21 @@ func encodeShortChanIDs(w *bytes.Buffer, encodingType ShortChanIDEncoding, // body. We add 1 as the response will have the encoding type // prepended to it. numBytesBody := uint16(len(shortChanIDs)*8) + 1 - if err := WriteElements(w, numBytesBody); err != nil { + if err := WriteUint16(w, numBytesBody); err != nil { return err } // We'll then write out the encoding that that follows the // actual encoded short channel ID's. - if err := WriteElements(w, encodingType); err != nil { + err := WriteShortChanIDEncoding(w, encodingType) + if err != nil { return err } // Now that we know they're sorted, we can write out each short // channel ID to the buffer. for _, chanID := range shortChanIDs { - if err := WriteElements(w, chanID); err != nil { + if err := WriteShortChannelID(w, chanID); err != nil { return fmt.Errorf("unable to write short chan "+ "ID: %v", err) } @@ -374,7 +374,7 @@ func encodeShortChanIDs(w *bytes.Buffer, encodingType ShortChanIDEncoding, // into the zlib writer, which will do compressing on // the fly. for _, chanID := range shortChanIDs { - err := WriteElements(&wb, chanID) + err := WriteShortChannelID(&wb, chanID) if err != nil { return fmt.Errorf( "unable to write short chan "+ @@ -418,15 +418,15 @@ func encodeShortChanIDs(w *bytes.Buffer, encodingType ShortChanIDEncoding, // Finally, we can write out the number of bytes, the // compression type, and finally the buffer itself. - if err := WriteElements(w, uint16(numBytesBody)); err != nil { + if err := WriteUint16(w, uint16(numBytesBody)); err != nil { return err } - if err := WriteElements(w, encodingType); err != nil { + err := WriteShortChanIDEncoding(w, encodingType) + if err != nil { return err } - _, err := w.Write(compressedPayload) - return err + return WriteBytes(w, compressedPayload) default: // If we're trying to encode with an encoding type that we diff --git a/lnwire/reply_channel_range.go b/lnwire/reply_channel_range.go index 00bb10eb..9dc0fca9 100644 --- a/lnwire/reply_channel_range.go +++ b/lnwire/reply_channel_range.go @@ -87,22 +87,28 @@ func (c *ReplyChannelRange) Decode(r io.Reader, pver uint32) error { // // This is part of the lnwire.Message interface. func (c *ReplyChannelRange) Encode(w *bytes.Buffer, pver uint32) error { - err := WriteElements(w, - c.ChainHash[:], - c.FirstBlockHeight, - c.NumBlocks, - c.Complete, - ) + if err := WriteBytes(w, c.ChainHash[:]); err != nil { + return err + } + + if err := WriteUint32(w, c.FirstBlockHeight); err != nil { + return err + } + + if err := WriteUint32(w, c.NumBlocks); err != nil { + return err + } + + if err := WriteUint8(w, c.Complete); err != nil { + return err + } + + err := encodeShortChanIDs(w, c.EncodingType, c.ShortChanIDs, c.noSort) if err != nil { return err } - err = encodeShortChanIDs(w, c.EncodingType, c.ShortChanIDs, c.noSort) - if err != nil { - return err - } - - return c.ExtraData.Encode(w) + return WriteBytes(w, c.ExtraData) } // MsgType returns the integer uniquely identifying this message type on the diff --git a/lnwire/update_add_htlc.go b/lnwire/update_add_htlc.go index a7756994..666a5494 100644 --- a/lnwire/update_add_htlc.go +++ b/lnwire/update_add_htlc.go @@ -85,20 +85,36 @@ func (c *UpdateAddHTLC) Decode(r io.Reader, pver uint32) error { ) } -// Encode serializes the target UpdateAddHTLC into the passed io.Writer observing -// the protocol version specified. +// Encode serializes the target UpdateAddHTLC into the passed io.Writer +// observing the protocol version specified. // // This is part of the lnwire.Message interface. func (c *UpdateAddHTLC) Encode(w *bytes.Buffer, pver uint32) error { - return WriteElements(w, - c.ChanID, - c.ID, - c.Amount, - c.PaymentHash[:], - c.Expiry, - c.OnionBlob[:], - c.ExtraData, - ) + if err := WriteChannelID(w, c.ChanID); err != nil { + return err + } + + if err := WriteUint64(w, c.ID); err != nil { + return err + } + + if err := WriteMilliSatoshi(w, c.Amount); err != nil { + return err + } + + if err := WriteBytes(w, c.PaymentHash[:]); err != nil { + return err + } + + if err := WriteUint32(w, c.Expiry); err != nil { + return err + } + + if err := WriteBytes(w, c.OnionBlob[:]); err != nil { + return err + } + + return WriteBytes(w, c.ExtraData) } // MsgType returns the integer uniquely identifying this message type on the diff --git a/lnwire/update_fail_htlc.go b/lnwire/update_fail_htlc.go index 2344b568..61f02bac 100644 --- a/lnwire/update_fail_htlc.go +++ b/lnwire/update_fail_htlc.go @@ -56,12 +56,19 @@ func (c *UpdateFailHTLC) Decode(r io.Reader, pver uint32) error { // // This is part of the lnwire.Message interface. func (c *UpdateFailHTLC) Encode(w *bytes.Buffer, pver uint32) error { - return WriteElements(w, - c.ChanID, - c.ID, - c.Reason, - c.ExtraData, - ) + if err := WriteChannelID(w, c.ChanID); err != nil { + return err + } + + if err := WriteUint64(w, c.ID); err != nil { + return err + } + + if err := WriteOpaqueReason(w, c.Reason); err != nil { + return err + } + + return WriteBytes(w, c.ExtraData) } // MsgType returns the integer uniquely identifying this message type on the diff --git a/lnwire/update_fail_malformed_htlc.go b/lnwire/update_fail_malformed_htlc.go index 120f8954..f28107a9 100644 --- a/lnwire/update_fail_malformed_htlc.go +++ b/lnwire/update_fail_malformed_htlc.go @@ -54,14 +54,26 @@ func (c *UpdateFailMalformedHTLC) Decode(r io.Reader, pver uint32) error { // io.Writer observing the protocol version specified. // // This is part of the lnwire.Message interface. -func (c *UpdateFailMalformedHTLC) Encode(w *bytes.Buffer, pver uint32) error { - return WriteElements(w, - c.ChanID, - c.ID, - c.ShaOnionBlob[:], - c.FailureCode, - c.ExtraData, - ) +func (c *UpdateFailMalformedHTLC) Encode(w *bytes.Buffer, + pver uint32) error { + + if err := WriteChannelID(w, c.ChanID); err != nil { + return err + } + + if err := WriteUint64(w, c.ID); err != nil { + return err + } + + if err := WriteBytes(w, c.ShaOnionBlob[:]); err != nil { + return err + } + + if err := WriteFailCode(w, c.FailureCode); err != nil { + return err + } + + return WriteBytes(w, c.ExtraData) } // MsgType returns the integer uniquely identifying this message type on the diff --git a/lnwire/update_fee.go b/lnwire/update_fee.go index a30634d7..a7026044 100644 --- a/lnwire/update_fee.go +++ b/lnwire/update_fee.go @@ -53,11 +53,15 @@ func (c *UpdateFee) Decode(r io.Reader, pver uint32) error { // // This is part of the lnwire.Message interface. func (c *UpdateFee) Encode(w *bytes.Buffer, pver uint32) error { - return WriteElements(w, - c.ChanID, - c.FeePerKw, - c.ExtraData, - ) + if err := WriteChannelID(w, c.ChanID); err != nil { + return err + } + + if err := WriteUint32(w, c.FeePerKw); err != nil { + return err + } + + return WriteBytes(w, c.ExtraData) } // MsgType returns the integer uniquely identifying this message type on the diff --git a/lnwire/update_fulfill_htlc.go b/lnwire/update_fulfill_htlc.go index 4cd3a972..275a37c8 100644 --- a/lnwire/update_fulfill_htlc.go +++ b/lnwire/update_fulfill_htlc.go @@ -62,12 +62,19 @@ func (c *UpdateFulfillHTLC) Decode(r io.Reader, pver uint32) error { // // This is part of the lnwire.Message interface. func (c *UpdateFulfillHTLC) Encode(w *bytes.Buffer, pver uint32) error { - return WriteElements(w, - c.ChanID, - c.ID, - c.PaymentPreimage[:], - c.ExtraData, - ) + if err := WriteChannelID(w, c.ChanID); err != nil { + return err + } + + if err := WriteUint64(w, c.ID); err != nil { + return err + } + + if err := WriteBytes(w, c.PaymentPreimage[:]); err != nil { + return err + } + + return WriteBytes(w, c.ExtraData) } // MsgType returns the integer uniquely identifying this message type on the