-
-
Notifications
You must be signed in to change notification settings - Fork 240
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added support for custom JWT types #441
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -35,7 +35,9 @@ def _verify_token_is_fresh(jwt_header, jwt_data): | |
raise FreshTokenRequired("Fresh token required", jwt_header, jwt_data) | ||
|
||
|
||
def verify_jwt_in_request(optional=False, fresh=False, refresh=False, locations=None): | ||
def verify_jwt_in_request( | ||
optional=False, fresh=False, refresh=False, locations=None, token_type="access" | ||
): | ||
Comment on lines
+38
to
+40
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not sure I like having |
||
""" | ||
Verify that a valid JWT is present in the request, unless ``optional=True`` in | ||
which case no JWT is also considered valid. | ||
|
@@ -49,25 +51,31 @@ def verify_jwt_in_request(optional=False, fresh=False, refresh=False, locations= | |
Defaults to ``False``. | ||
|
||
:param refresh: | ||
If ``True``, require a refresh JWT to be verified. | ||
If ``True``, require a refresh JWT to be verified. If ``False``, compare | ||
the JWT type to ``token_type``. Defaults to ``False``. | ||
|
||
:param locations: | ||
A location or list of locations to look for the JWT in this request, for | ||
example ``'headers'`` or ``['headers', 'cookies']``. Defaluts to ``None`` | ||
which indicates that JWTs will be looked for in the locations defined by the | ||
``JWT_TOKEN_LOCATION`` configuration option. | ||
|
||
:param token_type: | ||
If ``refresh`` is ``False``, then the ``type`` claim on the JWT must exactly match | ||
this string in order to be verified. Defaults to ``"access"``. | ||
""" | ||
if request.method in config.exempt_methods: | ||
return | ||
|
||
try: | ||
if refresh: | ||
token_type = "refresh" | ||
jwt_data, jwt_header, jwt_location = _decode_jwt_from_request( | ||
locations, fresh, refresh=True | ||
locations, fresh, refresh=True, token_type=token_type | ||
) | ||
else: | ||
jwt_data, jwt_header, jwt_location = _decode_jwt_from_request( | ||
locations, fresh | ||
locations, fresh, refresh=False, token_type=token_type | ||
) | ||
except NoAuthorizationError: | ||
if not optional: | ||
|
@@ -88,7 +96,9 @@ def verify_jwt_in_request(optional=False, fresh=False, refresh=False, locations= | |
return jwt_header, jwt_data | ||
|
||
|
||
def jwt_required(optional=False, fresh=False, refresh=False, locations=None): | ||
def jwt_required( | ||
optional=False, fresh=False, refresh=False, locations=None, token_type="access" | ||
): | ||
""" | ||
A decorator to protect a Flask endpoint with JSON Web Tokens. | ||
|
||
|
@@ -106,19 +116,24 @@ def jwt_required(optional=False, fresh=False, refresh=False, locations=None): | |
|
||
:param refresh: | ||
If ``True``, requires a refresh JWT to access this endpoint. If ``False``, | ||
requires an access JWT to access this endpoint. Defaults to ``False``. | ||
requires a JWT specified by ``token_type`` to access | ||
this endpoint. Defaults to ``False``. | ||
|
||
:param locations: | ||
A location or list of locations to look for the JWT in this request, for | ||
example ``'headers'`` or ``['headers', 'cookies']``. Defaluts to ``None`` | ||
which indicates that JWTs will be looked for in the locations defined by the | ||
``JWT_TOKEN_LOCATION`` configuration option. | ||
|
||
:param token_type: | ||
If ``refresh`` is ``False``, then the ``type`` claim on the JWT must exactly match | ||
this string to access this endpoint. Defaults to ``"access"``. | ||
""" | ||
|
||
def wrapper(fn): | ||
@wraps(fn) | ||
def decorator(*args, **kwargs): | ||
verify_jwt_in_request(optional, fresh, refresh, locations) | ||
verify_jwt_in_request(optional, fresh, refresh, locations, token_type) | ||
|
||
# Compatibility with flask < 2.0 | ||
try: | ||
|
@@ -253,7 +268,7 @@ def _decode_jwt_from_json(refresh): | |
return encoded_token, None | ||
|
||
|
||
def _decode_jwt_from_request(locations, fresh, refresh=False): | ||
def _decode_jwt_from_request(locations, fresh, refresh=False, token_type="access"): | ||
# Figure out what locations to look for the JWT in this request | ||
if isinstance(locations, str): | ||
locations = [locations] | ||
|
@@ -312,7 +327,7 @@ def _decode_jwt_from_request(locations, fresh, refresh=False): | |
raise NoAuthorizationError(errors[0]) | ||
|
||
# Additional verifications provided by this extension | ||
verify_token_type(decoded_token, refresh) | ||
verify_token_type(decoded_token, refresh, token_type) | ||
if fresh: | ||
_verify_token_is_fresh(jwt_header, decoded_token) | ||
verify_token_not_blocklisted(jwt_header, decoded_token) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,6 +6,7 @@ | |
from flask import jsonify | ||
|
||
from flask_jwt_extended import create_access_token | ||
from flask_jwt_extended import create_custom_token | ||
from flask_jwt_extended import create_refresh_token | ||
from flask_jwt_extended import decode_token | ||
from flask_jwt_extended import get_jwt_identity | ||
|
@@ -46,6 +47,11 @@ def optional_protected(): | |
else: | ||
return jsonify(foo="bar") | ||
|
||
@app.route("/custom_protected", methods=["GET"]) | ||
@jwt_required(token_type="custom") | ||
def custom_protected(): | ||
return jsonify(foo="bar") | ||
|
||
return app | ||
|
||
|
||
|
@@ -72,7 +78,7 @@ def test_jwt_required(app): | |
# Test refresh token access to jwt_required | ||
response = test_client.get(url, headers=make_headers(refresh_token)) | ||
assert response.status_code == 422 | ||
assert response.get_json() == {"msg": "Only non-refresh tokens are allowed"} | ||
assert response.get_json() == {"msg": "Token of type refresh is not allowed"} | ||
|
||
|
||
def test_fresh_jwt_required(app): | ||
|
@@ -113,7 +119,7 @@ def test_fresh_jwt_required(app): | |
|
||
response = test_client.get(url, headers=make_headers(refresh_token)) | ||
assert response.status_code == 422 | ||
assert response.get_json() == {"msg": "Only non-refresh tokens are allowed"} | ||
assert response.get_json() == {"msg": "Token of type refresh is not allowed"} | ||
|
||
# Test with custom response | ||
@jwtM.needs_fresh_token_loader | ||
|
@@ -176,7 +182,7 @@ def test_jwt_optional(app, delta_func): | |
|
||
response = test_client.get(url, headers=make_headers(refresh_token)) | ||
assert response.status_code == 422 | ||
assert response.get_json() == {"msg": "Only non-refresh tokens are allowed"} | ||
assert response.get_json() == {"msg": "Token of type refresh is not allowed"} | ||
|
||
response = test_client.get(url, headers=make_headers(expired_token)) | ||
assert response.status_code == 401 | ||
|
@@ -229,6 +235,42 @@ def test_jwt_optional_with_no_valid_jwt(app): | |
assert response.get_json() == {"msg": "Not enough segments"} | ||
|
||
|
||
def test_custom_jwt_required(app): | ||
url = "/custom_protected" | ||
|
||
test_client = app.test_client() | ||
with app.test_request_context(): | ||
custom_token = create_custom_token("username", token_type="custom") | ||
fresh_custom_token = create_custom_token( | ||
"username", token_type="custom", fresh=True | ||
) | ||
refresh_token = create_refresh_token("username") | ||
incorrect_custom_token = create_custom_token( | ||
"username", token_type="other_custom" | ||
) | ||
|
||
# Access and fresh access should be able to access this | ||
for token in (custom_token, fresh_custom_token): | ||
response = test_client.get(url, headers=make_headers(token)) | ||
assert response.status_code == 200 | ||
assert response.get_json() == {"foo": "bar"} | ||
|
||
# Test accessing jwt_required with no jwt in the request | ||
response = test_client.get(url, headers=None) | ||
assert response.status_code == 401 | ||
assert response.get_json() == {"msg": "Missing Authorization Header"} | ||
|
||
# Test refresh token access to jwt_required | ||
response = test_client.get(url, headers=make_headers(refresh_token)) | ||
assert response.status_code == 422 | ||
assert response.get_json() == {"msg": "Token of type refresh is not allowed"} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not a big deal, but I wonder if the error message should say what token type is allowed instead of this specific token is now allowed? Might make for more discoverable errors? |
||
|
||
# Test refresh token access to jwt_required | ||
response = test_client.get(url, headers=make_headers(incorrect_custom_token)) | ||
assert response.status_code == 422 | ||
assert response.get_json() == {"msg": "Token of type other_custom is not allowed"} | ||
|
||
|
||
def test_override_jwt_location(app): | ||
app.config["JWT_TOKEN_LOCATION"] = ["cookies"] | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this might be a breaking change. I believe some people are using JWTs from other sources that set the
type
to other fields, and we implictly treat those as access tokens right now.I haven't fully thought this out, but I wonder if we instead set
token_type
tonull
in the view decorators andverify_jwt_in_request
, and then we preserve the old logic here unless atoken_type
is actually specified?