diff --git a/cache.go b/cache.go index 3457d7c..4050281 100644 --- a/cache.go +++ b/cache.go @@ -1,6 +1,7 @@ package main import ( + "encoding/json" "time" "github.com/dgraph-io/badger" @@ -18,13 +19,38 @@ type Cache struct { expiringKeys map[string]time.Time } -func (c *Cache) initialize() func() error { +func (c *Cache) initialize() func() { db, err := badger.Open(badger.DefaultOptions("/tmp/njump-cache")) if err != nil { log.Fatal().Err(err).Msg("failed to open badger at /tmp/njump-cache") } c.DB = db + // load expiringKeys + err = c.DB.View(func(txn *badger.Txn) error { + j, err := txn.Get([]byte("_expirations")) + if err != nil { + return err + } + + expirations := make(map[string]int64) + err = j.Value(func(val []byte) error { + return json.Unmarshal(val, &expirations) + }) + if err != nil { + return err + } + + for key, iwhen := range expirations { + c.expiringKeys[key] = time.Unix(iwhen, 0) + } + + return nil + }) + if err != nil && err != badger.ErrKeyNotFound { + panic(err) + } + go func() { // key expiration routine endOfTime := time.Unix(9999999999, 0) @@ -54,14 +80,29 @@ func (c *Cache) initialize() func() error { return nil }) if err != nil { - panic(err) + log.Fatal().Err(err).Msg("") } case <-c.refreshTimers: } } }() - return db.Close + // this is to be executed when the program ends + return func() { + // persist expiration times + expirations := make(map[string]int64, len(c.expiringKeys)) + for key, when := range c.expiringKeys { + expirations[key] = when.Unix() + } + j, _ := json.Marshal(expirations) + err := c.DB.Update(func(txn *badger.Txn) error { + return txn.Set([]byte("_expirations"), j) + }) + if err != nil { + panic(err) + } + db.Close() + } } func (c *Cache) Get(key string) ([]byte, bool) { @@ -80,7 +121,7 @@ func (c *Cache) Get(key string) ([]byte, bool) { return nil, false } if err != nil { - panic(err) + log.Fatal().Err(err).Msg("") } return val, true @@ -91,7 +132,7 @@ func (c *Cache) Set(key string, value []byte) { return txn.Set([]byte(key), value) }) if err != nil { - panic(err) + log.Fatal().Err(err).Msg("") } } @@ -100,7 +141,7 @@ func (c *Cache) SetWithTTL(key string, value []byte, ttl time.Duration) { return txn.Set([]byte(key), value) }) if err != nil { - panic(err) + log.Fatal().Err(err).Msg("") } c.expiringKeys[key] = time.Now().Add(ttl) c.refreshTimers <- struct{}{}