diff --git a/.github/workflows/checks.yml b/.github/workflows/checks.yml index 21a2a29..964a59d 100644 --- a/.github/workflows/checks.yml +++ b/.github/workflows/checks.yml @@ -1,7 +1,7 @@ on: [push] jobs: - audit: + checks: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4.2.2 @@ -16,7 +16,7 @@ jobs: version: latest - name: Format check - run: go fmt ./... + run: make fmt-check - name: Static analysis run: go vet ./... diff --git a/Makefile b/Makefile index 433f318..75fccc3 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,4 @@ -.PHONY: audit build lint lintfix run +.PHONY: audit build lint lintfix fmt fmt-check run test test-verbose test-coverage check: lint go fmt ./... @@ -13,6 +13,15 @@ lint: lintfix: golangci-lint run --fix +fmt: + go fmt ./... + +fmt-check: + @if [ -n "$$(go fmt ./...)" ]; then \ + echo "Found unformatted Go files. Please run 'make fmt'"; \ + exit 1; \ + fi + run: build ./bin/main javascript diff --git a/cmd/launcher/main.go b/cmd/launcher/main.go index a1acd17..ef2df7d 100644 --- a/cmd/launcher/main.go +++ b/cmd/launcher/main.go @@ -40,19 +40,18 @@ func main() { errorreporting.Init(cfg.Sentry) defer errorreporting.Close() - srv := http.NewHealthCheckServer() + srv := http.NewHealthCheckServer(cfg.HealthCheckServerPort) go func() { - logs.Infof("Starting health check server at port %d", http.GetPort()) - if err := srv.ListenAndServe(); err != nil { errMsg := "Health check server failed to start" if opErr, ok := err.(*net.OpError); ok && opErr.Op == "listen" { - errMsg = fmt.Sprintf("%s: Port %d is already in use", errMsg, http.GetPort()) + errMsg = fmt.Sprintf("%s: Port %s is already in use", errMsg, srv.Addr) } else { errMsg = fmt.Sprintf("%s: %s", errMsg, err) } logs.Error(errMsg) } + logs.Infof("Started launcher's health check server at port %d", srv.Addr) }() cmd := &commands.LaunchCommand{} diff --git a/internal/config/config.go b/internal/config/config.go index 9dbe70e..241f4e6 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -17,6 +17,11 @@ var configPath = "/etc/n8n-task-runners.json" var cfg Config +const ( + // EnvVarHealthCheckPort is the env var for the port for the launcher's health check server. + EnvVarHealthCheckPort = "N8N_LAUNCHER_HEALTCHECK_PORT" +) + // Config holds the full configuration for the launcher. type Config struct { // LogLevel is the log level for the launcher. Default: `info`. @@ -33,6 +38,9 @@ type Config struct { // TaskBrokerURI is the URI of the task broker server. TaskBrokerURI string `env:"N8N_TASK_BROKER_URI, default=http://127.0.0.1:5679"` + // HealthCheckServerPort is the port for the launcher's health check server. + HealthCheckServerPort string `env:"N8N_LAUNCHER_HEALTCHECK_PORT, default=5680"` + // Runner is the runner config for the task runner, obtained from: // `/etc/n8n-task-runners.json`. Runner *RunnerConfig @@ -92,6 +100,10 @@ func LoadConfig(runnerType string, lookuper envconfig.Lookuper) (*Config, error) cfgErrs = append(cfgErrs, errs.ErrNegativeAutoShutdownTimeout) } + if port, err := strconv.Atoi(cfg.HealthCheckServerPort); err != nil || port <= 0 || port >= 65536 { + cfgErrs = append(cfgErrs, fmt.Errorf("%s must be a valid port number", EnvVarHealthCheckPort)) + } + // runner runnerCfg, err := readFileConfig(runnerType) diff --git a/internal/http/check_until_broker_ready_test.go b/internal/http/check_until_broker_ready_test.go index edb30b3..1772439 100644 --- a/internal/http/check_until_broker_ready_test.go +++ b/internal/http/check_until_broker_ready_test.go @@ -29,18 +29,18 @@ func TestCheckUntilBrokerReadyHappyPath(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { requestCount := 0 - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { requestCount++ tt.serverFn(w, r, requestCount) })) - defer server.Close() + defer srv.Close() ctx, cancel := context.WithTimeout(context.Background(), tt.timeout) defer cancel() done := make(chan error) go func() { - done <- CheckUntilBrokerReady(server.URL) + done <- CheckUntilBrokerReady(srv.URL) }() select { @@ -83,11 +83,11 @@ func TestCheckUntilBrokerReadyErrors(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(tt.handler)) + srv := httptest.NewServer(http.HandlerFunc(tt.handler)) if tt.name == "error - closed server" { - server.Close() + srv.Close() } else { - defer server.Close() + defer srv.Close() } // CheckUntilBrokerReady retries forever, so set up @@ -99,7 +99,7 @@ func TestCheckUntilBrokerReadyErrors(t *testing.T) { brokerUnexpectedlyReady := make(chan error) go func() { - brokerUnexpectedlyReady <- CheckUntilBrokerReady(server.URL) + brokerUnexpectedlyReady <- CheckUntilBrokerReady(srv.URL) }() select { @@ -137,7 +137,7 @@ func TestSendReadinessRequest(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodGet { t.Errorf("expected GET request, got %s", r.Method) } @@ -146,9 +146,9 @@ func TestSendReadinessRequest(t *testing.T) { } w.WriteHeader(tt.serverResponse) })) - defer server.Close() + defer srv.Close() - resp, err := sendHealthRequest(server.URL) + resp, err := sendHealthRequest(srv.URL) if err == nil { defer resp.Body.Close() diff --git a/internal/http/fetch_grant_token.go b/internal/http/fetch_grant_token.go index fa68d56..90fe5e3 100644 --- a/internal/http/fetch_grant_token.go +++ b/internal/http/fetch_grant_token.go @@ -20,12 +20,12 @@ func sendGrantTokenRequest(taskBrokerServerURI, authToken string) (string, error payload := map[string]string{"token": authToken} payloadBytes, err := json.Marshal(payload) if err != nil { - return "", err + return "", fmt.Errorf("failed to marshal grant token request: %w", err) } req, err := http.NewRequest("POST", url, bytes.NewBuffer(payloadBytes)) if err != nil { - return "", err + return "", fmt.Errorf("failed to create grant token request: %w", err) } req.Header.Set("Content-Type", "application/json") @@ -48,9 +48,9 @@ func sendGrantTokenRequest(taskBrokerServerURI, authToken string) (string, error return tokenResp.Data.Token, nil } -// FetchGrantToken exchanges the launcher's auth token for a single-use -// grant token from the task broker. In case the task broker is -// temporarily unavailable, this exchange is retried a limited number of times. +// FetchGrantToken exchanges the launcher's auth token for a single-use grant +// token from the task broker. In case the task broker is temporarily +// unavailable, this exchange is retried a limited number of times. func FetchGrantToken(taskBrokerServerURI, authToken string) (string, error) { grantTokenFetch := func() (string, error) { token, err := sendGrantTokenRequest(taskBrokerServerURI, authToken) diff --git a/internal/http/fetch_grant_token_test.go b/internal/http/fetch_grant_token_test.go new file mode 100644 index 0000000..5430fb4 --- /dev/null +++ b/internal/http/fetch_grant_token_test.go @@ -0,0 +1,174 @@ +package http + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "task-runner-launcher/internal/retry" + "testing" + "time" +) + +func init() { + retry.DefaultMaxRetryTime = 50 * time.Millisecond + retry.DefaultMaxRetries = 3 + retry.DefaultWaitTimeBetweenRetries = 10 * time.Millisecond +} + +func TestFetchGrantToken(t *testing.T) { + tests := []struct { + name string + serverURL string + authToken string + serverFn func(w http.ResponseWriter, r *http.Request) + wantErr bool + errorContains string + }{ + { + name: "successful request", + authToken: "test-token", + serverFn: func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + if err := json.NewEncoder(w).Encode(map[string]interface{}{ + "data": map[string]string{ + "token": "test-grant-token", + }, + }); err != nil { + t.Errorf("Failed to encode response: %v", err) + } + }, + }, + { + name: "invalid response json", + authToken: "test-token", + serverFn: func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + if _, err := w.Write([]byte("invalid json")); err != nil { + t.Errorf("Failed to write response: %v", err) + } + }, + wantErr: true, + errorContains: "failed to decode grant token response", + }, + { + name: "server error", + authToken: "test-token", + serverFn: func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + }, + wantErr: true, + errorContains: "status code 500", + }, + { + name: "verify request body", + authToken: "test-auth-token", + serverFn: func(w http.ResponseWriter, r *http.Request) { + var body struct { + Token string `json:"token"` + } + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + t.Errorf("Failed to decode request body: %v", err) + } + if body.Token != "test-auth-token" { + t.Errorf("Expected auth token 'test-auth-token', got %q", body.Token) + } + if r.Header.Get("Content-Type") != "application/json" { + t.Errorf("Expected Content-Type 'application/json', got %q", r.Header.Get("Content-Type")) + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + if err := json.NewEncoder(w).Encode(map[string]interface{}{ + "data": map[string]string{ + "token": "test-grant-token", + }, + }); err != nil { + t.Errorf("Failed to encode response: %v", err) + } + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(tt.serverFn)) + defer srv.Close() + + token, err := FetchGrantToken(srv.URL, tt.authToken) + hasErr := err != nil + + if hasErr != tt.wantErr { + t.Errorf("FetchGrantToken() error = %v, wantErr %v", err, tt.wantErr) + return + } + + if hasErr && tt.wantErr && !strings.Contains(err.Error(), tt.errorContains) { + t.Errorf("Expected error containing %q, got %v", tt.errorContains, err) + } + + if !tt.wantErr && token == "" { + t.Error("Expected non-empty token for successful request") + } + }) + } +} + +func TestFetchGrantTokenInvalidURL(t *testing.T) { + _, err := FetchGrantToken("not-a-valid-url", "test-token") + if err == nil { + t.Error("Expected error for invalid URL, got nil") + } +} + +func TestFetchGrantTokenRetry(t *testing.T) { + tryCount := 0 + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + tryCount++ + if tryCount < 2 { + w.WriteHeader(http.StatusInternalServerError) + return + } + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + if err := json.NewEncoder(w).Encode(map[string]interface{}{ + "data": map[string]string{ + "token": "test-grant-token", + }, + }); err != nil { + t.Errorf("Failed to encode response: %v", err) + } + })) + defer srv.Close() + + token, err := FetchGrantToken(srv.URL, "test-token") + if err != nil { + t.Errorf("FetchGrantToken() unexpected error = %v", err) + } + if token == "" { + t.Error("Expected non-empty token after retry") + } + if tryCount != 2 { + t.Errorf("Expected 2 attempts, got %d", tryCount) + } +} + +func TestFetchGrantTokenConnectionFailure(t *testing.T) { + invalidServerURL := "http://localhost:1" + + token, err := FetchGrantToken(invalidServerURL, "test-token") + + if err == nil { + t.Error("Expected error for connection failure, got nil") + } + + if !strings.Contains(err.Error(), "connection refused") { + t.Errorf("Expected error containing 'connection refused', got %v", err) + } + + if token != "" { + t.Errorf("Expected empty token for failed connection, got %q", token) + } +} diff --git a/internal/http/healthcheck_server.go b/internal/http/healthcheck_server.go index 979b86a..6d9c52b 100644 --- a/internal/http/healthcheck_server.go +++ b/internal/http/healthcheck_server.go @@ -4,26 +4,22 @@ import ( "encoding/json" "fmt" "net/http" - "os" - "strconv" "task-runner-launcher/internal/logs" "time" ) const ( - defaultPort = 5680 - portEnvVar = "N8N_LAUNCHER_HEALTCHECK_PORT" healthCheckPath = "/healthz" readTimeout = 1 * time.Second writeTimeout = 1 * time.Second ) -func NewHealthCheckServer() *http.Server { +func NewHealthCheckServer(port string) *http.Server { mux := http.NewServeMux() mux.HandleFunc(healthCheckPath, handleHealthCheck) return &http.Server{ - Addr: fmt.Sprintf(":%d", GetPort()), + Addr: fmt.Sprintf(":%s", port), Handler: mux, ReadTimeout: readTimeout, WriteTimeout: writeTimeout, @@ -48,14 +44,3 @@ func handleHealthCheck(w http.ResponseWriter, r *http.Request) { return } } - -func GetPort() int { - if customPortStr := os.Getenv(portEnvVar); customPortStr != "" { - if customPort, err := strconv.Atoi(customPortStr); err == nil && customPort > 0 && customPort < 65536 { - return customPort - } - logs.Warnf("%s sets an invalid port, falling back to default port %d", portEnvVar, defaultPort) - } - - return defaultPort -} diff --git a/internal/http/healthcheck_server_test.go b/internal/http/healthcheck_server_test.go new file mode 100644 index 0000000..b41450d --- /dev/null +++ b/internal/http/healthcheck_server_test.go @@ -0,0 +1,114 @@ +package http + +import ( + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "testing" +) + +func TestHealthCheckHandler(t *testing.T) { + tests := []struct { + name string + method string + expectedStatus int + wantBody bool + }{ + { + name: "GET request returns 200 and status ok", + method: http.MethodGet, + expectedStatus: http.StatusOK, + wantBody: true, + }, + { + name: "POST request returns 405 and status not allowed", + method: http.MethodPost, + expectedStatus: http.StatusMethodNotAllowed, + wantBody: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := httptest.NewRequest(tt.method, "/healthz", nil) + w := httptest.NewRecorder() + + handleHealthCheck(w, req) + + if got := w.Code; got != tt.expectedStatus { + t.Errorf("handleHealthCheck() status = %v, want %v", got, tt.expectedStatus) + } + + if tt.wantBody { + var response struct { + Status string `json:"status"` + } + + if err := json.NewDecoder(w.Body).Decode(&response); err != nil { + t.Errorf("failed to decode response body: %v", err) + } + + if response.Status != "ok" { + t.Errorf("handleHealthCheck() status = %v, want %v", response.Status, "ok") + } + + if contentType := w.Header().Get("Content-Type"); contentType != "application/json" { + t.Errorf("handleHealthCheck() Content-Type = %v, want %v", contentType, "application/json") + } + } + }) + } +} + +func TestHealthCheckHandlerEncodingError(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/healthz", nil) + + failingWriter := &failingWriter{ + headers: http.Header{}, + } + handleHealthCheck(failingWriter, req) + + if failingWriter.statusCode != http.StatusInternalServerError { + t.Errorf("handleHealthCheck() with encoding error, status = %v, want %v", + failingWriter.statusCode, http.StatusInternalServerError) + } +} + +type failingWriter struct { + statusCode int + headers http.Header +} + +func (w *failingWriter) Header() http.Header { + return w.headers +} + +func (w *failingWriter) Write([]byte) (int, error) { + return 0, fmt.Errorf("encoding error") +} + +func (w *failingWriter) WriteHeader(statusCode int) { + w.statusCode = statusCode +} + +func TestNewHealthCheckServer(t *testing.T) { + server := NewHealthCheckServer("5680") + + if server == nil { + t.Fatal("NewHealthCheckServer() returned nil") + return + } + + if server.Addr != ":5680" { + t.Errorf("NewHealthCheckServer() addr = %v, want %v", server.Addr, ":5680") + } + + if server.ReadTimeout != readTimeout { + t.Errorf("NewHealthCheckServer() readTimeout = %v, want %v", server.ReadTimeout, readTimeout) + } + + if server.WriteTimeout != writeTimeout { + t.Errorf("NewHealthCheckServer() writeTimeout = %v, want %v", server.WriteTimeout, writeTimeout) + } +} diff --git a/internal/retry/retry.go b/internal/retry/retry.go index a15c96a..be6b427 100644 --- a/internal/retry/retry.go +++ b/internal/retry/retry.go @@ -7,9 +7,9 @@ import ( ) var ( - defaultMaxRetryTime = 60 * time.Second - defaultMaxRetries = 100 - defaultWaitTimeBetweenRetries = 5 * time.Second + DefaultMaxRetryTime = 60 * time.Second + DefaultMaxRetries = 100 + DefaultWaitTimeBetweenRetries = 5 * time.Second ) type retryConfig struct { @@ -68,15 +68,15 @@ func UnlimitedRetry[T any](operationName string, operationFn func() (T, error)) return retry(operationName, operationFn, retryConfig{ MaxRetryTime: 0, MaxAttempts: 0, - WaitTimeBetweenRetries: defaultWaitTimeBetweenRetries, + WaitTimeBetweenRetries: DefaultWaitTimeBetweenRetries, }) } // LimitedRetry retries an operation until max retry time or until max attempts. func LimitedRetry[T any](operationName string, operationFn func() (T, error)) (T, error) { return retry(operationName, operationFn, retryConfig{ - MaxRetryTime: defaultMaxRetryTime, - MaxAttempts: defaultMaxRetries, - WaitTimeBetweenRetries: defaultWaitTimeBetweenRetries, + MaxRetryTime: DefaultMaxRetryTime, + MaxAttempts: DefaultMaxRetries, + WaitTimeBetweenRetries: DefaultWaitTimeBetweenRetries, }) } diff --git a/internal/retry/retry_test.go b/internal/retry/retry_test.go index 89e264c..bbad757 100644 --- a/internal/retry/retry_test.go +++ b/internal/retry/retry_test.go @@ -8,18 +8,18 @@ import ( func setRetryTimings(t *testing.T) func() { t.Helper() - origMaxRetryTime := defaultMaxRetryTime - origMaxRetries := defaultMaxRetries - origWaitTime := defaultWaitTimeBetweenRetries + origMaxRetryTime := DefaultMaxRetryTime + origMaxRetries := DefaultMaxRetries + origWaitTime := DefaultWaitTimeBetweenRetries - defaultMaxRetryTime = 100 * time.Millisecond - defaultMaxRetries = 3 - defaultWaitTimeBetweenRetries = 10 * time.Millisecond + DefaultMaxRetryTime = 100 * time.Millisecond + DefaultMaxRetries = 3 + DefaultWaitTimeBetweenRetries = 10 * time.Millisecond return func() { - defaultMaxRetryTime = origMaxRetryTime - defaultMaxRetries = origMaxRetries - defaultWaitTimeBetweenRetries = origWaitTime + DefaultMaxRetryTime = origMaxRetryTime + DefaultMaxRetries = origMaxRetries + DefaultWaitTimeBetweenRetries = origWaitTime } } @@ -131,14 +131,14 @@ func TestLimitedRetry(t *testing.T) { operationFn: func() (string, error) { return "", errors.New("persistent error") }, - expectedCalls: defaultMaxRetries, + expectedCalls: DefaultMaxRetries, expectError: true, expectedValue: "", }, { name: "fails after max retry time", operationFn: func() (string, error) { - time.Sleep(defaultMaxRetryTime + time.Second) + time.Sleep(DefaultMaxRetryTime + time.Second) return "", errors.New("timeout error") }, expectedCalls: 1,