From 98eb70e7a1ed0167d71039842854fb296c15f87b Mon Sep 17 00:00:00 2001 From: Johan Stenstam Date: Mon, 27 May 2024 18:11:43 +0200 Subject: [PATCH] foo --- apihandler.go | 59 +++++++++++++++++++++++++------- bootstrap.go | 6 ++-- dnshandler.go | 89 +++++++++++++++++++++++++++++++++++++----------- logging.go | 10 ++++-- main.go | 33 ++++++++++-------- policy.go | 2 ++ reaper.go | 4 +-- refreshengine.go | 42 ++++++++++++++++------- rpz.go | 32 ++++++++--------- sources.go | 8 ++--- xfr.go | 23 +++++++++---- 11 files changed, 212 insertions(+), 96 deletions(-) diff --git a/apihandler.go b/apihandler.go index 41dd71d..9b1416a 100644 --- a/apihandler.go +++ b/apihandler.go @@ -61,7 +61,11 @@ func APIcommand(conf *Config) func(w http.ResponseWriter, r *http.Request) { defer func() { // log.Printf("defer: resp: %v", resp) w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(resp) + err := json.NewEncoder(w).Encode(resp) + if err != nil { + log.Printf("Error from json encoder: %v", err) + log.Printf("resp: %v", resp) + } }() switch cp.Command { @@ -85,15 +89,27 @@ func APIcommand(conf *Config) func(w http.ResponseWriter, r *http.Request) { } case "mqtt-start": - conf.TemData.MqttEngine.StartEngine() + _, _, _, err := conf.TemData.MqttEngine.StartEngine() + if err != nil { + resp.Error = true + resp.ErrorMsg = err.Error() + } resp.Msg = "MQTT engine started" case "mqtt-stop": - conf.TemData.MqttEngine.StopEngine() + _, err := conf.TemData.MqttEngine.StopEngine() + if err != nil { + resp.Error = true + resp.ErrorMsg = err.Error() + } resp.Msg = "MQTT engine stopped" case "mqtt-restart": - conf.TemData.MqttEngine.RestartEngine() + _, err := conf.TemData.MqttEngine.RestartEngine() + if err != nil { + resp.Error = true + resp.ErrorMsg = err.Error() + } resp.Msg = "MQTT engine restarted" case "rpz-add": @@ -508,14 +524,18 @@ func APIdispatcher(conf *Config, done <-chan struct{}) { } tlsServer := &http.Server{ - Addr: tlsaddress, - Handler: router, - TLSConfig: tlsConfig, + Addr: tlsaddress, + Handler: router, + TLSConfig: tlsConfig, + ReadTimeout: 10 * time.Second, + WriteTimeout: 10 * time.Second, } bootstrapTlsServer := &http.Server{ - Addr: bootstraptlsaddress, - Handler: bootstraprouter, - TLSConfig: tlsConfig, + Addr: bootstraptlsaddress, + Handler: bootstraprouter, + TLSConfig: tlsConfig, + ReadTimeout: 10 * time.Second, + WriteTimeout: 10 * time.Second, } var wg sync.WaitGroup @@ -525,9 +545,16 @@ func APIdispatcher(conf *Config, done <-chan struct{}) { if address != "" { wg.Add(1) go func(wg *sync.WaitGroup) { + apiServer := &http.Server{ + Addr: address, + Handler: router, + ReadTimeout: 10 * time.Second, + WriteTimeout: 10 * time.Second, + } + log.Println("*** API: Starting API dispatcher #1. Listening on", address) wg.Done() - TEMExiter(http.ListenAndServe(address, router)) + TEMExiter(apiServer.ListenAndServe()) }(&wg) } @@ -547,9 +574,15 @@ func APIdispatcher(conf *Config, done <-chan struct{}) { if bootstrapaddress != "" { wg.Add(1) go func(wg *sync.WaitGroup) { + apiServer := &http.Server{ + Addr: bootstrapaddress, + Handler: bootstraprouter, + ReadTimeout: 10 * time.Second, + WriteTimeout: 10 * time.Second, + } log.Println("*** API: Starting Bootstrap API dispatcher #1. Listening on", bootstrapaddress) wg.Done() - TEMExiter(http.ListenAndServe(bootstrapaddress, bootstraprouter)) + TEMExiter(apiServer.ListenAndServe()) }(&wg) } else { log.Println("*** API: No bootstrap address specified") @@ -587,7 +620,7 @@ func BumpSerial(conf *Config, zone string) (string, error) { if resp.Error { log.Printf("BumpSerial: Error from RefreshEngine: %s", resp.ErrorMsg) return fmt.Sprintf("Zone %s: error bumping SOA serial: %s", zone, resp.ErrorMsg), - fmt.Errorf("Zone %s: error bumping SOA serial and epoch: %v", zone, resp.ErrorMsg) + fmt.Errorf("zone %s: error bumping SOA serial and epoch: %v", zone, resp.ErrorMsg) } if resp.Msg == "" { diff --git a/bootstrap.go b/bootstrap.go index 4ba7262..42a53af 100644 --- a/bootstrap.go +++ b/bootstrap.go @@ -39,7 +39,7 @@ func (td *TemData) BootstrapMqttSource(s *tapir.WBGlist, src SourceConf) (*tapir tlsConfig.InsecureSkipVerify = true err = api.SetupTLS(tlsConfig) if err != nil { - return nil, fmt.Errorf("Error setting up TLS for the API client: %v", err) + return nil, fmt.Errorf("error setting up TLS for the API client: %v", err) } // Iterate over the bootstrap servers @@ -53,7 +53,7 @@ func (td *TemData) BootstrapMqttSource(s *tapir.WBGlist, src SourceConf) (*tapir continue } - uptime := time.Now().Sub(pr.BootTime).Round(time.Second) + uptime := time.Since(pr.BootTime).Round(time.Second) td.Logger.Printf("MQTT bootstrap server %s uptime: %v. It has processed %d MQTT messages", server, uptime, 17) status, buf, err := api.RequestNG(http.MethodPost, "/bootstrap", tapir.BootstrapPost{ @@ -108,5 +108,5 @@ func (td *TemData) BootstrapMqttSource(s *tapir.WBGlist, src SourceConf) (*tapir } // If no bootstrap server succeeded - return nil, fmt.Errorf("All bootstrap servers failed") + return nil, fmt.Errorf("all bootstrap servers failed") } diff --git a/dnshandler.go b/dnshandler.go index 10bd2fe..883c881 100644 --- a/dnshandler.go +++ b/dnshandler.go @@ -65,7 +65,10 @@ func createHandler(conf *Config) func(w dns.ResponseWriter, r *dns.Msg) { // send NOERROR response m := new(dns.Msg) m.SetReply(r) - w.WriteMsg(m) + err := w.WriteMsg(m) + if err != nil { + lg.Printf("Error from WriteMsg(): %v", err) + } if _, ok := td.RpzSources[qname]; ok { lg.Printf("Received Notify for known zone %s. Fetching from upstream", qname) @@ -82,14 +85,20 @@ func createHandler(conf *Config) func(w dns.ResponseWriter, r *dns.Msg) { qtype := r.Question[0].Qtype lg.Printf("Zone %s %s request from %s", qname, dns.TypeToString[qtype], w.RemoteAddr()) if qname == td.Rpz.ZoneName { - td.RpzResponder(w, r, qtype, lg) + err := td.RpzResponder(w, r, qtype, lg) + if err != nil { + lg.Printf("Error from RpzResponder(): %v", err) + } } else if zd, ok := td.RpzSources[qname]; ok { // The qname is equal to the name of a zone we have - ApexResponder(w, r, zd, qname, qtype, lg) + err := ApexResponder(w, r, zd, qname, qtype, lg) + if err != nil { + lg.Printf("Error from ApexResponder(): %v", err) + } } else { lg.Printf("DnsHandler: Qname is '%s', which is not a known zone.", qname) known_zones := []string{td.Rpz.ZoneName} - for z, _ := range td.RpzSources { + for z := range td.RpzSources { known_zones = append(known_zones, z) } lg.Printf("DnsHandler: Known zones are: %v", known_zones) @@ -98,7 +107,10 @@ func createHandler(conf *Config) func(w dns.ResponseWriter, r *dns.Msg) { if strings.HasSuffix(qname, td.Rpz.ZoneName) { lg.Printf("Query for qname %s belongs in our own RPZ \"%s\"", qname, td.Rpz.ZoneName) - td.QueryResponder(w, r, qname, qtype, lg) + err := td.QueryResponder(w, r, qname, qtype, lg) + if err != nil { + lg.Printf("Error from QueryResponder(): %v", err) + } return } zd := td.FindZone(qname) @@ -106,18 +118,27 @@ func createHandler(conf *Config) func(w dns.ResponseWriter, r *dns.Msg) { lg.Printf("After FindZone: zd==nil") m := new(dns.Msg) m.SetRcode(r, dns.RcodeRefused) - w.WriteMsg(m) + err := w.WriteMsg(m) + if err != nil { + lg.Printf("Error from WriteMsg(): %v", err) + } return // didn't find any zone for that qname or found zone, but it is an XFR zone only } lg.Printf("After FindZone: zd: zd.ZoneType: %v", zd.ZoneType) if zd.ZoneType == tapir.XfrZone { m := new(dns.Msg) m.SetRcode(r, dns.RcodeRefused) - w.WriteMsg(m) + err := w.WriteMsg(m) + if err != nil { + lg.Printf("Error from WriteMsg(): %v", err) + } return // didn't find any zone for that qname or found zone, but it is an XFR zone only } lg.Printf("Found matching full zone for qname %s: %s", qname, zd.ZoneName) - QueryResponder(w, r, zd, qname, qtype, lg) + err := QueryResponder(w, r, zd, qname, qtype, lg) + if err != nil { + lg.Printf("Error from QueryResponder(): %v", err) + } return } return @@ -189,7 +210,10 @@ func (td *TemData) RpzResponder(w dns.ResponseWriter, r *dns.Msg, qtype uint16, m.MsgHdr.Rcode = dns.RcodeRefused m.Ns = append(m.Ns, zd.NSrrs...) } - w.WriteMsg(m) + err = w.WriteMsg(m) + if err != nil { + lg.Printf("Error from WriteMsg(): %v", err) + } return nil } @@ -230,7 +254,10 @@ func ApexResponder(w dns.ResponseWriter, r *dns.Msg, zd *tapir.ZoneData, m.MsgHdr.Rcode = dns.RcodeRefused m.Ns = append(m.Ns, zd.NSrrs...) } - w.WriteMsg(m) + err := w.WriteMsg(m) + if err != nil { + lg.Printf("Error from WriteMsg(): %v", err) + } return nil } @@ -260,8 +287,10 @@ func QueryResponder(w dns.ResponseWriter, r *dns.Msg, zd *tapir.ZoneData, qname // return NXDOMAIN m.MsgHdr.Rcode = dns.RcodeNameError m.Ns = append(m.Ns, apex.RRtypes[dns.TypeSOA].RRs...) - w.WriteMsg(m) - return + err := w.WriteMsg(m) + if err != nil { + lg.Printf("Error from WriteMsg(): %v", err) + } } // log.Printf("Zone %s Data: %v", zd.ZoneName, zd.Data) @@ -281,7 +310,10 @@ func QueryResponder(w dns.ResponseWriter, r *dns.Msg, zd *tapir.ZoneData, qname // return NXDOMAIN m.MsgHdr.Rcode = dns.RcodeNameError m.Ns = append(m.Ns, apex.RRtypes[dns.TypeSOA].RRs...) - w.WriteMsg(m) + err := w.WriteMsg(m) + if err != nil { + lg.Printf("Error from WriteMsg(): %v", err) + } return nil } @@ -296,7 +328,10 @@ func QueryResponder(w dns.ResponseWriter, r *dns.Msg, zd *tapir.ZoneData, qname if len(owner.RRtypes) == 0 { m.MsgHdr.Rcode = dns.RcodeNameError m.Ns = append(m.Ns, apex.RRtypes[dns.TypeSOA].RRs...) - w.WriteMsg(m) + err := w.WriteMsg(m) + if err != nil { + lg.Printf("Error from WriteMsg(): %v", err) + } return nil } @@ -318,7 +353,10 @@ func QueryResponder(w dns.ResponseWriter, r *dns.Msg, zd *tapir.ZoneData, qname glue = zd.FindGlue(apex.RRtypes[dns.TypeNS]) m.Extra = append(m.Extra, glue.RRs...) } - w.WriteMsg(m) + err := w.WriteMsg(m) + if err != nil { + lg.Printf("Error from WriteMsg(): %v", err) + } return nil } } @@ -346,7 +384,10 @@ func QueryResponder(w dns.ResponseWriter, r *dns.Msg, zd *tapir.ZoneData, qname } else { m.Ns = append(m.Ns, apex.RRtypes[dns.TypeSOA].RRs...) } - w.WriteMsg(m) + err := w.WriteMsg(m) + if err != nil { + lg.Printf("Error from WriteMsg(): %v", err) + } return nil default: @@ -355,7 +396,10 @@ func QueryResponder(w dns.ResponseWriter, r *dns.Msg, zd *tapir.ZoneData, qname m.Ns = append(m.Ns, apex.RRtypes[dns.TypeNS].RRs...) glue = zd.FindGlue(apex.RRtypes[dns.TypeNS]) m.Extra = append(m.Extra, glue.RRs...) - w.WriteMsg(m) + err := w.WriteMsg(m) + if err != nil { + lg.Printf("Error from WriteMsg(): %v", err) + } } return nil } @@ -371,8 +415,10 @@ func (td *TemData) QueryResponder(w dns.ResponseWriter, r *dns.Msg, qname string m.MsgHdr.Rcode = dns.RcodeNameError // m.Ns = append(m.Ns, apex.RRtypes[dns.TypeSOA].RRs...) m.Ns = append(m.Ns, dns.RR(&td.Rpz.Axfr.SOA)) - w.WriteMsg(m) - return + err := w.WriteMsg(m) + if err != nil { + lg.Printf("Error from WriteMsg(): %v", err) + } } // log.Printf("Zone %s Data: %v", zd.ZoneName, zd.Data) @@ -390,7 +436,10 @@ func (td *TemData) QueryResponder(w dns.ResponseWriter, r *dns.Msg, qname string default: m.Ns = append(m.Ns, dns.RR(&td.Rpz.Axfr.SOA)) } - w.WriteMsg(m) + err := w.WriteMsg(m) + if err != nil { + lg.Printf("Error from WriteMsg(): %v", err) + } return nil } returnNXDOMAIN() diff --git a/logging.go b/logging.go index 81465cb..bd1da57 100644 --- a/logging.go +++ b/logging.go @@ -8,6 +8,7 @@ import ( "fmt" "log" "os" + "path/filepath" "github.com/spf13/viper" "gopkg.in/natefinch/lumberjack.v2" @@ -30,7 +31,8 @@ func SetupLogging(conf *Config) { logfile = viper.GetString("policy.logfile") if logfile != "" { - f, err := os.OpenFile(logfile, os.O_RDWR|os.O_CREATE|os.O_APPEND, 0666) + logfile = filepath.Clean(logfile) + f, err := os.OpenFile(logfile, os.O_RDWR|os.O_CREATE|os.O_APPEND, 0644) if err != nil { TEMExiter("error opening TEM policy logfile '%s': %v", logfile, err) } @@ -50,7 +52,8 @@ func SetupLogging(conf *Config) { logfile = viper.GetString("dnsengine.logfile") if logfile != "" { - f, err := os.OpenFile(logfile, os.O_RDWR|os.O_CREATE|os.O_APPEND, 0666) + logfile = filepath.Clean(logfile) + f, err := os.OpenFile(logfile, os.O_RDWR|os.O_CREATE|os.O_APPEND, 0644) if err != nil { TEMExiter("error opening TEM dnsengine logfile '%s': %v", logfile, err) } @@ -70,7 +73,8 @@ func SetupLogging(conf *Config) { logfile = viper.GetString("mqtt.logfile") if logfile != "" { - f, err := os.OpenFile(logfile, os.O_RDWR|os.O_CREATE|os.O_APPEND, 0666) + logfile = filepath.Clean(logfile) + f, err := os.OpenFile(logfile, os.O_RDWR|os.O_CREATE|os.O_APPEND, 0644) if err != nil { TEMExiter("error opening TEM MQTT logfile '%s': %v", logfile, err) } diff --git a/main.go b/main.go index 7dc7f79..0778c9c 100644 --- a/main.go +++ b/main.go @@ -5,7 +5,6 @@ package main import ( - "flag" "fmt" "log" @@ -22,10 +21,6 @@ import ( "github.com/dnstapir/tapir" ) -var ( - soreuseport = flag.Int("soreuseport", 0, "use SO_REUSE_PORT") -) - var TEMExiter = func(args ...interface{}) { log.Printf("TEMExiter: [placeholderfunction w/o real cleanup]") log.Printf("TEMExiter: Exit message: %s", fmt.Sprintf(args[0].(string), args[1:]...)) @@ -61,7 +56,10 @@ func mainloop(conf *Config, configfile *string, td *TemData) { var msg string log.Printf("TEMExiter: will try to clean up.") - td.SaveRpzSerial() + err := td.SaveRpzSerial() + if err != nil { + log.Printf("Error saving RPZ serial: %v", err) + } switch args[0].(type) { case string: @@ -75,10 +73,8 @@ func mainloop(conf *Config, configfile *string, td *TemData) { } fmt.Println(msg) - log.Printf(msg) + log.Println(msg) - // var done struct{} - // apistopper <- done os.Exit(1) } @@ -91,7 +87,10 @@ func mainloop(conf *Config, configfile *string, td *TemData) { select { case <-exit: log.Println("mainloop: Exit signal received. Cleaning up.") - td.SaveRpzSerial() + err := td.SaveRpzSerial() + if err != nil { + log.Printf("Error saving RPZ serial: %v", err) + } // do whatever we need to do to wrap up nicely wg.Done() case <-hupper: @@ -107,7 +106,10 @@ func mainloop(conf *Config, configfile *string, td *TemData) { conf.TemData.RpzRefreshCh <- RpzRefresh{Name: ""} case <-conf.Internal.APIStopCh: log.Printf("mainloop: API instruction to stop\n") - td.SaveRpzSerial() + err := td.SaveRpzSerial() + if err != nil { + log.Printf("Error saving RPZ serial: %v", err) + } wg.Done() } } @@ -163,14 +165,17 @@ func main() { SetupLogging(&conf) fmt.Printf("Policy Logging to logger: %v\n", conf.Loggers.Policy) - ValidateConfig(nil, cfgFileUsed) // will terminate on error + err := ValidateConfig(nil, cfgFileUsed) // will terminate on error + if err != nil { + TEMExiter("Error validating config: %v", err) + } - err := viper.Unmarshal(&conf) + err = viper.Unmarshal(&conf) if err != nil { TEMExiter("Error unmarshalling config into struct: %v", err) } - fmt.Printf("TEM (TAPIR Edge Manager) version %s starting.\n", appVersion) + fmt.Printf("%s (TAPIR Edge Manager) version %s (%s) starting.\n", appName, appVersion, appDate) var stopch = make(chan struct{}, 10) diff --git a/policy.go b/policy.go index d1bb12e..96a48ad 100644 --- a/policy.go +++ b/policy.go @@ -8,6 +8,7 @@ import ( "log" "net" "os" + "path/filepath" "strconv" "strings" @@ -76,6 +77,7 @@ func (td *TemData) ParseOutputs() error { serialFile := viper.GetString("output.rpz.serialcache") if serialFile != "" { + serialFile = filepath.Clean(serialFile) serialData, err := os.ReadFile(serialFile) if err != nil { td.Logger.Printf("Error reading serial from file %s: %v", serialFile, err) diff --git a/reaper.go b/reaper.go index f2c8e12..0537c62 100644 --- a/reaper.go +++ b/reaper.go @@ -35,7 +35,7 @@ func (td *TemData) Reaper(full bool) error { if _, exist := wbgl.ReaperData[timekey]; !exist { wbgl.ReaperData[timekey] = map[string]bool{} } - for name, _ := range d { + for name := range d { wbgl.ReaperData[timekey][name] = true } // wbgl.ReaperData[timekey] = d @@ -48,7 +48,7 @@ func (td *TemData) Reaper(full bool) error { td.Logger.Printf("Reaper: list [%s][%s] has %d timekeys stored", listtype, listname, len(wbgl.ReaperData[timekey])) td.mu.Lock() - for name, _ := range wbgl.ReaperData[timekey] { + for name := range wbgl.ReaperData[timekey] { td.Logger.Printf("Reaper: removing %s from %s %s", name, listtype, listname) delete(td.Lists[listtype][listname].Names, name) delete(wbgl.ReaperData[timekey], name) diff --git a/refreshengine.go b/refreshengine.go index b30d33d..7e33745 100644 --- a/refreshengine.go +++ b/refreshengine.go @@ -57,17 +57,15 @@ func (td *TemData) RefreshEngine(conf *Config, stopch chan struct{}) { reaperTicker := time.NewTicker(td.ReaperInterval) go func() { - time.Sleep(reaperStart.Sub(time.Now())) + time.Sleep(time.Until(reaperStart)) reaperTicker.Reset(td.ReaperInterval) }() if !viper.GetBool("service.refresh.active") { log.Printf("Refresh Engine is NOT active. Zones will only be updated on receipt on Notifies.") - for { - select { - case <-zonerefch: // ensure that we keep reading to keep the - continue // channel open - } + for range zonerefch { + // ensure that we keep reading to keep the channel open + continue } } else { log.Printf("RefreshEngine: Starting") @@ -94,7 +92,10 @@ func (td *TemData) RefreshEngine(conf *Config, stopch chan struct{}) { case "intel-update", "observation": log.Printf("RefreshEngine: Tapir Observation update: (src: %s) %d additions and %d removals\n", tpkg.Data.SrcName, len(tpkg.Data.Added), len(tpkg.Data.Removed)) - td.ProcessTapirUpdate(tpkg) + _, err := td.ProcessTapirUpdate(tpkg) + if err != nil { + log.Printf("RefreshEngine: Error from ProcessTapirUpdate(): %v", err) + } log.Printf("RefreshEngine: Tapir Observation update evaluated.") case "global-config": @@ -150,7 +151,10 @@ func (td *TemData) RefreshEngine(conf *Config, stopch chan struct{}) { log.Printf("RefreshEngine: %s updated from upstream. Resetting serial to unixtime: %d", zone, td.RpzSources[zone].SOA.Serial) } - td.NotifyDownstreams() + err := td.NotifyDownstreams() + if err != nil { + log.Printf("RefreshEngine: Error notifying downstreams: %v", err) + } } // showing some apex details: log.Printf("Showing some details for zone %s: ", zone) @@ -254,7 +258,10 @@ func (td *TemData) RefreshEngine(conf *Config, stopch chan struct{}) { } } if updated { - td.NotifyDownstreams() + err := td.NotifyDownstreams() + if err != nil { + log.Printf("RefreshEngine: Error notifying downstreams: %v", err) + } } } } @@ -282,7 +289,11 @@ func (td *TemData) RefreshEngine(conf *Config, stopch chan struct{}) { zd.SOA.Serial = uint32(time.Now().Unix()) resp.NewSerial = zd.SOA.Serial rc = refreshCounters[zone] - td.NotifyDownstreams() + err := td.NotifyDownstreams() + if err != nil { + resp.Error = true + resp.ErrorMsg = fmt.Sprintf("Error notifying downstreams: %v", err) + } resp.Msg = fmt.Sprintf("Zone %s: bumped serial from %d to %d. Notified downstreams: %v", zone, resp.OldSerial, resp.NewSerial, rc.Downstreams) log.Printf(resp.Msg) @@ -316,9 +327,14 @@ func (td *TemData) RefreshEngine(conf *Config, stopch chan struct{}) { // if the name isn't either whitelisted or blacklisted if cmd.ListType == "greylist" { - td.GreylistAdd(cmd.Domain, cmd.Policy, cmd.RpzSource) - resp.Msg = fmt.Sprintf("Domain name \"%s\" (policy %s) added to greylisting DB.", - cmd.Domain, cmd.Policy) + _, err := td.GreylistAdd(cmd.Domain, cmd.Policy, cmd.RpzSource) + if err != nil { + resp.Error = true + resp.ErrorMsg = fmt.Sprintf("Error adding domain name \"%s\" to greylisting DB: %v", cmd.Domain, err) + } else { + resp.Msg = fmt.Sprintf("Domain name \"%s\" (policy %s) added to greylisting DB.", + cmd.Domain, cmd.Policy) + } cmd.Result <- resp continue } diff --git a/rpz.go b/rpz.go index 521e062..505a692 100644 --- a/rpz.go +++ b/rpz.go @@ -37,17 +37,17 @@ func (td *TemData) GenerateRpzAxfr() error { case "dawg": td.Logger.Printf("Cannot list DAWG lists. Ignoring blacklist %s.", bname) case "map": - for k, _ := range blist.Names { - if tapir.GlobalCF.Debug { - // td.Logger.Printf("Adding name %s from blacklist %s to tentative output.", - // k, bname) - } - if td.Whitelisted(k) { - // td.Logger.Printf("Blacklisted name %s is also whitelisted. Dropped from output.", k) - } else { - // td.Logger.Printf("Blacklisted name %s is not whitelisted. Added to output.", k) - black[k] = true - } + for k := range blist.Names { + // if tapir.GlobalCF.Debug { + // td.Logger.Printf("Adding name %s from blacklist %s to tentative output.", + // k, bname) + // } + // if td.Whitelisted(k) { + // td.Logger.Printf("Blacklisted name %s is also whitelisted. Dropped from output.", k) + // } else { + // td.Logger.Printf("Blacklisted name %s is not whitelisted. Added to output.", k) + black[k] = true + // } } } } @@ -92,9 +92,9 @@ func (td *TemData) GenerateRpzAxfr() error { td.GreylistedNames = grey td.Logger.Printf("GenRpzAxfr: There are a total of %d greylisted names in the sources", len(grey)) - newaxfrdata := []*tapir.RpzName{} + // newaxfrdata := []*tapir.RpzName{} // td.Rpz.RpzMap = map[string]*tapir.RpzName{} - for name, _ := range td.BlacklistedNames { + for name := range td.BlacklistedNames { cname := new(dns.CNAME) cname.Hdr = dns.RR_Header{ Name: name + td.Rpz.ZoneName, @@ -110,7 +110,7 @@ func (td *TemData) GenerateRpzAxfr() error { RR: &rr, Action: td.Policy.BlacklistAction, } - newaxfrdata = append(newaxfrdata, &rpzn) + // newaxfrdata = append(newaxfrdata, &rpzn) // td.Rpz.RpzMap[nname+td.Rpz.ZoneName] = &rpzn td.mu.Lock() td.Rpz.Axfr.Data[name+td.Rpz.ZoneName] = &rpzn @@ -137,7 +137,7 @@ func (td *TemData) GenerateRpzAxfr() error { RR: &rr, Action: td.Policy.BlacklistAction, // XXX: naa } - newaxfrdata = append(newaxfrdata, &rpzn) + // newaxfrdata = append(newaxfrdata, &rpzn) // td.Rpz.RpzMap[name+td.Rpz.ZoneName] = &rpzn td.mu.Lock() td.Rpz.Axfr.Data[name+td.Rpz.ZoneName] = &rpzn @@ -248,8 +248,6 @@ func (td *TemData) GenerateRpzIxfr(data *tapir.TapirMsg) (RpzIxfr, error) { tn.Name, tapir.ActionToString[newAction], tapir.ActionToString[cur.Action]) } - } else { - // no change, do nothing } } } else { diff --git a/sources.go b/sources.go index 2a8909b..70c410f 100644 --- a/sources.go +++ b/sources.go @@ -7,6 +7,7 @@ import ( "fmt" "log" "os" + "path/filepath" "strings" "time" @@ -111,7 +112,7 @@ func NewTemData(conf *Config, lg *log.Logger) (*TemData, error) { func (td *TemData) ParseSourcesNG() error { var srcfoo SrcFoo - configFile := tapir.TemSourcesCfgFile + configFile := filepath.Clean(tapir.TemSourcesCfgFile) data, err := os.ReadFile(configFile) if err != nil { return fmt.Errorf("error reading config file: %v", err) @@ -230,7 +231,7 @@ func (td *TemData) ParseSourcesNG() error { err = td.ParseLocalFile(name, &newsource, rptchan) case "xfr": err = td.ParseRpzFeed(name, &newsource, rptchan) - td.Logger.Printf("Thread %d: source \"%s\" (%s) now returned from ParseRpzFeed(). %d remaining", thread, name, threads) + td.Logger.Printf("Thread %d: source \"%s\" now returned from ParseRpzFeed(). %d remaining", thread, name, threads) default: td.Logger.Printf("*** ParseSourcesNG: Error: unhandled source type %s", src.Source) } @@ -339,8 +340,7 @@ func (td *TemData) ParseRpzFeed(sourceid string, s *tapir.WBGlist, rpt chan stri upstream := viper.GetString(fmt.Sprintf("sources.%s.upstream", sourceid)) if upstream == "" { - return fmt.Errorf("Unable to load RPZ source %s, upstream address not specified.", - sourceid) + return fmt.Errorf("unable to load RPZ source %s, upstream address not specified", sourceid) } s.Names = map[string]tapir.TapirName{} // must initialize diff --git a/xfr.go b/xfr.go index eb378e7..5c29020 100644 --- a/xfr.go +++ b/xfr.go @@ -112,8 +112,11 @@ func (td *TemData) RpzAxfrOut(w dns.ResponseWriter, r *dns.Msg) (uint32, int, er outbound_xfr <- &env close(outbound_xfr) - wg.Wait() // wait until everything is written out - w.Close() // close connection + wg.Wait() // wait until everything is written out + err := w.Close() // close connection + if err != nil { + td.Logger.Printf("RpzAxfrOut: Error from Close(): %v", err) + } td.Logger.Printf("ZoneTransferOut: %s: Sent %d RRs (including SOA twice).", zone, total_sent) @@ -152,9 +155,9 @@ func (td *TemData) RpzIxfrOut(w dns.ResponseWriter, r *dns.Msg) (uint32, int, er if len(r.Ns) > 0 { for _, rr := range r.Ns { - switch rr.(type) { + switch rr := rr.(type) { case *dns.SOA: - curserial = rr.(*dns.SOA).Serial + curserial = rr.Serial default: td.Logger.Printf("RpzIxfrOut: unexpected RR in IXFR request Authority section:\n%s\n", rr.String()) @@ -197,7 +200,10 @@ func (td *TemData) RpzIxfrOut(w dns.ResponseWriter, r *dns.Msg) (uint32, int, er wg.Add(1) go func() { - tr.Out(w, r, outbound_xfr) + err := tr.Out(w, r, outbound_xfr) + if err != nil { + td.Logger.Printf("Error from transfer.Out(): %v", err) + } wg.Done() }() @@ -286,8 +292,11 @@ func (td *TemData) RpzIxfrOut(w dns.ResponseWriter, r *dns.Msg) (uint32, int, er outbound_xfr <- &env close(outbound_xfr) - wg.Wait() // wait until everything is written out - w.Close() // close connection + wg.Wait() // wait until everything is written out + err = w.Close() // close connection + if err != nil { + td.Logger.Printf("RpzIxfrOut: Error from Close(): %v", err) + } td.Logger.Printf("RpzIxfrOut: %s: Sent %d RRs (including SOA twice).", zone, total_sent) err = td.PruneRpzIxfrChain()