diff --git a/handler.go b/handler.go index 83665cf..377ca76 100644 --- a/handler.go +++ b/handler.go @@ -2,6 +2,7 @@ package main import ( "github.com/cottand/leng/internal/metric" + "github.com/cottand/leng/lcache" "github.com/miekg/dns" "net" "slices" @@ -30,9 +31,9 @@ func (q *Question) String() string { type EventLoop struct { requestChannel chan DNSOperationData resolver *Resolver - cache Cache + cache lcache.Cache[lcache.DefaultEntry] // negCache caches failures - negCache Cache + negCache lcache.Cache[lcache.DefaultEntry] active bool muActive sync.RWMutex config *Config @@ -52,17 +53,12 @@ func NewEventLoop(config *Config, blockCache *MemoryBlockCache) *EventLoop { var ( clientConfig *dns.ClientConfig resolver *Resolver - cache Cache - negCache Cache ) resolver = &Resolver{clientConfig} - //cache = lcache.NewDefault(int64(config.Upstream.Maxcount)) - negCache = &MemoryCache{ - Backend: make(map[string]*Mesg), - Maxcount: config.Upstream.Maxcount, - } + cache := lcache.NewDefault(config.Upstream.Maxcount) + negCache := lcache.NewDefault(config.Upstream.Maxcount) handler := &EventLoop{ requestChannel: make(chan DNSOperationData), @@ -91,7 +87,7 @@ func (h *EventLoop) do() { } // responseFor has side-effects, like writing to h's caches, so avoid calling it concurrently -func (h *EventLoop) responseFor(Net string, req *dns.Msg, _local net.Addr, _remote net.Addr) (_ *dns.Msg, success bool) { +func (h *EventLoop) responseFor(Net string, req *dns.Msg, _local net.Addr, _remote net.Addr) (resp *dns.Msg, success bool, blocked bool, cached bool) { var remote net.IP if Net == "tcp" || Net == "http" { @@ -102,118 +98,76 @@ func (h *EventLoop) responseFor(Net string, req *dns.Msg, _local net.Addr, _remo // first of all, check custom DNS. No need to cache it because it is already in-mem and precedes the blocking if custom := h.customDns.Resolve(req, _local, _remote); custom != nil { - return custom, true + return custom, true, false, true } + // does not include custom DNS + defer metric.ReportDNSRespond(remote, resp, blocked, cached) + q := req.Question[0] Q := Question{UnFqdn(q.Name), dns.TypeToString[q.Qtype], dns.ClassToString[q.Qclass]} logger.Infof("%s lookup %s\n", remote, Q.String()) IPQuery := h.isIPQuery(q) + blocked = IPQuery > 0 && lengActive && h.blockCache.Exists(Q.Qname) + if blocked { + resp = h.blockedResponseFor(req, IPQuery) + + logger.Noticef("%s found in blocklist\n", Q.Qname) + return resp, true, blocked, false + } + // Only query cache when qtype == 'A'|'AAAA' , qclass == 'IN' key := KeyGen(Q) if IPQuery > 0 { - mesg, blocked, err := h.cache.Get(key) + mesg, err := h.cache.Get(key) if err != nil { - if mesg, blocked, err = h.negCache.Get(key); err != nil { + if mesg, err = h.negCache.Get(key); err != nil { logger.Debugf("%s didn't hit cache\n", Q.String()) } else { logger.Debugf("%s hit negative cache\n", Q.String()) - return nil, false + return nil, false, true, false } } else { - if blocked && !lengActive { - logger.Debugf("%s hit cache and was blocked: forwarding request\n", Q.String()) - } else { - logger.Debugf("%s hit cache\n", Q.String()) + cached = true + logger.Debugf("%s hit cache\n", Q.String()) - // we need this copy against concurrent modification of ID - msg := *mesg - msg.Id = req.Id + // we need this copy against concurrent modification of ID + msg := *mesg + msg.Id = req.Id - defer metric.ReportDNSRespond(remote, &msg, blocked, true) - return &msg, true - } + return &msg.Msg, true, blocked, cached } } - // Check blocklist - var blacklisted = false - - if IPQuery > 0 { - blacklisted = h.blockCache.Exists(Q.Qname) + cached = false - if lengActive && blacklisted { - m := new(dns.Msg) - m.SetReply(req) - - if h.config.Blocking.NXDomain { - m.SetRcode(req, dns.RcodeNameError) - } else { - nullroute := net.ParseIP(h.config.Blocking.Nullroute) - nullroutev6 := net.ParseIP(h.config.Blocking.Nullroutev6) - - switch IPQuery { - case _IP4Query: - rrHeader := dns.RR_Header{ - Name: q.Name, - Rrtype: dns.TypeA, - Class: dns.ClassINET, - Ttl: h.config.TTL, - } - a := &dns.A{Hdr: rrHeader, A: nullroute} - m.Answer = append(m.Answer, a) - case _IP6Query: - rrHeader := dns.RR_Header{ - Name: q.Name, - Rrtype: dns.TypeAAAA, - Class: dns.ClassINET, - Ttl: h.config.TTL, - } - a := &dns.AAAA{Hdr: rrHeader, AAAA: nullroutev6} - m.Answer = append(m.Answer, a) - } - } - - defer metric.ReportDNSRespond(remote, m, true, false) - - logger.Noticef("%s found in blocklist\n", Q.Qname) - - // cache the block; we don't know the true TTL for blocked entries: we just enforce our config - err := h.cache.Set(key, m, true) - if err != nil { - logger.Errorf("Set %s block cache failed: %s\n", Q.String(), err.Error()) - } - - return m, true - } - logger.Debugf("%s not found in blocklist\n", Q.Qname) - } - - mesg, err := h.resolver.Lookup(Net, req, h.config.Timeout, h.config.Interval, h.config.Upstream.Nameservers, h.config.Upstream.DoH) + resp, err := h.resolver.Lookup(Net, req, h.config.Timeout, h.config.Interval, h.config.Upstream.Nameservers, h.config.Upstream.DoH) if err != nil { logger.Errorf("resolve query error %s\n", err) // cache the failure, too! - if err = h.negCache.Set(key, nil, false); err != nil { + // TODO set TTL for failed errors + if err = h.negCache.Set(key, &lcache.DefaultEntry{}); err != nil { logger.Errorf("set %s negative cache failed: %v\n", Q.String(), err) } - return nil, false + return nil, false, blocked, cached } // if we were doing DNS over UDP, and we got a truncated response, // we retry in TCP in hopes that we do not get a truncated one again. - if mesg.Truncated && Net == "udp" { - mesg, err = h.resolver.Lookup("tcp", req, h.config.Timeout, h.config.Interval, h.config.Upstream.Nameservers, h.config.Upstream.DoH) + if resp.Truncated && Net == "udp" { + resp, err = h.resolver.Lookup("tcp", req, h.config.Timeout, h.config.Interval, h.config.Upstream.Nameservers, h.config.Upstream.DoH) if err != nil { logger.Errorf("resolve tcp query error %s\n", err) // cache the failure, too! - if err = h.negCache.Set(key, nil, false); err != nil { + // TODO set TTL for failed errors + if err = h.negCache.Set(key, &lcache.DefaultEntry{}); err != nil { logger.Errorf("set %s negative cache failed: %v\n", Q.String(), err) } - return nil, false + return nil, false, blocked, cached } } @@ -221,30 +175,27 @@ func (h *EventLoop) responseFor(Net string, req *dns.Msg, _local net.Addr, _remo ttl := h.config.Upstream.Expire var candidateTTL uint32 - for index, answer := range mesg.Answer { + for index, answer := range resp.Answer { logger.Debugf("Answer %d - %s\n", index, answer.String()) candidateTTL = answer.Header().Ttl + // TODO is a zero TTL a forever TTL?? if candidateTTL > 0 && candidateTTL < ttl { ttl = candidateTTL } } - defer metric.ReportDNSRespond(remote, mesg, false, false) - - if IPQuery > 0 && len(mesg.Answer) > 0 { - if !lengActive && blacklisted { - logger.Debugf("%s is blacklisted and leng not active: not caching\n", Q.String()) - } else { - err = h.cache.Set(key, mesg, false) + if IPQuery > 0 && len(resp.Answer) > 0 { + go func() { + err := h.cache.Set(key, &lcache.DefaultEntry{Msg: *resp}) if err != nil { - logger.Errorf("set %s cache failed: %s\n", Q.String(), err.Error()) + logger.Warningf("set %s cache failed: %v\n", Q.String(), err) } logger.Debugf("insert %s into cache with ttl %d\n", Q.String(), ttl) - } + }() } - return mesg, true + return resp, true, blocked, cached } func (h *EventLoop) doRequest(Net string, w dns.ResponseWriter, req *dns.Msg) { @@ -252,7 +203,7 @@ func (h *EventLoop) doRequest(Net string, w dns.ResponseWriter, req *dns.Msg) { _ = w.Close() }(w) - resp, ok := h.responseFor(Net, req, w.LocalAddr(), w.RemoteAddr()) + resp, ok, _, _ := h.responseFor(Net, req, w.LocalAddr(), w.RemoteAddr()) if !ok { m := new(dns.Msg) @@ -272,7 +223,7 @@ func (h *EventLoop) doRequest(Net string, w dns.ResponseWriter, req *dns.Msg) { for _, cname := range cnames { r := dns.Msg{} r.SetQuestion(cname.Target, req.Question[0].Qtype) - followed, ok := h.responseFor(Net, &r, w.LocalAddr(), w.RemoteAddr()) + followed, ok, _, _ := h.responseFor(Net, &r, w.LocalAddr(), w.RemoteAddr()) for _, fAnswer := range followed.Answer { containsNewAnswer := func(rr dns.RR) bool { return rr.String() == fAnswer.String() @@ -372,6 +323,40 @@ func (h *EventLoop) isIPQuery(q dns.Question) int { return notIPQuery } } +func (h *EventLoop) blockedResponseFor(req *dns.Msg, IPQuery int) *dns.Msg { + m := new(dns.Msg) + m.SetReply(req) + q := req.Question[0] + + if h.config.Blocking.NXDomain { + m.SetRcode(req, dns.RcodeNameError) + } else { + nullroute := net.ParseIP(h.config.Blocking.Nullroute) + nullroutev6 := net.ParseIP(h.config.Blocking.Nullroutev6) + + switch IPQuery { + case _IP4Query: + rrHeader := dns.RR_Header{ + Name: q.Name, + Rrtype: dns.TypeA, + Class: dns.ClassINET, + Ttl: h.config.TTL, + } + a := &dns.A{Hdr: rrHeader, A: nullroute} + m.Answer = append(m.Answer, a) + case _IP6Query: + rrHeader := dns.RR_Header{ + Name: q.Name, + Rrtype: dns.TypeAAAA, + Class: dns.ClassINET, + Ttl: h.config.TTL, + } + a := &dns.AAAA{Hdr: rrHeader, AAAA: nullroutev6} + m.Answer = append(m.Answer, a) + } + } + return m +} // UnFqdn function func UnFqdn(s string) string { diff --git a/lcache/cache.go b/lcache/cache.go index 9f6bce3..11f5900 100644 --- a/lcache/cache.go +++ b/lcache/cache.go @@ -45,11 +45,11 @@ type lengCache[E Entry] struct { // NewGeneric creates a new Cache // maxSize <= 0 means the cache is unbounded -func NewGeneric[E Entry](maxSize int64) Cache[E] { +func NewGeneric[E Entry](maxSize int) Cache[E] { return &lengCache[E]{ backend: sync.Map{}, size: atomic.Int64{}, - maxSize: maxSize, + maxSize: int64(maxSize), } } diff --git a/lcache/default.go b/lcache/default.go index dac5093..fc5d204 100644 --- a/lcache/default.go +++ b/lcache/default.go @@ -13,6 +13,6 @@ func (dnsEntry DefaultEntry) RRs() []dns.RR { return dnsEntry.Answer } -func NewDefault(maxSize int64) Cache[DefaultEntry] { +func NewDefault(maxSize int) Cache[DefaultEntry] { return NewGeneric[DefaultEntry](maxSize) }