Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle A and AAAA records TTL properly in dnscache #187

Draft
wants to merge 13 commits into
base: master
Choose a base branch
from
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,8 @@ vet:

# Generate code
generate: controller-gen mockgen manifests
go generate ./...
$(CONTROLLER_GEN) object paths="./..."
go generate ./...

.PHONY: controller-gen
controller-gen: $(CONTROLLER_GEN)
Expand Down
8 changes: 8 additions & 0 deletions api/v1/clusterwidenetworkpolicy_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,14 @@ type FQDNSelector struct {
MatchPattern string `json:"matchPattern,omitempty"`
}

// IPSet stores set name association to IP addresses
// type IPSet struct {
// FQDN string `json:"fqdn,omitempty"`
// SetName string `json:"setName,omitempty"`
// IPs map[string]metav1.Time `json:"ips,omitempty"`
// Version IPVersion `json:"version,omitempty"`
// }

// IPSet stores set name association to IP addresses
type IPSet struct {
FQDN string `json:"fqdn,omitempty"`
Expand Down
137 changes: 71 additions & 66 deletions pkg/dns/dnscache.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,7 @@ import (
"crypto/md5" //nolint:gosec
"encoding/hex"
"fmt"
"math"
"net"
"regexp"
"sort"
"strings"
"sync"
"time"
Expand All @@ -16,7 +13,7 @@ import (
"github.com/go-logr/logr"
"github.com/google/nftables"
dnsgo "github.com/miekg/dns"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
// metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"

firewallv1 "github.com/metal-stack/firewall-controller/v2/api/v1"
)
Expand All @@ -42,31 +39,23 @@ type RenderIPSet struct {
}

type ipEntry struct {
ips []string
expirationTime time.Time
setName string
// ips is a map of the ip address and its expiration time which is the time of the DNS lookup + the TTL
ips map[string]time.Time
setName string
}

func newIPEntry(setName string, expirationTime time.Time) *ipEntry {
func newIPEntry(setName string) *ipEntry {
return &ipEntry{
expirationTime: expirationTime,
setName: setName,
setName: setName,
ips: map[string]time.Time{},
}
}

func (e *ipEntry) update(setName string, ips []net.IP, expirationTime time.Time, dtype nftables.SetDatatype) error {
newIPs, deletedIPs := e.getNewAndDeletedIPs(ips)
if !e.expirationTime.After(time.Now()) {
e.expirationTime = expirationTime
}
func (e *ipEntry) update(log logr.Logger, setName string, rrs []dnsgo.RR, lookupTime time.Time, dtype nftables.SetDatatype) error {
deletedIPs := e.expireIPs()
newIPs := e.addAndUpdateIPs(log, rrs, lookupTime)

if newIPs != nil || deletedIPs != nil {
e.ips = make([]string, len(ips))
for i, ip := range ips {
e.ips[i] = ip.String()
}
sort.Strings(e.ips)

if err := updateNftSet(newIPs, deletedIPs, setName, dtype); err != nil {
return fmt.Errorf("failed to update nft set: %w", err)
}
Expand All @@ -75,27 +64,32 @@ func (e *ipEntry) update(setName string, ips []net.IP, expirationTime time.Time,
return nil
}

func (e *ipEntry) getNewAndDeletedIPs(ips []net.IP) (newIPs, deletedIPs []nftables.SetElement) {
currentIps := make(map[string]bool, len(e.ips))
for _, ip := range e.ips {
currentIps[ip] = false
}

for _, ip := range ips {
s := ip.String()
if _, ok := currentIps[s]; ok {
currentIps[s] = true
} else {
newIPs = append(newIPs, nftables.SetElement{Key: ip})
func (e *ipEntry) expireIPs() (deletedIPs []nftables.SetElement) {
for ip, expirationTime := range e.ips {
if expirationTime.Before(time.Now()) {
deletedIPs = append(deletedIPs, nftables.SetElement{Key: []byte(ip)})
delete(e.ips, ip)
}
}
return
}

for ip, exists := range currentIps {
if !exists {
deletedIPs = append(deletedIPs, nftables.SetElement{Key: net.ParseIP(ip)})
func (e *ipEntry) addAndUpdateIPs(log logr.Logger, rrs []dnsgo.RR, lookupTime time.Time) (newIPs []nftables.SetElement) {
for _, rr := range rrs {
var s string
switch r := rr.(type) {
case *dnsgo.A:
s = r.A.String()
case *dnsgo.AAAA:
s = r.AAAA.String()
}
}
if _, ok := e.ips[s]; ok {
newIPs = append(newIPs, nftables.SetElement{Key: []byte(s)})
}
log.WithValues("ip", s, "rr header ttl", rr.Header().Ttl, "expiration time", lookupTime.Add(time.Duration(rr.Header().Ttl)*time.Second))
e.ips[s] = lookupTime.Add(time.Duration(rr.Header().Ttl) * time.Second)

}
return
}

Expand Down Expand Up @@ -197,9 +191,17 @@ func (c *DNSCache) restoreSets(fqdnSets []firewallv1.IPSet) {
}

ipe := &ipEntry{
ips: s.IPs,
expirationTime: s.ExpirationTime.Time,
setName: s.SetName,
setName: s.SetName,
}
for _, ip := range s.IPs {
ipa, _, _ := strings.Cut(ip, ",")
expirationTime := time.Now()
if _, ets, found := strings.Cut(ip, ": "); found {
if err := expirationTime.UnmarshalText([]byte(ets)); err != nil {
expirationTime = time.Now()
}
}
ipe.ips[ipa] = expirationTime
}
switch s.Version {
case firewallv1.IPv4:
Expand Down Expand Up @@ -311,10 +313,8 @@ func (c *DNSCache) Update(lookupTime time.Time, qname string, msg *dnsgo.Msg, fq
return true, fmt.Errorf("too many hops, fqdn chain: %s", strings.Join(fqdns, ","))
}

ipv4 := []net.IP{}
ipv6 := []net.IP{}
minIPv4TTL := uint32(math.MaxUint32)
minIPv6TTL := uint32(math.MaxUint32)
ipv4 := []dnsgo.RR{}
ipv6 := []dnsgo.RR{}
found := false

for _, ans := range msg.Answer {
Expand All @@ -326,17 +326,11 @@ func (c *DNSCache) Update(lookupTime time.Time, qname string, msg *dnsgo.Msg, fq

switch rr := ans.(type) {
case *dnsgo.A:
ipv4 = append(ipv4, rr.A)
if minIPv4TTL > rr.Hdr.Ttl {
minIPv4TTL = rr.Hdr.Ttl
}
ipv4 = append(ipv4, rr)
found = true
c.log.V(4).Info("DEBUG dnscache Update function A record found", "IPs", ipv4)
case *dnsgo.AAAA:
ipv6 = append(ipv6, rr.AAAA)
if minIPv6TTL > rr.Hdr.Ttl {
minIPv6TTL = rr.Hdr.Ttl
}
ipv6 = append(ipv6, rr)
found = true
c.log.V(4).Info("DEBUG dnscache Update function AAAA record found", "IPs", ipv6)
case *dnsgo.CNAME:
Expand All @@ -362,12 +356,12 @@ func (c *DNSCache) Update(lookupTime time.Time, qname string, msg *dnsgo.Msg, fq
for _, fqdn := range fqdns {
c.log.V(4).Info("DEBUG dnscache Update function Updating DNS cache for", "fqdn", fqdn, "ipv4", ipv4, "ipv6", ipv6)
if c.ipv4Enabled && len(ipv4) > 0 {
if err := c.updateIPEntry(fqdn, ipv4, lookupTime.Add(time.Duration(minIPv4TTL)), nftables.TypeIPAddr); err != nil {
if err := c.updateIPEntry(fqdn, ipv4, lookupTime, nftables.TypeIPAddr); err != nil {
return false, fmt.Errorf("failed to update IPv4 addresses: %w", err)
}
}
if c.ipv6Enabled && len(ipv6) > 0 {
if err := c.updateIPEntry(fqdn, ipv6, lookupTime.Add(time.Duration(minIPv6TTL)), nftables.TypeIP6Addr); err != nil {
if err := c.updateIPEntry(fqdn, ipv6, lookupTime, nftables.TypeIP6Addr); err != nil {
return false, fmt.Errorf("failed to update IPv6 addresses: %w", err)
}
}
Expand All @@ -376,10 +370,10 @@ func (c *DNSCache) Update(lookupTime time.Time, qname string, msg *dnsgo.Msg, fq
return found, nil
}

func (c *DNSCache) updateIPEntry(qname string, ips []net.IP, expirationTime time.Time, dtype nftables.SetDatatype) error {
func (c *DNSCache) updateIPEntry(qname string, rrs []dnsgo.RR, lookupTime time.Time, dtype nftables.SetDatatype) error {
scopedLog := c.log.WithValues(
"fqdn", qname,
"ip_len", len(ips),
"ip_len", len(rrs),
"dtype", dtype.Name,
)

Expand All @@ -396,21 +390,22 @@ func (c *DNSCache) updateIPEntry(qname string, ips []net.IP, expirationTime time
case nftables.TypeIPAddr:
if entry.ipv4 == nil {
setName := c.createSetName(qname, dtype.Name, 0)
ipe = newIPEntry(setName, expirationTime)
ipe = newIPEntry(setName)
entry.ipv4 = ipe
}
ipe = entry.ipv4
case nftables.TypeIP6Addr:
if entry.ipv6 == nil {
setName := c.createSetName(qname, dtype.Name, 0)
ipe = newIPEntry(setName, expirationTime)
ipe = newIPEntry(setName)
entry.ipv6 = ipe
}
ipe = entry.ipv6
}

setName := ipe.setName
if err := ipe.update(setName, ips, expirationTime, dtype); err != nil {
scopedLog.WithValues("set", setName, "lookupTime", lookupTime, "rrs", rrs).Info("updating ip entry")
if err := ipe.update(scopedLog, setName, rrs, lookupTime, dtype); err != nil {
return fmt.Errorf("failed to update ipEntry: %w", err)
}
c.fqdnToEntry[qname] = entry
Expand Down Expand Up @@ -478,19 +473,29 @@ func updateNftSet(
}

func createIPSetFromIPEntry(fqdn string, version firewallv1.IPVersion, entry *ipEntry) firewallv1.IPSet {
return firewallv1.IPSet{
FQDN: fqdn,
SetName: entry.setName,
IPs: entry.ips,
ExpirationTime: metav1.Time{Time: entry.expirationTime},
Version: version,
ips := firewallv1.IPSet{
FQDN: fqdn,
SetName: entry.setName,
IPs: []string{},
Version: version,
}
for ip, expirationTime := range entry.ips {
if et, err := expirationTime.MarshalText(); err == nil {
ip = ip + ", expiration time: " + string(et)
}
ips.IPs = append(ips.IPs, ip)
}
return ips
}

func createRenderIPSetFromIPEntry(version IPVersion, entry *ipEntry) RenderIPSet {
var ips []string
for ip, _ := range entry.ips {
ips = append(ips, ip)
}
return RenderIPSet{
SetName: entry.setName,
IPs: entry.ips,
IPs: ips,
Version: version,
}
}
11 changes: 8 additions & 3 deletions pkg/nftables/networkpolicy.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ func clusterwideNetworkPolicyEgressRules(
np firewallv1.ClusterwideNetworkPolicy,
logAcceptedConnections bool,
) (rules nftablesRules, updated firewallv1.ClusterwideNetworkPolicy) {
var fqdnState firewallv1.FQDNState
for _, e := range np.Spec.Egress {
tcpPorts, udpPorts := calculatePorts(e.Ports)
ruleBases := []ruleBase{}
Expand All @@ -95,9 +96,9 @@ func clusterwideNetworkPolicyEgressRules(
ruleBases = append(ruleBases, ruleBase{base: rb})
} else if len(e.ToFQDNs) > 0 && cache.IsInitialized() {
// Generate allow rules based on DNS selectors
rbs, u := clusterwideNetworkPolicyEgressToFQDNRules(cache, e)
np.Status.FQDNState = u
rbs, u := clusterwideNetworkPolicyEgressToFQDNRules(cache, fqdnState, e)
ruleBases = append(ruleBases, rbs...)
fqdnState = u
}

comment := fmt.Sprintf("accept traffic for np %s", np.ObjectMeta.Name)
Expand All @@ -111,6 +112,7 @@ func clusterwideNetworkPolicyEgressRules(
}
}

np.Status.FQDNState = fqdnState
return uniqueSorted(rules), np
}

Expand All @@ -125,9 +127,12 @@ func clusterwideNetworkPolicyEgressToRules(e firewallv1.EgressRule) (allow, exce

func clusterwideNetworkPolicyEgressToFQDNRules(
cache FQDNCache,
fqdnState firewallv1.FQDNState,
e firewallv1.EgressRule,
) (rules []ruleBase, updatedState firewallv1.FQDNState) {
fqdnState := firewallv1.FQDNState{}
if fqdnState == nil {
fqdnState = firewallv1.FQDNState{}
}

for _, fqdn := range e.ToFQDNs {
fqdnName := fqdn.MatchName
Expand Down
Loading