From 7b380fb277adc08919c5e401d6d4aa9c5f01d908 Mon Sep 17 00:00:00 2001 From: Umputun Date: Fri, 11 Oct 2024 02:50:36 -0500 Subject: [PATCH] Add the ability to set a custom not-found handler. The handler allows setting a custom response body for unregistered (not found) resources and also allows setting a custom status code. --- README.md | 11 ++++ group.go | 10 ++++ group_test.go | 147 ++++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 168 insertions(+) diff --git a/README.md b/README.md index 16803b8..919bfa3 100644 --- a/README.md +++ b/README.md @@ -28,6 +28,17 @@ func main() { group := routegroup.New(mux) } ``` +** Setting optional `NotFoundHandler`** + +It is possible to set a custom `NotFoundHandler` for the group. This handler will be called when no other route matches the request: + +```go + group.NotFoundHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.Error(w, "404 page not found, something is wrong!", http.StatusNotFound) + }) +``` + +If the custom `NotFoundHandler` is not set, `routegroup` will automatically use a default handler from stdlib (`http.NotFoundHandler()`). **Adding Routes with Middleware** diff --git a/group.go b/group.go index 2346a6c..d90bc9f 100644 --- a/group.go +++ b/group.go @@ -17,6 +17,7 @@ type Bundle struct { once sync.Once // used to register a not found handler for the root path if no / route is registered set bool // true if the root path is registered in the mux disableRootNotFoundHandler bool // if true, the not found handler for the root path is not registered automatically + notFound http.HandlerFunc } } @@ -39,6 +40,9 @@ func (b *Bundle) ServeHTTP(w http.ResponseWriter, r *http.Request) { // register a not found handler for the root path if no / route is registered // this is needed to be able to use middleware on all routes, for example logging notFoundHandler := http.NotFoundHandler() + if b.rootRegistered.notFound != nil { + notFoundHandler = b.rootRegistered.notFound + } b.register("/", notFoundHandler.ServeHTTP) } }) @@ -108,6 +112,11 @@ func (b *Bundle) Handler(r *http.Request) (h http.Handler, pattern string) { // DisableNotFoundHandler disables the automatic registration of a not found handler for the root path. func (b *Bundle) DisableNotFoundHandler() { b.rootRegistered.disableRootNotFoundHandler = true } +// NotFoundHandler sets a custom handler for the root path if no / route is registered. +func (b *Bundle) NotFoundHandler(handler http.HandlerFunc) { + b.rootRegistered.notFound = handler +} + // Matches non-space characters, spaces, then anything, i.e. "GET /path/to/resource" var reGo122 = regexp.MustCompile(`^(\S*)\s+(.*)$`) @@ -119,6 +128,7 @@ func (b *Bundle) register(pattern string, handler http.HandlerFunc) { pattern = b.basePath + pattern } + // check if the root path is registered if pattern == "/" || b.basePath+pattern == "/" { b.rootRegistered.set = true } diff --git a/group_test.go b/group_test.go index 57404de..3feadb3 100644 --- a/group_test.go +++ b/group_test.go @@ -800,6 +800,153 @@ func TestHTTPServerWithDerived(t *testing.T) { }) } +func TestHTTPServerWithCustomNotFound(t *testing.T) { + group := routegroup.New(http.NewServeMux()) + group.Use(testMiddleware) + group.NotFoundHandler(func(w http.ResponseWriter, _ *http.Request) { + http.Error(w, "Custom 404: Page not found!", http.StatusNotFound) + }) + + apiGroup := group.Mount("/api") + apiGroup.HandleFunc("GET /test", func(w http.ResponseWriter, _ *http.Request) { + _, _ = w.Write([]byte("test handler")) + }) + + testServer := httptest.NewServer(group) + defer testServer.Close() + + t.Run("GET /api/test", func(t *testing.T) { + resp, err := http.Get(testServer.URL + "/api/test") + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + t.Errorf("Expected status code %d, got %d", http.StatusOK, resp.StatusCode) + } + + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatal(err) + } + if string(body) != "test handler" { + t.Errorf("Expected body 'test handler', got '%s'", string(body)) + } + if header := resp.Header.Get("X-Test-Middleware"); header != "true" { + t.Errorf("Expected header X-Test-Middleware to be 'true', got '%s'", header) + } + }) + + t.Run("GET /api/not-found", func(t *testing.T) { + resp, err := http.Get(testServer.URL + "/api/not-found") + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatal(err) + } + t.Logf("body: %s", body) + + if resp.StatusCode != http.StatusNotFound { + t.Errorf("Expected status code %d, got %d", http.StatusNotFound, resp.StatusCode) + } + + if header := resp.Header.Get("X-Test-Middleware"); header != "true" { + t.Errorf("Expected header X-Test-Middleware to be 'true', got '%s'", header) + } + if string(body) != "Custom 404: Page not found!\n" { + t.Errorf("Expected body 'Custom 404: Page not found!', got '%s'", string(body)) + } + }) + + t.Run("GET /not-found", func(t *testing.T) { + resp, err := http.Get(testServer.URL + "/not-found") + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatal(err) + } + t.Logf("body: %s", body) + + if resp.StatusCode != http.StatusNotFound { + t.Errorf("Expected status code %d, got %d", http.StatusNotFound, resp.StatusCode) + } + if header := resp.Header.Get("X-Test-Middleware"); header != "true" { + t.Errorf("Expected header X-Test-Middleware to be 'true', got '%s'", header) + } + if string(body) != "Custom 404: Page not found!\n" { + t.Errorf("Expected body 'Custom 404: Page not found!', got '%s'", string(body)) + } + }) +} + +func TestHTTPServerWithCustomNotFoundNon404Status(t *testing.T) { + group := routegroup.New(http.NewServeMux()) + group.Use(testMiddleware) + group.NotFoundHandler(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusServiceUnavailable) + _, _ = w.Write([]byte("Custom 404: Page not found!\n")) + }) + + apiGroup := group.Mount("/api") + apiGroup.HandleFunc("GET /test", func(w http.ResponseWriter, _ *http.Request) { + _, _ = w.Write([]byte("test handler")) + }) + + testServer := httptest.NewServer(group) + defer testServer.Close() + + t.Run("GET /api/test", func(t *testing.T) { + resp, err := http.Get(testServer.URL + "/api/test") + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatal(err) + } + if resp.StatusCode != http.StatusOK { + t.Errorf("Expected status code %d, got %d", http.StatusOK, resp.StatusCode) + } + if string(body) != "test handler" { + t.Errorf("Expected body 'test handler', got '%s'", string(body)) + } + if header := resp.Header.Get("X-Test-Middleware"); header != "true" { + t.Errorf("Expected header X-Test-Middleware to be 'true', got '%s'", header) + } + }) + + t.Run("GET /api/not-found", func(t *testing.T) { + resp, err := http.Get(testServer.URL + "/api/not-found") + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatal(err) + } + t.Logf("body: %s", body) + + if resp.StatusCode != http.StatusServiceUnavailable { + t.Errorf("Expected status code %d, got %d", http.StatusServiceUnavailable, resp.StatusCode) + } + if header := resp.Header.Get("X-Test-Middleware"); header != "true" { + t.Errorf("Expected header X-Test-Middleware to be 'true', got '%s'", header) + } + if string(body) != "Custom 404: Page not found!\n" { + t.Errorf("Expected body 'Custom 404: Page not found!', got '%s'", string(body)) + } + }) + +} + func ExampleNew() { group := routegroup.New(http.NewServeMux())