Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add default-disabled endpoint to trigger update over HTTP #135

Merged
merged 1 commit into from
Feb 22, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 64 additions & 22 deletions cmd/stayrtr/stayrtr.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,17 +49,17 @@ var (
MetricsAddr = flag.String("metrics.addr", ":9847", "Metrics address")
MetricsPath = flag.String("metrics.path", "/metrics", "Metrics path")

ExportPath = flag.String("export.path", "/rpki.json", "Export path")

RTRVersion = flag.Int("protocol", 1, "RTR protocol version. Default is version 1 (RFC 8210)")
RefreshRTR = flag.Int("rtr.refresh", 3600, "Refresh interval")
RetryRTR = flag.Int("rtr.retry", 600, "Retry interval")
ExpireRTR = flag.Int("rtr.expire", 7200, "Expire interval")
SendNotifs = flag.Bool("notifications", true, "Send notifications to clients (disable with -notifications=false)")
EnforceVersion = flag.Bool("enforce.version", false, "Disable version negotiation")
DisableBGPSec = flag.Bool("disable.bgpsec", false, "Disable sending out BGPSEC Router Keys")
EnableNODELAY = flag.Bool("enable.nodelay", false, "Force enable TCP NODELAY (Likely increases CPU)")

ExportPath = flag.String("export.path", "/rpki.json", "Export path")
EnableUpdateEndpoint = flag.Bool("update.endpoint", false, "Enable HTTP endpoint that expedites the next fetch")

RTRVersion = flag.Int("protocol", 1, "RTR protocol version. Default is version 1 (RFC 8210)")
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps we should add gofmt in ci/cd to check that merged code is formatted already

RefreshRTR = flag.Int("rtr.refresh", 3600, "Refresh interval")
RetryRTR = flag.Int("rtr.retry", 600, "Retry interval")
ExpireRTR = flag.Int("rtr.expire", 7200, "Expire interval")
SendNotifs = flag.Bool("notifications", true, "Send notifications to clients (disable with -notifications=false)")
EnforceVersion = flag.Bool("enforce.version", false, "Disable version negotiation")
DisableBGPSec = flag.Bool("disable.bgpsec", false, "Disable sending out BGPSEC Router Keys")
EnableNODELAY = flag.Bool("enable.nodelay", false, "Force enable TCP NODELAY (Likely increases CPU)")

Bind = flag.String("bind", ":8282", "Bind address")

Expand Down Expand Up @@ -169,7 +169,7 @@ func initMetrics() {
prometheus.MustRegister(CurrentSerial)
}

func metricHTTP() {
func serveHTTP() {
http.Handle(*MetricsPath, promhttp.Handler())
log.Fatal(http.ListenAndServe(*MetricsAddr, nil))
}
Expand Down Expand Up @@ -411,7 +411,7 @@ func (s *state) applyUpdateFromNewState(vrps []rtr.VRP, brks []rtr.BgpsecKey,
vrpsjson []prefixfile.VRPJson, brksjson []prefixfile.BgpSecKeyJson,
countv4 int, countv6 int) error {

SDs := make([]rtr.SendableData, 0, len(vrps) + len(brks))
SDs := make([]rtr.SendableData, 0, len(vrps)+len(brks))
for _, v := range vrps {
SDs = append(SDs, v.Copy())
}
Expand Down Expand Up @@ -528,6 +528,14 @@ func (s *state) updateSlurm(file string) (bool, error) {
return true, nil
}

func (s *state) updateDelay(delay *time.Ticker, interval int) {
if s.lastchange.IsZero() {
delay.Reset(30 * time.Second)
} else {
delay.Reset(time.Duration(interval) * time.Second)
}
}

func (s *state) routineUpdate(file string, interval int, slurmFile string) {
log.Debugf("Starting refresh routine (file: %v, interval: %vs, slurm: %v)", file, interval, slurmFile)
signals := make(chan os.Signal, 1)
Expand All @@ -548,11 +556,10 @@ func (s *state) routineUpdate(file string, interval int, slurmFile string) {
case <-delay.C:
case <-signals:
log.Debug("Received HUP signal")
if s.lastchange.IsZero() {
delay.Reset(30 * time.Second)
} else {
delay.Reset(time.Duration(interval) * time.Second)
}
s.updateDelay(delay, interval)
case <-s.triggerUpdate:
log.Debug("Received triggered update")
s.updateDelay(delay, interval)
}
slurmNotPresentOrUpdated := false

Expand Down Expand Up @@ -629,6 +636,34 @@ func (s *state) exporter(wr http.ResponseWriter, r *http.Request) {
enc.Encode(toExport)
}

func (s *state) updateNow(wr http.ResponseWriter, r *http.Request) {
wr.Header().Set("Content-Type", "application/json")

response := make(map[string]interface{})
if s.TriggerUpdate() {
response["status"] = "success"
response["message"] = "Update triggered successfully"
wr.WriteHeader(http.StatusOK)
} else {
response["status"] = "error"
response["message"] = "Update not triggered. Queue is full or not ready"
wr.WriteHeader(http.StatusInternalServerError)
}

json.NewEncoder(wr).Encode(response)
}

func (s *state) TriggerUpdate() bool {
select {
case s.triggerUpdate <- struct{}{}:
return true
default:
// Channel is full or not ready, log a warning
log.Warn("Update trigger skipped: update ongoing or not ready")
return false
}
}

type state struct {
lastdata *prefixfile.RPKIList
lasthashCache []byte
Expand All @@ -649,6 +684,8 @@ type state struct {
slurm *prefixfile.SlurmConfig

checktime bool

triggerUpdate chan struct{}
}

type metricsEvent struct {
Expand Down Expand Up @@ -725,9 +762,9 @@ func run() error {
RetryInterval: uint32(*RetryRTR),
ExpireInterval: uint32(*ExpireRTR),

EnforceVersion: *EnforceVersion,
DisableBGPSec: *DisableBGPSec,
EnableNODELAY: *EnableNODELAY,
EnforceVersion: *EnforceVersion,
DisableBGPSec: *DisableBGPSec,
EnableNODELAY: *EnableNODELAY,
}

var me *metricsEvent
Expand All @@ -750,6 +787,8 @@ func run() error {
lockJson: &sync.RWMutex{},

fetchConfig: utils.NewFetchConfig(),

triggerUpdate: make(chan struct{}, 1), // limit the number of queued updates. Downside: HTTP call to endpoint may fail
}
s.fetchConfig.UserAgent = *UserAgent
s.fetchConfig.Mime = *Mime
Expand All @@ -760,7 +799,10 @@ func run() error {
if *ExportPath != "" {
http.HandleFunc(*ExportPath, s.exporter)
}
go metricHTTP()
if *EnableUpdateEndpoint {
http.HandleFunc("/api/update", s.updateNow)
}
go serveHTTP()
}

if *Bind == "" && *BindTLS == "" && *BindSSH == "" {
Expand Down
Loading