Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: new cache implementation #82

Merged
merged 7 commits into from
Dec 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 99 additions & 0 deletions cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"regexp"
"strings"
"sync"
"sync/atomic"
"time"

"github.com/miekg/dns"
Expand Down Expand Up @@ -57,6 +58,104 @@ type Cache interface {
Exists(key string) bool
Remove(key string)
Length() int
Full() bool
}

type lengCache struct {
backend sync.Map // of string -> *Mesg
size atomic.Int64
full bool
maxSize int64
}

func NewCache(maxSize int64) Cache {
return &lengCache{
backend: sync.Map{},
size: atomic.Int64{},
maxSize: maxSize,
}
}

func (c *lengCache) Get(key string) (Msg *dns.Msg, blocked bool, err error) {
key = strings.ToLower(key)

//Truncate time to the second, so that subsecond queries won't keep moving
//forward the last update time without touching the TTL
now := WallClock.Now().Truncate(time.Second)
expired := false
existing, ok := c.backend.Load(key)
mesg := existing.(*Mesg)
if ok && mesg.Msg == nil {
ok = false
logger.Warningf("Cache: key %s returned nil entry", key)
c.Remove(key)
}

if ok {
elapsed := uint32(now.Sub(mesg.LastUpdateTime).Seconds())
for _, answer := range mesg.Msg.Answer {
if elapsed > answer.Header().Ttl {
logger.Debugf("Cache: Key expired %s", key)
c.Remove(key)
expired = true
}
answer.Header().Ttl -= elapsed
}
}

if !ok {
logger.Debugf("Cache: Cannot find key %s\n", key)
return nil, false, KeyNotFound{key}
}

if expired {
return nil, false, KeyExpired{key}
}

mesg.LastUpdateTime = now

return mesg.Msg, mesg.Blocked, nil
}

func (c *lengCache) Set(key string, msg *dns.Msg, blocked bool) error {
key = strings.ToLower(key)

if c.Full() && !c.Exists(key) {
return CacheIsFull{}
}
if msg == nil {
logger.Debugf("Setting an empty value for key %s", key)
}
c.backend.Store(key, &Mesg{msg, blocked, WallClock.Now().Truncate(time.Second)})
return nil
}

func (c *lengCache) Exists(key string) bool {
_, ok := c.backend.Load(key)
return ok
}

func (c *lengCache) Remove(key string) {
_, loaded := c.backend.LoadAndDelete(key)
if loaded {
newSize := c.size.Add(-1)
if newSize < c.maxSize {
c.full = false
}
}
}

func (c *lengCache) Length() int {
size := c.size.Load()
c.full = size > c.maxSize
return int(size)
}

func (c *lengCache) Full() bool {
if c.maxSize > 0 {
return c.full
}
return false
}

// MemoryCache type
Expand Down
13 changes: 9 additions & 4 deletions cache_test.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
package main

import (
"errors"
"fmt"
"net"
"regexp"
"strings"
"sync"
"testing"
"time"

Expand Down Expand Up @@ -250,8 +252,8 @@ func TestCacheTtlFrequentPolling(t *testing.T) {

}

/*
func TestExpirationRace(t *testing.T) {
t.Skip()
cache := makeCache()
fakeClock := clockwork.NewFakeClock()
WallClock = fakeClock
Expand Down Expand Up @@ -279,22 +281,25 @@ func TestExpirationRace(t *testing.T) {
}

for i := 0; i < 1000; i++ {
wg := &sync.WaitGroup{}
wg.Add(2)
fakeClock.Advance(time.Duration(100) * time.Millisecond)
go func() {
_, _, err := cache.Get(testDomain)
if err != nil {
if err != nil && !errors.Is(err, &KeyNotFound{}) {
t.Error(err)
}
wg.Done()
}()
go func() {
err := cache.Set(testDomain, m, true)
if err != nil {
t.Error(err)
}
wg.Done()
}()
}
}
*/

func BenchmarkSetCache(b *testing.B) {
cache := makeCache()
Expand All @@ -311,7 +316,7 @@ func BenchmarkSetCache(b *testing.B) {
}
}

func BenchmarkGetCache(b *testing.B) {
func BenchmarkGetCacheSingleDomain(b *testing.B) {
cache := makeCache()

m := new(dns.Msg)
Expand Down
147 changes: 147 additions & 0 deletions lcache/cache.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
package lcache

import (
"github.com/jonboulle/clockwork"
"github.com/miekg/dns"
"github.com/op/go-logging"
"math"
"strings"
"sync"
"sync/atomic"
"time"
)

var logger = logging.MustGetLogger("test")

// wallClock is the wall clock
var wallClock = clockwork.NewRealClock()

// entry represents a cache entry
type entry struct {
Msg *dns.Msg
Blocked bool
expiresAt time.Time
mu sync.Mutex
}

// Cache interface
type Cache interface {
Get(key string) (Msg *dns.Msg, blocked bool, err error)
Set(key string, Msg *dns.Msg, blocked bool) error
Exists(key string) bool
Remove(key string)
Length() int
Full() bool
}

type lengCache struct {
backend sync.Map // of string -> *entry
size atomic.Int64
full bool
maxSize int64
}

func New(maxSize int64) Cache {
return &lengCache{
backend: sync.Map{},
size: atomic.Int64{},
maxSize: maxSize,
}
}

func (c *lengCache) Get(key string) (Msg *dns.Msg, blocked bool, err error) {
key = strings.ToLower(key)

existing, ok := c.backend.Load(key)
if !ok {
logger.Debugf("Cache: Cannot find key %s\n", key)
return nil, false, KeyNotFound{key}
}
mesg := existing.(*entry)
if mesg.Msg == nil {
return nil, mesg.Blocked, nil
}
mesg.mu.Lock()
defer mesg.mu.Unlock()
now := wallClock.Now()

// entry expired!
if now.After(mesg.expiresAt) {
c.Remove(key)
return nil, false, KeyExpired{key}
}
newTtl := uint32(mesg.expiresAt.Sub(now).Truncate(time.Second).Seconds())

for _, answer := range mesg.Msg.Answer {
// this can happen concurrently (and it is a concurrent write of shared memory),
// but it's ok because two concurrent modifications usually have the same result
// when rounded to the second
answer.Header().Ttl = newTtl
}

return mesg.Msg, mesg.Blocked, nil
}

func minTtlFor(msg *dns.Msg) time.Duration {
if msg == nil {
return 0
}
// find smallest ttl
minTtl := uint32(math.MaxUint32)
for _, answer := range msg.Answer {
msgTtl := answer.Header().Ttl
if minTtl > msgTtl {
minTtl = msgTtl
}
}
return time.Duration(minTtl) * time.Second
}

func (c *lengCache) Set(key string, msg *dns.Msg, blocked bool) error {
key = strings.ToLower(key)

if c.Full() && !c.Exists(key) {
return CacheIsFull{}
}
if msg == nil {
logger.Debugf("Setting an empty value for key %s", key)
}

now := wallClock.Now()
e := entry{
Msg: msg,
Blocked: blocked,
expiresAt: now.Add(minTtlFor(msg)),
}
c.backend.Store(key, &e)
return nil
}

func (c *lengCache) Exists(key string) bool {
key = strings.ToLower(key)
_, ok := c.backend.Load(key)
return ok
}

func (c *lengCache) Remove(key string) {
_, loaded := c.backend.LoadAndDelete(key)
if loaded {
newSize := c.size.Add(-1)
if newSize < c.maxSize {
c.full = false
}
}
}

func (c *lengCache) Length() int {
size := c.size.Load()
c.full = size > c.maxSize
return int(size)
}

func (c *lengCache) Full() bool {
if c.maxSize > 0 {
return c.full
}
return false
}
Loading
Loading