diff --git a/channeldb/channel_test.go b/channeldb/channel_test.go index 6bd73d1f..656a885b 100644 --- a/channeldb/channel_test.go +++ b/channeldb/channel_test.go @@ -2,10 +2,8 @@ package channeldb import ( "bytes" - "io/ioutil" "math/rand" "net" - "os" "reflect" "runtime" "testing" @@ -86,40 +84,6 @@ var ( } ) -// makeTestDB creates a new instance of the ChannelDB for testing purposes. A -// callback which cleans up the created temporary directories is also returned -// and intended to be executed after the test completes. -func makeTestDB() (*DB, func(), error) { - // First, create a temporary directory to be used for the duration of - // this test. - tempDirName, err := ioutil.TempDir("", "channeldb") - if err != nil { - return nil, nil, err - } - - // Next, create channeldb for the first time. - backend, backendCleanup, err := kvdb.GetTestBackend(tempDirName, "cdb") - if err != nil { - backendCleanup() - return nil, nil, err - } - - cdb, err := CreateWithBackend(backend, OptionClock(testClock)) - if err != nil { - backendCleanup() - os.RemoveAll(tempDirName) - return nil, nil, err - } - - cleanUp := func() { - cdb.Close() - backendCleanup() - os.RemoveAll(tempDirName) - } - - return cdb, cleanUp, nil -} - // testChannelParams is a struct which details the specifics of how a channel // should be created. type testChannelParams struct { @@ -403,7 +367,7 @@ func createTestChannelState(t *testing.T, cdb *DB) *OpenChannel { func TestOpenChannelPutGetDelete(t *testing.T) { t.Parallel() - cdb, cleanUp, err := makeTestDB() + cdb, cleanUp, err := MakeTestDB() if err != nil { t.Fatalf("unable to make test database: %v", err) } @@ -552,7 +516,7 @@ func TestOptionalShutdown(t *testing.T) { test := test t.Run(test.name, func(t *testing.T) { - cdb, cleanUp, err := makeTestDB() + cdb, cleanUp, err := MakeTestDB() if err != nil { t.Fatalf("unable to make test database: %v", err) } @@ -609,7 +573,7 @@ func assertCommitmentEqual(t *testing.T, a, b *ChannelCommitment) { func TestChannelStateTransition(t *testing.T) { t.Parallel() - cdb, cleanUp, err := makeTestDB() + cdb, cleanUp, err := MakeTestDB() if err != nil { t.Fatalf("unable to make test database: %v", err) } @@ -914,7 +878,7 @@ func TestChannelStateTransition(t *testing.T) { func TestFetchPendingChannels(t *testing.T) { t.Parallel() - cdb, cleanUp, err := makeTestDB() + cdb, cleanUp, err := MakeTestDB() if err != nil { t.Fatalf("unable to make test database: %v", err) } @@ -993,7 +957,7 @@ func TestFetchPendingChannels(t *testing.T) { func TestFetchClosedChannels(t *testing.T) { t.Parallel() - cdb, cleanUp, err := makeTestDB() + cdb, cleanUp, err := MakeTestDB() if err != nil { t.Fatalf("unable to make test database: %v", err) } @@ -1084,7 +1048,7 @@ func TestFetchWaitingCloseChannels(t *testing.T) { // We'll start by creating two channels within our test database. One of // them will have their funding transaction confirmed on-chain, while // the other one will remain unconfirmed. - db, cleanUp, err := makeTestDB() + db, cleanUp, err := MakeTestDB() if err != nil { t.Fatalf("unable to make test database: %v", err) } @@ -1199,7 +1163,7 @@ func TestFetchWaitingCloseChannels(t *testing.T) { func TestRefreshShortChanID(t *testing.T) { t.Parallel() - cdb, cleanUp, err := makeTestDB() + cdb, cleanUp, err := MakeTestDB() if err != nil { t.Fatalf("unable to make test database: %v", err) } @@ -1347,7 +1311,7 @@ func TestCloseInitiator(t *testing.T) { t.Run(test.name, func(t *testing.T) { t.Parallel() - cdb, cleanUp, err := makeTestDB() + cdb, cleanUp, err := MakeTestDB() if err != nil { t.Fatalf("unable to make test database: %v", err) @@ -1392,7 +1356,7 @@ func TestCloseInitiator(t *testing.T) { // TestCloseChannelStatus tests setting of a channel status on the historical // channel on channel close. func TestCloseChannelStatus(t *testing.T) { - cdb, cleanUp, err := makeTestDB() + cdb, cleanUp, err := MakeTestDB() if err != nil { t.Fatalf("unable to make test database: %v", err) @@ -1538,7 +1502,7 @@ func TestBalanceAtHeight(t *testing.T) { t.Run(test.name, func(t *testing.T) { t.Parallel() - cdb, cleanUp, err := makeTestDB() + cdb, cleanUp, err := MakeTestDB() if err != nil { t.Fatalf("unable to make test database: %v", err) diff --git a/channeldb/db.go b/channeldb/db.go index 06d90560..a4c5c516 100644 --- a/channeldb/db.go +++ b/channeldb/db.go @@ -4,6 +4,7 @@ import ( "bytes" "encoding/binary" "fmt" + "io/ioutil" "net" "os" @@ -1260,3 +1261,37 @@ func (db *DB) FetchHistoricalChannel(outPoint *wire.OutPoint) (*OpenChannel, err return channel, nil } + +// MakeTestDB creates a new instance of the ChannelDB for testing purposes. +// A callback which cleans up the created temporary directories is also +// returned and intended to be executed after the test completes. +func MakeTestDB(modifiers ...OptionModifier) (*DB, func(), error) { + // First, create a temporary directory to be used for the duration of + // this test. + tempDirName, err := ioutil.TempDir("", "channeldb") + if err != nil { + return nil, nil, err + } + + // Next, create channeldb for the first time. + backend, backendCleanup, err := kvdb.GetTestBackend(tempDirName, "cdb") + if err != nil { + backendCleanup() + return nil, nil, err + } + + cdb, err := CreateWithBackend(backend, modifiers...) + if err != nil { + backendCleanup() + os.RemoveAll(tempDirName) + return nil, nil, err + } + + cleanUp := func() { + cdb.Close() + backendCleanup() + os.RemoveAll(tempDirName) + } + + return cdb, cleanUp, nil +} diff --git a/channeldb/db_test.go b/channeldb/db_test.go index b05ac115..a86eee69 100644 --- a/channeldb/db_test.go +++ b/channeldb/db_test.go @@ -115,7 +115,7 @@ func TestFetchClosedChannelForID(t *testing.T) { const numChans = 101 - cdb, cleanUp, err := makeTestDB() + cdb, cleanUp, err := MakeTestDB() if err != nil { t.Fatalf("unable to make test database: %v", err) } @@ -186,7 +186,7 @@ func TestFetchClosedChannelForID(t *testing.T) { func TestAddrsForNode(t *testing.T) { t.Parallel() - cdb, cleanUp, err := makeTestDB() + cdb, cleanUp, err := MakeTestDB() if err != nil { t.Fatalf("unable to make test database: %v", err) } @@ -247,7 +247,7 @@ func TestAddrsForNode(t *testing.T) { func TestFetchChannel(t *testing.T) { t.Parallel() - cdb, cleanUp, err := makeTestDB() + cdb, cleanUp, err := MakeTestDB() if err != nil { t.Fatalf("unable to make test database: %v", err) } @@ -351,7 +351,7 @@ func genRandomChannelShell() (*ChannelShell, error) { func TestRestoreChannelShells(t *testing.T) { t.Parallel() - cdb, cleanUp, err := makeTestDB() + cdb, cleanUp, err := MakeTestDB() if err != nil { t.Fatalf("unable to make test database: %v", err) } @@ -445,7 +445,7 @@ func TestRestoreChannelShells(t *testing.T) { func TestAbandonChannel(t *testing.T) { t.Parallel() - cdb, cleanUp, err := makeTestDB() + cdb, cleanUp, err := MakeTestDB() if err != nil { t.Fatalf("unable to make test database: %v", err) } @@ -618,7 +618,7 @@ func TestFetchChannels(t *testing.T) { t.Run(test.name, func(t *testing.T) { t.Parallel() - cdb, cleanUp, err := makeTestDB() + cdb, cleanUp, err := MakeTestDB() if err != nil { t.Fatalf("unable to make test "+ "database: %v", err) @@ -687,7 +687,7 @@ func TestFetchChannels(t *testing.T) { // TestFetchHistoricalChannel tests lookup of historical channels. func TestFetchHistoricalChannel(t *testing.T) { - cdb, cleanUp, err := makeTestDB() + cdb, cleanUp, err := MakeTestDB() if err != nil { t.Fatalf("unable to make test database: %v", err) } diff --git a/channeldb/forwarding_log_test.go b/channeldb/forwarding_log_test.go index cc06e886..07dfc902 100644 --- a/channeldb/forwarding_log_test.go +++ b/channeldb/forwarding_log_test.go @@ -19,7 +19,7 @@ func TestForwardingLogBasicStorageAndQuery(t *testing.T) { // First, we'll set up a test database, and use that to instantiate the // forwarding event log that we'll be using for the duration of the // test. - db, cleanUp, err := makeTestDB() + db, cleanUp, err := MakeTestDB() defer cleanUp() if err != nil { t.Fatalf("unable to make test db: %v", err) @@ -91,7 +91,7 @@ func TestForwardingLogQueryOptions(t *testing.T) { // First, we'll set up a test database, and use that to instantiate the // forwarding event log that we'll be using for the duration of the // test. - db, cleanUp, err := makeTestDB() + db, cleanUp, err := MakeTestDB() defer cleanUp() if err != nil { t.Fatalf("unable to make test db: %v", err) @@ -196,7 +196,7 @@ func TestForwardingLogQueryLimit(t *testing.T) { // First, we'll set up a test database, and use that to instantiate the // forwarding event log that we'll be using for the duration of the // test. - db, cleanUp, err := makeTestDB() + db, cleanUp, err := MakeTestDB() defer cleanUp() if err != nil { t.Fatalf("unable to make test db: %v", err) diff --git a/channeldb/graph_test.go b/channeldb/graph_test.go index a6c1fb0d..71edc8f8 100644 --- a/channeldb/graph_test.go +++ b/channeldb/graph_test.go @@ -73,7 +73,7 @@ func createTestVertex(db *DB) (*LightningNode, error) { func TestNodeInsertionAndDeletion(t *testing.T) { t.Parallel() - db, cleanUp, err := makeTestDB() + db, cleanUp, err := MakeTestDB() defer cleanUp() if err != nil { t.Fatalf("unable to make test database: %v", err) @@ -139,7 +139,7 @@ func TestNodeInsertionAndDeletion(t *testing.T) { func TestPartialNode(t *testing.T) { t.Parallel() - db, cleanUp, err := makeTestDB() + db, cleanUp, err := MakeTestDB() defer cleanUp() if err != nil { t.Fatalf("unable to make test database: %v", err) @@ -201,7 +201,7 @@ func TestPartialNode(t *testing.T) { func TestAliasLookup(t *testing.T) { t.Parallel() - db, cleanUp, err := makeTestDB() + db, cleanUp, err := MakeTestDB() defer cleanUp() if err != nil { t.Fatalf("unable to make test database: %v", err) @@ -255,7 +255,7 @@ func TestAliasLookup(t *testing.T) { func TestSourceNode(t *testing.T) { t.Parallel() - db, cleanUp, err := makeTestDB() + db, cleanUp, err := MakeTestDB() defer cleanUp() if err != nil { t.Fatalf("unable to make test database: %v", err) @@ -296,7 +296,7 @@ func TestSourceNode(t *testing.T) { func TestEdgeInsertionDeletion(t *testing.T) { t.Parallel() - db, cleanUp, err := makeTestDB() + db, cleanUp, err := MakeTestDB() defer cleanUp() if err != nil { t.Fatalf("unable to make test database: %v", err) @@ -431,7 +431,7 @@ func createEdge(height, txIndex uint32, txPosition uint16, outPointIndex uint32, func TestDisconnectBlockAtHeight(t *testing.T) { t.Parallel() - db, cleanUp, err := makeTestDB() + db, cleanUp, err := MakeTestDB() defer cleanUp() if err != nil { t.Fatalf("unable to make test database: %v", err) @@ -718,7 +718,7 @@ func createChannelEdge(db *DB, node1, node2 *LightningNode) (*ChannelEdgeInfo, func TestEdgeInfoUpdates(t *testing.T) { t.Parallel() - db, cleanUp, err := makeTestDB() + db, cleanUp, err := MakeTestDB() defer cleanUp() if err != nil { t.Fatalf("unable to make test database: %v", err) @@ -848,7 +848,7 @@ func newEdgePolicy(chanID uint64, op wire.OutPoint, db *DB, func TestGraphTraversal(t *testing.T) { t.Parallel() - db, cleanUp, err := makeTestDB() + db, cleanUp, err := MakeTestDB() defer cleanUp() if err != nil { t.Fatalf("unable to make test database: %v", err) @@ -1109,7 +1109,7 @@ func assertChanViewEqualChanPoints(t *testing.T, a []EdgePoint, b []*wire.OutPoi func TestGraphPruning(t *testing.T) { t.Parallel() - db, cleanUp, err := makeTestDB() + db, cleanUp, err := MakeTestDB() defer cleanUp() if err != nil { t.Fatalf("unable to make test database: %v", err) @@ -1317,7 +1317,7 @@ func TestGraphPruning(t *testing.T) { func TestHighestChanID(t *testing.T) { t.Parallel() - db, cleanUp, err := makeTestDB() + db, cleanUp, err := MakeTestDB() defer cleanUp() if err != nil { t.Fatalf("unable to make test database: %v", err) @@ -1394,7 +1394,7 @@ func TestHighestChanID(t *testing.T) { func TestChanUpdatesInHorizon(t *testing.T) { t.Parallel() - db, cleanUp, err := makeTestDB() + db, cleanUp, err := MakeTestDB() defer cleanUp() if err != nil { t.Fatalf("unable to make test database: %v", err) @@ -1570,7 +1570,7 @@ func TestChanUpdatesInHorizon(t *testing.T) { func TestNodeUpdatesInHorizon(t *testing.T) { t.Parallel() - db, cleanUp, err := makeTestDB() + db, cleanUp, err := MakeTestDB() defer cleanUp() if err != nil { t.Fatalf("unable to make test database: %v", err) @@ -1693,7 +1693,7 @@ func TestNodeUpdatesInHorizon(t *testing.T) { func TestFilterKnownChanIDs(t *testing.T) { t.Parallel() - db, cleanUp, err := makeTestDB() + db, cleanUp, err := MakeTestDB() defer cleanUp() if err != nil { t.Fatalf("unable to make test database: %v", err) @@ -1810,7 +1810,7 @@ func TestFilterKnownChanIDs(t *testing.T) { func TestFilterChannelRange(t *testing.T) { t.Parallel() - db, cleanUp, err := makeTestDB() + db, cleanUp, err := MakeTestDB() defer cleanUp() if err != nil { t.Fatalf("unable to make test database: %v", err) @@ -1929,7 +1929,7 @@ func TestFilterChannelRange(t *testing.T) { func TestFetchChanInfos(t *testing.T) { t.Parallel() - db, cleanUp, err := makeTestDB() + db, cleanUp, err := MakeTestDB() defer cleanUp() if err != nil { t.Fatalf("unable to make test database: %v", err) @@ -2057,7 +2057,7 @@ func TestFetchChanInfos(t *testing.T) { func TestIncompleteChannelPolicies(t *testing.T) { t.Parallel() - db, cleanUp, err := makeTestDB() + db, cleanUp, err := MakeTestDB() defer cleanUp() if err != nil { t.Fatalf("unable to make test database: %v", err) @@ -2172,7 +2172,7 @@ func TestIncompleteChannelPolicies(t *testing.T) { func TestChannelEdgePruningUpdateIndexDeletion(t *testing.T) { t.Parallel() - db, cleanUp, err := makeTestDB() + db, cleanUp, err := MakeTestDB() defer cleanUp() if err != nil { t.Fatalf("unable to make test database: %v", err) @@ -2327,7 +2327,7 @@ func TestChannelEdgePruningUpdateIndexDeletion(t *testing.T) { func TestPruneGraphNodes(t *testing.T) { t.Parallel() - db, cleanUp, err := makeTestDB() + db, cleanUp, err := MakeTestDB() defer cleanUp() if err != nil { t.Fatalf("unable to make test database: %v", err) @@ -2411,7 +2411,7 @@ func TestPruneGraphNodes(t *testing.T) { func TestAddChannelEdgeShellNodes(t *testing.T) { t.Parallel() - db, cleanUp, err := makeTestDB() + db, cleanUp, err := MakeTestDB() defer cleanUp() if err != nil { t.Fatalf("unable to make test database: %v", err) @@ -2465,7 +2465,7 @@ func TestAddChannelEdgeShellNodes(t *testing.T) { func TestNodePruningUpdateIndexDeletion(t *testing.T) { t.Parallel() - db, cleanUp, err := makeTestDB() + db, cleanUp, err := MakeTestDB() defer cleanUp() if err != nil { t.Fatalf("unable to make test database: %v", err) @@ -2535,7 +2535,7 @@ func TestNodeIsPublic(t *testing.T) { // We'll need to create a separate database and channel graph for each // participant to replicate real-world scenarios (private edges being in // some graphs but not others, etc.). - aliceDB, cleanUp, err := makeTestDB() + aliceDB, cleanUp, err := MakeTestDB() defer cleanUp() if err != nil { t.Fatalf("unable to make test database: %v", err) @@ -2549,7 +2549,7 @@ func TestNodeIsPublic(t *testing.T) { t.Fatalf("unable to set source node: %v", err) } - bobDB, cleanUp, err := makeTestDB() + bobDB, cleanUp, err := MakeTestDB() defer cleanUp() if err != nil { t.Fatalf("unable to make test database: %v", err) @@ -2563,7 +2563,7 @@ func TestNodeIsPublic(t *testing.T) { t.Fatalf("unable to set source node: %v", err) } - carolDB, cleanUp, err := makeTestDB() + carolDB, cleanUp, err := MakeTestDB() defer cleanUp() if err != nil { t.Fatalf("unable to make test database: %v", err) @@ -2684,7 +2684,7 @@ func TestNodeIsPublic(t *testing.T) { func TestDisabledChannelIDs(t *testing.T) { t.Parallel() - db, cleanUp, err := makeTestDB() + db, cleanUp, err := MakeTestDB() if err != nil { t.Fatalf("unable to make test database: %v", err) } @@ -2782,7 +2782,7 @@ func TestDisabledChannelIDs(t *testing.T) { func TestEdgePolicyMissingMaxHtcl(t *testing.T) { t.Parallel() - db, cleanUp, err := makeTestDB() + db, cleanUp, err := MakeTestDB() defer cleanUp() if err != nil { t.Fatalf("unable to make test database: %v", err) @@ -2962,7 +2962,7 @@ func TestGraphZombieIndex(t *testing.T) { t.Parallel() // We'll start by creating our test graph along with a test edge. - db, cleanUp, err := makeTestDB() + db, cleanUp, err := MakeTestDB() defer cleanUp() if err != nil { t.Fatalf("unable to create test database: %v", err) @@ -3151,7 +3151,7 @@ func TestLightningNodeSigVerification(t *testing.T) { } // Create a LightningNode from the same private key. - db, cleanUp, err := makeTestDB() + db, cleanUp, err := MakeTestDB() if err != nil { t.Fatalf("unable to make test database: %v", err) } diff --git a/channeldb/invoice_test.go b/channeldb/invoice_test.go index 10148917..64e2dbe6 100644 --- a/channeldb/invoice_test.go +++ b/channeldb/invoice_test.go @@ -136,7 +136,7 @@ func TestInvoiceWorkflow(t *testing.T) { } func testInvoiceWorkflow(t *testing.T, test invWorkflowTest) { - db, cleanUp, err := makeTestDB() + db, cleanUp, err := MakeTestDB() defer cleanUp() if err != nil { t.Fatalf("unable to make test db: %v", err) @@ -290,7 +290,7 @@ func testInvoiceWorkflow(t *testing.T, test invWorkflowTest) { // TestAddDuplicatePayAddr asserts that the payment addresses of inserted // invoices are unique. func TestAddDuplicatePayAddr(t *testing.T) { - db, cleanUp, err := makeTestDB() + db, cleanUp, err := MakeTestDB() defer cleanUp() require.NoError(t, err) @@ -317,7 +317,7 @@ func TestAddDuplicatePayAddr(t *testing.T) { // addresses to be inserted if they are blank to support JIT legacy keysend // invoices. func TestAddDuplicateKeysendPayAddr(t *testing.T) { - db, cleanUp, err := makeTestDB() + db, cleanUp, err := MakeTestDB() defer cleanUp() require.NoError(t, err) @@ -358,7 +358,7 @@ func TestAddDuplicateKeysendPayAddr(t *testing.T) { // TestInvRefEquivocation asserts that retrieving or updating an invoice using // an equivocating InvoiceRef results in ErrInvRefEquivocation. func TestInvRefEquivocation(t *testing.T) { - db, cleanUp, err := makeTestDB() + db, cleanUp, err := MakeTestDB() defer cleanUp() require.NoError(t, err) @@ -398,7 +398,7 @@ func TestInvRefEquivocation(t *testing.T) { func TestInvoiceCancelSingleHtlc(t *testing.T) { t.Parallel() - db, cleanUp, err := makeTestDB() + db, cleanUp, err := MakeTestDB() defer cleanUp() if err != nil { t.Fatalf("unable to make test db: %v", err) @@ -472,7 +472,7 @@ func TestInvoiceCancelSingleHtlc(t *testing.T) { func TestInvoiceAddTimeSeries(t *testing.T) { t.Parallel() - db, cleanUp, err := makeTestDB() + db, cleanUp, err := MakeTestDB(OptionClock(testClock)) defer cleanUp() if err != nil { t.Fatalf("unable to make test db: %v", err) @@ -627,7 +627,7 @@ func TestInvoiceAddTimeSeries(t *testing.T) { func TestFetchAllInvoicesWithPaymentHash(t *testing.T) { t.Parallel() - db, cleanup, err := makeTestDB() + db, cleanup, err := MakeTestDB() defer cleanup() if err != nil { t.Fatalf("unable to make test db: %v", err) @@ -732,7 +732,7 @@ func TestFetchAllInvoicesWithPaymentHash(t *testing.T) { func TestDuplicateSettleInvoice(t *testing.T) { t.Parallel() - db, cleanUp, err := makeTestDB() + db, cleanUp, err := MakeTestDB(OptionClock(testClock)) defer cleanUp() if err != nil { t.Fatalf("unable to make test db: %v", err) @@ -797,7 +797,7 @@ func TestDuplicateSettleInvoice(t *testing.T) { func TestQueryInvoices(t *testing.T) { t.Parallel() - db, cleanUp, err := makeTestDB() + db, cleanUp, err := MakeTestDB(OptionClock(testClock)) defer cleanUp() if err != nil { t.Fatalf("unable to make test db: %v", err) @@ -1112,7 +1112,7 @@ func getUpdateInvoice(amt lnwire.MilliSatoshi) InvoiceUpdateCallback { func TestCustomRecords(t *testing.T) { t.Parallel() - db, cleanUp, err := makeTestDB() + db, cleanUp, err := MakeTestDB() defer cleanUp() if err != nil { t.Fatalf("unable to make test db: %v", err) diff --git a/channeldb/kvdb/etcd/bucket.go b/channeldb/kvdb/etcd/bucket.go index 3bc087db..8a1ff071 100644 --- a/channeldb/kvdb/etcd/bucket.go +++ b/channeldb/kvdb/etcd/bucket.go @@ -11,9 +11,9 @@ const ( ) var ( - bucketPrefix = []byte("b") - valuePrefix = []byte("v") - sequencePrefix = []byte("$") + valuePostfix = []byte{0x00} + bucketPostfix = []byte{0xFF} + sequencePrefix = []byte("$seq$") ) // makeBucketID returns a deterministic key for the passed byte slice. @@ -28,52 +28,65 @@ func isValidBucketID(s []byte) bool { return len(s) == bucketIDLength } -// makeKey concatenates prefix, parent and key into one byte slice. -// The prefix indicates the use of this key (whether bucket, value or sequence), -// while parentID refers to the parent bucket. -func makeKey(prefix, parent, key []byte) []byte { - keyBuf := make([]byte, len(prefix)+len(parent)+len(key)) - copy(keyBuf, prefix) - copy(keyBuf[len(prefix):], parent) - copy(keyBuf[len(prefix)+len(parent):], key) +// makeKey concatenates parent, key and postfix into one byte slice. +// The postfix indicates the use of this key (whether bucket or value), while +// parent refers to the parent bucket. +func makeKey(parent, key, postfix []byte) []byte { + keyBuf := make([]byte, len(parent)+len(key)+len(postfix)) + copy(keyBuf, parent) + copy(keyBuf[len(parent):], key) + copy(keyBuf[len(parent)+len(key):], postfix) return keyBuf } -// makePrefix concatenates prefix with parent into one byte slice. -func makePrefix(prefix []byte, parent []byte) []byte { - prefixBuf := make([]byte, len(prefix)+len(parent)) - copy(prefixBuf, prefix) - copy(prefixBuf[len(prefix):], parent) - - return prefixBuf -} - // makeBucketKey returns a bucket key from the passed parent bucket id and // the key. func makeBucketKey(parent []byte, key []byte) []byte { - return makeKey(bucketPrefix, parent, key) + return makeKey(parent, key, bucketPostfix) } // makeValueKey returns a value key from the passed parent bucket id and // the key. func makeValueKey(parent []byte, key []byte) []byte { - return makeKey(valuePrefix, parent, key) + return makeKey(parent, key, valuePostfix) } // makeSequenceKey returns a sequence key of the passed parent bucket id. func makeSequenceKey(parent []byte) []byte { - return makeKey(sequencePrefix, parent, nil) + keyBuf := make([]byte, len(sequencePrefix)+len(parent)) + copy(keyBuf, sequencePrefix) + copy(keyBuf[len(sequencePrefix):], parent) + return keyBuf } -// makeBucketPrefix returns the bucket prefix of the passed parent bucket id. -// This prefix is used for all sub buckets. -func makeBucketPrefix(parent []byte) []byte { - return makePrefix(bucketPrefix, parent) +// isBucketKey returns true if the passed key is a bucket key, meaning it +// keys a bucket name. +func isBucketKey(key string) bool { + if len(key) < bucketIDLength+1 { + return false + } + + return key[len(key)-1] == bucketPostfix[0] } -// makeValuePrefix returns the value prefix of the passed parent bucket id. -// This prefix is used for all key/values in the bucket. -func makeValuePrefix(parent []byte) []byte { - return makePrefix(valuePrefix, parent) +// getKey chops out the key from the raw key (by removing the bucket id +// prefixing the key and the postfix indicating whether it is a bucket or +// a value key) +func getKey(rawKey string) []byte { + return []byte(rawKey[bucketIDLength : len(rawKey)-1]) +} + +// getKeyVal chops out the key from the raw key (by removing the bucket id +// prefixing the key and the postfix indicating whether it is a bucket or +// a value key) and also returns the appropriate value for the key, which is +// nil in case of buckets (or the set value otherwise). +func getKeyVal(kv *KV) ([]byte, []byte) { + var val []byte + + if !isBucketKey(kv.key) { + val = []byte(kv.val) + } + + return getKey(kv.key), val } diff --git a/channeldb/kvdb/etcd/db_test.go b/channeldb/kvdb/etcd/db_test.go index 155d912e..c4332db8 100644 --- a/channeldb/kvdb/etcd/db_test.go +++ b/channeldb/kvdb/etcd/db_test.go @@ -8,7 +8,7 @@ import ( "testing" "github.com/btcsuite/btcwallet/walletdb" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestCopy(t *testing.T) { @@ -18,30 +18,30 @@ func TestCopy(t *testing.T) { defer f.Cleanup() db, err := newEtcdBackend(f.BackendConfig()) - assert.NoError(t, err) + require.NoError(t, err) err = db.Update(func(tx walletdb.ReadWriteTx) error { // "apple" apple, err := tx.CreateTopLevelBucket([]byte("apple")) - assert.NoError(t, err) - assert.NotNil(t, apple) + require.NoError(t, err) + require.NotNil(t, apple) - assert.NoError(t, apple.Put([]byte("key"), []byte("val"))) + require.NoError(t, apple.Put([]byte("key"), []byte("val"))) return nil }) // Expect non-zero copy. var buf bytes.Buffer - assert.NoError(t, db.Copy(&buf)) - assert.Greater(t, buf.Len(), 0) - assert.Nil(t, err) + require.NoError(t, db.Copy(&buf)) + require.Greater(t, buf.Len(), 0) + require.Nil(t, err) expected := map[string]string{ bkey("apple"): bval("apple"), vkey("key", "apple"): "val", } - assert.Equal(t, expected, f.Dump()) + require.Equal(t, expected, f.Dump()) } func TestAbortContext(t *testing.T) { @@ -57,19 +57,19 @@ func TestAbortContext(t *testing.T) { // Pass abort context and abort right away. db, err := newEtcdBackend(config) - assert.NoError(t, err) + require.NoError(t, err) cancel() // Expect that the update will fail. err = db.Update(func(tx walletdb.ReadWriteTx) error { _, err := tx.CreateTopLevelBucket([]byte("bucket")) - assert.NoError(t, err) + require.NoError(t, err) return nil }) - assert.Error(t, err, "context canceled") + require.Error(t, err, "context canceled") // No changes in the DB. - assert.Equal(t, map[string]string{}, f.Dump()) + require.Equal(t, map[string]string{}, f.Dump()) } diff --git a/channeldb/kvdb/etcd/driver_test.go b/channeldb/kvdb/etcd/driver_test.go index 365eda7a..ea4196ef 100644 --- a/channeldb/kvdb/etcd/driver_test.go +++ b/channeldb/kvdb/etcd/driver_test.go @@ -6,25 +6,25 @@ import ( "testing" "github.com/btcsuite/btcwallet/walletdb" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestOpenCreateFailure(t *testing.T) { t.Parallel() db, err := walletdb.Open(dbType) - assert.Error(t, err) - assert.Nil(t, db) + require.Error(t, err) + require.Nil(t, db) db, err = walletdb.Open(dbType, "wrong") - assert.Error(t, err) - assert.Nil(t, db) + require.Error(t, err) + require.Nil(t, db) db, err = walletdb.Create(dbType) - assert.Error(t, err) - assert.Nil(t, db) + require.Error(t, err) + require.Nil(t, db) db, err = walletdb.Create(dbType, "wrong") - assert.Error(t, err) - assert.Nil(t, db) + require.Error(t, err) + require.Nil(t, db) } diff --git a/channeldb/kvdb/etcd/readwrite_bucket.go b/channeldb/kvdb/etcd/readwrite_bucket.go index e60d2cec..dafab5ff 100644 --- a/channeldb/kvdb/etcd/readwrite_bucket.go +++ b/channeldb/kvdb/etcd/readwrite_bucket.go @@ -3,7 +3,6 @@ package etcd import ( - "bytes" "strconv" "github.com/btcsuite/btcwallet/walletdb" @@ -24,11 +23,6 @@ type readWriteBucket struct { // newReadWriteBucket creates a new rw bucket with the passed transaction // and bucket id. func newReadWriteBucket(tx *readWriteTx, key, id []byte) *readWriteBucket { - if !bytes.Equal(id, tx.rootBucketID[:]) { - // Add the bucket key/value to the lock set. - tx.lock(string(key), string(id)) - } - return &readWriteBucket{ id: id, tx: tx, @@ -46,44 +40,23 @@ func (b *readWriteBucket) NestedReadBucket(key []byte) walletdb.ReadBucket { // is nil, but it does not include the key/value pairs within those // nested buckets. func (b *readWriteBucket) ForEach(cb func(k, v []byte) error) error { - prefix := makeValuePrefix(b.id) - prefixLen := len(prefix) + prefix := string(b.id) // Get the first matching key that is in the bucket. - kv, err := b.tx.stm.First(string(prefix)) + kv, err := b.tx.stm.First(prefix) if err != nil { return err } for kv != nil { - if err := cb([]byte(kv.key[prefixLen:]), []byte(kv.val)); err != nil { + key, val := getKeyVal(kv) + + if err := cb(key, val); err != nil { return err } // Step to the next key. - kv, err = b.tx.stm.Next(string(prefix), kv.key) - if err != nil { - return err - } - } - - // Make a bucket prefix. This prefixes all sub buckets. - prefix = makeBucketPrefix(b.id) - prefixLen = len(prefix) - - // Get the first bucket. - kv, err = b.tx.stm.First(string(prefix)) - if err != nil { - return err - } - - for kv != nil { - if err := cb([]byte(kv.key[prefixLen:]), nil); err != nil { - return err - } - - // Step to the next bucket. - kv, err = b.tx.stm.Next(string(prefix), kv.key) + kv, err = b.tx.stm.Next(prefix, kv.key) if err != nil { return err } @@ -143,6 +116,20 @@ func (b *readWriteBucket) NestedReadWriteBucket(key []byte) walletdb.ReadWriteBu return newReadWriteBucket(b.tx, bucketKey, bucketVal) } +// assertNoValue checks if the value for the passed key exists. +func (b *readWriteBucket) assertNoValue(key []byte) error { + val, err := b.tx.stm.Get(string(makeValueKey(b.id, key))) + if err != nil { + return err + } + + if val != nil { + return walletdb.ErrIncompatibleValue + } + + return nil +} + // CreateBucket creates and returns a new nested bucket with the given // key. Returns ErrBucketExists if the bucket already exists, // ErrBucketNameRequired if the key is empty, or ErrIncompatibleValue @@ -168,11 +155,15 @@ func (b *readWriteBucket) CreateBucket(key []byte) ( return nil, walletdb.ErrBucketExists } + if err := b.assertNoValue(key); err != nil { + return nil, err + } + // Create a deterministic bucket id from the bucket key. newID := makeBucketID(bucketKey) // Create the bucket. - b.tx.put(string(bucketKey), string(newID[:])) + b.tx.stm.Put(string(bucketKey), string(newID[:])) return newReadWriteBucket(b.tx, bucketKey, newID[:]), nil } @@ -198,8 +189,12 @@ func (b *readWriteBucket) CreateBucketIfNotExists(key []byte) ( } if !isValidBucketID(bucketVal) { + if err := b.assertNoValue(key); err != nil { + return nil, err + } + newID := makeBucketID(bucketKey) - b.tx.put(string(bucketKey), string(newID[:])) + b.tx.stm.Put(string(bucketKey), string(newID[:])) return newReadWriteBucket(b.tx, bucketKey, newID[:]), nil } @@ -241,46 +236,31 @@ func (b *readWriteBucket) DeleteNestedBucket(key []byte) error { id := queue[0] queue = queue[1:] - // Delete values in the current bucket - valuePrefix := string(makeValuePrefix(id)) - - kv, err := b.tx.stm.First(valuePrefix) + kv, err := b.tx.stm.First(string(id)) if err != nil { return err } for kv != nil { - b.tx.del(kv.key) + b.tx.stm.Del(kv.key) - kv, err = b.tx.stm.Next(valuePrefix, kv.key) + if isBucketKey(kv.key) { + queue = append(queue, []byte(kv.val)) + } + + kv, err = b.tx.stm.Next(string(id), kv.key) if err != nil { return err } } - // Iterate sub buckets - bucketPrefix := string(makeBucketPrefix(id)) - - kv, err = b.tx.stm.First(bucketPrefix) - if err != nil { - return err - } - - for kv != nil { - // Delete sub bucket key. - b.tx.del(kv.key) - // Queue it for traversal. - queue = append(queue, []byte(kv.val)) - - kv, err = b.tx.stm.Next(bucketPrefix, kv.key) - if err != nil { - return err - } - } + // Finally delete the sequence key for the bucket. + b.tx.stm.Del(string(makeSequenceKey(id))) } - // Delete the top level bucket. - b.tx.del(bucketKey) + // Delete the top level bucket and sequence key. + b.tx.stm.Del(bucketKey) + b.tx.stm.Del(string(makeSequenceKey(bucketVal))) return nil } @@ -292,8 +272,17 @@ func (b *readWriteBucket) Put(key, value []byte) error { return walletdb.ErrKeyRequired } + val, err := b.tx.stm.Get(string(makeBucketKey(b.id, key))) + if err != nil { + return err + } + + if val != nil { + return walletdb.ErrIncompatibleValue + } + // Update the transaction with the new value. - b.tx.put(string(makeValueKey(b.id, key)), string(value)) + b.tx.stm.Put(string(makeValueKey(b.id, key)), string(value)) return nil } @@ -306,7 +295,7 @@ func (b *readWriteBucket) Delete(key []byte) error { } // Update the transaction to delete the key/value. - b.tx.del(string(makeValueKey(b.id, key))) + b.tx.stm.Del(string(makeValueKey(b.id, key))) return nil } @@ -336,7 +325,7 @@ func (b *readWriteBucket) SetSequence(v uint64) error { val := strconv.FormatUint(v, 10) // Update the transaction with the new value for the sequence key. - b.tx.put(string(makeSequenceKey(b.id)), val) + b.tx.stm.Put(string(makeSequenceKey(b.id)), val) return nil } diff --git a/channeldb/kvdb/etcd/readwrite_bucket_test.go b/channeldb/kvdb/etcd/readwrite_bucket_test.go index a3a5d620..2795dce3 100644 --- a/channeldb/kvdb/etcd/readwrite_bucket_test.go +++ b/channeldb/kvdb/etcd/readwrite_bucket_test.go @@ -8,7 +8,7 @@ import ( "testing" "github.com/btcsuite/btcwallet/walletdb" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestBucketCreation(t *testing.T) { @@ -18,70 +18,70 @@ func TestBucketCreation(t *testing.T) { defer f.Cleanup() db, err := newEtcdBackend(f.BackendConfig()) - assert.NoError(t, err) + require.NoError(t, err) err = db.Update(func(tx walletdb.ReadWriteTx) error { // empty bucket name b, err := tx.CreateTopLevelBucket(nil) - assert.Error(t, walletdb.ErrBucketNameRequired, err) - assert.Nil(t, b) + require.Error(t, walletdb.ErrBucketNameRequired, err) + require.Nil(t, b) // empty bucket name b, err = tx.CreateTopLevelBucket([]byte("")) - assert.Error(t, walletdb.ErrBucketNameRequired, err) - assert.Nil(t, b) + require.Error(t, walletdb.ErrBucketNameRequired, err) + require.Nil(t, b) // "apple" apple, err := tx.CreateTopLevelBucket([]byte("apple")) - assert.NoError(t, err) - assert.NotNil(t, apple) + require.NoError(t, err) + require.NotNil(t, apple) // Check bucket tx. - assert.Equal(t, tx, apple.Tx()) + require.Equal(t, tx, apple.Tx()) // "apple" already created b, err = tx.CreateTopLevelBucket([]byte("apple")) - assert.NoError(t, err) - assert.NotNil(t, b) + require.NoError(t, err) + require.NotNil(t, b) // "apple/banana" banana, err := apple.CreateBucket([]byte("banana")) - assert.NoError(t, err) - assert.NotNil(t, banana) + require.NoError(t, err) + require.NotNil(t, banana) banana, err = apple.CreateBucketIfNotExists([]byte("banana")) - assert.NoError(t, err) - assert.NotNil(t, banana) + require.NoError(t, err) + require.NotNil(t, banana) // Try creating "apple/banana" again b, err = apple.CreateBucket([]byte("banana")) - assert.Error(t, walletdb.ErrBucketExists, err) - assert.Nil(t, b) + require.Error(t, walletdb.ErrBucketExists, err) + require.Nil(t, b) // "apple/mango" mango, err := apple.CreateBucket([]byte("mango")) - assert.Nil(t, err) - assert.NotNil(t, mango) + require.Nil(t, err) + require.NotNil(t, mango) // "apple/banana/pear" pear, err := banana.CreateBucket([]byte("pear")) - assert.Nil(t, err) - assert.NotNil(t, pear) + require.Nil(t, err) + require.NotNil(t, pear) // empty bucket - assert.Nil(t, apple.NestedReadWriteBucket(nil)) - assert.Nil(t, apple.NestedReadWriteBucket([]byte(""))) + require.Nil(t, apple.NestedReadWriteBucket(nil)) + require.Nil(t, apple.NestedReadWriteBucket([]byte(""))) // "apple/pear" doesn't exist - assert.Nil(t, apple.NestedReadWriteBucket([]byte("pear"))) + require.Nil(t, apple.NestedReadWriteBucket([]byte("pear"))) // "apple/banana" exits - assert.NotNil(t, apple.NestedReadWriteBucket([]byte("banana"))) - assert.NotNil(t, apple.NestedReadBucket([]byte("banana"))) + require.NotNil(t, apple.NestedReadWriteBucket([]byte("banana"))) + require.NotNil(t, apple.NestedReadBucket([]byte("banana"))) return nil }) - assert.Nil(t, err) + require.Nil(t, err) expected := map[string]string{ bkey("apple"): bval("apple"), @@ -89,7 +89,7 @@ func TestBucketCreation(t *testing.T) { bkey("apple", "mango"): bval("apple", "mango"), bkey("apple", "banana", "pear"): bval("apple", "banana", "pear"), } - assert.Equal(t, expected, f.Dump()) + require.Equal(t, expected, f.Dump()) } func TestBucketDeletion(t *testing.T) { @@ -99,99 +99,99 @@ func TestBucketDeletion(t *testing.T) { defer f.Cleanup() db, err := newEtcdBackend(f.BackendConfig()) - assert.NoError(t, err) + require.NoError(t, err) err = db.Update(func(tx walletdb.ReadWriteTx) error { // "apple" apple, err := tx.CreateTopLevelBucket([]byte("apple")) - assert.Nil(t, err) - assert.NotNil(t, apple) + require.Nil(t, err) + require.NotNil(t, apple) // "apple/banana" banana, err := apple.CreateBucket([]byte("banana")) - assert.Nil(t, err) - assert.NotNil(t, banana) + require.Nil(t, err) + require.NotNil(t, banana) kvs := []KV{{"key1", "val1"}, {"key2", "val2"}, {"key3", "val3"}} for _, kv := range kvs { - assert.NoError(t, banana.Put([]byte(kv.key), []byte(kv.val))) - assert.Equal(t, []byte(kv.val), banana.Get([]byte(kv.key))) + require.NoError(t, banana.Put([]byte(kv.key), []byte(kv.val))) + require.Equal(t, []byte(kv.val), banana.Get([]byte(kv.key))) } // Delete a k/v from "apple/banana" - assert.NoError(t, banana.Delete([]byte("key2"))) + require.NoError(t, banana.Delete([]byte("key2"))) // Try getting/putting/deleting invalid k/v's. - assert.Nil(t, banana.Get(nil)) - assert.Error(t, walletdb.ErrKeyRequired, banana.Put(nil, []byte("val"))) - assert.Error(t, walletdb.ErrKeyRequired, banana.Delete(nil)) + require.Nil(t, banana.Get(nil)) + require.Error(t, walletdb.ErrKeyRequired, banana.Put(nil, []byte("val"))) + require.Error(t, walletdb.ErrKeyRequired, banana.Delete(nil)) // Try deleting a k/v that doesn't exist. - assert.NoError(t, banana.Delete([]byte("nokey"))) + require.NoError(t, banana.Delete([]byte("nokey"))) // "apple/pear" pear, err := apple.CreateBucket([]byte("pear")) - assert.Nil(t, err) - assert.NotNil(t, pear) + require.Nil(t, err) + require.NotNil(t, pear) // Put some values into "apple/pear" for _, kv := range kvs { - assert.Nil(t, pear.Put([]byte(kv.key), []byte(kv.val))) - assert.Equal(t, []byte(kv.val), pear.Get([]byte(kv.key))) + require.Nil(t, pear.Put([]byte(kv.key), []byte(kv.val))) + require.Equal(t, []byte(kv.val), pear.Get([]byte(kv.key))) } // Create nested bucket "apple/pear/cherry" cherry, err := pear.CreateBucket([]byte("cherry")) - assert.Nil(t, err) - assert.NotNil(t, cherry) + require.Nil(t, err) + require.NotNil(t, cherry) // Put some values into "apple/pear/cherry" for _, kv := range kvs { - assert.NoError(t, cherry.Put([]byte(kv.key), []byte(kv.val))) + require.NoError(t, cherry.Put([]byte(kv.key), []byte(kv.val))) } // Read back values in "apple/pear/cherry" trough a read bucket. cherryReadBucket := pear.NestedReadBucket([]byte("cherry")) for _, kv := range kvs { - assert.Equal( + require.Equal( t, []byte(kv.val), cherryReadBucket.Get([]byte(kv.key)), ) } // Try deleting some invalid buckets. - assert.Error(t, + require.Error(t, walletdb.ErrBucketNameRequired, apple.DeleteNestedBucket(nil), ) // Try deleting a non existing bucket. - assert.Error( + require.Error( t, walletdb.ErrBucketNotFound, apple.DeleteNestedBucket([]byte("missing")), ) // Delete "apple/pear" - assert.Nil(t, apple.DeleteNestedBucket([]byte("pear"))) + require.Nil(t, apple.DeleteNestedBucket([]byte("pear"))) // "apple/pear" deleted - assert.Nil(t, apple.NestedReadWriteBucket([]byte("pear"))) + require.Nil(t, apple.NestedReadWriteBucket([]byte("pear"))) // "apple/pear/cherry" deleted - assert.Nil(t, pear.NestedReadWriteBucket([]byte("cherry"))) + require.Nil(t, pear.NestedReadWriteBucket([]byte("cherry"))) // Values deleted too. for _, kv := range kvs { - assert.Nil(t, pear.Get([]byte(kv.key))) - assert.Nil(t, cherry.Get([]byte(kv.key))) + require.Nil(t, pear.Get([]byte(kv.key))) + require.Nil(t, cherry.Get([]byte(kv.key))) } // "aple/banana" exists - assert.NotNil(t, apple.NestedReadWriteBucket([]byte("banana"))) + require.NotNil(t, apple.NestedReadWriteBucket([]byte("banana"))) return nil }) - assert.Nil(t, err) + require.Nil(t, err) expected := map[string]string{ bkey("apple"): bval("apple"), @@ -199,7 +199,7 @@ func TestBucketDeletion(t *testing.T) { vkey("key1", "apple", "banana"): "val1", vkey("key3", "apple", "banana"): "val3", } - assert.Equal(t, expected, f.Dump()) + require.Equal(t, expected, f.Dump()) } func TestBucketForEach(t *testing.T) { @@ -209,28 +209,28 @@ func TestBucketForEach(t *testing.T) { defer f.Cleanup() db, err := newEtcdBackend(f.BackendConfig()) - assert.NoError(t, err) + require.NoError(t, err) err = db.Update(func(tx walletdb.ReadWriteTx) error { // "apple" apple, err := tx.CreateTopLevelBucket([]byte("apple")) - assert.Nil(t, err) - assert.NotNil(t, apple) + require.Nil(t, err) + require.NotNil(t, apple) // "apple/banana" banana, err := apple.CreateBucket([]byte("banana")) - assert.Nil(t, err) - assert.NotNil(t, banana) + require.Nil(t, err) + require.NotNil(t, banana) kvs := []KV{{"key1", "val1"}, {"key2", "val2"}, {"key3", "val3"}} // put some values into "apple" and "apple/banana" too for _, kv := range kvs { - assert.Nil(t, apple.Put([]byte(kv.key), []byte(kv.val))) - assert.Equal(t, []byte(kv.val), apple.Get([]byte(kv.key))) + require.Nil(t, apple.Put([]byte(kv.key), []byte(kv.val))) + require.Equal(t, []byte(kv.val), apple.Get([]byte(kv.key))) - assert.Nil(t, banana.Put([]byte(kv.key), []byte(kv.val))) - assert.Equal(t, []byte(kv.val), banana.Get([]byte(kv.key))) + require.Nil(t, banana.Put([]byte(kv.key), []byte(kv.val))) + require.Equal(t, []byte(kv.val), banana.Get([]byte(kv.key))) } got := make(map[string]string) @@ -246,8 +246,8 @@ func TestBucketForEach(t *testing.T) { "banana": "", } - assert.NoError(t, err) - assert.Equal(t, expected, got) + require.NoError(t, err) + require.Equal(t, expected, got) got = make(map[string]string) err = banana.ForEach(func(key, val []byte) error { @@ -255,15 +255,15 @@ func TestBucketForEach(t *testing.T) { return nil }) - assert.NoError(t, err) + require.NoError(t, err) // remove the sub-bucket key delete(expected, "banana") - assert.Equal(t, expected, got) + require.Equal(t, expected, got) return nil }) - assert.Nil(t, err) + require.Nil(t, err) expected := map[string]string{ bkey("apple"): bval("apple"), @@ -275,7 +275,7 @@ func TestBucketForEach(t *testing.T) { vkey("key2", "apple", "banana"): "val2", vkey("key3", "apple", "banana"): "val3", } - assert.Equal(t, expected, f.Dump()) + require.Equal(t, expected, f.Dump()) } func TestBucketForEachWithError(t *testing.T) { @@ -285,37 +285,37 @@ func TestBucketForEachWithError(t *testing.T) { defer f.Cleanup() db, err := newEtcdBackend(f.BackendConfig()) - assert.NoError(t, err) + require.NoError(t, err) err = db.Update(func(tx walletdb.ReadWriteTx) error { // "apple" apple, err := tx.CreateTopLevelBucket([]byte("apple")) - assert.Nil(t, err) - assert.NotNil(t, apple) + require.Nil(t, err) + require.NotNil(t, apple) // "apple/banana" banana, err := apple.CreateBucket([]byte("banana")) - assert.Nil(t, err) - assert.NotNil(t, banana) + require.Nil(t, err) + require.NotNil(t, banana) // "apple/pear" pear, err := apple.CreateBucket([]byte("pear")) - assert.Nil(t, err) - assert.NotNil(t, pear) + require.Nil(t, err) + require.NotNil(t, pear) kvs := []KV{{"key1", "val1"}, {"key2", "val2"}} // Put some values into "apple" and "apple/banana" too. for _, kv := range kvs { - assert.Nil(t, apple.Put([]byte(kv.key), []byte(kv.val))) - assert.Equal(t, []byte(kv.val), apple.Get([]byte(kv.key))) + require.Nil(t, apple.Put([]byte(kv.key), []byte(kv.val))) + require.Equal(t, []byte(kv.val), apple.Get([]byte(kv.key))) } got := make(map[string]string) i := 0 // Error while iterating value keys. err = apple.ForEach(func(key, val []byte) error { - if i == 1 { + if i == 2 { return fmt.Errorf("error") } @@ -325,11 +325,12 @@ func TestBucketForEachWithError(t *testing.T) { }) expected := map[string]string{ - "key1": "val1", + "banana": "", + "key1": "val1", } - assert.Equal(t, expected, got) - assert.Error(t, err) + require.Equal(t, expected, got) + require.Error(t, err) got = make(map[string]string) i = 0 @@ -345,17 +346,17 @@ func TestBucketForEachWithError(t *testing.T) { }) expected = map[string]string{ + "banana": "", "key1": "val1", "key2": "val2", - "banana": "", } - assert.Equal(t, expected, got) - assert.Error(t, err) + require.Equal(t, expected, got) + require.Error(t, err) return nil }) - assert.Nil(t, err) + require.Nil(t, err) expected := map[string]string{ bkey("apple"): bval("apple"), @@ -364,7 +365,7 @@ func TestBucketForEachWithError(t *testing.T) { vkey("key1", "apple"): "val1", vkey("key2", "apple"): "val2", } - assert.Equal(t, expected, f.Dump()) + require.Equal(t, expected, f.Dump()) } func TestBucketSequence(t *testing.T) { @@ -374,31 +375,149 @@ func TestBucketSequence(t *testing.T) { defer f.Cleanup() db, err := newEtcdBackend(f.BackendConfig()) - assert.NoError(t, err) + require.NoError(t, err) err = db.Update(func(tx walletdb.ReadWriteTx) error { apple, err := tx.CreateTopLevelBucket([]byte("apple")) - assert.Nil(t, err) - assert.NotNil(t, apple) + require.Nil(t, err) + require.NotNil(t, apple) banana, err := apple.CreateBucket([]byte("banana")) - assert.Nil(t, err) - assert.NotNil(t, banana) + require.Nil(t, err) + require.NotNil(t, banana) - assert.Equal(t, uint64(0), apple.Sequence()) - assert.Equal(t, uint64(0), banana.Sequence()) + require.Equal(t, uint64(0), apple.Sequence()) + require.Equal(t, uint64(0), banana.Sequence()) - assert.Nil(t, apple.SetSequence(math.MaxUint64)) - assert.Equal(t, uint64(math.MaxUint64), apple.Sequence()) + require.Nil(t, apple.SetSequence(math.MaxUint64)) + require.Equal(t, uint64(math.MaxUint64), apple.Sequence()) for i := uint64(0); i < uint64(5); i++ { s, err := apple.NextSequence() - assert.Nil(t, err) - assert.Equal(t, i, s) + require.Nil(t, err) + require.Equal(t, i, s) } return nil }) - assert.Nil(t, err) + require.Nil(t, err) +} + +// TestKeyClash tests that one cannot create a bucket if a value with the same +// key exists and the same is true in reverse: that a value cannot be put if +// a bucket with the same key exists. +func TestKeyClash(t *testing.T) { + t.Parallel() + + f := NewEtcdTestFixture(t) + defer f.Cleanup() + + db, err := newEtcdBackend(f.BackendConfig()) + require.NoError(t, err) + + // First: + // put: /apple/key -> val + // create bucket: /apple/banana + err = db.Update(func(tx walletdb.ReadWriteTx) error { + apple, err := tx.CreateTopLevelBucket([]byte("apple")) + require.Nil(t, err) + require.NotNil(t, apple) + + require.NoError(t, apple.Put([]byte("key"), []byte("val"))) + + banana, err := apple.CreateBucket([]byte("banana")) + require.Nil(t, err) + require.NotNil(t, banana) + + return nil + }) + + require.Nil(t, err) + + // Next try to: + // put: /apple/banana -> val => will fail (as /apple/banana is a bucket) + // create bucket: /apple/key => will fail (as /apple/key is a value) + err = db.Update(func(tx walletdb.ReadWriteTx) error { + apple, err := tx.CreateTopLevelBucket([]byte("apple")) + require.Nil(t, err) + require.NotNil(t, apple) + + require.Error(t, + walletdb.ErrIncompatibleValue, + apple.Put([]byte("banana"), []byte("val")), + ) + + b, err := apple.CreateBucket([]byte("key")) + require.Nil(t, b) + require.Error(t, walletdb.ErrIncompatibleValue, b) + + b, err = apple.CreateBucketIfNotExists([]byte("key")) + require.Nil(t, b) + require.Error(t, walletdb.ErrIncompatibleValue, b) + + return nil + }) + + require.Nil(t, err) + + // Except that the only existing items in the db are: + // bucket: /apple + // bucket: /apple/banana + // value: /apple/key -> val + expected := map[string]string{ + bkey("apple"): bval("apple"), + bkey("apple", "banana"): bval("apple", "banana"), + vkey("key", "apple"): "val", + } + require.Equal(t, expected, f.Dump()) + +} + +// TestBucketCreateDelete tests that creating then deleting then creating a +// bucket suceeds. +func TestBucketCreateDelete(t *testing.T) { + t.Parallel() + f := NewEtcdTestFixture(t) + defer f.Cleanup() + + db, err := newEtcdBackend(f.BackendConfig()) + require.NoError(t, err) + + err = db.Update(func(tx walletdb.ReadWriteTx) error { + apple, err := tx.CreateTopLevelBucket([]byte("apple")) + require.NoError(t, err) + require.NotNil(t, apple) + + banana, err := apple.CreateBucket([]byte("banana")) + require.NoError(t, err) + require.NotNil(t, banana) + + return nil + }) + require.NoError(t, err) + + err = db.Update(func(tx walletdb.ReadWriteTx) error { + apple := tx.ReadWriteBucket([]byte("apple")) + require.NotNil(t, apple) + require.NoError(t, apple.DeleteNestedBucket([]byte("banana"))) + + return nil + }) + require.NoError(t, err) + + err = db.Update(func(tx walletdb.ReadWriteTx) error { + apple := tx.ReadWriteBucket([]byte("apple")) + require.NotNil(t, apple) + require.NoError(t, apple.Put([]byte("banana"), []byte("value"))) + + return nil + }) + require.NoError(t, err) + + expected := map[string]string{ + vkey("banana", "apple"): "value", + bkey("apple"): bval("apple"), + } + require.Equal(t, expected, f.Dump()) } diff --git a/channeldb/kvdb/etcd/readwrite_cursor.go b/channeldb/kvdb/etcd/readwrite_cursor.go index 98965693..75c0456d 100644 --- a/channeldb/kvdb/etcd/readwrite_cursor.go +++ b/channeldb/kvdb/etcd/readwrite_cursor.go @@ -19,7 +19,7 @@ type readWriteCursor struct { func newReadWriteCursor(bucket *readWriteBucket) *readWriteCursor { return &readWriteCursor{ bucket: bucket, - prefix: string(makeValuePrefix(bucket.id)), + prefix: string(bucket.id), } } @@ -35,8 +35,7 @@ func (c *readWriteCursor) First() (key, value []byte) { if kv != nil { c.currKey = kv.key - // Chop the prefix and return the key/value. - return []byte(kv.key[len(c.prefix):]), []byte(kv.val) + return getKeyVal(kv) } return nil, nil @@ -53,8 +52,7 @@ func (c *readWriteCursor) Last() (key, value []byte) { if kv != nil { c.currKey = kv.key - // Chop the prefix and return the key/value. - return []byte(kv.key[len(c.prefix):]), []byte(kv.val) + return getKeyVal(kv) } return nil, nil @@ -71,8 +69,7 @@ func (c *readWriteCursor) Next() (key, value []byte) { if kv != nil { c.currKey = kv.key - // Chop the prefix and return the key/value. - return []byte(kv.key[len(c.prefix):]), []byte(kv.val) + return getKeyVal(kv) } return nil, nil @@ -89,8 +86,7 @@ func (c *readWriteCursor) Prev() (key, value []byte) { if kv != nil { c.currKey = kv.key - // Chop the prefix and return the key/value. - return []byte(kv.key[len(c.prefix):]), []byte(kv.val) + return getKeyVal(kv) } return nil, nil @@ -115,8 +111,7 @@ func (c *readWriteCursor) Seek(seek []byte) (key, value []byte) { if kv != nil { c.currKey = kv.key - // Chop the prefix and return the key/value. - return []byte(kv.key[len(c.prefix):]), []byte(kv.val) + return getKeyVal(kv) } return nil, nil @@ -133,11 +128,14 @@ func (c *readWriteCursor) Delete() error { return err } - // Delete the current key. - c.bucket.tx.stm.Del(c.currKey) + if isBucketKey(c.currKey) { + c.bucket.DeleteNestedBucket(getKey(c.currKey)) + } else { + c.bucket.Delete(getKey(c.currKey)) + } - // Set current key to the next one if possible. if nextKey != nil { + // Set current key to the next one. c.currKey = nextKey.key } diff --git a/channeldb/kvdb/etcd/readwrite_cursor_test.go b/channeldb/kvdb/etcd/readwrite_cursor_test.go index c14de7aa..216b47c4 100644 --- a/channeldb/kvdb/etcd/readwrite_cursor_test.go +++ b/channeldb/kvdb/etcd/readwrite_cursor_test.go @@ -6,7 +6,7 @@ import ( "testing" "github.com/btcsuite/btcwallet/walletdb" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestReadCursorEmptyInterval(t *testing.T) { @@ -16,41 +16,41 @@ func TestReadCursorEmptyInterval(t *testing.T) { defer f.Cleanup() db, err := newEtcdBackend(f.BackendConfig()) - assert.NoError(t, err) + require.NoError(t, err) err = db.Update(func(tx walletdb.ReadWriteTx) error { - b, err := tx.CreateTopLevelBucket([]byte("alma")) - assert.NoError(t, err) - assert.NotNil(t, b) + b, err := tx.CreateTopLevelBucket([]byte("apple")) + require.NoError(t, err) + require.NotNil(t, b) return nil }) - assert.NoError(t, err) + require.NoError(t, err) err = db.View(func(tx walletdb.ReadTx) error { - b := tx.ReadBucket([]byte("alma")) - assert.NotNil(t, b) + b := tx.ReadBucket([]byte("apple")) + require.NotNil(t, b) cursor := b.ReadCursor() k, v := cursor.First() - assert.Nil(t, k) - assert.Nil(t, v) + require.Nil(t, k) + require.Nil(t, v) k, v = cursor.Next() - assert.Nil(t, k) - assert.Nil(t, v) + require.Nil(t, k) + require.Nil(t, v) k, v = cursor.Last() - assert.Nil(t, k) - assert.Nil(t, v) + require.Nil(t, k) + require.Nil(t, v) k, v = cursor.Prev() - assert.Nil(t, k) - assert.Nil(t, v) + require.Nil(t, k) + require.Nil(t, v) return nil }) - assert.NoError(t, err) + require.NoError(t, err) } func TestReadCursorNonEmptyInterval(t *testing.T) { @@ -60,7 +60,7 @@ func TestReadCursorNonEmptyInterval(t *testing.T) { defer f.Cleanup() db, err := newEtcdBackend(f.BackendConfig()) - assert.NoError(t, err) + require.NoError(t, err) testKeyValues := []KV{ {"b", "1"}, @@ -70,21 +70,21 @@ func TestReadCursorNonEmptyInterval(t *testing.T) { } err = db.Update(func(tx walletdb.ReadWriteTx) error { - b, err := tx.CreateTopLevelBucket([]byte("alma")) - assert.NoError(t, err) - assert.NotNil(t, b) + b, err := tx.CreateTopLevelBucket([]byte("apple")) + require.NoError(t, err) + require.NotNil(t, b) for _, kv := range testKeyValues { - assert.NoError(t, b.Put([]byte(kv.key), []byte(kv.val))) + require.NoError(t, b.Put([]byte(kv.key), []byte(kv.val))) } return nil }) - assert.NoError(t, err) + require.NoError(t, err) err = db.View(func(tx walletdb.ReadTx) error { - b := tx.ReadBucket([]byte("alma")) - assert.NotNil(t, b) + b := tx.ReadBucket([]byte("apple")) + require.NotNil(t, b) // Iterate from the front. var kvs []KV @@ -95,7 +95,7 @@ func TestReadCursorNonEmptyInterval(t *testing.T) { kvs = append(kvs, KV{string(k), string(v)}) k, v = cursor.Next() } - assert.Equal(t, testKeyValues, kvs) + require.Equal(t, testKeyValues, kvs) // Iterate from the back. kvs = []KV{} @@ -105,29 +105,29 @@ func TestReadCursorNonEmptyInterval(t *testing.T) { kvs = append(kvs, KV{string(k), string(v)}) k, v = cursor.Prev() } - assert.Equal(t, reverseKVs(testKeyValues), kvs) + require.Equal(t, reverseKVs(testKeyValues), kvs) // Random access perm := []int{3, 0, 2, 1} for _, i := range perm { k, v := cursor.Seek([]byte(testKeyValues[i].key)) - assert.Equal(t, []byte(testKeyValues[i].key), k) - assert.Equal(t, []byte(testKeyValues[i].val), v) + require.Equal(t, []byte(testKeyValues[i].key), k) + require.Equal(t, []byte(testKeyValues[i].val), v) } // Seek to nonexisting key. k, v = cursor.Seek(nil) - assert.Nil(t, k) - assert.Nil(t, v) + require.Nil(t, k) + require.Nil(t, v) k, v = cursor.Seek([]byte("x")) - assert.Nil(t, k) - assert.Nil(t, v) + require.Nil(t, k) + require.Nil(t, v) return nil }) - assert.NoError(t, err) + require.NoError(t, err) } func TestReadWriteCursor(t *testing.T) { @@ -137,7 +137,7 @@ func TestReadWriteCursor(t *testing.T) { defer f.Cleanup() db, err := newEtcdBackend(f.BackendConfig()) - assert.NoError(t, err) + require.NoError(t, err) testKeyValues := []KV{ {"b", "1"}, @@ -149,24 +149,24 @@ func TestReadWriteCursor(t *testing.T) { count := len(testKeyValues) // Pre-store the first half of the interval. - assert.NoError(t, db.Update(func(tx walletdb.ReadWriteTx) error { + require.NoError(t, db.Update(func(tx walletdb.ReadWriteTx) error { b, err := tx.CreateTopLevelBucket([]byte("apple")) - assert.NoError(t, err) - assert.NotNil(t, b) + require.NoError(t, err) + require.NotNil(t, b) for i := 0; i < count/2; i++ { err = b.Put( []byte(testKeyValues[i].key), []byte(testKeyValues[i].val), ) - assert.NoError(t, err) + require.NoError(t, err) } return nil })) err = db.Update(func(tx walletdb.ReadWriteTx) error { b := tx.ReadWriteBucket([]byte("apple")) - assert.NotNil(t, b) + require.NotNil(t, b) // Store the second half of the interval. for i := count / 2; i < count; i++ { @@ -174,77 +174,77 @@ func TestReadWriteCursor(t *testing.T) { []byte(testKeyValues[i].key), []byte(testKeyValues[i].val), ) - assert.NoError(t, err) + require.NoError(t, err) } cursor := b.ReadWriteCursor() // First on valid interval. fk, fv := cursor.First() - assert.Equal(t, []byte("b"), fk) - assert.Equal(t, []byte("1"), fv) + require.Equal(t, []byte("b"), fk) + require.Equal(t, []byte("1"), fv) // Prev(First()) = nil k, v := cursor.Prev() - assert.Nil(t, k) - assert.Nil(t, v) + require.Nil(t, k) + require.Nil(t, v) // Last on valid interval. lk, lv := cursor.Last() - assert.Equal(t, []byte("e"), lk) - assert.Equal(t, []byte("4"), lv) + require.Equal(t, []byte("e"), lk) + require.Equal(t, []byte("4"), lv) // Next(Last()) = nil k, v = cursor.Next() - assert.Nil(t, k) - assert.Nil(t, v) + require.Nil(t, k) + require.Nil(t, v) // Delete first item, then add an item before the // deleted one. Check that First/Next will "jump" // over the deleted item and return the new first. _, _ = cursor.First() - assert.NoError(t, cursor.Delete()) - assert.NoError(t, b.Put([]byte("a"), []byte("0"))) + require.NoError(t, cursor.Delete()) + require.NoError(t, b.Put([]byte("a"), []byte("0"))) fk, fv = cursor.First() - assert.Equal(t, []byte("a"), fk) - assert.Equal(t, []byte("0"), fv) + require.Equal(t, []byte("a"), fk) + require.Equal(t, []byte("0"), fv) k, v = cursor.Next() - assert.Equal(t, []byte("c"), k) - assert.Equal(t, []byte("2"), v) + require.Equal(t, []byte("c"), k) + require.Equal(t, []byte("2"), v) // Similarly test that a new end is returned if // the old end is deleted first. _, _ = cursor.Last() - assert.NoError(t, cursor.Delete()) - assert.NoError(t, b.Put([]byte("f"), []byte("5"))) + require.NoError(t, cursor.Delete()) + require.NoError(t, b.Put([]byte("f"), []byte("5"))) lk, lv = cursor.Last() - assert.Equal(t, []byte("f"), lk) - assert.Equal(t, []byte("5"), lv) + require.Equal(t, []byte("f"), lk) + require.Equal(t, []byte("5"), lv) k, v = cursor.Prev() - assert.Equal(t, []byte("da"), k) - assert.Equal(t, []byte("3"), v) + require.Equal(t, []byte("da"), k) + require.Equal(t, []byte("3"), v) // Overwrite k/v in the middle of the interval. - assert.NoError(t, b.Put([]byte("c"), []byte("3"))) + require.NoError(t, b.Put([]byte("c"), []byte("3"))) k, v = cursor.Prev() - assert.Equal(t, []byte("c"), k) - assert.Equal(t, []byte("3"), v) + require.Equal(t, []byte("c"), k) + require.Equal(t, []byte("3"), v) // Insert new key/values. - assert.NoError(t, b.Put([]byte("cx"), []byte("x"))) - assert.NoError(t, b.Put([]byte("cy"), []byte("y"))) + require.NoError(t, b.Put([]byte("cx"), []byte("x"))) + require.NoError(t, b.Put([]byte("cy"), []byte("y"))) k, v = cursor.Next() - assert.Equal(t, []byte("cx"), k) - assert.Equal(t, []byte("x"), v) + require.Equal(t, []byte("cx"), k) + require.Equal(t, []byte("x"), v) k, v = cursor.Next() - assert.Equal(t, []byte("cy"), k) - assert.Equal(t, []byte("y"), v) + require.Equal(t, []byte("cy"), k) + require.Equal(t, []byte("y"), v) expected := []KV{ {"a", "0"}, @@ -263,7 +263,7 @@ func TestReadWriteCursor(t *testing.T) { kvs = append(kvs, KV{string(k), string(v)}) k, v = cursor.Next() } - assert.Equal(t, expected, kvs) + require.Equal(t, expected, kvs) // Iterate from the back. kvs = []KV{} @@ -273,12 +273,12 @@ func TestReadWriteCursor(t *testing.T) { kvs = append(kvs, KV{string(k), string(v)}) k, v = cursor.Prev() } - assert.Equal(t, reverseKVs(expected), kvs) + require.Equal(t, reverseKVs(expected), kvs) return nil }) - assert.NoError(t, err) + require.NoError(t, err) expected := map[string]string{ bkey("apple"): bval("apple"), @@ -289,5 +289,80 @@ func TestReadWriteCursor(t *testing.T) { vkey("da", "apple"): "3", vkey("f", "apple"): "5", } - assert.Equal(t, expected, f.Dump()) + require.Equal(t, expected, f.Dump()) +} + +// TestReadWriteCursorWithBucketAndValue tests that cursors are able to iterate +// over both bucket and value keys if both are present in the iterated bucket. +func TestReadWriteCursorWithBucketAndValue(t *testing.T) { + t.Parallel() + + f := NewEtcdTestFixture(t) + defer f.Cleanup() + + db, err := newEtcdBackend(f.BackendConfig()) + require.NoError(t, err) + + // Pre-store the first half of the interval. + require.NoError(t, db.Update(func(tx walletdb.ReadWriteTx) error { + b, err := tx.CreateTopLevelBucket([]byte("apple")) + require.NoError(t, err) + require.NotNil(t, b) + + require.NoError(t, b.Put([]byte("key"), []byte("val"))) + + b1, err := b.CreateBucket([]byte("banana")) + require.NoError(t, err) + require.NotNil(t, b1) + + b2, err := b.CreateBucket([]byte("pear")) + require.NoError(t, err) + require.NotNil(t, b2) + + return nil + })) + + err = db.View(func(tx walletdb.ReadTx) error { + b := tx.ReadBucket([]byte("apple")) + require.NotNil(t, b) + + cursor := b.ReadCursor() + + // First on valid interval. + k, v := cursor.First() + require.Equal(t, []byte("banana"), k) + require.Nil(t, v) + + k, v = cursor.Next() + require.Equal(t, []byte("key"), k) + require.Equal(t, []byte("val"), v) + + k, v = cursor.Last() + require.Equal(t, []byte("pear"), k) + require.Nil(t, v) + + k, v = cursor.Seek([]byte("k")) + require.Equal(t, []byte("key"), k) + require.Equal(t, []byte("val"), v) + + k, v = cursor.Seek([]byte("banana")) + require.Equal(t, []byte("banana"), k) + require.Nil(t, v) + + k, v = cursor.Next() + require.Equal(t, []byte("key"), k) + require.Equal(t, []byte("val"), v) + + return nil + }) + + require.NoError(t, err) + + expected := map[string]string{ + bkey("apple"): bval("apple"), + bkey("apple", "banana"): bval("apple", "banana"), + bkey("apple", "pear"): bval("apple", "pear"), + vkey("key", "apple"): "val", + } + require.Equal(t, expected, f.Dump()) } diff --git a/channeldb/kvdb/etcd/readwrite_tx.go b/channeldb/kvdb/etcd/readwrite_tx.go index 22d0ce42..81c27323 100644 --- a/channeldb/kvdb/etcd/readwrite_tx.go +++ b/channeldb/kvdb/etcd/readwrite_tx.go @@ -17,14 +17,6 @@ type readWriteTx struct { // active is true if the transaction hasn't been committed yet. active bool - - // dirty is true if we intent to update a value in this transaction. - dirty bool - - // lset holds key/value set that we want to lock on. If upon commit the - // transaction is dirty and the lset is not empty, we'll bump the mod - // version of these key/values. - lset map[string]string } // newReadWriteTx creates an rw transaction with the passed STM. @@ -33,7 +25,6 @@ func newReadWriteTx(stm STM, prefix string) *readWriteTx { stm: stm, active: true, rootBucketID: makeBucketID([]byte(prefix)), - lset: make(map[string]string), } } @@ -43,50 +34,6 @@ func rootBucket(tx *readWriteTx) *readWriteBucket { return newReadWriteBucket(tx, tx.rootBucketID[:], tx.rootBucketID[:]) } -// lock adds a key value to the lock set. -func (tx *readWriteTx) lock(key, val string) { - tx.stm.Lock(key) - if !tx.dirty { - tx.lset[key] = val - } else { - // Bump the mod version of the key, - // leaving the value intact. - tx.stm.Put(key, val) - } -} - -// put updates the passed key/value. -func (tx *readWriteTx) put(key, val string) { - tx.stm.Put(key, val) - tx.setDirty() -} - -// del marks the passed key deleted. -func (tx *readWriteTx) del(key string) { - tx.stm.Del(key) - tx.setDirty() -} - -// setDirty marks the transaction dirty and bumps -// mod version for the existing lock set if it is -// not empty. -func (tx *readWriteTx) setDirty() { - // Bump the lock set. - if !tx.dirty && len(tx.lset) > 0 { - for key, val := range tx.lset { - // Bump the mod version of the key, - // leaving the value intact. - tx.stm.Put(key, val) - } - - // Clear the lock set. - tx.lset = make(map[string]string) - } - - // Set dirty. - tx.dirty = true -} - // ReadBucket opens the root bucket for read only access. If the bucket // described by the key does not exist, nil is returned. func (tx *readWriteTx) ReadBucket(key []byte) walletdb.ReadBucket { diff --git a/channeldb/kvdb/etcd/readwrite_tx_test.go b/channeldb/kvdb/etcd/readwrite_tx_test.go index f65faa54..bab6967f 100644 --- a/channeldb/kvdb/etcd/readwrite_tx_test.go +++ b/channeldb/kvdb/etcd/readwrite_tx_test.go @@ -6,7 +6,7 @@ import ( "testing" "github.com/btcsuite/btcwallet/walletdb" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestTxManualCommit(t *testing.T) { @@ -16,11 +16,11 @@ func TestTxManualCommit(t *testing.T) { defer f.Cleanup() db, err := newEtcdBackend(f.BackendConfig()) - assert.NoError(t, err) + require.NoError(t, err) tx, err := db.BeginReadWriteTx() - assert.NoError(t, err) - assert.NotNil(t, tx) + require.NoError(t, err) + require.NotNil(t, tx) committed := false @@ -29,24 +29,24 @@ func TestTxManualCommit(t *testing.T) { }) apple, err := tx.CreateTopLevelBucket([]byte("apple")) - assert.NoError(t, err) - assert.NotNil(t, apple) - assert.NoError(t, apple.Put([]byte("testKey"), []byte("testVal"))) + require.NoError(t, err) + require.NotNil(t, apple) + require.NoError(t, apple.Put([]byte("testKey"), []byte("testVal"))) banana, err := tx.CreateTopLevelBucket([]byte("banana")) - assert.NoError(t, err) - assert.NotNil(t, banana) - assert.NoError(t, banana.Put([]byte("testKey"), []byte("testVal"))) - assert.NoError(t, tx.DeleteTopLevelBucket([]byte("banana"))) + require.NoError(t, err) + require.NotNil(t, banana) + require.NoError(t, banana.Put([]byte("testKey"), []byte("testVal"))) + require.NoError(t, tx.DeleteTopLevelBucket([]byte("banana"))) - assert.NoError(t, tx.Commit()) - assert.True(t, committed) + require.NoError(t, tx.Commit()) + require.True(t, committed) expected := map[string]string{ bkey("apple"): bval("apple"), vkey("testKey", "apple"): "testVal", } - assert.Equal(t, expected, f.Dump()) + require.Equal(t, expected, f.Dump()) } func TestTxRollback(t *testing.T) { @@ -56,21 +56,21 @@ func TestTxRollback(t *testing.T) { defer f.Cleanup() db, err := newEtcdBackend(f.BackendConfig()) - assert.NoError(t, err) + require.NoError(t, err) tx, err := db.BeginReadWriteTx() - assert.Nil(t, err) - assert.NotNil(t, tx) + require.Nil(t, err) + require.NotNil(t, tx) apple, err := tx.CreateTopLevelBucket([]byte("apple")) - assert.Nil(t, err) - assert.NotNil(t, apple) + require.Nil(t, err) + require.NotNil(t, apple) - assert.NoError(t, apple.Put([]byte("testKey"), []byte("testVal"))) + require.NoError(t, apple.Put([]byte("testKey"), []byte("testVal"))) - assert.NoError(t, tx.Rollback()) - assert.Error(t, walletdb.ErrTxClosed, tx.Commit()) - assert.Equal(t, map[string]string{}, f.Dump()) + require.NoError(t, tx.Rollback()) + require.Error(t, walletdb.ErrTxClosed, tx.Commit()) + require.Equal(t, map[string]string{}, f.Dump()) } func TestChangeDuringManualTx(t *testing.T) { @@ -80,24 +80,24 @@ func TestChangeDuringManualTx(t *testing.T) { defer f.Cleanup() db, err := newEtcdBackend(f.BackendConfig()) - assert.NoError(t, err) + require.NoError(t, err) tx, err := db.BeginReadWriteTx() - assert.Nil(t, err) - assert.NotNil(t, tx) + require.Nil(t, err) + require.NotNil(t, tx) apple, err := tx.CreateTopLevelBucket([]byte("apple")) - assert.Nil(t, err) - assert.NotNil(t, apple) + require.Nil(t, err) + require.NotNil(t, apple) - assert.NoError(t, apple.Put([]byte("testKey"), []byte("testVal"))) + require.NoError(t, apple.Put([]byte("testKey"), []byte("testVal"))) // Try overwriting the bucket key. f.Put(bkey("apple"), "banana") // TODO: translate error - assert.NotNil(t, tx.Commit()) - assert.Equal(t, map[string]string{ + require.NotNil(t, tx.Commit()) + require.Equal(t, map[string]string{ bkey("apple"): "banana", }, f.Dump()) } @@ -109,16 +109,16 @@ func TestChangeDuringUpdate(t *testing.T) { defer f.Cleanup() db, err := newEtcdBackend(f.BackendConfig()) - assert.NoError(t, err) + require.NoError(t, err) count := 0 err = db.Update(func(tx walletdb.ReadWriteTx) error { apple, err := tx.CreateTopLevelBucket([]byte("apple")) - assert.NoError(t, err) - assert.NotNil(t, apple) + require.NoError(t, err) + require.NotNil(t, apple) - assert.NoError(t, apple.Put([]byte("key"), []byte("value"))) + require.NoError(t, apple.Put([]byte("key"), []byte("value"))) if count == 0 { f.Put(vkey("key", "apple"), "new_value") @@ -127,30 +127,30 @@ func TestChangeDuringUpdate(t *testing.T) { cursor := apple.ReadCursor() k, v := cursor.First() - assert.Equal(t, []byte("key"), k) - assert.Equal(t, []byte("value"), v) - assert.Equal(t, v, apple.Get([]byte("key"))) + require.Equal(t, []byte("key"), k) + require.Equal(t, []byte("value"), v) + require.Equal(t, v, apple.Get([]byte("key"))) k, v = cursor.Next() if count == 0 { - assert.Nil(t, k) - assert.Nil(t, v) + require.Nil(t, k) + require.Nil(t, v) } else { - assert.Equal(t, []byte("key2"), k) - assert.Equal(t, []byte("value2"), v) + require.Equal(t, []byte("key2"), k) + require.Equal(t, []byte("value2"), v) } count++ return nil }) - assert.Nil(t, err) - assert.Equal(t, count, 2) + require.Nil(t, err) + require.Equal(t, count, 2) expected := map[string]string{ bkey("apple"): bval("apple"), vkey("key", "apple"): "value", vkey("key2", "apple"): "value2", } - assert.Equal(t, expected, f.Dump()) + require.Equal(t, expected, f.Dump()) } diff --git a/channeldb/kvdb/etcd/stm.go b/channeldb/kvdb/etcd/stm.go index 7a2f33b5..14bb9ca9 100644 --- a/channeldb/kvdb/etcd/stm.go +++ b/channeldb/kvdb/etcd/stm.go @@ -32,11 +32,6 @@ type STM interface { // set. Returns nil if there's no matching key, or the key is empty. Get(key string) ([]byte, error) - // Lock adds a key to the lock set. If the lock set is not empty, we'll - // only check for conflicts in the lock set and the write set, instead - // of all read keys plus the write set. - Lock(key string) - // Put adds a value for a key to the txn's write set. Put(key, val string) @@ -151,9 +146,6 @@ type stm struct { // wset holds overwritten keys and their values. wset writeSet - // lset holds keys we intent to lock on. - lset map[string]interface{} - // getOpts are the opts used for gets. getOpts []v3.OpOption @@ -247,19 +239,19 @@ loop: default: } - - // Apply the transaction closure and abort the STM if there was an - // application error. + // Apply the transaction closure and abort the STM if there was + // an application error. if err = apply(s); err != nil { break loop } stats, err = s.commit() - // Re-apply only upon commit error (meaning the database was changed). + // Retry the apply closure only upon commit error (meaning the + // database was changed). if _, ok := err.(CommitError); !ok { - // Anything that's not a CommitError - // aborts the STM run loop. + // Anything that's not a CommitError aborts the STM + // run loop. break loop } @@ -303,24 +295,14 @@ func (rs readSet) gets() []v3.Op { return ops } -// cmps returns a cmp list testing values in read set didn't change. -func (rs readSet) cmps(lset map[string]interface{}) []v3.Cmp { - if len(lset) > 0 { - cmps := make([]v3.Cmp, 0, len(lset)) - for key := range lset { - if getValue, ok := rs[key]; ok { - cmps = append( - cmps, - v3.Compare(v3.ModRevision(key), "=", getValue.rev), - ) - } - } - return cmps - } - +// cmps returns a compare list which will serve as a precondition testing that +// the values in the read set didn't change. +func (rs readSet) cmps() []v3.Cmp { cmps := make([]v3.Cmp, 0, len(rs)) for key, getValue := range rs { - cmps = append(cmps, v3.Compare(v3.ModRevision(key), "=", getValue.rev)) + cmps = append(cmps, v3.Compare( + v3.ModRevision(key), "=", getValue.rev, + )) } return cmps @@ -370,6 +352,15 @@ func (s *stm) fetch(key string, opts ...v3.OpOption) ([]KV, error) { } } + if len(resp.Kvs) == 0 { + // Add assertion to the read set which will extend our commit + // constraint such that the commit will fail if the key is + // present in the database. + s.rset[key] = stmGet{ + rev: 0, + } + } + var result []KV // Fill the read set with key/values returned. @@ -413,12 +404,22 @@ func (s *stm) Get(key string) ([]byte, error) { // the prefetch set. if getValue, ok := s.prefetch[key]; ok { delete(s.prefetch, key) - s.rset[key] = getValue + + // Use the prefetched value only if it is for + // an existing key. + if getValue.rev != 0 { + s.rset[key] = getValue + } } // Return value if alread in read set. - if getVal, ok := s.rset[key]; ok { - return []byte(getVal.val), nil + if getValue, ok := s.rset[key]; ok { + // Return the value if the rset contains an existing key. + if getValue.rev != 0 { + return []byte(getValue.val), nil + } else { + return nil, nil + } } // Fetch and return value. @@ -435,13 +436,6 @@ func (s *stm) Get(key string) ([]byte, error) { return nil, nil } -// Lock adds a key to the lock set. If the lock set is -// not empty, we'll only check conflicts for the keys -// in the lock set. -func (s *stm) Lock(key string) { - s.lset[key] = nil -} - // First returns the first key/value matching prefix. If there's no key starting // with prefix, Last will return nil. func (s *stm) First(prefix string) (*KV, error) { @@ -711,7 +705,7 @@ func (s *stm) OnCommit(cb func()) { // because the keys have changed return a CommitError, otherwise return a // DatabaseError. func (s *stm) commit() (CommitStats, error) { - rset := s.rset.cmps(s.lset) + rset := s.rset.cmps() wset := s.wset.cmps(s.revision + 1) stats := CommitStats{ @@ -775,7 +769,6 @@ func (s *stm) Commit() error { func (s *stm) Rollback() { s.rset = make(map[string]stmGet) s.wset = make(map[string]stmPut) - s.lset = make(map[string]interface{}) s.getOpts = nil s.revision = math.MaxInt64 - 1 } diff --git a/channeldb/kvdb/etcd/stm_test.go b/channeldb/kvdb/etcd/stm_test.go index 767963d4..6beffc28 100644 --- a/channeldb/kvdb/etcd/stm_test.go +++ b/channeldb/kvdb/etcd/stm_test.go @@ -6,7 +6,7 @@ import ( "errors" "testing" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func reverseKVs(a []KV) []KV { @@ -24,7 +24,7 @@ func TestPutToEmpty(t *testing.T) { defer f.Cleanup() db, err := newEtcdBackend(f.BackendConfig()) - assert.NoError(t, err) + require.NoError(t, err) apply := func(stm STM) error { stm.Put("123", "abc") @@ -32,9 +32,9 @@ func TestPutToEmpty(t *testing.T) { } err = RunSTM(db.cli, apply) - assert.NoError(t, err) + require.NoError(t, err) - assert.Equal(t, "abc", f.Get("123")) + require.Equal(t, "abc", f.Get("123")) } func TestGetPutDel(t *testing.T) { @@ -56,64 +56,64 @@ func TestGetPutDel(t *testing.T) { } db, err := newEtcdBackend(f.BackendConfig()) - assert.NoError(t, err) + require.NoError(t, err) apply := func(stm STM) error { // Get some non existing keys. v, err := stm.Get("") - assert.NoError(t, err) - assert.Nil(t, v) + require.NoError(t, err) + require.Nil(t, v) v, err = stm.Get("x") - assert.NoError(t, err) - assert.Nil(t, v) + require.NoError(t, err) + require.Nil(t, v) // Get all existing keys. for _, kv := range testKeyValues { v, err = stm.Get(kv.key) - assert.NoError(t, err) - assert.Equal(t, []byte(kv.val), v) + require.NoError(t, err) + require.Equal(t, []byte(kv.val), v) } // Overwrite, then delete an existing key. stm.Put("c", "6") v, err = stm.Get("c") - assert.NoError(t, err) - assert.Equal(t, []byte("6"), v) + require.NoError(t, err) + require.Equal(t, []byte("6"), v) stm.Del("c") v, err = stm.Get("c") - assert.NoError(t, err) - assert.Nil(t, v) + require.NoError(t, err) + require.Nil(t, v) // Re-add the deleted key. stm.Put("c", "7") v, err = stm.Get("c") - assert.NoError(t, err) - assert.Equal(t, []byte("7"), v) + require.NoError(t, err) + require.Equal(t, []byte("7"), v) // Add a new key. stm.Put("x", "x") v, err = stm.Get("x") - assert.NoError(t, err) - assert.Equal(t, []byte("x"), v) + require.NoError(t, err) + require.Equal(t, []byte("x"), v) return nil } err = RunSTM(db.cli, apply) - assert.NoError(t, err) + require.NoError(t, err) - assert.Equal(t, "1", f.Get("a")) - assert.Equal(t, "2", f.Get("b")) - assert.Equal(t, "7", f.Get("c")) - assert.Equal(t, "4", f.Get("d")) - assert.Equal(t, "5", f.Get("e")) - assert.Equal(t, "x", f.Get("x")) + require.Equal(t, "1", f.Get("a")) + require.Equal(t, "2", f.Get("b")) + require.Equal(t, "7", f.Get("c")) + require.Equal(t, "4", f.Get("d")) + require.Equal(t, "5", f.Get("e")) + require.Equal(t, "x", f.Get("x")) } func TestFirstLastNextPrev(t *testing.T) { @@ -134,44 +134,44 @@ func TestFirstLastNextPrev(t *testing.T) { } db, err := newEtcdBackend(f.BackendConfig()) - assert.NoError(t, err) + require.NoError(t, err) apply := func(stm STM) error { // First/Last on valid multi item interval. kv, err := stm.First("k") - assert.NoError(t, err) - assert.Equal(t, &KV{"kb", "1"}, kv) + require.NoError(t, err) + require.Equal(t, &KV{"kb", "1"}, kv) kv, err = stm.Last("k") - assert.NoError(t, err) - assert.Equal(t, &KV{"ke", "4"}, kv) + require.NoError(t, err) + require.Equal(t, &KV{"ke", "4"}, kv) // First/Last on single item interval. kv, err = stm.First("w") - assert.NoError(t, err) - assert.Equal(t, &KV{"w", "w"}, kv) + require.NoError(t, err) + require.Equal(t, &KV{"w", "w"}, kv) kv, err = stm.Last("w") - assert.NoError(t, err) - assert.Equal(t, &KV{"w", "w"}, kv) + require.NoError(t, err) + require.Equal(t, &KV{"w", "w"}, kv) // Next/Prev on start/end. kv, err = stm.Next("k", "ke") - assert.NoError(t, err) - assert.Nil(t, kv) + require.NoError(t, err) + require.Nil(t, kv) kv, err = stm.Prev("k", "kb") - assert.NoError(t, err) - assert.Nil(t, kv) + require.NoError(t, err) + require.Nil(t, kv) // Next/Prev in the middle. kv, err = stm.Next("k", "kc") - assert.NoError(t, err) - assert.Equal(t, &KV{"kda", "3"}, kv) + require.NoError(t, err) + require.Equal(t, &KV{"kda", "3"}, kv) kv, err = stm.Prev("k", "ke") - assert.NoError(t, err) - assert.Equal(t, &KV{"kda", "3"}, kv) + require.NoError(t, err) + require.Equal(t, &KV{"kda", "3"}, kv) // Delete first item, then add an item before the // deleted one. Check that First/Next will "jump" @@ -180,12 +180,12 @@ func TestFirstLastNextPrev(t *testing.T) { stm.Put("ka", "0") kv, err = stm.First("k") - assert.NoError(t, err) - assert.Equal(t, &KV{"ka", "0"}, kv) + require.NoError(t, err) + require.Equal(t, &KV{"ka", "0"}, kv) kv, err = stm.Prev("k", "kc") - assert.NoError(t, err) - assert.Equal(t, &KV{"ka", "0"}, kv) + require.NoError(t, err) + require.Equal(t, &KV{"ka", "0"}, kv) // Similarly test that a new end is returned if // the old end is deleted first. @@ -193,19 +193,19 @@ func TestFirstLastNextPrev(t *testing.T) { stm.Put("kf", "5") kv, err = stm.Last("k") - assert.NoError(t, err) - assert.Equal(t, &KV{"kf", "5"}, kv) + require.NoError(t, err) + require.Equal(t, &KV{"kf", "5"}, kv) kv, err = stm.Next("k", "kda") - assert.NoError(t, err) - assert.Equal(t, &KV{"kf", "5"}, kv) + require.NoError(t, err) + require.Equal(t, &KV{"kf", "5"}, kv) // Overwrite one in the middle. stm.Put("kda", "6") kv, err = stm.Next("k", "kc") - assert.NoError(t, err) - assert.Equal(t, &KV{"kda", "6"}, kv) + require.NoError(t, err) + require.Equal(t, &KV{"kda", "6"}, kv) // Add three in the middle, then delete one. stm.Put("kdb", "7") @@ -218,12 +218,12 @@ func TestFirstLastNextPrev(t *testing.T) { var kvs []KV curr, err := stm.First("k") - assert.NoError(t, err) + require.NoError(t, err) for curr != nil { kvs = append(kvs, *curr) curr, err = stm.Next("k", curr.key) - assert.NoError(t, err) + require.NoError(t, err) } expected := []KV{ @@ -234,37 +234,37 @@ func TestFirstLastNextPrev(t *testing.T) { {"kdd", "9"}, {"kf", "5"}, } - assert.Equal(t, expected, kvs) + require.Equal(t, expected, kvs) // Similarly check that stepping from last to first // returns the expected sequence. kvs = []KV{} curr, err = stm.Last("k") - assert.NoError(t, err) + require.NoError(t, err) for curr != nil { kvs = append(kvs, *curr) curr, err = stm.Prev("k", curr.key) - assert.NoError(t, err) + require.NoError(t, err) } expected = reverseKVs(expected) - assert.Equal(t, expected, kvs) + require.Equal(t, expected, kvs) return nil } err = RunSTM(db.cli, apply) - assert.NoError(t, err) + require.NoError(t, err) - assert.Equal(t, "0", f.Get("ka")) - assert.Equal(t, "2", f.Get("kc")) - assert.Equal(t, "6", f.Get("kda")) - assert.Equal(t, "7", f.Get("kdb")) - assert.Equal(t, "9", f.Get("kdd")) - assert.Equal(t, "5", f.Get("kf")) - assert.Equal(t, "w", f.Get("w")) + require.Equal(t, "0", f.Get("ka")) + require.Equal(t, "2", f.Get("kc")) + require.Equal(t, "6", f.Get("kda")) + require.Equal(t, "7", f.Get("kdb")) + require.Equal(t, "9", f.Get("kdd")) + require.Equal(t, "5", f.Get("kf")) + require.Equal(t, "w", f.Get("w")) } func TestCommitError(t *testing.T) { @@ -274,7 +274,7 @@ func TestCommitError(t *testing.T) { defer f.Cleanup() db, err := newEtcdBackend(f.BackendConfig()) - assert.NoError(t, err) + require.NoError(t, err) // Preset DB state. f.Put("123", "xyz") @@ -285,10 +285,10 @@ func TestCommitError(t *testing.T) { apply := func(stm STM) error { // STM must have the key/value. val, err := stm.Get("123") - assert.NoError(t, err) + require.NoError(t, err) if cnt == 0 { - assert.Equal(t, []byte("xyz"), val) + require.Equal(t, []byte("xyz"), val) // Put a conflicting key/value during the first apply. f.Put("123", "def") @@ -302,10 +302,10 @@ func TestCommitError(t *testing.T) { } err = RunSTM(db.cli, apply) - assert.NoError(t, err) - assert.Equal(t, 2, cnt) + require.NoError(t, err) + require.Equal(t, 2, cnt) - assert.Equal(t, "abc", f.Get("123")) + require.Equal(t, "abc", f.Get("123")) } func TestManualTxError(t *testing.T) { @@ -315,7 +315,7 @@ func TestManualTxError(t *testing.T) { defer f.Cleanup() db, err := newEtcdBackend(f.BackendConfig()) - assert.NoError(t, err) + require.NoError(t, err) // Preset DB state. f.Put("123", "xyz") @@ -323,22 +323,22 @@ func TestManualTxError(t *testing.T) { stm := NewSTM(db.cli) val, err := stm.Get("123") - assert.NoError(t, err) - assert.Equal(t, []byte("xyz"), val) + require.NoError(t, err) + require.Equal(t, []byte("xyz"), val) // Put a conflicting key/value. f.Put("123", "def") // Should still get the original version. val, err = stm.Get("123") - assert.NoError(t, err) - assert.Equal(t, []byte("xyz"), val) + require.NoError(t, err) + require.Equal(t, []byte("xyz"), val) // Commit will fail with CommitError. err = stm.Commit() var e CommitError - assert.True(t, errors.As(err, &e)) + require.True(t, errors.As(err, &e)) // We expect that the transacton indeed did not commit. - assert.Equal(t, "def", f.Get("123")) + require.Equal(t, "def", f.Get("123")) } diff --git a/channeldb/meta_test.go b/channeldb/meta_test.go index 956ffb5d..98e9c88a 100644 --- a/channeldb/meta_test.go +++ b/channeldb/meta_test.go @@ -15,7 +15,7 @@ import ( func applyMigration(t *testing.T, beforeMigration, afterMigration func(d *DB), migrationFunc migration, shouldFail bool, dryRun bool) { - cdb, cleanUp, err := makeTestDB() + cdb, cleanUp, err := MakeTestDB() defer cleanUp() if err != nil { t.Fatal(err) @@ -86,7 +86,7 @@ func applyMigration(t *testing.T, beforeMigration, afterMigration func(d *DB), func TestVersionFetchPut(t *testing.T) { t.Parallel() - db, cleanUp, err := makeTestDB() + db, cleanUp, err := MakeTestDB() defer cleanUp() if err != nil { t.Fatal(err) diff --git a/channeldb/nodes_test.go b/channeldb/nodes_test.go index 755177aa..0d649d43 100644 --- a/channeldb/nodes_test.go +++ b/channeldb/nodes_test.go @@ -13,7 +13,7 @@ import ( func TestLinkNodeEncodeDecode(t *testing.T) { t.Parallel() - cdb, cleanUp, err := makeTestDB() + cdb, cleanUp, err := MakeTestDB() if err != nil { t.Fatalf("unable to make test database: %v", err) } @@ -110,7 +110,7 @@ func TestLinkNodeEncodeDecode(t *testing.T) { func TestDeleteLinkNode(t *testing.T) { t.Parallel() - cdb, cleanUp, err := makeTestDB() + cdb, cleanUp, err := MakeTestDB() if err != nil { t.Fatalf("unable to make test database: %v", err) } diff --git a/channeldb/payment_control_test.go b/channeldb/payment_control_test.go index 147e5452..4f901462 100644 --- a/channeldb/payment_control_test.go +++ b/channeldb/payment_control_test.go @@ -56,7 +56,7 @@ func genInfo() (*PaymentCreationInfo, *HTLCAttemptInfo, func TestPaymentControlSwitchFail(t *testing.T) { t.Parallel() - db, cleanup, err := makeTestDB() + db, cleanup, err := MakeTestDB() defer cleanup() if err != nil { t.Fatalf("unable to init db: %v", err) @@ -203,7 +203,7 @@ func TestPaymentControlSwitchFail(t *testing.T) { func TestPaymentControlSwitchDoubleSend(t *testing.T) { t.Parallel() - db, cleanup, err := makeTestDB() + db, cleanup, err := MakeTestDB() defer cleanup() if err != nil { @@ -286,7 +286,7 @@ func TestPaymentControlSwitchDoubleSend(t *testing.T) { func TestPaymentControlSuccessesWithoutInFlight(t *testing.T) { t.Parallel() - db, cleanup, err := makeTestDB() + db, cleanup, err := MakeTestDB() defer cleanup() if err != nil { @@ -319,7 +319,7 @@ func TestPaymentControlSuccessesWithoutInFlight(t *testing.T) { func TestPaymentControlFailsWithoutInFlight(t *testing.T) { t.Parallel() - db, cleanup, err := makeTestDB() + db, cleanup, err := MakeTestDB() defer cleanup() if err != nil { @@ -347,7 +347,7 @@ func TestPaymentControlFailsWithoutInFlight(t *testing.T) { func TestPaymentControlDeleteNonInFligt(t *testing.T) { t.Parallel() - db, cleanup, err := makeTestDB() + db, cleanup, err := MakeTestDB() defer cleanup() if err != nil { @@ -530,7 +530,7 @@ func TestPaymentControlMultiShard(t *testing.T) { } runSubTest := func(t *testing.T, test testCase) { - db, cleanup, err := makeTestDB() + db, cleanup, err := MakeTestDB() defer cleanup() if err != nil { @@ -780,7 +780,7 @@ func TestPaymentControlMultiShard(t *testing.T) { func TestPaymentControlMPPRecordValidation(t *testing.T) { t.Parallel() - db, cleanup, err := makeTestDB() + db, cleanup, err := MakeTestDB() defer cleanup() if err != nil { diff --git a/channeldb/payments_test.go b/channeldb/payments_test.go index 9e790c3e..0dc05956 100644 --- a/channeldb/payments_test.go +++ b/channeldb/payments_test.go @@ -399,7 +399,7 @@ func TestQueryPayments(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() - db, cleanup, err := makeTestDB() + db, cleanup, err := MakeTestDB() if err != nil { t.Fatalf("unable to init db: %v", err) } @@ -512,7 +512,7 @@ func TestQueryPayments(t *testing.T) { // case where a specific duplicate is not found and the duplicates bucket is not // present when we expect it to be. func TestFetchPaymentWithSequenceNumber(t *testing.T) { - db, cleanup, err := makeTestDB() + db, cleanup, err := MakeTestDB() require.NoError(t, err) defer cleanup() diff --git a/channeldb/reports_test.go b/channeldb/reports_test.go index 398d0e6d..a63fe42b 100644 --- a/channeldb/reports_test.go +++ b/channeldb/reports_test.go @@ -48,7 +48,7 @@ func TestPersistReport(t *testing.T) { test := test t.Run(test.name, func(t *testing.T) { - db, cleanup, err := makeTestDB() + db, cleanup, err := MakeTestDB() require.NoError(t, err) defer cleanup() @@ -85,7 +85,7 @@ func TestPersistReport(t *testing.T) { // channel, testing that the appropriate error is returned based on the state // of the existing bucket. func TestFetchChannelReadBucket(t *testing.T) { - db, cleanup, err := makeTestDB() + db, cleanup, err := MakeTestDB() require.NoError(t, err) defer cleanup() @@ -197,7 +197,7 @@ func TestFetchChannelWriteBucket(t *testing.T) { test := test t.Run(test.name, func(t *testing.T) { - db, cleanup, err := makeTestDB() + db, cleanup, err := MakeTestDB() require.NoError(t, err) defer cleanup() diff --git a/channeldb/waitingproof_test.go b/channeldb/waitingproof_test.go index fff52b92..12679b69 100644 --- a/channeldb/waitingproof_test.go +++ b/channeldb/waitingproof_test.go @@ -14,7 +14,7 @@ import ( func TestWaitingProofStore(t *testing.T) { t.Parallel() - db, cleanup, err := makeTestDB() + db, cleanup, err := MakeTestDB() if err != nil { t.Fatalf("failed to make test database: %s", err) } diff --git a/channeldb/witness_cache_test.go b/channeldb/witness_cache_test.go index 8ba1e835..fb6c9683 100644 --- a/channeldb/witness_cache_test.go +++ b/channeldb/witness_cache_test.go @@ -12,7 +12,7 @@ import ( func TestWitnessCacheSha256Retrieval(t *testing.T) { t.Parallel() - cdb, cleanUp, err := makeTestDB() + cdb, cleanUp, err := MakeTestDB() if err != nil { t.Fatalf("unable to make test database: %v", err) } @@ -57,7 +57,7 @@ func TestWitnessCacheSha256Retrieval(t *testing.T) { func TestWitnessCacheSha256Deletion(t *testing.T) { t.Parallel() - cdb, cleanUp, err := makeTestDB() + cdb, cleanUp, err := MakeTestDB() if err != nil { t.Fatalf("unable to make test database: %v", err) } @@ -108,7 +108,7 @@ func TestWitnessCacheSha256Deletion(t *testing.T) { func TestWitnessCacheUnknownWitness(t *testing.T) { t.Parallel() - cdb, cleanUp, err := makeTestDB() + cdb, cleanUp, err := MakeTestDB() if err != nil { t.Fatalf("unable to make test database: %v", err) } @@ -127,7 +127,7 @@ func TestWitnessCacheUnknownWitness(t *testing.T) { // TestAddSha256Witnesses tests that insertion using AddSha256Witnesses behaves // identically to the insertion via the generalized interface. func TestAddSha256Witnesses(t *testing.T) { - cdb, cleanUp, err := makeTestDB() + cdb, cleanUp, err := MakeTestDB() if err != nil { t.Fatalf("unable to make test database: %v", err) } diff --git a/nursery_store_test.go b/nursery_store_test.go index af5e0a06..dee92701 100644 --- a/nursery_store_test.go +++ b/nursery_store_test.go @@ -3,8 +3,6 @@ package lnd import ( - "io/ioutil" - "os" "reflect" "testing" @@ -12,31 +10,6 @@ import ( "github.com/lightningnetwork/lnd/channeldb" ) -// makeTestDB creates a new instance of the ChannelDB for testing purposes. A -// callback which cleans up the created temporary directories is also returned -// and intended to be executed after the test completes. -func makeTestDB() (*channeldb.DB, func(), error) { - // First, create a temporary directory to be used for the duration of - // this test. - tempDirName, err := ioutil.TempDir("", "channeldb") - if err != nil { - return nil, nil, err - } - - // Next, create channeldb for the first time. - cdb, err := channeldb.Open(tempDirName) - if err != nil { - return nil, nil, err - } - - cleanUp := func() { - cdb.Close() - os.RemoveAll(tempDirName) - } - - return cdb, cleanUp, nil -} - type incubateTest struct { nOutputs int chanPoint *wire.OutPoint @@ -75,7 +48,7 @@ func initIncubateTests() { // TestNurseryStoreInit verifies basic properties of the nursery store before // any modifying calls are made. func TestNurseryStoreInit(t *testing.T) { - cdb, cleanUp, err := makeTestDB() + cdb, cleanUp, err := channeldb.MakeTestDB() if err != nil { t.Fatalf("unable to open channel db: %v", err) } @@ -95,7 +68,7 @@ func TestNurseryStoreInit(t *testing.T) { // outputs through the nursery store, verifying the properties of the // intermediate states. func TestNurseryStoreIncubate(t *testing.T) { - cdb, cleanUp, err := makeTestDB() + cdb, cleanUp, err := channeldb.MakeTestDB() if err != nil { t.Fatalf("unable to open channel db: %v", err) } @@ -336,7 +309,7 @@ func TestNurseryStoreIncubate(t *testing.T) { // populated entries from the height index as it is purged, and that the last // purged height is set appropriately. func TestNurseryStoreGraduate(t *testing.T) { - cdb, cleanUp, err := makeTestDB() + cdb, cleanUp, err := channeldb.MakeTestDB() if err != nil { t.Fatalf("unable to open channel db: %v", err) } diff --git a/sweep/store_test.go b/sweep/store_test.go index 3738f6c9..b27efb31 100644 --- a/sweep/store_test.go +++ b/sweep/store_test.go @@ -1,8 +1,6 @@ package sweep import ( - "io/ioutil" - "os" "testing" "github.com/btcsuite/btcd/chaincfg/chainhash" @@ -10,38 +8,13 @@ import ( "github.com/lightningnetwork/lnd/channeldb" ) -// makeTestDB creates a new instance of the ChannelDB for testing purposes. A -// callback which cleans up the created temporary directories is also returned -// and intended to be executed after the test completes. -func makeTestDB() (*channeldb.DB, func(), error) { - // First, create a temporary directory to be used for the duration of - // this test. - tempDirName, err := ioutil.TempDir("", "channeldb") - if err != nil { - return nil, nil, err - } - - // Next, create channeldb for the first time. - cdb, err := channeldb.Open(tempDirName) - if err != nil { - return nil, nil, err - } - - cleanUp := func() { - cdb.Close() - os.RemoveAll(tempDirName) - } - - return cdb, cleanUp, nil -} - // TestStore asserts that the store persists the presented data to disk and is // able to retrieve it again. func TestStore(t *testing.T) { t.Run("bolt", func(t *testing.T) { // Create new store. - cdb, cleanUp, err := makeTestDB() + cdb, cleanUp, err := channeldb.MakeTestDB() if err != nil { t.Fatalf("unable to open channel db: %v", err) } diff --git a/utxonursery_test.go b/utxonursery_test.go index 1a9323db..4f9fdab9 100644 --- a/utxonursery_test.go +++ b/utxonursery_test.go @@ -5,7 +5,6 @@ package lnd import ( "bytes" "fmt" - "io/ioutil" "math" "os" "reflect" @@ -407,6 +406,7 @@ type nurseryTestContext struct { sweeper *mockSweeper timeoutChan chan chan time.Time t *testing.T + dbCleanup func() } func createNurseryTestContext(t *testing.T, @@ -416,12 +416,7 @@ func createNurseryTestContext(t *testing.T, // alternative, mocking nurseryStore, is not chosen because there is // still considerable logic in the store. - tempDirName, err := ioutil.TempDir("", "channeldb") - if err != nil { - t.Fatalf("unable to create temp dir: %v", err) - } - - cdb, err := channeldb.Open(tempDirName) + cdb, cleanup, err := channeldb.MakeTestDB() if err != nil { t.Fatalf("unable to open channeldb: %v", err) } @@ -484,6 +479,7 @@ func createNurseryTestContext(t *testing.T, sweeper: sweeper, timeoutChan: timeoutChan, t: t, + dbCleanup: cleanup, } ctx.receiveTx = func() wire.MsgTx { @@ -531,6 +527,8 @@ func (ctx *nurseryTestContext) notifyEpoch(height int32) { } func (ctx *nurseryTestContext) finish() { + defer ctx.dbCleanup() + // Add a final restart point in this state ctx.restart()