Skip to content

Commit

Permalink
Use fuego.HTTPError variants in the Security package (#396)
Browse files Browse the repository at this point in the history
  • Loading branch information
EwenQuim authored Feb 12, 2025
1 parent 784b931 commit 84d8aef
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 17 deletions.
42 changes: 30 additions & 12 deletions security.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,26 @@ import (
)

var (
ErrUnauthorized = errors.New("unauthorized")
ErrTokenNotFound = errors.New("token not found")
// ErrUnauthorized is used for authorization errors
//
// Deprecated: Use [UnauthorizedError] instead.
ErrUnauthorized = errors.New("unauthorized")
// ErrTokenNotFound is used when the token is not found
//
// Deprecated: Use [UnauthorizedError] instead.
ErrTokenNotFound = errors.New("token not found")
// ErrInvalidTokenType is used when the token type is invalid
//
// Deprecated: Use [UnauthorizedError] instead.
ErrInvalidTokenType = errors.New("invalid token type")
// ErrInvalidRolesType is used when the roles type is invalid
//
// Deprecated: Use [UnauthorizedError] instead.
ErrInvalidRolesType = errors.New("invalid role type. Must be []string")
ErrExpired = errors.New("token is expired")
// ErrExpired is used when the token is expired
//
// Deprecated: Use [ForbiddenError] instead.
ErrExpired = errors.New("token is expired")
)

// Security holds the key to sign the JWT tokens, and configuration information.
Expand Down Expand Up @@ -92,8 +107,11 @@ func (security Security) ValidateToken(token string) (*jwt.Token, error) {
}

iat, err := t.Claims.GetIssuedAt()
if err != nil || iat == nil || float64(iat.Unix())+security.ExpiresInterval.Seconds() < float64(security.Now().Unix()) {
return nil, ErrExpired
if err != nil {
return nil, UnauthorizedError{Title: "No Issued date found", Err: err}
}
if iat == nil || float64(iat.Unix())+security.ExpiresInterval.Seconds() < float64(security.Now().Unix()) {
return nil, UnauthorizedError{Title: "Token expired"}
}

return t, nil
Expand Down Expand Up @@ -123,11 +141,11 @@ func WithValue(ctx context.Context, val any) context.Context {
func TokenFromContext(ctx context.Context) (jwt.Claims, error) {
value := ctx.Value(contextKeyJWT)
if value == nil {
return nil, ErrTokenNotFound
return nil, UnauthorizedError{Title: "Could not find token in context"}
}
claims, ok := value.(jwt.MapClaims)
if !ok {
return nil, ErrInvalidTokenType
return nil, UnauthorizedError{Title: "Invalid token type"}
}

return claims, nil
Expand All @@ -147,7 +165,7 @@ func GetToken[T any](ctx context.Context) (T, error) {

myClaims, ok := claims.(T)
if !ok {
return myClaims, ErrInvalidTokenType
return myClaims, UnauthorizedError{Title: "Invalid token type"}
}

return myClaims, nil
Expand Down Expand Up @@ -306,20 +324,20 @@ func authWall(authorizeFunc func(userRoles ...string) bool) func(next http.Handl
// Get the authorizationHeader from the context (set by TokenToContext)
claims, err := TokenFromContext(r.Context())
if err != nil {
SendJSONError(w, nil, ErrUnauthorized)
SendJSONError(w, nil, UnauthorizedError{Title: "Unauthorized"})
return
}

// Get the subject and userRoles from the claims
userRoles, ok := claims.(jwt.MapClaims)["roles"].([]string)
if !ok {
SendJSONError(w, nil, ErrInvalidTokenType)
SendJSONError(w, nil, UnauthorizedError{Title: "Could not find roles in token"})
return
}

// Check if the user is authorized
if !authorizeFunc(userRoles...) {
SendJSONError(w, nil, ErrUnauthorized)
SendJSONError(w, nil, ForbiddenError{Title: "Access denied"})
return
}

Expand Down Expand Up @@ -464,7 +482,7 @@ func (security Security) LoginHandler(verifyUserInfo func(user, password string)
func (security Security) RefreshHandler(w http.ResponseWriter, r *http.Request) {
claims, err := TokenFromContext(r.Context())
if err != nil {
SendJSONError(w, nil, ErrUnauthorized)
SendJSONError(w, nil, UnauthorizedError{Title: "Could not find token in context"})
return
}

Expand Down
12 changes: 7 additions & 5 deletions security_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package fuego

import (
"context"
"fmt"
"net/http"
"net/http/httptest"
"regexp"
Expand Down Expand Up @@ -54,7 +55,8 @@ func TestSecurity(t *testing.T) {
security.Now = func() time.Time { return now.Add(15 * time.Minute) }
decoded, err := security.ValidateToken(s)
require.Error(t, err)
require.ErrorIs(t, err, ErrExpired)
fmt.Printf("error: %v\n", err)
require.ErrorAs(t, err, &UnauthorizedError{})
require.Empty(t, decoded)
})
}
Expand Down Expand Up @@ -184,7 +186,7 @@ func TestAuthWall(t *testing.T) {
w := httptest.NewRecorder()

authWall(h).ServeHTTP(w, r)
require.Equal(t, http.StatusInternalServerError, w.Code)
require.Equal(t, http.StatusForbidden, w.Code)
})

t.Run("with token", func(t *testing.T) {
Expand Down Expand Up @@ -214,7 +216,7 @@ func TestAuthWall(t *testing.T) {
w := httptest.NewRecorder()

authWall(h).ServeHTTP(w, r)
require.Equal(t, http.StatusInternalServerError, w.Code)
require.Equal(t, http.StatusForbidden, w.Code)
})

t.Run("with token", func(t *testing.T) {
Expand Down Expand Up @@ -373,7 +375,7 @@ func TestSecurity_StdLoginHandler(t *testing.T) {
security := NewSecurity()
v := func(r *http.Request) (jwt.Claims, error) {
if r.FormValue("user") != "test" || r.FormValue("password") != "test" {
return nil, ErrUnauthorized
return nil, UnauthorizedError{}
}
return jwt.MapClaims{"sub": "123"}, nil
}
Expand Down Expand Up @@ -405,7 +407,7 @@ func TestSecurity_LoginHandler(t *testing.T) {
security := NewSecurity()
v := func(user, password string) (jwt.Claims, error) {
if user != "test" || password != "test" {
return nil, ErrUnauthorized
return nil, UnauthorizedError{}
}
return jwt.MapClaims{"sub": "123"}, nil
}
Expand Down

0 comments on commit 84d8aef

Please sign in to comment.