diff --git a/cache/cache.go b/cache/cache.go index 3e68f733..beb6c583 100644 --- a/cache/cache.go +++ b/cache/cache.go @@ -17,7 +17,13 @@ limitations under the License. package cache import ( + "bufio" + "encoding/binary" + "encoding/json" + "errors" "fmt" + "io" + "os" "slices" "sort" "sync" @@ -57,20 +63,24 @@ type cache[T any] struct { index map[string]*item[T] // items is the store of elements in the cache. items []*item[T] + + // capacity is the maximum number of index the cache can hold. + capacity int + metrics *cacheMetrics + labelsFunc GetLvsFunc[T] + janitor *janitor[T] + snapshotPath string + buf buffer // sorted indicates whether the items are sorted by expiration time. // It is initially true, and set to false when the items are not sorted. sorted bool - // capacity is the maximum number of index the cache can hold. - capacity int - metrics *cacheMetrics - labelsFunc GetLvsFunc[T] - janitor *janitor[T] - closed bool + closed bool mu sync.RWMutex } var _ Expirable[any] = &Cache[any]{} +var _ Persistable = &Cache[any]{} // New creates a new cache with the given configuration. func New[T any](capacity int, keyFunc KeyFunc[T], opts ...Options[T]) (*Cache[T], error) { @@ -80,11 +90,12 @@ func New[T any](capacity int, keyFunc KeyFunc[T], opts ...Options[T]) (*Cache[T] } c := &cache[T]{ - index: make(map[string]*item[T]), - items: make([]*item[T], 0, capacity), - sorted: true, - capacity: capacity, - labelsFunc: opt.labelsFunc, + index: make(map[string]*item[T]), + items: make([]*item[T], 0, capacity), + sorted: true, + capacity: capacity, + snapshotPath: opt.snapshotPath, + labelsFunc: opt.labelsFunc, janitor: &janitor[T]{ interval: opt.interval, stop: make(chan bool), @@ -97,6 +108,16 @@ func New[T any](capacity int, keyFunc KeyFunc[T], opts ...Options[T]) (*Cache[T] C := &Cache[T]{cache: c, keyFunc: keyFunc} + if c.snapshotPath != "" { + // load the cache from the file if it exists + if _, err := os.Stat(c.snapshotPath); err == nil { + err = c.load() + if err != nil { + return nil, err + } + } + } + if opt.interval > 0 { go c.janitor.run(c) } @@ -498,3 +519,197 @@ func (j *janitor[T]) run(c *cache[T]) { } } } + +// buffer is a helper type used to write data to a byte slice +type buffer []byte + +// clear clears the buffer +func (s *buffer) clear() { + *s = (*s)[:0] +} + +// writeByteSlice writes a byte slice to the buffer +func (s *buffer) writeByteSlice(v []byte) { + *s = append(*s, v...) +} + +// writeUint64 writes a uint64 to the buffer +// it is written in little endian format +func (s *buffer) writeUint64(v uint64) { + var buf [8]byte + binary.LittleEndian.PutUint64(buf[:], v) + *s = append(*s, buf[:]...) +} + +// writeBuf writes the buffer to the file +func (c *cache[T]) writeBuf(file *os.File) error { + if _, err := file.Write(c.buf); err != nil { + return err + } + // sync the file to disk straight away + file.Sync() + return nil +} + +// Persist writes the cache to disk +// The cache is written to a temporary file first +// and then renamed to the final file name to atomically +// update the cache file. This is done to avoid corrupting +// the cache file in case of a crash while writing to the file. If a file +// with the same name exists, it is overwritten. +// The cache file is written in the following format: +// key length, key, expiration, data length, data // repeat for each item +// The key length and data length are written as uint64 in little endian format +// The expiration is written as a unix timestamp in seconds as uint64 in little endian format +// The key is written as a byte slice +// The data is written as a json encoded byte slice +func (c *cache[T]) Persist() error { + c.mu.Lock() + defer c.mu.Unlock() + + if err := c.writeToBuf(); err != nil { + return err + } + + // create new temp file + newFile, err := os.Create(fmt.Sprintf("%s.tmp", c.snapshotPath)) + if err != nil { + errf := os.Remove(fmt.Sprintf("%s.tmp", c.snapshotPath)) + return errors.Join(err, errf) + } + + if err := c.writeBuf(newFile); err != nil { + errf := os.Remove(fmt.Sprintf("%s.tmp", c.snapshotPath)) + return errors.Join(err, errf) + } + + // close the file + if err := newFile.Close(); err != nil { + errf := os.Remove(fmt.Sprintf("%s.tmp", c.snapshotPath)) + return errors.Join(err, errf) + } + + if err := os.Rename(fmt.Sprintf("%s.tmp", c.snapshotPath), c.snapshotPath); err != nil { + return fmt.Errorf("failed to rename file: %w", err) + } + + return nil +} + +// writeToBuf writes the cache to the buffer +// no locks are taken, the caller should ensure that +// the cache is not being modified while this function is called. +func (c *cache[T]) writeToBuf() error { + c.buf.clear() + for _, item := range c.items { + data, err := json.Marshal(item.object) + if err != nil { + return err + } + + // write the key, expiration and data to the buffer + // format: key length, key, expiration, data length, data + // doing this this way, gives us the ability to read the file + // without having to read the entire file into memory. This is + // done for possible future use cases e.g. where the cache file + // could be very large or for range queries. + c.buf.writeUint64(uint64(len(item.key))) + c.buf.writeByteSlice([]byte(item.key)) + // we write the expiration time in nanoseconds as uint64 + // instead of using item.expiresAt.MarshalBinary() because we are only + // interested in the nano second precision Unix time, + // everything else can be discarded. + c.buf.writeUint64(uint64(item.expiresAt.UnixNano())) + c.buf.writeUint64(uint64(len(data))) + c.buf.writeByteSlice(data) + } + return nil +} + +// load reads the cache from disk +// The cache file is read in the following format: +// key length, key, expiration, data length, data // repeat for each item +// This function cannot be called concurrently, and should be called +// before the cache is used. +func (c *cache[T]) load() error { + file, err := os.Open(c.snapshotPath) + if err != nil { + return err + } + defer file.Close() + + rd := bufio.NewReader(file) + items, err := c.readFrom(rd) + if err != nil { + return err + } + + for _, item := range items { + if len(c.items) >= c.capacity { + break + } + c.items = append(c.items, item) + c.index[item.key] = item + } + + if len(c.items) > 0 { + c.metrics.setCachedItems(float64(len(c.items))) + c.sorted = false + } + return nil +} + +func (c *cache[T]) readFrom(rd io.Reader) ([]*item[T], error) { + items := make([]*item[T], 0) + for { + // read until EOF + item, err := c.readItem(rd) + if err != nil { + if err == io.EOF { + break + } + return nil, err + } + items = append(items, item) + } + return items, nil +} + +func (c *cache[T]) readItem(rd io.Reader) (*item[T], error) { + var ( + buf = make([]byte, 8) + item item[T] + ) + if _, err := io.ReadFull(rd, buf); err != nil { + if err == io.EOF { + return nil, err + } + return nil, err + } + keyLen := binary.LittleEndian.Uint64(buf) + key := make([]byte, keyLen) + if _, err := io.ReadFull(rd, key); err != nil { + return nil, err + } + item.key = string(key) + + if _, err := io.ReadFull(rd, buf); err != nil { + return nil, err + } + item.expiresAt = time.Unix(int64(binary.LittleEndian.Uint64(buf)), 0) + + if _, err := io.ReadFull(rd, buf); err != nil { + return nil, err + } + dataLen := binary.LittleEndian.Uint64(buf) + data := make([]byte, dataLen) + if _, err := io.ReadFull(rd, data); err != nil { + return nil, err + } + + if err := json.Unmarshal(data, &item.object); err != nil { + return nil, err + } + + return &item, nil +} diff --git a/cache/cache_test.go b/cache/cache_test.go index 62dfe393..4b1e0495 100644 --- a/cache/cache_test.go +++ b/cache/cache_test.go @@ -17,6 +17,7 @@ limitations under the License. package cache import ( + "bytes" "fmt" "math/rand/v2" "sync" @@ -582,3 +583,128 @@ func createObjectMap(num int) map[int]IdentifiableObject { } return objMap } + +type nameTag struct { + Name string `json:"name"` + Tag string `json:"tag"` +} + +func TestCache_WriteToBuf(t *testing.T) { + testCases := []struct { + name string + input []nameTag + expected []*item[StoreObject[nameTag]] + }{ + { + name: "empty", + input: []nameTag{}, + expected: []*item[StoreObject[nameTag]]{}, + }, + { + name: "single item", + input: []nameTag{ + { + Name: "test", + Tag: "latest", + }, + }, + expected: []*item[StoreObject[nameTag]]{ + { + key: "test", + object: StoreObject[nameTag]{ + Key: "test", + Object: nameTag{ + Name: "test", + Tag: "latest", + }, + }, + }, + }, + }, + { + name: "multiple items", + input: []nameTag{ + { + Name: "test", + Tag: "latest", + }, + { + Name: "test2", + Tag: "latest", + }, + }, + expected: []*item[StoreObject[nameTag]]{ + { + key: "test", + object: StoreObject[nameTag]{ + Key: "test", + Object: nameTag{ + Name: "test", + Tag: "latest", + }, + }, + }, + { + key: "test2", + object: StoreObject[nameTag]{ + Key: "test2", + Object: nameTag{ + Name: "test2", + Tag: "latest", + }, + }, + }, + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + g := NewWithT(t) + + c, err := New[StoreObject[nameTag]](5, StoreObjectKeyFunc, + WithMetricsRegisterer[StoreObject[nameTag]](prometheus.NewPedanticRegistry()), + WithCleanupInterval[StoreObject[nameTag]](1*time.Second)) + g.Expect(err).ToNot(HaveOccurred()) + + for _, item := range tc.input { + obj := StoreObject[nameTag]{Key: item.Name, Object: item} + err = c.Set(obj) + g.Expect(err).ToNot(HaveOccurred()) + } + + err = c.writeToBuf() + g.Expect(err).ToNot(HaveOccurred()) + items, err := c.readFrom(bytes.NewReader(c.buf)) + g.Expect(err).ToNot(HaveOccurred()) + for i, item := range items { + g.Expect(item.key).To(Equal(tc.expected[i].key)) + g.Expect(item.object).To(Equal(tc.expected[i].object)) + } + }) + } +} + +func TestCache_Load(t *testing.T) { + path := "./testdata/cache.json" + g := NewWithT(t) + + reg := prometheus.NewPedanticRegistry() + c, err := New[StoreObject[nameTag]](5, StoreObjectKeyFunc, + WithMetricsRegisterer[StoreObject[nameTag]](reg), + WithCleanupInterval[StoreObject[nameTag]](1*time.Second), + WithSnapshotPath[StoreObject[nameTag]](path)) + g.Expect(err).ToNot(HaveOccurred()) + + g.Expect(c.items).To(HaveLen(1)) + g.Expect(c.items[0].key).To(Equal("test")) + + validateMetrics(reg, ` + # HELP gotk_cache_evictions_total Total number of cache evictions. + # TYPE gotk_cache_evictions_total counter + gotk_cache_evictions_total 0 + # HELP gotk_cached_items Total number of items in the cache. + # TYPE gotk_cached_items gauge + gotk_cached_items 1 +`, t) +} diff --git a/cache/store.go b/cache/store.go index ba80e232..250bca59 100644 --- a/cache/store.go +++ b/cache/store.go @@ -45,6 +45,8 @@ type Store[T any] interface { } // Expirable is an interface for a cache store that supports expiration. +// It extends the Store interface. +// It also provides disk persistence. type Expirable[T any] interface { Store[T] // SetExpiration sets the expiration time for the object. @@ -55,11 +57,18 @@ type Expirable[T any] interface { HasExpired(object T) (bool, error) } +// Persistable is an interface for a cache store that supports disk persistence. +type Persistable interface { + // Persist persists the cache to disk. + Persist() error +} + type storeOptions[T any] struct { - interval time.Duration - registerer prometheus.Registerer - extraLabels []string - labelsFunc GetLvsFunc[T] + interval time.Duration + registerer prometheus.Registerer + extraLabels []string + labelsFunc GetLvsFunc[T] + snapshotPath string } // Options is a function that sets the store options. @@ -93,6 +102,13 @@ func WithMetricsRegisterer[T any](r prometheus.Registerer) Options[T] { } } +func WithSnapshotPath[T any](path string) Options[T] { + return func(o *storeOptions[T]) error { + o.snapshotPath = path + return nil + } +} + // KeyFunc knows how to make a key from an object. Implementations should be deterministic. type KeyFunc[T any] func(object T) (string, error) diff --git a/cache/testdata/cache.json b/cache/testdata/cache.json new file mode 100644 index 00000000..56b26a5e Binary files /dev/null and b/cache/testdata/cache.json differ