diff --git a/kvdb/etcd/db.go b/kvdb/etcd/db.go index 12b89868..49105643 100644 --- a/kvdb/etcd/db.go +++ b/kvdb/etcd/db.go @@ -219,7 +219,8 @@ func (db *db) View(f func(tx walletdb.ReadTx) error, reset func()) error { return f(newReadWriteTx(stm, etcdDefaultRootBucketId, nil)) } - return RunSTM(db.cli, apply, db.txQueue, db.getSTMOptions()...) + _, err := RunSTM(db.cli, apply, db.txQueue, db.getSTMOptions()...) + return err } // Update opens a database read/write transaction and executes the function f @@ -240,7 +241,8 @@ func (db *db) Update(f func(tx walletdb.ReadWriteTx) error, reset func()) error return f(newReadWriteTx(stm, etcdDefaultRootBucketId, nil)) } - return RunSTM(db.cli, apply, db.txQueue, db.getSTMOptions()...) + _, err := RunSTM(db.cli, apply, db.txQueue, db.getSTMOptions()...) + return err } // PrintStats returns all collected stats pretty printed into a string. diff --git a/kvdb/etcd/stm.go b/kvdb/etcd/stm.go index 9d1e50e9..7ef9776b 100644 --- a/kvdb/etcd/stm.go +++ b/kvdb/etcd/stm.go @@ -8,6 +8,8 @@ import ( "math" "strings" + "github.com/google/btree" + pb "go.etcd.io/etcd/api/v3/etcdserverpb" v3 "go.etcd.io/etcd/client/v3" ) @@ -71,9 +73,13 @@ type STM interface { // Commit may return CommitError if transaction is outdated and needs retry. Commit() error - // Rollback emties the read and write sets such that a subsequent commit + // Rollback entries the read and write sets such that a subsequent commit // won't alter the database. Rollback() + + // Prefetch prefetches the passed keys and prefixes. For prefixes it'll + // fetch the whole range. + Prefetch(keys []string, prefix []string) } // CommitError is used to check if there was an error @@ -104,15 +110,26 @@ func (e DatabaseError) Error() string { return fmt.Sprintf("etcd error: %v - %v", e.msg, e.err) } -// stmGet is the result of a read operation, -// a value and the mod revision of the key/value. +// stmGet is the result of a read operation, a value and the mod revision of the +// key/value. type stmGet struct { - val string + KV rev int64 } +// Less implements less operator for btree.BTree. +func (c *stmGet) Less(than btree.Item) bool { + return c.key < than.(*stmGet).key +} + // readSet stores all reads done in an STM. -type readSet map[string]stmGet +type readSet struct { + // tree stores the items in the read set. + tree *btree.BTree + + // fullRanges stores full range prefixes. + fullRanges map[string]struct{} +} // stmPut stores a value and an operation (put/delete). type stmPut struct { @@ -141,11 +158,8 @@ type stm struct { // options stores optional settings passed by the user. options *STMOptions - // prefetch hold prefetched key values and revisions. - prefetch readSet - // rset holds read key values and revisions. - rset readSet + rset *readSet // wset holds overwritten keys and their values. wset writeSet @@ -158,6 +172,9 @@ type stm struct { // onCommit gets called upon commit. onCommit func() + + // callCount tracks the number of times we called into etcd. + callCount int } // STMOptions can be used to pass optional settings @@ -188,9 +205,12 @@ func WithCommitStatsCallback(cb func(bool, CommitStats)) STMOptionFunc { // RunSTM runs the apply function by creating an STM using serializable snapshot // isolation, passing it to the apply and handling commit errors and retries. func RunSTM(cli *v3.Client, apply func(STM) error, txQueue *commitQueue, - so ...STMOptionFunc) error { + so ...STMOptionFunc) (int, error) { - return runSTM(makeSTM(cli, false, txQueue, so...), apply) + stm := makeSTM(cli, false, txQueue, so...) + err := runSTM(stm, apply) + + return stm.callCount, err } // NewSTM creates a new STM instance, using serializable snapshot isolation. @@ -213,15 +233,15 @@ func makeSTM(cli *v3.Client, manual bool, txQueue *commitQueue, } s := &stm{ - client: cli, - manual: manual, - txQueue: txQueue, - options: opts, - prefetch: make(map[string]stmGet), + client: cli, + manual: manual, + txQueue: txQueue, + options: opts, + rset: newReadSet(), } // Reset read and write set. - s.Rollback() + s.rollback(true) return s } @@ -262,8 +282,11 @@ func runSTM(s *stm, apply func(STM) error) error { return } - // Rollback before trying to re-apply. - s.Rollback() + // Rollback the write set before trying to re-apply. + // Upon commit we retrieved the latest version of all + // previously fetched keys and ranges so we don't need + // to rollback the read set. + s.rollback(false) retries++ // Re-apply the transaction closure. @@ -287,14 +310,16 @@ func runSTM(s *stm, apply func(STM) error) error { // result in queueing up transactions and contending DB access. // Copying these strings is cheap due to Go's immutable string which is // always a reference. - rkeys := make([]string, len(s.rset)) + rkeys := make([]string, s.rset.tree.Len()) wkeys := make([]string, len(s.wset)) i := 0 - for key := range s.rset { - rkeys[i] = key + s.rset.tree.Ascend(func(item btree.Item) bool { + rkeys[i] = item.(*stmGet).key i++ - } + + return true + }) i = 0 for key := range s.wset { @@ -320,42 +345,225 @@ func runSTM(s *stm, apply func(STM) error) error { return executeErr } -// add inserts a txn response to the read set. This is useful when the txn -// fails due to conflict where the txn response can be used to prefetch -// key/values. -func (rs readSet) add(txnResp *v3.TxnResponse) { - for _, resp := range txnResp.Responses { - getResp := (*v3.GetResponse)(resp.GetResponseRange()) +func newReadSet() *readSet { + return &readSet{ + tree: btree.New(5), + fullRanges: make(map[string]struct{}), + } +} + +// add inserts key/values to to read set. +func (rs *readSet) add(responses []*pb.ResponseOp) { + for _, resp := range responses { + getResp := resp.GetResponseRange() for _, kv := range getResp.Kvs { - rs[string(kv.Key)] = stmGet{ - val: string(kv.Value), - rev: kv.ModRevision, - } + rs.addItem( + string(kv.Key), string(kv.Value), kv.ModRevision, + ) } } } -// gets is a helper to create an op slice for transaction -// construction. -func (rs readSet) gets() []v3.Op { - ops := make([]v3.Op, 0, len(rs)) +// addFullRange adds all full ranges to the read set. +func (rs *readSet) addFullRange(prefixes []string, responses []*pb.ResponseOp) { + for i, resp := range responses { + getResp := resp.GetResponseRange() + for _, kv := range getResp.Kvs { + rs.addItem( + string(kv.Key), string(kv.Value), kv.ModRevision, + ) + } - for k := range rs { - ops = append(ops, v3.OpGet(k)) + rs.fullRanges[prefixes[i]] = struct{}{} + } +} + +// presetItem presets a key to zero revision if not already present in the read +// set. +func (rs *readSet) presetItem(key string) { + item := &stmGet{ + KV: KV{ + key: key, + }, + rev: 0, + } + + if !rs.tree.Has(item) { + rs.tree.ReplaceOrInsert(item) + } +} + +// addItem adds a single new key/value to the read set (if not already present). +func (rs *readSet) addItem(key, val string, modRevision int64) { + item := &stmGet{ + KV: KV{ + key: key, + val: val, + }, + rev: modRevision, + } + + rs.tree.ReplaceOrInsert(item) +} + +// hasFullRange checks if the read set has a full range prefetched. +func (rs *readSet) hasFullRange(prefix string) bool { + _, ok := rs.fullRanges[prefix] + return ok +} + +// next returns the pre-fetched next value of the prefix. If matchKey is true, +// it'll simply return the key/value that matches the passed key. +func (rs *readSet) next(prefix, key string, matchKey bool) (*stmGet, bool) { + pivot := &stmGet{ + KV: KV{ + key: key, + }, + } + + var result *stmGet + rs.tree.AscendGreaterOrEqual( + pivot, + func(item btree.Item) bool { + next := item.(*stmGet) + if (!matchKey && next.key == key) || next.rev == 0 { + return true + } + + if strings.HasPrefix(next.key, prefix) { + result = next + } + + return false + }, + ) + + return result, result != nil +} + +// prev returns the pre-fetched prev key/value of the prefix from key. +func (rs *readSet) prev(prefix, key string) (*stmGet, bool) { + pivot := &stmGet{ + KV: KV{ + key: key, + }, + } + + var result *stmGet + + rs.tree.DescendLessOrEqual( + pivot, func(item btree.Item) bool { + prev := item.(*stmGet) + if prev.key == key || prev.rev == 0 { + return true + } + + if strings.HasPrefix(prev.key, prefix) { + result = prev + } + + return false + }, + ) + + return result, result != nil +} + +// last returns the last key/value of the passed range (if prefetched). +func (rs *readSet) last(prefix string) (*stmGet, bool) { + // We create an artificial key here that is just one step away from the + // prefix. This way when we try to get the first item with our prefix + // before this newly crafted key we'll make sure it's the last element + // of our range. + key := []byte(prefix) + key[len(key)-1] += 1 + + return rs.prev(prefix, string(key)) +} + +// clear completely clears the readset. +func (rs *readSet) clear() { + rs.tree.Clear(false) + rs.fullRanges = make(map[string]struct{}) +} + +// getItem returns the matching key/value from the readset. +func (rs *readSet) getItem(key string) (*stmGet, bool) { + pivot := &stmGet{ + KV: KV{ + key: key, + }, + rev: 0, + } + item := rs.tree.Get(pivot) + if item != nil { + return item.(*stmGet), true + } + + // It's possible that although this key isn't in the read set, we + // fetched a full range the key is prefixed with. In this case we'll + // insert the key with zero revision. + for prefix := range rs.fullRanges { + if strings.HasPrefix(key, prefix) { + rs.tree.ReplaceOrInsert(pivot) + return pivot, true + } + } + + return nil, false +} + +// prefetchSet is a helper to create an op slice of all OpGet's that represent +// fetched keys appended with a slice of all OpGet's representing all prefetched +// full ranges. +func (rs *readSet) prefetchSet() []v3.Op { + ops := make([]v3.Op, 0, rs.tree.Len()) + + rs.tree.Ascend(func(item btree.Item) bool { + key := item.(*stmGet).key + for prefix := range rs.fullRanges { + // Do not add the key if it has been prefetched in a + // full range. + if strings.HasPrefix(key, prefix) { + return true + } + } + + ops = append(ops, v3.OpGet(key)) + return true + }) + + for prefix := range rs.fullRanges { + ops = append(ops, v3.OpGet(prefix, v3.WithPrefix())) } return ops } +// getFullRanges returns all prefixes that we prefetched. +func (rs *readSet) getFullRanges() []string { + prefixes := make([]string, 0, len(rs.fullRanges)) + + for prefix := range rs.fullRanges { + prefixes = append(prefixes, prefix) + } + + return prefixes +} + // cmps returns a compare list which will serve as a precondition testing that // the values in the read set didn't change. -func (rs readSet) cmps() []v3.Cmp { - cmps := make([]v3.Cmp, 0, len(rs)) - for key, getValue := range rs { - cmps = append(cmps, v3.Compare( - v3.ModRevision(key), "=", getValue.rev, - )) - } +func (rs *readSet) cmps() []v3.Cmp { + cmps := make([]v3.Cmp, 0, rs.tree.Len()) + + rs.tree.Ascend(func(item btree.Item) bool { + get := item.(*stmGet) + cmps = append( + cmps, v3.Compare(v3.ModRevision(get.key), "=", get.rev), + ) + + return true + }) return cmps } @@ -384,6 +592,7 @@ func (ws writeSet) puts() []v3.Op { // then fetch will try to fix the STM's snapshot revision (if not already set). // We'll also cache the returned key/value in the read set. func (s *stm) fetch(key string, opts ...v3.OpOption) ([]KV, error) { + s.callCount++ resp, err := s.client.Get( s.options.ctx, key, append(opts, s.getOpts...)..., ) @@ -394,7 +603,7 @@ func (s *stm) fetch(key string, opts ...v3.OpOption) ([]KV, error) { } } - // Set revison and serializable options upon first fetch + // Set revision and serializable options upon first fetch // for any subsequent fetches. if s.getOpts == nil { s.revision = resp.Header.Revision @@ -408,26 +617,18 @@ func (s *stm) fetch(key string, opts ...v3.OpOption) ([]KV, error) { // Add assertion to the read set which will extend our commit // constraint such that the commit will fail if the key is // present in the database. - s.rset[key] = stmGet{ - rev: 0, - } + s.rset.addItem(key, "", 0) } var result []KV // Fill the read set with key/values returned. for _, kv := range resp.Kvs { - // Remove from prefetch. key := string(kv.Key) val := string(kv.Value) - delete(s.prefetch, key) - // Add to read set. - s.rset[key] = stmGet{ - val: val, - rev: kv.ModRevision, - } + s.rset.addItem(key, val, kv.ModRevision) result = append(result, KV{key, val}) } @@ -452,20 +653,8 @@ func (s *stm) Get(key string) ([]byte, error) { return []byte(put.val), nil } - // Populate read set if key is present in - // the prefetch set. - if getValue, ok := s.prefetch[key]; ok { - delete(s.prefetch, key) - - // Use the prefetched value only if it is for - // an existing key. - if getValue.rev != 0 { - s.rset[key] = getValue - } - } - // Return value if alread in read set. - if getValue, ok := s.rset[key]; ok { + if getValue, ok := s.rset.getItem(key); ok { // Return the value if the rset contains an existing key. if getValue.rev != 0 { return []byte(getValue.val), nil @@ -497,21 +686,28 @@ func (s *stm) First(prefix string) (*KV, error) { // Last returns the last key/value with prefix. If there's no key starting with // prefix, Last will return nil. func (s *stm) Last(prefix string) (*KV, error) { - // As we don't know the full range, fetch the last - // key/value with this prefix first. - resp, err := s.fetch(prefix, v3.WithLastKey()...) - if err != nil { - return nil, err - } - var ( kv KV found bool ) - if len(resp) > 0 { - kv = resp[0] - found = true + if s.rset.hasFullRange(prefix) { + if item, ok := s.rset.last(prefix); ok { + kv = item.KV + found = true + } + } else { + // As we don't know the full range, fetch the last + // key/value with this prefix first. + resp, err := s.fetch(prefix, v3.WithLastKey()...) + if err != nil { + return nil, err + } + + if len(resp) > 0 { + kv = resp[0] + found = true + } } // Now make sure there's nothing in the write set @@ -539,32 +735,41 @@ func (s *stm) Last(prefix string) (*KV, error) { } // Prev returns the prior key/value before key (with prefix). If there's no such -// key Next will return nil. +// key Prev will return nil. func (s *stm) Prev(prefix, startKey string) (*KV, error) { - var result KV + var kv, result KV fetchKey := startKey matchFound := false for { - // Ask etcd to retrieve one key that is a - // match in descending order from the passed key. - opts := []v3.OpOption{ - v3.WithRange(fetchKey), - v3.WithSort(v3.SortByKey, v3.SortDescend), - v3.WithLimit(1), - } + if s.rset.hasFullRange(prefix) { + if item, ok := s.rset.prev(prefix, fetchKey); ok { + kv = item.KV + } else { + break + } + } else { - kvs, err := s.fetch(prefix, opts...) - if err != nil { - return nil, err - } + // Ask etcd to retrieve one key that is a + // match in descending order from the passed key. + opts := []v3.OpOption{ + v3.WithRange(fetchKey), + v3.WithSort(v3.SortByKey, v3.SortDescend), + v3.WithLimit(1), + } - if len(kvs) == 0 { - break - } + kvs, err := s.fetch(prefix, opts...) + if err != nil { + return nil, err + } - kv := &kvs[0] + if len(kvs) == 0 { + break + } + + kv = kvs[0] + } // WithRange and WithPrefix can't be used // together, so check prefix here. If the @@ -580,13 +785,13 @@ func (s *stm) Prev(prefix, startKey string) (*KV, error) { continue } - result = *kv + result = kv matchFound = true break } - // Closre holding all checks to find a possibly + // Closure holding all checks to find a possibly // better match. matches := func(key string) bool { if !strings.HasPrefix(key, prefix) { @@ -635,47 +840,60 @@ func (s *stm) Seek(prefix, key string) (*KV, error) { // passed startKey. If includeStartKey is set to true, it'll return the value // of startKey (essentially implementing seek). func (s *stm) next(prefix, startKey string, includeStartKey bool) (*KV, error) { - var result KV + var kv, result KV fetchKey := startKey firstFetch := true matchFound := false for { - // Ask etcd to retrieve one key that is a - // match in ascending order from the passed key. - opts := []v3.OpOption{ - v3.WithFromKey(), - v3.WithSort(v3.SortByKey, v3.SortAscend), - v3.WithLimit(1), - } - - // By default we include the start key too - // if it is a full match. - if includeStartKey && firstFetch { + if s.rset.hasFullRange(prefix) { + matchKey := includeStartKey && firstFetch firstFetch = false + if item, ok := s.rset.next( + prefix, fetchKey, matchKey, + ); ok { + kv = item.KV + } else { + break + } } else { - // If we'd like to retrieve the first key - // after the start key. - fetchKey += "\x00" - } + // Ask etcd to retrieve one key that is a + // match in ascending order from the passed key. + opts := []v3.OpOption{ + v3.WithFromKey(), + v3.WithSort(v3.SortByKey, v3.SortAscend), + v3.WithLimit(1), + } - kvs, err := s.fetch(fetchKey, opts...) - if err != nil { - return nil, err - } + // By default we include the start key too + // if it is a full match. + if includeStartKey && firstFetch { + firstFetch = false + } else { + // If we'd like to retrieve the first key + // after the start key. + fetchKey += "\x00" + } - if len(kvs) == 0 { - break - } + kvs, err := s.fetch(fetchKey, opts...) + if err != nil { + return nil, err + } - kv := &kvs[0] - // WithRange and WithPrefix can't be used - // together, so check prefix here. If the - // returned key no longer has the prefix, - // then break the fetch loop. - if !strings.HasPrefix(kv.key, prefix) { - break + if len(kvs) == 0 { + break + } + + kv = kvs[0] + + // WithRange and WithPrefix can't be used + // together, so check prefix here. If the + // returned key no longer has the prefix, + // then break the fetch loop. + if !strings.HasPrefix(kv.key, prefix) { + break + } } // Move on to fetch starting with the next @@ -685,7 +903,7 @@ func (s *stm) next(prefix, startKey string, includeStartKey bool) (*KV, error) { continue } - result = *kv + result = kv matchFound = true break @@ -753,6 +971,72 @@ func (s *stm) OnCommit(cb func()) { s.onCommit = cb } +// Prefetch will prefetch the passed keys and prefixes in one transaction. +// Keys and prefixes that we already have will be skipped. +func (s *stm) Prefetch(keys []string, prefixes []string) { + fetchKeys := make([]string, 0, len(keys)) + for _, key := range keys { + if _, ok := s.rset.getItem(key); !ok { + fetchKeys = append(fetchKeys, key) + } + } + + fetchPrefixes := make([]string, 0, len(prefixes)) + for _, prefix := range prefixes { + if s.rset.hasFullRange(prefix) { + continue + } + fetchPrefixes = append(fetchPrefixes, prefix) + } + + if len(fetchKeys) == 0 && len(fetchPrefixes) == 0 { + return + } + + prefixOpts := append( + []v3.OpOption{v3.WithPrefix()}, s.getOpts..., + ) + + txn := s.client.Txn(s.options.ctx) + ops := make([]v3.Op, 0, len(fetchKeys)+len(fetchPrefixes)) + + for _, key := range fetchKeys { + ops = append(ops, v3.OpGet(key, s.getOpts...)) + } + for _, key := range fetchPrefixes { + ops = append(ops, v3.OpGet(key, prefixOpts...)) + } + + txn.Then(ops...) + txnresp, err := txn.Commit() + s.callCount++ + + if err != nil { + return + } + + // Set revision and serializable options upon first fetch for any + // subsequent fetches. + if s.getOpts == nil { + s.revision = txnresp.Header.Revision + s.getOpts = []v3.OpOption{ + v3.WithRev(s.revision), + v3.WithSerializable(), + } + } + + // Preset keys to "not-present" (revision set to zero). + for _, key := range fetchKeys { + s.rset.presetItem(key) + } + + // Set prefetched keys. + s.rset.add(txnresp.Responses[:len(fetchKeys)]) + + // Set prefetched ranges. + s.rset.addFullRange(fetchPrefixes, txnresp.Responses[len(fetchKeys):]) +} + // commit builds the final transaction and tries to execute it. If commit fails // because the keys have changed return a CommitError, otherwise return a // DatabaseError. @@ -774,10 +1058,11 @@ func (s *stm) commit() (CommitStats, error) { txn = txn.If(cmps...) txn = txn.Then(s.wset.puts()...) - // Prefetch keys in case of conflict to save - // a round trip to etcd. - txn = txn.Else(s.rset.gets()...) + // Prefetch keys and ranges in case of conflict to save as many + // round-trips as possible. + txn = txn.Else(s.rset.prefetchSet()...) + s.callCount++ txnresp, err := txn.Commit() if err != nil { return stats, DatabaseError{ @@ -786,8 +1071,7 @@ func (s *stm) commit() (CommitStats, error) { } } - // Call the commit callback if the transaction - // was successful. + // Call the commit callback if the transaction was successful. if txnresp.Succeeded { if s.onCommit != nil { s.onCommit() @@ -796,12 +1080,23 @@ func (s *stm) commit() (CommitStats, error) { return stats, nil } - // Load prefetch before if commit failed. - s.rset.add(txnresp) - s.prefetch = s.rset + // Determine where our fetched full ranges begin in the response. + prefixes := s.rset.getFullRanges() + firstPrefixResp := len(txnresp.Responses) - len(prefixes) - // Return CommitError indicating that the transaction - // can be retried. + // Clear reload and preload it with the prefetched keys and ranges. + s.rset.clear() + s.rset.add(txnresp.Responses[:firstPrefixResp]) + s.rset.addFullRange(prefixes, txnresp.Responses[firstPrefixResp:]) + + // Set our revision boundary. + s.revision = txnresp.Header.Revision + s.getOpts = []v3.OpOption{ + v3.WithRev(s.revision), + v3.WithSerializable(), + } + + // Return CommitError indicating that the transaction can be retried. return stats, CommitError{} } @@ -819,8 +1114,17 @@ func (s *stm) Commit() error { // Rollback resets the STM. This is useful for uncommitted transaction rollback // and also used in the STM main loop to reset state if commit fails. func (s *stm) Rollback() { - s.rset = make(map[string]stmGet) - s.wset = make(map[string]stmPut) - s.getOpts = nil - s.revision = math.MaxInt64 - 1 + s.rollback(true) +} + +// rollback will reset the read and write sets. If clearReadSet is false we'll +// only reset the the write set. +func (s *stm) rollback(clearReadSet bool) { + if clearReadSet { + s.rset.clear() + s.revision = math.MaxInt64 - 1 + s.getOpts = nil + } + + s.wset = make(map[string]stmPut) } diff --git a/kvdb/etcd/stm_test.go b/kvdb/etcd/stm_test.go index e4a3810c..5f311c8f 100644 --- a/kvdb/etcd/stm_test.go +++ b/kvdb/etcd/stm_test.go @@ -39,8 +39,9 @@ func TestPutToEmpty(t *testing.T) { return nil } - err = RunSTM(db.cli, apply, txQueue) + callCount, err := RunSTM(db.cli, apply, txQueue) require.NoError(t, err) + require.Equal(t, 1, callCount) require.Equal(t, "abc", f.Get("123")) } @@ -66,6 +67,9 @@ func TestGetPutDel(t *testing.T) { {"e", "5"}, } + // Extra 2 => Get(x), Commit() + expectedCallCount := len(testKeyValues) + 2 + for _, kv := range testKeyValues { f.Put(kv.key, kv.val) } @@ -79,11 +83,12 @@ func TestGetPutDel(t *testing.T) { require.NoError(t, err) require.Nil(t, v) + // Fetches: 1. v, err = stm.Get("x") require.NoError(t, err) require.Nil(t, v) - // Get all existing keys. + // Get all existing keys. Fetches: len(testKeyValues) for _, kv := range testKeyValues { v, err = stm.Get(kv.key) require.NoError(t, err) @@ -120,8 +125,9 @@ func TestGetPutDel(t *testing.T) { return nil } - err = RunSTM(db.cli, apply, txQueue) + callCount, err := RunSTM(db.cli, apply, txQueue) require.NoError(t, err) + require.Equal(t, expectedCallCount, callCount) require.Equal(t, "1", f.Get("a")) require.Equal(t, "2", f.Get("b")) @@ -134,6 +140,17 @@ func TestGetPutDel(t *testing.T) { func TestFirstLastNextPrev(t *testing.T) { t.Parallel() + testFirstLastNextPrev(t, nil, nil, 41) + testFirstLastNextPrev(t, nil, []string{"k"}, 4) + testFirstLastNextPrev(t, nil, []string{"k", "w"}, 2) + testFirstLastNextPrev(t, []string{"kb"}, nil, 42) + testFirstLastNextPrev(t, []string{"kb", "ke"}, nil, 42) + testFirstLastNextPrev(t, []string{"kb", "ke", "w"}, []string{"k", "w"}, 2) +} + +func testFirstLastNextPrev(t *testing.T, prefetchKeys []string, + prefetchRange []string, expectedCallCount int) { + f := NewEtcdTestFixture(t) ctx, cancel := context.WithCancel(context.Background()) @@ -159,6 +176,8 @@ func TestFirstLastNextPrev(t *testing.T) { require.NoError(t, err) apply := func(stm STM) error { + stm.Prefetch(prefetchKeys, prefetchRange) + // First/Last on valid multi item interval. kv, err := stm.First("k") require.NoError(t, err) @@ -177,11 +196,25 @@ func TestFirstLastNextPrev(t *testing.T) { require.NoError(t, err) require.Equal(t, &KV{"w", "w"}, kv) + // Non existing. + val, err := stm.Get("ke1") + require.Nil(t, val) + require.Nil(t, err) + + val, err = stm.Get("ke2") + require.Nil(t, val) + require.Nil(t, err) + // Next/Prev on start/end. kv, err = stm.Next("k", "ke") require.NoError(t, err) require.Nil(t, kv) + // Non existing. + val, err = stm.Get("ka") + require.Nil(t, val) + require.Nil(t, err) + kv, err = stm.Prev("k", "kb") require.NoError(t, err) require.Nil(t, kv) @@ -277,8 +310,9 @@ func TestFirstLastNextPrev(t *testing.T) { return nil } - err = RunSTM(db.cli, apply, txQueue) + callCount, err := RunSTM(db.cli, apply, txQueue) require.NoError(t, err) + require.Equal(t, expectedCallCount, callCount) require.Equal(t, "0", f.Get("ka")) require.Equal(t, "2", f.Get("kc")) @@ -330,9 +364,11 @@ func TestCommitError(t *testing.T) { return nil } - err = RunSTM(db.cli, apply, txQueue) + callCount, err := RunSTM(db.cli, apply, txQueue) require.NoError(t, err) require.Equal(t, 2, cnt) + // Get() + 2 * Commit(). + require.Equal(t, 3, callCount) require.Equal(t, "abc", f.Get("123")) } diff --git a/kvdb/go.mod b/kvdb/go.mod index c5015214..ab004881 100644 --- a/kvdb/go.mod +++ b/kvdb/go.mod @@ -3,9 +3,11 @@ module github.com/lightningnetwork/lnd/kvdb require ( github.com/btcsuite/btclog v0.0.0-20170628155309-84c8d2346e9f github.com/btcsuite/btcwallet/walletdb v1.3.6-0.20210803004036-eebed51155ec + github.com/google/btree v1.0.1 github.com/lightningnetwork/lnd/healthcheck v1.0.0 github.com/stretchr/testify v1.7.0 go.etcd.io/bbolt v1.3.6 + go.etcd.io/etcd/api/v3 v3.5.0 go.etcd.io/etcd/client/pkg/v3 v3.5.0 go.etcd.io/etcd/client/v3 v3.5.0 go.etcd.io/etcd/server/v3 v3.5.0