diff --git a/api/config.go b/api/config.go index a65340214..96443818d 100644 --- a/api/config.go +++ b/api/config.go @@ -9,6 +9,7 @@ type Config struct { GraphiteLocalMetricTTL time.Duration GraphiteRemoteMetricTTL time.Duration PrometheusRemoteMetricTTL time.Duration + Flags FeatureFlags } // WebConfig is container for web ui configuration parameters. @@ -33,4 +34,5 @@ type FeatureFlags struct { IsPlottingDefaultOn bool `json:"isPlottingDefaultOn"` IsPlottingAvailable bool `json:"isPlottingAvailable"` IsSubscriptionToAllTagsAvailable bool `json:"isSubscriptionToAllTagsAvailable"` + IsReadonlyEnabled bool `json:"isReadonlyEnabled"` } diff --git a/api/handler/handler.go b/api/handler/handler.go index 332010a30..0ceb343fd 100644 --- a/api/handler/handler.go +++ b/api/handler/handler.go @@ -24,7 +24,14 @@ const contactKey moiramiddle.ContextKey = "contact" const subscriptionKey moiramiddle.ContextKey = "subscription" // NewHandler creates new api handler request uris based on github.com/go-chi/chi -func NewHandler(db moira.Database, log moira.Logger, index moira.Searcher, config *api.Config, metricSourceProvider *metricSource.SourceProvider, webConfigContent []byte) http.Handler { +func NewHandler( + db moira.Database, + log moira.Logger, + index moira.Searcher, + config *api.Config, + metricSourceProvider *metricSource.SourceProvider, + webConfigContent []byte, +) http.Handler { database = db searchIndex = index router := chi.NewRouter() @@ -85,23 +92,30 @@ func NewHandler(db moira.Database, log moira.Logger, index moira.Searcher, confi // @tag.description APIs for interacting with Moira users router.Route("/api", func(router chi.Router) { router.Use(moiramiddle.DatabaseContext(database)) - router.Get("/config", getWebConfig(webConfigContent)) - router.Route("/user", user) - router.With(moiramiddle.Triggers(config.GraphiteLocalMetricTTL, config.GraphiteRemoteMetricTTL, config.PrometheusRemoteMetricTTL)).Route("/trigger", triggers(metricSourceProvider, searchIndex)) - router.Route("/tag", tag) - router.Route("/pattern", pattern) - router.Route("/event", event) - router.Route("/subscription", subscription) - router.Route("/notification", notification) router.Route("/health", health) - router.Route("/teams", teams) - router.Route("/contact", func(router chi.Router) { - contact(router) - contactEvents(router) + router.Route("/", func(router chi.Router) { + router.Use(moiramiddle.ReadOnlyMiddleware(config)) + router.Get("/config", getWebConfig(webConfigContent)) + router.Route("/user", user) + router.With(moiramiddle.Triggers( + config.GraphiteLocalMetricTTL, + config.GraphiteRemoteMetricTTL, + config.PrometheusRemoteMetricTTL, + )).Route("/trigger", triggers(metricSourceProvider, searchIndex)) + router.Route("/tag", tag) + router.Route("/pattern", pattern) + router.Route("/event", event) + router.Route("/subscription", subscription) + router.Route("/notification", notification) + router.Route("/teams", teams) + router.Route("/contact", func(router chi.Router) { + contact(router) + contactEvents(router) + }) + router.Get("/swagger/*", httpSwagger.Handler( + httpSwagger.URL("/api/swagger/doc.json"), + )) }) - router.Get("/swagger/*", httpSwagger.Handler( - httpSwagger.URL("/api/swagger/doc.json"), - )) }) if config.EnableCORS { diff --git a/api/handler/handler_test.go b/api/handler/handler_test.go new file mode 100644 index 000000000..5512f3448 --- /dev/null +++ b/api/handler/handler_test.go @@ -0,0 +1,101 @@ +package handler + +import ( + "bytes" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "testing" + + "github.com/golang/mock/gomock" + "github.com/moira-alert/moira/api" + "github.com/moira-alert/moira/api/dto" + "github.com/moira-alert/moira/logging/zerolog_adapter" + mock_moira_alert "github.com/moira-alert/moira/mock/moira-alert" + . "github.com/smartystreets/goconvey/convey" +) + +func TestReadonlyMode(t *testing.T) { + Convey("Test readonly mode enabled", t, func() { + mockCtrl := gomock.NewController(t) + defer mockCtrl.Finish() + + responseWriter := httptest.NewRecorder() + mockDb := mock_moira_alert.NewMockDatabase(mockCtrl) + database = mockDb + + logger, _ := zerolog_adapter.GetLogger("Test") + config := &api.Config{Flags: api.FeatureFlags{IsReadonlyEnabled: true}} + expectedConfig := []byte("Expected config") + handler := NewHandler(mockDb, logger, nil, config, nil, expectedConfig) + + Convey("Get notifier health", func() { + mockDb.EXPECT().GetNotifierState().Return("OK", nil).Times(1) + + expected := &dto.NotifierState{ + State: "OK", + } + + testRequest := httptest.NewRequest(http.MethodGet, "/api/health/notifier", nil) + + handler.ServeHTTP(responseWriter, testRequest) + + response := responseWriter.Result() + defer response.Body.Close() + content, _ := io.ReadAll(response.Body) + actual := &dto.NotifierState{} + err := json.Unmarshal(content, actual) + So(err, ShouldBeNil) + + So(actual, ShouldResemble, expected) + So(response.StatusCode, ShouldEqual, http.StatusOK) + }) + + Convey("Put notifier health", func() { + mockDb.EXPECT().SetNotifierState("OK").Return(nil).Times(1) + + state := &dto.NotifierState{ + State: "OK", + } + + stateBytes, err := json.Marshal(state) + So(err, ShouldBeNil) + + testRequest := httptest.NewRequest(http.MethodPut, "/api/health/notifier", bytes.NewReader(stateBytes)) + + handler.ServeHTTP(responseWriter, testRequest) + + response := responseWriter.Result() + defer response.Body.Close() + So(response.StatusCode, ShouldEqual, http.StatusOK) + }) + + Convey("Put new trigger", func() { + trigger := &dto.Trigger{} + triggerBytes, err := json.Marshal(trigger) + So(err, ShouldBeNil) + + testRequest := httptest.NewRequest(http.MethodPut, "/api/trigger", bytes.NewReader(triggerBytes)) + + handler.ServeHTTP(responseWriter, testRequest) + + response := responseWriter.Result() + defer response.Body.Close() + So(response.StatusCode, ShouldEqual, http.StatusForbidden) + }) + + Convey("Get contact", func() { + testRequest := httptest.NewRequest(http.MethodGet, "/api/config", nil) + + handler.ServeHTTP(responseWriter, testRequest) + + response := responseWriter.Result() + defer response.Body.Close() + actual, _ := io.ReadAll(response.Body) + + So(response.StatusCode, ShouldEqual, http.StatusOK) + So(actual, ShouldResemble, expectedConfig) + }) + }) +} diff --git a/api/middleware/context_test.go b/api/middleware/context_test.go index 5535c6c70..1c4e9e7c9 100644 --- a/api/middleware/context_test.go +++ b/api/middleware/context_test.go @@ -1,4 +1,4 @@ -package middleware_test +package middleware import ( "io" @@ -6,7 +6,6 @@ import ( "net/http/httptest" "testing" - "github.com/moira-alert/moira/api/middleware" . "github.com/smartystreets/goconvey/convey" ) @@ -26,7 +25,7 @@ func TestPaginateMiddleware(t *testing.T) { testRequest := httptest.NewRequest(http.MethodGet, "/test?"+param, nil) handler := func(w http.ResponseWriter, r *http.Request) {} - middlewareFunc := middleware.Paginate(defaultPage, defaultSize) + middlewareFunc := Paginate(defaultPage, defaultSize) wrappedHandler := middlewareFunc(http.HandlerFunc(handler)) wrappedHandler.ServeHTTP(responseWriter, testRequest) @@ -41,7 +40,7 @@ func TestPaginateMiddleware(t *testing.T) { testRequest := httptest.NewRequest(http.MethodGet, "/test?p=0%&size=100", nil) handler := func(w http.ResponseWriter, r *http.Request) {} - middlewareFunc := middleware.Paginate(defaultPage, defaultSize) + middlewareFunc := Paginate(defaultPage, defaultSize) wrappedHandler := middlewareFunc(http.HandlerFunc(handler)) wrappedHandler.ServeHTTP(responseWriter, testRequest) @@ -69,7 +68,7 @@ func TestPagerMiddleware(t *testing.T) { testRequest := httptest.NewRequest(http.MethodGet, "/test?"+param, nil) handler := func(w http.ResponseWriter, r *http.Request) {} - middlewareFunc := middleware.Pager(defaultCreatePager, defaultPagerID) + middlewareFunc := Pager(defaultCreatePager, defaultPagerID) wrappedHandler := middlewareFunc(http.HandlerFunc(handler)) wrappedHandler.ServeHTTP(responseWriter, testRequest) @@ -84,7 +83,7 @@ func TestPagerMiddleware(t *testing.T) { testRequest := httptest.NewRequest(http.MethodGet, "/test?pagerID=test%&createPager=true", nil) handler := func(w http.ResponseWriter, r *http.Request) {} - middlewareFunc := middleware.Pager(defaultCreatePager, defaultPagerID) + middlewareFunc := Pager(defaultCreatePager, defaultPagerID) wrappedHandler := middlewareFunc(http.HandlerFunc(handler)) wrappedHandler.ServeHTTP(responseWriter, testRequest) @@ -108,7 +107,7 @@ func TestPopulateMiddleware(t *testing.T) { testRequest := httptest.NewRequest(http.MethodGet, "/test?populated=true", nil) handler := func(w http.ResponseWriter, r *http.Request) {} - middlewareFunc := middleware.Populate(defaultPopulated) + middlewareFunc := Populate(defaultPopulated) wrappedHandler := middlewareFunc(http.HandlerFunc(handler)) wrappedHandler.ServeHTTP(responseWriter, testRequest) @@ -122,7 +121,7 @@ func TestPopulateMiddleware(t *testing.T) { testRequest := httptest.NewRequest(http.MethodGet, "/test?populated%=true", nil) handler := func(w http.ResponseWriter, r *http.Request) {} - middlewareFunc := middleware.Populate(defaultPopulated) + middlewareFunc := Populate(defaultPopulated) wrappedHandler := middlewareFunc(http.HandlerFunc(handler)) wrappedHandler.ServeHTTP(responseWriter, testRequest) @@ -150,7 +149,7 @@ func TestDateRangeMiddleware(t *testing.T) { testRequest := httptest.NewRequest(http.MethodGet, "/test?"+param, nil) handler := func(w http.ResponseWriter, r *http.Request) {} - middlewareFunc := middleware.DateRange(defaultFrom, defaultTo) + middlewareFunc := DateRange(defaultFrom, defaultTo) wrappedHandler := middlewareFunc(http.HandlerFunc(handler)) wrappedHandler.ServeHTTP(responseWriter, testRequest) @@ -165,7 +164,7 @@ func TestDateRangeMiddleware(t *testing.T) { testRequest := httptest.NewRequest(http.MethodGet, "/test?from=-2hours%&to=now", nil) handler := func(w http.ResponseWriter, r *http.Request) {} - middlewareFunc := middleware.DateRange(defaultFrom, defaultTo) + middlewareFunc := DateRange(defaultFrom, defaultTo) wrappedHandler := middlewareFunc(http.HandlerFunc(handler)) wrappedHandler.ServeHTTP(responseWriter, testRequest) @@ -189,7 +188,7 @@ func TestTargetNameMiddleware(t *testing.T) { testRequest := httptest.NewRequest(http.MethodGet, "/test?target=test", nil) handler := func(w http.ResponseWriter, r *http.Request) {} - middlewareFunc := middleware.TargetName(defaultTargetName) + middlewareFunc := TargetName(defaultTargetName) wrappedHandler := middlewareFunc(http.HandlerFunc(handler)) wrappedHandler.ServeHTTP(responseWriter, testRequest) @@ -203,7 +202,7 @@ func TestTargetNameMiddleware(t *testing.T) { testRequest := httptest.NewRequest(http.MethodGet, "/test?target%=test", nil) handler := func(w http.ResponseWriter, r *http.Request) {} - middlewareFunc := middleware.TargetName(defaultTargetName) + middlewareFunc := TargetName(defaultTargetName) wrappedHandler := middlewareFunc(http.HandlerFunc(handler)) wrappedHandler.ServeHTTP(responseWriter, testRequest) diff --git a/api/middleware/readonly_mode.go b/api/middleware/readonly_mode.go new file mode 100644 index 000000000..add187106 --- /dev/null +++ b/api/middleware/readonly_mode.go @@ -0,0 +1,29 @@ +package middleware + +import ( + "net/http" + + "github.com/go-chi/render" + "github.com/moira-alert/moira/api" +) + +// ReadOnlyMiddleware returns 403 for mutating queries if readonly mode is enabled +func ReadOnlyMiddleware(config *api.Config) func(next http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + fn := func(w http.ResponseWriter, r *http.Request) { + if config.Flags.IsReadonlyEnabled && isMutatingMethod(r.Method) { + render.Render(w, r, api.ErrorForbidden("Moira is currently in read-only mode")) //nolint:errcheck + return + } + next.ServeHTTP(w, r) + } + return http.HandlerFunc(fn) + } +} + +func isMutatingMethod(method string) bool { + return method == http.MethodPut || + method == http.MethodPost || + method == http.MethodPatch || + method == http.MethodDelete +} diff --git a/api/middleware/readonly_mode_test.go b/api/middleware/readonly_mode_test.go new file mode 100644 index 000000000..36a601a2e --- /dev/null +++ b/api/middleware/readonly_mode_test.go @@ -0,0 +1,73 @@ +package middleware + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/moira-alert/moira/api" + . "github.com/smartystreets/goconvey/convey" +) + +func TestReadonlyModeMiddleware(t *testing.T) { + Convey("Given readonly mode is disabled", t, func() { + config := &api.Config{Flags: api.FeatureFlags{IsReadonlyEnabled: false}} + + Convey("Performing get request", func() { + actual := PerformRequestWithReadonlyModeMiddleware(config, http.MethodGet) + + So(actual, ShouldEqual, http.StatusOK) + }) + Convey("Performing put request", func() { + actual := PerformRequestWithReadonlyModeMiddleware(config, http.MethodPut) + + So(actual, ShouldEqual, http.StatusOK) + }) + }) + + Convey("Given readonly mode is enabled", t, func() { + config := &api.Config{Flags: api.FeatureFlags{IsReadonlyEnabled: true}} + + Convey("Performing get request", func() { + actual := PerformRequestWithReadonlyModeMiddleware(config, http.MethodGet) + + So(actual, ShouldEqual, http.StatusOK) + }) + Convey("Performing put request", func() { + actual := PerformRequestWithReadonlyModeMiddleware(config, http.MethodPut) + + So(actual, ShouldEqual, http.StatusForbidden) + }) + Convey("Performing post request", func() { + actual := PerformRequestWithReadonlyModeMiddleware(config, http.MethodPost) + + So(actual, ShouldEqual, http.StatusForbidden) + }) + Convey("Performing patch request", func() { + actual := PerformRequestWithReadonlyModeMiddleware(config, http.MethodPatch) + + So(actual, ShouldEqual, http.StatusForbidden) + }) + Convey("Performing delete request", func() { + actual := PerformRequestWithReadonlyModeMiddleware(config, http.MethodDelete) + + So(actual, ShouldEqual, http.StatusForbidden) + }) + }) +} + +func PerformRequestWithReadonlyModeMiddleware(config *api.Config, method string) int { + responseWriter := httptest.NewRecorder() + + testRequest := httptest.NewRequest(method, "/test", nil) + + handler := func(w http.ResponseWriter, r *http.Request) {} + middlewareFunc := ReadOnlyMiddleware(config) + wrappedHandler := middlewareFunc(http.HandlerFunc(handler)) + + wrappedHandler.ServeHTTP(responseWriter, testRequest) + response := responseWriter.Result() + defer response.Body.Close() + + return response.StatusCode +} diff --git a/cmd/api/config.go b/cmd/api/config.go index 50ec41623..0a4995735 100644 --- a/cmd/api/config.go +++ b/cmd/api/config.go @@ -59,14 +59,19 @@ type featureFlags struct { IsPlottingDefaultOn bool `yaml:"is_plotting_default_on"` IsPlottingAvailable bool `yaml:"is_plotting_available"` IsSubscriptionToAllTagsAvailable bool `yaml:"is_subscription_to_all_tags_available"` + IsReadonlyEnabled bool `yaml:"is_readonly_enabled"` } -func (config *apiConfig) getSettings(localMetricTTL, remoteMetricTTL string) *api.Config { +func (config *apiConfig) getSettings( + localMetricTTL, remoteMetricTTL string, + flags api.FeatureFlags, +) *api.Config { return &api.Config{ EnableCORS: config.EnableCORS, Listen: config.Listen, GraphiteLocalMetricTTL: to.Duration(localMetricTTL), GraphiteRemoteMetricTTL: to.Duration(remoteMetricTTL), + Flags: flags, } } @@ -99,6 +104,7 @@ func (config *webConfig) getFeatureFlags() api.FeatureFlags { IsPlottingDefaultOn: config.FeatureFlags.IsPlottingDefaultOn, IsPlottingAvailable: config.FeatureFlags.IsPlottingAvailable, IsSubscriptionToAllTagsAvailable: config.FeatureFlags.IsSubscriptionToAllTagsAvailable, + IsReadonlyEnabled: config.FeatureFlags.IsReadonlyEnabled, } } diff --git a/cmd/api/config_test.go b/cmd/api/config_test.go index 75cd644ab..0b24df01f 100644 --- a/cmd/api/config_test.go +++ b/cmd/api/config_test.go @@ -23,9 +23,10 @@ func Test_apiConfig_getSettings(t *testing.T) { Listen: "0000", GraphiteLocalMetricTTL: time.Hour, GraphiteRemoteMetricTTL: 24 * time.Hour, + Flags: api.FeatureFlags{IsReadonlyEnabled: true}, } - result := apiConf.getSettings("1h", "24h") + result := apiConf.getSettings("1h", "24h", api.FeatureFlags{IsReadonlyEnabled: true}) So(result, ShouldResemble, expectedResult) }) } @@ -115,7 +116,7 @@ func Test_webConfig_getSettings(t *testing.T) { result, err := wC.getSettings(true) So(err, ShouldBeEmpty) - So(string(result), ShouldResemble, "{\"remoteAllowed\":true,\"contacts\":[],\"featureFlags\":{\"isPlottingDefaultOn\":false,\"isPlottingAvailable\":false,\"isSubscriptionToAllTagsAvailable\":false}}") + So(string(result), ShouldResemble, `{"remoteAllowed":true,"contacts":[],"featureFlags":{"isPlottingDefaultOn":false,"isPlottingAvailable":false,"isSubscriptionToAllTagsAvailable":false,"isReadonlyEnabled":false}}`) }) Convey("Default config, fill it", t, func() { @@ -123,7 +124,7 @@ func Test_webConfig_getSettings(t *testing.T) { result, err := config.Web.getSettings(true) So(err, ShouldBeEmpty) - So(string(result), ShouldResemble, "{\"remoteAllowed\":true,\"contacts\":[],\"featureFlags\":{\"isPlottingDefaultOn\":true,\"isPlottingAvailable\":true,\"isSubscriptionToAllTagsAvailable\":true}}") + So(string(result), ShouldResemble, `{"remoteAllowed":true,"contacts":[],"featureFlags":{"isPlottingDefaultOn":true,"isPlottingAvailable":true,"isSubscriptionToAllTagsAvailable":true,"isReadonlyEnabled":false}}`) }) Convey("Not empty config, fill it", t, func() { @@ -148,6 +149,6 @@ func Test_webConfig_getSettings(t *testing.T) { result, err := wC.getSettings(true) So(err, ShouldBeEmpty) - So(string(result), ShouldResemble, "{\"supportEmail\":\"lalal@mail.la\",\"remoteAllowed\":true,\"contacts\":[{\"type\":\"slack\",\"label\":\"label\",\"validation\":\"t(\\\\d+)\",\"help\":\"help\"}],\"featureFlags\":{\"isPlottingDefaultOn\":true,\"isPlottingAvailable\":false,\"isSubscriptionToAllTagsAvailable\":true}}") + So(string(result), ShouldResemble, `{"supportEmail":"lalal@mail.la","remoteAllowed":true,"contacts":[{"type":"slack","label":"label","validation":"t(\\d+)","help":"help"}],"featureFlags":{"isPlottingDefaultOn":true,"isPlottingAvailable":false,"isSubscriptionToAllTagsAvailable":true,"isReadonlyEnabled":false}}`) }) } diff --git a/cmd/api/main.go b/cmd/api/main.go index bbf76fffd..c044c89fc 100644 --- a/cmd/api/main.go +++ b/cmd/api/main.go @@ -61,7 +61,11 @@ func main() { os.Exit(1) } - apiConfig := applicationConfig.API.getSettings(applicationConfig.Redis.MetricsTTL, applicationConfig.Remote.MetricsTTL) + apiConfig := applicationConfig.API.getSettings( + applicationConfig.Redis.MetricsTTL, + applicationConfig.Remote.MetricsTTL, + applicationConfig.Web.getFeatureFlags(), + ) logger, err := logging.ConfigureLog(applicationConfig.Logger.LogFile, applicationConfig.Logger.LogLevel, serviceName, applicationConfig.Logger.LogPrettyFormat) @@ -144,7 +148,15 @@ func main() { Msg("Failed to get web applicationConfig content ") } - httpHandler := handler.NewHandler(database, logger, searchIndex, apiConfig, metricSourceProvider, webConfigContent) + httpHandler := handler.NewHandler( + database, + logger, + searchIndex, + apiConfig, + metricSourceProvider, + webConfigContent, + ) + server := &http.Server{ Handler: httpHandler, } diff --git a/local/api.yml b/local/api.yml index c0b4939db..1b5b9834e 100644 --- a/local/api.yml +++ b/local/api.yml @@ -50,6 +50,7 @@ web: is_plotting_available: true is_plotting_default_on: true is_subscription_to_all_tags_available: true + is_readonly_enabled: true notification_history: ttl: 48h query_limit: 10000