diff --git a/jwt/config.go b/jwt/config.go index 4fe968da..53b3626a 100644 --- a/jwt/config.go +++ b/jwt/config.go @@ -60,6 +60,11 @@ type Config struct { // - "cookie:" TokenLookup string + // TokenDeobfuscatorFunc defines a function to deobfuscate the founded token with [TokenLookup]. + // This help to implement a Token obfuscation algoritm to prevent information disclosure. + // Optional. Default: nil + TokenDeobfuscatorFunc func(obfuscatedToken string) (string, error) + // AuthScheme to be used in the Authorization header. // Optional. Default: "Bearer". AuthScheme string diff --git a/jwt/jwt.go b/jwt/jwt.go index 94139704..5c4342af 100644 --- a/jwt/jwt.go +++ b/jwt/jwt.go @@ -40,6 +40,14 @@ func New(config ...Config) fiber.Handler { if err != nil { return cfg.ErrorHandler(c, err) } + + if cfg.TokenDeobfuscatorFunc != nil { + auth, err = cfg.TokenDeobfuscatorFunc(auth) + if err != nil { + return cfg.ErrorHandler(c, err) + } + } + var token *jwt.Token if _, ok := cfg.Claims.(jwt.MapClaims); ok { diff --git a/jwt/jwt_test.go b/jwt/jwt_test.go index f380bea9..770665b5 100644 --- a/jwt/jwt_test.go +++ b/jwt/jwt_test.go @@ -1,6 +1,7 @@ package jwtware_test import ( + "encoding/hex" "fmt" "net/http" "net/http/httptest" @@ -104,6 +105,47 @@ const ( ` ) +func TestJwtDeobfuscation(t *testing.T) { + t.Parallel() + + defer func() { + // Assert + if err := recover(); err != nil { + t.Fatalf("Middleware should not panic") + } + }() + + for _, test := range hamac { + // Arrange + app := fiber.New() + + app.Use(jwtware.New(jwtware.Config{ + SigningKey: jwtware.SigningKey{ + JWTAlg: test.SigningMethod, + Key: []byte(defaultSigningKey), + }, + TokenDeobfuscatorFunc: func(obfuscatedToken string) (string, error) { + token, err := hex.DecodeString(obfuscatedToken) + return string(token), err + }, + })) + + app.Get("/ok", func(c *fiber.Ctx) error { + return c.SendString("OK") + }) + + req := httptest.NewRequest("GET", "/ok", nil) + req.Header.Add("Authorization", "Bearer "+hex.EncodeToString([]byte(test.Token))) + + // Act + resp, err := app.Test(req) + + // Assert + utils.AssertEqual(t, nil, err) + utils.AssertEqual(t, 200, resp.StatusCode) + } +} + func TestJwtFromHeader(t *testing.T) { t.Parallel()