diff --git a/cmd/stayrtr/stayrtr.go b/cmd/stayrtr/stayrtr.go index ce030f7..cc72015 100644 --- a/cmd/stayrtr/stayrtr.go +++ b/cmd/stayrtr/stayrtr.go @@ -49,17 +49,17 @@ var ( MetricsAddr = flag.String("metrics.addr", ":9847", "Metrics address") MetricsPath = flag.String("metrics.path", "/metrics", "Metrics path") - ExportPath = flag.String("export.path", "/rpki.json", "Export path") - - RTRVersion = flag.Int("protocol", 1, "RTR protocol version. Default is version 1 (RFC 8210)") - RefreshRTR = flag.Int("rtr.refresh", 3600, "Refresh interval") - RetryRTR = flag.Int("rtr.retry", 600, "Retry interval") - ExpireRTR = flag.Int("rtr.expire", 7200, "Expire interval") - SendNotifs = flag.Bool("notifications", true, "Send notifications to clients (disable with -notifications=false)") - EnforceVersion = flag.Bool("enforce.version", false, "Disable version negotiation") - DisableBGPSec = flag.Bool("disable.bgpsec", false, "Disable sending out BGPSEC Router Keys") - EnableNODELAY = flag.Bool("enable.nodelay", false, "Force enable TCP NODELAY (Likely increases CPU)") - + ExportPath = flag.String("export.path", "/rpki.json", "Export path") + EnableUpdateEndpoint = flag.Bool("update.endpoint", false, "Enable HTTP endpoint that expedites the next fetch") + + RTRVersion = flag.Int("protocol", 1, "RTR protocol version. Default is version 1 (RFC 8210)") + RefreshRTR = flag.Int("rtr.refresh", 3600, "Refresh interval") + RetryRTR = flag.Int("rtr.retry", 600, "Retry interval") + ExpireRTR = flag.Int("rtr.expire", 7200, "Expire interval") + SendNotifs = flag.Bool("notifications", true, "Send notifications to clients (disable with -notifications=false)") + EnforceVersion = flag.Bool("enforce.version", false, "Disable version negotiation") + DisableBGPSec = flag.Bool("disable.bgpsec", false, "Disable sending out BGPSEC Router Keys") + EnableNODELAY = flag.Bool("enable.nodelay", false, "Force enable TCP NODELAY (Likely increases CPU)") Bind = flag.String("bind", ":8282", "Bind address") @@ -169,7 +169,7 @@ func initMetrics() { prometheus.MustRegister(CurrentSerial) } -func metricHTTP() { +func serveHTTP() { http.Handle(*MetricsPath, promhttp.Handler()) log.Fatal(http.ListenAndServe(*MetricsAddr, nil)) } @@ -411,7 +411,7 @@ func (s *state) applyUpdateFromNewState(vrps []rtr.VRP, brks []rtr.BgpsecKey, vrpsjson []prefixfile.VRPJson, brksjson []prefixfile.BgpSecKeyJson, countv4 int, countv6 int) error { - SDs := make([]rtr.SendableData, 0, len(vrps) + len(brks)) + SDs := make([]rtr.SendableData, 0, len(vrps)+len(brks)) for _, v := range vrps { SDs = append(SDs, v.Copy()) } @@ -528,6 +528,14 @@ func (s *state) updateSlurm(file string) (bool, error) { return true, nil } +func (s *state) updateDelay(delay *time.Ticker, interval int) { + if s.lastchange.IsZero() { + delay.Reset(30 * time.Second) + } else { + delay.Reset(time.Duration(interval) * time.Second) + } +} + func (s *state) routineUpdate(file string, interval int, slurmFile string) { log.Debugf("Starting refresh routine (file: %v, interval: %vs, slurm: %v)", file, interval, slurmFile) signals := make(chan os.Signal, 1) @@ -548,11 +556,10 @@ func (s *state) routineUpdate(file string, interval int, slurmFile string) { case <-delay.C: case <-signals: log.Debug("Received HUP signal") - if s.lastchange.IsZero() { - delay.Reset(30 * time.Second) - } else { - delay.Reset(time.Duration(interval) * time.Second) - } + s.updateDelay(delay, interval) + case <-s.triggerUpdate: + log.Debug("Received triggered update") + s.updateDelay(delay, interval) } slurmNotPresentOrUpdated := false @@ -629,6 +636,34 @@ func (s *state) exporter(wr http.ResponseWriter, r *http.Request) { enc.Encode(toExport) } +func (s *state) updateNow(wr http.ResponseWriter, r *http.Request) { + wr.Header().Set("Content-Type", "application/json") + + response := make(map[string]interface{}) + if s.TriggerUpdate() { + response["status"] = "success" + response["message"] = "Update triggered successfully" + wr.WriteHeader(http.StatusOK) + } else { + response["status"] = "error" + response["message"] = "Update not triggered. Queue is full or not ready" + wr.WriteHeader(http.StatusInternalServerError) + } + + json.NewEncoder(wr).Encode(response) +} + +func (s *state) TriggerUpdate() bool { + select { + case s.triggerUpdate <- struct{}{}: + return true + default: + // Channel is full or not ready, log a warning + log.Warn("Update trigger skipped: update ongoing or not ready") + return false + } +} + type state struct { lastdata *prefixfile.RPKIList lasthashCache []byte @@ -649,6 +684,8 @@ type state struct { slurm *prefixfile.SlurmConfig checktime bool + + triggerUpdate chan struct{} } type metricsEvent struct { @@ -725,9 +762,9 @@ func run() error { RetryInterval: uint32(*RetryRTR), ExpireInterval: uint32(*ExpireRTR), - EnforceVersion: *EnforceVersion, - DisableBGPSec: *DisableBGPSec, - EnableNODELAY: *EnableNODELAY, + EnforceVersion: *EnforceVersion, + DisableBGPSec: *DisableBGPSec, + EnableNODELAY: *EnableNODELAY, } var me *metricsEvent @@ -750,6 +787,8 @@ func run() error { lockJson: &sync.RWMutex{}, fetchConfig: utils.NewFetchConfig(), + + triggerUpdate: make(chan struct{}, 1), // limit the number of queued updates. Downside: HTTP call to endpoint may fail } s.fetchConfig.UserAgent = *UserAgent s.fetchConfig.Mime = *Mime @@ -760,7 +799,10 @@ func run() error { if *ExportPath != "" { http.HandleFunc(*ExportPath, s.exporter) } - go metricHTTP() + if *EnableUpdateEndpoint { + http.HandleFunc("/api/update", s.updateNow) + } + go serveHTTP() } if *Bind == "" && *BindTLS == "" && *BindSSH == "" {