From fa6cc9ae2112e3a3483c268e541098b3c3e122be Mon Sep 17 00:00:00 2001 From: Mike Scornavacca Date: Wed, 28 Jul 2021 12:20:30 -0400 Subject: [PATCH 1/3] Swapped default for non-refresh token expiry time --- flask_jwt_extended/jwt_manager.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/flask_jwt_extended/jwt_manager.py b/flask_jwt_extended/jwt_manager.py index b39d2c93..50b6e440 100644 --- a/flask_jwt_extended/jwt_manager.py +++ b/flask_jwt_extended/jwt_manager.py @@ -487,10 +487,10 @@ def _encode_jwt_from_config( claim_overrides.update(claims) if expires_delta is None: - if token_type == "access": - expires_delta = config.access_expires - else: + if token_type == "refresh": expires_delta = config.refresh_expires + else: + expires_delta = config.access_expires return _encode_jwt( algorithm=config.algorithm, From 41da071bb174aac0a21a8a445e2d3968f8bc6d7f Mon Sep 17 00:00:00 2001 From: Mike Scornavacca Date: Wed, 28 Jul 2021 13:03:06 -0400 Subject: [PATCH 2/3] Added support for custom token types --- flask_jwt_extended/internal_utils.py | 8 ++--- flask_jwt_extended/utils.py | 51 +++++++++++++++++++++++++++ flask_jwt_extended/view_decorators.py | 30 +++++++++++----- 3 files changed, 77 insertions(+), 12 deletions(-) diff --git a/flask_jwt_extended/internal_utils.py b/flask_jwt_extended/internal_utils.py index c7ee8c21..7ba129b6 100644 --- a/flask_jwt_extended/internal_utils.py +++ b/flask_jwt_extended/internal_utils.py @@ -25,11 +25,11 @@ def user_lookup(*args, **kwargs): return jwt_manager._user_lookup_callback(*args, **kwargs) -def verify_token_type(decoded_token, refresh): - if not refresh and decoded_token["type"] == "refresh": - raise WrongTokenError("Only non-refresh tokens are allowed") - elif refresh and decoded_token["type"] != "refresh": +def verify_token_type(decoded_token, refresh, token_type): + if refresh and decoded_token["type"] != "refresh": raise WrongTokenError("Only refresh tokens are allowed") + elif not refresh and decoded_token["type"] != token_type: + raise WrongTokenError(f"Token is not of type {token_type}") def verify_token_not_blocklisted(jwt_header, jwt_data): diff --git a/flask_jwt_extended/utils.py b/flask_jwt_extended/utils.py index 48ae99ba..c10477fa 100644 --- a/flask_jwt_extended/utils.py +++ b/flask_jwt_extended/utils.py @@ -219,6 +219,57 @@ def create_refresh_token( ) +def create_custom_token( + identity, + token_type, + expires_delta=None, + additional_claims=None, + additional_headers=None, +): + """ + Create a new custom token, with a manually specified type. + + :param identity: + The identity of this token. It can be any data that is json serializable. + You can use :meth:`~flask_jwt_extended.JWTManager.user_identity_loader` + to define a callback function to convert any object passed in into a json + serializable format. + + :param token_type: + The type of this token. A string such as "refresh" or "access" that specifies + the type and purpose of this token. + + :param expires_delta: + A ``datetime.timedelta`` for how long this token should last before it expires. + Set to False to disable expiration. If this is None, it will use the + ``JWT_REFRESH_TOKEN_EXPIRES`` config value (see :ref:`Configuration Options`) + + :param additional_claims: + Optional. A hash of claims to include in the refresh token. These claims are + merged into the default claims (exp, iat, etc) and claims returned from the + :meth:`~flask_jwt_extended.JWTManager.additional_claims_loader` callback. + On conflict, these claims take presidence. + + :param headers: + Optional. A hash of headers to include in the refresh token. These headers + are merged into the default headers (alg, typ) and headers returned from the + :meth:`~flask_jwt_extended.JWTManager.additional_headers_loader` callback. + On conflict, these headers take presidence. + + :return: + An encoded refresh token + """ + jwt_manager = get_jwt_manager() + return jwt_manager._encode_jwt_from_config( + claims=additional_claims, + expires_delta=expires_delta, + fresh=False, + headers=additional_headers, + identity=identity, + token_type=token_type, + ) + + def get_unverified_jwt_headers(encoded_token): """ Returns the Headers of an encoded JWT without verifying the signature of the JWT. diff --git a/flask_jwt_extended/view_decorators.py b/flask_jwt_extended/view_decorators.py index f3ba655e..242f63c8 100644 --- a/flask_jwt_extended/view_decorators.py +++ b/flask_jwt_extended/view_decorators.py @@ -35,7 +35,9 @@ def _verify_token_is_fresh(jwt_header, jwt_data): raise FreshTokenRequired("Fresh token required", jwt_header, jwt_data) -def verify_jwt_in_request(optional=False, fresh=False, refresh=False, locations=None): +def verify_jwt_in_request( + optional=False, fresh=False, refresh=False, locations=None, token_type="access" +): """ Verify that a valid JWT is present in the request, unless ``optional=True`` in which case no JWT is also considered valid. @@ -49,13 +51,18 @@ def verify_jwt_in_request(optional=False, fresh=False, refresh=False, locations= Defaults to ``False``. :param refresh: - If ``True``, require a refresh JWT to be verified. + If ``True``, require a refresh JWT to be verified. If ``False``, compare + the JWT type to ``token_type``. Defaults to ``False``. :param locations: A location or list of locations to look for the JWT in this request, for example ``'headers'`` or ``['headers', 'cookies']``. Defaluts to ``None`` which indicates that JWTs will be looked for in the locations defined by the ``JWT_TOKEN_LOCATION`` configuration option. + + :param token_type: + If ``refresh`` is ``False``, then the ``type`` claim on the JWT must exactly match + this string in order to be verified. Defaults to ``"access"``. """ if request.method in config.exempt_methods: return @@ -67,7 +74,7 @@ def verify_jwt_in_request(optional=False, fresh=False, refresh=False, locations= ) else: jwt_data, jwt_header, jwt_location = _decode_jwt_from_request( - locations, fresh + locations, fresh, refresh=False, token_type=token_type ) except NoAuthorizationError: if not optional: @@ -88,7 +95,9 @@ def verify_jwt_in_request(optional=False, fresh=False, refresh=False, locations= return jwt_header, jwt_data -def jwt_required(optional=False, fresh=False, refresh=False, locations=None): +def jwt_required( + optional=False, fresh=False, refresh=False, locations=None, token_type="access" +): """ A decorator to protect a Flask endpoint with JSON Web Tokens. @@ -106,19 +115,24 @@ def jwt_required(optional=False, fresh=False, refresh=False, locations=None): :param refresh: If ``True``, requires a refresh JWT to access this endpoint. If ``False``, - requires an access JWT to access this endpoint. Defaults to ``False``. + requires a JWT specified by ``token_type`` to access + this endpoint. Defaults to ``False``. :param locations: A location or list of locations to look for the JWT in this request, for example ``'headers'`` or ``['headers', 'cookies']``. Defaluts to ``None`` which indicates that JWTs will be looked for in the locations defined by the ``JWT_TOKEN_LOCATION`` configuration option. + + :param token_type: + If ``refresh`` is ``False``, then the ``type`` claim on the JWT must exactly match + this string to access this endpoint. Defaults to ``"access"``. """ def wrapper(fn): @wraps(fn) def decorator(*args, **kwargs): - verify_jwt_in_request(optional, fresh, refresh, locations) + verify_jwt_in_request(optional, fresh, refresh, locations, token_type) # Compatibility with flask < 2.0 try: @@ -253,7 +267,7 @@ def _decode_jwt_from_json(refresh): return encoded_token, None -def _decode_jwt_from_request(locations, fresh, refresh=False): +def _decode_jwt_from_request(locations, fresh, refresh=False, token_type="access"): # Figure out what locations to look for the JWT in this request if isinstance(locations, str): locations = [locations] @@ -312,7 +326,7 @@ def _decode_jwt_from_request(locations, fresh, refresh=False): raise NoAuthorizationError(errors[0]) # Additional verifications provided by this extension - verify_token_type(decoded_token, refresh) + verify_token_type(decoded_token, refresh, token_type) if fresh: _verify_token_is_fresh(jwt_header, decoded_token) verify_token_not_blocklisted(jwt_header, decoded_token) From 8d9fa3278a2968f11f69ce816b49715b6a6a6432 Mon Sep 17 00:00:00 2001 From: Mike Scornavacca Date: Wed, 28 Jul 2021 14:06:39 -0400 Subject: [PATCH 3/3] Added additional tests for custom JWT types --- flask_jwt_extended/__init__.py | 1 + flask_jwt_extended/internal_utils.py | 2 +- flask_jwt_extended/utils.py | 3 +- flask_jwt_extended/view_decorators.py | 3 +- tests/test_view_decorators.py | 48 +++++++++++++++++++++++++-- 5 files changed, 51 insertions(+), 6 deletions(-) diff --git a/flask_jwt_extended/__init__.py b/flask_jwt_extended/__init__.py index 5d63b63c..fbbc042b 100644 --- a/flask_jwt_extended/__init__.py +++ b/flask_jwt_extended/__init__.py @@ -1,5 +1,6 @@ from .jwt_manager import JWTManager from .utils import create_access_token +from .utils import create_custom_token from .utils import create_refresh_token from .utils import current_user from .utils import decode_token diff --git a/flask_jwt_extended/internal_utils.py b/flask_jwt_extended/internal_utils.py index 7ba129b6..d8f90b4d 100644 --- a/flask_jwt_extended/internal_utils.py +++ b/flask_jwt_extended/internal_utils.py @@ -29,7 +29,7 @@ def verify_token_type(decoded_token, refresh, token_type): if refresh and decoded_token["type"] != "refresh": raise WrongTokenError("Only refresh tokens are allowed") elif not refresh and decoded_token["type"] != token_type: - raise WrongTokenError(f"Token is not of type {token_type}") + raise WrongTokenError(f"Token of type { decoded_token['type'] } is not allowed") def verify_token_not_blocklisted(jwt_header, jwt_data): diff --git a/flask_jwt_extended/utils.py b/flask_jwt_extended/utils.py index c10477fa..36fafed1 100644 --- a/flask_jwt_extended/utils.py +++ b/flask_jwt_extended/utils.py @@ -222,6 +222,7 @@ def create_refresh_token( def create_custom_token( identity, token_type, + fresh=False, expires_delta=None, additional_claims=None, additional_headers=None, @@ -263,7 +264,7 @@ def create_custom_token( return jwt_manager._encode_jwt_from_config( claims=additional_claims, expires_delta=expires_delta, - fresh=False, + fresh=fresh, headers=additional_headers, identity=identity, token_type=token_type, diff --git a/flask_jwt_extended/view_decorators.py b/flask_jwt_extended/view_decorators.py index 242f63c8..4fffd39d 100644 --- a/flask_jwt_extended/view_decorators.py +++ b/flask_jwt_extended/view_decorators.py @@ -69,8 +69,9 @@ def verify_jwt_in_request( try: if refresh: + token_type = "refresh" jwt_data, jwt_header, jwt_location = _decode_jwt_from_request( - locations, fresh, refresh=True + locations, fresh, refresh=True, token_type=token_type ) else: jwt_data, jwt_header, jwt_location = _decode_jwt_from_request( diff --git a/tests/test_view_decorators.py b/tests/test_view_decorators.py index 6a34c0bb..b1ba7fc6 100644 --- a/tests/test_view_decorators.py +++ b/tests/test_view_decorators.py @@ -6,6 +6,7 @@ from flask import jsonify from flask_jwt_extended import create_access_token +from flask_jwt_extended import create_custom_token from flask_jwt_extended import create_refresh_token from flask_jwt_extended import decode_token from flask_jwt_extended import get_jwt_identity @@ -46,6 +47,11 @@ def optional_protected(): else: return jsonify(foo="bar") + @app.route("/custom_protected", methods=["GET"]) + @jwt_required(token_type="custom") + def custom_protected(): + return jsonify(foo="bar") + return app @@ -72,7 +78,7 @@ def test_jwt_required(app): # Test refresh token access to jwt_required response = test_client.get(url, headers=make_headers(refresh_token)) assert response.status_code == 422 - assert response.get_json() == {"msg": "Only non-refresh tokens are allowed"} + assert response.get_json() == {"msg": "Token of type refresh is not allowed"} def test_fresh_jwt_required(app): @@ -113,7 +119,7 @@ def test_fresh_jwt_required(app): response = test_client.get(url, headers=make_headers(refresh_token)) assert response.status_code == 422 - assert response.get_json() == {"msg": "Only non-refresh tokens are allowed"} + assert response.get_json() == {"msg": "Token of type refresh is not allowed"} # Test with custom response @jwtM.needs_fresh_token_loader @@ -176,7 +182,7 @@ def test_jwt_optional(app, delta_func): response = test_client.get(url, headers=make_headers(refresh_token)) assert response.status_code == 422 - assert response.get_json() == {"msg": "Only non-refresh tokens are allowed"} + assert response.get_json() == {"msg": "Token of type refresh is not allowed"} response = test_client.get(url, headers=make_headers(expired_token)) assert response.status_code == 401 @@ -229,6 +235,42 @@ def test_jwt_optional_with_no_valid_jwt(app): assert response.get_json() == {"msg": "Not enough segments"} +def test_custom_jwt_required(app): + url = "/custom_protected" + + test_client = app.test_client() + with app.test_request_context(): + custom_token = create_custom_token("username", token_type="custom") + fresh_custom_token = create_custom_token( + "username", token_type="custom", fresh=True + ) + refresh_token = create_refresh_token("username") + incorrect_custom_token = create_custom_token( + "username", token_type="other_custom" + ) + + # Access and fresh access should be able to access this + for token in (custom_token, fresh_custom_token): + response = test_client.get(url, headers=make_headers(token)) + assert response.status_code == 200 + assert response.get_json() == {"foo": "bar"} + + # Test accessing jwt_required with no jwt in the request + response = test_client.get(url, headers=None) + assert response.status_code == 401 + assert response.get_json() == {"msg": "Missing Authorization Header"} + + # Test refresh token access to jwt_required + response = test_client.get(url, headers=make_headers(refresh_token)) + assert response.status_code == 422 + assert response.get_json() == {"msg": "Token of type refresh is not allowed"} + + # Test refresh token access to jwt_required + response = test_client.get(url, headers=make_headers(incorrect_custom_token)) + assert response.status_code == 422 + assert response.get_json() == {"msg": "Token of type other_custom is not allowed"} + + def test_override_jwt_location(app): app.config["JWT_TOKEN_LOCATION"] = ["cookies"]