diff --git a/src/server/middleware/csrf/csrf.go b/src/server/middleware/csrf/csrf.go index 9e0341aef32..82b1755b43c 100644 --- a/src/server/middleware/csrf/csrf.go +++ b/src/server/middleware/csrf/csrf.go @@ -72,8 +72,8 @@ func Middleware() func(handler http.Handler) http.Handler { } else if len(key) != 32 { log.Errorf("Invalid CSRF key length from the environment: %s. Please ensure the key length is 32 characters.", key) protect = func(_ http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - lib_http.SendError(w, errors.New("Invalid CSRF key length from the environment. Please ensure the key length is 32 characters.")) + return http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + lib_http.SendError(w, errors.New("invalid CSRF key length from the environment. Please ensure the key length is 32 characters")) }) } return diff --git a/src/server/middleware/csrf/csrf_test.go b/src/server/middleware/csrf/csrf_test.go index 054baf91925..1b9e75512a0 100644 --- a/src/server/middleware/csrf/csrf_test.go +++ b/src/server/middleware/csrf/csrf_test.go @@ -4,17 +4,21 @@ import ( "net/http" "net/http/httptest" "os" + "sync" "testing" "github.com/stretchr/testify/assert" "github.com/goharbor/harbor/src/common" - "github.com/goharbor/harbor/src/common/utils" "github.com/goharbor/harbor/src/common/utils/test" "github.com/goharbor/harbor/src/lib/config" _ "github.com/goharbor/harbor/src/pkg/config/inmemory" ) +func resetMiddleware() { + once = sync.Once{} +} + func TestMain(m *testing.M) { test.InitDatabaseFromEnv() conf := map[string]interface{}{} @@ -37,45 +41,29 @@ func TestMiddleware(t *testing.T) { req *http.Request statusCode int returnToken bool - validKey bool }{ { req: httptest.NewRequest(http.MethodGet, "/", nil), statusCode: http.StatusOK, returnToken: true, - validKey: true, }, { req: httptest.NewRequest(http.MethodDelete, "/", nil), statusCode: http.StatusForbidden, returnToken: true, - validKey: true, }, { req: httptest.NewRequest(http.MethodGet, "/api/2.0/projects", nil), // should be skipped statusCode: http.StatusOK, returnToken: false, - validKey: true, }, { req: httptest.NewRequest(http.MethodDelete, "/v2/library/hello-world/manifests/latest", nil), // should be skipped statusCode: http.StatusOK, returnToken: false, - validKey: true, - }, - { - req: httptest.NewRequest(http.MethodGet, "/", nil), - statusCode: http.StatusInternalServerError, - returnToken: false, - validKey: false, }, } for _, c := range cases { - if c.validKey { - os.Setenv(csrfKeyEnv, utils.GenerateRandomStringWithLen(32)) - } else { - os.Setenv(csrfKeyEnv, utils.GenerateRandomStringWithLen(10)) - } srv := Middleware()(&handler{}) rec := httptest.NewRecorder() srv.ServeHTTP(rec, c.req) @@ -84,6 +72,26 @@ func TestMiddleware(t *testing.T) { } } +func TestMiddlewareInvalidKey(t *testing.T) { + originalEnv := os.Getenv(csrfKeyEnv) + defer os.Setenv(csrfKeyEnv, originalEnv) + + t.Run("invalid CSRF key", func(t *testing.T) { + os.Setenv(csrfKeyEnv, "invalidkey") + resetMiddleware() + middleware := Middleware() + testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + t.Error("handler should not be reached when CSRF key is invalid") + }) + + handler := middleware(testHandler) + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + assert.Equal(t, http.StatusInternalServerError, rec.Code) + }) +} + func TestSecureCookie(t *testing.T) { assert.True(t, secureCookie()) conf := map[string]interface{}{