Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added support for custom JWT types #441

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions flask_jwt_extended/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .jwt_manager import JWTManager
from .utils import create_access_token
from .utils import create_custom_token
from .utils import create_refresh_token
from .utils import current_user
from .utils import decode_token
Expand Down
8 changes: 4 additions & 4 deletions flask_jwt_extended/internal_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,11 @@ def user_lookup(*args, **kwargs):
return jwt_manager._user_lookup_callback(*args, **kwargs)


def verify_token_type(decoded_token, refresh):
if not refresh and decoded_token["type"] == "refresh":
raise WrongTokenError("Only non-refresh tokens are allowed")
elif refresh and decoded_token["type"] != "refresh":
def verify_token_type(decoded_token, refresh, token_type):
if refresh and decoded_token["type"] != "refresh":
raise WrongTokenError("Only refresh tokens are allowed")
elif not refresh and decoded_token["type"] != token_type:
raise WrongTokenError(f"Token of type { decoded_token['type'] } is not allowed")
Comment on lines +31 to +32
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this might be a breaking change. I believe some people are using JWTs from other sources that set the type to other fields, and we implictly treat those as access tokens right now.

I haven't fully thought this out, but I wonder if we instead set token_type to null in the view decorators and verify_jwt_in_request, and then we preserve the old logic here unless a token_type is actually specified?



def verify_token_not_blocklisted(jwt_header, jwt_data):
Expand Down
6 changes: 3 additions & 3 deletions flask_jwt_extended/jwt_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,10 +487,10 @@ def _encode_jwt_from_config(
claim_overrides.update(claims)

if expires_delta is None:
if token_type == "access":
expires_delta = config.access_expires
else:
if token_type == "refresh":
expires_delta = config.refresh_expires
else:
expires_delta = config.access_expires

return _encode_jwt(
algorithm=config.algorithm,
Expand Down
52 changes: 52 additions & 0 deletions flask_jwt_extended/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,58 @@ def create_refresh_token(
)


def create_custom_token(
identity,
token_type,
fresh=False,
expires_delta=None,
additional_claims=None,
additional_headers=None,
):
"""
Create a new custom token, with a manually specified type.

:param identity:
The identity of this token. It can be any data that is json serializable.
You can use :meth:`~flask_jwt_extended.JWTManager.user_identity_loader`
to define a callback function to convert any object passed in into a json
serializable format.

:param token_type:
The type of this token. A string such as "refresh" or "access" that specifies
the type and purpose of this token.

:param expires_delta:
A ``datetime.timedelta`` for how long this token should last before it expires.
Set to False to disable expiration. If this is None, it will use the
``JWT_REFRESH_TOKEN_EXPIRES`` config value (see :ref:`Configuration Options`)

:param additional_claims:
Optional. A hash of claims to include in the refresh token. These claims are
merged into the default claims (exp, iat, etc) and claims returned from the
:meth:`~flask_jwt_extended.JWTManager.additional_claims_loader` callback.
On conflict, these claims take presidence.

:param headers:
Optional. A hash of headers to include in the refresh token. These headers
are merged into the default headers (alg, typ) and headers returned from the
:meth:`~flask_jwt_extended.JWTManager.additional_headers_loader` callback.
On conflict, these headers take presidence.

:return:
An encoded refresh token
"""
jwt_manager = get_jwt_manager()
return jwt_manager._encode_jwt_from_config(
claims=additional_claims,
expires_delta=expires_delta,
fresh=fresh,
headers=additional_headers,
identity=identity,
token_type=token_type,
)


def get_unverified_jwt_headers(encoded_token):
"""
Returns the Headers of an encoded JWT without verifying the signature of the JWT.
Expand Down
33 changes: 24 additions & 9 deletions flask_jwt_extended/view_decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,9 @@ def _verify_token_is_fresh(jwt_header, jwt_data):
raise FreshTokenRequired("Fresh token required", jwt_header, jwt_data)


def verify_jwt_in_request(optional=False, fresh=False, refresh=False, locations=None):
def verify_jwt_in_request(
optional=False, fresh=False, refresh=False, locations=None, token_type="access"
):
Comment on lines +38 to +40
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure I like having refresh and token_type in here, that seems confusing to understand from a top level API point of view. I'm not entirely sue how that could be handled without a breaking change though. Hmm... 🤔 🤔 🤔

"""
Verify that a valid JWT is present in the request, unless ``optional=True`` in
which case no JWT is also considered valid.
Expand All @@ -49,25 +51,31 @@ def verify_jwt_in_request(optional=False, fresh=False, refresh=False, locations=
Defaults to ``False``.

:param refresh:
If ``True``, require a refresh JWT to be verified.
If ``True``, require a refresh JWT to be verified. If ``False``, compare
the JWT type to ``token_type``. Defaults to ``False``.

:param locations:
A location or list of locations to look for the JWT in this request, for
example ``'headers'`` or ``['headers', 'cookies']``. Defaluts to ``None``
which indicates that JWTs will be looked for in the locations defined by the
``JWT_TOKEN_LOCATION`` configuration option.

:param token_type:
If ``refresh`` is ``False``, then the ``type`` claim on the JWT must exactly match
this string in order to be verified. Defaults to ``"access"``.
"""
if request.method in config.exempt_methods:
return

try:
if refresh:
token_type = "refresh"
jwt_data, jwt_header, jwt_location = _decode_jwt_from_request(
locations, fresh, refresh=True
locations, fresh, refresh=True, token_type=token_type
)
else:
jwt_data, jwt_header, jwt_location = _decode_jwt_from_request(
locations, fresh
locations, fresh, refresh=False, token_type=token_type
)
except NoAuthorizationError:
if not optional:
Expand All @@ -88,7 +96,9 @@ def verify_jwt_in_request(optional=False, fresh=False, refresh=False, locations=
return jwt_header, jwt_data


def jwt_required(optional=False, fresh=False, refresh=False, locations=None):
def jwt_required(
optional=False, fresh=False, refresh=False, locations=None, token_type="access"
):
"""
A decorator to protect a Flask endpoint with JSON Web Tokens.

Expand All @@ -106,19 +116,24 @@ def jwt_required(optional=False, fresh=False, refresh=False, locations=None):

:param refresh:
If ``True``, requires a refresh JWT to access this endpoint. If ``False``,
requires an access JWT to access this endpoint. Defaults to ``False``.
requires a JWT specified by ``token_type`` to access
this endpoint. Defaults to ``False``.

:param locations:
A location or list of locations to look for the JWT in this request, for
example ``'headers'`` or ``['headers', 'cookies']``. Defaluts to ``None``
which indicates that JWTs will be looked for in the locations defined by the
``JWT_TOKEN_LOCATION`` configuration option.

:param token_type:
If ``refresh`` is ``False``, then the ``type`` claim on the JWT must exactly match
this string to access this endpoint. Defaults to ``"access"``.
"""

def wrapper(fn):
@wraps(fn)
def decorator(*args, **kwargs):
verify_jwt_in_request(optional, fresh, refresh, locations)
verify_jwt_in_request(optional, fresh, refresh, locations, token_type)

# Compatibility with flask < 2.0
try:
Expand Down Expand Up @@ -253,7 +268,7 @@ def _decode_jwt_from_json(refresh):
return encoded_token, None


def _decode_jwt_from_request(locations, fresh, refresh=False):
def _decode_jwt_from_request(locations, fresh, refresh=False, token_type="access"):
# Figure out what locations to look for the JWT in this request
if isinstance(locations, str):
locations = [locations]
Expand Down Expand Up @@ -312,7 +327,7 @@ def _decode_jwt_from_request(locations, fresh, refresh=False):
raise NoAuthorizationError(errors[0])

# Additional verifications provided by this extension
verify_token_type(decoded_token, refresh)
verify_token_type(decoded_token, refresh, token_type)
if fresh:
_verify_token_is_fresh(jwt_header, decoded_token)
verify_token_not_blocklisted(jwt_header, decoded_token)
Expand Down
48 changes: 45 additions & 3 deletions tests/test_view_decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from flask import jsonify

from flask_jwt_extended import create_access_token
from flask_jwt_extended import create_custom_token
from flask_jwt_extended import create_refresh_token
from flask_jwt_extended import decode_token
from flask_jwt_extended import get_jwt_identity
Expand Down Expand Up @@ -46,6 +47,11 @@ def optional_protected():
else:
return jsonify(foo="bar")

@app.route("/custom_protected", methods=["GET"])
@jwt_required(token_type="custom")
def custom_protected():
return jsonify(foo="bar")

return app


Expand All @@ -72,7 +78,7 @@ def test_jwt_required(app):
# Test refresh token access to jwt_required
response = test_client.get(url, headers=make_headers(refresh_token))
assert response.status_code == 422
assert response.get_json() == {"msg": "Only non-refresh tokens are allowed"}
assert response.get_json() == {"msg": "Token of type refresh is not allowed"}


def test_fresh_jwt_required(app):
Expand Down Expand Up @@ -113,7 +119,7 @@ def test_fresh_jwt_required(app):

response = test_client.get(url, headers=make_headers(refresh_token))
assert response.status_code == 422
assert response.get_json() == {"msg": "Only non-refresh tokens are allowed"}
assert response.get_json() == {"msg": "Token of type refresh is not allowed"}

# Test with custom response
@jwtM.needs_fresh_token_loader
Expand Down Expand Up @@ -176,7 +182,7 @@ def test_jwt_optional(app, delta_func):

response = test_client.get(url, headers=make_headers(refresh_token))
assert response.status_code == 422
assert response.get_json() == {"msg": "Only non-refresh tokens are allowed"}
assert response.get_json() == {"msg": "Token of type refresh is not allowed"}

response = test_client.get(url, headers=make_headers(expired_token))
assert response.status_code == 401
Expand Down Expand Up @@ -229,6 +235,42 @@ def test_jwt_optional_with_no_valid_jwt(app):
assert response.get_json() == {"msg": "Not enough segments"}


def test_custom_jwt_required(app):
url = "/custom_protected"

test_client = app.test_client()
with app.test_request_context():
custom_token = create_custom_token("username", token_type="custom")
fresh_custom_token = create_custom_token(
"username", token_type="custom", fresh=True
)
refresh_token = create_refresh_token("username")
incorrect_custom_token = create_custom_token(
"username", token_type="other_custom"
)

# Access and fresh access should be able to access this
for token in (custom_token, fresh_custom_token):
response = test_client.get(url, headers=make_headers(token))
assert response.status_code == 200
assert response.get_json() == {"foo": "bar"}

# Test accessing jwt_required with no jwt in the request
response = test_client.get(url, headers=None)
assert response.status_code == 401
assert response.get_json() == {"msg": "Missing Authorization Header"}

# Test refresh token access to jwt_required
response = test_client.get(url, headers=make_headers(refresh_token))
assert response.status_code == 422
assert response.get_json() == {"msg": "Token of type refresh is not allowed"}
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not a big deal, but I wonder if the error message should say what token type is allowed instead of this specific token is now allowed? Might make for more discoverable errors?


# Test refresh token access to jwt_required
response = test_client.get(url, headers=make_headers(incorrect_custom_token))
assert response.status_code == 422
assert response.get_json() == {"msg": "Token of type other_custom is not allowed"}


def test_override_jwt_location(app):
app.config["JWT_TOKEN_LOCATION"] = ["cookies"]

Expand Down