Merge pull request #5642 from guggero/in-memory-graph

In-memory graph cache for faster pathfinding
This commit is contained in:
Oliver Gugger
2021-10-04 11:20:23 +02:00
committed by GitHub
65 changed files with 2595 additions and 1171 deletions

View File

@@ -148,7 +148,7 @@ func (d *databaseChannelGraph) addRandChannel(node1, node2 *btcec.PublicKey,
return nil, err return nil, err
} }
dbNode, err := d.db.FetchLightningNode(nil, vertex) dbNode, err := d.db.FetchLightningNode(vertex)
switch { switch {
case err == channeldb.ErrGraphNodeNotFound: case err == channeldb.ErrGraphNodeNotFound:
fallthrough fallthrough

View File

@@ -75,7 +75,7 @@ type Config struct {
// ChanStateDB is a pointer to the database that stores the channel // ChanStateDB is a pointer to the database that stores the channel
// state. // state.
ChanStateDB *channeldb.DB ChanStateDB *channeldb.ChannelStateDB
// BlockCacheSize is the size (in bytes) of blocks kept in memory. // BlockCacheSize is the size (in bytes) of blocks kept in memory.
BlockCacheSize uint64 BlockCacheSize uint64

View File

@@ -21,7 +21,11 @@ type LiveChannelSource interface {
// passed chanPoint. Optionally an existing db tx can be supplied. // passed chanPoint. Optionally an existing db tx can be supplied.
FetchChannel(tx kvdb.RTx, chanPoint wire.OutPoint) ( FetchChannel(tx kvdb.RTx, chanPoint wire.OutPoint) (
*channeldb.OpenChannel, error) *channeldb.OpenChannel, error)
}
// AddressSource is an interface that allows us to query for the set of
// addresses a node can be connected to.
type AddressSource interface {
// AddrsForNode returns all known addresses for the target node public // AddrsForNode returns all known addresses for the target node public
// key. // key.
AddrsForNode(nodePub *btcec.PublicKey) ([]net.Addr, error) AddrsForNode(nodePub *btcec.PublicKey) ([]net.Addr, error)
@@ -31,15 +35,15 @@ type LiveChannelSource interface {
// passed open channel. The backup includes all information required to restore // passed open channel. The backup includes all information required to restore
// the channel, as well as addressing information so we can find the peer and // the channel, as well as addressing information so we can find the peer and
// reconnect to them to initiate the protocol. // reconnect to them to initiate the protocol.
func assembleChanBackup(chanSource LiveChannelSource, func assembleChanBackup(addrSource AddressSource,
openChan *channeldb.OpenChannel) (*Single, error) { openChan *channeldb.OpenChannel) (*Single, error) {
log.Debugf("Crafting backup for ChannelPoint(%v)", log.Debugf("Crafting backup for ChannelPoint(%v)",
openChan.FundingOutpoint) openChan.FundingOutpoint)
// First, we'll query the channel source to obtain all the addresses // First, we'll query the channel source to obtain all the addresses
// that are are associated with the peer for this channel. // that are associated with the peer for this channel.
nodeAddrs, err := chanSource.AddrsForNode(openChan.IdentityPub) nodeAddrs, err := addrSource.AddrsForNode(openChan.IdentityPub)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -52,8 +56,8 @@ func assembleChanBackup(chanSource LiveChannelSource,
// FetchBackupForChan attempts to create a plaintext static channel backup for // FetchBackupForChan attempts to create a plaintext static channel backup for
// the target channel identified by its channel point. If we're unable to find // the target channel identified by its channel point. If we're unable to find
// the target channel, then an error will be returned. // the target channel, then an error will be returned.
func FetchBackupForChan(chanPoint wire.OutPoint, func FetchBackupForChan(chanPoint wire.OutPoint, chanSource LiveChannelSource,
chanSource LiveChannelSource) (*Single, error) { addrSource AddressSource) (*Single, error) {
// First, we'll query the channel source to see if the channel is known // First, we'll query the channel source to see if the channel is known
// and open within the database. // and open within the database.
@@ -66,7 +70,7 @@ func FetchBackupForChan(chanPoint wire.OutPoint,
// Once we have the target channel, we can assemble the backup using // Once we have the target channel, we can assemble the backup using
// the source to obtain any extra information that we may need. // the source to obtain any extra information that we may need.
staticChanBackup, err := assembleChanBackup(chanSource, targetChan) staticChanBackup, err := assembleChanBackup(addrSource, targetChan)
if err != nil { if err != nil {
return nil, fmt.Errorf("unable to create chan backup: %v", err) return nil, fmt.Errorf("unable to create chan backup: %v", err)
} }
@@ -76,7 +80,9 @@ func FetchBackupForChan(chanPoint wire.OutPoint,
// FetchStaticChanBackups will return a plaintext static channel back up for // FetchStaticChanBackups will return a plaintext static channel back up for
// all known active/open channels within the passed channel source. // all known active/open channels within the passed channel source.
func FetchStaticChanBackups(chanSource LiveChannelSource) ([]Single, error) { func FetchStaticChanBackups(chanSource LiveChannelSource,
addrSource AddressSource) ([]Single, error) {
// First, we'll query the backup source for information concerning all // First, we'll query the backup source for information concerning all
// currently open and available channels. // currently open and available channels.
openChans, err := chanSource.FetchAllChannels() openChans, err := chanSource.FetchAllChannels()
@@ -89,7 +95,7 @@ func FetchStaticChanBackups(chanSource LiveChannelSource) ([]Single, error) {
// channel. // channel.
staticChanBackups := make([]Single, 0, len(openChans)) staticChanBackups := make([]Single, 0, len(openChans))
for _, openChan := range openChans { for _, openChan := range openChans {
chanBackup, err := assembleChanBackup(chanSource, openChan) chanBackup, err := assembleChanBackup(addrSource, openChan)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -124,7 +124,9 @@ func TestFetchBackupForChan(t *testing.T) {
}, },
} }
for i, testCase := range testCases { for i, testCase := range testCases {
_, err := FetchBackupForChan(testCase.chanPoint, chanSource) _, err := FetchBackupForChan(
testCase.chanPoint, chanSource, chanSource,
)
switch { switch {
// If this is a valid test case, and we failed, then we'll // If this is a valid test case, and we failed, then we'll
// return an error. // return an error.
@@ -167,7 +169,7 @@ func TestFetchStaticChanBackups(t *testing.T) {
// With the channel source populated, we'll now attempt to create a set // With the channel source populated, we'll now attempt to create a set
// of backups for all the channels. This should succeed, as all items // of backups for all the channels. This should succeed, as all items
// are populated within the channel source. // are populated within the channel source.
backups, err := FetchStaticChanBackups(chanSource) backups, err := FetchStaticChanBackups(chanSource, chanSource)
if err != nil { if err != nil {
t.Fatalf("unable to create chan back ups: %v", err) t.Fatalf("unable to create chan back ups: %v", err)
} }
@@ -184,7 +186,7 @@ func TestFetchStaticChanBackups(t *testing.T) {
copy(n[:], randomChan2.IdentityPub.SerializeCompressed()) copy(n[:], randomChan2.IdentityPub.SerializeCompressed())
delete(chanSource.addrs, n) delete(chanSource.addrs, n)
_, err = FetchStaticChanBackups(chanSource) _, err = FetchStaticChanBackups(chanSource, chanSource)
if err == nil { if err == nil {
t.Fatalf("query with incomplete information should fail") t.Fatalf("query with incomplete information should fail")
} }
@@ -193,7 +195,7 @@ func TestFetchStaticChanBackups(t *testing.T) {
// source at all, then we'll fail as well. // source at all, then we'll fail as well.
chanSource = newMockChannelSource() chanSource = newMockChannelSource()
chanSource.failQuery = true chanSource.failQuery = true
_, err = FetchStaticChanBackups(chanSource) _, err = FetchStaticChanBackups(chanSource, chanSource)
if err == nil { if err == nil {
t.Fatalf("query should fail") t.Fatalf("query should fail")
} }

View File

@@ -729,7 +729,7 @@ type OpenChannel struct {
RevocationKeyLocator keychain.KeyLocator RevocationKeyLocator keychain.KeyLocator
// TODO(roasbeef): eww // TODO(roasbeef): eww
Db *DB Db *ChannelStateDB
// TODO(roasbeef): just need to store local and remote HTLC's? // TODO(roasbeef): just need to store local and remote HTLC's?
@@ -800,7 +800,7 @@ func (c *OpenChannel) RefreshShortChanID() error {
c.Lock() c.Lock()
defer c.Unlock() defer c.Unlock()
err := kvdb.View(c.Db, func(tx kvdb.RTx) error { err := kvdb.View(c.Db.backend, func(tx kvdb.RTx) error {
chanBucket, err := fetchChanBucket( chanBucket, err := fetchChanBucket(
tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash,
) )
@@ -875,12 +875,43 @@ func fetchChanBucket(tx kvdb.RTx, nodeKey *btcec.PublicKey,
func fetchChanBucketRw(tx kvdb.RwTx, nodeKey *btcec.PublicKey, // nolint:interfacer func fetchChanBucketRw(tx kvdb.RwTx, nodeKey *btcec.PublicKey, // nolint:interfacer
outPoint *wire.OutPoint, chainHash chainhash.Hash) (kvdb.RwBucket, error) { outPoint *wire.OutPoint, chainHash chainhash.Hash) (kvdb.RwBucket, error) {
readBucket, err := fetchChanBucket(tx, nodeKey, outPoint, chainHash) // First fetch the top level bucket which stores all data related to
if err != nil { // current, active channels.
return nil, err openChanBucket := tx.ReadWriteBucket(openChannelBucket)
if openChanBucket == nil {
return nil, ErrNoChanDBExists
} }
return readBucket.(kvdb.RwBucket), nil // TODO(roasbeef): CreateTopLevelBucket on the interface isn't like
// CreateIfNotExists, will return error
// Within this top level bucket, fetch the bucket dedicated to storing
// open channel data specific to the remote node.
nodePub := nodeKey.SerializeCompressed()
nodeChanBucket := openChanBucket.NestedReadWriteBucket(nodePub)
if nodeChanBucket == nil {
return nil, ErrNoActiveChannels
}
// We'll then recurse down an additional layer in order to fetch the
// bucket for this particular chain.
chainBucket := nodeChanBucket.NestedReadWriteBucket(chainHash[:])
if chainBucket == nil {
return nil, ErrNoActiveChannels
}
// With the bucket for the node and chain fetched, we can now go down
// another level, for this channel itself.
var chanPointBuf bytes.Buffer
if err := writeOutpoint(&chanPointBuf, outPoint); err != nil {
return nil, err
}
chanBucket := chainBucket.NestedReadWriteBucket(chanPointBuf.Bytes())
if chanBucket == nil {
return nil, ErrChannelNotFound
}
return chanBucket, nil
} }
// fullSync syncs the contents of an OpenChannel while re-using an existing // fullSync syncs the contents of an OpenChannel while re-using an existing
@@ -964,8 +995,8 @@ func (c *OpenChannel) MarkAsOpen(openLoc lnwire.ShortChannelID) error {
c.Lock() c.Lock()
defer c.Unlock() defer c.Unlock()
if err := kvdb.Update(c.Db, func(tx kvdb.RwTx) error { if err := kvdb.Update(c.Db.backend, func(tx kvdb.RwTx) error {
chanBucket, err := fetchChanBucket( chanBucket, err := fetchChanBucketRw(
tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash,
) )
if err != nil { if err != nil {
@@ -980,7 +1011,7 @@ func (c *OpenChannel) MarkAsOpen(openLoc lnwire.ShortChannelID) error {
channel.IsPending = false channel.IsPending = false
channel.ShortChannelID = openLoc channel.ShortChannelID = openLoc
return putOpenChannel(chanBucket.(kvdb.RwBucket), channel) return putOpenChannel(chanBucket, channel)
}, func() {}); err != nil { }, func() {}); err != nil {
return err return err
} }
@@ -1016,7 +1047,7 @@ func (c *OpenChannel) MarkDataLoss(commitPoint *btcec.PublicKey) error {
func (c *OpenChannel) DataLossCommitPoint() (*btcec.PublicKey, error) { func (c *OpenChannel) DataLossCommitPoint() (*btcec.PublicKey, error) {
var commitPoint *btcec.PublicKey var commitPoint *btcec.PublicKey
err := kvdb.View(c.Db, func(tx kvdb.RTx) error { err := kvdb.View(c.Db.backend, func(tx kvdb.RTx) error {
chanBucket, err := fetchChanBucket( chanBucket, err := fetchChanBucket(
tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash,
) )
@@ -1240,7 +1271,7 @@ func (c *OpenChannel) BroadcastedCooperative() (*wire.MsgTx, error) {
func (c *OpenChannel) getClosingTx(key []byte) (*wire.MsgTx, error) { func (c *OpenChannel) getClosingTx(key []byte) (*wire.MsgTx, error) {
var closeTx *wire.MsgTx var closeTx *wire.MsgTx
err := kvdb.View(c.Db, func(tx kvdb.RTx) error { err := kvdb.View(c.Db.backend, func(tx kvdb.RTx) error {
chanBucket, err := fetchChanBucket( chanBucket, err := fetchChanBucket(
tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash,
) )
@@ -1274,7 +1305,7 @@ func (c *OpenChannel) getClosingTx(key []byte) (*wire.MsgTx, error) {
func (c *OpenChannel) putChanStatus(status ChannelStatus, func (c *OpenChannel) putChanStatus(status ChannelStatus,
fs ...func(kvdb.RwBucket) error) error { fs ...func(kvdb.RwBucket) error) error {
if err := kvdb.Update(c.Db, func(tx kvdb.RwTx) error { if err := kvdb.Update(c.Db.backend, func(tx kvdb.RwTx) error {
chanBucket, err := fetchChanBucketRw( chanBucket, err := fetchChanBucketRw(
tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash,
) )
@@ -1318,7 +1349,7 @@ func (c *OpenChannel) putChanStatus(status ChannelStatus,
} }
func (c *OpenChannel) clearChanStatus(status ChannelStatus) error { func (c *OpenChannel) clearChanStatus(status ChannelStatus) error {
if err := kvdb.Update(c.Db, func(tx kvdb.RwTx) error { if err := kvdb.Update(c.Db.backend, func(tx kvdb.RwTx) error {
chanBucket, err := fetchChanBucketRw( chanBucket, err := fetchChanBucketRw(
tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash,
) )
@@ -1442,7 +1473,7 @@ func (c *OpenChannel) SyncPending(addr net.Addr, pendingHeight uint32) error {
c.FundingBroadcastHeight = pendingHeight c.FundingBroadcastHeight = pendingHeight
return kvdb.Update(c.Db, func(tx kvdb.RwTx) error { return kvdb.Update(c.Db.backend, func(tx kvdb.RwTx) error {
return syncNewChannel(tx, c, []net.Addr{addr}) return syncNewChannel(tx, c, []net.Addr{addr})
}, func() {}) }, func() {})
} }
@@ -1470,7 +1501,10 @@ func syncNewChannel(tx kvdb.RwTx, c *OpenChannel, addrs []net.Addr) error {
// Next, we need to establish a (possibly) new LinkNode relationship // Next, we need to establish a (possibly) new LinkNode relationship
// for this channel. The LinkNode metadata contains reachability, // for this channel. The LinkNode metadata contains reachability,
// up-time, and service bits related information. // up-time, and service bits related information.
linkNode := c.Db.NewLinkNode(wire.MainNet, c.IdentityPub, addrs...) linkNode := NewLinkNode(
&LinkNodeDB{backend: c.Db.backend},
wire.MainNet, c.IdentityPub, addrs...,
)
// TODO(roasbeef): do away with link node all together? // TODO(roasbeef): do away with link node all together?
@@ -1498,7 +1532,7 @@ func (c *OpenChannel) UpdateCommitment(newCommitment *ChannelCommitment,
return ErrNoRestoredChannelMutation return ErrNoRestoredChannelMutation
} }
err := kvdb.Update(c.Db, func(tx kvdb.RwTx) error { err := kvdb.Update(c.Db.backend, func(tx kvdb.RwTx) error {
chanBucket, err := fetchChanBucketRw( chanBucket, err := fetchChanBucketRw(
tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash,
) )
@@ -2090,7 +2124,7 @@ func (c *OpenChannel) AppendRemoteCommitChain(diff *CommitDiff) error {
return ErrNoRestoredChannelMutation return ErrNoRestoredChannelMutation
} }
return kvdb.Update(c.Db, func(tx kvdb.RwTx) error { return kvdb.Update(c.Db.backend, func(tx kvdb.RwTx) error {
// First, we'll grab the writable bucket where this channel's // First, we'll grab the writable bucket where this channel's
// data resides. // data resides.
chanBucket, err := fetchChanBucketRw( chanBucket, err := fetchChanBucketRw(
@@ -2160,7 +2194,7 @@ func (c *OpenChannel) AppendRemoteCommitChain(diff *CommitDiff) error {
// these pointers, causing the tip and the tail to point to the same entry. // these pointers, causing the tip and the tail to point to the same entry.
func (c *OpenChannel) RemoteCommitChainTip() (*CommitDiff, error) { func (c *OpenChannel) RemoteCommitChainTip() (*CommitDiff, error) {
var cd *CommitDiff var cd *CommitDiff
err := kvdb.View(c.Db, func(tx kvdb.RTx) error { err := kvdb.View(c.Db.backend, func(tx kvdb.RTx) error {
chanBucket, err := fetchChanBucket( chanBucket, err := fetchChanBucket(
tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash,
) )
@@ -2199,7 +2233,7 @@ func (c *OpenChannel) RemoteCommitChainTip() (*CommitDiff, error) {
// updates that still need to be signed for. // updates that still need to be signed for.
func (c *OpenChannel) UnsignedAckedUpdates() ([]LogUpdate, error) { func (c *OpenChannel) UnsignedAckedUpdates() ([]LogUpdate, error) {
var updates []LogUpdate var updates []LogUpdate
err := kvdb.View(c.Db, func(tx kvdb.RTx) error { err := kvdb.View(c.Db.backend, func(tx kvdb.RTx) error {
chanBucket, err := fetchChanBucket( chanBucket, err := fetchChanBucket(
tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash,
) )
@@ -2233,7 +2267,7 @@ func (c *OpenChannel) UnsignedAckedUpdates() ([]LogUpdate, error) {
// updates that the remote still needs to sign for. // updates that the remote still needs to sign for.
func (c *OpenChannel) RemoteUnsignedLocalUpdates() ([]LogUpdate, error) { func (c *OpenChannel) RemoteUnsignedLocalUpdates() ([]LogUpdate, error) {
var updates []LogUpdate var updates []LogUpdate
err := kvdb.View(c.Db, func(tx kvdb.RTx) error { err := kvdb.View(c.Db.backend, func(tx kvdb.RTx) error {
chanBucket, err := fetchChanBucket( chanBucket, err := fetchChanBucket(
tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash,
) )
@@ -2277,7 +2311,7 @@ func (c *OpenChannel) InsertNextRevocation(revKey *btcec.PublicKey) error {
c.RemoteNextRevocation = revKey c.RemoteNextRevocation = revKey
err := kvdb.Update(c.Db, func(tx kvdb.RwTx) error { err := kvdb.Update(c.Db.backend, func(tx kvdb.RwTx) error {
chanBucket, err := fetchChanBucketRw( chanBucket, err := fetchChanBucketRw(
tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash,
) )
@@ -2318,7 +2352,7 @@ func (c *OpenChannel) AdvanceCommitChainTail(fwdPkg *FwdPkg,
var newRemoteCommit *ChannelCommitment var newRemoteCommit *ChannelCommitment
err := kvdb.Update(c.Db, func(tx kvdb.RwTx) error { err := kvdb.Update(c.Db.backend, func(tx kvdb.RwTx) error {
chanBucket, err := fetchChanBucketRw( chanBucket, err := fetchChanBucketRw(
tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash,
) )
@@ -2493,7 +2527,7 @@ func (c *OpenChannel) LoadFwdPkgs() ([]*FwdPkg, error) {
defer c.RUnlock() defer c.RUnlock()
var fwdPkgs []*FwdPkg var fwdPkgs []*FwdPkg
if err := kvdb.View(c.Db, func(tx kvdb.RTx) error { if err := kvdb.View(c.Db.backend, func(tx kvdb.RTx) error {
var err error var err error
fwdPkgs, err = c.Packager.LoadFwdPkgs(tx) fwdPkgs, err = c.Packager.LoadFwdPkgs(tx)
return err return err
@@ -2513,7 +2547,7 @@ func (c *OpenChannel) AckAddHtlcs(addRefs ...AddRef) error {
c.Lock() c.Lock()
defer c.Unlock() defer c.Unlock()
return kvdb.Update(c.Db, func(tx kvdb.RwTx) error { return kvdb.Update(c.Db.backend, func(tx kvdb.RwTx) error {
return c.Packager.AckAddHtlcs(tx, addRefs...) return c.Packager.AckAddHtlcs(tx, addRefs...)
}, func() {}) }, func() {})
} }
@@ -2526,7 +2560,7 @@ func (c *OpenChannel) AckSettleFails(settleFailRefs ...SettleFailRef) error {
c.Lock() c.Lock()
defer c.Unlock() defer c.Unlock()
return kvdb.Update(c.Db, func(tx kvdb.RwTx) error { return kvdb.Update(c.Db.backend, func(tx kvdb.RwTx) error {
return c.Packager.AckSettleFails(tx, settleFailRefs...) return c.Packager.AckSettleFails(tx, settleFailRefs...)
}, func() {}) }, func() {})
} }
@@ -2537,7 +2571,7 @@ func (c *OpenChannel) SetFwdFilter(height uint64, fwdFilter *PkgFilter) error {
c.Lock() c.Lock()
defer c.Unlock() defer c.Unlock()
return kvdb.Update(c.Db, func(tx kvdb.RwTx) error { return kvdb.Update(c.Db.backend, func(tx kvdb.RwTx) error {
return c.Packager.SetFwdFilter(tx, height, fwdFilter) return c.Packager.SetFwdFilter(tx, height, fwdFilter)
}, func() {}) }, func() {})
} }
@@ -2551,7 +2585,7 @@ func (c *OpenChannel) RemoveFwdPkgs(heights ...uint64) error {
c.Lock() c.Lock()
defer c.Unlock() defer c.Unlock()
return kvdb.Update(c.Db, func(tx kvdb.RwTx) error { return kvdb.Update(c.Db.backend, func(tx kvdb.RwTx) error {
for _, height := range heights { for _, height := range heights {
err := c.Packager.RemovePkg(tx, height) err := c.Packager.RemovePkg(tx, height)
if err != nil { if err != nil {
@@ -2579,7 +2613,7 @@ func (c *OpenChannel) RevocationLogTail() (*ChannelCommitment, error) {
} }
var commit ChannelCommitment var commit ChannelCommitment
if err := kvdb.View(c.Db, func(tx kvdb.RTx) error { if err := kvdb.View(c.Db.backend, func(tx kvdb.RTx) error {
chanBucket, err := fetchChanBucket( chanBucket, err := fetchChanBucket(
tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash,
) )
@@ -2626,7 +2660,7 @@ func (c *OpenChannel) CommitmentHeight() (uint64, error) {
defer c.RUnlock() defer c.RUnlock()
var height uint64 var height uint64
err := kvdb.View(c.Db, func(tx kvdb.RTx) error { err := kvdb.View(c.Db.backend, func(tx kvdb.RTx) error {
// Get the bucket dedicated to storing the metadata for open // Get the bucket dedicated to storing the metadata for open
// channels. // channels.
chanBucket, err := fetchChanBucket( chanBucket, err := fetchChanBucket(
@@ -2663,7 +2697,7 @@ func (c *OpenChannel) FindPreviousState(updateNum uint64) (*ChannelCommitment, e
defer c.RUnlock() defer c.RUnlock()
var commit ChannelCommitment var commit ChannelCommitment
err := kvdb.View(c.Db, func(tx kvdb.RTx) error { err := kvdb.View(c.Db.backend, func(tx kvdb.RTx) error {
chanBucket, err := fetchChanBucket( chanBucket, err := fetchChanBucket(
tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash,
) )
@@ -2821,7 +2855,7 @@ func (c *OpenChannel) CloseChannel(summary *ChannelCloseSummary,
c.Lock() c.Lock()
defer c.Unlock() defer c.Unlock()
return kvdb.Update(c.Db, func(tx kvdb.RwTx) error { return kvdb.Update(c.Db.backend, func(tx kvdb.RwTx) error {
openChanBucket := tx.ReadWriteBucket(openChannelBucket) openChanBucket := tx.ReadWriteBucket(openChannelBucket)
if openChanBucket == nil { if openChanBucket == nil {
return ErrNoChanDBExists return ErrNoChanDBExists
@@ -3033,7 +3067,7 @@ func (c *OpenChannel) Snapshot() *ChannelSnapshot {
// latest fully committed state is returned. The first commitment returned is // latest fully committed state is returned. The first commitment returned is
// the local commitment, and the second returned is the remote commitment. // the local commitment, and the second returned is the remote commitment.
func (c *OpenChannel) LatestCommitments() (*ChannelCommitment, *ChannelCommitment, error) { func (c *OpenChannel) LatestCommitments() (*ChannelCommitment, *ChannelCommitment, error) {
err := kvdb.View(c.Db, func(tx kvdb.RTx) error { err := kvdb.View(c.Db.backend, func(tx kvdb.RTx) error {
chanBucket, err := fetchChanBucket( chanBucket, err := fetchChanBucket(
tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash,
) )
@@ -3055,7 +3089,7 @@ func (c *OpenChannel) LatestCommitments() (*ChannelCommitment, *ChannelCommitmen
// acting on a possible contract breach to ensure, that the caller has the most // acting on a possible contract breach to ensure, that the caller has the most
// up to date information required to deliver justice. // up to date information required to deliver justice.
func (c *OpenChannel) RemoteRevocationStore() (shachain.Store, error) { func (c *OpenChannel) RemoteRevocationStore() (shachain.Store, error) {
err := kvdb.View(c.Db, func(tx kvdb.RTx) error { err := kvdb.View(c.Db.backend, func(tx kvdb.RTx) error {
chanBucket, err := fetchChanBucket( chanBucket, err := fetchChanBucket(
tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash,
) )

View File

@@ -183,7 +183,7 @@ var channelIDOption = func(chanID lnwire.ShortChannelID) testChannelOption {
// createTestChannel writes a test channel to the database. It takes a set of // createTestChannel writes a test channel to the database. It takes a set of
// functional options which can be used to overwrite the default of creating // functional options which can be used to overwrite the default of creating
// a pending channel that was broadcast at height 100. // a pending channel that was broadcast at height 100.
func createTestChannel(t *testing.T, cdb *DB, func createTestChannel(t *testing.T, cdb *ChannelStateDB,
opts ...testChannelOption) *OpenChannel { opts ...testChannelOption) *OpenChannel {
// Create a default set of parameters. // Create a default set of parameters.
@@ -221,7 +221,7 @@ func createTestChannel(t *testing.T, cdb *DB,
return params.channel return params.channel
} }
func createTestChannelState(t *testing.T, cdb *DB) *OpenChannel { func createTestChannelState(t *testing.T, cdb *ChannelStateDB) *OpenChannel {
// Simulate 1000 channel updates. // Simulate 1000 channel updates.
producer, err := shachain.NewRevocationProducerFromBytes(key[:]) producer, err := shachain.NewRevocationProducerFromBytes(key[:])
if err != nil { if err != nil {
@@ -359,12 +359,14 @@ func createTestChannelState(t *testing.T, cdb *DB) *OpenChannel {
func TestOpenChannelPutGetDelete(t *testing.T) { func TestOpenChannelPutGetDelete(t *testing.T) {
t.Parallel() t.Parallel()
cdb, cleanUp, err := MakeTestDB() fullDB, cleanUp, err := MakeTestDB()
if err != nil { if err != nil {
t.Fatalf("unable to make test database: %v", err) t.Fatalf("unable to make test database: %v", err)
} }
defer cleanUp() defer cleanUp()
cdb := fullDB.ChannelStateDB()
// Create the test channel state, with additional htlcs on the local // Create the test channel state, with additional htlcs on the local
// and remote commitment. // and remote commitment.
localHtlcs := []HTLC{ localHtlcs := []HTLC{
@@ -508,12 +510,14 @@ func TestOptionalShutdown(t *testing.T) {
test := test test := test
t.Run(test.name, func(t *testing.T) { t.Run(test.name, func(t *testing.T) {
cdb, cleanUp, err := MakeTestDB() fullDB, cleanUp, err := MakeTestDB()
if err != nil { if err != nil {
t.Fatalf("unable to make test database: %v", err) t.Fatalf("unable to make test database: %v", err)
} }
defer cleanUp() defer cleanUp()
cdb := fullDB.ChannelStateDB()
// Create a channel with upfront scripts set as // Create a channel with upfront scripts set as
// specified in the test. // specified in the test.
state := createTestChannel( state := createTestChannel(
@@ -565,12 +569,14 @@ func assertCommitmentEqual(t *testing.T, a, b *ChannelCommitment) {
func TestChannelStateTransition(t *testing.T) { func TestChannelStateTransition(t *testing.T) {
t.Parallel() t.Parallel()
cdb, cleanUp, err := MakeTestDB() fullDB, cleanUp, err := MakeTestDB()
if err != nil { if err != nil {
t.Fatalf("unable to make test database: %v", err) t.Fatalf("unable to make test database: %v", err)
} }
defer cleanUp() defer cleanUp()
cdb := fullDB.ChannelStateDB()
// First create a minimal channel, then perform a full sync in order to // First create a minimal channel, then perform a full sync in order to
// persist the data. // persist the data.
channel := createTestChannel(t, cdb) channel := createTestChannel(t, cdb)
@@ -842,7 +848,7 @@ func TestChannelStateTransition(t *testing.T) {
} }
// At this point, we should have 2 forwarding packages added. // At this point, we should have 2 forwarding packages added.
fwdPkgs := loadFwdPkgs(t, cdb, channel.Packager) fwdPkgs := loadFwdPkgs(t, cdb.backend, channel.Packager)
require.Len(t, fwdPkgs, 2, "wrong number of forwarding packages") require.Len(t, fwdPkgs, 2, "wrong number of forwarding packages")
// Now attempt to delete the channel from the database. // Now attempt to delete the channel from the database.
@@ -877,19 +883,21 @@ func TestChannelStateTransition(t *testing.T) {
} }
// All forwarding packages of this channel has been deleted too. // All forwarding packages of this channel has been deleted too.
fwdPkgs = loadFwdPkgs(t, cdb, channel.Packager) fwdPkgs = loadFwdPkgs(t, cdb.backend, channel.Packager)
require.Empty(t, fwdPkgs, "no forwarding packages should exist") require.Empty(t, fwdPkgs, "no forwarding packages should exist")
} }
func TestFetchPendingChannels(t *testing.T) { func TestFetchPendingChannels(t *testing.T) {
t.Parallel() t.Parallel()
cdb, cleanUp, err := MakeTestDB() fullDB, cleanUp, err := MakeTestDB()
if err != nil { if err != nil {
t.Fatalf("unable to make test database: %v", err) t.Fatalf("unable to make test database: %v", err)
} }
defer cleanUp() defer cleanUp()
cdb := fullDB.ChannelStateDB()
// Create a pending channel that was broadcast at height 99. // Create a pending channel that was broadcast at height 99.
const broadcastHeight = 99 const broadcastHeight = 99
createTestChannel(t, cdb, pendingHeightOption(broadcastHeight)) createTestChannel(t, cdb, pendingHeightOption(broadcastHeight))
@@ -963,12 +971,14 @@ func TestFetchPendingChannels(t *testing.T) {
func TestFetchClosedChannels(t *testing.T) { func TestFetchClosedChannels(t *testing.T) {
t.Parallel() t.Parallel()
cdb, cleanUp, err := MakeTestDB() fullDB, cleanUp, err := MakeTestDB()
if err != nil { if err != nil {
t.Fatalf("unable to make test database: %v", err) t.Fatalf("unable to make test database: %v", err)
} }
defer cleanUp() defer cleanUp()
cdb := fullDB.ChannelStateDB()
// Create an open channel in the database. // Create an open channel in the database.
state := createTestChannel(t, cdb, openChannelOption()) state := createTestChannel(t, cdb, openChannelOption())
@@ -1054,18 +1064,20 @@ func TestFetchWaitingCloseChannels(t *testing.T) {
// We'll start by creating two channels within our test database. One of // We'll start by creating two channels within our test database. One of
// them will have their funding transaction confirmed on-chain, while // them will have their funding transaction confirmed on-chain, while
// the other one will remain unconfirmed. // the other one will remain unconfirmed.
db, cleanUp, err := MakeTestDB() fullDB, cleanUp, err := MakeTestDB()
if err != nil { if err != nil {
t.Fatalf("unable to make test database: %v", err) t.Fatalf("unable to make test database: %v", err)
} }
defer cleanUp() defer cleanUp()
cdb := fullDB.ChannelStateDB()
channels := make([]*OpenChannel, numChannels) channels := make([]*OpenChannel, numChannels)
for i := 0; i < numChannels; i++ { for i := 0; i < numChannels; i++ {
// Create a pending channel in the database at the broadcast // Create a pending channel in the database at the broadcast
// height. // height.
channels[i] = createTestChannel( channels[i] = createTestChannel(
t, db, pendingHeightOption(broadcastHeight), t, cdb, pendingHeightOption(broadcastHeight),
) )
} }
@@ -1116,7 +1128,7 @@ func TestFetchWaitingCloseChannels(t *testing.T) {
// Now, we'll fetch all the channels waiting to be closed from the // Now, we'll fetch all the channels waiting to be closed from the
// database. We should expect to see both channels above, even if any of // database. We should expect to see both channels above, even if any of
// them haven't had their funding transaction confirm on-chain. // them haven't had their funding transaction confirm on-chain.
waitingCloseChannels, err := db.FetchWaitingCloseChannels() waitingCloseChannels, err := cdb.FetchWaitingCloseChannels()
if err != nil { if err != nil {
t.Fatalf("unable to fetch all waiting close channels: %v", err) t.Fatalf("unable to fetch all waiting close channels: %v", err)
} }
@@ -1169,12 +1181,14 @@ func TestFetchWaitingCloseChannels(t *testing.T) {
func TestRefreshShortChanID(t *testing.T) { func TestRefreshShortChanID(t *testing.T) {
t.Parallel() t.Parallel()
cdb, cleanUp, err := MakeTestDB() fullDB, cleanUp, err := MakeTestDB()
if err != nil { if err != nil {
t.Fatalf("unable to make test database: %v", err) t.Fatalf("unable to make test database: %v", err)
} }
defer cleanUp() defer cleanUp()
cdb := fullDB.ChannelStateDB()
// First create a test channel. // First create a test channel.
state := createTestChannel(t, cdb) state := createTestChannel(t, cdb)
@@ -1317,13 +1331,15 @@ func TestCloseInitiator(t *testing.T) {
t.Run(test.name, func(t *testing.T) { t.Run(test.name, func(t *testing.T) {
t.Parallel() t.Parallel()
cdb, cleanUp, err := MakeTestDB() fullDB, cleanUp, err := MakeTestDB()
if err != nil { if err != nil {
t.Fatalf("unable to make test database: %v", t.Fatalf("unable to make test database: %v",
err) err)
} }
defer cleanUp() defer cleanUp()
cdb := fullDB.ChannelStateDB()
// Create an open channel. // Create an open channel.
channel := createTestChannel( channel := createTestChannel(
t, cdb, openChannelOption(), t, cdb, openChannelOption(),
@@ -1362,13 +1378,15 @@ func TestCloseInitiator(t *testing.T) {
// TestCloseChannelStatus tests setting of a channel status on the historical // TestCloseChannelStatus tests setting of a channel status on the historical
// channel on channel close. // channel on channel close.
func TestCloseChannelStatus(t *testing.T) { func TestCloseChannelStatus(t *testing.T) {
cdb, cleanUp, err := MakeTestDB() fullDB, cleanUp, err := MakeTestDB()
if err != nil { if err != nil {
t.Fatalf("unable to make test database: %v", t.Fatalf("unable to make test database: %v",
err) err)
} }
defer cleanUp() defer cleanUp()
cdb := fullDB.ChannelStateDB()
// Create an open channel. // Create an open channel.
channel := createTestChannel( channel := createTestChannel(
t, cdb, openChannelOption(), t, cdb, openChannelOption(),
@@ -1427,7 +1445,7 @@ func TestBalanceAtHeight(t *testing.T) {
putRevokedState := func(c *OpenChannel, height uint64, local, putRevokedState := func(c *OpenChannel, height uint64, local,
remote lnwire.MilliSatoshi) error { remote lnwire.MilliSatoshi) error {
err := kvdb.Update(c.Db, func(tx kvdb.RwTx) error { err := kvdb.Update(c.Db.backend, func(tx kvdb.RwTx) error {
chanBucket, err := fetchChanBucketRw( chanBucket, err := fetchChanBucketRw(
tx, c.IdentityPub, &c.FundingOutpoint, tx, c.IdentityPub, &c.FundingOutpoint,
c.ChainHash, c.ChainHash,
@@ -1508,13 +1526,15 @@ func TestBalanceAtHeight(t *testing.T) {
t.Run(test.name, func(t *testing.T) { t.Run(test.name, func(t *testing.T) {
t.Parallel() t.Parallel()
cdb, cleanUp, err := MakeTestDB() fullDB, cleanUp, err := MakeTestDB()
if err != nil { if err != nil {
t.Fatalf("unable to make test database: %v", t.Fatalf("unable to make test database: %v",
err) err)
} }
defer cleanUp() defer cleanUp()
cdb := fullDB.ChannelStateDB()
// Create options to set the heights and balances of // Create options to set the heights and balances of
// our local and remote commitments. // our local and remote commitments.
localCommitOpt := channelCommitmentOption( localCommitOpt := channelCommitmentOption(

View File

@@ -23,6 +23,7 @@ import (
"github.com/lightningnetwork/lnd/clock" "github.com/lightningnetwork/lnd/clock"
"github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/kvdb"
"github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/routing/route"
) )
const ( const (
@@ -209,6 +210,11 @@ var (
// Big endian is the preferred byte order, due to cursor scans over // Big endian is the preferred byte order, due to cursor scans over
// integer keys iterating in order. // integer keys iterating in order.
byteOrder = binary.BigEndian byteOrder = binary.BigEndian
// channelOpeningStateBucket is the database bucket used to store the
// channelOpeningState for each channel that is currently in the process
// of being opened.
channelOpeningStateBucket = []byte("channelOpeningState")
) )
// DB is the primary datastore for the lnd daemon. The database stores // DB is the primary datastore for the lnd daemon. The database stores
@@ -217,6 +223,9 @@ var (
type DB struct { type DB struct {
kvdb.Backend kvdb.Backend
// channelStateDB separates all DB operations on channel state.
channelStateDB *ChannelStateDB
dbPath string dbPath string
graph *ChannelGraph graph *ChannelGraph
clock clock.Clock clock clock.Clock
@@ -265,13 +274,27 @@ func CreateWithBackend(backend kvdb.Backend, modifiers ...OptionModifier) (*DB,
chanDB := &DB{ chanDB := &DB{
Backend: backend, Backend: backend,
clock: opts.clock, channelStateDB: &ChannelStateDB{
dryRun: opts.dryRun, linkNodeDB: &LinkNodeDB{
backend: backend,
},
backend: backend,
},
clock: opts.clock,
dryRun: opts.dryRun,
} }
chanDB.graph = newChannelGraph(
chanDB, opts.RejectCacheSize, opts.ChannelCacheSize, // Set the parent pointer (only used in tests).
opts.BatchCommitInterval, chanDB.channelStateDB.parent = chanDB
var err error
chanDB.graph, err = NewChannelGraph(
backend, opts.RejectCacheSize, opts.ChannelCacheSize,
opts.BatchCommitInterval, opts.PreAllocCacheNumNodes,
) )
if err != nil {
return nil, err
}
// Synchronize the version of database and apply migrations if needed. // Synchronize the version of database and apply migrations if needed.
if err := chanDB.syncVersions(dbVersions); err != nil { if err := chanDB.syncVersions(dbVersions); err != nil {
@@ -287,7 +310,7 @@ func (d *DB) Path() string {
return d.dbPath return d.dbPath
} }
var topLevelBuckets = [][]byte{ var dbTopLevelBuckets = [][]byte{
openChannelBucket, openChannelBucket,
closedChannelBucket, closedChannelBucket,
forwardingLogBucket, forwardingLogBucket,
@@ -298,10 +321,6 @@ var topLevelBuckets = [][]byte{
paymentsIndexBucket, paymentsIndexBucket,
peersBucket, peersBucket,
nodeInfoBucket, nodeInfoBucket,
nodeBucket,
edgeBucket,
edgeIndexBucket,
graphMetaBucket,
metaBucket, metaBucket,
closeSummaryBucket, closeSummaryBucket,
outpointBucket, outpointBucket,
@@ -312,7 +331,7 @@ var topLevelBuckets = [][]byte{
// operation is fully atomic. // operation is fully atomic.
func (d *DB) Wipe() error { func (d *DB) Wipe() error {
err := kvdb.Update(d, func(tx kvdb.RwTx) error { err := kvdb.Update(d, func(tx kvdb.RwTx) error {
for _, tlb := range topLevelBuckets { for _, tlb := range dbTopLevelBuckets {
err := tx.DeleteTopLevelBucket(tlb) err := tx.DeleteTopLevelBucket(tlb)
if err != nil && err != kvdb.ErrBucketNotFound { if err != nil && err != kvdb.ErrBucketNotFound {
return err return err
@@ -327,10 +346,10 @@ func (d *DB) Wipe() error {
return initChannelDB(d.Backend) return initChannelDB(d.Backend)
} }
// createChannelDB creates and initializes a fresh version of channeldb. In // initChannelDB creates and initializes a fresh version of channeldb. In the
// the case that the target path has not yet been created or doesn't yet exist, // case that the target path has not yet been created or doesn't yet exist, then
// then the path is created. Additionally, all required top-level buckets used // the path is created. Additionally, all required top-level buckets used within
// within the database are created. // the database are created.
func initChannelDB(db kvdb.Backend) error { func initChannelDB(db kvdb.Backend) error {
err := kvdb.Update(db, func(tx kvdb.RwTx) error { err := kvdb.Update(db, func(tx kvdb.RwTx) error {
meta := &Meta{} meta := &Meta{}
@@ -340,42 +359,12 @@ func initChannelDB(db kvdb.Backend) error {
return nil return nil
} }
for _, tlb := range topLevelBuckets { for _, tlb := range dbTopLevelBuckets {
if _, err := tx.CreateTopLevelBucket(tlb); err != nil { if _, err := tx.CreateTopLevelBucket(tlb); err != nil {
return err return err
} }
} }
nodes := tx.ReadWriteBucket(nodeBucket)
_, err = nodes.CreateBucket(aliasIndexBucket)
if err != nil {
return err
}
_, err = nodes.CreateBucket(nodeUpdateIndexBucket)
if err != nil {
return err
}
edges := tx.ReadWriteBucket(edgeBucket)
if _, err := edges.CreateBucket(edgeIndexBucket); err != nil {
return err
}
if _, err := edges.CreateBucket(edgeUpdateIndexBucket); err != nil {
return err
}
if _, err := edges.CreateBucket(channelPointBucket); err != nil {
return err
}
if _, err := edges.CreateBucket(zombieBucket); err != nil {
return err
}
graphMeta := tx.ReadWriteBucket(graphMetaBucket)
_, err = graphMeta.CreateBucket(pruneLogBucket)
if err != nil {
return err
}
meta.DbVersionNumber = getLatestDBVersion(dbVersions) meta.DbVersionNumber = getLatestDBVersion(dbVersions)
return putMeta(meta, tx) return putMeta(meta, tx)
}, func() {}) }, func() {})
@@ -397,15 +386,45 @@ func fileExists(path string) bool {
return true return true
} }
// ChannelStateDB is a database that keeps track of all channel state.
type ChannelStateDB struct {
// linkNodeDB separates all DB operations on LinkNodes.
linkNodeDB *LinkNodeDB
// parent holds a pointer to the "main" channeldb.DB object. This is
// only used for testing and should never be used in production code.
// For testing use the ChannelStateDB.GetParentDB() function to retrieve
// this pointer.
parent *DB
// backend points to the actual backend holding the channel state
// database. This may be a real backend or a cache middleware.
backend kvdb.Backend
}
// GetParentDB returns the "main" channeldb.DB object that is the owner of this
// ChannelStateDB instance. Use this function only in tests where passing around
// pointers makes testing less readable. Never to be used in production code!
func (c *ChannelStateDB) GetParentDB() *DB {
return c.parent
}
// LinkNodeDB returns the current instance of the link node database.
func (c *ChannelStateDB) LinkNodeDB() *LinkNodeDB {
return c.linkNodeDB
}
// FetchOpenChannels starts a new database transaction and returns all stored // FetchOpenChannels starts a new database transaction and returns all stored
// currently active/open channels associated with the target nodeID. In the case // currently active/open channels associated with the target nodeID. In the case
// that no active channels are known to have been created with this node, then a // that no active channels are known to have been created with this node, then a
// zero-length slice is returned. // zero-length slice is returned.
func (d *DB) FetchOpenChannels(nodeID *btcec.PublicKey) ([]*OpenChannel, error) { func (c *ChannelStateDB) FetchOpenChannels(nodeID *btcec.PublicKey) (
[]*OpenChannel, error) {
var channels []*OpenChannel var channels []*OpenChannel
err := kvdb.View(d, func(tx kvdb.RTx) error { err := kvdb.View(c.backend, func(tx kvdb.RTx) error {
var err error var err error
channels, err = d.fetchOpenChannels(tx, nodeID) channels, err = c.fetchOpenChannels(tx, nodeID)
return err return err
}, func() { }, func() {
channels = nil channels = nil
@@ -418,7 +437,7 @@ func (d *DB) FetchOpenChannels(nodeID *btcec.PublicKey) ([]*OpenChannel, error)
// stored currently active/open channels associated with the target nodeID. In // stored currently active/open channels associated with the target nodeID. In
// the case that no active channels are known to have been created with this // the case that no active channels are known to have been created with this
// node, then a zero-length slice is returned. // node, then a zero-length slice is returned.
func (d *DB) fetchOpenChannels(tx kvdb.RTx, func (c *ChannelStateDB) fetchOpenChannels(tx kvdb.RTx,
nodeID *btcec.PublicKey) ([]*OpenChannel, error) { nodeID *btcec.PublicKey) ([]*OpenChannel, error) {
// Get the bucket dedicated to storing the metadata for open channels. // Get the bucket dedicated to storing the metadata for open channels.
@@ -454,7 +473,7 @@ func (d *DB) fetchOpenChannels(tx kvdb.RTx,
// Finally, we both of the necessary buckets retrieved, fetch // Finally, we both of the necessary buckets retrieved, fetch
// all the active channels related to this node. // all the active channels related to this node.
nodeChannels, err := d.fetchNodeChannels(chainBucket) nodeChannels, err := c.fetchNodeChannels(chainBucket)
if err != nil { if err != nil {
return fmt.Errorf("unable to read channel for "+ return fmt.Errorf("unable to read channel for "+
"chain_hash=%x, node_key=%x: %v", "chain_hash=%x, node_key=%x: %v",
@@ -471,7 +490,8 @@ func (d *DB) fetchOpenChannels(tx kvdb.RTx,
// fetchNodeChannels retrieves all active channels from the target chainBucket // fetchNodeChannels retrieves all active channels from the target chainBucket
// which is under a node's dedicated channel bucket. This function is typically // which is under a node's dedicated channel bucket. This function is typically
// used to fetch all the active channels related to a particular node. // used to fetch all the active channels related to a particular node.
func (d *DB) fetchNodeChannels(chainBucket kvdb.RBucket) ([]*OpenChannel, error) { func (c *ChannelStateDB) fetchNodeChannels(chainBucket kvdb.RBucket) (
[]*OpenChannel, error) {
var channels []*OpenChannel var channels []*OpenChannel
@@ -497,7 +517,7 @@ func (d *DB) fetchNodeChannels(chainBucket kvdb.RBucket) ([]*OpenChannel, error)
return fmt.Errorf("unable to read channel data for "+ return fmt.Errorf("unable to read channel data for "+
"chan_point=%v: %v", outPoint, err) "chan_point=%v: %v", outPoint, err)
} }
oChannel.Db = d oChannel.Db = c
channels = append(channels, oChannel) channels = append(channels, oChannel)
@@ -514,8 +534,8 @@ func (d *DB) fetchNodeChannels(chainBucket kvdb.RBucket) ([]*OpenChannel, error)
// point. If the channel cannot be found, then an error will be returned. // point. If the channel cannot be found, then an error will be returned.
// Optionally an existing db tx can be supplied. Optionally an existing db tx // Optionally an existing db tx can be supplied. Optionally an existing db tx
// can be supplied. // can be supplied.
func (d *DB) FetchChannel(tx kvdb.RTx, chanPoint wire.OutPoint) (*OpenChannel, func (c *ChannelStateDB) FetchChannel(tx kvdb.RTx, chanPoint wire.OutPoint) (
error) { *OpenChannel, error) {
var ( var (
targetChan *OpenChannel targetChan *OpenChannel
@@ -591,7 +611,7 @@ func (d *DB) FetchChannel(tx kvdb.RTx, chanPoint wire.OutPoint) (*OpenChannel,
} }
targetChan = channel targetChan = channel
targetChan.Db = d targetChan.Db = c
return nil return nil
}) })
@@ -600,7 +620,7 @@ func (d *DB) FetchChannel(tx kvdb.RTx, chanPoint wire.OutPoint) (*OpenChannel,
var err error var err error
if tx == nil { if tx == nil {
err = kvdb.View(d, chanScan, func() {}) err = kvdb.View(c.backend, chanScan, func() {})
} else { } else {
err = chanScan(tx) err = chanScan(tx)
} }
@@ -620,16 +640,16 @@ func (d *DB) FetchChannel(tx kvdb.RTx, chanPoint wire.OutPoint) (*OpenChannel,
// FetchAllChannels attempts to retrieve all open channels currently stored // FetchAllChannels attempts to retrieve all open channels currently stored
// within the database, including pending open, fully open and channels waiting // within the database, including pending open, fully open and channels waiting
// for a closing transaction to confirm. // for a closing transaction to confirm.
func (d *DB) FetchAllChannels() ([]*OpenChannel, error) { func (c *ChannelStateDB) FetchAllChannels() ([]*OpenChannel, error) {
return fetchChannels(d) return fetchChannels(c)
} }
// FetchAllOpenChannels will return all channels that have the funding // FetchAllOpenChannels will return all channels that have the funding
// transaction confirmed, and is not waiting for a closing transaction to be // transaction confirmed, and is not waiting for a closing transaction to be
// confirmed. // confirmed.
func (d *DB) FetchAllOpenChannels() ([]*OpenChannel, error) { func (c *ChannelStateDB) FetchAllOpenChannels() ([]*OpenChannel, error) {
return fetchChannels( return fetchChannels(
d, c,
pendingChannelFilter(false), pendingChannelFilter(false),
waitingCloseFilter(false), waitingCloseFilter(false),
) )
@@ -638,8 +658,8 @@ func (d *DB) FetchAllOpenChannels() ([]*OpenChannel, error) {
// FetchPendingChannels will return channels that have completed the process of // FetchPendingChannels will return channels that have completed the process of
// generating and broadcasting funding transactions, but whose funding // generating and broadcasting funding transactions, but whose funding
// transactions have yet to be confirmed on the blockchain. // transactions have yet to be confirmed on the blockchain.
func (d *DB) FetchPendingChannels() ([]*OpenChannel, error) { func (c *ChannelStateDB) FetchPendingChannels() ([]*OpenChannel, error) {
return fetchChannels(d, return fetchChannels(c,
pendingChannelFilter(true), pendingChannelFilter(true),
waitingCloseFilter(false), waitingCloseFilter(false),
) )
@@ -649,9 +669,9 @@ func (d *DB) FetchPendingChannels() ([]*OpenChannel, error) {
// but are now waiting for a closing transaction to be confirmed. // but are now waiting for a closing transaction to be confirmed.
// //
// NOTE: This includes channels that are also pending to be opened. // NOTE: This includes channels that are also pending to be opened.
func (d *DB) FetchWaitingCloseChannels() ([]*OpenChannel, error) { func (c *ChannelStateDB) FetchWaitingCloseChannels() ([]*OpenChannel, error) {
return fetchChannels( return fetchChannels(
d, waitingCloseFilter(true), c, waitingCloseFilter(true),
) )
} }
@@ -692,10 +712,12 @@ func waitingCloseFilter(waitingClose bool) fetchChannelsFilter {
// which have a true value returned for *all* of the filters will be returned. // which have a true value returned for *all* of the filters will be returned.
// If no filters are provided, every channel in the open channels bucket will // If no filters are provided, every channel in the open channels bucket will
// be returned. // be returned.
func fetchChannels(d *DB, filters ...fetchChannelsFilter) ([]*OpenChannel, error) { func fetchChannels(c *ChannelStateDB, filters ...fetchChannelsFilter) (
[]*OpenChannel, error) {
var channels []*OpenChannel var channels []*OpenChannel
err := kvdb.View(d, func(tx kvdb.RTx) error { err := kvdb.View(c.backend, func(tx kvdb.RTx) error {
// Get the bucket dedicated to storing the metadata for open // Get the bucket dedicated to storing the metadata for open
// channels. // channels.
openChanBucket := tx.ReadBucket(openChannelBucket) openChanBucket := tx.ReadBucket(openChannelBucket)
@@ -737,7 +759,7 @@ func fetchChannels(d *DB, filters ...fetchChannelsFilter) ([]*OpenChannel, error
"bucket for chain=%x", chainHash[:]) "bucket for chain=%x", chainHash[:])
} }
nodeChans, err := d.fetchNodeChannels(chainBucket) nodeChans, err := c.fetchNodeChannels(chainBucket)
if err != nil { if err != nil {
return fmt.Errorf("unable to read "+ return fmt.Errorf("unable to read "+
"channel for chain_hash=%x, "+ "channel for chain_hash=%x, "+
@@ -786,10 +808,12 @@ func fetchChannels(d *DB, filters ...fetchChannelsFilter) ([]*OpenChannel, error
// it becomes fully closed after a single confirmation. When a channel was // it becomes fully closed after a single confirmation. When a channel was
// forcibly closed, it will become fully closed after _all_ the pending funds // forcibly closed, it will become fully closed after _all_ the pending funds
// (if any) have been swept. // (if any) have been swept.
func (d *DB) FetchClosedChannels(pendingOnly bool) ([]*ChannelCloseSummary, error) { func (c *ChannelStateDB) FetchClosedChannels(pendingOnly bool) (
[]*ChannelCloseSummary, error) {
var chanSummaries []*ChannelCloseSummary var chanSummaries []*ChannelCloseSummary
if err := kvdb.View(d, func(tx kvdb.RTx) error { if err := kvdb.View(c.backend, func(tx kvdb.RTx) error {
closeBucket := tx.ReadBucket(closedChannelBucket) closeBucket := tx.ReadBucket(closedChannelBucket)
if closeBucket == nil { if closeBucket == nil {
return ErrNoClosedChannels return ErrNoClosedChannels
@@ -827,9 +851,11 @@ var ErrClosedChannelNotFound = errors.New("unable to find closed channel summary
// FetchClosedChannel queries for a channel close summary using the channel // FetchClosedChannel queries for a channel close summary using the channel
// point of the channel in question. // point of the channel in question.
func (d *DB) FetchClosedChannel(chanID *wire.OutPoint) (*ChannelCloseSummary, error) { func (c *ChannelStateDB) FetchClosedChannel(chanID *wire.OutPoint) (
*ChannelCloseSummary, error) {
var chanSummary *ChannelCloseSummary var chanSummary *ChannelCloseSummary
if err := kvdb.View(d, func(tx kvdb.RTx) error { if err := kvdb.View(c.backend, func(tx kvdb.RTx) error {
closeBucket := tx.ReadBucket(closedChannelBucket) closeBucket := tx.ReadBucket(closedChannelBucket)
if closeBucket == nil { if closeBucket == nil {
return ErrClosedChannelNotFound return ErrClosedChannelNotFound
@@ -861,11 +887,11 @@ func (d *DB) FetchClosedChannel(chanID *wire.OutPoint) (*ChannelCloseSummary, er
// FetchClosedChannelForID queries for a channel close summary using the // FetchClosedChannelForID queries for a channel close summary using the
// channel ID of the channel in question. // channel ID of the channel in question.
func (d *DB) FetchClosedChannelForID(cid lnwire.ChannelID) ( func (c *ChannelStateDB) FetchClosedChannelForID(cid lnwire.ChannelID) (
*ChannelCloseSummary, error) { *ChannelCloseSummary, error) {
var chanSummary *ChannelCloseSummary var chanSummary *ChannelCloseSummary
if err := kvdb.View(d, func(tx kvdb.RTx) error { if err := kvdb.View(c.backend, func(tx kvdb.RTx) error {
closeBucket := tx.ReadBucket(closedChannelBucket) closeBucket := tx.ReadBucket(closedChannelBucket)
if closeBucket == nil { if closeBucket == nil {
return ErrClosedChannelNotFound return ErrClosedChannelNotFound
@@ -914,8 +940,12 @@ func (d *DB) FetchClosedChannelForID(cid lnwire.ChannelID) (
// cooperatively closed and it's reached a single confirmation, or after all // cooperatively closed and it's reached a single confirmation, or after all
// the pending funds in a channel that has been forcibly closed have been // the pending funds in a channel that has been forcibly closed have been
// swept. // swept.
func (d *DB) MarkChanFullyClosed(chanPoint *wire.OutPoint) error { func (c *ChannelStateDB) MarkChanFullyClosed(chanPoint *wire.OutPoint) error {
return kvdb.Update(d, func(tx kvdb.RwTx) error { var (
openChannels []*OpenChannel
pruneLinkNode *btcec.PublicKey
)
err := kvdb.Update(c.backend, func(tx kvdb.RwTx) error {
var b bytes.Buffer var b bytes.Buffer
if err := writeOutpoint(&b, chanPoint); err != nil { if err := writeOutpoint(&b, chanPoint); err != nil {
return err return err
@@ -961,19 +991,35 @@ func (d *DB) MarkChanFullyClosed(chanPoint *wire.OutPoint) error {
// other open channels with this peer. If we don't we'll // other open channels with this peer. If we don't we'll
// garbage collect it to ensure we don't establish persistent // garbage collect it to ensure we don't establish persistent
// connections to peers without open channels. // connections to peers without open channels.
return d.pruneLinkNode(tx, chanSummary.RemotePub) pruneLinkNode = chanSummary.RemotePub
}, func() {}) openChannels, err = c.fetchOpenChannels(
tx, pruneLinkNode,
)
if err != nil {
return fmt.Errorf("unable to fetch open channels for "+
"peer %x: %v",
pruneLinkNode.SerializeCompressed(), err)
}
return nil
}, func() {
openChannels = nil
pruneLinkNode = nil
})
if err != nil {
return err
}
// Decide whether we want to remove the link node, based upon the number
// of still open channels.
return c.pruneLinkNode(openChannels, pruneLinkNode)
} }
// pruneLinkNode determines whether we should garbage collect a link node from // pruneLinkNode determines whether we should garbage collect a link node from
// the database due to no longer having any open channels with it. If there are // the database due to no longer having any open channels with it. If there are
// any left, then this acts as a no-op. // any left, then this acts as a no-op.
func (d *DB) pruneLinkNode(tx kvdb.RwTx, remotePub *btcec.PublicKey) error { func (c *ChannelStateDB) pruneLinkNode(openChannels []*OpenChannel,
openChannels, err := d.fetchOpenChannels(tx, remotePub) remotePub *btcec.PublicKey) error {
if err != nil {
return fmt.Errorf("unable to fetch open channels for peer %x: "+
"%v", remotePub.SerializeCompressed(), err)
}
if len(openChannels) > 0 { if len(openChannels) > 0 {
return nil return nil
@@ -982,27 +1028,42 @@ func (d *DB) pruneLinkNode(tx kvdb.RwTx, remotePub *btcec.PublicKey) error {
log.Infof("Pruning link node %x with zero open channels from database", log.Infof("Pruning link node %x with zero open channels from database",
remotePub.SerializeCompressed()) remotePub.SerializeCompressed())
return d.deleteLinkNode(tx, remotePub) return c.linkNodeDB.DeleteLinkNode(remotePub)
} }
// PruneLinkNodes attempts to prune all link nodes found within the databse with // PruneLinkNodes attempts to prune all link nodes found within the databse with
// whom we no longer have any open channels with. // whom we no longer have any open channels with.
func (d *DB) PruneLinkNodes() error { func (c *ChannelStateDB) PruneLinkNodes() error {
return kvdb.Update(d, func(tx kvdb.RwTx) error { allLinkNodes, err := c.linkNodeDB.FetchAllLinkNodes()
linkNodes, err := d.fetchAllLinkNodes(tx) if err != nil {
return err
}
for _, linkNode := range allLinkNodes {
var (
openChannels []*OpenChannel
linkNode = linkNode
)
err := kvdb.View(c.backend, func(tx kvdb.RTx) error {
var err error
openChannels, err = c.fetchOpenChannels(
tx, linkNode.IdentityPub,
)
return err
}, func() {
openChannels = nil
})
if err != nil { if err != nil {
return err return err
} }
for _, linkNode := range linkNodes { err = c.pruneLinkNode(openChannels, linkNode.IdentityPub)
err := d.pruneLinkNode(tx, linkNode.IdentityPub) if err != nil {
if err != nil { return err
return err
}
} }
}
return nil return nil
}, func() {})
} }
// ChannelShell is a shell of a channel that is meant to be used for channel // ChannelShell is a shell of a channel that is meant to be used for channel
@@ -1024,8 +1085,8 @@ type ChannelShell struct {
// addresses, and finally create an edge within the graph for the channel as // addresses, and finally create an edge within the graph for the channel as
// well. This method is idempotent, so repeated calls with the same set of // well. This method is idempotent, so repeated calls with the same set of
// channel shells won't modify the database after the initial call. // channel shells won't modify the database after the initial call.
func (d *DB) RestoreChannelShells(channelShells ...*ChannelShell) error { func (c *ChannelStateDB) RestoreChannelShells(channelShells ...*ChannelShell) error {
err := kvdb.Update(d, func(tx kvdb.RwTx) error { err := kvdb.Update(c.backend, func(tx kvdb.RwTx) error {
for _, channelShell := range channelShells { for _, channelShell := range channelShells {
channel := channelShell.Chan channel := channelShell.Chan
@@ -1039,7 +1100,7 @@ func (d *DB) RestoreChannelShells(channelShells ...*ChannelShell) error {
// and link node for this channel. If the channel // and link node for this channel. If the channel
// already exists, then in order to ensure this method // already exists, then in order to ensure this method
// is idempotent, we'll continue to the next step. // is idempotent, we'll continue to the next step.
channel.Db = d channel.Db = c
err := syncNewChannel( err := syncNewChannel(
tx, channel, channelShell.NodeAddrs, tx, channel, channelShell.NodeAddrs,
) )
@@ -1059,41 +1120,28 @@ func (d *DB) RestoreChannelShells(channelShells ...*ChannelShell) error {
// AddrsForNode consults the graph and channel database for all addresses known // AddrsForNode consults the graph and channel database for all addresses known
// to the passed node public key. // to the passed node public key.
func (d *DB) AddrsForNode(nodePub *btcec.PublicKey) ([]net.Addr, error) { func (d *DB) AddrsForNode(nodePub *btcec.PublicKey) ([]net.Addr,
var ( error) {
linkNode *LinkNode
graphNode LightningNode
)
dbErr := kvdb.View(d, func(tx kvdb.RTx) error { linkNode, err := d.channelStateDB.linkNodeDB.FetchLinkNode(nodePub)
var err error if err != nil {
return nil, err
}
linkNode, err = fetchLinkNode(tx, nodePub) // We'll also query the graph for this peer to see if they have any
if err != nil { // addresses that we don't currently have stored within the link node
return err // database.
} pubKey, err := route.NewVertexFromBytes(nodePub.SerializeCompressed())
if err != nil {
// We'll also query the graph for this peer to see if they have return nil, err
// any addresses that we don't currently have stored within the }
// link node database. graphNode, err := d.graph.FetchLightningNode(pubKey)
nodes := tx.ReadBucket(nodeBucket) if err != nil && err != ErrGraphNodeNotFound {
if nodes == nil { return nil, err
return ErrGraphNotFound } else if err == ErrGraphNodeNotFound {
} // If the node isn't found, then that's OK, as we still have the
compressedPubKey := nodePub.SerializeCompressed() // link node data. But any other error needs to be returned.
graphNode, err = fetchLightningNode(nodes, compressedPubKey) graphNode = &LightningNode{}
if err != nil && err != ErrGraphNodeNotFound {
// If the node isn't found, then that's OK, as we still
// have the link node data.
return err
}
return nil
}, func() {
linkNode = nil
})
if dbErr != nil {
return nil, dbErr
} }
// Now that we have both sources of addrs for this node, we'll use a // Now that we have both sources of addrs for this node, we'll use a
@@ -1118,16 +1166,18 @@ func (d *DB) AddrsForNode(nodePub *btcec.PublicKey) ([]net.Addr, error) {
// database. If the channel was already removed (has a closed channel entry), // database. If the channel was already removed (has a closed channel entry),
// then we'll return a nil error. Otherwise, we'll insert a new close summary // then we'll return a nil error. Otherwise, we'll insert a new close summary
// into the database. // into the database.
func (d *DB) AbandonChannel(chanPoint *wire.OutPoint, bestHeight uint32) error { func (c *ChannelStateDB) AbandonChannel(chanPoint *wire.OutPoint,
bestHeight uint32) error {
// With the chanPoint constructed, we'll attempt to find the target // With the chanPoint constructed, we'll attempt to find the target
// channel in the database. If we can't find the channel, then we'll // channel in the database. If we can't find the channel, then we'll
// return the error back to the caller. // return the error back to the caller.
dbChan, err := d.FetchChannel(nil, *chanPoint) dbChan, err := c.FetchChannel(nil, *chanPoint)
switch { switch {
// If the channel wasn't found, then it's possible that it was already // If the channel wasn't found, then it's possible that it was already
// abandoned from the database. // abandoned from the database.
case err == ErrChannelNotFound: case err == ErrChannelNotFound:
_, closedErr := d.FetchClosedChannel(chanPoint) _, closedErr := c.FetchClosedChannel(chanPoint)
if closedErr != nil { if closedErr != nil {
return closedErr return closedErr
} }
@@ -1163,6 +1213,58 @@ func (d *DB) AbandonChannel(chanPoint *wire.OutPoint, bestHeight uint32) error {
return dbChan.CloseChannel(summary, ChanStatusLocalCloseInitiator) return dbChan.CloseChannel(summary, ChanStatusLocalCloseInitiator)
} }
// SaveChannelOpeningState saves the serialized channel state for the provided
// chanPoint to the channelOpeningStateBucket.
func (c *ChannelStateDB) SaveChannelOpeningState(outPoint,
serializedState []byte) error {
return kvdb.Update(c.backend, func(tx kvdb.RwTx) error {
bucket, err := tx.CreateTopLevelBucket(channelOpeningStateBucket)
if err != nil {
return err
}
return bucket.Put(outPoint, serializedState)
}, func() {})
}
// GetChannelOpeningState fetches the serialized channel state for the provided
// outPoint from the database, or returns ErrChannelNotFound if the channel
// is not found.
func (c *ChannelStateDB) GetChannelOpeningState(outPoint []byte) ([]byte, error) {
var serializedState []byte
err := kvdb.View(c.backend, func(tx kvdb.RTx) error {
bucket := tx.ReadBucket(channelOpeningStateBucket)
if bucket == nil {
// If the bucket does not exist, it means we never added
// a channel to the db, so return ErrChannelNotFound.
return ErrChannelNotFound
}
serializedState = bucket.Get(outPoint)
if serializedState == nil {
return ErrChannelNotFound
}
return nil
}, func() {
serializedState = nil
})
return serializedState, err
}
// DeleteChannelOpeningState removes any state for outPoint from the database.
func (c *ChannelStateDB) DeleteChannelOpeningState(outPoint []byte) error {
return kvdb.Update(c.backend, func(tx kvdb.RwTx) error {
bucket := tx.ReadWriteBucket(channelOpeningStateBucket)
if bucket == nil {
return ErrChannelNotFound
}
return bucket.Delete(outPoint)
}, func() {})
}
// syncVersions function is used for safe db version synchronization. It // syncVersions function is used for safe db version synchronization. It
// applies migration functions to the current database and recovers the // applies migration functions to the current database and recovers the
// previous state of db if at least one error/panic appeared during migration. // previous state of db if at least one error/panic appeared during migration.
@@ -1236,11 +1338,17 @@ func (d *DB) syncVersions(versions []version) error {
}, func() {}) }, func() {})
} }
// ChannelGraph returns a new instance of the directed channel graph. // ChannelGraph returns the current instance of the directed channel graph.
func (d *DB) ChannelGraph() *ChannelGraph { func (d *DB) ChannelGraph() *ChannelGraph {
return d.graph return d.graph
} }
// ChannelStateDB returns the sub database that is concerned with the channel
// state.
func (d *DB) ChannelStateDB() *ChannelStateDB {
return d.channelStateDB
}
func getLatestDBVersion(versions []version) uint32 { func getLatestDBVersion(versions []version) uint32 {
return versions[len(versions)-1].number return versions[len(versions)-1].number
} }
@@ -1290,9 +1398,11 @@ func fetchHistoricalChanBucket(tx kvdb.RTx,
// FetchHistoricalChannel fetches open channel data from the historical channel // FetchHistoricalChannel fetches open channel data from the historical channel
// bucket. // bucket.
func (d *DB) FetchHistoricalChannel(outPoint *wire.OutPoint) (*OpenChannel, error) { func (c *ChannelStateDB) FetchHistoricalChannel(outPoint *wire.OutPoint) (
*OpenChannel, error) {
var channel *OpenChannel var channel *OpenChannel
err := kvdb.View(d, func(tx kvdb.RTx) error { err := kvdb.View(c.backend, func(tx kvdb.RTx) error {
chanBucket, err := fetchHistoricalChanBucket(tx, outPoint) chanBucket, err := fetchHistoricalChanBucket(tx, outPoint)
if err != nil { if err != nil {
return err return err
@@ -1300,7 +1410,7 @@ func (d *DB) FetchHistoricalChannel(outPoint *wire.OutPoint) (*OpenChannel, erro
channel, err = fetchOpenChannel(chanBucket, outPoint) channel, err = fetchOpenChannel(chanBucket, outPoint)
channel.Db = d channel.Db = c
return err return err
}, func() { }, func() {
channel = nil channel = nil

View File

@@ -87,15 +87,18 @@ func TestWipe(t *testing.T) {
} }
defer cleanup() defer cleanup()
cdb, err := CreateWithBackend(backend) fullDB, err := CreateWithBackend(backend)
if err != nil { if err != nil {
t.Fatalf("unable to create channeldb: %v", err) t.Fatalf("unable to create channeldb: %v", err)
} }
defer cdb.Close() defer fullDB.Close()
if err := cdb.Wipe(); err != nil { if err := fullDB.Wipe(); err != nil {
t.Fatalf("unable to wipe channeldb: %v", err) t.Fatalf("unable to wipe channeldb: %v", err)
} }
cdb := fullDB.ChannelStateDB()
// Check correct errors are returned // Check correct errors are returned
openChannels, err := cdb.FetchAllOpenChannels() openChannels, err := cdb.FetchAllOpenChannels()
require.NoError(t, err, "fetching open channels") require.NoError(t, err, "fetching open channels")
@@ -113,12 +116,14 @@ func TestFetchClosedChannelForID(t *testing.T) {
const numChans = 101 const numChans = 101
cdb, cleanUp, err := MakeTestDB() fullDB, cleanUp, err := MakeTestDB()
if err != nil { if err != nil {
t.Fatalf("unable to make test database: %v", err) t.Fatalf("unable to make test database: %v", err)
} }
defer cleanUp() defer cleanUp()
cdb := fullDB.ChannelStateDB()
// Create the test channel state, that we will mutate the index of the // Create the test channel state, that we will mutate the index of the
// funding point. // funding point.
state := createTestChannelState(t, cdb) state := createTestChannelState(t, cdb)
@@ -184,18 +189,18 @@ func TestFetchClosedChannelForID(t *testing.T) {
func TestAddrsForNode(t *testing.T) { func TestAddrsForNode(t *testing.T) {
t.Parallel() t.Parallel()
cdb, cleanUp, err := MakeTestDB() fullDB, cleanUp, err := MakeTestDB()
if err != nil { if err != nil {
t.Fatalf("unable to make test database: %v", err) t.Fatalf("unable to make test database: %v", err)
} }
defer cleanUp() defer cleanUp()
graph := cdb.ChannelGraph() graph := fullDB.ChannelGraph()
// We'll make a test vertex to insert into the database, as the source // We'll make a test vertex to insert into the database, as the source
// node, but this node will only have half the number of addresses it // node, but this node will only have half the number of addresses it
// usually does. // usually does.
testNode, err := createTestVertex(cdb) testNode, err := createTestVertex(fullDB)
if err != nil { if err != nil {
t.Fatalf("unable to create test node: %v", err) t.Fatalf("unable to create test node: %v", err)
} }
@@ -210,8 +215,9 @@ func TestAddrsForNode(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("unable to recv node pub: %v", err) t.Fatalf("unable to recv node pub: %v", err)
} }
linkNode := cdb.NewLinkNode( linkNode := NewLinkNode(
wire.MainNet, nodePub, anotherAddr, fullDB.channelStateDB.linkNodeDB, wire.MainNet, nodePub,
anotherAddr,
) )
if err := linkNode.Sync(); err != nil { if err := linkNode.Sync(); err != nil {
t.Fatalf("unable to sync link node: %v", err) t.Fatalf("unable to sync link node: %v", err)
@@ -219,7 +225,7 @@ func TestAddrsForNode(t *testing.T) {
// Now that we've created a link node, as well as a vertex for the // Now that we've created a link node, as well as a vertex for the
// node, we'll query for all its addresses. // node, we'll query for all its addresses.
nodeAddrs, err := cdb.AddrsForNode(nodePub) nodeAddrs, err := fullDB.AddrsForNode(nodePub)
if err != nil { if err != nil {
t.Fatalf("unable to obtain node addrs: %v", err) t.Fatalf("unable to obtain node addrs: %v", err)
} }
@@ -245,12 +251,14 @@ func TestAddrsForNode(t *testing.T) {
func TestFetchChannel(t *testing.T) { func TestFetchChannel(t *testing.T) {
t.Parallel() t.Parallel()
cdb, cleanUp, err := MakeTestDB() fullDB, cleanUp, err := MakeTestDB()
if err != nil { if err != nil {
t.Fatalf("unable to make test database: %v", err) t.Fatalf("unable to make test database: %v", err)
} }
defer cleanUp() defer cleanUp()
cdb := fullDB.ChannelStateDB()
// Create an open channel. // Create an open channel.
channelState := createTestChannel(t, cdb, openChannelOption()) channelState := createTestChannel(t, cdb, openChannelOption())
@@ -349,12 +357,14 @@ func genRandomChannelShell() (*ChannelShell, error) {
func TestRestoreChannelShells(t *testing.T) { func TestRestoreChannelShells(t *testing.T) {
t.Parallel() t.Parallel()
cdb, cleanUp, err := MakeTestDB() fullDB, cleanUp, err := MakeTestDB()
if err != nil { if err != nil {
t.Fatalf("unable to make test database: %v", err) t.Fatalf("unable to make test database: %v", err)
} }
defer cleanUp() defer cleanUp()
cdb := fullDB.ChannelStateDB()
// First, we'll make our channel shell, it will only have the minimal // First, we'll make our channel shell, it will only have the minimal
// amount of information required for us to initiate the data loss // amount of information required for us to initiate the data loss
// protection feature. // protection feature.
@@ -423,7 +433,9 @@ func TestRestoreChannelShells(t *testing.T) {
// We should also be able to find the link node that was inserted by // We should also be able to find the link node that was inserted by
// its public key. // its public key.
linkNode, err := cdb.FetchLinkNode(channelShell.Chan.IdentityPub) linkNode, err := fullDB.channelStateDB.linkNodeDB.FetchLinkNode(
channelShell.Chan.IdentityPub,
)
if err != nil { if err != nil {
t.Fatalf("unable to fetch link node: %v", err) t.Fatalf("unable to fetch link node: %v", err)
} }
@@ -443,12 +455,14 @@ func TestRestoreChannelShells(t *testing.T) {
func TestAbandonChannel(t *testing.T) { func TestAbandonChannel(t *testing.T) {
t.Parallel() t.Parallel()
cdb, cleanUp, err := MakeTestDB() fullDB, cleanUp, err := MakeTestDB()
if err != nil { if err != nil {
t.Fatalf("unable to make test database: %v", err) t.Fatalf("unable to make test database: %v", err)
} }
defer cleanUp() defer cleanUp()
cdb := fullDB.ChannelStateDB()
// If we attempt to abandon the state of a channel that doesn't exist // If we attempt to abandon the state of a channel that doesn't exist
// in the open or closed channel bucket, then we should receive an // in the open or closed channel bucket, then we should receive an
// error. // error.
@@ -616,13 +630,15 @@ func TestFetchChannels(t *testing.T) {
t.Run(test.name, func(t *testing.T) { t.Run(test.name, func(t *testing.T) {
t.Parallel() t.Parallel()
cdb, cleanUp, err := MakeTestDB() fullDB, cleanUp, err := MakeTestDB()
if err != nil { if err != nil {
t.Fatalf("unable to make test "+ t.Fatalf("unable to make test "+
"database: %v", err) "database: %v", err)
} }
defer cleanUp() defer cleanUp()
cdb := fullDB.ChannelStateDB()
// Create a pending channel that is not awaiting close. // Create a pending channel that is not awaiting close.
createTestChannel( createTestChannel(
t, cdb, channelIDOption(pendingChan), t, cdb, channelIDOption(pendingChan),
@@ -685,12 +701,14 @@ func TestFetchChannels(t *testing.T) {
// TestFetchHistoricalChannel tests lookup of historical channels. // TestFetchHistoricalChannel tests lookup of historical channels.
func TestFetchHistoricalChannel(t *testing.T) { func TestFetchHistoricalChannel(t *testing.T) {
cdb, cleanUp, err := MakeTestDB() fullDB, cleanUp, err := MakeTestDB()
if err != nil { if err != nil {
t.Fatalf("unable to make test database: %v", err) t.Fatalf("unable to make test database: %v", err)
} }
defer cleanUp() defer cleanUp()
cdb := fullDB.ChannelStateDB()
// Create a an open channel in the database. // Create a an open channel in the database.
channel := createTestChannel(t, cdb, openChannelOption()) channel := createTestChannel(t, cdb, openChannelOption())

View File

@@ -174,39 +174,132 @@ const (
// independently. Edge removal results in the deletion of all edge information // independently. Edge removal results in the deletion of all edge information
// for that edge. // for that edge.
type ChannelGraph struct { type ChannelGraph struct {
db *DB db kvdb.Backend
cacheMu sync.RWMutex cacheMu sync.RWMutex
rejectCache *rejectCache rejectCache *rejectCache
chanCache *channelCache chanCache *channelCache
graphCache *GraphCache
chanScheduler batch.Scheduler chanScheduler batch.Scheduler
nodeScheduler batch.Scheduler nodeScheduler batch.Scheduler
} }
// newChannelGraph allocates a new ChannelGraph backed by a DB instance. The // NewChannelGraph allocates a new ChannelGraph backed by a DB instance. The
// returned instance has its own unique reject cache and channel cache. // returned instance has its own unique reject cache and channel cache.
func newChannelGraph(db *DB, rejectCacheSize, chanCacheSize int, func NewChannelGraph(db kvdb.Backend, rejectCacheSize, chanCacheSize int,
batchCommitInterval time.Duration) *ChannelGraph { batchCommitInterval time.Duration,
preAllocCacheNumNodes int) (*ChannelGraph, error) {
if err := initChannelGraph(db); err != nil {
return nil, err
}
g := &ChannelGraph{ g := &ChannelGraph{
db: db, db: db,
rejectCache: newRejectCache(rejectCacheSize), rejectCache: newRejectCache(rejectCacheSize),
chanCache: newChannelCache(chanCacheSize), chanCache: newChannelCache(chanCacheSize),
graphCache: NewGraphCache(preAllocCacheNumNodes),
} }
g.chanScheduler = batch.NewTimeScheduler( g.chanScheduler = batch.NewTimeScheduler(
db.Backend, &g.cacheMu, batchCommitInterval, db, &g.cacheMu, batchCommitInterval,
) )
g.nodeScheduler = batch.NewTimeScheduler( g.nodeScheduler = batch.NewTimeScheduler(
db.Backend, nil, batchCommitInterval, db, nil, batchCommitInterval,
) )
return g startTime := time.Now()
log.Debugf("Populating in-memory channel graph, this might take a " +
"while...")
err := g.ForEachNodeCacheable(func(tx kvdb.RTx, node GraphCacheNode) error {
return g.graphCache.AddNode(tx, node)
})
if err != nil {
return nil, err
}
log.Debugf("Finished populating in-memory channel graph (took %v, %s)",
time.Since(startTime), g.graphCache.Stats())
return g, nil
} }
// Database returns a pointer to the underlying database. var graphTopLevelBuckets = [][]byte{
func (c *ChannelGraph) Database() *DB { nodeBucket,
return c.db edgeBucket,
edgeIndexBucket,
graphMetaBucket,
}
// Wipe completely deletes all saved state within all used buckets within the
// database. The deletion is done in a single transaction, therefore this
// operation is fully atomic.
func (c *ChannelGraph) Wipe() error {
err := kvdb.Update(c.db, func(tx kvdb.RwTx) error {
for _, tlb := range graphTopLevelBuckets {
err := tx.DeleteTopLevelBucket(tlb)
if err != nil && err != kvdb.ErrBucketNotFound {
return err
}
}
return nil
}, func() {})
if err != nil {
return err
}
return initChannelGraph(c.db)
}
// createChannelDB creates and initializes a fresh version of channeldb. In
// the case that the target path has not yet been created or doesn't yet exist,
// then the path is created. Additionally, all required top-level buckets used
// within the database are created.
func initChannelGraph(db kvdb.Backend) error {
err := kvdb.Update(db, func(tx kvdb.RwTx) error {
for _, tlb := range graphTopLevelBuckets {
if _, err := tx.CreateTopLevelBucket(tlb); err != nil {
return err
}
}
nodes := tx.ReadWriteBucket(nodeBucket)
_, err := nodes.CreateBucketIfNotExists(aliasIndexBucket)
if err != nil {
return err
}
_, err = nodes.CreateBucketIfNotExists(nodeUpdateIndexBucket)
if err != nil {
return err
}
edges := tx.ReadWriteBucket(edgeBucket)
_, err = edges.CreateBucketIfNotExists(edgeIndexBucket)
if err != nil {
return err
}
_, err = edges.CreateBucketIfNotExists(edgeUpdateIndexBucket)
if err != nil {
return err
}
_, err = edges.CreateBucketIfNotExists(channelPointBucket)
if err != nil {
return err
}
_, err = edges.CreateBucketIfNotExists(zombieBucket)
if err != nil {
return err
}
graphMeta := tx.ReadWriteBucket(graphMetaBucket)
_, err = graphMeta.CreateBucketIfNotExists(pruneLogBucket)
return err
}, func() {})
if err != nil {
return fmt.Errorf("unable to create new channel graph: %v", err)
}
return nil
} }
// ForEachChannel iterates through all the channel edges stored within the // ForEachChannel iterates through all the channel edges stored within the
@@ -218,7 +311,9 @@ func (c *ChannelGraph) Database() *DB {
// NOTE: If an edge can't be found, or wasn't advertised, then a nil pointer // NOTE: If an edge can't be found, or wasn't advertised, then a nil pointer
// for that particular channel edge routing policy will be passed into the // for that particular channel edge routing policy will be passed into the
// callback. // callback.
func (c *ChannelGraph) ForEachChannel(cb func(*ChannelEdgeInfo, *ChannelEdgePolicy, *ChannelEdgePolicy) error) error { func (c *ChannelGraph) ForEachChannel(cb func(*ChannelEdgeInfo,
*ChannelEdgePolicy, *ChannelEdgePolicy) error) error {
// TODO(roasbeef): ptr map to reduce # of allocs? no duplicates // TODO(roasbeef): ptr map to reduce # of allocs? no duplicates
return kvdb.View(c.db, func(tx kvdb.RTx) error { return kvdb.View(c.db, func(tx kvdb.RTx) error {
@@ -270,23 +365,22 @@ func (c *ChannelGraph) ForEachChannel(cb func(*ChannelEdgeInfo, *ChannelEdgePoli
// ForEachNodeChannel iterates through all channels of a given node, executing the // ForEachNodeChannel iterates through all channels of a given node, executing the
// passed callback with an edge info structure and the policies of each end // passed callback with an edge info structure and the policies of each end
// of the channel. The first edge policy is the outgoing edge *to* the // of the channel. The first edge policy is the outgoing edge *to* the
// the connecting node, while the second is the incoming edge *from* the // connecting node, while the second is the incoming edge *from* the
// connecting node. If the callback returns an error, then the iteration is // connecting node. If the callback returns an error, then the iteration is
// halted with the error propagated back up to the caller. // halted with the error propagated back up to the caller.
// //
// Unknown policies are passed into the callback as nil values. // Unknown policies are passed into the callback as nil values.
// func (c *ChannelGraph) ForEachNodeChannel(node route.Vertex,
// If the caller wishes to re-use an existing boltdb transaction, then it cb func(channel *DirectedChannel) error) error {
// should be passed as the first argument. Otherwise the first argument should
// be nil and a fresh transaction will be created to execute the graph
// traversal.
func (c *ChannelGraph) ForEachNodeChannel(tx kvdb.RTx, nodePub []byte,
cb func(kvdb.RTx, *ChannelEdgeInfo, *ChannelEdgePolicy,
*ChannelEdgePolicy) error) error {
db := c.db return c.graphCache.ForEachChannel(node, cb)
}
return nodeTraversal(tx, nodePub, db, cb) // FetchNodeFeatures returns the features of a given node.
func (c *ChannelGraph) FetchNodeFeatures(
node route.Vertex) (*lnwire.FeatureVector, error) {
return c.graphCache.GetFeatures(node), nil
} }
// DisabledChannelIDs returns the channel ids of disabled channels. // DisabledChannelIDs returns the channel ids of disabled channels.
@@ -374,6 +468,47 @@ func (c *ChannelGraph) ForEachNode(cb func(kvdb.RTx, *LightningNode) error) erro
return kvdb.View(c.db, traversal, func() {}) return kvdb.View(c.db, traversal, func() {})
} }
// ForEachNodeCacheable iterates through all the stored vertices/nodes in the
// graph, executing the passed callback with each node encountered. If the
// callback returns an error, then the transaction is aborted and the iteration
// stops early.
func (c *ChannelGraph) ForEachNodeCacheable(cb func(kvdb.RTx,
GraphCacheNode) error) error {
traversal := func(tx kvdb.RTx) error {
// First grab the nodes bucket which stores the mapping from
// pubKey to node information.
nodes := tx.ReadBucket(nodeBucket)
if nodes == nil {
return ErrGraphNotFound
}
cacheableNode := newGraphCacheNode(route.Vertex{}, nil)
return nodes.ForEach(func(pubKey, nodeBytes []byte) error {
// If this is the source key, then we skip this
// iteration as the value for this key is a pubKey
// rather than raw node information.
if bytes.Equal(pubKey, sourceKey) || len(pubKey) != 33 {
return nil
}
nodeReader := bytes.NewReader(nodeBytes)
err := deserializeLightningNodeCacheable(
nodeReader, cacheableNode,
)
if err != nil {
return err
}
// Execute the callback, the transaction will abort if
// this returns an error.
return cb(tx, cacheableNode)
})
}
return kvdb.View(c.db, traversal, func() {})
}
// SourceNode returns the source node of the graph. The source node is treated // SourceNode returns the source node of the graph. The source node is treated
// as the center node within a star-graph. This method may be used to kick off // as the center node within a star-graph. This method may be used to kick off
// a path finding algorithm in order to explore the reachability of another // a path finding algorithm in order to explore the reachability of another
@@ -465,6 +600,13 @@ func (c *ChannelGraph) AddLightningNode(node *LightningNode,
r := &batch.Request{ r := &batch.Request{
Update: func(tx kvdb.RwTx) error { Update: func(tx kvdb.RwTx) error {
cNode := newGraphCacheNode(
node.PubKeyBytes, node.Features,
)
if err := c.graphCache.AddNode(tx, cNode); err != nil {
return err
}
return addLightningNode(tx, node) return addLightningNode(tx, node)
}, },
} }
@@ -543,6 +685,8 @@ func (c *ChannelGraph) DeleteLightningNode(nodePub route.Vertex) error {
return ErrGraphNodeNotFound return ErrGraphNodeNotFound
} }
c.graphCache.RemoveNode(nodePub)
return c.deleteLightningNode(nodes, nodePub[:]) return c.deleteLightningNode(nodes, nodePub[:])
}, func() {}) }, func() {})
} }
@@ -669,6 +813,8 @@ func (c *ChannelGraph) addChannelEdge(tx kvdb.RwTx, edge *ChannelEdgeInfo) error
return ErrEdgeAlreadyExist return ErrEdgeAlreadyExist
} }
c.graphCache.AddChannel(edge, nil, nil)
// Before we insert the channel into the database, we'll ensure that // Before we insert the channel into the database, we'll ensure that
// both nodes already exist in the channel graph. If either node // both nodes already exist in the channel graph. If either node
// doesn't, then we'll insert a "shell" node that just includes its // doesn't, then we'll insert a "shell" node that just includes its
@@ -868,6 +1014,8 @@ func (c *ChannelGraph) UpdateChannelEdge(edge *ChannelEdgeInfo) error {
return ErrEdgeNotFound return ErrEdgeNotFound
} }
c.graphCache.UpdateChannel(edge)
return putChanEdgeInfo(edgeIndex, edge, chanKey) return putChanEdgeInfo(edgeIndex, edge, chanKey)
}, func() {}) }, func() {})
} }
@@ -953,7 +1101,7 @@ func (c *ChannelGraph) PruneGraph(spentOutputs []*wire.OutPoint,
// will be returned if that outpoint isn't known to be // will be returned if that outpoint isn't known to be
// a channel. If no error is returned, then a channel // a channel. If no error is returned, then a channel
// was successfully pruned. // was successfully pruned.
err = delChannelEdge( err = c.delChannelEdge(
edges, edgeIndex, chanIndex, zombieIndex, nodes, edges, edgeIndex, chanIndex, zombieIndex, nodes,
chanID, false, false, chanID, false, false,
) )
@@ -1004,6 +1152,8 @@ func (c *ChannelGraph) PruneGraph(spentOutputs []*wire.OutPoint,
c.chanCache.remove(channel.ChannelID) c.chanCache.remove(channel.ChannelID)
} }
log.Debugf("Pruned graph, cache now has %s", c.graphCache.Stats())
return chansClosed, nil return chansClosed, nil
} }
@@ -1104,6 +1254,8 @@ func (c *ChannelGraph) pruneGraphNodes(nodes kvdb.RwBucket,
continue continue
} }
c.graphCache.RemoveNode(nodePubKey)
// If we reach this point, then there are no longer any edges // If we reach this point, then there are no longer any edges
// that connect this node, so we can delete it. // that connect this node, so we can delete it.
if err := c.deleteLightningNode(nodes, nodePubKey[:]); err != nil { if err := c.deleteLightningNode(nodes, nodePubKey[:]); err != nil {
@@ -1202,7 +1354,7 @@ func (c *ChannelGraph) DisconnectBlockAtHeight(height uint32) ([]*ChannelEdgeInf
} }
for _, k := range keys { for _, k := range keys {
err = delChannelEdge( err = c.delChannelEdge(
edges, edgeIndex, chanIndex, zombieIndex, nodes, edges, edgeIndex, chanIndex, zombieIndex, nodes,
k, false, false, k, false, false,
) )
@@ -1310,7 +1462,9 @@ func (c *ChannelGraph) PruneTip() (*chainhash.Hash, uint32, error) {
// true, then when we mark these edges as zombies, we'll set up the keys such // true, then when we mark these edges as zombies, we'll set up the keys such
// that we require the node that failed to send the fresh update to be the one // that we require the node that failed to send the fresh update to be the one
// that resurrects the channel from its zombie state. // that resurrects the channel from its zombie state.
func (c *ChannelGraph) DeleteChannelEdges(strictZombiePruning bool, chanIDs ...uint64) error { func (c *ChannelGraph) DeleteChannelEdges(strictZombiePruning bool,
chanIDs ...uint64) error {
// TODO(roasbeef): possibly delete from node bucket if node has no more // TODO(roasbeef): possibly delete from node bucket if node has no more
// channels // channels
// TODO(roasbeef): don't delete both edges? // TODO(roasbeef): don't delete both edges?
@@ -1343,7 +1497,7 @@ func (c *ChannelGraph) DeleteChannelEdges(strictZombiePruning bool, chanIDs ...u
var rawChanID [8]byte var rawChanID [8]byte
for _, chanID := range chanIDs { for _, chanID := range chanIDs {
byteOrder.PutUint64(rawChanID[:], chanID) byteOrder.PutUint64(rawChanID[:], chanID)
err := delChannelEdge( err := c.delChannelEdge(
edges, edgeIndex, chanIndex, zombieIndex, nodes, edges, edgeIndex, chanIndex, zombieIndex, nodes,
rawChanID[:], true, strictZombiePruning, rawChanID[:], true, strictZombiePruning,
) )
@@ -1472,7 +1626,9 @@ type ChannelEdge struct {
// ChanUpdatesInHorizon returns all the known channel edges which have at least // ChanUpdatesInHorizon returns all the known channel edges which have at least
// one edge that has an update timestamp within the specified horizon. // one edge that has an update timestamp within the specified horizon.
func (c *ChannelGraph) ChanUpdatesInHorizon(startTime, endTime time.Time) ([]ChannelEdge, error) { func (c *ChannelGraph) ChanUpdatesInHorizon(startTime,
endTime time.Time) ([]ChannelEdge, error) {
// To ensure we don't return duplicate ChannelEdges, we'll use an // To ensure we don't return duplicate ChannelEdges, we'll use an
// additional map to keep track of the edges already seen to prevent // additional map to keep track of the edges already seen to prevent
// re-adding it. // re-adding it.
@@ -1605,7 +1761,9 @@ func (c *ChannelGraph) ChanUpdatesInHorizon(startTime, endTime time.Time) ([]Cha
// update timestamp within the passed range. This method can be used by two // update timestamp within the passed range. This method can be used by two
// nodes to quickly determine if they have the same set of up to date node // nodes to quickly determine if they have the same set of up to date node
// announcements. // announcements.
func (c *ChannelGraph) NodeUpdatesInHorizon(startTime, endTime time.Time) ([]LightningNode, error) { func (c *ChannelGraph) NodeUpdatesInHorizon(startTime,
endTime time.Time) ([]LightningNode, error) {
var nodesInHorizon []LightningNode var nodesInHorizon []LightningNode
err := kvdb.View(c.db, func(tx kvdb.RTx) error { err := kvdb.View(c.db, func(tx kvdb.RTx) error {
@@ -1933,7 +2091,7 @@ func delEdgeUpdateIndexEntry(edgesBucket kvdb.RwBucket, chanID uint64,
return nil return nil
} }
func delChannelEdge(edges, edgeIndex, chanIndex, zombieIndex, func (c *ChannelGraph) delChannelEdge(edges, edgeIndex, chanIndex, zombieIndex,
nodes kvdb.RwBucket, chanID []byte, isZombie, strictZombie bool) error { nodes kvdb.RwBucket, chanID []byte, isZombie, strictZombie bool) error {
edgeInfo, err := fetchChanEdgeInfo(edgeIndex, chanID) edgeInfo, err := fetchChanEdgeInfo(edgeIndex, chanID)
@@ -1941,6 +2099,11 @@ func delChannelEdge(edges, edgeIndex, chanIndex, zombieIndex,
return err return err
} }
c.graphCache.RemoveChannel(
edgeInfo.NodeKey1Bytes, edgeInfo.NodeKey2Bytes,
edgeInfo.ChannelID,
)
// We'll also remove the entry in the edge update index bucket before // We'll also remove the entry in the edge update index bucket before
// we delete the edges themselves so we can access their last update // we delete the edges themselves so we can access their last update
// times. // times.
@@ -2075,7 +2238,9 @@ func (c *ChannelGraph) UpdateEdgePolicy(edge *ChannelEdgePolicy,
}, },
Update: func(tx kvdb.RwTx) error { Update: func(tx kvdb.RwTx) error {
var err error var err error
isUpdate1, err = updateEdgePolicy(tx, edge) isUpdate1, err = updateEdgePolicy(
tx, edge, c.graphCache,
)
// Silence ErrEdgeNotFound so that the batch can // Silence ErrEdgeNotFound so that the batch can
// succeed, but propagate the error via local state. // succeed, but propagate the error via local state.
@@ -2138,7 +2303,9 @@ func (c *ChannelGraph) updateEdgeCache(e *ChannelEdgePolicy, isUpdate1 bool) {
// buckets using an existing database transaction. The returned boolean will be // buckets using an existing database transaction. The returned boolean will be
// true if the updated policy belongs to node1, and false if the policy belonged // true if the updated policy belongs to node1, and false if the policy belonged
// to node2. // to node2.
func updateEdgePolicy(tx kvdb.RwTx, edge *ChannelEdgePolicy) (bool, error) { func updateEdgePolicy(tx kvdb.RwTx, edge *ChannelEdgePolicy,
graphCache *GraphCache) (bool, error) {
edges := tx.ReadWriteBucket(edgeBucket) edges := tx.ReadWriteBucket(edgeBucket)
if edges == nil { if edges == nil {
return false, ErrEdgeNotFound return false, ErrEdgeNotFound
@@ -2186,6 +2353,14 @@ func updateEdgePolicy(tx kvdb.RwTx, edge *ChannelEdgePolicy) (bool, error) {
return false, err return false, err
} }
var (
fromNodePubKey route.Vertex
toNodePubKey route.Vertex
)
copy(fromNodePubKey[:], fromNode)
copy(toNodePubKey[:], toNode)
graphCache.UpdatePolicy(edge, fromNodePubKey, toNodePubKey, isUpdate1)
return isUpdate1, nil return isUpdate1, nil
} }
@@ -2232,7 +2407,7 @@ type LightningNode struct {
// compatible manner. // compatible manner.
ExtraOpaqueData []byte ExtraOpaqueData []byte
db *DB db kvdb.Backend
// TODO(roasbeef): discovery will need storage to keep it's last IP // TODO(roasbeef): discovery will need storage to keep it's last IP
// address and re-announce if interface changes? // address and re-announce if interface changes?
@@ -2356,17 +2531,11 @@ func (l *LightningNode) isPublic(tx kvdb.RTx, sourcePubKey []byte) (bool, error)
// FetchLightningNode attempts to look up a target node by its identity public // FetchLightningNode attempts to look up a target node by its identity public
// key. If the node isn't found in the database, then ErrGraphNodeNotFound is // key. If the node isn't found in the database, then ErrGraphNodeNotFound is
// returned. // returned.
// func (c *ChannelGraph) FetchLightningNode(nodePub route.Vertex) (
// If the caller wishes to re-use an existing boltdb transaction, then it
// should be passed as the first argument. Otherwise the first argument should
// be nil and a fresh transaction will be created to execute the graph
// traversal.
func (c *ChannelGraph) FetchLightningNode(tx kvdb.RTx, nodePub route.Vertex) (
*LightningNode, error) { *LightningNode, error) {
var node *LightningNode var node *LightningNode
err := kvdb.View(c.db, func(tx kvdb.RTx) error {
fetchNode := func(tx kvdb.RTx) error {
// First grab the nodes bucket which stores the mapping from // First grab the nodes bucket which stores the mapping from
// pubKey to node information. // pubKey to node information.
nodes := tx.ReadBucket(nodeBucket) nodes := tx.ReadBucket(nodeBucket)
@@ -2393,14 +2562,9 @@ func (c *ChannelGraph) FetchLightningNode(tx kvdb.RTx, nodePub route.Vertex) (
node = &n node = &n
return nil return nil
} }, func() {
node = nil
var err error })
if tx == nil {
err = kvdb.View(c.db, fetchNode, func() {})
} else {
err = fetchNode(tx)
}
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -2408,6 +2572,52 @@ func (c *ChannelGraph) FetchLightningNode(tx kvdb.RTx, nodePub route.Vertex) (
return node, nil return node, nil
} }
// graphCacheNode is a struct that wraps a LightningNode in a way that it can be
// cached in the graph cache.
type graphCacheNode struct {
pubKeyBytes route.Vertex
features *lnwire.FeatureVector
nodeScratch [8]byte
}
// newGraphCacheNode returns a new cache optimized node.
func newGraphCacheNode(pubKey route.Vertex,
features *lnwire.FeatureVector) *graphCacheNode {
return &graphCacheNode{
pubKeyBytes: pubKey,
features: features,
}
}
// PubKey returns the node's public identity key.
func (n *graphCacheNode) PubKey() route.Vertex {
return n.pubKeyBytes
}
// Features returns the node's features.
func (n *graphCacheNode) Features() *lnwire.FeatureVector {
return n.features
}
// ForEachChannel iterates through all channels of this node, executing the
// passed callback with an edge info structure and the policies of each end
// of the channel. The first edge policy is the outgoing edge *to* the
// connecting node, while the second is the incoming edge *from* the
// connecting node. If the callback returns an error, then the iteration is
// halted with the error propagated back up to the caller.
//
// Unknown policies are passed into the callback as nil values.
func (n *graphCacheNode) ForEachChannel(tx kvdb.RTx,
cb func(kvdb.RTx, *ChannelEdgeInfo, *ChannelEdgePolicy,
*ChannelEdgePolicy) error) error {
return nodeTraversal(tx, n.pubKeyBytes[:], nil, cb)
}
var _ GraphCacheNode = (*graphCacheNode)(nil)
// HasLightningNode determines if the graph has a vertex identified by the // HasLightningNode determines if the graph has a vertex identified by the
// target node identity public key. If the node exists in the database, a // target node identity public key. If the node exists in the database, a
// timestamp of when the data for the node was lasted updated is returned along // timestamp of when the data for the node was lasted updated is returned along
@@ -2460,7 +2670,7 @@ func (c *ChannelGraph) HasLightningNode(nodePub [33]byte) (time.Time, bool, erro
// nodeTraversal is used to traverse all channels of a node given by its // nodeTraversal is used to traverse all channels of a node given by its
// public key and passes channel information into the specified callback. // public key and passes channel information into the specified callback.
func nodeTraversal(tx kvdb.RTx, nodePub []byte, db *DB, func nodeTraversal(tx kvdb.RTx, nodePub []byte, db kvdb.Backend,
cb func(kvdb.RTx, *ChannelEdgeInfo, *ChannelEdgePolicy, *ChannelEdgePolicy) error) error { cb func(kvdb.RTx, *ChannelEdgeInfo, *ChannelEdgePolicy, *ChannelEdgePolicy) error) error {
traversal := func(tx kvdb.RTx) error { traversal := func(tx kvdb.RTx) error {
@@ -2548,7 +2758,7 @@ func nodeTraversal(tx kvdb.RTx, nodePub []byte, db *DB,
// ForEachChannel iterates through all channels of this node, executing the // ForEachChannel iterates through all channels of this node, executing the
// passed callback with an edge info structure and the policies of each end // passed callback with an edge info structure and the policies of each end
// of the channel. The first edge policy is the outgoing edge *to* the // of the channel. The first edge policy is the outgoing edge *to* the
// the connecting node, while the second is the incoming edge *from* the // connecting node, while the second is the incoming edge *from* the
// connecting node. If the callback returns an error, then the iteration is // connecting node. If the callback returns an error, then the iteration is
// halted with the error propagated back up to the caller. // halted with the error propagated back up to the caller.
// //
@@ -2559,7 +2769,8 @@ func nodeTraversal(tx kvdb.RTx, nodePub []byte, db *DB,
// be nil and a fresh transaction will be created to execute the graph // be nil and a fresh transaction will be created to execute the graph
// traversal. // traversal.
func (l *LightningNode) ForEachChannel(tx kvdb.RTx, func (l *LightningNode) ForEachChannel(tx kvdb.RTx,
cb func(kvdb.RTx, *ChannelEdgeInfo, *ChannelEdgePolicy, *ChannelEdgePolicy) error) error { cb func(kvdb.RTx, *ChannelEdgeInfo, *ChannelEdgePolicy,
*ChannelEdgePolicy) error) error {
nodePub := l.PubKeyBytes[:] nodePub := l.PubKeyBytes[:]
db := l.db db := l.db
@@ -2627,7 +2838,7 @@ type ChannelEdgeInfo struct {
// compatible manner. // compatible manner.
ExtraOpaqueData []byte ExtraOpaqueData []byte
db *DB db kvdb.Backend
} }
// AddNodeKeys is a setter-like method that can be used to replace the set of // AddNodeKeys is a setter-like method that can be used to replace the set of
@@ -2988,7 +3199,7 @@ type ChannelEdgePolicy struct {
// compatible manner. // compatible manner.
ExtraOpaqueData []byte ExtraOpaqueData []byte
db *DB db kvdb.Backend
} }
// Signature is a channel announcement signature, which is needed for proper // Signature is a channel announcement signature, which is needed for proper
@@ -3406,7 +3617,7 @@ func (c *ChannelGraph) MarkEdgeZombie(chanID uint64,
c.cacheMu.Lock() c.cacheMu.Lock()
defer c.cacheMu.Unlock() defer c.cacheMu.Unlock()
err := kvdb.Batch(c.db.Backend, func(tx kvdb.RwTx) error { err := kvdb.Batch(c.db, func(tx kvdb.RwTx) error {
edges := tx.ReadWriteBucket(edgeBucket) edges := tx.ReadWriteBucket(edgeBucket)
if edges == nil { if edges == nil {
return ErrGraphNoEdgesFound return ErrGraphNoEdgesFound
@@ -3417,6 +3628,8 @@ func (c *ChannelGraph) MarkEdgeZombie(chanID uint64,
"bucket: %w", err) "bucket: %w", err)
} }
c.graphCache.RemoveChannel(pubKey1, pubKey2, chanID)
return markEdgeZombie(zombieIndex, chanID, pubKey1, pubKey2) return markEdgeZombie(zombieIndex, chanID, pubKey1, pubKey2)
}) })
if err != nil { if err != nil {
@@ -3471,6 +3684,18 @@ func (c *ChannelGraph) MarkEdgeLive(chanID uint64) error {
c.rejectCache.remove(chanID) c.rejectCache.remove(chanID)
c.chanCache.remove(chanID) c.chanCache.remove(chanID)
// We need to add the channel back into our graph cache, otherwise we
// won't use it for path finding.
edgeInfos, err := c.FetchChanInfos([]uint64{chanID})
if err != nil {
return err
}
for _, edgeInfo := range edgeInfos {
c.graphCache.AddChannel(
edgeInfo.Info, edgeInfo.Policy1, edgeInfo.Policy2,
)
}
return nil return nil
} }
@@ -3696,6 +3921,53 @@ func fetchLightningNode(nodeBucket kvdb.RBucket,
return deserializeLightningNode(nodeReader) return deserializeLightningNode(nodeReader)
} }
func deserializeLightningNodeCacheable(r io.Reader, node *graphCacheNode) error {
// Always populate a feature vector, even if we don't have a node
// announcement and short circuit below.
node.features = lnwire.EmptyFeatureVector()
// Skip ahead:
// - LastUpdate (8 bytes)
if _, err := r.Read(node.nodeScratch[:]); err != nil {
return err
}
if _, err := io.ReadFull(r, node.pubKeyBytes[:]); err != nil {
return err
}
// Read the node announcement flag.
if _, err := r.Read(node.nodeScratch[:2]); err != nil {
return err
}
hasNodeAnn := byteOrder.Uint16(node.nodeScratch[:2])
// The rest of the data is optional, and will only be there if we got a
// node announcement for this node.
if hasNodeAnn == 0 {
return nil
}
// We did get a node announcement for this node, so we'll have the rest
// of the data available.
var rgb uint8
if err := binary.Read(r, byteOrder, &rgb); err != nil {
return err
}
if err := binary.Read(r, byteOrder, &rgb); err != nil {
return err
}
if err := binary.Read(r, byteOrder, &rgb); err != nil {
return err
}
if _, err := wire.ReadVarString(r, 0); err != nil {
return err
}
return node.features.Decode(r)
}
func deserializeLightningNode(r io.Reader) (LightningNode, error) { func deserializeLightningNode(r io.Reader) (LightningNode, error) {
var ( var (
node LightningNode node LightningNode
@@ -4102,7 +4374,7 @@ func fetchChanEdgePolicy(edges kvdb.RBucket, chanID []byte,
func fetchChanEdgePolicies(edgeIndex kvdb.RBucket, edges kvdb.RBucket, func fetchChanEdgePolicies(edgeIndex kvdb.RBucket, edges kvdb.RBucket,
nodes kvdb.RBucket, chanID []byte, nodes kvdb.RBucket, chanID []byte,
db *DB) (*ChannelEdgePolicy, *ChannelEdgePolicy, error) { db kvdb.Backend) (*ChannelEdgePolicy, *ChannelEdgePolicy, error) {
edgeInfo := edgeIndex.Get(chanID) edgeInfo := edgeIndex.Get(chanID)
if edgeInfo == nil { if edgeInfo == nil {

460
channeldb/graph_cache.go Normal file
View File

@@ -0,0 +1,460 @@
package channeldb
import (
"fmt"
"sync"
"github.com/btcsuite/btcutil"
"github.com/lightningnetwork/lnd/kvdb"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/routing/route"
)
// GraphCacheNode is an interface for all the information the cache needs to know
// about a lightning node.
type GraphCacheNode interface {
// PubKey is the node's public identity key.
PubKey() route.Vertex
// Features returns the node's p2p features.
Features() *lnwire.FeatureVector
// ForEachChannel iterates through all channels of a given node,
// executing the passed callback with an edge info structure and the
// policies of each end of the channel. The first edge policy is the
// outgoing edge *to* the connecting node, while the second is the
// incoming edge *from* the connecting node. If the callback returns an
// error, then the iteration is halted with the error propagated back up
// to the caller.
ForEachChannel(kvdb.RTx,
func(kvdb.RTx, *ChannelEdgeInfo, *ChannelEdgePolicy,
*ChannelEdgePolicy) error) error
}
// CachedEdgePolicy is a struct that only caches the information of a
// ChannelEdgePolicy that we actually use for pathfinding and therefore need to
// store in the cache.
type CachedEdgePolicy struct {
// ChannelID is the unique channel ID for the channel. The first 3
// bytes are the block height, the next 3 the index within the block,
// and the last 2 bytes are the output index for the channel.
ChannelID uint64
// MessageFlags is a bitfield which indicates the presence of optional
// fields (like max_htlc) in the policy.
MessageFlags lnwire.ChanUpdateMsgFlags
// ChannelFlags is a bitfield which signals the capabilities of the
// channel as well as the directed edge this update applies to.
ChannelFlags lnwire.ChanUpdateChanFlags
// TimeLockDelta is the number of blocks this node will subtract from
// the expiry of an incoming HTLC. This value expresses the time buffer
// the node would like to HTLC exchanges.
TimeLockDelta uint16
// MinHTLC is the smallest value HTLC this node will forward, expressed
// in millisatoshi.
MinHTLC lnwire.MilliSatoshi
// MaxHTLC is the largest value HTLC this node will forward, expressed
// in millisatoshi.
MaxHTLC lnwire.MilliSatoshi
// FeeBaseMSat is the base HTLC fee that will be charged for forwarding
// ANY HTLC, expressed in mSAT's.
FeeBaseMSat lnwire.MilliSatoshi
// FeeProportionalMillionths is the rate that the node will charge for
// HTLCs for each millionth of a satoshi forwarded.
FeeProportionalMillionths lnwire.MilliSatoshi
// ToNodePubKey is a function that returns the to node of a policy.
// Since we only ever store the inbound policy, this is always the node
// that we query the channels for in ForEachChannel(). Therefore, we can
// save a lot of space by not storing this information in the memory and
// instead just set this function when we copy the policy from cache in
// ForEachChannel().
ToNodePubKey func() route.Vertex
// ToNodeFeatures are the to node's features. They are never set while
// the edge is in the cache, only on the copy that is returned in
// ForEachChannel().
ToNodeFeatures *lnwire.FeatureVector
}
// ComputeFee computes the fee to forward an HTLC of `amt` milli-satoshis over
// the passed active payment channel. This value is currently computed as
// specified in BOLT07, but will likely change in the near future.
func (c *CachedEdgePolicy) ComputeFee(
amt lnwire.MilliSatoshi) lnwire.MilliSatoshi {
return c.FeeBaseMSat + (amt*c.FeeProportionalMillionths)/feeRateParts
}
// ComputeFeeFromIncoming computes the fee to forward an HTLC given the incoming
// amount.
func (c *CachedEdgePolicy) ComputeFeeFromIncoming(
incomingAmt lnwire.MilliSatoshi) lnwire.MilliSatoshi {
return incomingAmt - divideCeil(
feeRateParts*(incomingAmt-c.FeeBaseMSat),
feeRateParts+c.FeeProportionalMillionths,
)
}
// NewCachedPolicy turns a full policy into a minimal one that can be cached.
func NewCachedPolicy(policy *ChannelEdgePolicy) *CachedEdgePolicy {
return &CachedEdgePolicy{
ChannelID: policy.ChannelID,
MessageFlags: policy.MessageFlags,
ChannelFlags: policy.ChannelFlags,
TimeLockDelta: policy.TimeLockDelta,
MinHTLC: policy.MinHTLC,
MaxHTLC: policy.MaxHTLC,
FeeBaseMSat: policy.FeeBaseMSat,
FeeProportionalMillionths: policy.FeeProportionalMillionths,
}
}
// DirectedChannel is a type that stores the channel information as seen from
// one side of the channel.
type DirectedChannel struct {
// ChannelID is the unique identifier of this channel.
ChannelID uint64
// IsNode1 indicates if this is the node with the smaller public key.
IsNode1 bool
// OtherNode is the public key of the node on the other end of this
// channel.
OtherNode route.Vertex
// Capacity is the announced capacity of this channel in satoshis.
Capacity btcutil.Amount
// OutPolicySet is a boolean that indicates whether the node has an
// outgoing policy set. For pathfinding only the existence of the policy
// is important to know, not the actual content.
OutPolicySet bool
// InPolicy is the incoming policy *from* the other node to this node.
// In path finding, we're walking backward from the destination to the
// source, so we're always interested in the edge that arrives to us
// from the other node.
InPolicy *CachedEdgePolicy
}
// DeepCopy creates a deep copy of the channel, including the incoming policy.
func (c *DirectedChannel) DeepCopy() *DirectedChannel {
channelCopy := *c
if channelCopy.InPolicy != nil {
inPolicyCopy := *channelCopy.InPolicy
channelCopy.InPolicy = &inPolicyCopy
// The fields for the ToNode can be overwritten by the path
// finding algorithm, which is why we need a deep copy in the
// first place. So we always start out with nil values, just to
// be sure they don't contain any old data.
channelCopy.InPolicy.ToNodePubKey = nil
channelCopy.InPolicy.ToNodeFeatures = nil
}
return &channelCopy
}
// GraphCache is a type that holds a minimal set of information of the public
// channel graph that can be used for pathfinding.
type GraphCache struct {
nodeChannels map[route.Vertex]map[uint64]*DirectedChannel
nodeFeatures map[route.Vertex]*lnwire.FeatureVector
mtx sync.RWMutex
}
// NewGraphCache creates a new graphCache.
func NewGraphCache(preAllocNumNodes int) *GraphCache {
return &GraphCache{
nodeChannels: make(
map[route.Vertex]map[uint64]*DirectedChannel,
// A channel connects two nodes, so we can look it up
// from both sides, meaning we get double the number of
// entries.
preAllocNumNodes*2,
),
nodeFeatures: make(
map[route.Vertex]*lnwire.FeatureVector,
preAllocNumNodes,
),
}
}
// Stats returns statistics about the current cache size.
func (c *GraphCache) Stats() string {
c.mtx.RLock()
defer c.mtx.RUnlock()
numChannels := 0
for node := range c.nodeChannels {
numChannels += len(c.nodeChannels[node])
}
return fmt.Sprintf("num_node_features=%d, num_nodes=%d, "+
"num_channels=%d", len(c.nodeFeatures), len(c.nodeChannels),
numChannels)
}
// AddNode adds a graph node, including all the (directed) channels of that
// node.
func (c *GraphCache) AddNode(tx kvdb.RTx, node GraphCacheNode) error {
nodePubKey := node.PubKey()
// Only hold the lock for a short time. The `ForEachChannel()` below is
// possibly slow as it has to go to the backend, so we can unlock
// between the calls. And the AddChannel() method will acquire its own
// lock anyway.
c.mtx.Lock()
c.nodeFeatures[nodePubKey] = node.Features()
c.mtx.Unlock()
return node.ForEachChannel(
tx, func(tx kvdb.RTx, info *ChannelEdgeInfo,
outPolicy *ChannelEdgePolicy,
inPolicy *ChannelEdgePolicy) error {
c.AddChannel(info, outPolicy, inPolicy)
return nil
},
)
}
// AddChannel adds a non-directed channel, meaning that the order of policy 1
// and policy 2 does not matter, the directionality is extracted from the info
// and policy flags automatically. The policy will be set as the outgoing policy
// on one node and the incoming policy on the peer's side.
func (c *GraphCache) AddChannel(info *ChannelEdgeInfo,
policy1 *ChannelEdgePolicy, policy2 *ChannelEdgePolicy) {
if info == nil {
return
}
if policy1 != nil && policy1.IsDisabled() &&
policy2 != nil && policy2.IsDisabled() {
return
}
// Create the edge entry for both nodes.
c.mtx.Lock()
c.updateOrAddEdge(info.NodeKey1Bytes, &DirectedChannel{
ChannelID: info.ChannelID,
IsNode1: true,
OtherNode: info.NodeKey2Bytes,
Capacity: info.Capacity,
})
c.updateOrAddEdge(info.NodeKey2Bytes, &DirectedChannel{
ChannelID: info.ChannelID,
IsNode1: false,
OtherNode: info.NodeKey1Bytes,
Capacity: info.Capacity,
})
c.mtx.Unlock()
// The policy's node is always the to_node. So if policy 1 has to_node
// of node 2 then we have the policy 1 as seen from node 1.
if policy1 != nil {
fromNode, toNode := info.NodeKey1Bytes, info.NodeKey2Bytes
if policy1.Node.PubKeyBytes != info.NodeKey2Bytes {
fromNode, toNode = toNode, fromNode
}
isEdge1 := policy1.ChannelFlags&lnwire.ChanUpdateDirection == 0
c.UpdatePolicy(policy1, fromNode, toNode, isEdge1)
}
if policy2 != nil {
fromNode, toNode := info.NodeKey2Bytes, info.NodeKey1Bytes
if policy2.Node.PubKeyBytes != info.NodeKey1Bytes {
fromNode, toNode = toNode, fromNode
}
isEdge1 := policy2.ChannelFlags&lnwire.ChanUpdateDirection == 0
c.UpdatePolicy(policy2, fromNode, toNode, isEdge1)
}
}
// updateOrAddEdge makes sure the edge information for a node is either updated
// if it already exists or is added to that node's list of channels.
func (c *GraphCache) updateOrAddEdge(node route.Vertex, edge *DirectedChannel) {
if len(c.nodeChannels[node]) == 0 {
c.nodeChannels[node] = make(map[uint64]*DirectedChannel)
}
c.nodeChannels[node][edge.ChannelID] = edge
}
// UpdatePolicy updates a single policy on both the from and to node. The order
// of the from and to node is not strictly important. But we assume that a
// channel edge was added beforehand so that the directed channel struct already
// exists in the cache.
func (c *GraphCache) UpdatePolicy(policy *ChannelEdgePolicy, fromNode,
toNode route.Vertex, edge1 bool) {
c.mtx.Lock()
defer c.mtx.Unlock()
updatePolicy := func(nodeKey route.Vertex) {
if len(c.nodeChannels[nodeKey]) == 0 {
return
}
channel, ok := c.nodeChannels[nodeKey][policy.ChannelID]
if !ok {
return
}
// Edge 1 is defined as the policy for the direction of node1 to
// node2.
switch {
// This is node 1, and it is edge 1, so this is the outgoing
// policy for node 1.
case channel.IsNode1 && edge1:
channel.OutPolicySet = true
// This is node 2, and it is edge 2, so this is the outgoing
// policy for node 2.
case !channel.IsNode1 && !edge1:
channel.OutPolicySet = true
// The other two cases left mean it's the inbound policy for the
// node.
default:
channel.InPolicy = NewCachedPolicy(policy)
}
}
updatePolicy(fromNode)
updatePolicy(toNode)
}
// RemoveNode completely removes a node and all its channels (including the
// peer's side).
func (c *GraphCache) RemoveNode(node route.Vertex) {
c.mtx.Lock()
defer c.mtx.Unlock()
delete(c.nodeFeatures, node)
// First remove all channels from the other nodes' lists.
for _, channel := range c.nodeChannels[node] {
c.removeChannelIfFound(channel.OtherNode, channel.ChannelID)
}
// Then remove our whole node completely.
delete(c.nodeChannels, node)
}
// RemoveChannel removes a single channel between two nodes.
func (c *GraphCache) RemoveChannel(node1, node2 route.Vertex, chanID uint64) {
c.mtx.Lock()
defer c.mtx.Unlock()
// Remove that one channel from both sides.
c.removeChannelIfFound(node1, chanID)
c.removeChannelIfFound(node2, chanID)
}
// removeChannelIfFound removes a single channel from one side.
func (c *GraphCache) removeChannelIfFound(node route.Vertex, chanID uint64) {
if len(c.nodeChannels[node]) == 0 {
return
}
delete(c.nodeChannels[node], chanID)
}
// UpdateChannel updates the channel edge information for a specific edge. We
// expect the edge to already exist and be known. If it does not yet exist, this
// call is a no-op.
func (c *GraphCache) UpdateChannel(info *ChannelEdgeInfo) {
c.mtx.Lock()
defer c.mtx.Unlock()
if len(c.nodeChannels[info.NodeKey1Bytes]) == 0 ||
len(c.nodeChannels[info.NodeKey2Bytes]) == 0 {
return
}
channel, ok := c.nodeChannels[info.NodeKey1Bytes][info.ChannelID]
if ok {
// We only expect to be called when the channel is already
// known.
channel.Capacity = info.Capacity
channel.OtherNode = info.NodeKey2Bytes
}
channel, ok = c.nodeChannels[info.NodeKey2Bytes][info.ChannelID]
if ok {
channel.Capacity = info.Capacity
channel.OtherNode = info.NodeKey1Bytes
}
}
// ForEachChannel invokes the given callback for each channel of the given node.
func (c *GraphCache) ForEachChannel(node route.Vertex,
cb func(channel *DirectedChannel) error) error {
c.mtx.RLock()
defer c.mtx.RUnlock()
channels, ok := c.nodeChannels[node]
if !ok {
return nil
}
features, ok := c.nodeFeatures[node]
if !ok {
log.Warnf("Node %v has no features defined, falling back to "+
"default feature vector for path finding", node)
features = lnwire.EmptyFeatureVector()
}
toNodeCallback := func() route.Vertex {
return node
}
for _, channel := range channels {
// We need to copy the channel and policy to avoid it being
// updated in the cache if the path finding algorithm sets
// fields on it (currently only the ToNodeFeatures of the
// policy).
channelCopy := channel.DeepCopy()
if channelCopy.InPolicy != nil {
channelCopy.InPolicy.ToNodePubKey = toNodeCallback
channelCopy.InPolicy.ToNodeFeatures = features
}
if err := cb(channelCopy); err != nil {
return err
}
}
return nil
}
// GetFeatures returns the features of the node with the given ID.
func (c *GraphCache) GetFeatures(node route.Vertex) *lnwire.FeatureVector {
c.mtx.RLock()
defer c.mtx.RUnlock()
features, ok := c.nodeFeatures[node]
if !ok || features == nil {
// The router expects the features to never be nil, so we return
// an empty feature set instead.
return lnwire.EmptyFeatureVector()
}
return features
}

View File

@@ -0,0 +1,147 @@
package channeldb
import (
"encoding/hex"
"testing"
"github.com/lightningnetwork/lnd/kvdb"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/routing/route"
"github.com/stretchr/testify/require"
)
var (
pubKey1Bytes, _ = hex.DecodeString(
"0248f5cba4c6da2e4c9e01e81d1404dfac0cbaf3ee934a4fc117d2ea9a64" +
"22c91d",
)
pubKey2Bytes, _ = hex.DecodeString(
"038155ba86a8d3b23c806c855097ca5c9fa0f87621f1e7a7d2835ad057f6" +
"f4484f",
)
pubKey1, _ = route.NewVertexFromBytes(pubKey1Bytes)
pubKey2, _ = route.NewVertexFromBytes(pubKey2Bytes)
)
type node struct {
pubKey route.Vertex
features *lnwire.FeatureVector
edgeInfos []*ChannelEdgeInfo
outPolicies []*ChannelEdgePolicy
inPolicies []*ChannelEdgePolicy
}
func (n *node) PubKey() route.Vertex {
return n.pubKey
}
func (n *node) Features() *lnwire.FeatureVector {
return n.features
}
func (n *node) ForEachChannel(tx kvdb.RTx,
cb func(kvdb.RTx, *ChannelEdgeInfo, *ChannelEdgePolicy,
*ChannelEdgePolicy) error) error {
for idx := range n.edgeInfos {
err := cb(
tx, n.edgeInfos[idx], n.outPolicies[idx],
n.inPolicies[idx],
)
if err != nil {
return err
}
}
return nil
}
// TestGraphCacheAddNode tests that a channel going from node A to node B can be
// cached correctly, independent of the direction we add the channel as.
func TestGraphCacheAddNode(t *testing.T) {
runTest := func(nodeA, nodeB route.Vertex) {
t.Helper()
channelFlagA, channelFlagB := 0, 1
if nodeA == pubKey2 {
channelFlagA, channelFlagB = 1, 0
}
outPolicy1 := &ChannelEdgePolicy{
ChannelID: 1000,
ChannelFlags: lnwire.ChanUpdateChanFlags(channelFlagA),
Node: &LightningNode{
PubKeyBytes: nodeB,
Features: lnwire.EmptyFeatureVector(),
},
}
inPolicy1 := &ChannelEdgePolicy{
ChannelID: 1000,
ChannelFlags: lnwire.ChanUpdateChanFlags(channelFlagB),
Node: &LightningNode{
PubKeyBytes: nodeA,
Features: lnwire.EmptyFeatureVector(),
},
}
node := &node{
pubKey: nodeA,
features: lnwire.EmptyFeatureVector(),
edgeInfos: []*ChannelEdgeInfo{{
ChannelID: 1000,
// Those are direction independent!
NodeKey1Bytes: pubKey1,
NodeKey2Bytes: pubKey2,
Capacity: 500,
}},
outPolicies: []*ChannelEdgePolicy{outPolicy1},
inPolicies: []*ChannelEdgePolicy{inPolicy1},
}
cache := NewGraphCache(10)
require.NoError(t, cache.AddNode(nil, node))
var fromChannels, toChannels []*DirectedChannel
_ = cache.ForEachChannel(nodeA, func(c *DirectedChannel) error {
fromChannels = append(fromChannels, c)
return nil
})
_ = cache.ForEachChannel(nodeB, func(c *DirectedChannel) error {
toChannels = append(toChannels, c)
return nil
})
require.Len(t, fromChannels, 1)
require.Len(t, toChannels, 1)
require.Equal(t, outPolicy1 != nil, fromChannels[0].OutPolicySet)
assertCachedPolicyEqual(t, inPolicy1, fromChannels[0].InPolicy)
require.Equal(t, inPolicy1 != nil, toChannels[0].OutPolicySet)
assertCachedPolicyEqual(t, outPolicy1, toChannels[0].InPolicy)
}
runTest(pubKey1, pubKey2)
runTest(pubKey2, pubKey1)
}
func assertCachedPolicyEqual(t *testing.T, original *ChannelEdgePolicy,
cached *CachedEdgePolicy) {
require.Equal(t, original.ChannelID, cached.ChannelID)
require.Equal(t, original.MessageFlags, cached.MessageFlags)
require.Equal(t, original.ChannelFlags, cached.ChannelFlags)
require.Equal(t, original.TimeLockDelta, cached.TimeLockDelta)
require.Equal(t, original.MinHTLC, cached.MinHTLC)
require.Equal(t, original.MaxHTLC, cached.MaxHTLC)
require.Equal(t, original.FeeBaseMSat, cached.FeeBaseMSat)
require.Equal(
t, original.FeeProportionalMillionths,
cached.FeeProportionalMillionths,
)
require.Equal(
t,
route.Vertex(original.Node.PubKeyBytes),
cached.ToNodePubKey(),
)
require.Equal(t, original.Node.Features, cached.ToNodeFeatures)
}

File diff suppressed because it is too large Load Diff

View File

@@ -56,12 +56,14 @@ type LinkNode struct {
// authenticated connection for the stored identity public key. // authenticated connection for the stored identity public key.
Addresses []net.Addr Addresses []net.Addr
db *DB // db is the database instance this node was fetched from. This is used
// to sync back the node's state if it is updated.
db *LinkNodeDB
} }
// NewLinkNode creates a new LinkNode from the provided parameters, which is // NewLinkNode creates a new LinkNode from the provided parameters, which is
// backed by an instance of channeldb. // backed by an instance of a link node DB.
func (d *DB) NewLinkNode(bitNet wire.BitcoinNet, pub *btcec.PublicKey, func NewLinkNode(db *LinkNodeDB, bitNet wire.BitcoinNet, pub *btcec.PublicKey,
addrs ...net.Addr) *LinkNode { addrs ...net.Addr) *LinkNode {
return &LinkNode{ return &LinkNode{
@@ -69,7 +71,7 @@ func (d *DB) NewLinkNode(bitNet wire.BitcoinNet, pub *btcec.PublicKey,
IdentityPub: pub, IdentityPub: pub,
LastSeen: time.Now(), LastSeen: time.Now(),
Addresses: addrs, Addresses: addrs,
db: d, db: db,
} }
} }
@@ -98,10 +100,9 @@ func (l *LinkNode) AddAddress(addr net.Addr) error {
// Sync performs a full database sync which writes the current up-to-date data // Sync performs a full database sync which writes the current up-to-date data
// within the struct to the database. // within the struct to the database.
func (l *LinkNode) Sync() error { func (l *LinkNode) Sync() error {
// Finally update the database by storing the link node and updating // Finally update the database by storing the link node and updating
// any relevant indexes. // any relevant indexes.
return kvdb.Update(l.db, func(tx kvdb.RwTx) error { return kvdb.Update(l.db.backend, func(tx kvdb.RwTx) error {
nodeMetaBucket := tx.ReadWriteBucket(nodeInfoBucket) nodeMetaBucket := tx.ReadWriteBucket(nodeInfoBucket)
if nodeMetaBucket == nil { if nodeMetaBucket == nil {
return ErrLinkNodesNotFound return ErrLinkNodesNotFound
@@ -127,15 +128,20 @@ func putLinkNode(nodeMetaBucket kvdb.RwBucket, l *LinkNode) error {
return nodeMetaBucket.Put(nodePub, b.Bytes()) return nodeMetaBucket.Put(nodePub, b.Bytes())
} }
// LinkNodeDB is a database that keeps track of all link nodes.
type LinkNodeDB struct {
backend kvdb.Backend
}
// DeleteLinkNode removes the link node with the given identity from the // DeleteLinkNode removes the link node with the given identity from the
// database. // database.
func (d *DB) DeleteLinkNode(identity *btcec.PublicKey) error { func (l *LinkNodeDB) DeleteLinkNode(identity *btcec.PublicKey) error {
return kvdb.Update(d, func(tx kvdb.RwTx) error { return kvdb.Update(l.backend, func(tx kvdb.RwTx) error {
return d.deleteLinkNode(tx, identity) return deleteLinkNode(tx, identity)
}, func() {}) }, func() {})
} }
func (d *DB) deleteLinkNode(tx kvdb.RwTx, identity *btcec.PublicKey) error { func deleteLinkNode(tx kvdb.RwTx, identity *btcec.PublicKey) error {
nodeMetaBucket := tx.ReadWriteBucket(nodeInfoBucket) nodeMetaBucket := tx.ReadWriteBucket(nodeInfoBucket)
if nodeMetaBucket == nil { if nodeMetaBucket == nil {
return ErrLinkNodesNotFound return ErrLinkNodesNotFound
@@ -148,9 +154,9 @@ func (d *DB) deleteLinkNode(tx kvdb.RwTx, identity *btcec.PublicKey) error {
// FetchLinkNode attempts to lookup the data for a LinkNode based on a target // FetchLinkNode attempts to lookup the data for a LinkNode based on a target
// identity public key. If a particular LinkNode for the passed identity public // identity public key. If a particular LinkNode for the passed identity public
// key cannot be found, then ErrNodeNotFound if returned. // key cannot be found, then ErrNodeNotFound if returned.
func (d *DB) FetchLinkNode(identity *btcec.PublicKey) (*LinkNode, error) { func (l *LinkNodeDB) FetchLinkNode(identity *btcec.PublicKey) (*LinkNode, error) {
var linkNode *LinkNode var linkNode *LinkNode
err := kvdb.View(d, func(tx kvdb.RTx) error { err := kvdb.View(l.backend, func(tx kvdb.RTx) error {
node, err := fetchLinkNode(tx, identity) node, err := fetchLinkNode(tx, identity)
if err != nil { if err != nil {
return err return err
@@ -191,10 +197,10 @@ func fetchLinkNode(tx kvdb.RTx, targetPub *btcec.PublicKey) (*LinkNode, error) {
// FetchAllLinkNodes starts a new database transaction to fetch all nodes with // FetchAllLinkNodes starts a new database transaction to fetch all nodes with
// whom we have active channels with. // whom we have active channels with.
func (d *DB) FetchAllLinkNodes() ([]*LinkNode, error) { func (l *LinkNodeDB) FetchAllLinkNodes() ([]*LinkNode, error) {
var linkNodes []*LinkNode var linkNodes []*LinkNode
err := kvdb.View(d, func(tx kvdb.RTx) error { err := kvdb.View(l.backend, func(tx kvdb.RTx) error {
nodes, err := d.fetchAllLinkNodes(tx) nodes, err := fetchAllLinkNodes(tx)
if err != nil { if err != nil {
return err return err
} }
@@ -213,7 +219,7 @@ func (d *DB) FetchAllLinkNodes() ([]*LinkNode, error) {
// fetchAllLinkNodes uses an existing database transaction to fetch all nodes // fetchAllLinkNodes uses an existing database transaction to fetch all nodes
// with whom we have active channels with. // with whom we have active channels with.
func (d *DB) fetchAllLinkNodes(tx kvdb.RTx) ([]*LinkNode, error) { func fetchAllLinkNodes(tx kvdb.RTx) ([]*LinkNode, error) {
nodeMetaBucket := tx.ReadBucket(nodeInfoBucket) nodeMetaBucket := tx.ReadBucket(nodeInfoBucket)
if nodeMetaBucket == nil { if nodeMetaBucket == nil {
return nil, ErrLinkNodesNotFound return nil, ErrLinkNodesNotFound

View File

@@ -13,12 +13,14 @@ import (
func TestLinkNodeEncodeDecode(t *testing.T) { func TestLinkNodeEncodeDecode(t *testing.T) {
t.Parallel() t.Parallel()
cdb, cleanUp, err := MakeTestDB() fullDB, cleanUp, err := MakeTestDB()
if err != nil { if err != nil {
t.Fatalf("unable to make test database: %v", err) t.Fatalf("unable to make test database: %v", err)
} }
defer cleanUp() defer cleanUp()
cdb := fullDB.ChannelStateDB()
// First we'll create some initial data to use for populating our test // First we'll create some initial data to use for populating our test
// LinkNode instances. // LinkNode instances.
_, pub1 := btcec.PrivKeyFromBytes(btcec.S256(), key[:]) _, pub1 := btcec.PrivKeyFromBytes(btcec.S256(), key[:])
@@ -34,8 +36,8 @@ func TestLinkNodeEncodeDecode(t *testing.T) {
// Create two fresh link node instances with the above dummy data, then // Create two fresh link node instances with the above dummy data, then
// fully sync both instances to disk. // fully sync both instances to disk.
node1 := cdb.NewLinkNode(wire.MainNet, pub1, addr1) node1 := NewLinkNode(cdb.linkNodeDB, wire.MainNet, pub1, addr1)
node2 := cdb.NewLinkNode(wire.TestNet3, pub2, addr2) node2 := NewLinkNode(cdb.linkNodeDB, wire.TestNet3, pub2, addr2)
if err := node1.Sync(); err != nil { if err := node1.Sync(); err != nil {
t.Fatalf("unable to sync node: %v", err) t.Fatalf("unable to sync node: %v", err)
} }
@@ -46,7 +48,7 @@ func TestLinkNodeEncodeDecode(t *testing.T) {
// Fetch all current link nodes from the database, they should exactly // Fetch all current link nodes from the database, they should exactly
// match the two created above. // match the two created above.
originalNodes := []*LinkNode{node2, node1} originalNodes := []*LinkNode{node2, node1}
linkNodes, err := cdb.FetchAllLinkNodes() linkNodes, err := cdb.linkNodeDB.FetchAllLinkNodes()
if err != nil { if err != nil {
t.Fatalf("unable to fetch nodes: %v", err) t.Fatalf("unable to fetch nodes: %v", err)
} }
@@ -82,7 +84,7 @@ func TestLinkNodeEncodeDecode(t *testing.T) {
} }
// Fetch the same node from the database according to its public key. // Fetch the same node from the database according to its public key.
node1DB, err := cdb.FetchLinkNode(pub1) node1DB, err := cdb.linkNodeDB.FetchLinkNode(pub1)
if err != nil { if err != nil {
t.Fatalf("unable to find node: %v", err) t.Fatalf("unable to find node: %v", err)
} }
@@ -110,31 +112,33 @@ func TestLinkNodeEncodeDecode(t *testing.T) {
func TestDeleteLinkNode(t *testing.T) { func TestDeleteLinkNode(t *testing.T) {
t.Parallel() t.Parallel()
cdb, cleanUp, err := MakeTestDB() fullDB, cleanUp, err := MakeTestDB()
if err != nil { if err != nil {
t.Fatalf("unable to make test database: %v", err) t.Fatalf("unable to make test database: %v", err)
} }
defer cleanUp() defer cleanUp()
cdb := fullDB.ChannelStateDB()
_, pubKey := btcec.PrivKeyFromBytes(btcec.S256(), key[:]) _, pubKey := btcec.PrivKeyFromBytes(btcec.S256(), key[:])
addr := &net.TCPAddr{ addr := &net.TCPAddr{
IP: net.ParseIP("127.0.0.1"), IP: net.ParseIP("127.0.0.1"),
Port: 1337, Port: 1337,
} }
linkNode := cdb.NewLinkNode(wire.TestNet3, pubKey, addr) linkNode := NewLinkNode(cdb.linkNodeDB, wire.TestNet3, pubKey, addr)
if err := linkNode.Sync(); err != nil { if err := linkNode.Sync(); err != nil {
t.Fatalf("unable to write link node to db: %v", err) t.Fatalf("unable to write link node to db: %v", err)
} }
if _, err := cdb.FetchLinkNode(pubKey); err != nil { if _, err := cdb.linkNodeDB.FetchLinkNode(pubKey); err != nil {
t.Fatalf("unable to find link node: %v", err) t.Fatalf("unable to find link node: %v", err)
} }
if err := cdb.DeleteLinkNode(pubKey); err != nil { if err := cdb.linkNodeDB.DeleteLinkNode(pubKey); err != nil {
t.Fatalf("unable to delete link node from db: %v", err) t.Fatalf("unable to delete link node from db: %v", err)
} }
if _, err := cdb.FetchLinkNode(pubKey); err == nil { if _, err := cdb.linkNodeDB.FetchLinkNode(pubKey); err == nil {
t.Fatal("should not have found link node in db, but did") t.Fatal("should not have found link node in db, but did")
} }
} }

View File

@@ -17,6 +17,12 @@ const (
// in order to reply to gossip queries. This produces a cache size of // in order to reply to gossip queries. This produces a cache size of
// around 40MB. // around 40MB.
DefaultChannelCacheSize = 20000 DefaultChannelCacheSize = 20000
// DefaultPreAllocCacheNumNodes is the default number of channels we
// assume for mainnet for pre-allocating the graph cache. As of
// September 2021, there currently are 14k nodes in a strictly pruned
// graph, so we choose a number that is slightly higher.
DefaultPreAllocCacheNumNodes = 15000
) )
// Options holds parameters for tuning and customizing a channeldb.DB. // Options holds parameters for tuning and customizing a channeldb.DB.
@@ -35,6 +41,10 @@ type Options struct {
// wait before attempting to commit a pending set of updates. // wait before attempting to commit a pending set of updates.
BatchCommitInterval time.Duration BatchCommitInterval time.Duration
// PreAllocCacheNumNodes is the number of nodes we expect to be in the
// graph cache, so we can pre-allocate the map accordingly.
PreAllocCacheNumNodes int
// clock is the time source used by the database. // clock is the time source used by the database.
clock clock.Clock clock clock.Clock
@@ -52,9 +62,10 @@ func DefaultOptions() Options {
AutoCompactMinAge: kvdb.DefaultBoltAutoCompactMinAge, AutoCompactMinAge: kvdb.DefaultBoltAutoCompactMinAge,
DBTimeout: kvdb.DefaultDBTimeout, DBTimeout: kvdb.DefaultDBTimeout,
}, },
RejectCacheSize: DefaultRejectCacheSize, RejectCacheSize: DefaultRejectCacheSize,
ChannelCacheSize: DefaultChannelCacheSize, ChannelCacheSize: DefaultChannelCacheSize,
clock: clock.NewDefaultClock(), PreAllocCacheNumNodes: DefaultPreAllocCacheNumNodes,
clock: clock.NewDefaultClock(),
} }
} }
@@ -75,6 +86,13 @@ func OptionSetChannelCacheSize(n int) OptionModifier {
} }
} }
// OptionSetPreAllocCacheNumNodes sets the PreAllocCacheNumNodes to n.
func OptionSetPreAllocCacheNumNodes(n int) OptionModifier {
return func(o *Options) {
o.PreAllocCacheNumNodes = n
}
}
// OptionSetSyncFreelist allows the database to sync its freelist. // OptionSetSyncFreelist allows the database to sync its freelist.
func OptionSetSyncFreelist(b bool) OptionModifier { func OptionSetSyncFreelist(b bool) OptionModifier {
return func(o *Options) { return func(o *Options) {

View File

@@ -36,12 +36,12 @@ type WaitingProofStore struct {
// cache is used in order to reduce the number of redundant get // cache is used in order to reduce the number of redundant get
// calls, when object isn't stored in it. // calls, when object isn't stored in it.
cache map[WaitingProofKey]struct{} cache map[WaitingProofKey]struct{}
db *DB db kvdb.Backend
mu sync.RWMutex mu sync.RWMutex
} }
// NewWaitingProofStore creates new instance of proofs storage. // NewWaitingProofStore creates new instance of proofs storage.
func NewWaitingProofStore(db *DB) (*WaitingProofStore, error) { func NewWaitingProofStore(db kvdb.Backend) (*WaitingProofStore, error) {
s := &WaitingProofStore{ s := &WaitingProofStore{
db: db, db: db,
} }

View File

@@ -17,7 +17,7 @@ type ChannelNotifier struct {
ntfnServer *subscribe.Server ntfnServer *subscribe.Server
chanDB *channeldb.DB chanDB *channeldb.ChannelStateDB
} }
// PendingOpenChannelEvent represents a new event where a new channel has // PendingOpenChannelEvent represents a new event where a new channel has
@@ -76,7 +76,7 @@ type FullyResolvedChannelEvent struct {
// New creates a new channel notifier. The ChannelNotifier gets channel // New creates a new channel notifier. The ChannelNotifier gets channel
// events from peers and from the chain arbitrator, and dispatches them to // events from peers and from the chain arbitrator, and dispatches them to
// its clients. // its clients.
func New(chanDB *channeldb.DB) *ChannelNotifier { func New(chanDB *channeldb.ChannelStateDB) *ChannelNotifier {
return &ChannelNotifier{ return &ChannelNotifier{
ntfnServer: subscribe.NewServer(), ntfnServer: subscribe.NewServer(),
chanDB: chanDB, chanDB: chanDB,

View File

@@ -34,7 +34,7 @@ const (
// need the secret key chain in order obtain the prior shachain root so we can // need the secret key chain in order obtain the prior shachain root so we can
// verify the DLP protocol as initiated by the remote node. // verify the DLP protocol as initiated by the remote node.
type chanDBRestorer struct { type chanDBRestorer struct {
db *channeldb.DB db *channeldb.ChannelStateDB
secretKeys keychain.SecretKeyRing secretKeys keychain.SecretKeyRing

View File

@@ -136,7 +136,7 @@ type BreachConfig struct {
// DB provides access to the user's channels, allowing the breach // DB provides access to the user's channels, allowing the breach
// arbiter to determine the current state of a user's channels, and how // arbiter to determine the current state of a user's channels, and how
// it should respond to channel closure. // it should respond to channel closure.
DB *channeldb.DB DB *channeldb.ChannelStateDB
// Estimator is used by the breach arbiter to determine an appropriate // Estimator is used by the breach arbiter to determine an appropriate
// fee level when generating, signing, and broadcasting sweep // fee level when generating, signing, and broadcasting sweep
@@ -1432,11 +1432,11 @@ func (b *BreachArbiter) sweepSpendableOutputsTxn(txWeight int64,
// store is to ensure that we can recover from a restart in the middle of a // store is to ensure that we can recover from a restart in the middle of a
// breached contract retribution. // breached contract retribution.
type RetributionStore struct { type RetributionStore struct {
db *channeldb.DB db kvdb.Backend
} }
// NewRetributionStore creates a new instance of a RetributionStore. // NewRetributionStore creates a new instance of a RetributionStore.
func NewRetributionStore(db *channeldb.DB) *RetributionStore { func NewRetributionStore(db kvdb.Backend) *RetributionStore {
return &RetributionStore{ return &RetributionStore{
db: db, db: db,
} }

View File

@@ -987,7 +987,7 @@ func initBreachedState(t *testing.T) (*BreachArbiter,
contractBreaches := make(chan *ContractBreachEvent) contractBreaches := make(chan *ContractBreachEvent)
brar, cleanUpArb, err := createTestArbiter( brar, cleanUpArb, err := createTestArbiter(
t, contractBreaches, alice.State().Db, t, contractBreaches, alice.State().Db.GetParentDB(),
) )
if err != nil { if err != nil {
t.Fatalf("unable to initialize test breach arbiter: %v", err) t.Fatalf("unable to initialize test breach arbiter: %v", err)
@@ -1164,7 +1164,7 @@ func TestBreachHandoffFail(t *testing.T) {
assertNotPendingClosed(t, alice) assertNotPendingClosed(t, alice)
brar, cleanUpArb, err := createTestArbiter( brar, cleanUpArb, err := createTestArbiter(
t, contractBreaches, alice.State().Db, t, contractBreaches, alice.State().Db.GetParentDB(),
) )
if err != nil { if err != nil {
t.Fatalf("unable to initialize test breach arbiter: %v", err) t.Fatalf("unable to initialize test breach arbiter: %v", err)
@@ -2075,7 +2075,7 @@ func assertNoArbiterBreach(t *testing.T, brar *BreachArbiter,
// assertBrarCleanup blocks until the given channel point has been removed the // assertBrarCleanup blocks until the given channel point has been removed the
// retribution store and the channel is fully closed in the database. // retribution store and the channel is fully closed in the database.
func assertBrarCleanup(t *testing.T, brar *BreachArbiter, func assertBrarCleanup(t *testing.T, brar *BreachArbiter,
chanPoint *wire.OutPoint, db *channeldb.DB) { chanPoint *wire.OutPoint, db *channeldb.ChannelStateDB) {
t.Helper() t.Helper()
@@ -2174,7 +2174,7 @@ func createTestArbiter(t *testing.T, contractBreaches chan *ContractBreachEvent,
notifier := mock.MakeMockSpendNotifier() notifier := mock.MakeMockSpendNotifier()
ba := NewBreachArbiter(&BreachConfig{ ba := NewBreachArbiter(&BreachConfig{
CloseLink: func(_ *wire.OutPoint, _ ChannelCloseType) {}, CloseLink: func(_ *wire.OutPoint, _ ChannelCloseType) {},
DB: db, DB: db.ChannelStateDB(),
Estimator: chainfee.NewStaticEstimator(12500, 0), Estimator: chainfee.NewStaticEstimator(12500, 0),
GenSweepScript: func() ([]byte, error) { return nil, nil }, GenSweepScript: func() ([]byte, error) { return nil, nil },
ContractBreaches: contractBreaches, ContractBreaches: contractBreaches,
@@ -2375,7 +2375,7 @@ func createInitChannels(revocationWindow int) (*lnwallet.LightningChannel, *lnwa
RevocationStore: shachain.NewRevocationStore(), RevocationStore: shachain.NewRevocationStore(),
LocalCommitment: aliceCommit, LocalCommitment: aliceCommit,
RemoteCommitment: aliceCommit, RemoteCommitment: aliceCommit,
Db: dbAlice, Db: dbAlice.ChannelStateDB(),
Packager: channeldb.NewChannelPackager(shortChanID), Packager: channeldb.NewChannelPackager(shortChanID),
FundingTxn: channels.TestFundingTx, FundingTxn: channels.TestFundingTx,
} }
@@ -2393,7 +2393,7 @@ func createInitChannels(revocationWindow int) (*lnwallet.LightningChannel, *lnwa
RevocationStore: shachain.NewRevocationStore(), RevocationStore: shachain.NewRevocationStore(),
LocalCommitment: bobCommit, LocalCommitment: bobCommit,
RemoteCommitment: bobCommit, RemoteCommitment: bobCommit,
Db: dbBob, Db: dbBob.ChannelStateDB(),
Packager: channeldb.NewChannelPackager(shortChanID), Packager: channeldb.NewChannelPackager(shortChanID),
} }

View File

@@ -258,7 +258,9 @@ func (a *arbChannel) NewAnchorResolutions() (*lnwallet.AnchorResolutions,
// same instance that is used by the link. // same instance that is used by the link.
chanPoint := a.channel.FundingOutpoint chanPoint := a.channel.FundingOutpoint
channel, err := a.c.chanSource.FetchChannel(nil, chanPoint) channel, err := a.c.chanSource.ChannelStateDB().FetchChannel(
nil, chanPoint,
)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -301,7 +303,9 @@ func (a *arbChannel) ForceCloseChan() (*lnwallet.LocalForceCloseSummary, error)
// Now that we know the link can't mutate the channel // Now that we know the link can't mutate the channel
// state, we'll read the channel from disk the target // state, we'll read the channel from disk the target
// channel according to its channel point. // channel according to its channel point.
channel, err := a.c.chanSource.FetchChannel(nil, chanPoint) channel, err := a.c.chanSource.ChannelStateDB().FetchChannel(
nil, chanPoint,
)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -422,7 +426,7 @@ func (c *ChainArbitrator) ResolveContract(chanPoint wire.OutPoint) error {
// First, we'll we'll mark the channel as fully closed from the PoV of // First, we'll we'll mark the channel as fully closed from the PoV of
// the channel source. // the channel source.
err := c.chanSource.MarkChanFullyClosed(&chanPoint) err := c.chanSource.ChannelStateDB().MarkChanFullyClosed(&chanPoint)
if err != nil { if err != nil {
log.Errorf("ChainArbitrator: unable to mark ChannelPoint(%v) "+ log.Errorf("ChainArbitrator: unable to mark ChannelPoint(%v) "+
"fully closed: %v", chanPoint, err) "fully closed: %v", chanPoint, err)
@@ -480,7 +484,7 @@ func (c *ChainArbitrator) Start() error {
// First, we'll fetch all the channels that are still open, in order to // First, we'll fetch all the channels that are still open, in order to
// collect them within our set of active contracts. // collect them within our set of active contracts.
openChannels, err := c.chanSource.FetchAllChannels() openChannels, err := c.chanSource.ChannelStateDB().FetchAllChannels()
if err != nil { if err != nil {
return err return err
} }
@@ -538,7 +542,9 @@ func (c *ChainArbitrator) Start() error {
// In addition to the channels that we know to be open, we'll also // In addition to the channels that we know to be open, we'll also
// launch arbitrators to finishing resolving any channels that are in // launch arbitrators to finishing resolving any channels that are in
// the pending close state. // the pending close state.
closingChannels, err := c.chanSource.FetchClosedChannels(true) closingChannels, err := c.chanSource.ChannelStateDB().FetchClosedChannels(
true,
)
if err != nil { if err != nil {
return err return err
} }

View File

@@ -49,7 +49,7 @@ func TestChainArbitratorRepublishCloses(t *testing.T) {
// We manually set the db here to make sure all channels are // We manually set the db here to make sure all channels are
// synced to the same db. // synced to the same db.
channel.Db = db channel.Db = db.ChannelStateDB()
addr := &net.TCPAddr{ addr := &net.TCPAddr{
IP: net.ParseIP("127.0.0.1"), IP: net.ParseIP("127.0.0.1"),
@@ -165,7 +165,7 @@ func TestResolveContract(t *testing.T) {
} }
defer cleanup() defer cleanup()
channel := newChannel.State() channel := newChannel.State()
channel.Db = db channel.Db = db.ChannelStateDB()
addr := &net.TCPAddr{ addr := &net.TCPAddr{
IP: net.ParseIP("127.0.0.1"), IP: net.ParseIP("127.0.0.1"),
Port: 18556, Port: 18556,
@@ -205,7 +205,7 @@ func TestResolveContract(t *testing.T) {
// While the resolver are active, we'll now remove the channel from the // While the resolver are active, we'll now remove the channel from the
// database (mark is as closed). // database (mark is as closed).
err = db.AbandonChannel(&channel.FundingOutpoint, 4) err = db.ChannelStateDB().AbandonChannel(&channel.FundingOutpoint, 4)
if err != nil { if err != nil {
t.Fatalf("unable to remove channel: %v", err) t.Fatalf("unable to remove channel: %v", err)
} }

View File

@@ -58,7 +58,7 @@ func copyChannelState(state *channeldb.OpenChannel) (
*channeldb.OpenChannel, func(), error) { *channeldb.OpenChannel, func(), error) {
// Make a copy of the DB. // Make a copy of the DB.
dbFile := filepath.Join(state.Db.Path(), "channel.db") dbFile := filepath.Join(state.Db.GetParentDB().Path(), "channel.db")
tempDbPath, err := ioutil.TempDir("", "past-state") tempDbPath, err := ioutil.TempDir("", "past-state")
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
@@ -81,7 +81,7 @@ func copyChannelState(state *channeldb.OpenChannel) (
return nil, nil, err return nil, nil, err
} }
chans, err := newDb.FetchAllChannels() chans, err := newDb.ChannelStateDB().FetchAllChannels()
if err != nil { if err != nil {
cleanup() cleanup()
return nil, nil, err return nil, nil, err

View File

@@ -6,7 +6,6 @@ import (
"errors" "errors"
"fmt" "fmt"
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/kvdb"
"github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/lnwire"
) )
@@ -59,7 +58,7 @@ type GossipMessageStore interface {
// version of a message (like in the case of multiple ChannelUpdate's) for a // version of a message (like in the case of multiple ChannelUpdate's) for a
// channel with a peer. // channel with a peer.
type MessageStore struct { type MessageStore struct {
db *channeldb.DB db kvdb.Backend
} }
// A compile-time assertion to ensure messageStore implements the // A compile-time assertion to ensure messageStore implements the
@@ -67,8 +66,8 @@ type MessageStore struct {
var _ GossipMessageStore = (*MessageStore)(nil) var _ GossipMessageStore = (*MessageStore)(nil)
// NewMessageStore creates a new message store backed by a channeldb instance. // NewMessageStore creates a new message store backed by a channeldb instance.
func NewMessageStore(db *channeldb.DB) (*MessageStore, error) { func NewMessageStore(db kvdb.Backend) (*MessageStore, error) {
err := kvdb.Batch(db.Backend, func(tx kvdb.RwTx) error { err := kvdb.Batch(db, func(tx kvdb.RwTx) error {
_, err := tx.CreateTopLevelBucket(messageStoreBucket) _, err := tx.CreateTopLevelBucket(messageStoreBucket)
return err return err
}) })
@@ -124,7 +123,7 @@ func (s *MessageStore) AddMessage(msg lnwire.Message, peerPubKey [33]byte) error
return err return err
} }
return kvdb.Batch(s.db.Backend, func(tx kvdb.RwTx) error { return kvdb.Batch(s.db, func(tx kvdb.RwTx) error {
messageStore := tx.ReadWriteBucket(messageStoreBucket) messageStore := tx.ReadWriteBucket(messageStoreBucket)
if messageStore == nil { if messageStore == nil {
return ErrCorruptedMessageStore return ErrCorruptedMessageStore
@@ -145,7 +144,7 @@ func (s *MessageStore) DeleteMessage(msg lnwire.Message,
return err return err
} }
return kvdb.Batch(s.db.Backend, func(tx kvdb.RwTx) error { return kvdb.Batch(s.db, func(tx kvdb.RwTx) error {
messageStore := tx.ReadWriteBucket(messageStoreBucket) messageStore := tx.ReadWriteBucket(messageStoreBucket)
if messageStore == nil { if messageStore == nil {
return ErrCorruptedMessageStore return ErrCorruptedMessageStore

View File

@@ -59,6 +59,18 @@ in `lnd`, saving developer time and limiting the potential for bugs.
Instructions for enabling Postgres can be found in Instructions for enabling Postgres can be found in
[docs/postgres.md](../postgres.md). [docs/postgres.md](../postgres.md).
### In-memory path finding
Finding a path through the channel graph for sending a payment doesn't involve
any database queries anymore. The [channel graph is now kept fully
in-memory](https://github.com/lightningnetwork/lnd/pull/5642) for up a massive
performance boost when calling `QueryRoutes` or any of the `SendPayment`
variants. Keeping the full graph in memory naturally comes with increased RAM
usage. Users running `lnd` on low-memory systems are advised to run with the
`routing.strictgraphpruning=true` configuration option that more aggressively
removes zombie channels from the graph, reducing the number of channels that
need to be kept in memory.
## Protocol Extensions ## Protocol Extensions
### Explicit Channel Negotiation ### Explicit Channel Negotiation

View File

@@ -23,7 +23,6 @@ import (
"github.com/lightningnetwork/lnd/htlcswitch" "github.com/lightningnetwork/lnd/htlcswitch"
"github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/input"
"github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/keychain"
"github.com/lightningnetwork/lnd/kvdb"
"github.com/lightningnetwork/lnd/labels" "github.com/lightningnetwork/lnd/labels"
"github.com/lightningnetwork/lnd/lnpeer" "github.com/lightningnetwork/lnd/lnpeer"
"github.com/lightningnetwork/lnd/lnrpc" "github.com/lightningnetwork/lnd/lnrpc"
@@ -550,19 +549,6 @@ const (
addedToRouterGraph addedToRouterGraph
) )
var (
// channelOpeningStateBucket is the database bucket used to store the
// channelOpeningState for each channel that is currently in the process
// of being opened.
channelOpeningStateBucket = []byte("channelOpeningState")
// ErrChannelNotFound is an error returned when a channel is not known
// to us. In this case of the fundingManager, this error is returned
// when the channel in question is not considered being in an opening
// state.
ErrChannelNotFound = fmt.Errorf("channel not found")
)
// NewFundingManager creates and initializes a new instance of the // NewFundingManager creates and initializes a new instance of the
// fundingManager. // fundingManager.
func NewFundingManager(cfg Config) (*Manager, error) { func NewFundingManager(cfg Config) (*Manager, error) {
@@ -887,7 +873,7 @@ func (f *Manager) advanceFundingState(channel *channeldb.OpenChannel,
channelState, shortChanID, err := f.getChannelOpeningState( channelState, shortChanID, err := f.getChannelOpeningState(
&channel.FundingOutpoint, &channel.FundingOutpoint,
) )
if err == ErrChannelNotFound { if err == channeldb.ErrChannelNotFound {
// Channel not in fundingManager's opening database, // Channel not in fundingManager's opening database,
// meaning it was successfully announced to the // meaning it was successfully announced to the
// network. // network.
@@ -3551,26 +3537,20 @@ func copyPubKey(pub *btcec.PublicKey) *btcec.PublicKey {
// chanPoint to the channelOpeningStateBucket. // chanPoint to the channelOpeningStateBucket.
func (f *Manager) saveChannelOpeningState(chanPoint *wire.OutPoint, func (f *Manager) saveChannelOpeningState(chanPoint *wire.OutPoint,
state channelOpeningState, shortChanID *lnwire.ShortChannelID) error { state channelOpeningState, shortChanID *lnwire.ShortChannelID) error {
return kvdb.Update(f.cfg.Wallet.Cfg.Database, func(tx kvdb.RwTx) error {
bucket, err := tx.CreateTopLevelBucket(channelOpeningStateBucket) var outpointBytes bytes.Buffer
if err != nil { if err := WriteOutpoint(&outpointBytes, chanPoint); err != nil {
return err return err
} }
var outpointBytes bytes.Buffer // Save state and the uint64 representation of the shortChanID
if err = WriteOutpoint(&outpointBytes, chanPoint); err != nil { // for later use.
return err scratch := make([]byte, 10)
} byteOrder.PutUint16(scratch[:2], uint16(state))
byteOrder.PutUint64(scratch[2:], shortChanID.ToUint64())
// Save state and the uint64 representation of the shortChanID return f.cfg.Wallet.Cfg.Database.SaveChannelOpeningState(
// for later use. outpointBytes.Bytes(), scratch,
scratch := make([]byte, 10) )
byteOrder.PutUint16(scratch[:2], uint16(state))
byteOrder.PutUint64(scratch[2:], shortChanID.ToUint64())
return bucket.Put(outpointBytes.Bytes(), scratch)
}, func() {})
} }
// getChannelOpeningState fetches the channelOpeningState for the provided // getChannelOpeningState fetches the channelOpeningState for the provided
@@ -3579,51 +3559,31 @@ func (f *Manager) saveChannelOpeningState(chanPoint *wire.OutPoint,
func (f *Manager) getChannelOpeningState(chanPoint *wire.OutPoint) ( func (f *Manager) getChannelOpeningState(chanPoint *wire.OutPoint) (
channelOpeningState, *lnwire.ShortChannelID, error) { channelOpeningState, *lnwire.ShortChannelID, error) {
var state channelOpeningState var outpointBytes bytes.Buffer
var shortChanID lnwire.ShortChannelID if err := WriteOutpoint(&outpointBytes, chanPoint); err != nil {
err := kvdb.View(f.cfg.Wallet.Cfg.Database, func(tx kvdb.RTx) error { return 0, nil, err
}
bucket := tx.ReadBucket(channelOpeningStateBucket) value, err := f.cfg.Wallet.Cfg.Database.GetChannelOpeningState(
if bucket == nil { outpointBytes.Bytes(),
// If the bucket does not exist, it means we never added )
// a channel to the db, so return ErrChannelNotFound.
return ErrChannelNotFound
}
var outpointBytes bytes.Buffer
if err := WriteOutpoint(&outpointBytes, chanPoint); err != nil {
return err
}
value := bucket.Get(outpointBytes.Bytes())
if value == nil {
return ErrChannelNotFound
}
state = channelOpeningState(byteOrder.Uint16(value[:2]))
shortChanID = lnwire.NewShortChanIDFromInt(byteOrder.Uint64(value[2:]))
return nil
}, func() {})
if err != nil { if err != nil {
return 0, nil, err return 0, nil, err
} }
state := channelOpeningState(byteOrder.Uint16(value[:2]))
shortChanID := lnwire.NewShortChanIDFromInt(byteOrder.Uint64(value[2:]))
return state, &shortChanID, nil return state, &shortChanID, nil
} }
// deleteChannelOpeningState removes any state for chanPoint from the database. // deleteChannelOpeningState removes any state for chanPoint from the database.
func (f *Manager) deleteChannelOpeningState(chanPoint *wire.OutPoint) error { func (f *Manager) deleteChannelOpeningState(chanPoint *wire.OutPoint) error {
return kvdb.Update(f.cfg.Wallet.Cfg.Database, func(tx kvdb.RwTx) error { var outpointBytes bytes.Buffer
bucket := tx.ReadWriteBucket(channelOpeningStateBucket) if err := WriteOutpoint(&outpointBytes, chanPoint); err != nil {
if bucket == nil { return err
return fmt.Errorf("bucket not found") }
}
var outpointBytes bytes.Buffer return f.cfg.Wallet.Cfg.Database.DeleteChannelOpeningState(
if err := WriteOutpoint(&outpointBytes, chanPoint); err != nil { outpointBytes.Bytes(),
return err )
}
return bucket.Delete(outpointBytes.Bytes())
}, func() {})
} }

View File

@@ -262,7 +262,7 @@ func (n *testNode) AddNewChannel(channel *channeldb.OpenChannel,
} }
} }
func createTestWallet(cdb *channeldb.DB, netParams *chaincfg.Params, func createTestWallet(cdb *channeldb.ChannelStateDB, netParams *chaincfg.Params,
notifier chainntnfs.ChainNotifier, wc lnwallet.WalletController, notifier chainntnfs.ChainNotifier, wc lnwallet.WalletController,
signer input.Signer, keyRing keychain.SecretKeyRing, signer input.Signer, keyRing keychain.SecretKeyRing,
bio lnwallet.BlockChainIO, bio lnwallet.BlockChainIO,
@@ -330,11 +330,13 @@ func createTestFundingManager(t *testing.T, privKey *btcec.PrivateKey,
} }
dbDir := filepath.Join(tempTestDir, "cdb") dbDir := filepath.Join(tempTestDir, "cdb")
cdb, err := channeldb.Open(dbDir) fullDB, err := channeldb.Open(dbDir)
if err != nil { if err != nil {
return nil, err return nil, err
} }
cdb := fullDB.ChannelStateDB()
keyRing := &mock.SecretKeyRing{ keyRing := &mock.SecretKeyRing{
RootKey: alicePrivKey, RootKey: alicePrivKey,
} }
@@ -923,12 +925,12 @@ func assertDatabaseState(t *testing.T, node *testNode,
} }
state, _, err = node.fundingMgr.getChannelOpeningState( state, _, err = node.fundingMgr.getChannelOpeningState(
fundingOutPoint) fundingOutPoint)
if err != nil && err != ErrChannelNotFound { if err != nil && err != channeldb.ErrChannelNotFound {
t.Fatalf("unable to get channel state: %v", err) t.Fatalf("unable to get channel state: %v", err)
} }
// If we found the channel, check if it had the expected state. // If we found the channel, check if it had the expected state.
if err != ErrChannelNotFound && state == expectedState { if err != channeldb.ErrChannelNotFound && state == expectedState {
// Got expected state, return with success. // Got expected state, return with success.
return return
} }
@@ -1166,7 +1168,7 @@ func assertErrChannelNotFound(t *testing.T, node *testNode,
} }
state, _, err = node.fundingMgr.getChannelOpeningState( state, _, err = node.fundingMgr.getChannelOpeningState(
fundingOutPoint) fundingOutPoint)
if err == ErrChannelNotFound { if err == channeldb.ErrChannelNotFound {
// Got expected state, return with success. // Got expected state, return with success.
return return
} else if err != nil { } else if err != nil {

View File

@@ -199,9 +199,16 @@ type circuitMap struct {
// parameterize an instance of circuitMap. // parameterize an instance of circuitMap.
type CircuitMapConfig struct { type CircuitMapConfig struct {
// DB provides the persistent storage engine for the circuit map. // DB provides the persistent storage engine for the circuit map.
// TODO(conner): create abstraction to allow for the substitution of DB kvdb.Backend
// other persistence engines.
DB *channeldb.DB // FetchAllOpenChannels is a function that fetches all currently open
// channels from the channel database.
FetchAllOpenChannels func() ([]*channeldb.OpenChannel, error)
// FetchClosedChannels is a function that fetches all closed channels
// from the channel database.
FetchClosedChannels func(
pendingOnly bool) ([]*channeldb.ChannelCloseSummary, error)
// ExtractErrorEncrypter derives the shared secret used to encrypt // ExtractErrorEncrypter derives the shared secret used to encrypt
// errors from the obfuscator's ephemeral public key. // errors from the obfuscator's ephemeral public key.
@@ -296,7 +303,7 @@ func (cm *circuitMap) cleanClosedChannels() error {
// Find closed channels and cache their ShortChannelIDs into a map. // Find closed channels and cache their ShortChannelIDs into a map.
// This map will be used for looking up relative circuits and keystones. // This map will be used for looking up relative circuits and keystones.
closedChannels, err := cm.cfg.DB.FetchClosedChannels(false) closedChannels, err := cm.cfg.FetchClosedChannels(false)
if err != nil { if err != nil {
return err return err
} }
@@ -629,7 +636,7 @@ func (cm *circuitMap) decodeCircuit(v []byte) (*PaymentCircuit, error) {
// channels. Therefore, it must be called before any links are created to avoid // channels. Therefore, it must be called before any links are created to avoid
// interfering with normal operation. // interfering with normal operation.
func (cm *circuitMap) trimAllOpenCircuits() error { func (cm *circuitMap) trimAllOpenCircuits() error {
activeChannels, err := cm.cfg.DB.FetchAllOpenChannels() activeChannels, err := cm.cfg.FetchAllOpenChannels()
if err != nil { if err != nil {
return err return err
} }
@@ -860,7 +867,7 @@ func (cm *circuitMap) CommitCircuits(circuits ...*PaymentCircuit) (
// Write the entire batch of circuits to the persistent circuit bucket // Write the entire batch of circuits to the persistent circuit bucket
// using bolt's Batch write. This method must be called from multiple, // using bolt's Batch write. This method must be called from multiple,
// distinct goroutines to have any impact on performance. // distinct goroutines to have any impact on performance.
err := kvdb.Batch(cm.cfg.DB.Backend, func(tx kvdb.RwTx) error { err := kvdb.Batch(cm.cfg.DB, func(tx kvdb.RwTx) error {
circuitBkt := tx.ReadWriteBucket(circuitAddKey) circuitBkt := tx.ReadWriteBucket(circuitAddKey)
if circuitBkt == nil { if circuitBkt == nil {
return ErrCorruptedCircuitMap return ErrCorruptedCircuitMap
@@ -1091,7 +1098,7 @@ func (cm *circuitMap) DeleteCircuits(inKeys ...CircuitKey) error {
} }
cm.mtx.Unlock() cm.mtx.Unlock()
err := kvdb.Batch(cm.cfg.DB.Backend, func(tx kvdb.RwTx) error { err := kvdb.Batch(cm.cfg.DB, func(tx kvdb.RwTx) error {
for _, circuit := range removedCircuits { for _, circuit := range removedCircuits {
// If this htlc made it to an outgoing link, load the // If this htlc made it to an outgoing link, load the
// keystone bucket from which we will remove the // keystone bucket from which we will remove the

View File

@@ -103,8 +103,11 @@ func newCircuitMap(t *testing.T) (*htlcswitch.CircuitMapConfig,
onionProcessor := newOnionProcessor(t) onionProcessor := newOnionProcessor(t)
db := makeCircuitDB(t, "")
circuitMapCfg := &htlcswitch.CircuitMapConfig{ circuitMapCfg := &htlcswitch.CircuitMapConfig{
DB: makeCircuitDB(t, ""), DB: db,
FetchAllOpenChannels: db.ChannelStateDB().FetchAllOpenChannels,
FetchClosedChannels: db.ChannelStateDB().FetchClosedChannels,
ExtractErrorEncrypter: onionProcessor.ExtractErrorEncrypter, ExtractErrorEncrypter: onionProcessor.ExtractErrorEncrypter,
} }
@@ -634,13 +637,17 @@ func makeCircuitDB(t *testing.T, path string) *channeldb.DB {
func restartCircuitMap(t *testing.T, cfg *htlcswitch.CircuitMapConfig) ( func restartCircuitMap(t *testing.T, cfg *htlcswitch.CircuitMapConfig) (
*htlcswitch.CircuitMapConfig, htlcswitch.CircuitMap) { *htlcswitch.CircuitMapConfig, htlcswitch.CircuitMap) {
// Record the current temp path and close current db. // Record the current temp path and close current db. We know we have
dbPath := cfg.DB.Path() // a full channeldb.DB here since we created it just above.
dbPath := cfg.DB.(*channeldb.DB).Path()
cfg.DB.Close() cfg.DB.Close()
// Reinitialize circuit map with same db path. // Reinitialize circuit map with same db path.
db := makeCircuitDB(t, dbPath)
cfg2 := &htlcswitch.CircuitMapConfig{ cfg2 := &htlcswitch.CircuitMapConfig{
DB: makeCircuitDB(t, dbPath), DB: db,
FetchAllOpenChannels: db.ChannelStateDB().FetchAllOpenChannels,
FetchClosedChannels: db.ChannelStateDB().FetchClosedChannels,
ExtractErrorEncrypter: cfg.ExtractErrorEncrypter, ExtractErrorEncrypter: cfg.ExtractErrorEncrypter,
} }
cm2, err := htlcswitch.NewCircuitMap(cfg2) cm2, err := htlcswitch.NewCircuitMap(cfg2)

View File

@@ -1938,7 +1938,7 @@ func newSingleLinkTestHarness(chanAmt, chanReserve btcutil.Amount) (
pCache := newMockPreimageCache() pCache := newMockPreimageCache()
aliceDb := aliceLc.channel.State().Db aliceDb := aliceLc.channel.State().Db.GetParentDB()
aliceSwitch, err := initSwitchWithDB(testStartingHeight, aliceDb) aliceSwitch, err := initSwitchWithDB(testStartingHeight, aliceDb)
if err != nil { if err != nil {
return nil, nil, nil, nil, nil, nil, err return nil, nil, nil, nil, nil, nil, err
@@ -4438,7 +4438,7 @@ func (h *persistentLinkHarness) restartLink(
pCache = newMockPreimageCache() pCache = newMockPreimageCache()
) )
aliceDb := aliceChannel.State().Db aliceDb := aliceChannel.State().Db.GetParentDB()
aliceSwitch := h.coreLink.cfg.Switch aliceSwitch := h.coreLink.cfg.Switch
if restartSwitch { if restartSwitch {
var err error var err error

View File

@@ -170,8 +170,10 @@ func initSwitchWithDB(startingHeight uint32, db *channeldb.DB) (*Switch, error)
} }
cfg := Config{ cfg := Config{
DB: db, DB: db,
SwitchPackager: channeldb.NewSwitchPackager(), FetchAllOpenChannels: db.ChannelStateDB().FetchAllOpenChannels,
FetchClosedChannels: db.ChannelStateDB().FetchClosedChannels,
SwitchPackager: channeldb.NewSwitchPackager(),
FwdingLog: &mockForwardingLog{ FwdingLog: &mockForwardingLog{
events: make(map[time.Time]channeldb.ForwardingEvent), events: make(map[time.Time]channeldb.ForwardingEvent),
}, },

View File

@@ -83,7 +83,7 @@ func deserializeNetworkResult(r io.Reader) (*networkResult, error) {
// is back. The Switch will checkpoint any received result to the store, and // is back. The Switch will checkpoint any received result to the store, and
// the store will keep results and notify the callers about them. // the store will keep results and notify the callers about them.
type networkResultStore struct { type networkResultStore struct {
db *channeldb.DB backend kvdb.Backend
// results is a map from paymentIDs to channels where subscribers to // results is a map from paymentIDs to channels where subscribers to
// payment results will be notified. // payment results will be notified.
@@ -96,9 +96,9 @@ type networkResultStore struct {
paymentIDMtx *multimutex.Mutex paymentIDMtx *multimutex.Mutex
} }
func newNetworkResultStore(db *channeldb.DB) *networkResultStore { func newNetworkResultStore(db kvdb.Backend) *networkResultStore {
return &networkResultStore{ return &networkResultStore{
db: db, backend: db,
results: make(map[uint64][]chan *networkResult), results: make(map[uint64][]chan *networkResult),
paymentIDMtx: multimutex.NewMutex(), paymentIDMtx: multimutex.NewMutex(),
} }
@@ -126,7 +126,7 @@ func (store *networkResultStore) storeResult(paymentID uint64,
var paymentIDBytes [8]byte var paymentIDBytes [8]byte
binary.BigEndian.PutUint64(paymentIDBytes[:], paymentID) binary.BigEndian.PutUint64(paymentIDBytes[:], paymentID)
err := kvdb.Batch(store.db.Backend, func(tx kvdb.RwTx) error { err := kvdb.Batch(store.backend, func(tx kvdb.RwTx) error {
networkResults, err := tx.CreateTopLevelBucket( networkResults, err := tx.CreateTopLevelBucket(
networkResultStoreBucketKey, networkResultStoreBucketKey,
) )
@@ -171,7 +171,7 @@ func (store *networkResultStore) subscribeResult(paymentID uint64) (
resultChan = make(chan *networkResult, 1) resultChan = make(chan *networkResult, 1)
) )
err := kvdb.View(store.db, func(tx kvdb.RTx) error { err := kvdb.View(store.backend, func(tx kvdb.RTx) error {
var err error var err error
result, err = fetchResult(tx, paymentID) result, err = fetchResult(tx, paymentID)
switch { switch {
@@ -219,7 +219,7 @@ func (store *networkResultStore) getResult(pid uint64) (
*networkResult, error) { *networkResult, error) {
var result *networkResult var result *networkResult
err := kvdb.View(store.db, func(tx kvdb.RTx) error { err := kvdb.View(store.backend, func(tx kvdb.RTx) error {
var err error var err error
result, err = fetchResult(tx, pid) result, err = fetchResult(tx, pid)
return err return err
@@ -260,7 +260,7 @@ func fetchResult(tx kvdb.RTx, pid uint64) (*networkResult, error) {
// concurrently while this process is ongoing, as its result might end up being // concurrently while this process is ongoing, as its result might end up being
// deleted. // deleted.
func (store *networkResultStore) cleanStore(keep map[uint64]struct{}) error { func (store *networkResultStore) cleanStore(keep map[uint64]struct{}) error {
return kvdb.Update(store.db.Backend, func(tx kvdb.RwTx) error { return kvdb.Update(store.backend, func(tx kvdb.RwTx) error {
networkResults, err := tx.CreateTopLevelBucket( networkResults, err := tx.CreateTopLevelBucket(
networkResultStoreBucketKey, networkResultStoreBucketKey,
) )

View File

@@ -130,9 +130,18 @@ type Config struct {
// subsystem. // subsystem.
LocalChannelClose func(pubKey []byte, request *ChanClose) LocalChannelClose func(pubKey []byte, request *ChanClose)
// DB is the channeldb instance that will be used to back the switch's // DB is the database backend that will be used to back the switch's
// persistent circuit map. // persistent circuit map.
DB *channeldb.DB DB kvdb.Backend
// FetchAllOpenChannels is a function that fetches all currently open
// channels from the channel database.
FetchAllOpenChannels func() ([]*channeldb.OpenChannel, error)
// FetchClosedChannels is a function that fetches all closed channels
// from the channel database.
FetchClosedChannels func(
pendingOnly bool) ([]*channeldb.ChannelCloseSummary, error)
// SwitchPackager provides access to the forwarding packages of all // SwitchPackager provides access to the forwarding packages of all
// active channels. This gives the switch the ability to read arbitrary // active channels. This gives the switch the ability to read arbitrary
@@ -294,6 +303,8 @@ type Switch struct {
func New(cfg Config, currentHeight uint32) (*Switch, error) { func New(cfg Config, currentHeight uint32) (*Switch, error) {
circuitMap, err := NewCircuitMap(&CircuitMapConfig{ circuitMap, err := NewCircuitMap(&CircuitMapConfig{
DB: cfg.DB, DB: cfg.DB,
FetchAllOpenChannels: cfg.FetchAllOpenChannels,
FetchClosedChannels: cfg.FetchClosedChannels,
ExtractErrorEncrypter: cfg.ExtractErrorEncrypter, ExtractErrorEncrypter: cfg.ExtractErrorEncrypter,
}) })
if err != nil { if err != nil {
@@ -1455,7 +1466,7 @@ func (s *Switch) closeCircuit(pkt *htlcPacket) (*PaymentCircuit, error) {
// we're the originator of the payment, so the link stops attempting to // we're the originator of the payment, so the link stops attempting to
// re-broadcast. // re-broadcast.
func (s *Switch) ackSettleFail(settleFailRefs ...channeldb.SettleFailRef) error { func (s *Switch) ackSettleFail(settleFailRefs ...channeldb.SettleFailRef) error {
return kvdb.Batch(s.cfg.DB.Backend, func(tx kvdb.RwTx) error { return kvdb.Batch(s.cfg.DB, func(tx kvdb.RwTx) error {
return s.cfg.SwitchPackager.AckSettleFails(tx, settleFailRefs...) return s.cfg.SwitchPackager.AckSettleFails(tx, settleFailRefs...)
}) })
} }
@@ -1859,7 +1870,7 @@ func (s *Switch) Start() error {
// forwarding packages and reforwards any Settle or Fail HTLCs found. This is // forwarding packages and reforwards any Settle or Fail HTLCs found. This is
// used to resurrect the switch's mailboxes after a restart. // used to resurrect the switch's mailboxes after a restart.
func (s *Switch) reforwardResponses() error { func (s *Switch) reforwardResponses() error {
openChannels, err := s.cfg.DB.FetchAllOpenChannels() openChannels, err := s.cfg.FetchAllOpenChannels()
if err != nil { if err != nil {
return err return err
} }
@@ -2122,6 +2133,17 @@ func (s *Switch) getLink(chanID lnwire.ChannelID) (ChannelLink, error) {
return link, nil return link, nil
} }
// GetLinkByShortID attempts to return the link which possesses the target short
// channel ID.
func (s *Switch) GetLinkByShortID(chanID lnwire.ShortChannelID) (ChannelLink,
error) {
s.indexMtx.RLock()
defer s.indexMtx.RUnlock()
return s.getLinkByShortID(chanID)
}
// getLinkByShortID attempts to return the link which possesses the target // getLinkByShortID attempts to return the link which possesses the target
// short channel ID. // short channel ID.
// //

View File

@@ -306,7 +306,7 @@ func createTestChannel(alicePrivKey, bobPrivKey []byte,
LocalCommitment: aliceCommit, LocalCommitment: aliceCommit,
RemoteCommitment: aliceCommit, RemoteCommitment: aliceCommit,
ShortChannelID: chanID, ShortChannelID: chanID,
Db: dbAlice, Db: dbAlice.ChannelStateDB(),
Packager: channeldb.NewChannelPackager(chanID), Packager: channeldb.NewChannelPackager(chanID),
FundingTxn: channels.TestFundingTx, FundingTxn: channels.TestFundingTx,
} }
@@ -325,7 +325,7 @@ func createTestChannel(alicePrivKey, bobPrivKey []byte,
LocalCommitment: bobCommit, LocalCommitment: bobCommit,
RemoteCommitment: bobCommit, RemoteCommitment: bobCommit,
ShortChannelID: chanID, ShortChannelID: chanID,
Db: dbBob, Db: dbBob.ChannelStateDB(),
Packager: channeldb.NewChannelPackager(chanID), Packager: channeldb.NewChannelPackager(chanID),
} }
@@ -384,7 +384,8 @@ func createTestChannel(alicePrivKey, bobPrivKey []byte,
} }
restoreAlice := func() (*lnwallet.LightningChannel, error) { restoreAlice := func() (*lnwallet.LightningChannel, error) {
aliceStoredChannels, err := dbAlice.FetchOpenChannels(aliceKeyPub) aliceStoredChannels, err := dbAlice.ChannelStateDB().
FetchOpenChannels(aliceKeyPub)
switch err { switch err {
case nil: case nil:
case kvdb.ErrDatabaseNotOpen: case kvdb.ErrDatabaseNotOpen:
@@ -394,7 +395,8 @@ func createTestChannel(alicePrivKey, bobPrivKey []byte,
"db: %v", err) "db: %v", err)
} }
aliceStoredChannels, err = dbAlice.FetchOpenChannels(aliceKeyPub) aliceStoredChannels, err = dbAlice.ChannelStateDB().
FetchOpenChannels(aliceKeyPub)
if err != nil { if err != nil {
return nil, errors.Errorf("unable to fetch alice "+ return nil, errors.Errorf("unable to fetch alice "+
"channel: %v", err) "channel: %v", err)
@@ -428,7 +430,8 @@ func createTestChannel(alicePrivKey, bobPrivKey []byte,
} }
restoreBob := func() (*lnwallet.LightningChannel, error) { restoreBob := func() (*lnwallet.LightningChannel, error) {
bobStoredChannels, err := dbBob.FetchOpenChannels(bobKeyPub) bobStoredChannels, err := dbBob.ChannelStateDB().
FetchOpenChannels(bobKeyPub)
switch err { switch err {
case nil: case nil:
case kvdb.ErrDatabaseNotOpen: case kvdb.ErrDatabaseNotOpen:
@@ -438,7 +441,8 @@ func createTestChannel(alicePrivKey, bobPrivKey []byte,
"db: %v", err) "db: %v", err)
} }
bobStoredChannels, err = dbBob.FetchOpenChannels(bobKeyPub) bobStoredChannels, err = dbBob.ChannelStateDB().
FetchOpenChannels(bobKeyPub)
if err != nil { if err != nil {
return nil, errors.Errorf("unable to fetch bob "+ return nil, errors.Errorf("unable to fetch bob "+
"channel: %v", err) "channel: %v", err)
@@ -950,9 +954,9 @@ func newThreeHopNetwork(t testing.TB, aliceChannel, firstBobChannel,
secondBobChannel, carolChannel *lnwallet.LightningChannel, secondBobChannel, carolChannel *lnwallet.LightningChannel,
startingHeight uint32, opts ...serverOption) *threeHopNetwork { startingHeight uint32, opts ...serverOption) *threeHopNetwork {
aliceDb := aliceChannel.State().Db aliceDb := aliceChannel.State().Db.GetParentDB()
bobDb := firstBobChannel.State().Db bobDb := firstBobChannel.State().Db.GetParentDB()
carolDb := carolChannel.State().Db carolDb := carolChannel.State().Db.GetParentDB()
hopNetwork := newHopNetwork() hopNetwork := newHopNetwork()
@@ -1201,8 +1205,8 @@ func newTwoHopNetwork(t testing.TB,
aliceChannel, bobChannel *lnwallet.LightningChannel, aliceChannel, bobChannel *lnwallet.LightningChannel,
startingHeight uint32) *twoHopNetwork { startingHeight uint32) *twoHopNetwork {
aliceDb := aliceChannel.State().Db aliceDb := aliceChannel.State().Db.GetParentDB()
bobDb := bobChannel.State().Db bobDb := bobChannel.State().Db.GetParentDB()
hopNetwork := newHopNetwork() hopNetwork := newHopNetwork()

24
lnd.go
View File

@@ -22,6 +22,7 @@ import (
"sync" "sync"
"time" "time"
"github.com/btcsuite/btcd/chaincfg"
"github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcutil" "github.com/btcsuite/btcutil"
"github.com/btcsuite/btcwallet/wallet" "github.com/btcsuite/btcwallet/wallet"
@@ -697,7 +698,7 @@ func Main(cfg *Config, lisCfg ListenerCfg, interceptor signal.Interceptor) error
BtcdMode: cfg.BtcdMode, BtcdMode: cfg.BtcdMode,
LtcdMode: cfg.LtcdMode, LtcdMode: cfg.LtcdMode,
HeightHintDB: dbs.heightHintDB, HeightHintDB: dbs.heightHintDB,
ChanStateDB: dbs.chanStateDB, ChanStateDB: dbs.chanStateDB.ChannelStateDB(),
PrivateWalletPw: privateWalletPw, PrivateWalletPw: privateWalletPw,
PublicWalletPw: publicWalletPw, PublicWalletPw: publicWalletPw,
Birthday: walletInitParams.Birthday, Birthday: walletInitParams.Birthday,
@@ -1679,14 +1680,27 @@ func initializeDatabases(ctx context.Context,
"instances") "instances")
} }
// Otherwise, we'll open two instances, one for the state we only need dbOptions := []channeldb.OptionModifier{
// locally, and the other for things we want to ensure are replicated.
dbs.graphDB, err = channeldb.CreateWithBackend(
databaseBackends.GraphDB,
channeldb.OptionSetRejectCacheSize(cfg.Caches.RejectCacheSize), channeldb.OptionSetRejectCacheSize(cfg.Caches.RejectCacheSize),
channeldb.OptionSetChannelCacheSize(cfg.Caches.ChannelCacheSize), channeldb.OptionSetChannelCacheSize(cfg.Caches.ChannelCacheSize),
channeldb.OptionSetBatchCommitInterval(cfg.DB.BatchCommitInterval), channeldb.OptionSetBatchCommitInterval(cfg.DB.BatchCommitInterval),
channeldb.OptionDryRunMigration(cfg.DryRunMigration), channeldb.OptionDryRunMigration(cfg.DryRunMigration),
}
// We want to pre-allocate the channel graph cache according to what we
// expect for mainnet to speed up memory allocation.
if cfg.ActiveNetParams.Name == chaincfg.MainNetParams.Name {
dbOptions = append(
dbOptions, channeldb.OptionSetPreAllocCacheNumNodes(
channeldb.DefaultPreAllocCacheNumNodes,
),
)
}
// Otherwise, we'll open two instances, one for the state we only need
// locally, and the other for things we want to ensure are replicated.
dbs.graphDB, err = channeldb.CreateWithBackend(
databaseBackends.GraphDB, dbOptions...,
) )
switch { switch {
// Give the DB a chance to dry run the migration. Since we know that // Give the DB a chance to dry run the migration. Since we know that

View File

@@ -56,7 +56,7 @@ type AddInvoiceConfig struct {
// ChanDB is a global boltdb instance which is needed to access the // ChanDB is a global boltdb instance which is needed to access the
// channel graph. // channel graph.
ChanDB *channeldb.DB ChanDB *channeldb.ChannelStateDB
// Graph holds a reference to the ChannelGraph database. // Graph holds a reference to the ChannelGraph database.
Graph *channeldb.ChannelGraph Graph *channeldb.ChannelGraph

View File

@@ -51,7 +51,7 @@ type Config struct {
// ChanStateDB is a possibly replicated db instance which contains the // ChanStateDB is a possibly replicated db instance which contains the
// channel state // channel state
ChanStateDB *channeldb.DB ChanStateDB *channeldb.ChannelStateDB
// GenInvoiceFeatures returns a feature containing feature bits that // GenInvoiceFeatures returns a feature containing feature bits that
// should be advertised on freshly generated invoices. // should be advertised on freshly generated invoices.

View File

@@ -55,7 +55,7 @@ type RouterBackend struct {
FindRoute func(source, target route.Vertex, FindRoute func(source, target route.Vertex,
amt lnwire.MilliSatoshi, restrictions *routing.RestrictParams, amt lnwire.MilliSatoshi, restrictions *routing.RestrictParams,
destCustomRecords record.CustomSet, destCustomRecords record.CustomSet,
routeHints map[route.Vertex][]*channeldb.ChannelEdgePolicy, routeHints map[route.Vertex][]*channeldb.CachedEdgePolicy,
finalExpiry uint16) (*route.Route, error) finalExpiry uint16) (*route.Route, error)
MissionControl MissionControl MissionControl MissionControl

View File

@@ -126,7 +126,7 @@ func testQueryRoutes(t *testing.T, useMissionControl bool, useMsat bool,
findRoute := func(source, target route.Vertex, findRoute := func(source, target route.Vertex,
amt lnwire.MilliSatoshi, restrictions *routing.RestrictParams, amt lnwire.MilliSatoshi, restrictions *routing.RestrictParams,
_ record.CustomSet, _ record.CustomSet,
routeHints map[route.Vertex][]*channeldb.ChannelEdgePolicy, routeHints map[route.Vertex][]*channeldb.CachedEdgePolicy,
finalExpiry uint16) (*route.Route, error) { finalExpiry uint16) (*route.Route, error) {
if int64(amt) != amtSat*1000 { if int64(amt) != amtSat*1000 {

View File

@@ -25,24 +25,20 @@ func testUpdateChanStatus(net *lntest.NetworkHarness, t *harnessTest) {
ctxb := context.Background() ctxb := context.Background()
// Create two fresh nodes and open a channel between them. // Create two fresh nodes and open a channel between them.
alice := net.NewNode( alice := net.NewNode(t.t, "Alice", []string{
t.t, "Alice", []string{ "--minbackoff=10s",
"--minbackoff=10s", "--chan-enable-timeout=1.5s",
"--chan-enable-timeout=1.5s", "--chan-disable-timeout=3s",
"--chan-disable-timeout=3s", "--chan-status-sample-interval=.5s",
"--chan-status-sample-interval=.5s", })
},
)
defer shutdownAndAssert(net, t, alice) defer shutdownAndAssert(net, t, alice)
bob := net.NewNode( bob := net.NewNode(t.t, "Bob", []string{
t.t, "Bob", []string{ "--minbackoff=10s",
"--minbackoff=10s", "--chan-enable-timeout=1.5s",
"--chan-enable-timeout=1.5s", "--chan-disable-timeout=3s",
"--chan-disable-timeout=3s", "--chan-status-sample-interval=.5s",
"--chan-status-sample-interval=.5s", })
},
)
defer shutdownAndAssert(net, t, bob) defer shutdownAndAssert(net, t, bob)
// Connect Alice to Bob. // Connect Alice to Bob.
@@ -55,36 +51,32 @@ func testUpdateChanStatus(net *lntest.NetworkHarness, t *harnessTest) {
// being the sole funder of the channel. // being the sole funder of the channel.
chanAmt := btcutil.Amount(100000) chanAmt := btcutil.Amount(100000)
chanPoint := openChannelAndAssert( chanPoint := openChannelAndAssert(
t, net, alice, bob, t, net, alice, bob, lntest.OpenChannelParams{
lntest.OpenChannelParams{
Amt: chanAmt, Amt: chanAmt,
}, },
) )
// Wait for Alice and Bob to receive the channel edge from the // Wait for Alice and Bob to receive the channel edge from the
// funding manager. // funding manager.
ctxt, _ := context.WithTimeout(ctxb, defaultTimeout) ctxt, cancel := context.WithTimeout(ctxb, defaultTimeout)
defer cancel()
err := alice.WaitForNetworkChannelOpen(ctxt, chanPoint) err := alice.WaitForNetworkChannelOpen(ctxt, chanPoint)
if err != nil { require.NoError(t.t, err, "alice didn't see the alice->bob channel")
t.Fatalf("alice didn't see the alice->bob channel before "+
"timeout: %v", err)
}
ctxt, _ = context.WithTimeout(ctxb, defaultTimeout)
err = bob.WaitForNetworkChannelOpen(ctxt, chanPoint) err = bob.WaitForNetworkChannelOpen(ctxt, chanPoint)
if err != nil { require.NoError(t.t, err, "bob didn't see the alice->bob channel")
t.Fatalf("bob didn't see the bob->alice channel before "+
"timeout: %v", err)
}
// Launch a node for Carol which will connect to Alice and Bob in // Launch a node for Carol which will connect to Alice and Bob in order
// order to receive graph updates. This will ensure that the // to receive graph updates. This will ensure that the channel updates
// channel updates are propagated throughout the network. // are propagated throughout the network.
carol := net.NewNode(t.t, "Carol", nil) carol := net.NewNode(t.t, "Carol", nil)
defer shutdownAndAssert(net, t, carol) defer shutdownAndAssert(net, t, carol)
// Connect both Alice and Bob to the new node Carol, so she can sync her
// graph.
net.ConnectNodes(t.t, alice, carol) net.ConnectNodes(t.t, alice, carol)
net.ConnectNodes(t.t, bob, carol) net.ConnectNodes(t.t, bob, carol)
waitForGraphSync(t, carol)
// assertChannelUpdate checks that the required policy update has // assertChannelUpdate checks that the required policy update has
// happened on the given node. // happened on the given node.
@@ -109,12 +101,11 @@ func testUpdateChanStatus(net *lntest.NetworkHarness, t *harnessTest) {
ChanPoint: chanPoint, ChanPoint: chanPoint,
Action: action, Action: action,
} }
ctxt, _ = context.WithTimeout(ctxb, defaultTimeout) ctxt, cancel := context.WithTimeout(ctxb, defaultTimeout)
defer cancel()
_, err = node.RouterClient.UpdateChanStatus(ctxt, req) _, err = node.RouterClient.UpdateChanStatus(ctxt, req)
if err != nil { require.NoErrorf(t.t, err, "UpdateChanStatus")
t.Fatalf("unable to call UpdateChanStatus for %s's node: %v",
node.Name(), err)
}
} }
// assertEdgeDisabled ensures that a given node has the correct // assertEdgeDisabled ensures that a given node has the correct
@@ -122,26 +113,30 @@ func testUpdateChanStatus(net *lntest.NetworkHarness, t *harnessTest) {
assertEdgeDisabled := func(node *lntest.HarnessNode, assertEdgeDisabled := func(node *lntest.HarnessNode,
chanPoint *lnrpc.ChannelPoint, disabled bool) { chanPoint *lnrpc.ChannelPoint, disabled bool) {
var predErr error outPoint, err := lntest.MakeOutpoint(chanPoint)
err = wait.Predicate(func() bool { require.NoError(t.t, err)
err = wait.NoError(func() error {
req := &lnrpc.ChannelGraphRequest{ req := &lnrpc.ChannelGraphRequest{
IncludeUnannounced: true, IncludeUnannounced: true,
} }
ctxt, _ = context.WithTimeout(ctxb, defaultTimeout) ctxt, cancel := context.WithTimeout(ctxb, defaultTimeout)
defer cancel()
chanGraph, err := node.DescribeGraph(ctxt, req) chanGraph, err := node.DescribeGraph(ctxt, req)
if err != nil { if err != nil {
predErr = fmt.Errorf("unable to query node %v's graph: %v", node, err) return fmt.Errorf("unable to query node %v's "+
return false "graph: %v", node, err)
} }
numEdges := len(chanGraph.Edges) numEdges := len(chanGraph.Edges)
if numEdges != 1 { if numEdges != 1 {
predErr = fmt.Errorf("expected to find 1 edge in the graph, found %d", numEdges) return fmt.Errorf("expected to find 1 edge in "+
return false "the graph, found %d", numEdges)
} }
edge := chanGraph.Edges[0] edge := chanGraph.Edges[0]
if edge.ChanPoint != chanPoint.GetFundingTxidStr() { if edge.ChanPoint != outPoint.String() {
predErr = fmt.Errorf("expected chan_point %v, got %v", return fmt.Errorf("expected chan_point %v, "+
chanPoint.GetFundingTxidStr(), edge.ChanPoint) "got %v", outPoint, edge.ChanPoint)
} }
var policy *lnrpc.RoutingPolicy var policy *lnrpc.RoutingPolicy
if node.PubKeyStr == edge.Node1Pub { if node.PubKeyStr == edge.Node1Pub {
@@ -150,15 +145,14 @@ func testUpdateChanStatus(net *lntest.NetworkHarness, t *harnessTest) {
policy = edge.Node2Policy policy = edge.Node2Policy
} }
if disabled != policy.Disabled { if disabled != policy.Disabled {
predErr = fmt.Errorf("expected policy.Disabled to be %v, "+ return fmt.Errorf("expected policy.Disabled "+
"but policy was %v", disabled, policy) "to be %v, but policy was %v", disabled,
return false policy)
} }
return true
return nil
}, defaultTimeout) }, defaultTimeout)
if err != nil { require.NoError(t.t, err)
t.Fatalf("%v", predErr)
}
} }
// When updating the state of the channel between Alice and Bob, we // When updating the state of the channel between Alice and Bob, we
@@ -193,9 +187,7 @@ func testUpdateChanStatus(net *lntest.NetworkHarness, t *harnessTest) {
// disconnections from automatically disabling the channel again // disconnections from automatically disabling the channel again
// (we don't want to clutter the network with channels that are // (we don't want to clutter the network with channels that are
// falsely advertised as enabled when they don't work). // falsely advertised as enabled when they don't work).
if err := net.DisconnectNodes(alice, bob); err != nil { require.NoError(t.t, net.DisconnectNodes(alice, bob))
t.Fatalf("unable to disconnect Alice from Bob: %v", err)
}
expectedPolicy.Disabled = true expectedPolicy.Disabled = true
assertChannelUpdate(alice, expectedPolicy) assertChannelUpdate(alice, expectedPolicy)
assertChannelUpdate(bob, expectedPolicy) assertChannelUpdate(bob, expectedPolicy)
@@ -217,9 +209,7 @@ func testUpdateChanStatus(net *lntest.NetworkHarness, t *harnessTest) {
expectedPolicy.Disabled = true expectedPolicy.Disabled = true
assertChannelUpdate(alice, expectedPolicy) assertChannelUpdate(alice, expectedPolicy)
if err := net.DisconnectNodes(alice, bob); err != nil { require.NoError(t.t, net.DisconnectNodes(alice, bob))
t.Fatalf("unable to disconnect Alice from Bob: %v", err)
}
// Bob sends a "Disabled = true" update upon detecting the // Bob sends a "Disabled = true" update upon detecting the
// disconnect. // disconnect.
@@ -237,9 +227,7 @@ func testUpdateChanStatus(net *lntest.NetworkHarness, t *harnessTest) {
// note the asymmetry between manual enable and manual disable! // note the asymmetry between manual enable and manual disable!
assertEdgeDisabled(alice, chanPoint, true) assertEdgeDisabled(alice, chanPoint, true)
if err := net.DisconnectNodes(alice, bob); err != nil { require.NoError(t.t, net.DisconnectNodes(alice, bob))
t.Fatalf("unable to disconnect Alice from Bob: %v", err)
}
// Bob sends a "Disabled = true" update upon detecting the // Bob sends a "Disabled = true" update upon detecting the
// disconnect. // disconnect.

View File

@@ -18,7 +18,7 @@ type Config struct {
// Database is a wrapper around a namespace within boltdb reserved for // Database is a wrapper around a namespace within boltdb reserved for
// ln-based wallet metadata. See the 'channeldb' package for further // ln-based wallet metadata. See the 'channeldb' package for further
// information. // information.
Database *channeldb.DB Database *channeldb.ChannelStateDB
// Notifier is used by in order to obtain notifications about funding // Notifier is used by in order to obtain notifications about funding
// transaction reaching a specified confirmation depth, and to catch // transaction reaching a specified confirmation depth, and to catch

View File

@@ -327,13 +327,13 @@ func createTestWallet(tempTestDir string, miningNode *rpctest.Harness,
signer input.Signer, bio lnwallet.BlockChainIO) (*lnwallet.LightningWallet, error) { signer input.Signer, bio lnwallet.BlockChainIO) (*lnwallet.LightningWallet, error) {
dbDir := filepath.Join(tempTestDir, "cdb") dbDir := filepath.Join(tempTestDir, "cdb")
cdb, err := channeldb.Open(dbDir) fullDB, err := channeldb.Open(dbDir)
if err != nil { if err != nil {
return nil, err return nil, err
} }
cfg := lnwallet.Config{ cfg := lnwallet.Config{
Database: cdb, Database: fullDB.ChannelStateDB(),
Notifier: notifier, Notifier: notifier,
SecretKeyRing: keyRing, SecretKeyRing: keyRing,
WalletController: wc, WalletController: wc,
@@ -2944,11 +2944,11 @@ func clearWalletStates(a, b *lnwallet.LightningWallet) error {
a.ResetReservations() a.ResetReservations()
b.ResetReservations() b.ResetReservations()
if err := a.Cfg.Database.Wipe(); err != nil { if err := a.Cfg.Database.GetParentDB().Wipe(); err != nil {
return err return err
} }
return b.Cfg.Database.Wipe() return b.Cfg.Database.GetParentDB().Wipe()
} }
func waitForMempoolTx(r *rpctest.Harness, txid *chainhash.Hash) error { func waitForMempoolTx(r *rpctest.Harness, txid *chainhash.Hash) error {

View File

@@ -323,7 +323,7 @@ func CreateTestChannels(chanType channeldb.ChannelType) (
RevocationStore: shachain.NewRevocationStore(), RevocationStore: shachain.NewRevocationStore(),
LocalCommitment: aliceLocalCommit, LocalCommitment: aliceLocalCommit,
RemoteCommitment: aliceRemoteCommit, RemoteCommitment: aliceRemoteCommit,
Db: dbAlice, Db: dbAlice.ChannelStateDB(),
Packager: channeldb.NewChannelPackager(shortChanID), Packager: channeldb.NewChannelPackager(shortChanID),
FundingTxn: testTx, FundingTxn: testTx,
} }
@@ -341,7 +341,7 @@ func CreateTestChannels(chanType channeldb.ChannelType) (
RevocationStore: shachain.NewRevocationStore(), RevocationStore: shachain.NewRevocationStore(),
LocalCommitment: bobLocalCommit, LocalCommitment: bobLocalCommit,
RemoteCommitment: bobRemoteCommit, RemoteCommitment: bobRemoteCommit,
Db: dbBob, Db: dbBob.ChannelStateDB(),
Packager: channeldb.NewChannelPackager(shortChanID), Packager: channeldb.NewChannelPackager(shortChanID),
} }

View File

@@ -940,7 +940,7 @@ func createTestChannelsForVectors(tc *testContext, chanType channeldb.ChannelTyp
RevocationStore: shachain.NewRevocationStore(), RevocationStore: shachain.NewRevocationStore(),
LocalCommitment: remoteCommit, LocalCommitment: remoteCommit,
RemoteCommitment: remoteCommit, RemoteCommitment: remoteCommit,
Db: dbRemote, Db: dbRemote.ChannelStateDB(),
Packager: channeldb.NewChannelPackager(shortChanID), Packager: channeldb.NewChannelPackager(shortChanID),
FundingTxn: tc.fundingTx.MsgTx(), FundingTxn: tc.fundingTx.MsgTx(),
} }
@@ -958,7 +958,7 @@ func createTestChannelsForVectors(tc *testContext, chanType channeldb.ChannelTyp
RevocationStore: shachain.NewRevocationStore(), RevocationStore: shachain.NewRevocationStore(),
LocalCommitment: localCommit, LocalCommitment: localCommit,
RemoteCommitment: localCommit, RemoteCommitment: localCommit,
Db: dbLocal, Db: dbLocal.ChannelStateDB(),
Packager: channeldb.NewChannelPackager(shortChanID), Packager: channeldb.NewChannelPackager(shortChanID),
FundingTxn: tc.fundingTx.MsgTx(), FundingTxn: tc.fundingTx.MsgTx(),
} }

View File

@@ -185,7 +185,7 @@ type Config struct {
InterceptSwitch *htlcswitch.InterceptableSwitch InterceptSwitch *htlcswitch.InterceptableSwitch
// ChannelDB is used to fetch opened channels, and closed channels. // ChannelDB is used to fetch opened channels, and closed channels.
ChannelDB *channeldb.DB ChannelDB *channeldb.ChannelStateDB
// ChannelGraph is a pointer to the channel graph which is used to // ChannelGraph is a pointer to the channel graph which is used to
// query information about the set of known active channels. // query information about the set of known active channels.

View File

@@ -229,7 +229,7 @@ func createTestPeer(notifier chainntnfs.ChainNotifier,
RevocationStore: shachain.NewRevocationStore(), RevocationStore: shachain.NewRevocationStore(),
LocalCommitment: aliceCommit, LocalCommitment: aliceCommit,
RemoteCommitment: aliceCommit, RemoteCommitment: aliceCommit,
Db: dbAlice, Db: dbAlice.ChannelStateDB(),
Packager: channeldb.NewChannelPackager(shortChanID), Packager: channeldb.NewChannelPackager(shortChanID),
FundingTxn: channels.TestFundingTx, FundingTxn: channels.TestFundingTx,
} }
@@ -246,7 +246,7 @@ func createTestPeer(notifier chainntnfs.ChainNotifier,
RevocationStore: shachain.NewRevocationStore(), RevocationStore: shachain.NewRevocationStore(),
LocalCommitment: bobCommit, LocalCommitment: bobCommit,
RemoteCommitment: bobCommit, RemoteCommitment: bobCommit,
Db: dbBob, Db: dbBob.ChannelStateDB(),
Packager: channeldb.NewChannelPackager(shortChanID), Packager: channeldb.NewChannelPackager(shortChanID),
} }
@@ -321,7 +321,7 @@ func createTestPeer(notifier chainntnfs.ChainNotifier,
ChanStatusSampleInterval: 30 * time.Second, ChanStatusSampleInterval: 30 * time.Second,
ChanEnableTimeout: chanActiveTimeout, ChanEnableTimeout: chanActiveTimeout,
ChanDisableTimeout: 2 * time.Minute, ChanDisableTimeout: 2 * time.Minute,
DB: dbAlice, DB: dbAlice.ChannelStateDB(),
Graph: dbAlice.ChannelGraph(), Graph: dbAlice.ChannelGraph(),
MessageSigner: nodeSignerAlice, MessageSigner: nodeSignerAlice,
OurPubKey: aliceKeyPub, OurPubKey: aliceKeyPub,
@@ -359,7 +359,7 @@ func createTestPeer(notifier chainntnfs.ChainNotifier,
ChanActiveTimeout: chanActiveTimeout, ChanActiveTimeout: chanActiveTimeout,
InterceptSwitch: htlcswitch.NewInterceptableSwitch(nil), InterceptSwitch: htlcswitch.NewInterceptableSwitch(nil),
ChannelDB: dbAlice, ChannelDB: dbAlice.ChannelStateDB(),
FeeEstimator: estimator, FeeEstimator: estimator,
Wallet: wallet, Wallet: wallet,
ChainNotifier: notifier, ChainNotifier: notifier,

View File

@@ -2,7 +2,6 @@ package routing
import ( import (
"github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/kvdb"
"github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/routing/route" "github.com/lightningnetwork/lnd/routing/route"
) )
@@ -10,10 +9,10 @@ import (
// routingGraph is an abstract interface that provides information about nodes // routingGraph is an abstract interface that provides information about nodes
// and edges to pathfinding. // and edges to pathfinding.
type routingGraph interface { type routingGraph interface {
// forEachNodeChannel calls the callback for every channel of the given node. // forEachNodeChannel calls the callback for every channel of the given
// node.
forEachNodeChannel(nodePub route.Vertex, forEachNodeChannel(nodePub route.Vertex,
cb func(*channeldb.ChannelEdgeInfo, *channeldb.ChannelEdgePolicy, cb func(channel *channeldb.DirectedChannel) error) error
*channeldb.ChannelEdgePolicy) error) error
// sourceNode returns the source node of the graph. // sourceNode returns the source node of the graph.
sourceNode() route.Vertex sourceNode() route.Vertex
@@ -22,59 +21,44 @@ type routingGraph interface {
fetchNodeFeatures(nodePub route.Vertex) (*lnwire.FeatureVector, error) fetchNodeFeatures(nodePub route.Vertex) (*lnwire.FeatureVector, error)
} }
// dbRoutingTx is a routingGraph implementation that retrieves from the // CachedGraph is a routingGraph implementation that retrieves from the
// database. // database.
type dbRoutingTx struct { type CachedGraph struct {
graph *channeldb.ChannelGraph graph *channeldb.ChannelGraph
tx kvdb.RTx
source route.Vertex source route.Vertex
} }
// newDbRoutingTx instantiates a new db-connected routing graph. It implictly // A compile time assertion to make sure CachedGraph implements the routingGraph
// interface.
var _ routingGraph = (*CachedGraph)(nil)
// NewCachedGraph instantiates a new db-connected routing graph. It implictly
// instantiates a new read transaction. // instantiates a new read transaction.
func newDbRoutingTx(graph *channeldb.ChannelGraph) (*dbRoutingTx, error) { func NewCachedGraph(graph *channeldb.ChannelGraph) (*CachedGraph, error) {
sourceNode, err := graph.SourceNode() sourceNode, err := graph.SourceNode()
if err != nil { if err != nil {
return nil, err return nil, err
} }
tx, err := graph.Database().BeginReadTx() return &CachedGraph{
if err != nil {
return nil, err
}
return &dbRoutingTx{
graph: graph, graph: graph,
tx: tx,
source: sourceNode.PubKeyBytes, source: sourceNode.PubKeyBytes,
}, nil }, nil
} }
// close closes the underlying db transaction.
func (g *dbRoutingTx) close() error {
return g.tx.Rollback()
}
// forEachNodeChannel calls the callback for every channel of the given node. // forEachNodeChannel calls the callback for every channel of the given node.
// //
// NOTE: Part of the routingGraph interface. // NOTE: Part of the routingGraph interface.
func (g *dbRoutingTx) forEachNodeChannel(nodePub route.Vertex, func (g *CachedGraph) forEachNodeChannel(nodePub route.Vertex,
cb func(*channeldb.ChannelEdgeInfo, *channeldb.ChannelEdgePolicy, cb func(channel *channeldb.DirectedChannel) error) error {
*channeldb.ChannelEdgePolicy) error) error {
txCb := func(_ kvdb.RTx, info *channeldb.ChannelEdgeInfo, return g.graph.ForEachNodeChannel(nodePub, cb)
p1, p2 *channeldb.ChannelEdgePolicy) error {
return cb(info, p1, p2)
}
return g.graph.ForEachNodeChannel(g.tx, nodePub[:], txCb)
} }
// sourceNode returns the source node of the graph. // sourceNode returns the source node of the graph.
// //
// NOTE: Part of the routingGraph interface. // NOTE: Part of the routingGraph interface.
func (g *dbRoutingTx) sourceNode() route.Vertex { func (g *CachedGraph) sourceNode() route.Vertex {
return g.source return g.source
} }
@@ -82,23 +66,8 @@ func (g *dbRoutingTx) sourceNode() route.Vertex {
// unknown, assume no additional features are supported. // unknown, assume no additional features are supported.
// //
// NOTE: Part of the routingGraph interface. // NOTE: Part of the routingGraph interface.
func (g *dbRoutingTx) fetchNodeFeatures(nodePub route.Vertex) ( func (g *CachedGraph) fetchNodeFeatures(nodePub route.Vertex) (
*lnwire.FeatureVector, error) { *lnwire.FeatureVector, error) {
targetNode, err := g.graph.FetchLightningNode(g.tx, nodePub) return g.graph.FetchNodeFeatures(nodePub)
switch err {
// If the node exists and has features, return them directly.
case nil:
return targetNode.Features, nil
// If we couldn't find a node announcement, populate a blank feature
// vector.
case channeldb.ErrGraphNodeNotFound:
return lnwire.EmptyFeatureVector(), nil
// Otherwise bubble the error up.
default:
return nil, err
}
} }

View File

@@ -39,7 +39,7 @@ type nodeWithDist struct {
weight int64 weight int64
// nextHop is the edge this route comes from. // nextHop is the edge this route comes from.
nextHop *channeldb.ChannelEdgePolicy nextHop *channeldb.CachedEdgePolicy
// routingInfoSize is the total size requirement for the payloads field // routingInfoSize is the total size requirement for the payloads field
// in the onion packet from this hop towards the final destination. // in the onion packet from this hop towards the final destination.

View File

@@ -162,11 +162,7 @@ func (c *integratedRoutingContext) testPayment(maxParts uint32,
} }
session, err := newPaymentSession( session, err := newPaymentSession(
&payment, getBandwidthHints, &payment, getBandwidthHints, c.graph, mc, c.pathFindingCfg,
func() (routingGraph, func(), error) {
return c.graph, func() {}, nil
},
mc, c.pathFindingCfg,
) )
if err != nil { if err != nil {
c.t.Fatal(err) c.t.Fatal(err)

View File

@@ -159,8 +159,7 @@ func (m *mockGraph) addChannel(id uint64, node1id, node2id byte,
// //
// NOTE: Part of the routingGraph interface. // NOTE: Part of the routingGraph interface.
func (m *mockGraph) forEachNodeChannel(nodePub route.Vertex, func (m *mockGraph) forEachNodeChannel(nodePub route.Vertex,
cb func(*channeldb.ChannelEdgeInfo, *channeldb.ChannelEdgePolicy, cb func(channel *channeldb.DirectedChannel) error) error {
*channeldb.ChannelEdgePolicy) error) error {
// Look up the mock node. // Look up the mock node.
node, ok := m.nodes[nodePub] node, ok := m.nodes[nodePub]
@@ -171,36 +170,31 @@ func (m *mockGraph) forEachNodeChannel(nodePub route.Vertex,
// Iterate over all of its channels. // Iterate over all of its channels.
for peer, channel := range node.channels { for peer, channel := range node.channels {
// Lexicographically sort the pubkeys. // Lexicographically sort the pubkeys.
var node1, node2 route.Vertex var node1 route.Vertex
if bytes.Compare(nodePub[:], peer[:]) == -1 { if bytes.Compare(nodePub[:], peer[:]) == -1 {
node1, node2 = peer, nodePub node1 = peer
} else { } else {
node1, node2 = nodePub, peer node1 = nodePub
} }
peerNode := m.nodes[peer] peerNode := m.nodes[peer]
// Call the per channel callback. // Call the per channel callback.
err := cb( err := cb(
&channeldb.ChannelEdgeInfo{ &channeldb.DirectedChannel{
NodeKey1Bytes: node1, ChannelID: channel.id,
NodeKey2Bytes: node2, IsNode1: nodePub == node1,
}, OtherNode: peer,
&channeldb.ChannelEdgePolicy{ Capacity: channel.capacity,
ChannelID: channel.id, OutPolicySet: true,
Node: &channeldb.LightningNode{ InPolicy: &channeldb.CachedEdgePolicy{
PubKeyBytes: peer, ChannelID: channel.id,
Features: lnwire.EmptyFeatureVector(), ToNodePubKey: func() route.Vertex {
return nodePub
},
ToNodeFeatures: lnwire.EmptyFeatureVector(),
FeeBaseMSat: peerNode.baseFee,
}, },
FeeBaseMSat: node.baseFee,
},
&channeldb.ChannelEdgePolicy{
ChannelID: channel.id,
Node: &channeldb.LightningNode{
PubKeyBytes: nodePub,
Features: lnwire.EmptyFeatureVector(),
},
FeeBaseMSat: peerNode.baseFee,
}, },
) )
if err != nil { if err != nil {

View File

@@ -173,13 +173,13 @@ func (m *mockPaymentSessionOld) RequestRoute(_, _ lnwire.MilliSatoshi,
} }
func (m *mockPaymentSessionOld) UpdateAdditionalEdge(_ *lnwire.ChannelUpdate, func (m *mockPaymentSessionOld) UpdateAdditionalEdge(_ *lnwire.ChannelUpdate,
_ *btcec.PublicKey, _ *channeldb.ChannelEdgePolicy) bool { _ *btcec.PublicKey, _ *channeldb.CachedEdgePolicy) bool {
return false return false
} }
func (m *mockPaymentSessionOld) GetAdditionalEdgePolicy(_ *btcec.PublicKey, func (m *mockPaymentSessionOld) GetAdditionalEdgePolicy(_ *btcec.PublicKey,
_ uint64) *channeldb.ChannelEdgePolicy { _ uint64) *channeldb.CachedEdgePolicy {
return nil return nil
} }
@@ -637,17 +637,17 @@ func (m *mockPaymentSession) RequestRoute(maxAmt, feeLimit lnwire.MilliSatoshi,
} }
func (m *mockPaymentSession) UpdateAdditionalEdge(msg *lnwire.ChannelUpdate, func (m *mockPaymentSession) UpdateAdditionalEdge(msg *lnwire.ChannelUpdate,
pubKey *btcec.PublicKey, policy *channeldb.ChannelEdgePolicy) bool { pubKey *btcec.PublicKey, policy *channeldb.CachedEdgePolicy) bool {
args := m.Called(msg, pubKey, policy) args := m.Called(msg, pubKey, policy)
return args.Bool(0) return args.Bool(0)
} }
func (m *mockPaymentSession) GetAdditionalEdgePolicy(pubKey *btcec.PublicKey, func (m *mockPaymentSession) GetAdditionalEdgePolicy(pubKey *btcec.PublicKey,
channelID uint64) *channeldb.ChannelEdgePolicy { channelID uint64) *channeldb.CachedEdgePolicy {
args := m.Called(pubKey, channelID) args := m.Called(pubKey, channelID)
return args.Get(0).(*channeldb.ChannelEdgePolicy) return args.Get(0).(*channeldb.CachedEdgePolicy)
} }
type mockControlTower struct { type mockControlTower struct {

View File

@@ -42,7 +42,7 @@ const (
type pathFinder = func(g *graphParams, r *RestrictParams, type pathFinder = func(g *graphParams, r *RestrictParams,
cfg *PathFindingConfig, source, target route.Vertex, cfg *PathFindingConfig, source, target route.Vertex,
amt lnwire.MilliSatoshi, finalHtlcExpiry int32) ( amt lnwire.MilliSatoshi, finalHtlcExpiry int32) (
[]*channeldb.ChannelEdgePolicy, error) []*channeldb.CachedEdgePolicy, error)
var ( var (
// DefaultAttemptCost is the default fixed virtual cost in path finding // DefaultAttemptCost is the default fixed virtual cost in path finding
@@ -76,7 +76,7 @@ var (
// of the edge. // of the edge.
type edgePolicyWithSource struct { type edgePolicyWithSource struct {
sourceNode route.Vertex sourceNode route.Vertex
edge *channeldb.ChannelEdgePolicy edge *channeldb.CachedEdgePolicy
} }
// finalHopParams encapsulates various parameters for route construction that // finalHopParams encapsulates various parameters for route construction that
@@ -102,7 +102,7 @@ type finalHopParams struct {
// any feature vectors on all hops have been validated for transitive // any feature vectors on all hops have been validated for transitive
// dependencies. // dependencies.
func newRoute(sourceVertex route.Vertex, func newRoute(sourceVertex route.Vertex,
pathEdges []*channeldb.ChannelEdgePolicy, currentHeight uint32, pathEdges []*channeldb.CachedEdgePolicy, currentHeight uint32,
finalHop finalHopParams) (*route.Route, error) { finalHop finalHopParams) (*route.Route, error) {
var ( var (
@@ -147,10 +147,10 @@ func newRoute(sourceVertex route.Vertex,
supports := func(feature lnwire.FeatureBit) bool { supports := func(feature lnwire.FeatureBit) bool {
// If this edge comes from router hints, the features // If this edge comes from router hints, the features
// could be nil. // could be nil.
if edge.Node.Features == nil { if edge.ToNodeFeatures == nil {
return false return false
} }
return edge.Node.Features.HasFeature(feature) return edge.ToNodeFeatures.HasFeature(feature)
} }
// We start by assuming the node doesn't support TLV. We'll now // We start by assuming the node doesn't support TLV. We'll now
@@ -225,7 +225,7 @@ func newRoute(sourceVertex route.Vertex,
// each new hop such that, the final slice of hops will be in // each new hop such that, the final slice of hops will be in
// the forwards order. // the forwards order.
currentHop := &route.Hop{ currentHop := &route.Hop{
PubKeyBytes: edge.Node.PubKeyBytes, PubKeyBytes: edge.ToNodePubKey(),
ChannelID: edge.ChannelID, ChannelID: edge.ChannelID,
AmtToForward: amtToForward, AmtToForward: amtToForward,
OutgoingTimeLock: outgoingTimeLock, OutgoingTimeLock: outgoingTimeLock,
@@ -280,7 +280,7 @@ type graphParams struct {
// additionalEdges is an optional set of edges that should be // additionalEdges is an optional set of edges that should be
// considered during path finding, that is not already found in the // considered during path finding, that is not already found in the
// channel graph. // channel graph.
additionalEdges map[route.Vertex][]*channeldb.ChannelEdgePolicy additionalEdges map[route.Vertex][]*channeldb.CachedEdgePolicy
// bandwidthHints is an optional map from channels to bandwidths that // bandwidthHints is an optional map from channels to bandwidths that
// can be populated if the caller has a better estimate of the current // can be populated if the caller has a better estimate of the current
@@ -359,14 +359,12 @@ func getOutgoingBalance(node route.Vertex, outgoingChans map[uint64]struct{},
g routingGraph) (lnwire.MilliSatoshi, lnwire.MilliSatoshi, error) { g routingGraph) (lnwire.MilliSatoshi, lnwire.MilliSatoshi, error) {
var max, total lnwire.MilliSatoshi var max, total lnwire.MilliSatoshi
cb := func(edgeInfo *channeldb.ChannelEdgeInfo, outEdge, cb := func(channel *channeldb.DirectedChannel) error {
_ *channeldb.ChannelEdgePolicy) error { if !channel.OutPolicySet {
if outEdge == nil {
return nil return nil
} }
chanID := outEdge.ChannelID chanID := channel.ChannelID
// Enforce outgoing channel restriction. // Enforce outgoing channel restriction.
if outgoingChans != nil { if outgoingChans != nil {
@@ -381,9 +379,7 @@ func getOutgoingBalance(node route.Vertex, outgoingChans map[uint64]struct{},
// This can happen when a channel is added to the graph after // This can happen when a channel is added to the graph after
// we've already queried the bandwidth hints. // we've already queried the bandwidth hints.
if !ok { if !ok {
bandwidth = lnwire.NewMSatFromSatoshis( bandwidth = lnwire.NewMSatFromSatoshis(channel.Capacity)
edgeInfo.Capacity,
)
} }
if bandwidth > max { if bandwidth > max {
@@ -416,7 +412,7 @@ func getOutgoingBalance(node route.Vertex, outgoingChans map[uint64]struct{},
// available bandwidth. // available bandwidth.
func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig, func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig,
source, target route.Vertex, amt lnwire.MilliSatoshi, source, target route.Vertex, amt lnwire.MilliSatoshi,
finalHtlcExpiry int32) ([]*channeldb.ChannelEdgePolicy, error) { finalHtlcExpiry int32) ([]*channeldb.CachedEdgePolicy, error) {
// Pathfinding can be a significant portion of the total payment // Pathfinding can be a significant portion of the total payment
// latency, especially on low-powered devices. Log several metrics to // latency, especially on low-powered devices. Log several metrics to
@@ -523,7 +519,7 @@ func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig,
// Build reverse lookup to find incoming edges. Needed because // Build reverse lookup to find incoming edges. Needed because
// search is taken place from target to source. // search is taken place from target to source.
for _, outgoingEdgePolicy := range outgoingEdgePolicies { for _, outgoingEdgePolicy := range outgoingEdgePolicies {
toVertex := outgoingEdgePolicy.Node.PubKeyBytes toVertex := outgoingEdgePolicy.ToNodePubKey()
incomingEdgePolicy := &edgePolicyWithSource{ incomingEdgePolicy := &edgePolicyWithSource{
sourceNode: vertex, sourceNode: vertex,
edge: outgoingEdgePolicy, edge: outgoingEdgePolicy,
@@ -587,7 +583,7 @@ func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig,
// satisfy our specific requirements. // satisfy our specific requirements.
processEdge := func(fromVertex route.Vertex, processEdge := func(fromVertex route.Vertex,
fromFeatures *lnwire.FeatureVector, fromFeatures *lnwire.FeatureVector,
edge *channeldb.ChannelEdgePolicy, toNodeDist *nodeWithDist) { edge *channeldb.CachedEdgePolicy, toNodeDist *nodeWithDist) {
edgesExpanded++ edgesExpanded++
@@ -883,13 +879,14 @@ func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig,
// Use the distance map to unravel the forward path from source to // Use the distance map to unravel the forward path from source to
// target. // target.
var pathEdges []*channeldb.ChannelEdgePolicy var pathEdges []*channeldb.CachedEdgePolicy
currentNode := source currentNode := source
for { for {
// Determine the next hop forward using the next map. // Determine the next hop forward using the next map.
currentNodeWithDist, ok := distance[currentNode] currentNodeWithDist, ok := distance[currentNode]
if !ok { if !ok {
// If the node doesnt have a next hop it means we didn't find a path. // If the node doesn't have a next hop it means we
// didn't find a path.
return nil, errNoPathFound return nil, errNoPathFound
} }
@@ -897,7 +894,7 @@ func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig,
pathEdges = append(pathEdges, currentNodeWithDist.nextHop) pathEdges = append(pathEdges, currentNodeWithDist.nextHop)
// Advance current node. // Advance current node.
currentNode = currentNodeWithDist.nextHop.Node.PubKeyBytes currentNode = currentNodeWithDist.nextHop.ToNodePubKey()
// Check stop condition at the end of this loop. This prevents // Check stop condition at the end of this loop. This prevents
// breaking out too soon for self-payments that have target set // breaking out too soon for self-payments that have target set
@@ -918,7 +915,7 @@ func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig,
// route construction does not care where the features are actually // route construction does not care where the features are actually
// taken from. In the future we may wish to do route construction within // taken from. In the future we may wish to do route construction within
// findPath, and avoid using ChannelEdgePolicy altogether. // findPath, and avoid using ChannelEdgePolicy altogether.
pathEdges[len(pathEdges)-1].Node.Features = features pathEdges[len(pathEdges)-1].ToNodeFeatures = features
log.Debugf("Found route: probability=%v, hops=%v, fee=%v", log.Debugf("Found route: probability=%v, hops=%v, fee=%v",
distance[source].probability, len(pathEdges), distance[source].probability, len(pathEdges),

View File

@@ -23,6 +23,7 @@ import (
"github.com/btcsuite/btcd/wire" "github.com/btcsuite/btcd/wire"
"github.com/btcsuite/btcutil" "github.com/btcsuite/btcutil"
"github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/kvdb"
"github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/record" "github.com/lightningnetwork/lnd/record"
"github.com/lightningnetwork/lnd/routing/route" "github.com/lightningnetwork/lnd/routing/route"
@@ -148,26 +149,36 @@ type testChan struct {
// makeTestGraph creates a new instance of a channeldb.ChannelGraph for testing // makeTestGraph creates a new instance of a channeldb.ChannelGraph for testing
// purposes. A callback which cleans up the created temporary directories is // purposes. A callback which cleans up the created temporary directories is
// also returned and intended to be executed after the test completes. // also returned and intended to be executed after the test completes.
func makeTestGraph() (*channeldb.ChannelGraph, func(), error) { func makeTestGraph() (*channeldb.ChannelGraph, kvdb.Backend, func(), error) {
// First, create a temporary directory to be used for the duration of // First, create a temporary directory to be used for the duration of
// this test. // this test.
tempDirName, err := ioutil.TempDir("", "channeldb") tempDirName, err := ioutil.TempDir("", "channeldb")
if err != nil { if err != nil {
return nil, nil, err return nil, nil, nil, err
} }
// Next, create channeldb for the first time. // Next, create channelgraph for the first time.
cdb, err := channeldb.Open(tempDirName) backend, backendCleanup, err := kvdb.GetTestBackend(tempDirName, "cgr")
if err != nil { if err != nil {
return nil, nil, err return nil, nil, nil, err
} }
cleanUp := func() { cleanUp := func() {
cdb.Close() backendCleanup()
os.RemoveAll(tempDirName) _ = os.RemoveAll(tempDirName)
} }
return cdb.ChannelGraph(), cleanUp, nil opts := channeldb.DefaultOptions()
graph, err := channeldb.NewChannelGraph(
backend, opts.RejectCacheSize, opts.ChannelCacheSize,
opts.BatchCommitInterval, opts.PreAllocCacheNumNodes,
)
if err != nil {
cleanUp()
return nil, nil, nil, err
}
return graph, backend, cleanUp, nil
} }
// parseTestGraph returns a fully populated ChannelGraph given a path to a JSON // parseTestGraph returns a fully populated ChannelGraph given a path to a JSON
@@ -197,7 +208,7 @@ func parseTestGraph(path string) (*testGraphInstance, error) {
testAddrs = append(testAddrs, testAddr) testAddrs = append(testAddrs, testAddr)
// Next, create a temporary graph database for usage within the test. // Next, create a temporary graph database for usage within the test.
graph, cleanUp, err := makeTestGraph() graph, graphBackend, cleanUp, err := makeTestGraph()
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -293,6 +304,16 @@ func parseTestGraph(path string) (*testGraphInstance, error) {
} }
} }
aliasForNode := func(node route.Vertex) string {
for alias, pubKey := range aliasMap {
if pubKey == node {
return alias
}
}
return ""
}
// With all the vertexes inserted, we can now insert the edges into the // With all the vertexes inserted, we can now insert the edges into the
// test graph. // test graph.
for _, edge := range g.Edges { for _, edge := range g.Edges {
@@ -342,10 +363,17 @@ func parseTestGraph(path string) (*testGraphInstance, error) {
return nil, err return nil, err
} }
channelFlags := lnwire.ChanUpdateChanFlags(edge.ChannelFlags)
isUpdate1 := channelFlags&lnwire.ChanUpdateDirection == 0
targetNode := edgeInfo.NodeKey1Bytes
if isUpdate1 {
targetNode = edgeInfo.NodeKey2Bytes
}
edgePolicy := &channeldb.ChannelEdgePolicy{ edgePolicy := &channeldb.ChannelEdgePolicy{
SigBytes: testSig.Serialize(), SigBytes: testSig.Serialize(),
MessageFlags: lnwire.ChanUpdateMsgFlags(edge.MessageFlags), MessageFlags: lnwire.ChanUpdateMsgFlags(edge.MessageFlags),
ChannelFlags: lnwire.ChanUpdateChanFlags(edge.ChannelFlags), ChannelFlags: channelFlags,
ChannelID: edge.ChannelID, ChannelID: edge.ChannelID,
LastUpdate: testTime, LastUpdate: testTime,
TimeLockDelta: edge.Expiry, TimeLockDelta: edge.Expiry,
@@ -353,6 +381,10 @@ func parseTestGraph(path string) (*testGraphInstance, error) {
MaxHTLC: lnwire.MilliSatoshi(edge.MaxHTLC), MaxHTLC: lnwire.MilliSatoshi(edge.MaxHTLC),
FeeBaseMSat: lnwire.MilliSatoshi(edge.FeeBaseMsat), FeeBaseMSat: lnwire.MilliSatoshi(edge.FeeBaseMsat),
FeeProportionalMillionths: lnwire.MilliSatoshi(edge.FeeRate), FeeProportionalMillionths: lnwire.MilliSatoshi(edge.FeeRate),
Node: &channeldb.LightningNode{
Alias: aliasForNode(targetNode),
PubKeyBytes: targetNode,
},
} }
if err := graph.UpdateEdgePolicy(edgePolicy); err != nil { if err := graph.UpdateEdgePolicy(edgePolicy); err != nil {
return nil, err return nil, err
@@ -381,11 +413,12 @@ func parseTestGraph(path string) (*testGraphInstance, error) {
} }
return &testGraphInstance{ return &testGraphInstance{
graph: graph, graph: graph,
cleanUp: cleanUp, graphBackend: graphBackend,
aliasMap: aliasMap, cleanUp: cleanUp,
privKeyMap: privKeyMap, aliasMap: aliasMap,
channelIDs: channelIDs, privKeyMap: privKeyMap,
channelIDs: channelIDs,
}, nil }, nil
} }
@@ -447,8 +480,9 @@ type testChannel struct {
} }
type testGraphInstance struct { type testGraphInstance struct {
graph *channeldb.ChannelGraph graph *channeldb.ChannelGraph
cleanUp func() graphBackend kvdb.Backend
cleanUp func()
// aliasMap is a map from a node's alias to its public key. This type is // aliasMap is a map from a node's alias to its public key. This type is
// provided in order to allow easily look up from the human memorable alias // provided in order to allow easily look up from the human memorable alias
@@ -482,7 +516,7 @@ func createTestGraphFromChannels(testChannels []*testChannel, source string) (
testAddrs = append(testAddrs, testAddr) testAddrs = append(testAddrs, testAddr)
// Next, create a temporary graph database for usage within the test. // Next, create a temporary graph database for usage within the test.
graph, cleanUp, err := makeTestGraph() graph, graphBackend, cleanUp, err := makeTestGraph()
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -622,6 +656,11 @@ func createTestGraphFromChannels(testChannels []*testChannel, source string) (
channelFlags |= lnwire.ChanUpdateDisabled channelFlags |= lnwire.ChanUpdateDisabled
} }
node2Features := lnwire.EmptyFeatureVector()
if node2.testChannelPolicy != nil {
node2Features = node2.Features
}
edgePolicy := &channeldb.ChannelEdgePolicy{ edgePolicy := &channeldb.ChannelEdgePolicy{
SigBytes: testSig.Serialize(), SigBytes: testSig.Serialize(),
MessageFlags: msgFlags, MessageFlags: msgFlags,
@@ -633,6 +672,11 @@ func createTestGraphFromChannels(testChannels []*testChannel, source string) (
MaxHTLC: node1.MaxHTLC, MaxHTLC: node1.MaxHTLC,
FeeBaseMSat: node1.FeeBaseMsat, FeeBaseMSat: node1.FeeBaseMsat,
FeeProportionalMillionths: node1.FeeRate, FeeProportionalMillionths: node1.FeeRate,
Node: &channeldb.LightningNode{
Alias: node2.Alias,
PubKeyBytes: node2Vertex,
Features: node2Features,
},
} }
if err := graph.UpdateEdgePolicy(edgePolicy); err != nil { if err := graph.UpdateEdgePolicy(edgePolicy); err != nil {
return nil, err return nil, err
@@ -650,6 +694,11 @@ func createTestGraphFromChannels(testChannels []*testChannel, source string) (
} }
channelFlags |= lnwire.ChanUpdateDirection channelFlags |= lnwire.ChanUpdateDirection
node1Features := lnwire.EmptyFeatureVector()
if node1.testChannelPolicy != nil {
node1Features = node1.Features
}
edgePolicy := &channeldb.ChannelEdgePolicy{ edgePolicy := &channeldb.ChannelEdgePolicy{
SigBytes: testSig.Serialize(), SigBytes: testSig.Serialize(),
MessageFlags: msgFlags, MessageFlags: msgFlags,
@@ -661,6 +710,11 @@ func createTestGraphFromChannels(testChannels []*testChannel, source string) (
MaxHTLC: node2.MaxHTLC, MaxHTLC: node2.MaxHTLC,
FeeBaseMSat: node2.FeeBaseMsat, FeeBaseMSat: node2.FeeBaseMsat,
FeeProportionalMillionths: node2.FeeRate, FeeProportionalMillionths: node2.FeeRate,
Node: &channeldb.LightningNode{
Alias: node1.Alias,
PubKeyBytes: node1Vertex,
Features: node1Features,
},
} }
if err := graph.UpdateEdgePolicy(edgePolicy); err != nil { if err := graph.UpdateEdgePolicy(edgePolicy); err != nil {
return nil, err return nil, err
@@ -671,10 +725,11 @@ func createTestGraphFromChannels(testChannels []*testChannel, source string) (
} }
return &testGraphInstance{ return &testGraphInstance{
graph: graph, graph: graph,
cleanUp: cleanUp, graphBackend: graphBackend,
aliasMap: aliasMap, cleanUp: cleanUp,
privKeyMap: privKeyMap, aliasMap: aliasMap,
privKeyMap: privKeyMap,
}, nil }, nil
} }
@@ -1044,20 +1099,23 @@ func TestPathFindingWithAdditionalEdges(t *testing.T) {
// Create the channel edge going from songoku to doge and include it in // Create the channel edge going from songoku to doge and include it in
// our map of additional edges. // our map of additional edges.
songokuToDoge := &channeldb.ChannelEdgePolicy{ songokuToDoge := &channeldb.CachedEdgePolicy{
Node: doge, ToNodePubKey: func() route.Vertex {
return doge.PubKeyBytes
},
ToNodeFeatures: lnwire.EmptyFeatureVector(),
ChannelID: 1337, ChannelID: 1337,
FeeBaseMSat: 1, FeeBaseMSat: 1,
FeeProportionalMillionths: 1000, FeeProportionalMillionths: 1000,
TimeLockDelta: 9, TimeLockDelta: 9,
} }
additionalEdges := map[route.Vertex][]*channeldb.ChannelEdgePolicy{ additionalEdges := map[route.Vertex][]*channeldb.CachedEdgePolicy{
graph.aliasMap["songoku"]: {songokuToDoge}, graph.aliasMap["songoku"]: {songokuToDoge},
} }
find := func(r *RestrictParams) ( find := func(r *RestrictParams) (
[]*channeldb.ChannelEdgePolicy, error) { []*channeldb.CachedEdgePolicy, error) {
return dbFindPath( return dbFindPath(
graph.graph, additionalEdges, nil, graph.graph, additionalEdges, nil,
@@ -1124,14 +1182,13 @@ func TestNewRoute(t *testing.T) {
createHop := func(baseFee lnwire.MilliSatoshi, createHop := func(baseFee lnwire.MilliSatoshi,
feeRate lnwire.MilliSatoshi, feeRate lnwire.MilliSatoshi,
bandwidth lnwire.MilliSatoshi, bandwidth lnwire.MilliSatoshi,
timeLockDelta uint16) *channeldb.ChannelEdgePolicy { timeLockDelta uint16) *channeldb.CachedEdgePolicy {
return &channeldb.ChannelEdgePolicy{ return &channeldb.CachedEdgePolicy{
Node: &channeldb.LightningNode{ ToNodePubKey: func() route.Vertex {
Features: lnwire.NewFeatureVector( return route.Vertex{}
nil, nil,
),
}, },
ToNodeFeatures: lnwire.NewFeatureVector(nil, nil),
FeeProportionalMillionths: feeRate, FeeProportionalMillionths: feeRate,
FeeBaseMSat: baseFee, FeeBaseMSat: baseFee,
TimeLockDelta: timeLockDelta, TimeLockDelta: timeLockDelta,
@@ -1144,7 +1201,7 @@ func TestNewRoute(t *testing.T) {
// hops is the list of hops (the route) that gets passed into // hops is the list of hops (the route) that gets passed into
// the call to newRoute. // the call to newRoute.
hops []*channeldb.ChannelEdgePolicy hops []*channeldb.CachedEdgePolicy
// paymentAmount is the amount that is send into the route // paymentAmount is the amount that is send into the route
// indicated by hops. // indicated by hops.
@@ -1193,7 +1250,7 @@ func TestNewRoute(t *testing.T) {
// For a single hop payment, no fees are expected to be paid. // For a single hop payment, no fees are expected to be paid.
name: "single hop", name: "single hop",
paymentAmount: 100000, paymentAmount: 100000,
hops: []*channeldb.ChannelEdgePolicy{ hops: []*channeldb.CachedEdgePolicy{
createHop(100, 1000, 1000000, 10), createHop(100, 1000, 1000000, 10),
}, },
expectedFees: []lnwire.MilliSatoshi{0}, expectedFees: []lnwire.MilliSatoshi{0},
@@ -1206,7 +1263,7 @@ func TestNewRoute(t *testing.T) {
// a fee to receive the payment. // a fee to receive the payment.
name: "two hop", name: "two hop",
paymentAmount: 100000, paymentAmount: 100000,
hops: []*channeldb.ChannelEdgePolicy{ hops: []*channeldb.CachedEdgePolicy{
createHop(0, 1000, 1000000, 10), createHop(0, 1000, 1000000, 10),
createHop(30, 1000, 1000000, 5), createHop(30, 1000, 1000000, 5),
}, },
@@ -1221,7 +1278,7 @@ func TestNewRoute(t *testing.T) {
name: "two hop tlv onion feature", name: "two hop tlv onion feature",
destFeatures: tlvFeatures, destFeatures: tlvFeatures,
paymentAmount: 100000, paymentAmount: 100000,
hops: []*channeldb.ChannelEdgePolicy{ hops: []*channeldb.CachedEdgePolicy{
createHop(0, 1000, 1000000, 10), createHop(0, 1000, 1000000, 10),
createHop(30, 1000, 1000000, 5), createHop(30, 1000, 1000000, 5),
}, },
@@ -1238,7 +1295,7 @@ func TestNewRoute(t *testing.T) {
destFeatures: tlvPayAddrFeatures, destFeatures: tlvPayAddrFeatures,
paymentAddr: &testPaymentAddr, paymentAddr: &testPaymentAddr,
paymentAmount: 100000, paymentAmount: 100000,
hops: []*channeldb.ChannelEdgePolicy{ hops: []*channeldb.CachedEdgePolicy{
createHop(0, 1000, 1000000, 10), createHop(0, 1000, 1000000, 10),
createHop(30, 1000, 1000000, 5), createHop(30, 1000, 1000000, 5),
}, },
@@ -1258,7 +1315,7 @@ func TestNewRoute(t *testing.T) {
// gets rounded down to 1. // gets rounded down to 1.
name: "three hop", name: "three hop",
paymentAmount: 100000, paymentAmount: 100000,
hops: []*channeldb.ChannelEdgePolicy{ hops: []*channeldb.CachedEdgePolicy{
createHop(0, 10, 1000000, 10), createHop(0, 10, 1000000, 10),
createHop(0, 10, 1000000, 5), createHop(0, 10, 1000000, 5),
createHop(0, 10, 1000000, 3), createHop(0, 10, 1000000, 3),
@@ -1273,7 +1330,7 @@ func TestNewRoute(t *testing.T) {
// because of the increase amount to forward. // because of the increase amount to forward.
name: "three hop with fee carry over", name: "three hop with fee carry over",
paymentAmount: 100000, paymentAmount: 100000,
hops: []*channeldb.ChannelEdgePolicy{ hops: []*channeldb.CachedEdgePolicy{
createHop(0, 10000, 1000000, 10), createHop(0, 10000, 1000000, 10),
createHop(0, 10000, 1000000, 5), createHop(0, 10000, 1000000, 5),
createHop(0, 10000, 1000000, 3), createHop(0, 10000, 1000000, 3),
@@ -1288,7 +1345,7 @@ func TestNewRoute(t *testing.T) {
// effect. // effect.
name: "three hop with minimal fees for carry over", name: "three hop with minimal fees for carry over",
paymentAmount: 100000, paymentAmount: 100000,
hops: []*channeldb.ChannelEdgePolicy{ hops: []*channeldb.CachedEdgePolicy{
createHop(0, 10000, 1000000, 10), createHop(0, 10000, 1000000, 10),
// First hop charges 0.1% so the second hop fee // First hop charges 0.1% so the second hop fee
@@ -1312,7 +1369,7 @@ func TestNewRoute(t *testing.T) {
// custom feature vector. // custom feature vector.
if testCase.destFeatures != nil { if testCase.destFeatures != nil {
finalHop := testCase.hops[len(testCase.hops)-1] finalHop := testCase.hops[len(testCase.hops)-1]
finalHop.Node.Features = testCase.destFeatures finalHop.ToNodeFeatures = testCase.destFeatures
} }
assertRoute := func(t *testing.T, route *route.Route) { assertRoute := func(t *testing.T, route *route.Route) {
@@ -1539,7 +1596,7 @@ func TestDestTLVGraphFallback(t *testing.T) {
} }
find := func(r *RestrictParams, find := func(r *RestrictParams,
target route.Vertex) ([]*channeldb.ChannelEdgePolicy, error) { target route.Vertex) ([]*channeldb.CachedEdgePolicy, error) {
return dbFindPath( return dbFindPath(
ctx.graph, nil, nil, ctx.graph, nil, nil,
@@ -2120,7 +2177,7 @@ func TestPathFindSpecExample(t *testing.T) {
// Carol, so we set "B" as the source node so path finding starts from // Carol, so we set "B" as the source node so path finding starts from
// Bob. // Bob.
bob := ctx.aliases["B"] bob := ctx.aliases["B"]
bobNode, err := ctx.graph.FetchLightningNode(nil, bob) bobNode, err := ctx.graph.FetchLightningNode(bob)
if err != nil { if err != nil {
t.Fatalf("unable to find bob: %v", err) t.Fatalf("unable to find bob: %v", err)
} }
@@ -2170,7 +2227,7 @@ func TestPathFindSpecExample(t *testing.T) {
// Next, we'll set A as the source node so we can assert that we create // Next, we'll set A as the source node so we can assert that we create
// the proper route for any queries starting with Alice. // the proper route for any queries starting with Alice.
alice := ctx.aliases["A"] alice := ctx.aliases["A"]
aliceNode, err := ctx.graph.FetchLightningNode(nil, alice) aliceNode, err := ctx.graph.FetchLightningNode(alice)
if err != nil { if err != nil {
t.Fatalf("unable to find alice: %v", err) t.Fatalf("unable to find alice: %v", err)
} }
@@ -2270,16 +2327,16 @@ func TestPathFindSpecExample(t *testing.T) {
} }
func assertExpectedPath(t *testing.T, aliasMap map[string]route.Vertex, func assertExpectedPath(t *testing.T, aliasMap map[string]route.Vertex,
path []*channeldb.ChannelEdgePolicy, nodeAliases ...string) { path []*channeldb.CachedEdgePolicy, nodeAliases ...string) {
if len(path) != len(nodeAliases) { if len(path) != len(nodeAliases) {
t.Fatal("number of hops and number of aliases do not match") t.Fatal("number of hops and number of aliases do not match")
} }
for i, hop := range path { for i, hop := range path {
if hop.Node.PubKeyBytes != aliasMap[nodeAliases[i]] { if hop.ToNodePubKey() != aliasMap[nodeAliases[i]] {
t.Fatalf("expected %v to be pos #%v in hop, instead "+ t.Fatalf("expected %v to be pos #%v in hop, instead "+
"%v was", nodeAliases[i], i, hop.Node.Alias) "%v was", nodeAliases[i], i, hop.ToNodePubKey())
} }
} }
} }
@@ -2930,7 +2987,7 @@ func (c *pathFindingTestContext) cleanup() {
} }
func (c *pathFindingTestContext) findPath(target route.Vertex, func (c *pathFindingTestContext) findPath(target route.Vertex,
amt lnwire.MilliSatoshi) ([]*channeldb.ChannelEdgePolicy, amt lnwire.MilliSatoshi) ([]*channeldb.CachedEdgePolicy,
error) { error) {
return dbFindPath( return dbFindPath(
@@ -2939,7 +2996,9 @@ func (c *pathFindingTestContext) findPath(target route.Vertex,
) )
} }
func (c *pathFindingTestContext) assertPath(path []*channeldb.ChannelEdgePolicy, expected []uint64) { func (c *pathFindingTestContext) assertPath(path []*channeldb.CachedEdgePolicy,
expected []uint64) {
if len(path) != len(expected) { if len(path) != len(expected) {
c.t.Fatalf("expected path of length %v, but got %v", c.t.Fatalf("expected path of length %v, but got %v",
len(expected), len(path)) len(expected), len(path))
@@ -2956,28 +3015,22 @@ func (c *pathFindingTestContext) assertPath(path []*channeldb.ChannelEdgePolicy,
// dbFindPath calls findPath after getting a db transaction from the database // dbFindPath calls findPath after getting a db transaction from the database
// graph. // graph.
func dbFindPath(graph *channeldb.ChannelGraph, func dbFindPath(graph *channeldb.ChannelGraph,
additionalEdges map[route.Vertex][]*channeldb.ChannelEdgePolicy, additionalEdges map[route.Vertex][]*channeldb.CachedEdgePolicy,
bandwidthHints map[uint64]lnwire.MilliSatoshi, bandwidthHints map[uint64]lnwire.MilliSatoshi,
r *RestrictParams, cfg *PathFindingConfig, r *RestrictParams, cfg *PathFindingConfig,
source, target route.Vertex, amt lnwire.MilliSatoshi, source, target route.Vertex, amt lnwire.MilliSatoshi,
finalHtlcExpiry int32) ([]*channeldb.ChannelEdgePolicy, error) { finalHtlcExpiry int32) ([]*channeldb.CachedEdgePolicy, error) {
routingTx, err := newDbRoutingTx(graph) routingGraph, err := NewCachedGraph(graph)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer func() {
err := routingTx.close()
if err != nil {
log.Errorf("Error closing db tx: %v", err)
}
}()
return findPath( return findPath(
&graphParams{ &graphParams{
additionalEdges: additionalEdges, additionalEdges: additionalEdges,
bandwidthHints: bandwidthHints, bandwidthHints: bandwidthHints,
graph: routingTx, graph: routingGraph,
}, },
r, cfg, source, target, amt, finalHtlcExpiry, r, cfg, source, target, amt, finalHtlcExpiry,
) )

View File

@@ -898,7 +898,7 @@ func (p *shardHandler) handleFailureMessage(rt *route.Route,
var ( var (
isAdditionalEdge bool isAdditionalEdge bool
policy *channeldb.ChannelEdgePolicy policy *channeldb.CachedEdgePolicy
) )
// Before we apply the channel update, we need to decide whether the // Before we apply the channel update, we need to decide whether the

View File

@@ -472,8 +472,8 @@ func testPaymentLifecycle(t *testing.T, test paymentLifecycleTestCase,
Payer: payer, Payer: payer,
ChannelPruneExpiry: time.Hour * 24, ChannelPruneExpiry: time.Hour * 24,
GraphPruneInterval: time.Hour * 2, GraphPruneInterval: time.Hour * 2,
QueryBandwidth: func(e *channeldb.ChannelEdgeInfo) lnwire.MilliSatoshi { QueryBandwidth: func(c *channeldb.DirectedChannel) lnwire.MilliSatoshi {
return lnwire.NewMSatFromSatoshis(e.Capacity) return lnwire.NewMSatFromSatoshis(c.Capacity)
}, },
NextPaymentID: func() (uint64, error) { NextPaymentID: func() (uint64, error) {
next := atomic.AddUint64(&uniquePaymentID, 1) next := atomic.AddUint64(&uniquePaymentID, 1)

View File

@@ -144,13 +144,13 @@ type PaymentSession interface {
// a boolean to indicate whether the update has been applied without // a boolean to indicate whether the update has been applied without
// error. // error.
UpdateAdditionalEdge(msg *lnwire.ChannelUpdate, pubKey *btcec.PublicKey, UpdateAdditionalEdge(msg *lnwire.ChannelUpdate, pubKey *btcec.PublicKey,
policy *channeldb.ChannelEdgePolicy) bool policy *channeldb.CachedEdgePolicy) bool
// GetAdditionalEdgePolicy uses the public key and channel ID to query // GetAdditionalEdgePolicy uses the public key and channel ID to query
// the ephemeral channel edge policy for additional edges. Returns a nil // the ephemeral channel edge policy for additional edges. Returns a nil
// if nothing found. // if nothing found.
GetAdditionalEdgePolicy(pubKey *btcec.PublicKey, GetAdditionalEdgePolicy(pubKey *btcec.PublicKey,
channelID uint64) *channeldb.ChannelEdgePolicy channelID uint64) *channeldb.CachedEdgePolicy
} }
// paymentSession is used during an HTLC routings session to prune the local // paymentSession is used during an HTLC routings session to prune the local
@@ -162,7 +162,7 @@ type PaymentSession interface {
// loop if payment attempts take long enough. An additional set of edges can // loop if payment attempts take long enough. An additional set of edges can
// also be provided to assist in reaching the payment's destination. // also be provided to assist in reaching the payment's destination.
type paymentSession struct { type paymentSession struct {
additionalEdges map[route.Vertex][]*channeldb.ChannelEdgePolicy additionalEdges map[route.Vertex][]*channeldb.CachedEdgePolicy
getBandwidthHints func() (map[uint64]lnwire.MilliSatoshi, error) getBandwidthHints func() (map[uint64]lnwire.MilliSatoshi, error)
@@ -172,7 +172,7 @@ type paymentSession struct {
pathFinder pathFinder pathFinder pathFinder
getRoutingGraph func() (routingGraph, func(), error) routingGraph routingGraph
// pathFindingConfig defines global parameters that control the // pathFindingConfig defines global parameters that control the
// trade-off in path finding between fees and probabiity. // trade-off in path finding between fees and probabiity.
@@ -193,7 +193,7 @@ type paymentSession struct {
// newPaymentSession instantiates a new payment session. // newPaymentSession instantiates a new payment session.
func newPaymentSession(p *LightningPayment, func newPaymentSession(p *LightningPayment,
getBandwidthHints func() (map[uint64]lnwire.MilliSatoshi, error), getBandwidthHints func() (map[uint64]lnwire.MilliSatoshi, error),
getRoutingGraph func() (routingGraph, func(), error), routingGraph routingGraph,
missionControl MissionController, pathFindingConfig PathFindingConfig) ( missionControl MissionController, pathFindingConfig PathFindingConfig) (
*paymentSession, error) { *paymentSession, error) {
@@ -209,7 +209,7 @@ func newPaymentSession(p *LightningPayment,
getBandwidthHints: getBandwidthHints, getBandwidthHints: getBandwidthHints,
payment: p, payment: p,
pathFinder: findPath, pathFinder: findPath,
getRoutingGraph: getRoutingGraph, routingGraph: routingGraph,
pathFindingConfig: pathFindingConfig, pathFindingConfig: pathFindingConfig,
missionControl: missionControl, missionControl: missionControl,
minShardAmt: DefaultShardMinAmt, minShardAmt: DefaultShardMinAmt,
@@ -287,29 +287,20 @@ func (p *paymentSession) RequestRoute(maxAmt, feeLimit lnwire.MilliSatoshi,
p.log.Debugf("pathfinding for amt=%v", maxAmt) p.log.Debugf("pathfinding for amt=%v", maxAmt)
// Get a routing graph. sourceVertex := p.routingGraph.sourceNode()
routingGraph, cleanup, err := p.getRoutingGraph()
if err != nil {
return nil, err
}
sourceVertex := routingGraph.sourceNode()
// Find a route for the current amount. // Find a route for the current amount.
path, err := p.pathFinder( path, err := p.pathFinder(
&graphParams{ &graphParams{
additionalEdges: p.additionalEdges, additionalEdges: p.additionalEdges,
bandwidthHints: bandwidthHints, bandwidthHints: bandwidthHints,
graph: routingGraph, graph: p.routingGraph,
}, },
restrictions, &p.pathFindingConfig, restrictions, &p.pathFindingConfig,
sourceVertex, p.payment.Target, sourceVertex, p.payment.Target,
maxAmt, finalHtlcExpiry, maxAmt, finalHtlcExpiry,
) )
// Close routing graph.
cleanup()
switch { switch {
case err == errNoPathFound: case err == errNoPathFound:
// Don't split if this is a legacy payment without mpp // Don't split if this is a legacy payment without mpp
@@ -403,7 +394,7 @@ func (p *paymentSession) RequestRoute(maxAmt, feeLimit lnwire.MilliSatoshi,
// updates to the supplied policy. It returns a boolean to indicate whether // updates to the supplied policy. It returns a boolean to indicate whether
// there's an error when applying the updates. // there's an error when applying the updates.
func (p *paymentSession) UpdateAdditionalEdge(msg *lnwire.ChannelUpdate, func (p *paymentSession) UpdateAdditionalEdge(msg *lnwire.ChannelUpdate,
pubKey *btcec.PublicKey, policy *channeldb.ChannelEdgePolicy) bool { pubKey *btcec.PublicKey, policy *channeldb.CachedEdgePolicy) bool {
// Validate the message signature. // Validate the message signature.
if err := VerifyChannelUpdateSignature(msg, pubKey); err != nil { if err := VerifyChannelUpdateSignature(msg, pubKey); err != nil {
@@ -428,7 +419,7 @@ func (p *paymentSession) UpdateAdditionalEdge(msg *lnwire.ChannelUpdate,
// ephemeral channel edge policy for additional edges. Returns a nil if nothing // ephemeral channel edge policy for additional edges. Returns a nil if nothing
// found. // found.
func (p *paymentSession) GetAdditionalEdgePolicy(pubKey *btcec.PublicKey, func (p *paymentSession) GetAdditionalEdgePolicy(pubKey *btcec.PublicKey,
channelID uint64) *channeldb.ChannelEdgePolicy { channelID uint64) *channeldb.CachedEdgePolicy {
target := route.NewVertex(pubKey) target := route.NewVertex(pubKey)

View File

@@ -17,14 +17,14 @@ var _ PaymentSessionSource = (*SessionSource)(nil)
type SessionSource struct { type SessionSource struct {
// Graph is the channel graph that will be used to gather metrics from // Graph is the channel graph that will be used to gather metrics from
// and also to carry out path finding queries. // and also to carry out path finding queries.
Graph *channeldb.ChannelGraph Graph routingGraph
// QueryBandwidth is a method that allows querying the lower link layer // QueryBandwidth is a method that allows querying the lower link layer
// to determine the up to date available bandwidth at a prospective link // to determine the up to date available bandwidth at a prospective link
// to be traversed. If the link isn't available, then a value of zero // to be traversed. If the link isn't available, then a value of zero
// should be returned. Otherwise, the current up to date knowledge of // should be returned. Otherwise, the current up to date knowledge of
// the available bandwidth of the link should be returned. // the available bandwidth of the link should be returned.
QueryBandwidth func(*channeldb.ChannelEdgeInfo) lnwire.MilliSatoshi QueryBandwidth func(*channeldb.DirectedChannel) lnwire.MilliSatoshi
// MissionControl is a shared memory of sorts that executions of payment // MissionControl is a shared memory of sorts that executions of payment
// path finding use in order to remember which vertexes/edges were // path finding use in order to remember which vertexes/edges were
@@ -40,21 +40,6 @@ type SessionSource struct {
PathFindingConfig PathFindingConfig PathFindingConfig PathFindingConfig
} }
// getRoutingGraph returns a routing graph and a clean-up function for
// pathfinding.
func (m *SessionSource) getRoutingGraph() (routingGraph, func(), error) {
routingTx, err := newDbRoutingTx(m.Graph)
if err != nil {
return nil, nil, err
}
return routingTx, func() {
err := routingTx.close()
if err != nil {
log.Errorf("Error closing db tx: %v", err)
}
}, nil
}
// NewPaymentSession creates a new payment session backed by the latest prune // NewPaymentSession creates a new payment session backed by the latest prune
// view from Mission Control. An optional set of routing hints can be provided // view from Mission Control. An optional set of routing hints can be provided
// in order to populate additional edges to explore when finding a path to the // in order to populate additional edges to explore when finding a path to the
@@ -62,19 +47,16 @@ func (m *SessionSource) getRoutingGraph() (routingGraph, func(), error) {
func (m *SessionSource) NewPaymentSession(p *LightningPayment) ( func (m *SessionSource) NewPaymentSession(p *LightningPayment) (
PaymentSession, error) { PaymentSession, error) {
sourceNode, err := m.Graph.SourceNode()
if err != nil {
return nil, err
}
getBandwidthHints := func() (map[uint64]lnwire.MilliSatoshi, getBandwidthHints := func() (map[uint64]lnwire.MilliSatoshi,
error) { error) {
return generateBandwidthHints(sourceNode, m.QueryBandwidth) return generateBandwidthHints(
m.Graph.sourceNode(), m.Graph, m.QueryBandwidth,
)
} }
session, err := newPaymentSession( session, err := newPaymentSession(
p, getBandwidthHints, m.getRoutingGraph, p, getBandwidthHints, m.Graph,
m.MissionControl, m.PathFindingConfig, m.MissionControl, m.PathFindingConfig,
) )
if err != nil { if err != nil {
@@ -96,9 +78,9 @@ func (m *SessionSource) NewPaymentSessionEmpty() PaymentSession {
// RouteHintsToEdges converts a list of invoice route hints to an edge map that // RouteHintsToEdges converts a list of invoice route hints to an edge map that
// can be passed into pathfinding. // can be passed into pathfinding.
func RouteHintsToEdges(routeHints [][]zpay32.HopHint, target route.Vertex) ( func RouteHintsToEdges(routeHints [][]zpay32.HopHint, target route.Vertex) (
map[route.Vertex][]*channeldb.ChannelEdgePolicy, error) { map[route.Vertex][]*channeldb.CachedEdgePolicy, error) {
edges := make(map[route.Vertex][]*channeldb.ChannelEdgePolicy) edges := make(map[route.Vertex][]*channeldb.CachedEdgePolicy)
// Traverse through all of the available hop hints and include them in // Traverse through all of the available hop hints and include them in
// our edges map, indexed by the public key of the channel's starting // our edges map, indexed by the public key of the channel's starting
@@ -128,9 +110,12 @@ func RouteHintsToEdges(routeHints [][]zpay32.HopHint, target route.Vertex) (
// Finally, create the channel edge from the hop hint // Finally, create the channel edge from the hop hint
// and add it to list of edges corresponding to the node // and add it to list of edges corresponding to the node
// at the start of the channel. // at the start of the channel.
edge := &channeldb.ChannelEdgePolicy{ edge := &channeldb.CachedEdgePolicy{
Node: endNode, ToNodePubKey: func() route.Vertex {
ChannelID: hopHint.ChannelID, return endNode.PubKeyBytes
},
ToNodeFeatures: lnwire.EmptyFeatureVector(),
ChannelID: hopHint.ChannelID,
FeeBaseMSat: lnwire.MilliSatoshi( FeeBaseMSat: lnwire.MilliSatoshi(
hopHint.FeeBaseMSat, hopHint.FeeBaseMSat,
), ),

View File

@@ -121,9 +121,7 @@ func TestUpdateAdditionalEdge(t *testing.T) {
return nil, nil return nil, nil
}, },
func() (routingGraph, func(), error) { &sessionGraph{},
return &sessionGraph{}, func() {}, nil
},
&MissionControl{}, &MissionControl{},
PathFindingConfig{}, PathFindingConfig{},
) )
@@ -203,9 +201,7 @@ func TestRequestRoute(t *testing.T) {
return nil, nil return nil, nil
}, },
func() (routingGraph, func(), error) { &sessionGraph{},
return &sessionGraph{}, func() {}, nil
},
&MissionControl{}, &MissionControl{},
PathFindingConfig{}, PathFindingConfig{},
) )
@@ -217,7 +213,7 @@ func TestRequestRoute(t *testing.T) {
session.pathFinder = func( session.pathFinder = func(
g *graphParams, r *RestrictParams, cfg *PathFindingConfig, g *graphParams, r *RestrictParams, cfg *PathFindingConfig,
source, target route.Vertex, amt lnwire.MilliSatoshi, source, target route.Vertex, amt lnwire.MilliSatoshi,
finalHtlcExpiry int32) ([]*channeldb.ChannelEdgePolicy, error) { finalHtlcExpiry int32) ([]*channeldb.CachedEdgePolicy, error) {
// We expect find path to receive a cltv limit excluding the // We expect find path to receive a cltv limit excluding the
// final cltv delta (including the block padding). // final cltv delta (including the block padding).
@@ -225,13 +221,14 @@ func TestRequestRoute(t *testing.T) {
t.Fatal("wrong cltv limit") t.Fatal("wrong cltv limit")
} }
path := []*channeldb.ChannelEdgePolicy{ path := []*channeldb.CachedEdgePolicy{
{ {
Node: &channeldb.LightningNode{ ToNodePubKey: func() route.Vertex {
Features: lnwire.NewFeatureVector( return route.Vertex{}
nil, nil,
),
}, },
ToNodeFeatures: lnwire.NewFeatureVector(
nil, nil,
),
}, },
} }

View File

@@ -339,7 +339,7 @@ type Config struct {
// a value of zero should be returned. Otherwise, the current up to // a value of zero should be returned. Otherwise, the current up to
// date knowledge of the available bandwidth of the link should be // date knowledge of the available bandwidth of the link should be
// returned. // returned.
QueryBandwidth func(edge *channeldb.ChannelEdgeInfo) lnwire.MilliSatoshi QueryBandwidth func(*channeldb.DirectedChannel) lnwire.MilliSatoshi
// NextPaymentID is a method that guarantees to return a new, unique ID // NextPaymentID is a method that guarantees to return a new, unique ID
// each time it is called. This is used by the router to generate a // each time it is called. This is used by the router to generate a
@@ -406,6 +406,10 @@ type ChannelRouter struct {
// when doing any path finding. // when doing any path finding.
selfNode *channeldb.LightningNode selfNode *channeldb.LightningNode
// cachedGraph is an instance of routingGraph that caches the source node as
// well as the channel graph itself in memory.
cachedGraph routingGraph
// newBlocks is a channel in which new blocks connected to the end of // newBlocks is a channel in which new blocks connected to the end of
// the main chain are sent over, and blocks updated after a call to // the main chain are sent over, and blocks updated after a call to
// UpdateFilter. // UpdateFilter.
@@ -460,14 +464,17 @@ var _ ChannelGraphSource = (*ChannelRouter)(nil)
// channel graph is a subset of the UTXO set) set, then the router will proceed // channel graph is a subset of the UTXO set) set, then the router will proceed
// to fully sync to the latest state of the UTXO set. // to fully sync to the latest state of the UTXO set.
func New(cfg Config) (*ChannelRouter, error) { func New(cfg Config) (*ChannelRouter, error) {
selfNode, err := cfg.Graph.SourceNode() selfNode, err := cfg.Graph.SourceNode()
if err != nil { if err != nil {
return nil, err return nil, err
} }
r := &ChannelRouter{ r := &ChannelRouter{
cfg: &cfg, cfg: &cfg,
cachedGraph: &CachedGraph{
graph: cfg.Graph,
source: selfNode.PubKeyBytes,
},
networkUpdates: make(chan *routingMsg), networkUpdates: make(chan *routingMsg),
topologyClients: make(map[uint64]*topologyClient), topologyClients: make(map[uint64]*topologyClient),
ntfnClientUpdates: make(chan *topologyClientUpdate), ntfnClientUpdates: make(chan *topologyClientUpdate),
@@ -1727,7 +1734,7 @@ type routingMsg struct {
func (r *ChannelRouter) FindRoute(source, target route.Vertex, func (r *ChannelRouter) FindRoute(source, target route.Vertex,
amt lnwire.MilliSatoshi, restrictions *RestrictParams, amt lnwire.MilliSatoshi, restrictions *RestrictParams,
destCustomRecords record.CustomSet, destCustomRecords record.CustomSet,
routeHints map[route.Vertex][]*channeldb.ChannelEdgePolicy, routeHints map[route.Vertex][]*channeldb.CachedEdgePolicy,
finalExpiry uint16) (*route.Route, error) { finalExpiry uint16) (*route.Route, error) {
log.Debugf("Searching for path to %v, sending %v", target, amt) log.Debugf("Searching for path to %v, sending %v", target, amt)
@@ -1735,7 +1742,7 @@ func (r *ChannelRouter) FindRoute(source, target route.Vertex,
// We'll attempt to obtain a set of bandwidth hints that can help us // We'll attempt to obtain a set of bandwidth hints that can help us
// eliminate certain routes early on in the path finding process. // eliminate certain routes early on in the path finding process.
bandwidthHints, err := generateBandwidthHints( bandwidthHints, err := generateBandwidthHints(
r.selfNode, r.cfg.QueryBandwidth, r.selfNode.PubKeyBytes, r.cachedGraph, r.cfg.QueryBandwidth,
) )
if err != nil { if err != nil {
return nil, err return nil, err
@@ -1752,22 +1759,11 @@ func (r *ChannelRouter) FindRoute(source, target route.Vertex,
// execute our path finding algorithm. // execute our path finding algorithm.
finalHtlcExpiry := currentHeight + int32(finalExpiry) finalHtlcExpiry := currentHeight + int32(finalExpiry)
routingTx, err := newDbRoutingTx(r.cfg.Graph)
if err != nil {
return nil, err
}
defer func() {
err := routingTx.close()
if err != nil {
log.Errorf("Error closing db tx: %v", err)
}
}()
path, err := findPath( path, err := findPath(
&graphParams{ &graphParams{
additionalEdges: routeHints, additionalEdges: routeHints,
bandwidthHints: bandwidthHints, bandwidthHints: bandwidthHints,
graph: routingTx, graph: r.cachedGraph,
}, },
restrictions, restrictions,
&r.cfg.PathFindingConfig, &r.cfg.PathFindingConfig,
@@ -2505,8 +2501,10 @@ func (r *ChannelRouter) GetChannelByID(chanID lnwire.ShortChannelID) (
// within the graph. // within the graph.
// //
// NOTE: This method is part of the ChannelGraphSource interface. // NOTE: This method is part of the ChannelGraphSource interface.
func (r *ChannelRouter) FetchLightningNode(node route.Vertex) (*channeldb.LightningNode, error) { func (r *ChannelRouter) FetchLightningNode(
return r.cfg.Graph.FetchLightningNode(nil, node) node route.Vertex) (*channeldb.LightningNode, error) {
return r.cfg.Graph.FetchLightningNode(node)
} }
// ForEachNode is used to iterate over every node in router topology. // ForEachNode is used to iterate over every node in router topology.
@@ -2661,19 +2659,19 @@ func (r *ChannelRouter) MarkEdgeLive(chanID lnwire.ShortChannelID) error {
// these hints allows us to reduce the number of extraneous attempts as we can // these hints allows us to reduce the number of extraneous attempts as we can
// skip channels that are inactive, or just don't have enough bandwidth to // skip channels that are inactive, or just don't have enough bandwidth to
// carry the payment. // carry the payment.
func generateBandwidthHints(sourceNode *channeldb.LightningNode, func generateBandwidthHints(sourceNode route.Vertex, graph routingGraph,
queryBandwidth func(*channeldb.ChannelEdgeInfo) lnwire.MilliSatoshi) (map[uint64]lnwire.MilliSatoshi, error) { queryBandwidth func(*channeldb.DirectedChannel) lnwire.MilliSatoshi) (
map[uint64]lnwire.MilliSatoshi, error) {
// First, we'll collect the set of outbound edges from the target // First, we'll collect the set of outbound edges from the target
// source node. // source node.
var localChans []*channeldb.ChannelEdgeInfo var localChans []*channeldb.DirectedChannel
err := sourceNode.ForEachChannel(nil, func(tx kvdb.RTx, err := graph.forEachNodeChannel(
edgeInfo *channeldb.ChannelEdgeInfo, sourceNode, func(channel *channeldb.DirectedChannel) error {
_, _ *channeldb.ChannelEdgePolicy) error { localChans = append(localChans, channel)
return nil
localChans = append(localChans, edgeInfo) },
return nil )
})
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -2726,7 +2724,7 @@ func (r *ChannelRouter) BuildRoute(amt *lnwire.MilliSatoshi,
// We'll attempt to obtain a set of bandwidth hints that helps us select // We'll attempt to obtain a set of bandwidth hints that helps us select
// the best outgoing channel to use in case no outgoing channel is set. // the best outgoing channel to use in case no outgoing channel is set.
bandwidthHints, err := generateBandwidthHints( bandwidthHints, err := generateBandwidthHints(
r.selfNode, r.cfg.QueryBandwidth, r.selfNode.PubKeyBytes, r.cachedGraph, r.cfg.QueryBandwidth,
) )
if err != nil { if err != nil {
return nil, err return nil, err
@@ -2756,18 +2754,6 @@ func (r *ChannelRouter) BuildRoute(amt *lnwire.MilliSatoshi,
runningAmt = *amt runningAmt = *amt
} }
// Open a transaction to execute the graph queries in.
routingTx, err := newDbRoutingTx(r.cfg.Graph)
if err != nil {
return nil, err
}
defer func() {
err := routingTx.close()
if err != nil {
log.Errorf("Error closing db tx: %v", err)
}
}()
// Traverse hops backwards to accumulate fees in the running amounts. // Traverse hops backwards to accumulate fees in the running amounts.
source := r.selfNode.PubKeyBytes source := r.selfNode.PubKeyBytes
for i := len(hops) - 1; i >= 0; i-- { for i := len(hops) - 1; i >= 0; i-- {
@@ -2786,7 +2772,7 @@ func (r *ChannelRouter) BuildRoute(amt *lnwire.MilliSatoshi,
// known in the graph. // known in the graph.
u := newUnifiedPolicies(source, toNode, outgoingChans) u := newUnifiedPolicies(source, toNode, outgoingChans)
err := u.addGraphPolicies(routingTx) err := u.addGraphPolicies(r.cachedGraph)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -2832,7 +2818,7 @@ func (r *ChannelRouter) BuildRoute(amt *lnwire.MilliSatoshi,
// total amount, we make a forward pass. Because the amount may have // total amount, we make a forward pass. Because the amount may have
// been increased in the backward pass, fees need to be recalculated and // been increased in the backward pass, fees need to be recalculated and
// amount ranges re-checked. // amount ranges re-checked.
var pathEdges []*channeldb.ChannelEdgePolicy var pathEdges []*channeldb.CachedEdgePolicy
receiverAmt := runningAmt receiverAmt := runningAmt
for i, edge := range edges { for i, edge := range edges {
policy := edge.getPolicy(receiverAmt, bandwidthHints) policy := edge.getPolicy(receiverAmt, bandwidthHints)

View File

@@ -125,17 +125,19 @@ func createTestCtxFromGraphInstanceAssumeValid(t *testing.T,
} }
mc, err := NewMissionControl( mc, err := NewMissionControl(
graphInstance.graph.Database(), route.Vertex{}, graphInstance.graphBackend, route.Vertex{}, mcConfig,
mcConfig,
) )
require.NoError(t, err, "failed to create missioncontrol") require.NoError(t, err, "failed to create missioncontrol")
sessionSource := &SessionSource{ cachedGraph, err := NewCachedGraph(graphInstance.graph)
Graph: graphInstance.graph, require.NoError(t, err)
QueryBandwidth: func(
e *channeldb.ChannelEdgeInfo) lnwire.MilliSatoshi {
return lnwire.NewMSatFromSatoshis(e.Capacity) sessionSource := &SessionSource{
Graph: cachedGraph,
QueryBandwidth: func(
c *channeldb.DirectedChannel) lnwire.MilliSatoshi {
return lnwire.NewMSatFromSatoshis(c.Capacity)
}, },
PathFindingConfig: pathFindingConfig, PathFindingConfig: pathFindingConfig,
MissionControl: mc, MissionControl: mc,
@@ -159,7 +161,7 @@ func createTestCtxFromGraphInstanceAssumeValid(t *testing.T,
ChannelPruneExpiry: time.Hour * 24, ChannelPruneExpiry: time.Hour * 24,
GraphPruneInterval: time.Hour * 2, GraphPruneInterval: time.Hour * 2,
QueryBandwidth: func( QueryBandwidth: func(
e *channeldb.ChannelEdgeInfo) lnwire.MilliSatoshi { e *channeldb.DirectedChannel) lnwire.MilliSatoshi {
return lnwire.NewMSatFromSatoshis(e.Capacity) return lnwire.NewMSatFromSatoshis(e.Capacity)
}, },
@@ -188,7 +190,6 @@ func createTestCtxFromGraphInstanceAssumeValid(t *testing.T,
cleanUp := func() { cleanUp := func() {
ctx.router.Stop() ctx.router.Stop()
graphInstance.cleanUp()
} }
return ctx, cleanUp return ctx, cleanUp
@@ -197,17 +198,10 @@ func createTestCtxFromGraphInstanceAssumeValid(t *testing.T,
func createTestCtxSingleNode(t *testing.T, func createTestCtxSingleNode(t *testing.T,
startingHeight uint32) (*testCtx, func()) { startingHeight uint32) (*testCtx, func()) {
var ( graph, graphBackend, cleanup, err := makeTestGraph()
graph *channeldb.ChannelGraph
sourceNode *channeldb.LightningNode
cleanup func()
err error
)
graph, cleanup, err = makeTestGraph()
require.NoError(t, err, "failed to make test graph") require.NoError(t, err, "failed to make test graph")
sourceNode, err = createTestNode() sourceNode, err := createTestNode()
require.NoError(t, err, "failed to create test node") require.NoError(t, err, "failed to create test node")
require.NoError(t, require.NoError(t,
@@ -215,8 +209,9 @@ func createTestCtxSingleNode(t *testing.T,
) )
graphInstance := &testGraphInstance{ graphInstance := &testGraphInstance{
graph: graph, graph: graph,
cleanUp: cleanup, graphBackend: graphBackend,
cleanUp: cleanup,
} }
return createTestCtxFromGraphInstance( return createTestCtxFromGraphInstance(
@@ -1401,6 +1396,9 @@ func TestAddEdgeUnknownVertexes(t *testing.T) {
MinHTLC: 1, MinHTLC: 1,
FeeBaseMSat: 10, FeeBaseMSat: 10,
FeeProportionalMillionths: 10000, FeeProportionalMillionths: 10000,
Node: &channeldb.LightningNode{
PubKeyBytes: edge.NodeKey2Bytes,
},
} }
edgePolicy.ChannelFlags = 0 edgePolicy.ChannelFlags = 0
@@ -1417,6 +1415,9 @@ func TestAddEdgeUnknownVertexes(t *testing.T) {
MinHTLC: 1, MinHTLC: 1,
FeeBaseMSat: 10, FeeBaseMSat: 10,
FeeProportionalMillionths: 10000, FeeProportionalMillionths: 10000,
Node: &channeldb.LightningNode{
PubKeyBytes: edge.NodeKey1Bytes,
},
} }
edgePolicy.ChannelFlags = 1 edgePolicy.ChannelFlags = 1
@@ -1498,6 +1499,9 @@ func TestAddEdgeUnknownVertexes(t *testing.T) {
MinHTLC: 1, MinHTLC: 1,
FeeBaseMSat: 10, FeeBaseMSat: 10,
FeeProportionalMillionths: 10000, FeeProportionalMillionths: 10000,
Node: &channeldb.LightningNode{
PubKeyBytes: edge.NodeKey2Bytes,
},
} }
edgePolicy.ChannelFlags = 0 edgePolicy.ChannelFlags = 0
@@ -1513,6 +1517,9 @@ func TestAddEdgeUnknownVertexes(t *testing.T) {
MinHTLC: 1, MinHTLC: 1,
FeeBaseMSat: 10, FeeBaseMSat: 10,
FeeProportionalMillionths: 10000, FeeProportionalMillionths: 10000,
Node: &channeldb.LightningNode{
PubKeyBytes: edge.NodeKey1Bytes,
},
} }
edgePolicy.ChannelFlags = 1 edgePolicy.ChannelFlags = 1
@@ -1577,7 +1584,7 @@ func TestAddEdgeUnknownVertexes(t *testing.T) {
t.Fatalf("unable to find any routes: %v", err) t.Fatalf("unable to find any routes: %v", err)
} }
copy1, err := ctx.graph.FetchLightningNode(nil, pub1) copy1, err := ctx.graph.FetchLightningNode(pub1)
if err != nil { if err != nil {
t.Fatalf("unable to fetch node: %v", err) t.Fatalf("unable to fetch node: %v", err)
} }
@@ -1586,7 +1593,7 @@ func TestAddEdgeUnknownVertexes(t *testing.T) {
t.Fatalf("fetched node not equal to original") t.Fatalf("fetched node not equal to original")
} }
copy2, err := ctx.graph.FetchLightningNode(nil, pub2) copy2, err := ctx.graph.FetchLightningNode(pub2)
if err != nil { if err != nil {
t.Fatalf("unable to fetch node: %v", err) t.Fatalf("unable to fetch node: %v", err)
} }
@@ -2474,8 +2481,8 @@ func TestFindPathFeeWeighting(t *testing.T) {
if len(path) != 1 { if len(path) != 1 {
t.Fatalf("expected path length of 1, instead was: %v", len(path)) t.Fatalf("expected path length of 1, instead was: %v", len(path))
} }
if path[0].Node.Alias != "luoji" { if path[0].ToNodePubKey() != ctx.aliases["luoji"] {
t.Fatalf("wrong node: %v", path[0].Node.Alias) t.Fatalf("wrong node: %v", path[0].ToNodePubKey())
} }
} }

View File

@@ -40,7 +40,7 @@ func newUnifiedPolicies(sourceNode, toNode route.Vertex,
// addPolicy adds a single channel policy. Capacity may be zero if unknown // addPolicy adds a single channel policy. Capacity may be zero if unknown
// (light clients). // (light clients).
func (u *unifiedPolicies) addPolicy(fromNode route.Vertex, func (u *unifiedPolicies) addPolicy(fromNode route.Vertex,
edge *channeldb.ChannelEdgePolicy, capacity btcutil.Amount) { edge *channeldb.CachedEdgePolicy, capacity btcutil.Amount) {
localChan := fromNode == u.sourceNode localChan := fromNode == u.sourceNode
@@ -69,24 +69,18 @@ func (u *unifiedPolicies) addPolicy(fromNode route.Vertex,
// addGraphPolicies adds all policies that are known for the toNode in the // addGraphPolicies adds all policies that are known for the toNode in the
// graph. // graph.
func (u *unifiedPolicies) addGraphPolicies(g routingGraph) error { func (u *unifiedPolicies) addGraphPolicies(g routingGraph) error {
cb := func(edgeInfo *channeldb.ChannelEdgeInfo, _, cb := func(channel *channeldb.DirectedChannel) error {
inEdge *channeldb.ChannelEdgePolicy) error {
// If there is no edge policy for this candidate node, skip. // If there is no edge policy for this candidate node, skip.
// Note that we are searching backwards so this node would have // Note that we are searching backwards so this node would have
// come prior to the pivot node in the route. // come prior to the pivot node in the route.
if inEdge == nil { if channel.InPolicy == nil {
return nil return nil
} }
// The node on the other end of this channel is the from node.
fromNode, err := edgeInfo.OtherNodeKeyBytes(u.toNode[:])
if err != nil {
return err
}
// Add this policy to the unified policies map. // Add this policy to the unified policies map.
u.addPolicy(fromNode, inEdge, edgeInfo.Capacity) u.addPolicy(
channel.OtherNode, channel.InPolicy, channel.Capacity,
)
return nil return nil
} }
@@ -98,7 +92,7 @@ func (u *unifiedPolicies) addGraphPolicies(g routingGraph) error {
// unifiedPolicyEdge is the individual channel data that is kept inside an // unifiedPolicyEdge is the individual channel data that is kept inside an
// unifiedPolicy object. // unifiedPolicy object.
type unifiedPolicyEdge struct { type unifiedPolicyEdge struct {
policy *channeldb.ChannelEdgePolicy policy *channeldb.CachedEdgePolicy
capacity btcutil.Amount capacity btcutil.Amount
} }
@@ -139,7 +133,7 @@ type unifiedPolicy struct {
// specific amount to send. It differentiates between local and network // specific amount to send. It differentiates between local and network
// channels. // channels.
func (u *unifiedPolicy) getPolicy(amt lnwire.MilliSatoshi, func (u *unifiedPolicy) getPolicy(amt lnwire.MilliSatoshi,
bandwidthHints map[uint64]lnwire.MilliSatoshi) *channeldb.ChannelEdgePolicy { bandwidthHints map[uint64]lnwire.MilliSatoshi) *channeldb.CachedEdgePolicy {
if u.localChan { if u.localChan {
return u.getPolicyLocal(amt, bandwidthHints) return u.getPolicyLocal(amt, bandwidthHints)
@@ -151,10 +145,10 @@ func (u *unifiedPolicy) getPolicy(amt lnwire.MilliSatoshi,
// getPolicyLocal returns the optimal policy to use for this local connection // getPolicyLocal returns the optimal policy to use for this local connection
// given a specific amount to send. // given a specific amount to send.
func (u *unifiedPolicy) getPolicyLocal(amt lnwire.MilliSatoshi, func (u *unifiedPolicy) getPolicyLocal(amt lnwire.MilliSatoshi,
bandwidthHints map[uint64]lnwire.MilliSatoshi) *channeldb.ChannelEdgePolicy { bandwidthHints map[uint64]lnwire.MilliSatoshi) *channeldb.CachedEdgePolicy {
var ( var (
bestPolicy *channeldb.ChannelEdgePolicy bestPolicy *channeldb.CachedEdgePolicy
maxBandwidth lnwire.MilliSatoshi maxBandwidth lnwire.MilliSatoshi
) )
@@ -206,10 +200,10 @@ func (u *unifiedPolicy) getPolicyLocal(amt lnwire.MilliSatoshi,
// a specific amount to send. The goal is to return a policy that maximizes the // a specific amount to send. The goal is to return a policy that maximizes the
// probability of a successful forward in a non-strict forwarding context. // probability of a successful forward in a non-strict forwarding context.
func (u *unifiedPolicy) getPolicyNetwork( func (u *unifiedPolicy) getPolicyNetwork(
amt lnwire.MilliSatoshi) *channeldb.ChannelEdgePolicy { amt lnwire.MilliSatoshi) *channeldb.CachedEdgePolicy {
var ( var (
bestPolicy *channeldb.ChannelEdgePolicy bestPolicy *channeldb.CachedEdgePolicy
maxFee lnwire.MilliSatoshi maxFee lnwire.MilliSatoshi
maxTimelock uint16 maxTimelock uint16
) )

View File

@@ -20,7 +20,7 @@ func TestUnifiedPolicies(t *testing.T) {
u := newUnifiedPolicies(source, toNode, nil) u := newUnifiedPolicies(source, toNode, nil)
// Add two channels between the pair of nodes. // Add two channels between the pair of nodes.
p1 := channeldb.ChannelEdgePolicy{ p1 := channeldb.CachedEdgePolicy{
FeeProportionalMillionths: 100000, FeeProportionalMillionths: 100000,
FeeBaseMSat: 30, FeeBaseMSat: 30,
TimeLockDelta: 60, TimeLockDelta: 60,
@@ -28,7 +28,7 @@ func TestUnifiedPolicies(t *testing.T) {
MaxHTLC: 500, MaxHTLC: 500,
MinHTLC: 100, MinHTLC: 100,
} }
p2 := channeldb.ChannelEdgePolicy{ p2 := channeldb.CachedEdgePolicy{
FeeProportionalMillionths: 190000, FeeProportionalMillionths: 190000,
FeeBaseMSat: 10, FeeBaseMSat: 10,
TimeLockDelta: 40, TimeLockDelta: 40,
@@ -39,7 +39,7 @@ func TestUnifiedPolicies(t *testing.T) {
u.addPolicy(fromNode, &p1, 7) u.addPolicy(fromNode, &p1, 7)
u.addPolicy(fromNode, &p2, 7) u.addPolicy(fromNode, &p2, 7)
checkPolicy := func(policy *channeldb.ChannelEdgePolicy, checkPolicy := func(policy *channeldb.CachedEdgePolicy,
feeBase lnwire.MilliSatoshi, feeRate lnwire.MilliSatoshi, feeBase lnwire.MilliSatoshi, feeRate lnwire.MilliSatoshi,
timeLockDelta uint16) { timeLockDelta uint16) {

View File

@@ -3989,7 +3989,7 @@ func (r *rpcServer) createRPCClosedChannel(
CloseInitiator: closeInitiator, CloseInitiator: closeInitiator,
} }
reports, err := r.server.chanStateDB.FetchChannelReports( reports, err := r.server.miscDB.FetchChannelReports(
*r.cfg.ActiveNetParams.GenesisHash, &dbChannel.ChanPoint, *r.cfg.ActiveNetParams.GenesisHash, &dbChannel.ChanPoint,
) )
switch err { switch err {
@@ -5152,7 +5152,7 @@ func (r *rpcServer) ListInvoices(ctx context.Context,
PendingOnly: req.PendingOnly, PendingOnly: req.PendingOnly,
Reversed: req.Reversed, Reversed: req.Reversed,
} }
invoiceSlice, err := r.server.chanStateDB.QueryInvoices(q) invoiceSlice, err := r.server.miscDB.QueryInvoices(q)
if err != nil { if err != nil {
return nil, fmt.Errorf("unable to query invoices: %v", err) return nil, fmt.Errorf("unable to query invoices: %v", err)
} }
@@ -5549,7 +5549,7 @@ func (r *rpcServer) GetNodeInfo(ctx context.Context,
// With the public key decoded, attempt to fetch the node corresponding // With the public key decoded, attempt to fetch the node corresponding
// to this public key. If the node cannot be found, then an error will // to this public key. If the node cannot be found, then an error will
// be returned. // be returned.
node, err := graph.FetchLightningNode(nil, pubKey) node, err := graph.FetchLightningNode(pubKey)
switch { switch {
case err == channeldb.ErrGraphNodeNotFound: case err == channeldb.ErrGraphNodeNotFound:
return nil, status.Error(codes.NotFound, err.Error()) return nil, status.Error(codes.NotFound, err.Error())
@@ -5954,7 +5954,7 @@ func (r *rpcServer) ListPayments(ctx context.Context,
query.MaxPayments = math.MaxUint64 query.MaxPayments = math.MaxUint64
} }
paymentsQuerySlice, err := r.server.chanStateDB.QueryPayments(query) paymentsQuerySlice, err := r.server.miscDB.QueryPayments(query)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -5995,9 +5995,7 @@ func (r *rpcServer) DeletePayment(ctx context.Context,
rpcsLog.Infof("[DeletePayment] payment_identifier=%v, "+ rpcsLog.Infof("[DeletePayment] payment_identifier=%v, "+
"failed_htlcs_only=%v", hash, req.FailedHtlcsOnly) "failed_htlcs_only=%v", hash, req.FailedHtlcsOnly)
err = r.server.chanStateDB.DeletePayment( err = r.server.miscDB.DeletePayment(hash, req.FailedHtlcsOnly)
hash, req.FailedHtlcsOnly,
)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -6014,7 +6012,7 @@ func (r *rpcServer) DeleteAllPayments(ctx context.Context,
"failed_htlcs_only=%v", req.FailedPaymentsOnly, "failed_htlcs_only=%v", req.FailedPaymentsOnly,
req.FailedHtlcsOnly) req.FailedHtlcsOnly)
err := r.server.chanStateDB.DeletePayments( err := r.server.miscDB.DeletePayments(
req.FailedPaymentsOnly, req.FailedHtlcsOnly, req.FailedPaymentsOnly, req.FailedHtlcsOnly,
) )
if err != nil { if err != nil {
@@ -6176,7 +6174,7 @@ func (r *rpcServer) FeeReport(ctx context.Context,
return nil, err return nil, err
} }
fwdEventLog := r.server.chanStateDB.ForwardingLog() fwdEventLog := r.server.miscDB.ForwardingLog()
// computeFeeSum is a helper function that computes the total fees for // computeFeeSum is a helper function that computes the total fees for
// a particular time slice described by a forwarding event query. // a particular time slice described by a forwarding event query.
@@ -6417,7 +6415,7 @@ func (r *rpcServer) ForwardingHistory(ctx context.Context,
IndexOffset: req.IndexOffset, IndexOffset: req.IndexOffset,
NumMaxEvents: numEvents, NumMaxEvents: numEvents,
} }
timeSlice, err := r.server.chanStateDB.ForwardingLog().Query(eventQuery) timeSlice, err := r.server.miscDB.ForwardingLog().Query(eventQuery)
if err != nil { if err != nil {
return nil, fmt.Errorf("unable to query forwarding log: %v", err) return nil, fmt.Errorf("unable to query forwarding log: %v", err)
} }
@@ -6479,7 +6477,7 @@ func (r *rpcServer) ExportChannelBackup(ctx context.Context,
// the database. If this channel has been closed, or the outpoint is // the database. If this channel has been closed, or the outpoint is
// unknown, then we'll return an error // unknown, then we'll return an error
unpackedBackup, err := chanbackup.FetchBackupForChan( unpackedBackup, err := chanbackup.FetchBackupForChan(
chanPoint, r.server.chanStateDB, chanPoint, r.server.chanStateDB, r.server.addrSource,
) )
if err != nil { if err != nil {
return nil, err return nil, err
@@ -6649,7 +6647,7 @@ func (r *rpcServer) ExportAllChannelBackups(ctx context.Context,
// First, we'll attempt to read back ups for ALL currently opened // First, we'll attempt to read back ups for ALL currently opened
// channels from disk. // channels from disk.
allUnpackedBackups, err := chanbackup.FetchStaticChanBackups( allUnpackedBackups, err := chanbackup.FetchStaticChanBackups(
r.server.chanStateDB, r.server.chanStateDB, r.server.addrSource,
) )
if err != nil { if err != nil {
return nil, fmt.Errorf("unable to fetch all static chan "+ return nil, fmt.Errorf("unable to fetch all static chan "+
@@ -6776,7 +6774,7 @@ func (r *rpcServer) SubscribeChannelBackups(req *lnrpc.ChannelBackupSubscription
// we'll obtains the current set of single channel // we'll obtains the current set of single channel
// backups from disk. // backups from disk.
chanBackups, err := chanbackup.FetchStaticChanBackups( chanBackups, err := chanbackup.FetchStaticChanBackups(
r.server.chanStateDB, r.server.chanStateDB, r.server.addrSource,
) )
if err != nil { if err != nil {
return fmt.Errorf("unable to fetch all "+ return fmt.Errorf("unable to fetch all "+

View File

@@ -222,7 +222,13 @@ type server struct {
graphDB *channeldb.ChannelGraph graphDB *channeldb.ChannelGraph
chanStateDB *channeldb.DB chanStateDB *channeldb.ChannelStateDB
addrSource chanbackup.AddressSource
// miscDB is the DB that contains all "other" databases within the main
// channel DB that haven't been separated out yet.
miscDB *channeldb.DB
htlcSwitch *htlcswitch.Switch htlcSwitch *htlcswitch.Switch
@@ -432,14 +438,18 @@ func newServer(cfg *Config, listenAddrs []net.Addr,
s := &server{ s := &server{
cfg: cfg, cfg: cfg,
graphDB: dbs.graphDB.ChannelGraph(), graphDB: dbs.graphDB.ChannelGraph(),
chanStateDB: dbs.chanStateDB, chanStateDB: dbs.chanStateDB.ChannelStateDB(),
addrSource: dbs.chanStateDB,
miscDB: dbs.chanStateDB,
cc: cc, cc: cc,
sigPool: lnwallet.NewSigPool(cfg.Workers.Sig, cc.Signer), sigPool: lnwallet.NewSigPool(cfg.Workers.Sig, cc.Signer),
writePool: writePool, writePool: writePool,
readPool: readPool, readPool: readPool,
chansToRestore: chansToRestore, chansToRestore: chansToRestore,
channelNotifier: channelnotifier.New(dbs.chanStateDB), channelNotifier: channelnotifier.New(
dbs.chanStateDB.ChannelStateDB(),
),
identityECDH: nodeKeyECDH, identityECDH: nodeKeyECDH,
nodeSigner: netann.NewNodeSigner(nodeKeySigner), nodeSigner: netann.NewNodeSigner(nodeKeySigner),
@@ -494,7 +504,9 @@ func newServer(cfg *Config, listenAddrs []net.Addr,
thresholdMSats := lnwire.NewMSatFromSatoshis(thresholdSats) thresholdMSats := lnwire.NewMSatFromSatoshis(thresholdSats)
s.htlcSwitch, err = htlcswitch.New(htlcswitch.Config{ s.htlcSwitch, err = htlcswitch.New(htlcswitch.Config{
DB: dbs.chanStateDB, DB: dbs.chanStateDB,
FetchAllOpenChannels: s.chanStateDB.FetchAllOpenChannels,
FetchClosedChannels: s.chanStateDB.FetchClosedChannels,
LocalChannelClose: func(pubKey []byte, LocalChannelClose: func(pubKey []byte,
request *htlcswitch.ChanClose) { request *htlcswitch.ChanClose) {
@@ -537,7 +549,7 @@ func newServer(cfg *Config, listenAddrs []net.Addr,
MessageSigner: s.nodeSigner, MessageSigner: s.nodeSigner,
IsChannelActive: s.htlcSwitch.HasActiveLink, IsChannelActive: s.htlcSwitch.HasActiveLink,
ApplyChannelUpdate: s.applyChannelUpdate, ApplyChannelUpdate: s.applyChannelUpdate,
DB: dbs.chanStateDB, DB: s.chanStateDB,
Graph: dbs.graphDB.ChannelGraph(), Graph: dbs.graphDB.ChannelGraph(),
} }
@@ -702,9 +714,9 @@ func newServer(cfg *Config, listenAddrs []net.Addr,
return nil, err return nil, err
} }
queryBandwidth := func(edge *channeldb.ChannelEdgeInfo) lnwire.MilliSatoshi { queryBandwidth := func(c *channeldb.DirectedChannel) lnwire.MilliSatoshi {
cid := lnwire.NewChanIDFromOutPoint(&edge.ChannelPoint) cid := lnwire.NewShortChanIDFromInt(c.ChannelID)
link, err := s.htlcSwitch.GetLink(cid) link, err := s.htlcSwitch.GetLinkByShortID(cid)
if err != nil { if err != nil {
// If the link isn't online, then we'll report // If the link isn't online, then we'll report
// that it has zero bandwidth to the router. // that it has zero bandwidth to the router.
@@ -768,8 +780,12 @@ func newServer(cfg *Config, listenAddrs []net.Addr,
MinProbability: routingConfig.MinRouteProbability, MinProbability: routingConfig.MinRouteProbability,
} }
cachedGraph, err := routing.NewCachedGraph(chanGraph)
if err != nil {
return nil, err
}
paymentSessionSource := &routing.SessionSource{ paymentSessionSource := &routing.SessionSource{
Graph: chanGraph, Graph: cachedGraph,
MissionControl: s.missionControl, MissionControl: s.missionControl,
QueryBandwidth: queryBandwidth, QueryBandwidth: queryBandwidth,
PathFindingConfig: pathFindingConfig, PathFindingConfig: pathFindingConfig,
@@ -805,11 +821,11 @@ func newServer(cfg *Config, listenAddrs []net.Addr,
} }
chanSeries := discovery.NewChanSeries(s.graphDB) chanSeries := discovery.NewChanSeries(s.graphDB)
gossipMessageStore, err := discovery.NewMessageStore(s.chanStateDB) gossipMessageStore, err := discovery.NewMessageStore(dbs.chanStateDB)
if err != nil { if err != nil {
return nil, err return nil, err
} }
waitingProofStore, err := channeldb.NewWaitingProofStore(s.chanStateDB) waitingProofStore, err := channeldb.NewWaitingProofStore(dbs.chanStateDB)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -891,8 +907,8 @@ func newServer(cfg *Config, listenAddrs []net.Addr,
s.utxoNursery = contractcourt.NewUtxoNursery(&contractcourt.NurseryConfig{ s.utxoNursery = contractcourt.NewUtxoNursery(&contractcourt.NurseryConfig{
ChainIO: cc.ChainIO, ChainIO: cc.ChainIO,
ConfDepth: 1, ConfDepth: 1,
FetchClosedChannels: dbs.chanStateDB.FetchClosedChannels, FetchClosedChannels: s.chanStateDB.FetchClosedChannels,
FetchClosedChannel: dbs.chanStateDB.FetchClosedChannel, FetchClosedChannel: s.chanStateDB.FetchClosedChannel,
Notifier: cc.ChainNotifier, Notifier: cc.ChainNotifier,
PublishTransaction: cc.Wallet.PublishTransaction, PublishTransaction: cc.Wallet.PublishTransaction,
Store: utxnStore, Store: utxnStore,
@@ -1018,7 +1034,7 @@ func newServer(cfg *Config, listenAddrs []net.Addr,
s.breachArbiter = contractcourt.NewBreachArbiter(&contractcourt.BreachConfig{ s.breachArbiter = contractcourt.NewBreachArbiter(&contractcourt.BreachConfig{
CloseLink: closeLink, CloseLink: closeLink,
DB: dbs.chanStateDB, DB: s.chanStateDB,
Estimator: s.cc.FeeEstimator, Estimator: s.cc.FeeEstimator,
GenSweepScript: newSweepPkScriptGen(cc.Wallet), GenSweepScript: newSweepPkScriptGen(cc.Wallet),
Notifier: cc.ChainNotifier, Notifier: cc.ChainNotifier,
@@ -1075,7 +1091,7 @@ func newServer(cfg *Config, listenAddrs []net.Addr,
FindChannel: func(chanID lnwire.ChannelID) ( FindChannel: func(chanID lnwire.ChannelID) (
*channeldb.OpenChannel, error) { *channeldb.OpenChannel, error) {
dbChannels, err := dbs.chanStateDB.FetchAllChannels() dbChannels, err := s.chanStateDB.FetchAllChannels()
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -1247,10 +1263,12 @@ func newServer(cfg *Config, listenAddrs []net.Addr,
// static backup of the latest channel state. // static backup of the latest channel state.
chanNotifier := &channelNotifier{ chanNotifier := &channelNotifier{
chanNotifier: s.channelNotifier, chanNotifier: s.channelNotifier,
addrs: s.chanStateDB, addrs: dbs.chanStateDB,
} }
backupFile := chanbackup.NewMultiFile(cfg.BackupFilePath) backupFile := chanbackup.NewMultiFile(cfg.BackupFilePath)
startingChans, err := chanbackup.FetchStaticChanBackups(s.chanStateDB) startingChans, err := chanbackup.FetchStaticChanBackups(
s.chanStateDB, s.addrSource,
)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -1275,8 +1293,8 @@ func newServer(cfg *Config, listenAddrs []net.Addr,
}, },
GetOpenChannels: s.chanStateDB.FetchAllOpenChannels, GetOpenChannels: s.chanStateDB.FetchAllOpenChannels,
Clock: clock.NewDefaultClock(), Clock: clock.NewDefaultClock(),
ReadFlapCount: s.chanStateDB.ReadFlapCount, ReadFlapCount: s.miscDB.ReadFlapCount,
WriteFlapCount: s.chanStateDB.WriteFlapCounts, WriteFlapCount: s.miscDB.WriteFlapCounts,
FlapCountTicker: ticker.New(chanfitness.FlapCountFlushRate), FlapCountTicker: ticker.New(chanfitness.FlapCountFlushRate),
}) })
@@ -2531,7 +2549,7 @@ func (s *server) establishPersistentConnections() error {
// Iterate through the list of LinkNodes to find addresses we should // Iterate through the list of LinkNodes to find addresses we should
// attempt to connect to based on our set of previous connections. Set // attempt to connect to based on our set of previous connections. Set
// the reconnection port to the default peer port. // the reconnection port to the default peer port.
linkNodes, err := s.chanStateDB.FetchAllLinkNodes() linkNodes, err := s.chanStateDB.LinkNodeDB().FetchAllLinkNodes()
if err != nil && err != channeldb.ErrLinkNodesNotFound { if err != nil && err != channeldb.ErrLinkNodesNotFound {
return err return err
} }
@@ -3911,7 +3929,7 @@ func (s *server) fetchNodeAdvertisedAddr(pub *btcec.PublicKey) (net.Addr, error)
return nil, err return nil, err
} }
node, err := s.graphDB.FetchLightningNode(nil, vertex) node, err := s.graphDB.FetchLightningNode(vertex)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -93,7 +93,7 @@ func (s *subRPCServerConfigs) PopulateDependencies(cfg *Config,
routerBackend *routerrpc.RouterBackend, routerBackend *routerrpc.RouterBackend,
nodeSigner *netann.NodeSigner, nodeSigner *netann.NodeSigner,
graphDB *channeldb.ChannelGraph, graphDB *channeldb.ChannelGraph,
chanStateDB *channeldb.DB, chanStateDB *channeldb.ChannelStateDB,
sweeper *sweep.UtxoSweeper, sweeper *sweep.UtxoSweeper,
tower *watchtower.Standalone, tower *watchtower.Standalone,
towerClient wtclient.Client, towerClient wtclient.Client,