Skip to content

Commit

Permalink
Add the ability to set a custom not-found handler.
Browse files Browse the repository at this point in the history
The handler allows setting a custom response body for unregistered (not found) resources and also allows setting a custom status code.
  • Loading branch information
umputun committed Oct 11, 2024
1 parent 80c3d19 commit 7b380fb
Show file tree
Hide file tree
Showing 3 changed files with 168 additions and 0 deletions.
11 changes: 11 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,17 @@ func main() {
group := routegroup.New(mux)
}
```
** Setting optional `NotFoundHandler`**

It is possible to set a custom `NotFoundHandler` for the group. This handler will be called when no other route matches the request:

```go
group.NotFoundHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.Error(w, "404 page not found, something is wrong!", http.StatusNotFound)
})
```

If the custom `NotFoundHandler` is not set, `routegroup` will automatically use a default handler from stdlib (`http.NotFoundHandler()`).

**Adding Routes with Middleware**

Expand Down
10 changes: 10 additions & 0 deletions group.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ type Bundle struct {
once sync.Once // used to register a not found handler for the root path if no / route is registered
set bool // true if the root path is registered in the mux
disableRootNotFoundHandler bool // if true, the not found handler for the root path is not registered automatically
notFound http.HandlerFunc
}
}

Expand All @@ -39,6 +40,9 @@ func (b *Bundle) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// register a not found handler for the root path if no / route is registered
// this is needed to be able to use middleware on all routes, for example logging
notFoundHandler := http.NotFoundHandler()
if b.rootRegistered.notFound != nil {
notFoundHandler = b.rootRegistered.notFound
}
b.register("/", notFoundHandler.ServeHTTP)
}
})
Expand Down Expand Up @@ -108,6 +112,11 @@ func (b *Bundle) Handler(r *http.Request) (h http.Handler, pattern string) {
// DisableNotFoundHandler disables the automatic registration of a not found handler for the root path.
func (b *Bundle) DisableNotFoundHandler() { b.rootRegistered.disableRootNotFoundHandler = true }

// NotFoundHandler sets a custom handler for the root path if no / route is registered.
func (b *Bundle) NotFoundHandler(handler http.HandlerFunc) {
b.rootRegistered.notFound = handler
}

// Matches non-space characters, spaces, then anything, i.e. "GET /path/to/resource"
var reGo122 = regexp.MustCompile(`^(\S*)\s+(.*)$`)

Expand All @@ -119,6 +128,7 @@ func (b *Bundle) register(pattern string, handler http.HandlerFunc) {
pattern = b.basePath + pattern
}

// check if the root path is registered
if pattern == "/" || b.basePath+pattern == "/" {
b.rootRegistered.set = true
}
Expand Down
147 changes: 147 additions & 0 deletions group_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -800,6 +800,153 @@ func TestHTTPServerWithDerived(t *testing.T) {
})
}

func TestHTTPServerWithCustomNotFound(t *testing.T) {
group := routegroup.New(http.NewServeMux())
group.Use(testMiddleware)
group.NotFoundHandler(func(w http.ResponseWriter, _ *http.Request) {
http.Error(w, "Custom 404: Page not found!", http.StatusNotFound)
})

apiGroup := group.Mount("/api")
apiGroup.HandleFunc("GET /test", func(w http.ResponseWriter, _ *http.Request) {
_, _ = w.Write([]byte("test handler"))
})

testServer := httptest.NewServer(group)
defer testServer.Close()

t.Run("GET /api/test", func(t *testing.T) {
resp, err := http.Get(testServer.URL + "/api/test")
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Errorf("Expected status code %d, got %d", http.StatusOK, resp.StatusCode)
}

body, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatal(err)
}
if string(body) != "test handler" {
t.Errorf("Expected body 'test handler', got '%s'", string(body))
}
if header := resp.Header.Get("X-Test-Middleware"); header != "true" {
t.Errorf("Expected header X-Test-Middleware to be 'true', got '%s'", header)
}
})

t.Run("GET /api/not-found", func(t *testing.T) {
resp, err := http.Get(testServer.URL + "/api/not-found")
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatal(err)
}
t.Logf("body: %s", body)

if resp.StatusCode != http.StatusNotFound {
t.Errorf("Expected status code %d, got %d", http.StatusNotFound, resp.StatusCode)
}

if header := resp.Header.Get("X-Test-Middleware"); header != "true" {
t.Errorf("Expected header X-Test-Middleware to be 'true', got '%s'", header)
}
if string(body) != "Custom 404: Page not found!\n" {
t.Errorf("Expected body 'Custom 404: Page not found!', got '%s'", string(body))
}
})

t.Run("GET /not-found", func(t *testing.T) {
resp, err := http.Get(testServer.URL + "/not-found")
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatal(err)
}
t.Logf("body: %s", body)

if resp.StatusCode != http.StatusNotFound {
t.Errorf("Expected status code %d, got %d", http.StatusNotFound, resp.StatusCode)
}
if header := resp.Header.Get("X-Test-Middleware"); header != "true" {
t.Errorf("Expected header X-Test-Middleware to be 'true', got '%s'", header)
}
if string(body) != "Custom 404: Page not found!\n" {
t.Errorf("Expected body 'Custom 404: Page not found!', got '%s'", string(body))
}
})
}

func TestHTTPServerWithCustomNotFoundNon404Status(t *testing.T) {
group := routegroup.New(http.NewServeMux())
group.Use(testMiddleware)
group.NotFoundHandler(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusServiceUnavailable)
_, _ = w.Write([]byte("Custom 404: Page not found!\n"))
})

apiGroup := group.Mount("/api")
apiGroup.HandleFunc("GET /test", func(w http.ResponseWriter, _ *http.Request) {
_, _ = w.Write([]byte("test handler"))
})

testServer := httptest.NewServer(group)
defer testServer.Close()

t.Run("GET /api/test", func(t *testing.T) {
resp, err := http.Get(testServer.URL + "/api/test")
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatal(err)
}
if resp.StatusCode != http.StatusOK {
t.Errorf("Expected status code %d, got %d", http.StatusOK, resp.StatusCode)
}
if string(body) != "test handler" {
t.Errorf("Expected body 'test handler', got '%s'", string(body))
}
if header := resp.Header.Get("X-Test-Middleware"); header != "true" {
t.Errorf("Expected header X-Test-Middleware to be 'true', got '%s'", header)
}
})

t.Run("GET /api/not-found", func(t *testing.T) {
resp, err := http.Get(testServer.URL + "/api/not-found")
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatal(err)
}
t.Logf("body: %s", body)

if resp.StatusCode != http.StatusServiceUnavailable {
t.Errorf("Expected status code %d, got %d", http.StatusServiceUnavailable, resp.StatusCode)
}
if header := resp.Header.Get("X-Test-Middleware"); header != "true" {
t.Errorf("Expected header X-Test-Middleware to be 'true', got '%s'", header)
}
if string(body) != "Custom 404: Page not found!\n" {
t.Errorf("Expected body 'Custom 404: Page not found!', got '%s'", string(body))
}
})

}

func ExampleNew() {
group := routegroup.New(http.NewServeMux())

Expand Down

0 comments on commit 7b380fb

Please sign in to comment.