mirror of
https://github.com/aljazceru/breez-lnd.git
synced 2025-12-17 22:24:21 +01:00
channeldb+htlcswitch: write wire messages using length prefix
In this commit, we modify the way we write wire messages across the entire database. We'll now ensure that we always write wire messages with a length prefix. We update the `codec.go` file to always write a 2 byte length prefix, this affects the way we write the `CommitDiff` and `LogUpdates` struct to disk, and the network results bucket in the switch as it includes a wire message.
This commit is contained in:
@@ -1965,12 +1965,12 @@ func deserializeLogUpdates(r io.Reader) ([]LogUpdate, error) {
|
|||||||
return logUpdates, nil
|
return logUpdates, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func serializeCommitDiff(w io.Writer, diff *CommitDiff) error {
|
func serializeCommitDiff(w io.Writer, diff *CommitDiff) error { // nolint: dupl
|
||||||
if err := serializeChanCommit(w, &diff.Commitment); err != nil {
|
if err := serializeChanCommit(w, &diff.Commitment); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := diff.CommitSig.Encode(w, 0); err != nil {
|
if err := WriteElements(w, diff.CommitSig); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -2016,10 +2016,16 @@ func deserializeCommitDiff(r io.Reader) (*CommitDiff, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
d.CommitSig = &lnwire.CommitSig{}
|
var msg lnwire.Message
|
||||||
if err := d.CommitSig.Decode(r, 0); err != nil {
|
if err := ReadElements(r, &msg); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
commitSig, ok := msg.(*lnwire.CommitSig)
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("expected lnwire.CommitSig, instead "+
|
||||||
|
"read: %T", msg)
|
||||||
|
}
|
||||||
|
d.CommitSig = commitSig
|
||||||
|
|
||||||
d.LogUpdates, err = deserializeLogUpdates(r)
|
d.LogUpdates, err = deserializeLogUpdates(r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package channeldb
|
package channeldb
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
@@ -178,7 +179,17 @@ func WriteElement(w io.Writer, element interface{}) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
case lnwire.Message:
|
case lnwire.Message:
|
||||||
if _, err := lnwire.WriteMessage(w, e, 0); err != nil {
|
var msgBuf bytes.Buffer
|
||||||
|
if _, err := lnwire.WriteMessage(&msgBuf, e, 0); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
msgLen := uint16(len(msgBuf.Bytes()))
|
||||||
|
if err := WriteElements(w, msgLen); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := w.Write(msgBuf.Bytes()); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -394,7 +405,13 @@ func ReadElement(r io.Reader, element interface{}) error {
|
|||||||
*e = bytes
|
*e = bytes
|
||||||
|
|
||||||
case *lnwire.Message:
|
case *lnwire.Message:
|
||||||
msg, err := lnwire.ReadMessage(r, 0)
|
var msgLen uint16
|
||||||
|
if err := ReadElement(r, &msgLen); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
msgReader := io.LimitReader(r, int64(msgLen))
|
||||||
|
msg, err := lnwire.ReadMessage(msgReader, 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -61,28 +61,15 @@ type networkResult struct {
|
|||||||
|
|
||||||
// serializeNetworkResult serializes the networkResult.
|
// serializeNetworkResult serializes the networkResult.
|
||||||
func serializeNetworkResult(w io.Writer, n *networkResult) error {
|
func serializeNetworkResult(w io.Writer, n *networkResult) error {
|
||||||
if _, err := lnwire.WriteMessage(w, n.msg, 0); err != nil {
|
return channeldb.WriteElements(w, n.msg, n.unencrypted, n.isResolution)
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return channeldb.WriteElements(w, n.unencrypted, n.isResolution)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// deserializeNetworkResult deserializes the networkResult.
|
// deserializeNetworkResult deserializes the networkResult.
|
||||||
func deserializeNetworkResult(r io.Reader) (*networkResult, error) {
|
func deserializeNetworkResult(r io.Reader) (*networkResult, error) {
|
||||||
var (
|
|
||||||
err error
|
|
||||||
)
|
|
||||||
|
|
||||||
n := &networkResult{}
|
n := &networkResult{}
|
||||||
|
|
||||||
n.msg, err = lnwire.ReadMessage(r, 0)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := channeldb.ReadElements(r,
|
if err := channeldb.ReadElements(r,
|
||||||
&n.unencrypted, &n.isResolution,
|
&n.msg, &n.unencrypted, &n.isResolution,
|
||||||
); err != nil {
|
); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user