From 39e9ed2dd78b68dbcd3b9a5dca427313cd2652ef Mon Sep 17 00:00:00 2001 From: Hsiaoming Yang Date: Sat, 14 Nov 2020 14:37:25 +0900 Subject: [PATCH 01/26] Refactor whole key design. --- authlib/jose/jwk.py | 3 +- authlib/jose/rfc7517/__init__.py | 6 +- authlib/jose/rfc7517/asymmetric_key.py | 192 ++++++++++++++++ authlib/jose/rfc7517/base_key.py | 110 ++++++++++ authlib/jose/rfc7517/jwk.py | 2 +- authlib/jose/rfc7517/key_set.py | 29 +++ authlib/jose/rfc7517/models.py | 156 ------------- authlib/jose/rfc7518/__init__.py | 3 - authlib/jose/rfc7518/ec_key.py | 75 +++---- authlib/jose/rfc7518/jws_algs.py | 2 +- authlib/jose/rfc7518/key_util.py | 78 ------- authlib/jose/rfc7518/oct_key.py | 56 +++-- authlib/jose/rfc7518/rsa_key.py | 90 ++++---- authlib/jose/rfc8037/okp_key.py | 76 ++----- tests/core/test_jose/test_jwk.py | 244 ++++++++++----------- tests/flask/test_client/test_user_mixin.py | 4 +- 16 files changed, 596 insertions(+), 530 deletions(-) create mode 100644 authlib/jose/rfc7517/asymmetric_key.py create mode 100644 authlib/jose/rfc7517/base_key.py create mode 100644 authlib/jose/rfc7517/key_set.py delete mode 100644 authlib/jose/rfc7517/models.py delete mode 100644 authlib/jose/rfc7518/key_util.py diff --git a/authlib/jose/jwk.py b/authlib/jose/jwk.py index 02dbbabe..2e3efb6b 100644 --- a/authlib/jose/jwk.py +++ b/authlib/jose/jwk.py @@ -15,5 +15,4 @@ def dumps(key, kty=None, **params): params['kty'] = kty key = JsonWebKey.import_key(key, params) - data = key.as_dict() - return data + return dict(key) diff --git a/authlib/jose/rfc7517/__init__.py b/authlib/jose/rfc7517/__init__.py index e2f1595e..d3fbbb2d 100644 --- a/authlib/jose/rfc7517/__init__.py +++ b/authlib/jose/rfc7517/__init__.py @@ -7,9 +7,11 @@ https://tools.ietf.org/html/rfc7517 """ -from .models import Key, KeySet from ._cryptography_key import load_pem_key +from .base_key import Key +from .asymmetric_key import AsymmetricKey +from .key_set import KeySet from .jwk import JsonWebKey -__all__ = ['Key', 'KeySet', 'JsonWebKey', 'load_pem_key'] +__all__ = ['Key', 'AsymmetricKey', 'KeySet', 'JsonWebKey', 'load_pem_key'] diff --git a/authlib/jose/rfc7517/asymmetric_key.py b/authlib/jose/rfc7517/asymmetric_key.py new file mode 100644 index 00000000..aaa36c65 --- /dev/null +++ b/authlib/jose/rfc7517/asymmetric_key.py @@ -0,0 +1,192 @@ +from authlib.common.encoding import ( + json_dumps, + to_bytes, +) +from cryptography.hazmat.primitives.serialization import ( + Encoding, PrivateFormat, PublicFormat, + BestAvailableEncryption, NoEncryption, +) +from ._cryptography_key import load_pem_key +from .base_key import Key + + +class AsymmetricKey(Key): + """This is the base class for a JSON Web Key.""" + PUBLIC_KEY_FIELDS = [] + PRIVATE_KEY_FIELDS = [] + PRIVATE_KEY_CLS = bytes + PUBLIC_KEY_CLS = bytes + SSH_PUBLIC_PREFIX = b'' + + def __init__(self, private_key=None, public_key=None, options=None): + super(AsymmetricKey, self).__init__(options) + self.private_key = private_key + self.public_key = public_key + + @property + def public_only(self): + if self.private_key: + return False + if 'd' in self.tokens: + return False + return True + + def get_op_key(self, operation): + """Get the raw key for the given key_op. This method will also + check if the given key_op is supported by this key. + + :param operation: key operation value, such as "sign", "encrypt". + :return: raw key + """ + self.check_key_op(operation) + if operation in self.PUBLIC_KEY_OPS: + return self.get_public_key() + return self.get_private_key() + + def get_public_key(self): + if self.public_key: + return self.public_key + + private_key = self.get_private_key() + if private_key: + return private_key.public_key() + + return self.public_key + + def get_private_key(self): + if self.private_key: + return self.private_key + + if self.tokens: + self.load_raw_key() + return self.private_key + + def load_raw_key(self): + if 'd' in self.tokens: + self.private_key = self.load_private_key() + else: + self.public_key = self.load_public_key() + + def load_dict_key(self): + if self.private_key: + self._dict_data.update(self.dumps_private_key()) + else: + self._dict_data.update(self.dumps_public_key()) + + def dumps_private_key(self): + raise NotImplementedError() + + def dumps_public_key(self): + raise NotImplementedError() + + def load_private_key(self): + raise NotImplementedError() + + def load_public_key(self): + raise NotImplementedError() + + def as_dict(self, is_private=False): + """Represent this key as a dict of the JSON Web Key.""" + tokens = self.tokens + if is_private and 'd' not in tokens: + raise ValueError('This is a public key') + + kid = tokens.get('kid') + if 'd' in tokens and not is_private: + # filter out private fields + tokens = {k: tokens[k] for k in tokens if k in self.PUBLIC_KEY_FIELDS} + if kid: + tokens['kid'] = kid + + if not kid: + tokens['kid'] = self.thumbprint() + return tokens + + def as_key(self, is_private=False): + """Represent this key as raw key.""" + if is_private: + return self.get_private_key() + return self.get_public_key() + + def as_json(self, is_private=False): + """Represent this key as a JSON string.""" + obj = self.as_dict(is_private) + return json_dumps(obj) + + def as_bytes(self, encoding=None, is_private=False, password=None): + """Export key into PEM/DER format bytes. + + :param encoding: "PEM" or "DER" + :param is_private: export private key or public key + :param password: encrypt private key with password + :return: bytes + """ + + if encoding is None or encoding == 'PEM': + encoding = Encoding.PEM + elif encoding == 'DER': + encoding = Encoding.DER + else: + raise ValueError('Invalid encoding: {!r}'.format(encoding)) + + raw_key = self.as_key(is_private) + if is_private: + if not raw_key: + raise ValueError('This is a public key') + if password is None: + encryption_algorithm = NoEncryption() + else: + encryption_algorithm = BestAvailableEncryption(to_bytes(password)) + return raw_key.private_bytes( + encoding=encoding, + format=PrivateFormat.PKCS8, + encryption_algorithm=encryption_algorithm, + ) + return raw_key.public_bytes( + encoding=encoding, + format=PublicFormat.SubjectPublicKeyInfo, + ) + + def as_pem(self, is_private=False, password=None): + return self.as_bytes(is_private=is_private, password=password) + + def as_der(self, is_private=False, password=None): + return self.as_bytes(encoding='DER', is_private=is_private, password=password) + + @classmethod + def import_dict_key(cls, raw, options=None): + cls.check_required_fields(raw) + key = cls(options=options) + key._dict_data = raw + return key + + @classmethod + def import_key(cls, raw, options=None): + if isinstance(raw, cls): + if options is not None: + raw.options.update(options) + return raw + + if isinstance(raw, cls.PUBLIC_KEY_CLS): + key = cls(public_key=raw, options=options) + elif isinstance(raw, cls.PRIVATE_KEY_CLS): + key = cls(private_key=raw, options=options) + elif isinstance(raw, dict): + key = cls.import_dict_key(raw, options) + else: + if options is not None: + password = options.pop('password', None) + else: + password = None + raw_key = load_pem_key(raw, cls.SSH_PUBLIC_PREFIX, password=password) + if isinstance(raw_key, cls.PUBLIC_KEY_CLS): + key = cls(public_key=raw_key, options=options) + elif isinstance(raw_key, cls.PRIVATE_KEY_CLS): + key = cls(private_key=raw_key, options=options) + else: + raise ValueError('Invalid data for importing key') + return key + + @classmethod + def generate_key(cls, crv_or_size, options=None, is_private=False): + raise NotImplementedError() diff --git a/authlib/jose/rfc7517/base_key.py b/authlib/jose/rfc7517/base_key.py new file mode 100644 index 00000000..c89c41e0 --- /dev/null +++ b/authlib/jose/rfc7517/base_key.py @@ -0,0 +1,110 @@ +import hashlib +from collections import OrderedDict +from authlib.common.encoding import ( + json_dumps, + to_bytes, + to_unicode, + urlsafe_b64encode, +) +from ..errors import InvalidUseError + + +class Key(object): + """This is the base class for a JSON Web Key.""" + kty = '_' + + ALLOWED_PARAMS = [ + 'use', 'key_ops', 'alg', 'kid', + 'x5u', 'x5c', 'x5t', 'x5t#S256' + ] + + PRIVATE_KEY_OPS = [ + 'sign', 'decrypt', 'unwrapKey', + ] + PUBLIC_KEY_OPS = [ + 'verify', 'encrypt', 'wrapKey', + ] + + REQUIRED_JSON_FIELDS = [] + + def __init__(self, options=None): + self.options = options or {} + self._dict_data = {} + + @property + def tokens(self): + if not self._dict_data: + self.load_dict_key() + + rv = dict(self._dict_data) + rv['kty'] = self.kty + for k in self.ALLOWED_PARAMS: + if k not in rv and k in self.options: + rv[k] = self.options[k] + return rv + + def keys(self): + return self.tokens.keys() + + def __getitem__(self, item): + return self.tokens[item] + + @property + def public_only(self): + raise NotImplementedError() + + def load_raw_key(self): + raise NotImplementedError() + + def load_dict_key(self): + raise NotImplementedError() + + def check_key_op(self, operation): + """Check if the given key_op is supported by this key. + + :param operation: key operation value, such as "sign", "encrypt". + :raise: ValueError + """ + key_ops = self.tokens.get('key_ops') + if key_ops is not None and operation not in key_ops: + raise ValueError('Unsupported key_op "{}"'.format(operation)) + + if operation in self.PRIVATE_KEY_OPS and self.public_only: + raise ValueError('Invalid key_op "{}" for public key'.format(operation)) + + use = self.tokens.get('use') + if use: + if operation in ['sign', 'verify']: + if use != 'sig': + raise InvalidUseError() + elif operation in ['decrypt', 'encrypt', 'wrapKey', 'unwrapKey']: + if use != 'enc': + raise InvalidUseError() + + def as_dict(self, is_private=False): + raise NotImplementedError() + + def as_json(self, is_private=False): + """Represent this key as a JSON string.""" + obj = self.as_dict(is_private) + return json_dumps(obj) + + def thumbprint(self): + """Implementation of RFC7638 JSON Web Key (JWK) Thumbprint.""" + fields = list(self.REQUIRED_JSON_FIELDS) + fields.append('kty') + fields.sort() + data = OrderedDict() + + for k in fields: + data[k] = self.tokens[k] + + json_data = json_dumps(data) + digest_data = hashlib.sha256(to_bytes(json_data)).digest() + return to_unicode(urlsafe_b64encode(digest_data)) + + @classmethod + def check_required_fields(cls, data): + for k in cls.REQUIRED_JSON_FIELDS: + if k not in data: + raise ValueError('Missing required field: "{}"'.format(k)) diff --git a/authlib/jose/rfc7517/jwk.py b/authlib/jose/rfc7517/jwk.py index 99c7a59c..576c4e83 100644 --- a/authlib/jose/rfc7517/jwk.py +++ b/authlib/jose/rfc7517/jwk.py @@ -1,6 +1,6 @@ from authlib.common.encoding import json_loads +from .key_set import KeySet from ._cryptography_key import load_pem_key -from .models import KeySet class JsonWebKey(object): diff --git a/authlib/jose/rfc7517/key_set.py b/authlib/jose/rfc7517/key_set.py new file mode 100644 index 00000000..d7cb2a88 --- /dev/null +++ b/authlib/jose/rfc7517/key_set.py @@ -0,0 +1,29 @@ +from authlib.common.encoding import json_dumps + + +class KeySet(object): + """This class represents a JSON Web Key Set.""" + + def __init__(self, keys): + self.keys = keys + + def as_dict(self, is_private=False): + """Represent this key as a dict of the JSON Web Key Set.""" + return {'keys': [k.as_dict(is_private) for k in self.keys]} + + def as_json(self, is_private=False): + """Represent this key set as a JSON string.""" + obj = self.as_dict(is_private) + return json_dumps(obj) + + def find_by_kid(self, kid): + """Find the key matches the given kid value. + + :param kid: A string of kid + :return: Key instance + :raise: ValueError + """ + for k in self.keys: + if k.tokens.get('kid') == kid: + return k + raise ValueError('Invalid JSON Web Key Set') diff --git a/authlib/jose/rfc7517/models.py b/authlib/jose/rfc7517/models.py deleted file mode 100644 index b3b24f32..00000000 --- a/authlib/jose/rfc7517/models.py +++ /dev/null @@ -1,156 +0,0 @@ -import hashlib -from collections import OrderedDict -from authlib.common.encoding import ( - json_dumps, - to_bytes, - to_unicode, - urlsafe_b64encode, -) -from ..errors import InvalidUseError - - -class Key(dict): - """This is the base class for a JSON Web Key.""" - kty = '_' - - ALLOWED_PARAMS = [ - 'use', 'key_ops', 'alg', 'kid', - 'x5u', 'x5c', 'x5t', 'x5t#S256' - ] - - PRIVATE_KEY_OPS = [ - 'sign', 'decrypt', 'unwrapKey', - ] - PUBLIC_KEY_OPS = [ - 'verify', 'encrypt', 'wrapKey', - ] - - REQUIRED_JSON_FIELDS = [] - RAW_KEY_CLS = bytes - - def __init__(self, payload): - super(Key, self).__init__(payload) - - self.key_type = 'secret' - self.raw_key = None - - def get_op_key(self, operation): - """Get the raw key for the given key_op. This method will also - check if the given key_op is supported by this key. - - :param operation: key operation value, such as "sign", "encrypt". - :return: raw key - """ - self.check_key_op(operation) - if operation in self.PUBLIC_KEY_OPS: - return self.get_public_key() - return self.get_private_key() - - def get_public_key(self): - if self.key_type == 'private': - return self.raw_key.public_key() - return self.raw_key - - def get_private_key(self): - if self.key_type == 'private': - return self.raw_key - - def check_key_op(self, operation): - """Check if the given key_op is supported by this key. - - :param operation: key operation value, such as "sign", "encrypt". - :raise: ValueError - """ - key_ops = self.get('key_ops') - if key_ops is not None and operation not in key_ops: - raise ValueError('Unsupported key_op "{}"'.format(operation)) - - if operation in self.PRIVATE_KEY_OPS and self.key_type == 'public': - raise ValueError('Invalid key_op "{}" for public key'.format(operation)) - - use = self.get('use') - if use: - if operation in ['sign', 'verify']: - if use != 'sig': - raise InvalidUseError() - elif operation in ['decrypt', 'encrypt', 'wrapKey', 'unwrapKey']: - if use != 'enc': - raise InvalidUseError() - - def as_key(self): - """Represent this key as raw key.""" - return self.raw_key - - def as_dict(self, add_kid=False): - """Represent this key as a dict of the JSON Web Key.""" - obj = dict(self) - obj['kty'] = self.kty - if add_kid and 'kid' not in obj: - obj['kid'] = self.thumbprint() - return obj - - def as_json(self): - """Represent this key as a JSON string.""" - obj = self.as_dict() - return json_dumps(obj) - - def as_pem(self): - """Represent this key as string in PEM format.""" - raise RuntimeError('Not supported') - - def thumbprint(self): - """Implementation of RFC7638 JSON Web Key (JWK) Thumbprint.""" - fields = list(self.REQUIRED_JSON_FIELDS) - fields.append('kty') - fields.sort() - data = OrderedDict() - - obj = self.as_dict() - for k in fields: - data[k] = obj[k] - - json_data = json_dumps(data) - digest_data = hashlib.sha256(to_bytes(json_data)).digest() - return to_unicode(urlsafe_b64encode(digest_data)) - - @classmethod - def check_required_fields(cls, data): - for k in cls.REQUIRED_JSON_FIELDS: - if k not in data: - raise ValueError('Missing required field: "{}"'.format(k)) - - @classmethod - def generate_key(cls, crv_or_size, options=None, is_private=False): - raise NotImplementedError() - - @classmethod - def import_key(cls, raw, options=None): - raise NotImplementedError() - - -class KeySet(object): - """This class represents a JSON Web Key Set.""" - - def __init__(self, keys): - self.keys = keys - - def as_dict(self): - """Represent this key as a dict of the JSON Web Key Set.""" - return {'keys': [k.as_dict(True) for k in self.keys]} - - def as_json(self): - """Represent this key set as a JSON string.""" - obj = self.as_dict() - return json_dumps(obj) - - def find_by_kid(self, kid): - """Find the key matches the given kid value. - - :param kid: A string of kid - :return: Key instance - :raise: ValueError - """ - for k in self.keys: - if k.get('kid') == kid: - return k - raise ValueError('Invalid JSON Web Key Set') diff --git a/authlib/jose/rfc7518/__init__.py b/authlib/jose/rfc7518/__init__.py index 35c80845..4ffd514e 100644 --- a/authlib/jose/rfc7518/__init__.py +++ b/authlib/jose/rfc7518/__init__.py @@ -1,7 +1,6 @@ from .oct_key import OctKey from .rsa_key import RSAKey from .ec_key import ECKey -from .key_util import import_key, export_key from .jws_algs import JWS_ALGORITHMS from .jwe_algs import JWE_ALG_ALGORITHMS, ECDHAlgorithm from .jwe_encs import JWE_ENC_ALGORITHMS @@ -30,6 +29,4 @@ def register_jwe_rfc7518(cls): 'RSAKey', 'ECKey', 'ECDHAlgorithm', - 'import_key', - 'export_key', ] diff --git a/authlib/jose/rfc7518/ec_key.py b/authlib/jose/rfc7518/ec_key.py index 61fb46cd..d0b11540 100644 --- a/authlib/jose/rfc7518/ec_key.py +++ b/authlib/jose/rfc7518/ec_key.py @@ -6,11 +6,10 @@ ) from cryptography.hazmat.backends import default_backend from authlib.common.encoding import base64_to_int, int_to_base64 -from .key_util import export_key, import_key -from ..rfc7517 import Key +from ..rfc7517 import AsymmetricKey -class ECKey(Key): +class ECKey(AsymmetricKey): """Key class of the ``EC`` key type.""" kty = 'EC' @@ -28,83 +27,67 @@ class ECKey(Key): SECP256K1.name: 'secp256k1', } REQUIRED_JSON_FIELDS = ['crv', 'x', 'y'] - RAW_KEY_CLS = (EllipticCurvePublicKey, EllipticCurvePrivateKeyWithSerialization) - def as_pem(self, is_private=False, password=None): - """Export key into PEM format bytes. + PUBLIC_KEY_FIELDS = REQUIRED_JSON_FIELDS + PRIVATE_KEY_FIELDS = ['crv', 'd', 'x', 'y'] - :param is_private: export private key or public key - :param password: encrypt private key with password - :return: bytes - """ - return export_key(self, is_private=is_private, password=password) + PUBLIC_KEY_CLS = EllipticCurvePublicKey + PRIVATE_KEY_CLS = EllipticCurvePrivateKeyWithSerialization + SSH_PUBLIC_PREFIX = b'ecdsa-sha2-' def exchange_shared_key(self, pubkey): # # used in ECDHAlgorithm - if isinstance(self.raw_key, EllipticCurvePrivateKeyWithSerialization): - return self.raw_key.exchange(ec.ECDH(), pubkey) + private_key = self.get_private_key() + if private_key: + return private_key.exchange(ec.ECDH(), pubkey) raise ValueError('Invalid key for exchanging shared key') - @property - def curve_name(self): - return self.CURVES_DSS[self.raw_key.curve.name] - @property def curve_key_size(self): - return self.raw_key.curve.key_size + raw_key = self.get_private_key() + if not raw_key: + raw_key = self.public_key + return raw_key.curve.key_size - @classmethod - def loads_private_key(cls, obj): - curve = cls.DSS_CURVES[obj['crv']]() + def load_private_key(self): + curve = self.DSS_CURVES[self._dict_data['crv']]() public_numbers = EllipticCurvePublicNumbers( - base64_to_int(obj['x']), - base64_to_int(obj['y']), + base64_to_int(self._dict_data['x']), + base64_to_int(self._dict_data['y']), curve, ) private_numbers = EllipticCurvePrivateNumbers( - base64_to_int(obj['d']), + base64_to_int(self.tokens['d']), public_numbers ) return private_numbers.private_key(default_backend()) - @classmethod - def loads_public_key(cls, obj): - curve = cls.DSS_CURVES[obj['crv']]() + def load_public_key(self): + curve = self.DSS_CURVES[self._dict_data['crv']]() public_numbers = EllipticCurvePublicNumbers( - base64_to_int(obj['x']), - base64_to_int(obj['y']), + base64_to_int(self._dict_data['x']), + base64_to_int(self._dict_data['y']), curve, ) return public_numbers.public_key(default_backend()) - @classmethod - def dumps_private_key(cls, raw_key): - numbers = raw_key.private_numbers() + def dumps_private_key(self): + numbers = self.private_key.private_numbers() return { - 'crv': cls.CURVES_DSS[raw_key.curve.name], + 'crv': self.CURVES_DSS[self.private_key.curve.name], 'x': int_to_base64(numbers.public_numbers.x), 'y': int_to_base64(numbers.public_numbers.y), 'd': int_to_base64(numbers.private_value), } - @classmethod - def dumps_public_key(cls, raw_key): - numbers = raw_key.public_numbers() + def dumps_public_key(self): + numbers = self.public_key.public_numbers() return { - 'crv': cls.CURVES_DSS[numbers.curve.name], + 'crv': self.CURVES_DSS[numbers.curve.name], 'x': int_to_base64(numbers.x), 'y': int_to_base64(numbers.y) } - @classmethod - def import_key(cls, raw, options=None) -> 'ECKey': - """Import a key from PEM or dict data.""" - return import_key( - cls, raw, - EllipticCurvePublicKey, EllipticCurvePrivateKeyWithSerialization, - b'ecdsa-sha2-', options - ) - @classmethod def generate_key(cls, crv='P-256', options=None, is_private=False) -> 'ECKey': if crv not in cls.DSS_CURVES: diff --git a/authlib/jose/rfc7518/jws_algs.py b/authlib/jose/rfc7518/jws_algs.py index d2749520..eae8a9d6 100644 --- a/authlib/jose/rfc7518/jws_algs.py +++ b/authlib/jose/rfc7518/jws_algs.py @@ -120,7 +120,7 @@ def __init__(self, name, curve, sha_type): def prepare_key(self, raw_data): key = ECKey.import_key(raw_data) - if key.curve_name != self.curve: + if key['crv'] != self.curve: raise ValueError(f'Key for "{self.name}" not supported, only "{self.curve}" allowed') return key diff --git a/authlib/jose/rfc7518/key_util.py b/authlib/jose/rfc7518/key_util.py deleted file mode 100644 index a53f42d3..00000000 --- a/authlib/jose/rfc7518/key_util.py +++ /dev/null @@ -1,78 +0,0 @@ -from cryptography.hazmat.primitives.serialization import ( - Encoding, PrivateFormat, PublicFormat, - BestAvailableEncryption, NoEncryption, -) -from authlib.common.encoding import to_bytes -from ..rfc7517 import load_pem_key - - -def import_key(cls, raw, public_key_cls, private_key_cls, ssh_type=None, options=None): - if isinstance(raw, cls): - if options is not None: - raw.update(options) - return raw - - payload = None - if isinstance(raw, (public_key_cls, private_key_cls)): - raw_key = raw - elif isinstance(raw, dict): - cls.check_required_fields(raw) - payload = raw - if 'd' in payload: - raw_key = cls.loads_private_key(payload) - else: - raw_key = cls.loads_public_key(payload) - else: - if options is not None: - password = options.get('password') - else: - password = None - raw_key = load_pem_key(raw, ssh_type, password=password) - - if isinstance(raw_key, private_key_cls): - if payload is None: - payload = cls.dumps_private_key(raw_key) - key_type = 'private' - elif isinstance(raw_key, public_key_cls): - if payload is None: - payload = cls.dumps_public_key(raw_key) - key_type = 'public' - else: - raise ValueError('Invalid data for importing key') - - obj = cls(payload) - obj.raw_key = raw_key - obj.key_type = key_type - return obj - - -def export_key(key, encoding=None, is_private=False, password=None): - if encoding is None or encoding == 'PEM': - encoding = Encoding.PEM - elif encoding == 'DER': - encoding = Encoding.DER - else: - raise ValueError('Invalid encoding: {!r}'.format(encoding)) - - if is_private: - if key.key_type == 'private': - if password is None: - encryption_algorithm = NoEncryption() - else: - encryption_algorithm = BestAvailableEncryption(to_bytes(password)) - return key.raw_key.private_bytes( - encoding=encoding, - format=PrivateFormat.PKCS8, - encryption_algorithm=encryption_algorithm, - ) - raise ValueError('This is a public key') - - if key.key_type == 'private': - raw_key = key.raw_key.public_key() - else: - raw_key = key.raw_key - - return raw_key.public_bytes( - encoding=encoding, - format=PublicFormat.SubjectPublicKeyInfo, - ) diff --git a/authlib/jose/rfc7518/oct_key.py b/authlib/jose/rfc7518/oct_key.py index a095ada4..12c5415d 100644 --- a/authlib/jose/rfc7518/oct_key.py +++ b/authlib/jose/rfc7518/oct_key.py @@ -3,7 +3,7 @@ urlsafe_b64encode, urlsafe_b64decode, ) from authlib.common.security import generate_token -from authlib.jose.rfc7517 import Key +from ..rfc7517 import Key class OctKey(Key): @@ -12,29 +12,55 @@ class OctKey(Key): kty = 'oct' REQUIRED_JSON_FIELDS = ['k'] - def get_op_key(self, key_op): - self.check_key_op(key_op) + def __init__(self, raw_key=None, options=None): + super(OctKey, self).__init__(options) + self.raw_key = raw_key + + @property + def public_only(self): + return False + + def get_op_key(self, operation): + """Get the raw key for the given key_op. This method will also + check if the given key_op is supported by this key. + + :param operation: key operation value, such as "sign", "encrypt". + :return: raw key + """ + self.check_key_op(operation) + if not self.raw_key: + self.load_raw_key() return self.raw_key + def load_raw_key(self): + self.raw_key = urlsafe_b64decode(to_bytes(self.tokens['k'])) + + def load_dict_key(self): + k = to_unicode(urlsafe_b64encode(self.raw_key)) + self._dict_data = {'kty': self.kty, 'k': k} + + def as_dict(self, is_private=False): + tokens = self.tokens + if 'kid' not in tokens: + tokens['kid'] = self.thumbprint() + return tokens + @classmethod def import_key(cls, raw, options=None): """Import a key from bytes, string, or dict data.""" + if isinstance(raw, cls): + if options is not None: + raw.options.update(options) + return raw + if isinstance(raw, dict): cls.check_required_fields(raw) - payload = raw - raw_key = urlsafe_b64decode(to_bytes(payload['k'])) + key = cls(options=options) + key._dict_data = raw else: raw_key = to_bytes(raw) - k = to_unicode(urlsafe_b64encode(raw_key)) - payload = {'k': k} - - if options is not None: - payload.update(options) - - obj = cls(payload) - obj.raw_key = raw_key - obj.key_type = 'secret' - return obj + key = cls(raw_key=raw_key, options=options) + return key @classmethod def generate_key(cls, key_size=256, options=None, is_private=False): diff --git a/authlib/jose/rfc7518/rsa_key.py b/authlib/jose/rfc7518/rsa_key.py index 4e9bcc74..53bd9958 100644 --- a/authlib/jose/rfc7518/rsa_key.py +++ b/authlib/jose/rfc7518/rsa_key.py @@ -6,29 +6,23 @@ ) from cryptography.hazmat.backends import default_backend from authlib.common.encoding import base64_to_int, int_to_base64 -from .key_util import export_key, import_key -from ..rfc7517 import Key +from ..rfc7517 import AsymmetricKey -class RSAKey(Key): +class RSAKey(AsymmetricKey): """Key class of the ``RSA`` key type.""" kty = 'RSA' - RAW_KEY_CLS = (RSAPublicKey, RSAPrivateKeyWithSerialization) - REQUIRED_JSON_FIELDS = ['e', 'n'] - - def as_pem(self, is_private=False, password=None): - """Export key into PEM format bytes. + PUBLIC_KEY_CLS = RSAPublicKey + PRIVATE_KEY_CLS = RSAPrivateKeyWithSerialization - :param is_private: export private key or public key - :param password: encrypt private key with password - :return: bytes - """ - return export_key(self, is_private=is_private, password=password) + PUBLIC_KEY_FIELDS = ['e', 'n'] + PRIVATE_KEY_FIELDS = ['d', 'dp', 'dq', 'e', 'n', 'p', 'q', 'qi'] + REQUIRED_JSON_FIELDS = ['e', 'n'] + SSH_PUBLIC_PREFIX = b'ssh-rsa' - @staticmethod - def dumps_private_key(raw_key): - numbers = raw_key.private_numbers() + def dumps_private_key(self): + numbers = self.private_key.private_numbers() return { 'n': int_to_base64(numbers.public_numbers.n), 'e': int_to_base64(numbers.public_numbers.e), @@ -40,33 +34,24 @@ def dumps_private_key(raw_key): 'qi': int_to_base64(numbers.iqmp) } - @staticmethod - def dumps_public_key(raw_key): - numbers = raw_key.public_numbers() + def dumps_public_key(self): + numbers = self.public_key.public_numbers() return { 'n': int_to_base64(numbers.n), 'e': int_to_base64(numbers.e) } - @staticmethod - def loads_private_key(obj): + def load_private_key(self): + obj = self._dict_data + if 'oth' in obj: # pragma: no cover # https://tools.ietf.org/html/rfc7518#section-6.3.2.7 raise ValueError('"oth" is not supported yet') - props = ['p', 'q', 'dp', 'dq', 'qi'] - props_found = [prop in obj for prop in props] - any_props_found = any(props_found) - - if any_props_found and not all(props_found): - raise ValueError( - 'RSA key must include all parameters ' - 'if any are present besides d') - public_numbers = RSAPublicNumbers( base64_to_int(obj['e']), base64_to_int(obj['n'])) - if any_props_found: + if has_all_prime_factors(obj): numbers = RSAPrivateNumbers( d=base64_to_int(obj['d']), p=base64_to_int(obj['p']), @@ -90,25 +75,15 @@ def loads_private_key(obj): return numbers.private_key(default_backend()) - @staticmethod - def loads_public_key(obj): + def load_public_key(self): numbers = RSAPublicNumbers( - base64_to_int(obj['e']), - base64_to_int(obj['n']) + base64_to_int(self._dict_data['e']), + base64_to_int(self._dict_data['n']) ) return numbers.public_key(default_backend()) @classmethod - def import_key(cls, raw, options=None): - """Import a key from PEM or dict data.""" - return import_key( - cls, raw, - RSAPublicKey, RSAPrivateKeyWithSerialization, - b'ssh-rsa', options - ) - - @classmethod - def generate_key(cls, key_size=2048, options=None, is_private=False): + def generate_key(cls, key_size=2048, options=None, is_private=False) -> 'RSAKey': if key_size < 512: raise ValueError('key_size must not be less than 512') if key_size % 8 != 0: @@ -121,3 +96,28 @@ def generate_key(cls, key_size=2048, options=None, is_private=False): if not is_private: raw_key = raw_key.public_key() return cls.import_key(raw_key, options=options) + + @classmethod + def import_dict_key(cls, raw, options=None): + cls.check_required_fields(raw) + key = cls(options=options) + key._dict_data = raw + if 'd' in raw and not has_all_prime_factors(raw): + # reload dict key + key.load_raw_key() + key.load_dict_key() + return key + + +def has_all_prime_factors(obj): + props = ['p', 'q', 'dp', 'dq', 'qi'] + props_found = [prop in obj for prop in props] + if all(props_found): + return True + + if any(props_found): + raise ValueError( + 'RSA key must include all parameters ' + 'if any are present besides d') + + return False diff --git a/authlib/jose/rfc8037/okp_key.py b/authlib/jose/rfc8037/okp_key.py index d8438b3b..1a70c6d9 100644 --- a/authlib/jose/rfc8037/okp_key.py +++ b/authlib/jose/rfc8037/okp_key.py @@ -17,8 +17,7 @@ to_unicode, to_bytes, urlsafe_b64decode, urlsafe_b64encode, ) -from authlib.jose.rfc7517 import Key -from ..rfc7518 import import_key, export_key +from ..rfc7517 import AsymmetricKey PUBLIC_KEYS_MAP = { @@ -33,41 +32,25 @@ 'X25519': X25519PrivateKey, 'X448': X448PrivateKey, } -PUBLIC_KEY_TUPLE = tuple(PUBLIC_KEYS_MAP.values()) -PRIVATE_KEY_TUPLE = tuple(PRIVATE_KEYS_MAP.values()) -class OKPKey(Key): +class OKPKey(AsymmetricKey): """Key class of the ``OKP`` key type.""" kty = 'OKP' REQUIRED_JSON_FIELDS = ['crv', 'x'] - RAW_KEY_CLS = ( - Ed25519PublicKey, Ed25519PrivateKey, - Ed448PublicKey, Ed448PrivateKey, - X25519PublicKey, X25519PrivateKey, - X448PublicKey, X448PrivateKey, - ) - - def as_pem(self, is_private=False, password=None): - """Export key into PEM format bytes. - - :param is_private: export private key or public key - :param password: encrypt private key with password - :return: bytes - """ - return export_key(self, is_private=is_private, password=password) + PUBLIC_KEY_FIELDS = REQUIRED_JSON_FIELDS + PRIVATE_KEY_FIELDS = ['crv', 'd'] + PUBLIC_KEY_CLS = tuple(PUBLIC_KEYS_MAP.values()) + PRIVATE_KEY_CLS = tuple(PRIVATE_KEYS_MAP.values()) + SSH_PUBLIC_PREFIX = b'ssh-ed25519' def exchange_shared_key(self, pubkey): # used in ECDHAlgorithm - if isinstance(self.raw_key, (X25519PrivateKey, X448PrivateKey)): - return self.raw_key.exchange(pubkey) + if self.private_key and isinstance(self.private_key, (X25519PrivateKey, X448PrivateKey)): + return self.private_key.exchange(pubkey) raise ValueError('Invalid key for exchanging shared key') - @property - def curve_key_size(self): - raise NotImplementedError() - @staticmethod def get_key_curve(key): if isinstance(key, (Ed25519PublicKey, Ed25519PrivateKey)): @@ -79,22 +62,19 @@ def get_key_curve(key): elif isinstance(key, (X448PublicKey, X448PrivateKey)): return 'X448' - @staticmethod - def loads_private_key(obj): - crv_key = PRIVATE_KEYS_MAP[obj['crv']] - d_bytes = urlsafe_b64decode(to_bytes(obj['d'])) + def load_private_key(self): + crv_key = PRIVATE_KEYS_MAP[self._dict_data['crv']] + d_bytes = urlsafe_b64decode(to_bytes(self._dict_data['d'])) return crv_key.from_private_bytes(d_bytes) - @staticmethod - def loads_public_key(obj): - crv_key = PUBLIC_KEYS_MAP[obj['crv']] - x_bytes = urlsafe_b64decode(to_bytes(obj['x'])) + def load_public_key(self): + crv_key = PUBLIC_KEYS_MAP[self._dict_data['crv']] + x_bytes = urlsafe_b64decode(to_bytes(self._dict_data['x'])) return crv_key.from_public_bytes(x_bytes) - @staticmethod - def dumps_private_key(raw_key): - obj = OKPKey.dumps_public_key(raw_key.public_key()) - d_bytes = raw_key.private_bytes( + def dumps_private_key(self): + obj = self.dumps_public_key(self.private_key.public_key()) + d_bytes = self.private_key.private_bytes( Encoding.Raw, PrivateFormat.Raw, NoEncryption() @@ -102,25 +82,17 @@ def dumps_private_key(raw_key): obj['d'] = to_unicode(urlsafe_b64encode(d_bytes)) return obj - @staticmethod - def dumps_public_key(raw_key): - x_bytes = raw_key.public_bytes(Encoding.Raw, PublicFormat.Raw) + def dumps_public_key(self, public_key=None): + if public_key is None: + public_key = self.public_key + x_bytes = public_key.public_bytes(Encoding.Raw, PublicFormat.Raw) return { - 'crv': OKPKey.get_key_curve(raw_key), + 'crv': self.get_key_curve(public_key), 'x': to_unicode(urlsafe_b64encode(x_bytes)), } @classmethod - def import_key(cls, raw, options=None): - """Import a key from PEM or dict data.""" - return import_key( - cls, raw, - PUBLIC_KEY_TUPLE, PRIVATE_KEY_TUPLE, - b'ssh-ed25519', options - ) - - @classmethod - def generate_key(cls, crv='Ed25519', options=None, is_private=False): + def generate_key(cls, crv='Ed25519', options=None, is_private=False) -> 'OKPKey': if crv not in PRIVATE_KEYS_MAP: raise ValueError('Invalid crv value: "{}"'.format(crv)) private_key_cls = PRIVATE_KEYS_MAP[crv] diff --git a/tests/core/test_jose/test_jwk.py b/tests/core/test_jose/test_jwk.py index 496d06a9..629e9ebb 100644 --- a/tests/core/test_jose/test_jwk.py +++ b/tests/core/test_jose/test_jwk.py @@ -1,53 +1,54 @@ import unittest -from authlib.jose import jwk, JsonWebKey, KeySet -from authlib.jose import RSAKey, ECKey, OKPKey +from authlib.jose import JsonWebKey, KeySet +from authlib.jose import OctKey, RSAKey, ECKey, OKPKey from authlib.common.encoding import base64_to_int from tests.util import read_file_path -RSA_PRIVATE_KEY = read_file_path('jwk_private.json') - -class JWKTest(unittest.TestCase): +class BaseTest(unittest.TestCase): def assertBase64IntEqual(self, x, y): self.assertEqual(base64_to_int(x), base64_to_int(y)) - def test_ec_public_key(self): - # https://tools.ietf.org/html/rfc7520#section-3.1 - obj = read_file_path('secp521r1-public.json') - key = jwk.loads(obj) - new_obj = jwk.dumps(key) - self.assertEqual(new_obj['crv'], obj['crv']) - self.assertBase64IntEqual(new_obj['x'], obj['x']) - self.assertBase64IntEqual(new_obj['y'], obj['y']) - self.assertEqual(key.as_json()[0], '{') - def test_ec_private_key(self): - # https://tools.ietf.org/html/rfc7520#section-3.2 - obj = read_file_path('secp521r1-private.json') - key = jwk.loads(obj) - new_obj = jwk.dumps(key, 'EC') - self.assertEqual(new_obj['crv'], obj['crv']) - self.assertBase64IntEqual(new_obj['x'], obj['x']) - self.assertBase64IntEqual(new_obj['y'], obj['y']) - self.assertBase64IntEqual(new_obj['d'], obj['d']) +class OctKeyTest(BaseTest): + def test_import_oct_key(self): + # https://tools.ietf.org/html/rfc7520#section-3.5 + obj = { + "kty": "oct", + "kid": "018c0ae5-4d9b-471b-bfd6-eef314bc7037", + "use": "sig", + "alg": "HS256", + "k": "hJtXIZ2uSN5kbQfbtTNWbpdmhkV8FJG-Onbc6mxCcYg" + } + key = OctKey.import_key(obj) + new_obj = key.as_dict() + self.assertEqual(obj['k'], new_obj['k']) + self.assertIn('use', new_obj) - def test_invalid_ec(self): - self.assertRaises(ValueError, jwk.loads, {'kty': 'EC'}) - self.assertRaises(ValueError, jwk.dumps, '', 'EC') + def test_invalid_oct_key(self): + self.assertRaises(ValueError, OctKey.import_key, {}) + + +class RSAKeyTest(BaseTest): + def test_import_ssh_pem(self): + raw = read_file_path('ssh_public.pem') + key = RSAKey.import_key(raw) + obj = key.as_dict() + self.assertEqual(obj['kty'], 'RSA') def test_rsa_public_key(self): # https://tools.ietf.org/html/rfc7520#section-3.3 obj = read_file_path('jwk_public.json') - key = jwk.loads(obj) - new_obj = jwk.dumps(key) + key = RSAKey.import_key(obj) + new_obj = key.as_dict() self.assertBase64IntEqual(new_obj['n'], obj['n']) self.assertBase64IntEqual(new_obj['e'], obj['e']) def test_rsa_private_key(self): # https://tools.ietf.org/html/rfc7520#section-3.4 - obj = RSA_PRIVATE_KEY - key = jwk.loads(obj) - new_obj = jwk.dumps(key, 'RSA') + obj = read_file_path('jwk_private.json') + key = RSAKey.import_key(obj) + new_obj = key.as_dict(is_private=True) self.assertBase64IntEqual(new_obj['n'], obj['n']) self.assertBase64IntEqual(new_obj['e'], obj['e']) self.assertBase64IntEqual(new_obj['d'], obj['d']) @@ -58,65 +59,109 @@ def test_rsa_private_key(self): self.assertBase64IntEqual(new_obj['qi'], obj['qi']) def test_rsa_private_key2(self): + rsa_obj = read_file_path('jwk_private.json') obj = { "kty": "RSA", "kid": "bilbo.baggins@hobbiton.example", "use": "sig", - "n": RSA_PRIVATE_KEY['n'], - 'd': RSA_PRIVATE_KEY['d'], + "n": rsa_obj['n'], + 'd': rsa_obj['d'], "e": "AQAB" } - key = jwk.loads(obj) - new_obj = jwk.dumps(key.raw_key, 'RSA') + key = RSAKey.import_key(obj) + new_obj = key.as_dict(is_private=True) self.assertBase64IntEqual(new_obj['n'], obj['n']) self.assertBase64IntEqual(new_obj['e'], obj['e']) self.assertBase64IntEqual(new_obj['d'], obj['d']) - self.assertBase64IntEqual(new_obj['p'], RSA_PRIVATE_KEY['p']) - self.assertBase64IntEqual(new_obj['q'], RSA_PRIVATE_KEY['q']) - self.assertBase64IntEqual(new_obj['dp'], RSA_PRIVATE_KEY['dp']) - self.assertBase64IntEqual(new_obj['dq'], RSA_PRIVATE_KEY['dq']) - self.assertBase64IntEqual(new_obj['qi'], RSA_PRIVATE_KEY['qi']) + self.assertBase64IntEqual(new_obj['p'], rsa_obj['p']) + self.assertBase64IntEqual(new_obj['q'], rsa_obj['q']) + self.assertBase64IntEqual(new_obj['dp'], rsa_obj['dp']) + self.assertBase64IntEqual(new_obj['dq'], rsa_obj['dq']) + self.assertBase64IntEqual(new_obj['qi'], rsa_obj['qi']) def test_invalid_rsa(self): + self.assertRaises(ValueError, RSAKey.import_key, {'kty': 'RSA'}) + rsa_obj = read_file_path('jwk_private.json') obj = { "kty": "RSA", "kid": "bilbo.baggins@hobbiton.example", "use": "sig", - "n": RSA_PRIVATE_KEY['n'], - 'd': RSA_PRIVATE_KEY['d'], - 'p': RSA_PRIVATE_KEY['p'], + "n": rsa_obj['n'], + 'd': rsa_obj['d'], + 'p': rsa_obj['p'], "e": "AQAB" } - self.assertRaises(ValueError, jwk.loads, obj) - self.assertRaises(ValueError, jwk.loads, {'kty': 'RSA'}) - self.assertRaises(ValueError, jwk.dumps, '', 'RSA') + self.assertRaises(ValueError, RSAKey.import_key, obj) - def test_dumps_okp_public_key(self): - key = read_file_path('ed25519-ssh.pub') - self.assertRaises(ValueError, jwk.dumps, key) + def test_rsa_key_generate(self): + self.assertRaises(ValueError, RSAKey.generate_key, 256) + self.assertRaises(ValueError, RSAKey.generate_key, 2001) - obj = jwk.dumps(key, 'OKP') - self.assertEqual(obj['kty'], 'OKP') - self.assertEqual(obj['crv'], 'Ed25519') + key1 = RSAKey.generate_key(is_private=True) + self.assertIn(b'PRIVATE', key1.as_pem(is_private=True)) + self.assertIn(b'PUBLIC', key1.as_pem(is_private=False)) + + key2 = RSAKey.generate_key(is_private=False) + self.assertRaises(ValueError, key2.as_pem, True) + self.assertIn(b'PUBLIC', key2.as_pem(is_private=False)) + + +class ECKeyTest(BaseTest): + def test_ec_public_key(self): + # https://tools.ietf.org/html/rfc7520#section-3.1 + obj = read_file_path('secp521r1-public.json') + key = ECKey.import_key(obj) + new_obj = key.as_dict() + self.assertEqual(new_obj['crv'], obj['crv']) + self.assertBase64IntEqual(new_obj['x'], obj['x']) + self.assertBase64IntEqual(new_obj['y'], obj['y']) + self.assertEqual(key.as_json()[0], '{') + + def test_ec_private_key(self): + # https://tools.ietf.org/html/rfc7520#section-3.2 + obj = read_file_path('secp521r1-private.json') + key = ECKey.import_key(obj) + new_obj = key.as_dict(is_private=True) + self.assertEqual(new_obj['crv'], obj['crv']) + self.assertBase64IntEqual(new_obj['x'], obj['x']) + self.assertBase64IntEqual(new_obj['y'], obj['y']) + self.assertBase64IntEqual(new_obj['d'], obj['d']) + + def test_invalid_ec(self): + self.assertRaises(ValueError, ECKey.import_key, {'kty': 'EC'}) + + def test_ec_key_generate(self): + key1 = ECKey.generate_key('P-384', is_private=True) + self.assertIn(b'PRIVATE', key1.as_pem(is_private=True)) + self.assertIn(b'PUBLIC', key1.as_pem(is_private=False)) - key = read_file_path('ed25519-pub.pem') - obj = jwk.dumps(key, 'OKP') + key2 = ECKey.generate_key('P-256', is_private=False) + self.assertRaises(ValueError, key2.as_pem, True) + self.assertIn(b'PUBLIC', key2.as_pem(is_private=False)) + + +class OKPKeyTest(BaseTest): + def test_import_okp_ssh_key(self): + raw = read_file_path('ed25519-ssh.pub') + key = OKPKey.import_key(raw) + obj = key.as_dict() self.assertEqual(obj['kty'], 'OKP') self.assertEqual(obj['crv'], 'Ed25519') - def test_loads_okp_public_key(self): + def test_import_okp_public_key(self): obj = { "x": "AD9E0JYnpV-OxZbd8aN1t4z71Vtf6JcJC7TYHT0HDbg", "crv": "Ed25519", "kty": "OKP" } - key = jwk.loads(obj) - new_obj = jwk.dumps(key) + key = OKPKey.import_key(obj) + new_obj = key.as_dict() self.assertEqual(obj['x'], new_obj['x']) - def test_dumps_okp_private_key(self): - key = read_file_path('ed25519-pkcs8.pem') - obj = jwk.dumps(key, 'OKP') + def test_import_okp_private_pem(self): + raw = read_file_path('ed25519-pkcs8.pem') + key = OKPKey.import_key(raw) + obj = key.as_dict(is_private=True) self.assertEqual(obj['kty'], 'OKP') self.assertEqual(obj['crv'], 'Ed25519') self.assertIn('d', obj) @@ -128,44 +173,25 @@ def test_loads_okp_private_key(self): 'crv': 'Ed25519', 'kty': 'OKP' } - key = jwk.loads(obj) - new_obj = jwk.dumps(key) + key = OKPKey.import_key(obj) + new_obj = key.as_dict(is_private=True) self.assertEqual(obj['d'], new_obj['d']) - def test_mac_computation(self): - # https://tools.ietf.org/html/rfc7520#section-3.5 - obj = { - "kty": "oct", - "kid": "018c0ae5-4d9b-471b-bfd6-eef314bc7037", - "use": "sig", - "alg": "HS256", - "k": "hJtXIZ2uSN5kbQfbtTNWbpdmhkV8FJG-Onbc6mxCcYg" - } - key = jwk.loads(obj) - new_obj = jwk.dumps(key) - self.assertEqual(obj['k'], new_obj['k']) - self.assertIn('use', new_obj) + def test_okp_key_generate_pem(self): + self.assertRaises(ValueError, OKPKey.generate_key, 'invalid') - new_obj = jwk.dumps(key, use='sig') - self.assertEqual(new_obj['use'], 'sig') + key1 = OKPKey.generate_key('Ed25519', is_private=True) + self.assertIn(b'PRIVATE', key1.as_pem(is_private=True)) + self.assertIn(b'PUBLIC', key1.as_pem(is_private=False)) - def test_jwk_loads(self): - self.assertRaises(ValueError, jwk.loads, {}) - self.assertRaises(ValueError, jwk.loads, {}, 'k') + key2 = OKPKey.generate_key('X25519', is_private=False) + self.assertRaises(ValueError, key2.as_pem, True) + self.assertIn(b'PUBLIC', key2.as_pem(is_private=False)) - obj = { - "kty": "oct", - "kid": "018c0ae5-4d9b-471b-bfd6-eef314bc7037", - "use": "sig", - "alg": "HS256", - "k": "hJtXIZ2uSN5kbQfbtTNWbpdmhkV8FJG-Onbc6mxCcYg" - } - self.assertRaises(ValueError, jwk.loads, [obj], 'invalid-kid') - def test_jwk_dumps_ssh(self): - key = read_file_path('ssh_public.pem') - obj = jwk.dumps(key, kty='RSA') - self.assertEqual(obj['kty'], 'RSA') +class JWKTest(BaseTest): + def test_import_keys(self): + pass def test_thumbprint(self): # https://tools.ietf.org/html/rfc7638#section-3.1 @@ -180,37 +206,3 @@ def test_key_set(self): obj = key_set.as_dict()['keys'][0] self.assertIn('kid', obj) self.assertEqual(key_set.as_json()[0], '{') - - def test_rsa_key_generate_pem(self): - self.assertRaises(ValueError, RSAKey.generate_key, 256) - self.assertRaises(ValueError, RSAKey.generate_key, 2001) - - key1 = RSAKey.generate_key(is_private=True) - self.assertIn(b'PRIVATE', key1.as_pem(is_private=True)) - self.assertIn(b'PUBLIC', key1.as_pem(is_private=False)) - - key2 = RSAKey.generate_key(is_private=False) - self.assertRaises(ValueError, key2.as_pem, True) - self.assertIn(b'PUBLIC', key2.as_pem(is_private=False)) - - def test_ec_key_generate_pem(self): - self.assertRaises(ValueError, ECKey.generate_key, 'invalid') - - key1 = ECKey.generate_key('P-384', is_private=True) - self.assertIn(b'PRIVATE', key1.as_pem(is_private=True)) - self.assertIn(b'PUBLIC', key1.as_pem(is_private=False)) - - key2 = ECKey.generate_key('P-256', is_private=False) - self.assertRaises(ValueError, key2.as_pem, True) - self.assertIn(b'PUBLIC', key2.as_pem(is_private=False)) - - def test_okp_key_generate_pem(self): - self.assertRaises(ValueError, OKPKey.generate_key, 'invalid') - - key1 = OKPKey.generate_key('Ed25519', is_private=True) - self.assertIn(b'PRIVATE', key1.as_pem(is_private=True)) - self.assertIn(b'PUBLIC', key1.as_pem(is_private=False)) - - key2 = OKPKey.generate_key('X25519', is_private=False) - self.assertRaises(ValueError, key2.as_pem, True) - self.assertIn(b'PUBLIC', key2.as_pem(is_private=False)) diff --git a/tests/flask/test_client/test_user_mixin.py b/tests/flask/test_client/test_user_mixin.py index 7b6d25f2..919b145c 100644 --- a/tests/flask/test_client/test_user_mixin.py +++ b/tests/flask/test_client/test_user_mixin.py @@ -6,9 +6,7 @@ from authlib.integrations.flask_client import OAuth from authlib.oidc.core.grants.util import generate_id_token from tests.util import read_file_path -from tests.client_base import ( - get_bearer_token, -) +from tests.client_base import get_bearer_token class FlaskUserMixinTest(TestCase): From 2411c22ba4fb7cbd5ea806aeeabca02f8232f45a Mon Sep 17 00:00:00 2001 From: Hsiaoming Yang Date: Sat, 14 Nov 2020 14:42:35 +0900 Subject: [PATCH 02/26] Remove compatible imports for jose --- authlib/jose/__init__.py | 13 ------------- tests/flask/test_oauth2/test_openid_hybrid_grant.py | 4 ++-- .../flask/test_oauth2/test_openid_implict_grant.py | 4 ++-- 3 files changed, 4 insertions(+), 17 deletions(-) diff --git a/authlib/jose/__init__.py b/authlib/jose/__init__.py index d0ce6233..ec6cfb4c 100644 --- a/authlib/jose/__init__.py +++ b/authlib/jose/__init__.py @@ -44,19 +44,6 @@ OKPKey.kty: OKPKey, } -# compatible constants -JWS_ALGORITHMS = list(JsonWebSignature.ALGORITHMS_REGISTRY.keys()) -JWE_ALG_ALGORITHMS = list(JsonWebEncryption.ALG_REGISTRY.keys()) -JWE_ENC_ALGORITHMS = list(JsonWebEncryption.ENC_REGISTRY.keys()) -JWE_ZIP_ALGORITHMS = list(JsonWebEncryption.ZIP_REGISTRY.keys()) -JWE_ALGORITHMS = JWE_ALG_ALGORITHMS + JWE_ENC_ALGORITHMS + JWE_ZIP_ALGORITHMS - -# compatible imports -JWS = JsonWebSignature -JWE = JsonWebEncryption -JWK = JsonWebKey -JWT = JsonWebToken - jwt = JsonWebToken() diff --git a/tests/flask/test_oauth2/test_openid_hybrid_grant.py b/tests/flask/test_oauth2/test_openid_hybrid_grant.py index e596c4d4..4f274bd8 100644 --- a/tests/flask/test_oauth2/test_openid_hybrid_grant.py +++ b/tests/flask/test_oauth2/test_openid_hybrid_grant.py @@ -1,6 +1,6 @@ from flask import json from authlib.common.urls import urlparse, url_decode -from authlib.jose import JWT +from authlib.jose import JsonWebToken from authlib.oidc.core import HybridIDToken from authlib.oidc.core.grants import ( OpenIDCode as _OpenIDCode, @@ -72,7 +72,7 @@ def prepare_data(self): db.session.commit() def validate_claims(self, id_token, params): - jwt = JWT() + jwt = JsonWebToken() claims = jwt.decode( id_token, 'secret', claims_cls=HybridIDToken, diff --git a/tests/flask/test_oauth2/test_openid_implict_grant.py b/tests/flask/test_oauth2/test_openid_implict_grant.py index 6b66086b..af3673a7 100644 --- a/tests/flask/test_oauth2/test_openid_implict_grant.py +++ b/tests/flask/test_oauth2/test_openid_implict_grant.py @@ -1,4 +1,4 @@ -from authlib.jose import JWT +from authlib.jose import JsonWebToken from authlib.oidc.core import ImplicitIDToken from authlib.oidc.core.grants import ( OpenIDImplicitGrant as _OpenIDImplicitGrant @@ -47,7 +47,7 @@ def prepare_data(self): db.session.commit() def validate_claims(self, id_token, params): - jwt = JWT(['HS256']) + jwt = JsonWebToken(['HS256']) claims = jwt.decode( id_token, 'secret', claims_cls=ImplicitIDToken, From 0ac11c81f0707197f3340efc2ff95b5e24bfa2a3 Mon Sep 17 00:00:00 2001 From: Hsiaoming Yang Date: Sat, 14 Nov 2020 15:23:01 +0900 Subject: [PATCH 03/26] Fix JsonWebKey generate and import keys --- authlib/jose/rfc7517/asymmetric_key.py | 4 +++ authlib/jose/rfc7517/base_key.py | 4 +++ authlib/jose/rfc7517/jwk.py | 2 +- authlib/jose/rfc7518/oct_key.py | 6 +++- tests/core/test_jose/test_jwk.py | 43 ++++++++++++++++++++++++-- 5 files changed, 55 insertions(+), 4 deletions(-) diff --git a/authlib/jose/rfc7517/asymmetric_key.py b/authlib/jose/rfc7517/asymmetric_key.py index aaa36c65..0901a453 100644 --- a/authlib/jose/rfc7517/asymmetric_key.py +++ b/authlib/jose/rfc7517/asymmetric_key.py @@ -187,6 +187,10 @@ def import_key(cls, raw, options=None): raise ValueError('Invalid data for importing key') return key + @classmethod + def validate_raw_key(cls, key): + return isinstance(key, cls.PUBLIC_KEY_CLS) or isinstance(key, cls.PRIVATE_KEY_CLS) + @classmethod def generate_key(cls, crv_or_size, options=None, is_private=False): raise NotImplementedError() diff --git a/authlib/jose/rfc7517/base_key.py b/authlib/jose/rfc7517/base_key.py index c89c41e0..7c80284a 100644 --- a/authlib/jose/rfc7517/base_key.py +++ b/authlib/jose/rfc7517/base_key.py @@ -108,3 +108,7 @@ def check_required_fields(cls, data): for k in cls.REQUIRED_JSON_FIELDS: if k not in data: raise ValueError('Missing required field: "{}"'.format(k)) + + @classmethod + def validate_raw_key(cls, key): + raise NotImplementedError() diff --git a/authlib/jose/rfc7517/jwk.py b/authlib/jose/rfc7517/jwk.py index 576c4e83..c0d47e62 100644 --- a/authlib/jose/rfc7517/jwk.py +++ b/authlib/jose/rfc7517/jwk.py @@ -36,7 +36,7 @@ def import_key(cls, raw, options=None): raw_key = load_pem_key(raw) for _kty in cls.JWK_KEY_CLS: key_cls = cls.JWK_KEY_CLS[_kty] - if isinstance(raw_key, key_cls.RAW_KEY_CLS): + if key_cls.validate_raw_key(raw_key): return key_cls.import_key(raw_key, options) key_cls = cls.JWK_KEY_CLS[kty] diff --git a/authlib/jose/rfc7518/oct_key.py b/authlib/jose/rfc7518/oct_key.py index 12c5415d..8c6537d7 100644 --- a/authlib/jose/rfc7518/oct_key.py +++ b/authlib/jose/rfc7518/oct_key.py @@ -45,6 +45,10 @@ def as_dict(self, is_private=False): tokens['kid'] = self.thumbprint() return tokens + @classmethod + def validate_raw_key(cls, key): + return isinstance(key, bytes) + @classmethod def import_key(cls, raw, options=None): """Import a key from bytes, string, or dict data.""" @@ -63,7 +67,7 @@ def import_key(cls, raw, options=None): return key @classmethod - def generate_key(cls, key_size=256, options=None, is_private=False): + def generate_key(cls, key_size=256, options=None, is_private=True): """Generate a ``OctKey`` with the given bit size.""" if not is_private: raise ValueError('oct key can not be generated as public') diff --git a/tests/core/test_jose/test_jwk.py b/tests/core/test_jose/test_jwk.py index 629e9ebb..171280b1 100644 --- a/tests/core/test_jose/test_jwk.py +++ b/tests/core/test_jose/test_jwk.py @@ -28,6 +28,21 @@ def test_import_oct_key(self): def test_invalid_oct_key(self): self.assertRaises(ValueError, OctKey.import_key, {}) + def test_generate_oct_key(self): + self.assertRaises(ValueError, OctKey.generate_key, 251) + + with self.assertRaises(ValueError) as cm: + OctKey.generate_key(is_private=False) + + self.assertEqual(str(cm.exception), 'oct key can not be generated as public') + + key = OctKey.generate_key() + self.assertIn('kid', key.as_dict()) + self.assertNotIn('use', key.as_dict()) + + key2 = OctKey.import_key(key, {'use': 'sig'}) + self.assertIn('use', key2.as_dict()) + class RSAKeyTest(BaseTest): def test_import_ssh_pem(self): @@ -131,6 +146,8 @@ def test_invalid_ec(self): self.assertRaises(ValueError, ECKey.import_key, {'kty': 'EC'}) def test_ec_key_generate(self): + self.assertRaises(ValueError, ECKey.generate_key, 'Invalid') + key1 = ECKey.generate_key('P-384', is_private=True) self.assertIn(b'PRIVATE', key1.as_pem(is_private=True)) self.assertIn(b'PUBLIC', key1.as_pem(is_private=False)) @@ -166,7 +183,7 @@ def test_import_okp_private_pem(self): self.assertEqual(obj['crv'], 'Ed25519') self.assertIn('d', obj) - def test_loads_okp_private_key(self): + def test_import_okp_private_dict(self): obj = { 'x': '11qYAYKxCrfVS_7TyWQHOg7hcvPapiMlrwIaaPcHURo', 'd': 'nWGxne_9WmC6hEr0kuwsxERJxWl7MmkZcDusAxyuf2A', @@ -190,8 +207,30 @@ def test_okp_key_generate_pem(self): class JWKTest(BaseTest): + def test_generate_keys(self): + key = JsonWebKey.generate_key(kty='oct', crv_or_size=256, is_private=True) + self.assertEqual(key['kty'], 'oct') + + key = JsonWebKey.generate_key(kty='EC', crv_or_size='P-256') + self.assertEqual(key['kty'], 'EC') + + key = JsonWebKey.generate_key(kty='RSA', crv_or_size=2048) + self.assertEqual(key['kty'], 'RSA') + + key = JsonWebKey.generate_key(kty='OKP', crv_or_size='Ed25519') + self.assertEqual(key['kty'], 'OKP') + def test_import_keys(self): - pass + rsa_pub_pem = read_file_path('rsa_public.pem') + self.assertRaises(ValueError, JsonWebKey.import_key, rsa_pub_pem, {'kty': 'EC'}) + + key = JsonWebKey.import_key(raw=rsa_pub_pem, options={'kty': 'RSA'}) + self.assertIn('e', dict(key)) + self.assertIn('n', dict(key)) + + key = JsonWebKey.import_key(raw=rsa_pub_pem) + self.assertIn('e', dict(key)) + self.assertIn('n', dict(key)) def test_thumbprint(self): # https://tools.ietf.org/html/rfc7638#section-3.1 From 1f6586bfaa565faa7bdc39def5ccac3590810bb5 Mon Sep 17 00:00:00 2001 From: Hsiaoming Yang Date: Sun, 15 Nov 2020 11:18:19 +0900 Subject: [PATCH 04/26] Add params to export JWK data --- authlib/jose/rfc7517/asymmetric_key.py | 10 ++++------ authlib/jose/rfc7517/base_key.py | 6 +++--- authlib/jose/rfc7517/key_set.py | 8 ++++---- authlib/jose/rfc7518/oct_key.py | 4 +++- 4 files changed, 14 insertions(+), 14 deletions(-) diff --git a/authlib/jose/rfc7517/asymmetric_key.py b/authlib/jose/rfc7517/asymmetric_key.py index 0901a453..83094bc9 100644 --- a/authlib/jose/rfc7517/asymmetric_key.py +++ b/authlib/jose/rfc7517/asymmetric_key.py @@ -85,7 +85,7 @@ def load_private_key(self): def load_public_key(self): raise NotImplementedError() - def as_dict(self, is_private=False): + def as_dict(self, is_private=False, **params): """Represent this key as a dict of the JSON Web Key.""" tokens = self.tokens if is_private and 'd' not in tokens: @@ -95,11 +95,14 @@ def as_dict(self, is_private=False): if 'd' in tokens and not is_private: # filter out private fields tokens = {k: tokens[k] for k in tokens if k in self.PUBLIC_KEY_FIELDS} + tokens['kty'] = self.kty if kid: tokens['kid'] = kid if not kid: tokens['kid'] = self.thumbprint() + + tokens.update(params) return tokens def as_key(self, is_private=False): @@ -108,11 +111,6 @@ def as_key(self, is_private=False): return self.get_private_key() return self.get_public_key() - def as_json(self, is_private=False): - """Represent this key as a JSON string.""" - obj = self.as_dict(is_private) - return json_dumps(obj) - def as_bytes(self, encoding=None, is_private=False, password=None): """Export key into PEM/DER format bytes. diff --git a/authlib/jose/rfc7517/base_key.py b/authlib/jose/rfc7517/base_key.py index 7c80284a..f8fe7b4a 100644 --- a/authlib/jose/rfc7517/base_key.py +++ b/authlib/jose/rfc7517/base_key.py @@ -81,12 +81,12 @@ def check_key_op(self, operation): if use != 'enc': raise InvalidUseError() - def as_dict(self, is_private=False): + def as_dict(self, is_private=False, **params): raise NotImplementedError() - def as_json(self, is_private=False): + def as_json(self, is_private=False, **params): """Represent this key as a JSON string.""" - obj = self.as_dict(is_private) + obj = self.as_dict(is_private, **params) return json_dumps(obj) def thumbprint(self): diff --git a/authlib/jose/rfc7517/key_set.py b/authlib/jose/rfc7517/key_set.py index d7cb2a88..e95c4d0c 100644 --- a/authlib/jose/rfc7517/key_set.py +++ b/authlib/jose/rfc7517/key_set.py @@ -7,13 +7,13 @@ class KeySet(object): def __init__(self, keys): self.keys = keys - def as_dict(self, is_private=False): + def as_dict(self, is_private=False, **params): """Represent this key as a dict of the JSON Web Key Set.""" - return {'keys': [k.as_dict(is_private) for k in self.keys]} + return {'keys': [k.as_dict(is_private, **params) for k in self.keys]} - def as_json(self, is_private=False): + def as_json(self, is_private=False, **params): """Represent this key set as a JSON string.""" - obj = self.as_dict(is_private) + obj = self.as_dict(is_private, **params) return json_dumps(obj) def find_by_kid(self, kid): diff --git a/authlib/jose/rfc7518/oct_key.py b/authlib/jose/rfc7518/oct_key.py index 8c6537d7..c2e16b14 100644 --- a/authlib/jose/rfc7518/oct_key.py +++ b/authlib/jose/rfc7518/oct_key.py @@ -39,10 +39,12 @@ def load_dict_key(self): k = to_unicode(urlsafe_b64encode(self.raw_key)) self._dict_data = {'kty': self.kty, 'k': k} - def as_dict(self, is_private=False): + def as_dict(self, is_private=False, **params): tokens = self.tokens if 'kid' not in tokens: tokens['kid'] = self.thumbprint() + + tokens.update(params) return tokens @classmethod From 11794ef7cd410a7cafe04c3db06f8fabf672c8c7 Mon Sep 17 00:00:00 2001 From: Hsiaoming Yang Date: Sun, 15 Nov 2020 11:32:49 +0900 Subject: [PATCH 05/26] Add tests for import key set --- authlib/jose/rfc7517/jwk.py | 1 + tests/core/test_jose/test_jwk.py | 18 +++++++++++++++++- 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/authlib/jose/rfc7517/jwk.py b/authlib/jose/rfc7517/jwk.py index c0d47e62..dcb38b2c 100644 --- a/authlib/jose/rfc7517/jwk.py +++ b/authlib/jose/rfc7517/jwk.py @@ -52,6 +52,7 @@ def import_key_set(cls, raw): if isinstance(raw, dict) and 'keys' in raw: keys = raw.get('keys') return KeySet([cls.import_key(k) for k in keys]) + raise ValueError('Invalid key set format') def _transform_raw_key(raw): diff --git a/tests/core/test_jose/test_jwk.py b/tests/core/test_jose/test_jwk.py index 171280b1..80cb616c 100644 --- a/tests/core/test_jose/test_jwk.py +++ b/tests/core/test_jose/test_jwk.py @@ -1,7 +1,7 @@ import unittest from authlib.jose import JsonWebKey, KeySet from authlib.jose import OctKey, RSAKey, ECKey, OKPKey -from authlib.common.encoding import base64_to_int +from authlib.common.encoding import base64_to_int, json_dumps from tests.util import read_file_path @@ -232,6 +232,22 @@ def test_import_keys(self): self.assertIn('e', dict(key)) self.assertIn('n', dict(key)) + def test_import_key_set(self): + jwks_public = read_file_path('jwks_public.json') + key_set1 = JsonWebKey.import_key_set(jwks_public) + key1 = key_set1.find_by_kid('abc') + self.assertEqual(key1['e'], 'AQAB') + + key_set2 = JsonWebKey.import_key_set(jwks_public['keys']) + key2 = key_set2.find_by_kid('abc') + self.assertEqual(key2['e'], 'AQAB') + + key_set3 = JsonWebKey.import_key_set(json_dumps(jwks_public)) + key3 = key_set3.find_by_kid('abc') + self.assertEqual(key3['e'], 'AQAB') + + self.assertRaises(ValueError, JsonWebKey.import_key_set, 'invalid') + def test_thumbprint(self): # https://tools.ietf.org/html/rfc7638#section-3.1 data = read_file_path('thumbprint_example.json') From 9e8dce2c0ae9a8cf65040d6502529dadf0dd4a26 Mon Sep 17 00:00:00 2001 From: Hsiaoming Yang Date: Sun, 15 Nov 2020 18:00:02 +0900 Subject: [PATCH 06/26] split a OpenIDToken extension --- authlib/oauth2/rfc8628/models.py | 4 +- authlib/oidc/core/grants/__init__.py | 3 +- authlib/oidc/core/grants/code.py | 66 ++++++++++++++++++---------- docs/specs/oidc.rst | 4 ++ 4 files changed, 52 insertions(+), 25 deletions(-) diff --git a/authlib/oauth2/rfc8628/models.py b/authlib/oauth2/rfc8628/models.py index 3cad46d6..f00d4808 100644 --- a/authlib/oauth2/rfc8628/models.py +++ b/authlib/oauth2/rfc8628/models.py @@ -27,4 +27,6 @@ def get_user_code(self): def is_expired(self): expires_at = self.get('expires_at') - return expires_at < time.time() + if expires_at: + return expires_at < time.time() + return False diff --git a/authlib/oidc/core/grants/__init__.py b/authlib/oidc/core/grants/__init__.py index fb60bb72..8b4b0025 100644 --- a/authlib/oidc/core/grants/__init__.py +++ b/authlib/oidc/core/grants/__init__.py @@ -1,8 +1,9 @@ -from .code import OpenIDCode +from .code import OpenIDToken, OpenIDCode from .implicit import OpenIDImplicitGrant from .hybrid import OpenIDHybridGrant __all__ = [ + 'OpenIDToken', 'OpenIDCode', 'OpenIDImplicitGrant', 'OpenIDHybridGrant', diff --git a/authlib/oidc/core/grants/code.py b/authlib/oidc/core/grants/code.py index 61be7a4d..0e01bb23 100644 --- a/authlib/oidc/core/grants/code.py +++ b/authlib/oidc/core/grants/code.py @@ -19,28 +19,7 @@ log = logging.getLogger(__name__) -class OpenIDCode(object): - """An extension from OpenID Connect for "grant_type=code" request. - """ - def __init__(self, require_nonce=False): - self.require_nonce = require_nonce - - def exists_nonce(self, nonce, request): - """Check if the given nonce is existing in your database. Developers - MUST implement this method in subclass, e.g.:: - - def exists_nonce(self, nonce, request): - exists = AuthorizationCode.query.filter_by( - client_id=request.client_id, nonce=nonce - ).first() - return bool(exists) - - :param nonce: A string of "nonce" parameter in request - :param request: OAuth2Request instance - :return: Boolean - """ - raise NotImplementedError() - +class OpenIDToken(object): def get_jwt_config(self, grant): # pragma: no cover """Get the JWT configuration for OpenIDCode extension. The JWT configuration will be used to generate ``id_token``. Developers @@ -59,7 +38,7 @@ def get_jwt_config(self, grant): """ raise NotImplementedError() - def generate_user_info(self, user, scope): # pragma: no cover + def generate_user_info(self, user, scope): """Provide user information for the given scope. Developers MUST implement this method in subclass, e.g.:: @@ -103,6 +82,47 @@ def process_token(self, grant, token): token['id_token'] = id_token return token + def __call__(self, grant): + grant.register_hook('process_token', self.process_token) + + +class OpenIDCode(OpenIDToken): + """An extension from OpenID Connect for "grant_type=code" request. Developers + MUST implement the missing methods:: + + class MyOpenIDCode(OpenIDCode): + def get_jwt_config(self): + return {...} + + def exists_nonce(self, nonce, request): + return check_if_nonce_in_cache(request.client_id, nonce) + + def generate_user_info(self, user, scope): + return {...} + + The register this extension with AuthorizationCodeGrant:: + + authorization_server.register_grant(AuthorizationCodeGrant, extensions=[MyOpenIDCode()]) + """ + def __init__(self, require_nonce=False): + self.require_nonce = require_nonce + + def exists_nonce(self, nonce, request): + """Check if the given nonce is existing in your database. Developers + MUST implement this method in subclass, e.g.:: + + def exists_nonce(self, nonce, request): + exists = AuthorizationCode.query.filter_by( + client_id=request.client_id, nonce=nonce + ).first() + return bool(exists) + + :param nonce: A string of "nonce" parameter in request + :param request: OAuth2Request instance + :return: Boolean + """ + raise NotImplementedError() + def validate_openid_authorization_request(self, grant): validate_nonce(grant.request, self.exists_nonce, self.require_nonce) diff --git a/docs/specs/oidc.rst b/docs/specs/oidc.rst index d767dc60..7c4202ba 100644 --- a/docs/specs/oidc.rst +++ b/docs/specs/oidc.rst @@ -15,6 +15,10 @@ OpenID Grants .. module:: authlib.oidc.core.grants +.. autoclass:: OpenIDToken + :show-inheritance: + :members: + .. autoclass:: OpenIDCode :show-inheritance: :members: From b815d99571cfb7487f767e394a60456809e6c054 Mon Sep 17 00:00:00 2001 From: Jelle Besseling Date: Sun, 15 Nov 2020 17:23:43 +0100 Subject: [PATCH 07/26] Include correct parameters in django example --- docs/client/django.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/client/django.rst b/docs/client/django.rst index 115e3d46..41d9dc2a 100644 --- a/docs/client/django.rst +++ b/docs/client/django.rst @@ -110,7 +110,7 @@ it is also possible to use signal to listen for token updating:: from authlib.integrations.django_client import token_update @receiver(token_update) - def on_token_update(sender, token, refresh_token=None, access_token=None): + def on_token_update(sender, name, token, refresh_token=None, access_token=None, **kwargs): if refresh_token: item = OAuth2Token.find(name=name, refresh_token=refresh_token) elif access_token: From 6dfda77c37bd850cafb8f7fe4e7a93d2b7148efa Mon Sep 17 00:00:00 2001 From: ldng Date: Sat, 21 Nov 2020 02:24:26 +0000 Subject: [PATCH 08/26] Fix #297 : Accept extra auth-param attributes (#298) --- .../integrations/django_oauth2/resource_protector.py | 4 ++-- authlib/oauth2/rfc6750/errors.py | 5 ++++- authlib/oauth2/rfc6750/validator.py | 11 ++++++----- 3 files changed, 12 insertions(+), 8 deletions(-) diff --git a/authlib/integrations/django_oauth2/resource_protector.py b/authlib/integrations/django_oauth2/resource_protector.py index 472263c8..3d4b1326 100644 --- a/authlib/integrations/django_oauth2/resource_protector.py +++ b/authlib/integrations/django_oauth2/resource_protector.py @@ -51,9 +51,9 @@ def decorated(request, *args, **kwargs): class BearerTokenValidator(_BearerTokenValidator): - def __init__(self, token_model, realm=None): + def __init__(self, token_model, realm=None, extra_attributes=None): self.token_model = token_model - super(BearerTokenValidator, self).__init__(realm) + super(BearerTokenValidator, self).__init__(realm, extra_attributes) def authenticate_token(self, token_string): try: diff --git a/authlib/oauth2/rfc6750/errors.py b/authlib/oauth2/rfc6750/errors.py index 06d8f5f8..ead765c8 100644 --- a/authlib/oauth2/rfc6750/errors.py +++ b/authlib/oauth2/rfc6750/errors.py @@ -36,10 +36,11 @@ class InvalidTokenError(OAuth2Error): status_code = 401 def __init__(self, description=None, uri=None, status_code=None, - state=None, realm=None): + state=None, realm=None, extra_attributes=None): super(InvalidTokenError, self).__init__( description, uri, status_code, state) self.realm = realm + self.extra_attributes = extra_attributes def get_headers(self): """If the protected resource request does not include authentication @@ -55,6 +56,8 @@ def get_headers(self): extras = [] if self.realm: extras.append('realm="{}"'.format(self.realm)) + if self.extra_attributes: + extras.extend(['{}="{}"'.format(k, v) for k, v in self.extra_attributes.items()]) extras.append('error="{}"'.format(self.error)) error_description = self.get_error_description() extras.append('error_description="{}"'.format(error_description)) diff --git a/authlib/oauth2/rfc6750/validator.py b/authlib/oauth2/rfc6750/validator.py index 0461f828..6bb03af3 100644 --- a/authlib/oauth2/rfc6750/validator.py +++ b/authlib/oauth2/rfc6750/validator.py @@ -16,8 +16,9 @@ class BearerTokenValidator(object): TOKEN_TYPE = 'bearer' - def __init__(self, realm=None): + def __init__(self, realm=None, extra_attributes=None): self.realm = realm + self.extra_attributes = extra_attributes def authenticate_token(self, token_string): """A method to query token from database with the given token string. @@ -67,14 +68,14 @@ def scope_insufficient(self, token, scope, operator='AND'): def __call__(self, token_string, scope, request, scope_operator='AND'): if self.request_invalid(request): - raise InvalidRequestError() + raise InvalidRequestError(realm=self.realm, extra_attributes=self.extra_attributes) token = self.authenticate_token(token_string) if not token: - raise InvalidTokenError(realm=self.realm) + raise InvalidTokenError(realm=self.realm, extra_attributes=self.extra_attributes) if token.is_expired(): - raise InvalidTokenError(realm=self.realm) + raise InvalidTokenError(realm=self.realm, extra_attributes=self.extra_attributes) if token.is_revoked(): - raise InvalidTokenError(realm=self.realm) + raise InvalidTokenError(realm=self.realm, extra_attributes=self.extra_attributes) if self.scope_insufficient(token, scope, scope_operator): raise InsufficientScopeError() return token From 4210c8805169798cac2d5c4be270a98a74f90817 Mon Sep 17 00:00:00 2001 From: Hsiaoming Yang Date: Sun, 15 Nov 2020 20:30:27 +0900 Subject: [PATCH 09/26] Refactor device code flow, support other auth methods --- authlib/oauth2/rfc8628/device_code.py | 21 +++++-------------- authlib/oauth2/rfc8628/models.py | 6 ++++++ .../test_oauth2/test_device_code_grant.py | 11 +++------- 3 files changed, 14 insertions(+), 24 deletions(-) diff --git a/authlib/oauth2/rfc8628/device_code.py b/authlib/oauth2/rfc8628/device_code.py index af7c8c17..0952afe8 100644 --- a/authlib/oauth2/rfc8628/device_code.py +++ b/authlib/oauth2/rfc8628/device_code.py @@ -61,6 +61,7 @@ class DeviceCodeGrant(BaseGrant, TokenEndpointMixin): indication that the client should continue to poll. """ GRANT_TYPE = DEVICE_CODE_GRANT_TYPE + TOKEN_ENDPOINT_AUTH_METHODS = ['client_secret_basic', 'client_secret_post', 'none'] def validate_token_request(self): """After displaying instructions to the user, the client creates an @@ -94,18 +95,15 @@ def validate_token_request(self): if not device_code: raise InvalidRequestError('Missing "device_code" in payload') - if not self.request.client_id: - raise InvalidRequestError('Missing "client_id" in payload') + client = self.authenticate_token_endpoint_client() + if not client.check_grant_type(self.GRANT_TYPE): + raise UnauthorizedClientError() credential = self.query_device_credential(device_code) if not credential: raise InvalidRequestError('Invalid "device_code" in payload') - if credential.get_client_id() != self.request.client_id: - raise UnauthorizedClientError() - - client = self.authenticate_token_endpoint_client() - if not client.check_grant_type(self.GRANT_TYPE): + if credential.get_client_id() != client.get_client_id(): raise UnauthorizedClientError() user = self.validate_device_credential(credential) @@ -148,15 +146,6 @@ def validate_device_credential(self, credential): raise AuthorizationPendingError() - def authenticate_token_endpoint_client(self): - client = self.server.query_client(self.request.client_id) - if not client: - raise InvalidClientError() - self.server.send_signal( - 'after_authenticate_client', - client=client, grant=self) - return client - def query_device_credential(self, device_code): """Get device credential from previously savings via ``DeviceAuthorizationEndpoint``. Developers MUST implement it in subclass:: diff --git a/authlib/oauth2/rfc8628/models.py b/authlib/oauth2/rfc8628/models.py index f00d4808..0ec1e366 100644 --- a/authlib/oauth2/rfc8628/models.py +++ b/authlib/oauth2/rfc8628/models.py @@ -25,6 +25,12 @@ def get_scope(self): def get_user_code(self): return self['user_code'] + def get_nonce(self): + return self.get('nonce') + + def get_auth_time(self): + return self.get('auth_time') + def is_expired(self): expires_at = self.get('expires_at') if expires_at: diff --git a/tests/flask/test_oauth2/test_device_code_grant.py b/tests/flask/test_oauth2/test_device_code_grant.py index eb0b5454..60d4ceec 100644 --- a/tests/flask/test_oauth2/test_device_code_grant.py +++ b/tests/flask/test_oauth2/test_device_code_grant.py @@ -89,6 +89,7 @@ def prepare_data(self, grant_type=DeviceCodeGrant.GRANT_TYPE): 'redirect_uris': ['http://localhost/authorized'], 'scope': 'profile', 'grant_types': [grant_type], + 'token_endpoint_auth_method': 'none', }) db.session.add(client) db.session.commit() @@ -98,13 +99,7 @@ def test_invalid_request(self): self.prepare_data() rv = self.client.post('/oauth/token', data={ 'grant_type': DeviceCodeGrant.GRANT_TYPE, - }) - resp = json.loads(rv.data) - self.assertEqual(resp['error'], 'invalid_request') - - rv = self.client.post('/oauth/token', data={ - 'grant_type': DeviceCodeGrant.GRANT_TYPE, - 'device_code': 'valid-device', + 'client_id': 'test', }) resp = json.loads(rv.data) self.assertEqual(resp['error'], 'invalid_request') @@ -125,7 +120,7 @@ def test_unauthorized_client(self): 'client_id': 'invalid', }) resp = json.loads(rv.data) - self.assertEqual(resp['error'], 'unauthorized_client') + self.assertEqual(resp['error'], 'invalid_client') self.prepare_data(grant_type='password') rv = self.client.post('/oauth/token', data={ From 62d942e62eefc76397498d0e880bb4198ec9ff22 Mon Sep 17 00:00:00 2001 From: Hsiaoming Yang Date: Fri, 20 Nov 2020 21:18:08 +0900 Subject: [PATCH 10/26] Add WWW-Authenticate for resource protector Fixed https://github.com/lepture/authlib/issues/296 --- .../django_oauth2/resource_protector.py | 4 +-- authlib/oauth2/rfc6749/__init__.py | 3 +- authlib/oauth2/rfc6749/errors.py | 31 ++++++++++++++++--- authlib/oauth2/rfc6749/resource_protector.py | 28 ++++++++++++++--- authlib/oauth2/rfc6750/errors.py | 10 +++--- authlib/oauth2/rfc6750/validator.py | 9 ++---- authlib/oauth2/rfc7662/introspection.py | 1 - 7 files changed, 63 insertions(+), 23 deletions(-) diff --git a/authlib/integrations/django_oauth2/resource_protector.py b/authlib/integrations/django_oauth2/resource_protector.py index 3d4b1326..e6b4ea96 100644 --- a/authlib/integrations/django_oauth2/resource_protector.py +++ b/authlib/integrations/django_oauth2/resource_protector.py @@ -51,9 +51,9 @@ def decorated(request, *args, **kwargs): class BearerTokenValidator(_BearerTokenValidator): - def __init__(self, token_model, realm=None, extra_attributes=None): + def __init__(self, token_model, realm=None, **extra_attributes): self.token_model = token_model - super(BearerTokenValidator, self).__init__(realm, extra_attributes) + super(BearerTokenValidator, self).__init__(realm, **extra_attributes) def authenticate_token(self, token_string): try: diff --git a/authlib/oauth2/rfc6749/__init__.py b/authlib/oauth2/rfc6749/__init__.py index 2994f4f4..f4a0c808 100644 --- a/authlib/oauth2/rfc6749/__init__.py +++ b/authlib/oauth2/rfc6749/__init__.py @@ -31,7 +31,7 @@ from .models import ClientMixin, AuthorizationCodeMixin, TokenMixin from .authenticate_client import ClientAuthentication from .authorization_server import AuthorizationServer -from .resource_protector import ResourceProtector +from .resource_protector import ResourceProtector, TokenValidator from .token_endpoint import TokenEndpoint from .grants import ( BaseGrant, @@ -65,6 +65,7 @@ 'ClientAuthentication', 'AuthorizationServer', 'ResourceProtector', + 'TokenValidator', 'TokenEndpoint', 'BaseGrant', 'AuthorizationEndpointMixin', diff --git a/authlib/oauth2/rfc6749/errors.py b/authlib/oauth2/rfc6749/errors.py index deba33fb..a36d44b2 100644 --- a/authlib/oauth2/rfc6749/errors.py +++ b/authlib/oauth2/rfc6749/errors.py @@ -156,15 +156,38 @@ class AccessDeniedError(OAuth2Error): # -- below are extended errors -- # -class MissingAuthorizationError(OAuth2Error): +class ForbiddenError(OAuth2Error): + status_code = 401 + + def __init__(self, auth_type=None, realm=None): + super(ForbiddenError, self).__init__() + self.auth_type = auth_type + self.realm = realm + + def get_headers(self): + headers = super(ForbiddenError, self).get_headers() + if not self.auth_type: + return headers + + extras = [] + if self.realm: + extras.append('realm="{}"'.format(self.realm)) + extras.append('error="{}"'.format(self.error)) + error_description = self.description + extras.append('error_description="{}"'.format(error_description)) + headers.append( + ('WWW-Authenticate', f'{self.auth_type} ' + ', '.join(extras)) + ) + return headers + + +class MissingAuthorizationError(ForbiddenError): error = 'missing_authorization' description = 'Missing "Authorization" in headers.' - status_code = 401 -class UnsupportedTokenTypeError(OAuth2Error): +class UnsupportedTokenTypeError(ForbiddenError): error = 'unsupported_token_type' - status_code = 401 # -- exceptions for clients -- # diff --git a/authlib/oauth2/rfc6749/resource_protector.py b/authlib/oauth2/rfc6749/resource_protector.py index 40567950..a4d4f942 100644 --- a/authlib/oauth2/rfc6749/resource_protector.py +++ b/authlib/oauth2/rfc6749/resource_protector.py @@ -10,28 +10,48 @@ from .errors import MissingAuthorizationError, UnsupportedTokenTypeError +class TokenValidator(object): + """Base token validator class. Subclass this validator to register + into ResourceProtector instance. + """ + TOKEN_TYPE = 'bearer' + + def __init__(self, realm=None, **extra_attributes): + self.realm = realm + self.extra_attributes = extra_attributes + + def __call__(self, token_string, scope, request, scope_operator='AND'): + raise NotImplementedError() + + class ResourceProtector(object): def __init__(self): self._token_validators = {} + self._default_realm = None + self._default_auth_type = None + + def register_token_validator(self, validator: TokenValidator): + if not self._default_auth_type: + self._default_realm = validator.realm + self._default_auth_type = validator.TOKEN_TYPE - def register_token_validator(self, validator): if validator.TOKEN_TYPE not in self._token_validators: self._token_validators[validator.TOKEN_TYPE] = validator def validate_request(self, scope, request, scope_operator='AND'): auth = request.headers.get('Authorization') if not auth: - raise MissingAuthorizationError() + raise MissingAuthorizationError(self._default_auth_type, self._default_realm) # https://tools.ietf.org/html/rfc6749#section-7.1 token_parts = auth.split(None, 1) if len(token_parts) != 2: - raise UnsupportedTokenTypeError() + raise UnsupportedTokenTypeError(self._default_auth_type, self._default_realm) token_type, token_string = token_parts validator = self._token_validators.get(token_type.lower()) if not validator: - raise UnsupportedTokenTypeError() + raise UnsupportedTokenTypeError(self._default_auth_type, self._default_realm) return validator(token_string, scope, request, scope_operator) diff --git a/authlib/oauth2/rfc6750/errors.py b/authlib/oauth2/rfc6750/errors.py index ead765c8..26ca34ff 100644 --- a/authlib/oauth2/rfc6750/errors.py +++ b/authlib/oauth2/rfc6750/errors.py @@ -36,7 +36,7 @@ class InvalidTokenError(OAuth2Error): status_code = 401 def __init__(self, description=None, uri=None, status_code=None, - state=None, realm=None, extra_attributes=None): + state=None, realm=None, **extra_attributes): super(InvalidTokenError, self).__init__( description, uri, status_code, state) self.realm = realm @@ -55,12 +55,12 @@ def get_headers(self): extras = [] if self.realm: - extras.append('realm="{}"'.format(self.realm)) + extras.append(f'realm="{self.realm}"') if self.extra_attributes: - extras.extend(['{}="{}"'.format(k, v) for k, v in self.extra_attributes.items()]) - extras.append('error="{}"'.format(self.error)) + extras.extend([f'{k}="{self.extra_attributes[k]}"' for k in self.extra_attributes]) + extras.append(f'error="{self.error}"') error_description = self.get_error_description() - extras.append('error_description="{}"'.format(error_description)) + extras.append(f'error_description="{error_description}"') headers.append( ('WWW-Authenticate', 'Bearer ' + ', '.join(extras)) ) diff --git a/authlib/oauth2/rfc6750/validator.py b/authlib/oauth2/rfc6750/validator.py index 6bb03af3..aa4ac8f8 100644 --- a/authlib/oauth2/rfc6750/validator.py +++ b/authlib/oauth2/rfc6750/validator.py @@ -6,6 +6,7 @@ """ from ..rfc6749.util import scope_to_list +from ..rfc6749 import TokenValidator from .errors import ( InvalidRequestError, InvalidTokenError, @@ -13,13 +14,9 @@ ) -class BearerTokenValidator(object): +class BearerTokenValidator(TokenValidator): TOKEN_TYPE = 'bearer' - def __init__(self, realm=None, extra_attributes=None): - self.realm = realm - self.extra_attributes = extra_attributes - def authenticate_token(self, token_string): """A method to query token from database with the given token string. Developers MUST re-implement this method. For instance:: @@ -68,7 +65,7 @@ def scope_insufficient(self, token, scope, operator='AND'): def __call__(self, token_string, scope, request, scope_operator='AND'): if self.request_invalid(request): - raise InvalidRequestError(realm=self.realm, extra_attributes=self.extra_attributes) + raise InvalidRequestError() token = self.authenticate_token(token_string) if not token: raise InvalidTokenError(realm=self.realm, extra_attributes=self.extra_attributes) diff --git a/authlib/oauth2/rfc7662/introspection.py b/authlib/oauth2/rfc7662/introspection.py index f1e52027..cca15b83 100644 --- a/authlib/oauth2/rfc7662/introspection.py +++ b/authlib/oauth2/rfc7662/introspection.py @@ -1,4 +1,3 @@ -import time from authlib.consts import default_json_headers from ..rfc6749 import ( TokenEndpoint, From 6e567144246aeeee6d25586fad2c01e6251f2c5e Mon Sep 17 00:00:00 2001 From: Hsiaoming Yang Date: Sat, 21 Nov 2020 11:21:43 +0900 Subject: [PATCH 11/26] Add unsupported_response_type error related: https://github.com/lepture/authlib/issues/299 --- authlib/oauth2/rfc6749/__init__.py | 2 ++ .../oauth2/rfc6749/authorization_server.py | 7 +++--- authlib/oauth2/rfc6749/errors.py | 24 +++++++++++++++++-- .../rfc6749/grants/authorization_code.py | 3 ++- authlib/oauth2/rfc6749/grants/base.py | 9 ++++--- tests/flask/test_oauth2/test_oauth2_server.py | 2 +- .../test_oauth2/test_openid_hybrid_grant.py | 2 +- 7 files changed, 36 insertions(+), 13 deletions(-) diff --git a/authlib/oauth2/rfc6749/__init__.py b/authlib/oauth2/rfc6749/__init__.py index f4a0c808..0b88cc0b 100644 --- a/authlib/oauth2/rfc6749/__init__.py +++ b/authlib/oauth2/rfc6749/__init__.py @@ -20,6 +20,7 @@ InvalidScopeError, InsecureTransportError, UnauthorizedClientError, + UnsupportedResponseTypeError, UnsupportedGrantTypeError, UnsupportedTokenTypeError, # exceptions for clients @@ -55,6 +56,7 @@ 'InvalidScopeError', 'InsecureTransportError', 'UnauthorizedClientError', + 'UnsupportedResponseTypeError', 'UnsupportedGrantTypeError', 'UnsupportedTokenTypeError', 'MissingCodeException', diff --git a/authlib/oauth2/rfc6749/authorization_server.py b/authlib/oauth2/rfc6749/authorization_server.py index 55933676..23e072d6 100644 --- a/authlib/oauth2/rfc6749/authorization_server.py +++ b/authlib/oauth2/rfc6749/authorization_server.py @@ -3,6 +3,7 @@ OAuth2Error, InvalidGrantError, InvalidScopeError, + UnsupportedResponseTypeError, UnsupportedGrantTypeError, ) from .util import scope_to_list @@ -147,7 +148,7 @@ def get_authorization_grant(self, request): for (grant_cls, extensions) in self._authorization_grants: if grant_cls.check_authorization_endpoint(request): return _create_grant(grant_cls, extensions, request, self) - raise InvalidGrantError(f'Response type "{request.response_type}" is not supported') + raise UnsupportedResponseTypeError(request.response_type) def get_token_grant(self, request): """Find the token grant for current request. @@ -159,7 +160,7 @@ def get_token_grant(self, request): if grant_cls.check_token_endpoint(request) and \ request.method in grant_cls.TOKEN_ENDPOINT_HTTP_METHODS: return _create_grant(grant_cls, extensions, request, self) - raise UnsupportedGrantTypeError(f'Grant type {request.grant_type} is not supported') + raise UnsupportedGrantTypeError(request.grant_type) def create_endpoint_response(self, name, request=None): """Validate endpoint request and create endpoint response. @@ -189,7 +190,7 @@ def create_authorization_response(self, request=None, grant_user=None): request = self.create_oauth2_request(request) try: grant = self.get_authorization_grant(request) - except InvalidGrantError as error: + except UnsupportedResponseTypeError as error: return self.handle_error_response(request, error) try: diff --git a/authlib/oauth2/rfc6749/errors.py b/authlib/oauth2/rfc6749/errors.py index a36d44b2..53c2dff6 100644 --- a/authlib/oauth2/rfc6749/errors.py +++ b/authlib/oauth2/rfc6749/errors.py @@ -36,8 +36,8 @@ __all__ = [ 'OAuth2Error', 'InsecureTransportError', 'InvalidRequestError', - 'InvalidClientError', 'InvalidGrantError', - 'UnauthorizedClientError', 'UnsupportedGrantTypeError', + 'InvalidClientError', 'UnauthorizedClientError', 'InvalidGrantError', + 'UnsupportedResponseTypeError', 'UnsupportedGrantTypeError', 'InvalidScopeError', 'AccessDeniedError', 'MissingAuthorizationError', 'UnsupportedTokenTypeError', 'MissingCodeException', 'MissingTokenException', @@ -122,6 +122,19 @@ class UnauthorizedClientError(OAuth2Error): error = 'unauthorized_client' +class UnsupportedResponseTypeError(OAuth2Error): + """The authorization server does not support obtaining + an access token using this method.""" + error = 'unsupported_response_type' + + def __init__(self, response_type): + super(UnsupportedResponseTypeError, self).__init__() + self.response_type = response_type + + def get_error_description(self): + return f'response_type={self.response_type} is not supported' + + class UnsupportedGrantTypeError(OAuth2Error): """The authorization grant type is not supported by the authorization server. @@ -130,6 +143,13 @@ class UnsupportedGrantTypeError(OAuth2Error): """ error = 'unsupported_grant_type' + def __init__(self, grant_type): + super(UnsupportedGrantTypeError, self).__init__() + self.grant_type = grant_type + + def get_error_description(self): + return f'grant_type={self.grant_type} is not supported' + class InvalidScopeError(OAuth2Error): """The requested scope is invalid, unknown, malformed, or diff --git a/authlib/oauth2/rfc6749/grants/authorization_code.py b/authlib/oauth2/rfc6749/grants/authorization_code.py index c9f08e2b..19e765f2 100644 --- a/authlib/oauth2/rfc6749/grants/authorization_code.py +++ b/authlib/oauth2/rfc6749/grants/authorization_code.py @@ -208,7 +208,8 @@ def validate_token_request(self): log.debug('Validate token request of %r', client) if not client.check_grant_type(self.GRANT_TYPE): - raise UnauthorizedClientError() + raise UnauthorizedClientError( + f'The client is not authorized to use "grant_type={self.GRANT_TYPE}"') code = self.request.form.get('code') if code is None: diff --git a/authlib/oauth2/rfc6749/grants/base.py b/authlib/oauth2/rfc6749/grants/base.py index 5762a260..4412be92 100644 --- a/authlib/oauth2/rfc6749/grants/base.py +++ b/authlib/oauth2/rfc6749/grants/base.py @@ -130,16 +130,15 @@ def validate_authorization_redirect_uri(request, client): if request.redirect_uri: if not client.check_redirect_uri(request.redirect_uri): raise InvalidRequestError( - 'Redirect URI {!r} is not supported by client.'.format(request.redirect_uri), - state=request.state, - ) + f'Redirect URI {request.redirect_uri} is not supported by client.', + state=request.state) return request.redirect_uri else: redirect_uri = client.get_default_redirect_uri() if not redirect_uri: raise InvalidRequestError( - 'Missing "redirect_uri" in request.' - ) + 'Missing "redirect_uri" in request.', + state=request.state) return redirect_uri def validate_consent_request(self): diff --git a/tests/flask/test_oauth2/test_oauth2_server.py b/tests/flask/test_oauth2/test_oauth2_server.py index 37e55380..2328e609 100644 --- a/tests/flask/test_oauth2/test_oauth2_server.py +++ b/tests/flask/test_oauth2/test_oauth2_server.py @@ -70,7 +70,7 @@ def test_none_grant(self): '&client_id=implicit-client' ) rv = self.client.get(authorize_url) - self.assertIn(b'invalid_grant', rv.data) + self.assertIn(b'unsupported_response_type', rv.data) rv = self.client.post(authorize_url, data={'user_id': '1'}) self.assertNotEqual(rv.status, 200) diff --git a/tests/flask/test_oauth2/test_openid_hybrid_grant.py b/tests/flask/test_oauth2/test_openid_hybrid_grant.py index 4f274bd8..c9e4a6c9 100644 --- a/tests/flask/test_oauth2/test_openid_hybrid_grant.py +++ b/tests/flask/test_oauth2/test_openid_hybrid_grant.py @@ -130,7 +130,7 @@ def test_invalid_response_type(self): 'user_id': '1', }) resp = json.loads(rv.data) - self.assertEqual(resp['error'], 'invalid_grant') + self.assertEqual(resp['error'], 'unsupported_response_type') def test_invalid_scope(self): self.prepare_data() From d27916e1fe9b45c1edea15ec38dd53167e9b1da6 Mon Sep 17 00:00:00 2001 From: Hsiaoming Yang Date: Sat, 21 Nov 2020 15:49:12 +0900 Subject: [PATCH 12/26] Refactor multiple scopes support on resource protector --- .../django_oauth2/resource_protector.py | 15 ++++++------ .../flask_oauth2/resource_protector.py | 22 ++++++++--------- authlib/oauth2/rfc6749/resource_protector.py | 6 ++--- authlib/oauth2/rfc6750/validator.py | 24 +++++++++---------- docs/django/2/resource-server.rst | 22 ++++++++++------- docs/flask/2/resource-server.rst | 22 +++++++++-------- .../test_oauth2/test_resource_protector.py | 17 ++----------- tests/flask/test_oauth2/test_oauth2_server.py | 15 ++---------- 8 files changed, 61 insertions(+), 82 deletions(-) diff --git a/authlib/integrations/django_oauth2/resource_protector.py b/authlib/integrations/django_oauth2/resource_protector.py index e6b4ea96..1dc36965 100644 --- a/authlib/integrations/django_oauth2/resource_protector.py +++ b/authlib/integrations/django_oauth2/resource_protector.py @@ -15,28 +15,27 @@ class ResourceProtector(_ResourceProtector): - def acquire_token(self, request, scope=None, operator='AND'): + def acquire_token(self, request, scopes=None): """A method to acquire current valid token with the given scope. :param request: Django HTTP request instance - :param scope: string or list of scope values - :param operator: value of "AND" or "OR" + :param scopes: a list of scope values :return: token object """ url = request.get_raw_uri() req = HttpRequest(request.method, url, request.body, request.headers) - if not callable(operator): - operator = operator.upper() - token = self.validate_request(scope, req, operator) + if isinstance(scopes, str): + scopes = [scopes] + token = self.validate_request(scopes, req) token_authenticated.send(sender=self.__class__, token=token) return token - def __call__(self, scope=None, operator='AND', optional=False): + def __call__(self, scopes=None, optional=False): def wrapper(f): @functools.wraps(f) def decorated(request, *args, **kwargs): try: - token = self.acquire_token(request, scope, operator) + token = self.acquire_token(request, scopes) request.oauth_token = token except MissingAuthorizationError as error: if optional: diff --git a/authlib/integrations/flask_oauth2/resource_protector.py b/authlib/integrations/flask_oauth2/resource_protector.py index 41535f35..7f7f6540 100644 --- a/authlib/integrations/flask_oauth2/resource_protector.py +++ b/authlib/integrations/flask_oauth2/resource_protector.py @@ -43,7 +43,7 @@ def token_revoked(self, token): # protect resource with require_oauth @app.route('/user') - @require_oauth('profile') + @require_oauth(['profile']) def user_profile(): user = User.query.get(current_token.user_id) return jsonify(user.to_dict()) @@ -61,11 +61,10 @@ def raise_error_response(self, error): headers = error.get_headers() raise_http_exception(status, body, headers) - def acquire_token(self, scope=None, operator='AND'): + def acquire_token(self, scopes=None): """A method to acquire current valid token with the given scope. - :param scope: string or list of scope values - :param operator: value of "AND" or "OR" + :param scopes: a list of scope values :return: token object """ request = HttpRequest( @@ -74,16 +73,17 @@ def acquire_token(self, scope=None, operator='AND'): _req.data, _req.headers ) - if not callable(operator): - operator = operator.upper() - token = self.validate_request(scope, request, operator) + # backward compatible + if isinstance(scopes, str): + scopes = [scopes] + token = self.validate_request(scopes, request) token_authenticated.send(self, token=token) ctx = _app_ctx_stack.top ctx.authlib_server_oauth2_token = token return token @contextmanager - def acquire(self, scope=None, operator='AND'): + def acquire(self, scopes=None): """The with statement of ``require_oauth``. Instead of using a decorator, you can use a with statement instead:: @@ -94,16 +94,16 @@ def user_api(): return jsonify(user.to_dict()) """ try: - yield self.acquire_token(scope, operator) + yield self.acquire_token(scopes) except OAuth2Error as error: self.raise_error_response(error) - def __call__(self, scope=None, operator='AND', optional=False): + def __call__(self, scopes=None, optional=False): def wrapper(f): @functools.wraps(f) def decorated(*args, **kwargs): try: - self.acquire_token(scope, operator) + self.acquire_token(scopes) except MissingAuthorizationError as error: if optional: return f(*args, **kwargs) diff --git a/authlib/oauth2/rfc6749/resource_protector.py b/authlib/oauth2/rfc6749/resource_protector.py index a4d4f942..b4fe667d 100644 --- a/authlib/oauth2/rfc6749/resource_protector.py +++ b/authlib/oauth2/rfc6749/resource_protector.py @@ -20,7 +20,7 @@ def __init__(self, realm=None, **extra_attributes): self.realm = realm self.extra_attributes = extra_attributes - def __call__(self, token_string, scope, request, scope_operator='AND'): + def __call__(self, token_string, scopes, request): raise NotImplementedError() @@ -38,7 +38,7 @@ def register_token_validator(self, validator: TokenValidator): if validator.TOKEN_TYPE not in self._token_validators: self._token_validators[validator.TOKEN_TYPE] = validator - def validate_request(self, scope, request, scope_operator='AND'): + def validate_request(self, scopes, request): auth = request.headers.get('Authorization') if not auth: raise MissingAuthorizationError(self._default_auth_type, self._default_realm) @@ -54,4 +54,4 @@ def validate_request(self, scope, request, scope_operator='AND'): if not validator: raise UnsupportedTokenTypeError(self._default_auth_type, self._default_realm) - return validator(token_string, scope, request, scope_operator) + return validator(token_string, scopes, request) diff --git a/authlib/oauth2/rfc6750/validator.py b/authlib/oauth2/rfc6750/validator.py index aa4ac8f8..d162edcf 100644 --- a/authlib/oauth2/rfc6750/validator.py +++ b/authlib/oauth2/rfc6750/validator.py @@ -45,8 +45,8 @@ def request_invalid(self, request): """ raise NotImplementedError() - def scope_insufficient(self, token, scope, operator='AND'): - if not scope: + def scope_insufficient(self, token, scopes): + if not scopes: return False token_scopes = scope_to_list(token.get_scope()) @@ -54,16 +54,14 @@ def scope_insufficient(self, token, scope, operator='AND'): return True token_scopes = set(token_scopes) - resource_scopes = set(scope_to_list(scope)) - if operator == 'AND': - return not token_scopes.issuperset(resource_scopes) - if operator == 'OR': - return not token_scopes & resource_scopes - if callable(operator): - return not operator(token_scopes, resource_scopes) - raise ValueError('Invalid operator value') - - def __call__(self, token_string, scope, request, scope_operator='AND'): + for scope in scopes: + resource_scopes = set(scope_to_list(scope)) + if token_scopes.issuperset(resource_scopes): + return False + + return True + + def __call__(self, token_string, scopes, request): if self.request_invalid(request): raise InvalidRequestError() token = self.authenticate_token(token_string) @@ -73,6 +71,6 @@ def __call__(self, token_string, scope, request, scope_operator='AND'): raise InvalidTokenError(realm=self.realm, extra_attributes=self.extra_attributes) if token.is_revoked(): raise InvalidTokenError(realm=self.realm, extra_attributes=self.extra_attributes) - if self.scope_insufficient(token, scope, scope_operator): + if self.scope_insufficient(token, scopes): raise InsufficientScopeError() return token diff --git a/docs/django/2/resource-server.rst b/docs/django/2/resource-server.rst index a1e32815..76d95b31 100644 --- a/docs/django/2/resource-server.rst +++ b/docs/django/2/resource-server.rst @@ -38,12 +38,14 @@ which is the instance of current in-use Token. Multiple Scopes --------------- -You can apply multiple scopes to one endpoint in **AND** and **OR** modes. -The default is **AND** mode. +.. versionchanged:: v1.0 + +You can apply multiple scopes to one endpoint in **AND**, **OR** and mix modes. +Here are some examples: .. code-block:: python - @require_oauth('profile email', 'AND') + @require_oauth(['profile email']) def user_profile(request): user = request.oauth_token.user return JsonResponse(dict(sub=user.pk, username=user.username)) @@ -52,24 +54,26 @@ It requires the token containing both ``profile`` and ``email`` scope. .. code-block:: python - @require_oauth('profile email', 'OR') + @require_oauth(['profile', 'email']) def user_profile(request): user = request.oauth_token.user return JsonResponse(dict(sub=user.pk, username=user.username)) It requires the token containing either ``profile`` or ``email`` scope. -It is also possible to pass a function as the scope operator. e.g.:: - def scope_operator(token_scopes, resource_scopes): - # this equals "AND" - return token_scopes.issuperset(resource_scopes) +It is also possible to mix **AND** and **OR** logic. e.g.:: - @require_oauth('profile email', scope_operator) + @app.route('/profile') + @require_oauth(['profile email', 'user']) def user_profile(request): user = request.oauth_token.user return JsonResponse(dict(sub=user.pk, username=user.username)) +This means if the token will be valid if: + +1. token contains both ``profile`` and ``email`` scope +2. or token contains ``user`` scope Optional ``require_oauth`` -------------------------- diff --git a/docs/flask/2/resource-server.rst b/docs/flask/2/resource-server.rst index 2bbbef7b..849cb255 100644 --- a/docs/flask/2/resource-server.rst +++ b/docs/flask/2/resource-server.rst @@ -73,13 +73,15 @@ If decorator is not your favorite, there is a ``with`` statement for you:: Multiple Scopes --------------- -You can apply multiple scopes to one endpoint in **AND** and **OR** modes. -The default is **AND** mode. +.. versionchanged:: v1.0 + +You can apply multiple scopes to one endpoint in **AND**, **OR** and mix modes. +Here are some examples: .. code-block:: python @app.route('/profile') - @require_oauth('profile email', 'AND') + @require_oauth(['profile email']) def user_profile(): user = current_token.user return jsonify(user) @@ -89,25 +91,25 @@ It requires the token containing both ``profile`` and ``email`` scope. .. code-block:: python @app.route('/profile') - @require_oauth('profile email', 'OR') + @require_oauth(['profile', 'email']') def user_profile(): user = current_token.user return jsonify(user) It requires the token containing either ``profile`` or ``email`` scope. -It is also possible to pass a function as the scope operator. e.g.:: - - def scope_operator(token_scopes, resource_scopes): - # this equals "AND" - return token_scopes.issuperset(resource_scopes) +It is also possible to mix **AND** and **OR** logic. e.g.:: @app.route('/profile') - @require_oauth('profile email', scope_operator) + @require_oauth(['profile email', 'user']) def user_profile(): user = current_token.user return jsonify(user) +This means if the token will be valid if: + +1. token contains both ``profile`` and ``email`` scope +2. or token contains ``user`` scope Optional ``require_oauth`` -------------------------- diff --git a/tests/django/test_oauth2/test_resource_protector.py b/tests/django/test_oauth2/test_resource_protector.py index 4312b895..bb18e821 100644 --- a/tests/django/test_oauth2/test_resource_protector.py +++ b/tests/django/test_oauth2/test_resource_protector.py @@ -110,12 +110,12 @@ def get_user_profile(request): def test_scope_operator(self): self.prepare_data() - @require_oauth('profile email', 'AND') + @require_oauth(['profile email']) def operator_and(request): user = request.oauth_token.user return JsonResponse(dict(sub=user.pk, username=user.username)) - @require_oauth('profile email', 'OR') + @require_oauth(['profile', 'email']) def operator_or(request): user = request.oauth_token.user return JsonResponse(dict(sub=user.pk, username=user.username)) @@ -130,16 +130,3 @@ def operator_or(request): self.assertEqual(resp.status_code, 200) data = json.loads(resp.content) self.assertEqual(data['username'], 'foo') - - def scope_operator(token_scopes, resource_scopes): - return 'profile' in token_scopes and 'email' not in token_scopes - - @require_oauth(operator=scope_operator) - def operator_func(request): - user = request.oauth_token.user - return JsonResponse(dict(sub=user.pk, username=user.username)) - - resp = operator_func(request) - self.assertEqual(resp.status_code, 200) - data = json.loads(resp.content) - self.assertEqual(data['username'], 'foo') diff --git a/tests/flask/test_oauth2/test_oauth2_server.py b/tests/flask/test_oauth2/test_oauth2_server.py index 2328e609..5c25954a 100644 --- a/tests/flask/test_oauth2/test_oauth2_server.py +++ b/tests/flask/test_oauth2/test_oauth2_server.py @@ -29,23 +29,15 @@ def public_info(): return jsonify(status='ok') @app.route('/operator-and') - @require_oauth('profile email', 'AND') + @require_oauth(['profile email']) def operator_and(): return jsonify(status='ok') @app.route('/operator-or') - @require_oauth('profile email', 'OR') + @require_oauth(['profile', 'email']) def operator_or(): return jsonify(status='ok') - def scope_operator(token_scopes, resource_scopes): - return 'profile' in token_scopes and 'email' not in token_scopes - - @app.route('/operator-func') - @require_oauth(operator=scope_operator) - def operator_func(): - return jsonify(status='ok') - @app.route('/acquire') def test_acquire(): with require_oauth.acquire('profile') as token: @@ -188,9 +180,6 @@ def test_scope_operator(self): rv = self.client.get('/operator-or', headers=headers) self.assertEqual(rv.status_code, 200) - rv = self.client.get('/operator-func', headers=headers) - self.assertEqual(rv.status_code, 200) - def test_optional_token(self): self.prepare_data() rv = self.client.get('/optional') From bff1c7225d9b8e76d19e866accb6e337db0b4477 Mon Sep 17 00:00:00 2001 From: Hsiaoming Yang Date: Sun, 22 Nov 2020 11:16:45 +0900 Subject: [PATCH 13/26] Refactor token validator and resource protector --- authlib/oauth2/rfc6749/resource_protector.py | 69 +++++++++++++++++++- authlib/oauth2/rfc6750/__init__.py | 4 +- authlib/oauth2/rfc6750/errors.py | 3 +- authlib/oauth2/rfc6750/validator.py | 41 +++--------- 4 files changed, 79 insertions(+), 38 deletions(-) diff --git a/authlib/oauth2/rfc6749/resource_protector.py b/authlib/oauth2/rfc6749/resource_protector.py index b4fe667d..79a39151 100644 --- a/authlib/oauth2/rfc6749/resource_protector.py +++ b/authlib/oauth2/rfc6749/resource_protector.py @@ -20,7 +20,47 @@ def __init__(self, realm=None, **extra_attributes): self.realm = realm self.extra_attributes = extra_attributes - def __call__(self, token_string, scopes, request): + def authenticate_token(self, token_string): + """A method to query token from database with the given token string. + Developers MUST re-implement this method. For instance:: + + def authenticate_token(self, token_string): + return get_token_from_database(token_string) + + :param token_string: A string to represent the access_token. + :return: token + """ + raise NotImplementedError() + + def validate_request(self, request): + """A method to validate if the HTTP request is valid or not. Developers MUST + re-implement this method. For instance, your server requires a + "X-Device-Version" in the header:: + + def validate_request(self, request): + if 'X-Device-Version' not in request.headers: + raise InvalidRequestError() + + Usually, you don't have to detect if the request is valid or not. If you have + to, you MUST re-implement this method. + + :param request: instance of HttpRequest + :raise: InvalidRequestError + """ + + def validate_token(self, token, scopes): + """A method to validate if the authorized token is valid, if it has the + permission on the given scopes. Developers MUST re-implement this method. + e.g, check if token is expired, revoked:: + + def validate_token(self, token, scopes): + if not token: + raise InvalidTokenError() + if token.is_expired() or token.is_revoked(): + raise InvalidTokenError() + if not match_token_scopes(token, scopes): + raise InsufficientScopeError() + """ raise NotImplementedError() @@ -31,6 +71,9 @@ def __init__(self): self._default_auth_type = None def register_token_validator(self, validator: TokenValidator): + """Register a token validator for a given Authorization type. + Authlib has a built-in BearerTokenValidator per rfc6750. + """ if not self._default_auth_type: self._default_realm = validator.realm self._default_auth_type = validator.TOKEN_TYPE @@ -38,7 +81,19 @@ def register_token_validator(self, validator: TokenValidator): if validator.TOKEN_TYPE not in self._token_validators: self._token_validators[validator.TOKEN_TYPE] = validator - def validate_request(self, scopes, request): + def parse_request_authorization(self, request): + """Parse the token and token validator from request Authorization header. + Here is an example of Authorization header:: + + Authorization: Bearer a-token-string + + This method will parse this header, if it can find the validator for + ``Bearer``, it will return the validator and ``a-token-string``. + + :return: validator, token_string + :raise: MissingAuthorizationError + :raise: UnsupportedTokenTypeError + """ auth = request.headers.get('Authorization') if not auth: raise MissingAuthorizationError(self._default_auth_type, self._default_realm) @@ -54,4 +109,12 @@ def validate_request(self, scopes, request): if not validator: raise UnsupportedTokenTypeError(self._default_auth_type, self._default_realm) - return validator(token_string, scopes, request) + return validator, token_string + + def validate_request(self, scopes, request): + """Validate the request and return a token.""" + validator, token_string = self.parse_request_authorization(request) + validator.validate_request(request) + token = validator.authenticate_token(token_string) + validator.validate_token(token, scopes) + return token diff --git a/authlib/oauth2/rfc6750/__init__.py b/authlib/oauth2/rfc6750/__init__.py index 4ad02126..6539f4cb 100644 --- a/authlib/oauth2/rfc6750/__init__.py +++ b/authlib/oauth2/rfc6750/__init__.py @@ -9,14 +9,14 @@ https://tools.ietf.org/html/rfc6750 """ -from .errors import InvalidRequestError, InvalidTokenError, InsufficientScopeError +from .errors import InvalidTokenError, InsufficientScopeError from .parameters import add_bearer_token from .wrappers import BearerToken from .validator import BearerTokenValidator __all__ = [ - 'InvalidRequestError', 'InvalidTokenError', 'InsufficientScopeError', + 'InvalidTokenError', 'InsufficientScopeError', 'add_bearer_token', 'BearerToken', 'BearerTokenValidator', diff --git a/authlib/oauth2/rfc6750/errors.py b/authlib/oauth2/rfc6750/errors.py index 26ca34ff..3ce462a3 100644 --- a/authlib/oauth2/rfc6750/errors.py +++ b/authlib/oauth2/rfc6750/errors.py @@ -12,10 +12,9 @@ :copyright: (c) 2017 by Hsiaoming Yang. """ from ..base import OAuth2Error -from ..rfc6749.errors import InvalidRequestError __all__ = [ - 'InvalidRequestError', 'InvalidTokenError', 'InsufficientScopeError' + 'InvalidTokenError', 'InsufficientScopeError' ] diff --git a/authlib/oauth2/rfc6750/validator.py b/authlib/oauth2/rfc6750/validator.py index d162edcf..eff26524 100644 --- a/authlib/oauth2/rfc6750/validator.py +++ b/authlib/oauth2/rfc6750/validator.py @@ -8,7 +8,6 @@ from ..rfc6749.util import scope_to_list from ..rfc6749 import TokenValidator from .errors import ( - InvalidRequestError, InvalidTokenError, InsufficientScopeError ) @@ -29,21 +28,16 @@ def authenticate_token(self, token_string): """ raise NotImplementedError() - def request_invalid(self, request): - """Check if the HTTP request is valid or not. Developers MUST - re-implement this method. For instance, your server requires a - "X-Device-Version" in the header:: - - def request_invalid(self, request): - return 'X-Device-Version' in request.headers - - Usually, you don't have to detect if the request is valid or not, - you can just return a ``False``. - - :param request: instance of HttpRequest - :return: Boolean - """ - raise NotImplementedError() + def validate_token(self, token, scopes): + """Check if token is active and matches the requested scopes.""" + if not token: + raise InvalidTokenError(realm=self.realm, extra_attributes=self.extra_attributes) + if token.is_expired(): + raise InvalidTokenError(realm=self.realm, extra_attributes=self.extra_attributes) + if token.is_revoked(): + raise InvalidTokenError(realm=self.realm, extra_attributes=self.extra_attributes) + if self.scope_insufficient(token, scopes): + raise InsufficientScopeError() def scope_insufficient(self, token, scopes): if not scopes: @@ -58,19 +52,4 @@ def scope_insufficient(self, token, scopes): resource_scopes = set(scope_to_list(scope)) if token_scopes.issuperset(resource_scopes): return False - return True - - def __call__(self, token_string, scopes, request): - if self.request_invalid(request): - raise InvalidRequestError() - token = self.authenticate_token(token_string) - if not token: - raise InvalidTokenError(realm=self.realm, extra_attributes=self.extra_attributes) - if token.is_expired(): - raise InvalidTokenError(realm=self.realm, extra_attributes=self.extra_attributes) - if token.is_revoked(): - raise InvalidTokenError(realm=self.realm, extra_attributes=self.extra_attributes) - if self.scope_insufficient(token, scopes): - raise InsufficientScopeError() - return token From ffeeaa9fd7b5bc4ea7cae9fcf0c2ad9d7f5cf22a Mon Sep 17 00:00:00 2001 From: Hsiaoming Yang Date: Sun, 22 Nov 2020 11:58:28 +0900 Subject: [PATCH 14/26] split get_token_validator method on resource protector --- authlib/oauth2/rfc6749/resource_protector.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/authlib/oauth2/rfc6749/resource_protector.py b/authlib/oauth2/rfc6749/resource_protector.py index 79a39151..3dea497c 100644 --- a/authlib/oauth2/rfc6749/resource_protector.py +++ b/authlib/oauth2/rfc6749/resource_protector.py @@ -81,6 +81,13 @@ def register_token_validator(self, validator: TokenValidator): if validator.TOKEN_TYPE not in self._token_validators: self._token_validators[validator.TOKEN_TYPE] = validator + def get_token_validator(self, token_type): + """Get token validator from registry for the given token type.""" + validator = self._token_validators.get(token_type.lower()) + if not validator: + raise UnsupportedTokenTypeError(self._default_auth_type, self._default_realm) + return validator + def parse_request_authorization(self, request): """Parse the token and token validator from request Authorization header. Here is an example of Authorization header:: @@ -104,11 +111,7 @@ def parse_request_authorization(self, request): raise UnsupportedTokenTypeError(self._default_auth_type, self._default_realm) token_type, token_string = token_parts - - validator = self._token_validators.get(token_type.lower()) - if not validator: - raise UnsupportedTokenTypeError(self._default_auth_type, self._default_realm) - + validator = self.get_token_validator(token_type) return validator, token_string def validate_request(self, scopes, request): From 1e641c6116bacfded9e3a4976bec6438845f1b23 Mon Sep 17 00:00:00 2001 From: Hsiaoming Yang Date: Tue, 24 Nov 2020 21:34:57 +0900 Subject: [PATCH 15/26] Use setup.cfg for metadata --- authlib/jose/__init__.py | 8 ++++---- setup.cfg | 23 +++++++++++++++++++++++ setup.py | 31 ------------------------------- 3 files changed, 27 insertions(+), 35 deletions(-) diff --git a/authlib/jose/__init__.py b/authlib/jose/__init__.py index ec6cfb4c..208292bc 100644 --- a/authlib/jose/__init__.py +++ b/authlib/jose/__init__.py @@ -50,13 +50,13 @@ __all__ = [ 'JoseError', - 'JWS', 'JsonWebSignature', 'JWSAlgorithm', 'JWSHeader', 'JWSObject', - 'JWE', 'JsonWebEncryption', 'JWEAlgorithm', 'JWEEncAlgorithm', 'JWEZipAlgorithm', + 'JsonWebSignature', 'JWSAlgorithm', 'JWSHeader', 'JWSObject', + 'JsonWebEncryption', 'JWEAlgorithm', 'JWEEncAlgorithm', 'JWEZipAlgorithm', - 'JWK', 'JsonWebKey', 'Key', 'KeySet', + 'JsonWebKey', 'Key', 'KeySet', 'OctKey', 'RSAKey', 'ECKey', 'OKPKey', - 'JWT', 'JsonWebToken', 'BaseClaims', 'JWTClaims', + 'JsonWebToken', 'BaseClaims', 'JWTClaims', 'jwt', ] diff --git a/setup.cfg b/setup.cfg index 5cb2bc23..fc49e748 100644 --- a/setup.cfg +++ b/setup.cfg @@ -2,7 +2,30 @@ universal = 1 [metadata] +author = Hsiaoming Yang +author_email = me@lepture.com license_file = LICENSE +description = The ultimate Python library in building OAuth and OpenID Connect servers and clients. +long_description = file: README.rst +long_description_content_type = text/x-rst +classifiers = + Development Status :: 4 - Beta + Environment :: Console + Environment :: Web Environment + Framework :: Flask + Framework :: Django + Intended Audience :: Developers + License :: OSI Approved :: BSD License + Operating System :: OS Independent + Programming Language :: Python + Programming Language :: Python :: 3 + Programming Language :: Python :: 3.6 + Programming Language :: Python :: 3.7 + Programming Language :: Python :: 3.8 + Programming Language :: Python :: 3.9 + Topic :: Internet :: WWW/HTTP :: Dynamic Content + Topic :: Internet :: WWW/HTTP :: WSGI :: Application + [check-manifest] ignore = diff --git a/setup.py b/setup.py index 0d229b69..b2beba1c 100755 --- a/setup.py +++ b/setup.py @@ -5,11 +5,6 @@ from setuptools import setup, find_packages from authlib.consts import version, homepage - -with open('README.rst') as f: - readme = f.read() - - client_requires = ['requests'] crypto_requires = ['cryptography>=3.2,<4'] @@ -17,18 +12,11 @@ setup( name='Authlib', version=version, - author='Hsiaoming Yang', - author_email='me@lepture.com', url=homepage, packages=find_packages(include=('authlib', 'authlib.*')), - description=( - 'The ultimate Python library in building OAuth and ' - 'OpenID Connect servers.' - ), zip_safe=False, include_package_data=True, platforms='any', - long_description=readme, license='BSD-3-Clause', install_requires=crypto_requires, extras_require={ @@ -42,23 +30,4 @@ 'Blog': 'https://blog.authlib.org/', 'Donate': 'https://lepture.com/donate', }, - classifiers=[ - 'Development Status :: 4 - Beta', - 'Environment :: Console', - 'Environment :: Web Environment', - 'Framework :: Flask', - 'Framework :: Django', - 'Intended Audience :: Developers', - 'License :: OSI Approved :: BSD License', - 'Operating System :: OS Independent', - 'Programming Language :: Python', - 'Programming Language :: Python :: 3', - 'Programming Language :: Python :: 3.6', - 'Programming Language :: Python :: 3.7', - 'Programming Language :: Python :: 3.8', - 'Programming Language :: Python :: 3.9', - 'Topic :: Internet :: WWW/HTTP :: Dynamic Content', - 'Topic :: Internet :: WWW/HTTP :: WSGI :: Application', - 'Topic :: Software Development :: Libraries :: Python Modules', - ] ) From 3681c4656087e553ed5ac68993fa9c872566bf88 Mon Sep 17 00:00:00 2001 From: Hsiaoming Yang Date: Tue, 24 Nov 2020 21:35:35 +0900 Subject: [PATCH 16/26] Refactor device authorization endpoint 1. Device authorization endpoint Accept many client auth methods 2. validate scope with client --- .../oauth2/rfc6749/authorization_server.py | 2 +- authlib/oauth2/rfc6749/grants/base.py | 3 +- authlib/oauth2/rfc6749/token_endpoint.py | 5 +--- authlib/oauth2/rfc8628/endpoint.py | 29 ++++++++++++++++--- 4 files changed, 29 insertions(+), 10 deletions(-) diff --git a/authlib/oauth2/rfc6749/authorization_server.py b/authlib/oauth2/rfc6749/authorization_server.py index 23e072d6..b1d1560a 100644 --- a/authlib/oauth2/rfc6749/authorization_server.py +++ b/authlib/oauth2/rfc6749/authorization_server.py @@ -101,7 +101,7 @@ def handle_response(self, status, body, headers): """Return HTTP response. Framework MUST implement this function.""" raise NotImplementedError() - def validate_requested_scope(self, scope, state=None): + def validate_requested_scope(self, scope, client, state=None): """Validate if requested scope is supported by Authorization Server. Developers CAN re-write this method to meet your needs. """ diff --git a/authlib/oauth2/rfc6749/grants/base.py b/authlib/oauth2/rfc6749/grants/base.py index 4412be92..75fb5f2e 100644 --- a/authlib/oauth2/rfc6749/grants/base.py +++ b/authlib/oauth2/rfc6749/grants/base.py @@ -85,8 +85,9 @@ def save_token(self, token): def validate_requested_scope(self): """Validate if requested scope is supported by Authorization Server.""" scope = self.request.scope + client = self.request.client state = self.request.state - return self.server.validate_requested_scope(scope, state) + return self.server.validate_requested_scope(scope, client, state) def register_hook(self, hook_type, hook): if hook_type not in self._hooks: diff --git a/authlib/oauth2/rfc6749/token_endpoint.py b/authlib/oauth2/rfc6749/token_endpoint.py index a5c6e5ff..5d001348 100644 --- a/authlib/oauth2/rfc6749/token_endpoint.py +++ b/authlib/oauth2/rfc6749/token_endpoint.py @@ -20,10 +20,7 @@ def create_endpoint_request(self, request): def authenticate_endpoint_client(self, request): """Authentication client for endpoint with ``CLIENT_AUTH_METHODS``. """ - client = self.server.authenticate_client( - request=request, - methods=self.CLIENT_AUTH_METHODS, - ) + client = self.server.authenticate_client(request, self.CLIENT_AUTH_METHODS) request.client = client return client diff --git a/authlib/oauth2/rfc8628/endpoint.py b/authlib/oauth2/rfc8628/endpoint.py index fda5f1a3..2e820085 100644 --- a/authlib/oauth2/rfc8628/endpoint.py +++ b/authlib/oauth2/rfc8628/endpoint.py @@ -1,7 +1,6 @@ from authlib.consts import default_json_headers from authlib.common.security import generate_token from authlib.common.urls import add_params_to_uri -from ..rfc6749.errors import InvalidRequestError class DeviceAuthorizationEndpoint(object): @@ -46,6 +45,7 @@ class DeviceAuthorizationEndpoint(object): """ ENDPOINT_NAME = 'device_authorization' + CLIENT_AUTH_METHODS = ['client_secret_basic', 'client_secret_post', 'none'] #: customize "user_code" type, string or digital USER_CODE_TYPE = 'string' @@ -68,12 +68,33 @@ def __call__(self, request): def create_endpoint_request(self, request): return self.server.create_oauth2_request(request) + def authenticate_client(self, request): + """client_id is REQUIRED **if the client is not** authenticating with the + authorization server as described in Section 3.2.1. of [RFC6749]. + + This means the endpoint support "none" authentication method. In this case, + this endpoint's auth methods are: + + - client_secret_basic + - client_secret_post + - none + + Developers change the value of ``CLIENT_AUTH_METHODS`` in subclass. For + instance:: + + class MyDeviceAuthorizationEndpoint(DeviceAuthorizationEndpoint): + # only support ``client_secret_basic`` auth method + CLIENT_AUTH_METHODS = ['client_secret_basic'] + """ + client = self.server.authenticate_client(request, self.CLIENT_AUTH_METHODS) + request.client = client + return client + def create_endpoint_response(self, request): # https://tools.ietf.org/html/rfc8628#section-3.1 - if not request.client_id: - raise InvalidRequestError('Missing "client_id" in payload') - self.server.validate_requested_scope(request.scope) + client = self.authenticate_client(request) + self.server.validate_requested_scope(request.scope, client) device_code = self.generate_device_code() user_code = self.generate_user_code() From 20b994cf45944e8c754035f92ab58f9a640c6a2d Mon Sep 17 00:00:00 2001 From: Hsiaoming Yang Date: Tue, 24 Nov 2020 21:54:28 +0900 Subject: [PATCH 17/26] Refactor, move get_allowed_scope to BearerToken --- authlib/oauth2/rfc6749/authorization_server.py | 3 +-- authlib/oauth2/rfc6749/grants/base.py | 11 ++--------- authlib/oauth2/rfc6750/wrappers.py | 7 +++++++ authlib/oauth2/rfc8628/endpoint.py | 4 ++-- 4 files changed, 12 insertions(+), 13 deletions(-) diff --git a/authlib/oauth2/rfc6749/authorization_server.py b/authlib/oauth2/rfc6749/authorization_server.py index b1d1560a..770efc34 100644 --- a/authlib/oauth2/rfc6749/authorization_server.py +++ b/authlib/oauth2/rfc6749/authorization_server.py @@ -1,7 +1,6 @@ from .authenticate_client import ClientAuthentication from .errors import ( OAuth2Error, - InvalidGrantError, InvalidScopeError, UnsupportedResponseTypeError, UnsupportedGrantTypeError, @@ -101,7 +100,7 @@ def handle_response(self, status, body, headers): """Return HTTP response. Framework MUST implement this function.""" raise NotImplementedError() - def validate_requested_scope(self, scope, client, state=None): + def validate_requested_scope(self, scope, state=None): """Validate if requested scope is supported by Authorization Server. Developers CAN re-write this method to meet your needs. """ diff --git a/authlib/oauth2/rfc6749/grants/base.py b/authlib/oauth2/rfc6749/grants/base.py index 75fb5f2e..7659ba07 100644 --- a/authlib/oauth2/rfc6749/grants/base.py +++ b/authlib/oauth2/rfc6749/grants/base.py @@ -33,16 +33,10 @@ def client(self): def generate_token(self, user=None, scope=None, grant_type=None, expires_in=None, include_refresh_token=True): - if grant_type is None: grant_type = self.GRANT_TYPE - - client = self.request.client - if scope is not None: - scope = client.get_allowed_scope(scope) - return self.server.generate_token( - client=client, + client=self.request.client, grant_type=grant_type, user=user, scope=scope, @@ -85,9 +79,8 @@ def save_token(self, token): def validate_requested_scope(self): """Validate if requested scope is supported by Authorization Server.""" scope = self.request.scope - client = self.request.client state = self.request.state - return self.server.validate_requested_scope(scope, client, state) + return self.server.validate_requested_scope(scope, state) def register_hook(self, hook_type, hook): if hook_type not in self._hooks: diff --git a/authlib/oauth2/rfc6750/wrappers.py b/authlib/oauth2/rfc6750/wrappers.py index 9e2c226c..4b267dc3 100644 --- a/authlib/oauth2/rfc6750/wrappers.py +++ b/authlib/oauth2/rfc6750/wrappers.py @@ -78,8 +78,15 @@ def _get_expires_in(self, client, grant_type): expires_in = self.DEFAULT_EXPIRES_IN return expires_in + @staticmethod + def get_allowed_scope(client, scope): + if scope: + scope = client.get_allowed_scope(scope) + return scope + def __call__(self, client, grant_type, user=None, scope=None, expires_in=None, include_refresh_token=True): + scope = self.get_allowed_scope(client, scope) access_token = self.access_token_generator(client, grant_type, user, scope) if expires_in is None: expires_in = self._get_expires_in(client, grant_type) diff --git a/authlib/oauth2/rfc8628/endpoint.py b/authlib/oauth2/rfc8628/endpoint.py index 2e820085..06e9f3fd 100644 --- a/authlib/oauth2/rfc8628/endpoint.py +++ b/authlib/oauth2/rfc8628/endpoint.py @@ -93,8 +93,8 @@ class MyDeviceAuthorizationEndpoint(DeviceAuthorizationEndpoint): def create_endpoint_response(self, request): # https://tools.ietf.org/html/rfc8628#section-3.1 - client = self.authenticate_client(request) - self.server.validate_requested_scope(request.scope, client) + self.authenticate_client(request) + self.server.validate_requested_scope(request.scope) device_code = self.generate_device_code() user_code = self.generate_user_code() From b1d14c0f47f7095397ed78d922008c202c2b601b Mon Sep 17 00:00:00 2001 From: Hsiaoming Yang Date: Tue, 24 Nov 2020 22:28:06 +0900 Subject: [PATCH 18/26] refactor client model Use ``.check_endpoint_auth_method`` instead of ``check_token_endpoint_auth_method`` to support more situations --- .../integrations/sqla_oauth2/client_mixin.py | 7 ++- authlib/oauth2/rfc6749/authenticate_client.py | 47 +++++-------------- .../oauth2/rfc6749/authorization_server.py | 4 +- authlib/oauth2/rfc6749/grants/base.py | 3 +- authlib/oauth2/rfc6749/models.py | 20 ++++++-- authlib/oauth2/rfc6749/token_endpoint.py | 3 +- authlib/oauth2/rfc7523/client.py | 2 +- authlib/oauth2/rfc8628/device_code.py | 1 - authlib/oauth2/rfc8628/endpoint.py | 3 +- docs/django/2/authorization-server.rst | 12 ++++- tests/django/test_oauth2/models.py | 6 ++- .../test_oauth2/test_device_code_grant.py | 11 ++++- 12 files changed, 66 insertions(+), 53 deletions(-) diff --git a/authlib/integrations/sqla_oauth2/client_mixin.py b/authlib/integrations/sqla_oauth2/client_mixin.py index b88b4ad8..c8ea2512 100644 --- a/authlib/integrations/sqla_oauth2/client_mixin.py +++ b/authlib/integrations/sqla_oauth2/client_mixin.py @@ -124,8 +124,11 @@ def has_client_secret(self): def check_client_secret(self, client_secret): return self.client_secret == client_secret - def check_token_endpoint_auth_method(self, method): - return self.token_endpoint_auth_method == method + def check_endpoint_auth_method(self, method, endpoint): + if endpoint == 'token': + return self.token_endpoint_auth_method == method + # TODO + return True def check_response_type(self, response_type): return response_type in self.response_types diff --git a/authlib/oauth2/rfc6749/authenticate_client.py b/authlib/oauth2/rfc6749/authenticate_client.py index d21289a1..c07bb282 100644 --- a/authlib/oauth2/rfc6749/authenticate_client.py +++ b/authlib/oauth2/rfc6749/authenticate_client.py @@ -36,11 +36,11 @@ def __init__(self, query_client): def register(self, method, func): self._methods[method] = func - def authenticate(self, request, methods): + def authenticate(self, request, methods, endpoint): for method in methods: func = self._methods[method] client = func(self.query_client, request) - if client: + if client and client.check_endpoint_auth_method(method, endpoint): request.auth_method = method return client @@ -48,8 +48,8 @@ def authenticate(self, request, methods): raise InvalidClientError(state=request.state, status_code=401) raise InvalidClientError(state=request.state) - def __call__(self, request, methods): - return self.authenticate(request, methods) + def __call__(self, request, methods, endpoint='token'): + return self.authenticate(request, methods, endpoint) def authenticate_client_secret_basic(query_client, request): @@ -59,17 +59,10 @@ def authenticate_client_secret_basic(query_client, request): client_id, client_secret = extract_basic_authorization(request.headers) if client_id and client_secret: client = _validate_client(query_client, client_id, request.state, 401) - if client.check_token_endpoint_auth_method('client_secret_basic') \ - and client.check_client_secret(client_secret): - log.debug( - 'Authenticate %s via "client_secret_basic" ' - 'success', client_id - ) + if client.check_client_secret(client_secret): + log.debug(f'Authenticate {client_id} via "client_secret_basic" success') return client - log.debug( - 'Authenticate %s via "client_secret_basic" ' - 'failed', client_id - ) + log.debug(f'Authenticate {client_id} via "client_secret_basic" failed') def authenticate_client_secret_post(query_client, request): @@ -81,17 +74,10 @@ def authenticate_client_secret_post(query_client, request): client_secret = data.get('client_secret') if client_id and client_secret: client = _validate_client(query_client, client_id, request.state) - if client.check_token_endpoint_auth_method('client_secret_post') \ - and client.check_client_secret(client_secret): - log.debug( - 'Authenticate %s via "client_secret_post" ' - 'success', client_id - ) + if client.check_client_secret(client_secret): + log.debug(f'Authenticate {client_id} via "client_secret_post" success') return client - log.debug( - 'Authenticate %s via "client_secret_post" ' - 'failed', client_id - ) + log.debug(f'Authenticate {client_id} via "client_secret_post" failed') def authenticate_none(query_client, request): @@ -101,16 +87,9 @@ def authenticate_none(query_client, request): client_id = request.client_id if client_id and 'client_secret' not in request.data: client = _validate_client(query_client, client_id, request.state) - if client.check_token_endpoint_auth_method('none'): - log.debug( - 'Authenticate %s via "none" ' - 'success', client_id - ) - return client - log.debug( - 'Authenticate {} via "none" ' - 'failed'.format(client_id) - ) + log.debug(f'Authenticate {client_id} via "none" success') + return client + log.debug(f'Authenticate {client_id} via "none" failed') def _validate_client(query_client, client_id, state=None, status_code=400): diff --git a/authlib/oauth2/rfc6749/authorization_server.py b/authlib/oauth2/rfc6749/authorization_server.py index 770efc34..ef408991 100644 --- a/authlib/oauth2/rfc6749/authorization_server.py +++ b/authlib/oauth2/rfc6749/authorization_server.py @@ -33,13 +33,13 @@ def save_token(self, token, request): """Define function to save the generated token into database.""" raise NotImplementedError() - def authenticate_client(self, request, methods): + def authenticate_client(self, request, methods, endpoint='token'): """Authenticate client via HTTP request information with the given methods, such as ``client_secret_basic``, ``client_secret_post``. """ if self._client_auth is None and self.query_client: self._client_auth = ClientAuthentication(self.query_client) - return self._client_auth(request, methods) + return self._client_auth(request, methods, endpoint) def register_client_auth_method(self, method, func): """Add more client auth method. The default methods are: diff --git a/authlib/oauth2/rfc6749/grants/base.py b/authlib/oauth2/rfc6749/grants/base.py index 7659ba07..dcb1a265 100644 --- a/authlib/oauth2/rfc6749/grants/base.py +++ b/authlib/oauth2/rfc6749/grants/base.py @@ -65,8 +65,7 @@ def authenticate_token_endpoint_client(self): :return: client """ client = self.server.authenticate_client( - self.request, - self.TOKEN_ENDPOINT_AUTH_METHODS) + self.request, self.TOKEN_ENDPOINT_AUTH_METHODS) self.server.send_signal( 'after_authenticate_client', client=client, grant=self) diff --git a/authlib/oauth2/rfc6749/models.py b/authlib/oauth2/rfc6749/models.py index 47e5c2d9..e05bc8e4 100644 --- a/authlib/oauth2/rfc6749/models.py +++ b/authlib/oauth2/rfc6749/models.py @@ -4,6 +4,7 @@ This module defines how to construct Client, AuthorizationCode and Token. """ +from authlib.deprecate import deprecate class ClientMixin(object): @@ -91,9 +92,18 @@ def check_client_secret(self, client_secret): """ raise NotImplementedError() - def check_token_endpoint_auth_method(self, method): - """Check client ``token_endpoint_auth_method`` defined via `RFC7591`_. - Values defined by this specification are: + def check_endpoint_auth_method(self, method, endpoint): + """Check if client support the given method for the given endpoint. + There is a ``token_endpoint_auth_method`` defined via `RFC7591`_. + Developers MAY re-implement this method with:: + + def check_endpoint_auth_method(self, method, endpoint): + if endpoint == 'token': + # if client table has ``token_endpoint_auth_method`` + return self.token_endpoint_auth_method == method + return True + + Method values defined by this specification are: * "none": The client is a public client as defined in OAuth 2.0, and does not have a client secret. @@ -108,6 +118,10 @@ def check_token_endpoint_auth_method(self, method): """ raise NotImplementedError() + def check_token_endpoint_auth_method(self, method): + deprecate('Please implement ``check_endpoint_auth_method`` instead.') + return self.check_endpoint_auth_method(method, 'token') + def check_response_type(self, response_type): """Validate if the client can handle the given response_type. There are two response types defined by RFC6749: code and token. For diff --git a/authlib/oauth2/rfc6749/token_endpoint.py b/authlib/oauth2/rfc6749/token_endpoint.py index 5d001348..fb0bd403 100644 --- a/authlib/oauth2/rfc6749/token_endpoint.py +++ b/authlib/oauth2/rfc6749/token_endpoint.py @@ -20,7 +20,8 @@ def create_endpoint_request(self, request): def authenticate_endpoint_client(self, request): """Authentication client for endpoint with ``CLIENT_AUTH_METHODS``. """ - client = self.server.authenticate_client(request, self.CLIENT_AUTH_METHODS) + client = self.server.authenticate_client( + request, self.CLIENT_AUTH_METHODS, self.ENDPOINT_NAME) request.client = client return client diff --git a/authlib/oauth2/rfc7523/client.py b/authlib/oauth2/rfc7523/client.py index cda82c84..8127c7be 100644 --- a/authlib/oauth2/rfc7523/client.py +++ b/authlib/oauth2/rfc7523/client.py @@ -68,7 +68,7 @@ def process_assertion_claims(self, assertion, resolve_key): return claims def authenticate_client(self, client): - if client.check_token_endpoint_auth_method(self.CLIENT_AUTH_METHOD): + if client.check_endpoint_auth_method(self.CLIENT_AUTH_METHOD, 'token'): return client raise InvalidClientError() diff --git a/authlib/oauth2/rfc8628/device_code.py b/authlib/oauth2/rfc8628/device_code.py index 0952afe8..1d560f35 100644 --- a/authlib/oauth2/rfc8628/device_code.py +++ b/authlib/oauth2/rfc8628/device_code.py @@ -1,7 +1,6 @@ import logging from ..rfc6749.errors import ( InvalidRequestError, - InvalidClientError, UnauthorizedClientError, AccessDeniedError, ) diff --git a/authlib/oauth2/rfc8628/endpoint.py b/authlib/oauth2/rfc8628/endpoint.py index 06e9f3fd..5bcdb9fc 100644 --- a/authlib/oauth2/rfc8628/endpoint.py +++ b/authlib/oauth2/rfc8628/endpoint.py @@ -86,7 +86,8 @@ class MyDeviceAuthorizationEndpoint(DeviceAuthorizationEndpoint): # only support ``client_secret_basic`` auth method CLIENT_AUTH_METHODS = ['client_secret_basic'] """ - client = self.server.authenticate_client(request, self.CLIENT_AUTH_METHODS) + client = self.server.authenticate_client( + request, self.CLIENT_AUTH_METHODS, self.ENDPOINT_NAME) request.client = client return client diff --git a/docs/django/2/authorization-server.rst b/docs/django/2/authorization-server.rst index 2e61bb8c..4b105741 100644 --- a/docs/django/2/authorization-server.rst +++ b/docs/django/2/authorization-server.rst @@ -24,6 +24,11 @@ an example. Client ------ +.. versionchanged:: v1.0 + + ``check_token_endpoint_auth_method`` is deprecated, developers should + implement ``check_endpoint_auth_method`` instead. + A client is an application making protected resource requests on behalf of the resource owner and with its authorization. It contains at least three information: @@ -73,8 +78,11 @@ the missing methods of :class:`~authlib.oauth2.rfc6749.ClientMixin`:: def check_client_secret(self, client_secret): return self.client_secret == client_secret - def check_token_endpoint_auth_method(self, method): - return self.token_endpoint_auth_method == method + def check_endpoint_auth_method(self, method, endpoint): + if endpoint == 'token': + return self.token_endpoint_auth_method == method + # TODO: developers can update this check method + return True def check_response_type(self, response_type): allowed = self.response_type.split() diff --git a/tests/django/test_oauth2/models.py b/tests/django/test_oauth2/models.py index 434d53f1..519eef66 100644 --- a/tests/django/test_oauth2/models.py +++ b/tests/django/test_oauth2/models.py @@ -55,8 +55,10 @@ def has_client_secret(self): def check_client_secret(self, client_secret): return self.client_secret == client_secret - def check_token_endpoint_auth_method(self, method): - return self.token_endpoint_auth_method == method + def check_endpoint_auth_method(self, method, endpoint): + if endpoint == 'token': + return self.token_endpoint_auth_method == method + return True def check_response_type(self, response_type): allowed = self.response_type.split() diff --git a/tests/flask/test_oauth2/test_device_code_grant.py b/tests/flask/test_oauth2/test_device_code_grant.py index 60d4ceec..6d436c68 100644 --- a/tests/flask/test_oauth2/test_device_code_grant.py +++ b/tests/flask/test_oauth2/test_device_code_grant.py @@ -213,12 +213,19 @@ def test_missing_client_id(self): rv = self.client.post('/device_authorize', data={ 'scope': 'profile' }) - self.assertEqual(rv.status_code, 400) + self.assertEqual(rv.status_code, 401) resp = json.loads(rv.data) - self.assertEqual(resp['error'], 'invalid_request') + self.assertEqual(resp['error'], 'invalid_client') def test_create_authorization_response(self): self.create_server() + client = Client( + user_id=1, + client_id='client', + client_secret='secret', + ) + db.session.add(client) + db.session.commit() rv = self.client.post('/device_authorize', data={ 'client_id': 'client', }) From ae1ab049a3fd359d2049f480a1bedc9d5fdb074f Mon Sep 17 00:00:00 2001 From: Hsiaoming Yang Date: Fri, 27 Nov 2020 17:29:52 +0900 Subject: [PATCH 19/26] Add BearerTokenGenerator --- .../oauth2/rfc6749/authorization_server.py | 22 +-- .../rfc6749/grants/authorization_code.py | 2 +- authlib/oauth2/rfc6750/__init__.py | 3 +- .../oauth2/rfc6750/{wrappers.py => token.py} | 127 +++++++++++------- authlib/oauth2/rfc7523/jwt_bearer.py | 1 + 5 files changed, 93 insertions(+), 62 deletions(-) rename authlib/oauth2/rfc6750/{wrappers.py => token.py} (52%) diff --git a/authlib/oauth2/rfc6749/authorization_server.py b/authlib/oauth2/rfc6749/authorization_server.py index ef408991..109f2b5a 100644 --- a/authlib/oauth2/rfc6749/authorization_server.py +++ b/authlib/oauth2/rfc6749/authorization_server.py @@ -149,6 +149,17 @@ def get_authorization_grant(self, request): return _create_grant(grant_cls, extensions, request, self) raise UnsupportedResponseTypeError(request.response_type) + def get_consent_grant(self, request=None, end_user=None): + """Validate current HTTP request for authorization page. This page + is designed for resource owner to grant or deny the authorization. + """ + request = self.create_oauth2_request(request) + request.user = end_user + + grant = self.get_authorization_grant(request) + grant.validate_consent_request() + return grant + def get_token_grant(self, request): """Find the token grant for current request. @@ -217,17 +228,6 @@ def create_token_response(self, request=None): except OAuth2Error as error: return self.handle_error_response(request, error) - def get_consent_grant(self, request=None, end_user=None): - """Validate current HTTP request for authorization page. This page - is designed for resource owner to grant or deny the authorization. - """ - request = self.create_oauth2_request(request) - request.user = end_user - - grant = self.get_authorization_grant(request) - grant.validate_consent_request() - return grant - def handle_error_response(self, request, error): return self.handle_response(*error(self.get_error_uri(request, error))) diff --git a/authlib/oauth2/rfc6749/grants/authorization_code.py b/authlib/oauth2/rfc6749/grants/authorization_code.py index 19e765f2..570ebf26 100644 --- a/authlib/oauth2/rfc6749/grants/authorization_code.py +++ b/authlib/oauth2/rfc6749/grants/authorization_code.py @@ -268,6 +268,7 @@ def create_token_response(self): user = self.authenticate_user(authorization_code) if not user: raise InvalidRequestError('There is no "user" for this code.') + self.request.user = user scope = authorization_code.get_scope() token = self.generate_token( @@ -277,7 +278,6 @@ def create_token_response(self): ) log.debug('Issue token %r to %r', token, client) - self.request.user = user self.save_token(token) self.execute_hook('process_token', token=token) self.delete_authorization_code(authorization_code) diff --git a/authlib/oauth2/rfc6750/__init__.py b/authlib/oauth2/rfc6750/__init__.py index 6539f4cb..0d12e426 100644 --- a/authlib/oauth2/rfc6750/__init__.py +++ b/authlib/oauth2/rfc6750/__init__.py @@ -11,7 +11,7 @@ from .errors import InvalidTokenError, InsufficientScopeError from .parameters import add_bearer_token -from .wrappers import BearerToken +from .token import BearerToken, BearerTokenGenerator from .validator import BearerTokenValidator @@ -19,5 +19,6 @@ 'InvalidTokenError', 'InsufficientScopeError', 'add_bearer_token', 'BearerToken', + 'BearerTokenGenerator', 'BearerTokenValidator', ] diff --git a/authlib/oauth2/rfc6750/wrappers.py b/authlib/oauth2/rfc6750/token.py similarity index 52% rename from authlib/oauth2/rfc6750/wrappers.py rename to authlib/oauth2/rfc6750/token.py index 4b267dc3..faa6c16b 100644 --- a/authlib/oauth2/rfc6750/wrappers.py +++ b/authlib/oauth2/rfc6750/token.py @@ -1,54 +1,5 @@ class BearerToken(object): - """Bearer Token generator which can create the payload for token response - by OAuth 2 server. A typical token response would be: - - .. code-block:: http - - HTTP/1.1 200 OK - Content-Type: application/json;charset=UTF-8 - Cache-Control: no-store - Pragma: no-cache - - { - "access_token":"mF_9.B5f-4.1JqM", - "token_type":"Bearer", - "expires_in":3600, - "refresh_token":"tGzv3JOkF0XG5Qx2TlKWIA" - } - - :param access_token_generator: a function to generate access_token. - :param refresh_token_generator: a function to generate refresh_token, - if not provided, refresh_token will not be added into token response. - :param expires_generator: The expires_generator can be an int value or a - function. If it is int, all token expires_in will be this value. If it - is function, it can generate expires_in depending on client and - grant_type:: - - def expires_generator(client, grant_type): - if is_official_client(client): - return 3600 * 1000 - if grant_type == 'implicit': - return 3600 - return 3600 * 10 - :return: Callable - - When BearerToken is initialized, it will be callable:: - - token_generator = BearerToken(access_token_generator) - token = token_generator(client, grant_type, expires_in=None, - scope=None, include_refresh_token=True) - - The callable function that BearerToken created accepts these parameters: - - :param client: the client that making the request. - :param grant_type: current requested grant_type. - :param expires_in: if provided, use this value as expires_in. - :param scope: current requested scope. - :param include_refresh_token: should refresh_token be included. - :return: Token dict - """ - #: default expires_in value DEFAULT_EXPIRES_IN = 3600 #: default expires_in value differentiate by grant_type @@ -102,3 +53,81 @@ def __call__(self, client, grant_type, user=None, scope=None, if scope: token['scope'] = scope return token + + +class BearerTokenGenerator(object): + """Bearer token generator which can create the payload for token response + by OAuth 2 server. A typical token response would be: + + .. code-block:: http + + HTTP/1.1 200 OK + Content-Type: application/json;charset=UTF-8 + Cache-Control: no-store + Pragma: no-cache + + { + "access_token":"mF_9.B5f-4.1JqM", + "token_type":"Bearer", + "expires_in":3600, + "refresh_token":"tGzv3JOkF0XG5Qx2TlKWIA" + } + """ + TOKEN_TYPE = 'Bearer' + + #: default expires_in value + DEFAULT_EXPIRES_IN = 3600 + #: default expires_in value differentiate by grant_type + GRANT_TYPES_EXPIRES_IN = { + 'authorization_code': 864000, + 'implicit': 3600, + 'password': 864000, + 'client_credentials': 864000 + } + + def generate_access_token(self, client, grant_type, user, scope=None): + raise NotImplementedError() + + def generate_refresh_token(self, client, grant_type, user, scope=None): + raise NotImplementedError() + + def get_expires_in(self, client, grant_type): + return self.GRANT_TYPES_EXPIRES_IN.get(grant_type, self.DEFAULT_EXPIRES_IN) + + def normalize_scope(self, client, scope): + return scope + + def generate(self, client, grant_type, user=None, scope=None, + expires_in=None, include_refresh_token=True): + """Generate the token dict. + + :param client: the client that making the request. + :param grant_type: current requested grant_type. + :param user: current authorized user. + :param expires_in: if provided, use this value as expires_in. + :param scope: current requested scope. + :param include_refresh_token: should refresh_token be included. + :return: Token dict + """ + access_token = self.generate_access_token(client, grant_type, user, scope) + if expires_in is None: + expires_in = self.get_expires_in(client, grant_type) + + token = { + 'token_type': self.TOKEN_TYPE, + 'access_token': access_token, + 'expires_in': expires_in + } + + if include_refresh_token: + refresh_token = self.generate_refresh_token(client, grant_type, user, scope) + if refresh_token: + token['refresh_token'] = refresh_token + + if scope: + token['scope'] = self.normalize_scope(client, scope) + return token + + def __call__(self, client, grant_type, user=None, scope=None, + expires_in=None, include_refresh_token=True): + return self.generate(client, grant_type, user, scope, expires_in, include_refresh_token) diff --git a/authlib/oauth2/rfc7523/jwt_bearer.py b/authlib/oauth2/rfc7523/jwt_bearer.py index a11336d5..b1732930 100644 --- a/authlib/oauth2/rfc7523/jwt_bearer.py +++ b/authlib/oauth2/rfc7523/jwt_bearer.py @@ -107,6 +107,7 @@ def create_token_response(self): """ token = self.generate_token( scope=self.request.scope, + user=self.request.user, include_refresh_token=False, ) log.debug('Issue token %r to %r', token, self.request.client) From 3d70e54a12d150a870fb19128ccfafdd55ff6e30 Mon Sep 17 00:00:00 2001 From: Hsiaoming Yang Date: Fri, 27 Nov 2020 23:07:26 +0900 Subject: [PATCH 20/26] Refactor generate_token for authorization server With this change, developers can register generator for a given grant type. --- .../django_oauth2/authorization_server.py | 11 +- .../flask_oauth2/authorization_server.py | 2 +- .../oauth2/rfc6749/authorization_server.py | 55 ++++++++- authlib/oauth2/rfc6750/__init__.py | 3 +- authlib/oauth2/rfc6750/token.py | 108 +++++------------- 5 files changed, 90 insertions(+), 89 deletions(-) diff --git a/authlib/integrations/django_oauth2/authorization_server.py b/authlib/integrations/django_oauth2/authorization_server.py index 119cc7ab..a7115771 100644 --- a/authlib/integrations/django_oauth2/authorization_server.py +++ b/authlib/integrations/django_oauth2/authorization_server.py @@ -23,15 +23,14 @@ class AuthorizationServer(_AuthorizationServer): server = AuthorizationServer(OAuth2Client, OAuth2Token) """ - def __init__(self, client_model, token_model, generate_token=None): + def __init__(self, client_model, token_model): self.config = getattr(settings, 'AUTHLIB_OAUTH2_PROVIDER', {}) self.client_model = client_model self.token_model = token_model - if generate_token is None: - generate_token = self.create_bearer_token_generator() - - super(AuthorizationServer, self).__init__(generate_token=generate_token) - self.scopes_supported = self.config.get('scopes_supported') + scopes_supported = self.config.get('scopes_supported') + super(AuthorizationServer, self).__init__(scopes_supported=scopes_supported) + # add default token generator + self.register_token_generator('none', self.create_bearer_token_generator()) def query_client(self, client_id): """Default method for ``AuthorizationServer.query_client``. Developers MAY diff --git a/authlib/integrations/flask_oauth2/authorization_server.py b/authlib/integrations/flask_oauth2/authorization_server.py index 08715d0d..59cda1a1 100644 --- a/authlib/integrations/flask_oauth2/authorization_server.py +++ b/authlib/integrations/flask_oauth2/authorization_server.py @@ -54,7 +54,7 @@ def init_app(self, app, query_client=None, save_token=None): if save_token is not None: self._save_token = save_token - self.generate_token = self.create_bearer_token_generator(app.config) + self.register_token_generator('none', self.create_bearer_token_generator(app.config)) self.scopes_supported = app.config.get('OAUTH2_SCOPES_SUPPORTED') self._error_uris = app.config.get('OAUTH2_ERROR_URIS') diff --git a/authlib/oauth2/rfc6749/authorization_server.py b/authlib/oauth2/rfc6749/authorization_server.py index 109f2b5a..2def4a60 100644 --- a/authlib/oauth2/rfc6749/authorization_server.py +++ b/authlib/oauth2/rfc6749/authorization_server.py @@ -12,11 +12,11 @@ class AuthorizationServer(object): """Authorization server that handles Authorization Endpoint and Token Endpoint. - :param generate_token: A method to generate tokens. + :param scopes_supported: A list of supported scopes by this authorization server. """ - def __init__(self, generate_token=None, scopes_supported=None): - self.generate_token = generate_token + def __init__(self, scopes_supported=None): self.scopes_supported = scopes_supported + self._token_generators = {} self._client_auth = None self._authorization_grants = [] self._token_grants = [] @@ -33,6 +33,55 @@ def save_token(self, token, request): """Define function to save the generated token into database.""" raise NotImplementedError() + def generate_token(self, grant_type, client, user=None, scope=None, + expires_in=None, include_refresh_token=True): + """Generate the token dict. + + :param grant_type: current requested grant_type. + :param client: the client that making the request. + :param user: current authorized user. + :param expires_in: if provided, use this value as expires_in. + :param scope: current requested scope. + :param include_refresh_token: should refresh_token be included. + :return: Token dict + """ + # generator for a specified grant type + func = self._token_generators.get(grant_type) + if not func: + # default generator for all grant types + func = self._token_generators.get('none') + if not func: + raise RuntimeError('No configured token generator') + + return func( + grant_type=grant_type, client=client, user=user, scope=scope, + expires_in=expires_in, include_refresh_token=include_refresh_token) + + def register_token_generator(self, grant_type, func): + """Register a function as token generator for the given ``grant_type``. + Developers MUST register a default token generator with a special + ``grant_type=none``:: + + def generate_bearer_token(grant_type, client, user=None, scope=None, + expires_in=None, include_refresh_token=True): + token = {'token_type': 'Bearer', 'access_token': ...} + if include_refresh_token: + token['refresh_token'] = ... + ... + return token + + authorization_server.register_token_generator('none', generate_bearer_token) + + If you register a generator for a certain grant type, that generator will only works + for the given grant type:: + + authorization_server.register_token_generator('client_credentials', generate_bearer_token) + + :param grant_type: string name of the grant type + :param func: a function to generate token + """ + self._token_generators[grant_type] = func + def authenticate_client(self, request, methods, endpoint='token'): """Authenticate client via HTTP request information with the given methods, such as ``client_secret_basic``, ``client_secret_post``. diff --git a/authlib/oauth2/rfc6750/__init__.py b/authlib/oauth2/rfc6750/__init__.py index 0d12e426..598d9b46 100644 --- a/authlib/oauth2/rfc6750/__init__.py +++ b/authlib/oauth2/rfc6750/__init__.py @@ -11,7 +11,7 @@ from .errors import InvalidTokenError, InsufficientScopeError from .parameters import add_bearer_token -from .token import BearerToken, BearerTokenGenerator +from .token import BearerToken from .validator import BearerTokenValidator @@ -19,6 +19,5 @@ 'InvalidTokenError', 'InsufficientScopeError', 'add_bearer_token', 'BearerToken', - 'BearerTokenGenerator', 'BearerTokenValidator', ] diff --git a/authlib/oauth2/rfc6750/token.py b/authlib/oauth2/rfc6750/token.py index faa6c16b..1772eb85 100644 --- a/authlib/oauth2/rfc6750/token.py +++ b/authlib/oauth2/rfc6750/token.py @@ -1,5 +1,22 @@ - class BearerToken(object): + """Bearer token generator which can create the payload for token response + by OAuth 2 server. A typical token response would be: + + .. code-block:: http + + HTTP/1.1 200 OK + Content-Type: application/json;charset=UTF-8 + Cache-Control: no-store + Pragma: no-cache + + { + "access_token":"mF_9.B5f-4.1JqM", + "token_type":"Bearer", + "expires_in":3600, + "refresh_token":"tGzv3JOkF0XG5Qx2TlKWIA" + } + """ + #: default expires_in value DEFAULT_EXPIRES_IN = 3600 #: default expires_in value differentiate by grant_type @@ -35,71 +52,9 @@ def get_allowed_scope(client, scope): scope = client.get_allowed_scope(scope) return scope - def __call__(self, client, grant_type, user=None, scope=None, + def generate(self, grant_type, client, user=None, scope=None, expires_in=None, include_refresh_token=True): - scope = self.get_allowed_scope(client, scope) - access_token = self.access_token_generator(client, grant_type, user, scope) - if expires_in is None: - expires_in = self._get_expires_in(client, grant_type) - - token = { - 'token_type': 'Bearer', - 'access_token': access_token, - 'expires_in': expires_in - } - if include_refresh_token and self.refresh_token_generator: - token['refresh_token'] = self.refresh_token_generator( - client, grant_type, user, scope) - if scope: - token['scope'] = scope - return token - - -class BearerTokenGenerator(object): - """Bearer token generator which can create the payload for token response - by OAuth 2 server. A typical token response would be: - - .. code-block:: http - - HTTP/1.1 200 OK - Content-Type: application/json;charset=UTF-8 - Cache-Control: no-store - Pragma: no-cache - - { - "access_token":"mF_9.B5f-4.1JqM", - "token_type":"Bearer", - "expires_in":3600, - "refresh_token":"tGzv3JOkF0XG5Qx2TlKWIA" - } - """ - TOKEN_TYPE = 'Bearer' - - #: default expires_in value - DEFAULT_EXPIRES_IN = 3600 - #: default expires_in value differentiate by grant_type - GRANT_TYPES_EXPIRES_IN = { - 'authorization_code': 864000, - 'implicit': 3600, - 'password': 864000, - 'client_credentials': 864000 - } - - def generate_access_token(self, client, grant_type, user, scope=None): - raise NotImplementedError() - - def generate_refresh_token(self, client, grant_type, user, scope=None): - raise NotImplementedError() - - def get_expires_in(self, client, grant_type): - return self.GRANT_TYPES_EXPIRES_IN.get(grant_type, self.DEFAULT_EXPIRES_IN) - - def normalize_scope(self, client, scope): - return scope - - def generate(self, client, grant_type, user=None, scope=None, - expires_in=None, include_refresh_token=True): - """Generate the token dict. + """Generate a bearer token for OAuth 2.0 authorization token endpoint. :param client: the client that making the request. :param grant_type: current requested grant_type. @@ -109,25 +64,24 @@ def generate(self, client, grant_type, user=None, scope=None, :param include_refresh_token: should refresh_token be included. :return: Token dict """ - access_token = self.generate_access_token(client, grant_type, user, scope) + scope = self.get_allowed_scope(client, scope) + access_token = self.access_token_generator( + client=client, grant_type=grant_type, user=user, scope=scope) if expires_in is None: - expires_in = self.get_expires_in(client, grant_type) + expires_in = self._get_expires_in(client, grant_type) token = { - 'token_type': self.TOKEN_TYPE, + 'token_type': 'Bearer', 'access_token': access_token, 'expires_in': expires_in } - - if include_refresh_token: - refresh_token = self.generate_refresh_token(client, grant_type, user, scope) - if refresh_token: - token['refresh_token'] = refresh_token - + if include_refresh_token and self.refresh_token_generator: + token['refresh_token'] = self.refresh_token_generator( + client=client, grant_type=grant_type, user=user, scope=scope) if scope: - token['scope'] = self.normalize_scope(client, scope) + token['scope'] = scope return token - def __call__(self, client, grant_type, user=None, scope=None, + def __call__(self, grant_type, client, user=None, scope=None, expires_in=None, include_refresh_token=True): - return self.generate(client, grant_type, user, scope, expires_in, include_refresh_token) + return self.generate(grant_type, client, user, scope, expires_in, include_refresh_token) From 695af265255853310c905dcd48b439955148516f Mon Sep 17 00:00:00 2001 From: Hsiaoming Yang Date: Tue, 8 Dec 2020 23:46:09 +0900 Subject: [PATCH 21/26] Add JWTBearerTokenGenerator and JWTBearerTokenValidator Although these token generator and validator are designed for jwt-bearer grant type, it can also be used for other grant types. In this way, it solved the issue: https://github.com/lepture/authlib/issues/89 --- .../django_oauth2/authorization_server.py | 8 +- .../flask_oauth2/authorization_server.py | 8 +- authlib/oauth2/rfc6750/__init__.py | 6 +- authlib/oauth2/rfc6750/token.py | 2 +- authlib/oauth2/rfc7523/__init__.py | 6 ++ authlib/oauth2/rfc7523/jwt_bearer.py | 23 +++-- authlib/oauth2/rfc7523/token.py | 86 +++++++++++++++++++ authlib/oauth2/rfc7523/validator.py | 53 ++++++++++++ tox.ini | 4 +- 9 files changed, 172 insertions(+), 24 deletions(-) create mode 100755 authlib/oauth2/rfc7523/token.py create mode 100755 authlib/oauth2/rfc7523/validator.py diff --git a/authlib/integrations/django_oauth2/authorization_server.py b/authlib/integrations/django_oauth2/authorization_server.py index a7115771..1f634acb 100644 --- a/authlib/integrations/django_oauth2/authorization_server.py +++ b/authlib/integrations/django_oauth2/authorization_server.py @@ -6,7 +6,7 @@ HttpRequest, AuthorizationServer as _AuthorizationServer, ) -from authlib.oauth2.rfc6750 import BearerToken +from authlib.oauth2.rfc6750 import BearerTokenGenerator from authlib.common.security import generate_token as _generate_token from authlib.common.encoding import json_dumps from .signals import client_authenticated, token_revoked @@ -91,7 +91,7 @@ def create_bearer_token_generator(self): conf = self.config.get('token_expires_in') expires_generator = create_token_expires_in_generator(conf) - return BearerToken( + return BearerTokenGenerator( access_token_generator=access_token_generator, refresh_token_generator=refresh_token_generator, expires_generator=expires_generator, @@ -112,11 +112,11 @@ def token_generator(*args, **kwargs): def create_token_expires_in_generator(expires_in_conf=None): data = {} - data.update(BearerToken.GRANT_TYPES_EXPIRES_IN) + data.update(BearerTokenGenerator.GRANT_TYPES_EXPIRES_IN) if expires_in_conf: data.update(expires_in_conf) def expires_in(client, grant_type): - return data.get(grant_type, BearerToken.DEFAULT_EXPIRES_IN) + return data.get(grant_type, BearerTokenGenerator.DEFAULT_EXPIRES_IN) return expires_in diff --git a/authlib/integrations/flask_oauth2/authorization_server.py b/authlib/integrations/flask_oauth2/authorization_server.py index 59cda1a1..b828ae14 100644 --- a/authlib/integrations/flask_oauth2/authorization_server.py +++ b/authlib/integrations/flask_oauth2/authorization_server.py @@ -5,7 +5,7 @@ HttpRequest, AuthorizationServer as _AuthorizationServer, ) -from authlib.oauth2.rfc6750 import BearerToken +from authlib.oauth2.rfc6750 import BearerTokenGenerator from authlib.common.security import generate_token from .signals import client_authenticated, token_revoked from ..flask_helpers import create_oauth_request @@ -126,7 +126,7 @@ def gen_token(client, grant_type, user, scope): expires_conf = config.get('OAUTH2_TOKEN_EXPIRES_IN') expires_generator = create_token_expires_in_generator(expires_conf) - return BearerToken( + return BearerTokenGenerator( access_token_generator, refresh_token_generator, expires_generator @@ -138,12 +138,12 @@ def create_token_expires_in_generator(expires_in_conf=None): return import_string(expires_in_conf) data = {} - data.update(BearerToken.GRANT_TYPES_EXPIRES_IN) + data.update(BearerTokenGenerator.GRANT_TYPES_EXPIRES_IN) if isinstance(expires_in_conf, dict): data.update(expires_in_conf) def expires_in(client, grant_type): - return data.get(grant_type, BearerToken.DEFAULT_EXPIRES_IN) + return data.get(grant_type, BearerTokenGenerator.DEFAULT_EXPIRES_IN) return expires_in diff --git a/authlib/oauth2/rfc6750/__init__.py b/authlib/oauth2/rfc6750/__init__.py index 598d9b46..ac88cce4 100644 --- a/authlib/oauth2/rfc6750/__init__.py +++ b/authlib/oauth2/rfc6750/__init__.py @@ -11,13 +11,17 @@ from .errors import InvalidTokenError, InsufficientScopeError from .parameters import add_bearer_token -from .token import BearerToken +from .token import BearerTokenGenerator from .validator import BearerTokenValidator +# TODO: add deprecation +BearerToken = BearerTokenGenerator + __all__ = [ 'InvalidTokenError', 'InsufficientScopeError', 'add_bearer_token', 'BearerToken', + 'BearerTokenGenerator', 'BearerTokenValidator', ] diff --git a/authlib/oauth2/rfc6750/token.py b/authlib/oauth2/rfc6750/token.py index 1772eb85..1b5154eb 100644 --- a/authlib/oauth2/rfc6750/token.py +++ b/authlib/oauth2/rfc6750/token.py @@ -1,4 +1,4 @@ -class BearerToken(object): +class BearerTokenGenerator(object): """Bearer token generator which can create the payload for token response by OAuth 2 server. A typical token response would be: diff --git a/authlib/oauth2/rfc7523/__init__.py b/authlib/oauth2/rfc7523/__init__.py index d8404bc2..627992b8 100644 --- a/authlib/oauth2/rfc7523/__init__.py +++ b/authlib/oauth2/rfc7523/__init__.py @@ -21,6 +21,8 @@ from .auth import ( ClientSecretJWT, PrivateKeyJWT, ) +from .token import JWTBearerTokenGenerator +from .validator import JWTBearerToken, JWTBearerTokenValidator __all__ = [ 'JWTBearerGrant', @@ -29,4 +31,8 @@ 'private_key_jwt_sign', 'ClientSecretJWT', 'PrivateKeyJWT', + + 'JWTBearerToken', + 'JWTBearerTokenGenerator', + 'JWTBearerTokenValidator', ] diff --git a/authlib/oauth2/rfc7523/jwt_bearer.py b/authlib/oauth2/rfc7523/jwt_bearer.py index b1732930..dc0fe171 100644 --- a/authlib/oauth2/rfc7523/jwt_bearer.py +++ b/authlib/oauth2/rfc7523/jwt_bearer.py @@ -16,6 +16,15 @@ class JWTBearerGrant(BaseGrant, TokenEndpointMixin): GRANT_TYPE = JWT_BEARER_GRANT_TYPE + #: Options for verifying JWT payload claims. Developers MAY + #: overwrite this constant to create a more strict options. + CLAIMS_OPTIONS = { + 'iss': {'essential': True}, + 'sub': {'essential': True}, + 'aud': {'essential': True}, + 'exp': {'essential': True}, + } + @staticmethod def sign(key, issuer, audience, subject=None, issued_at=None, expires_at=None, claims=None, **kwargs): @@ -23,18 +32,6 @@ def sign(key, issuer, audience, subject=None, key, issuer, audience, subject, issued_at, expires_at, claims, **kwargs) - def create_claims_options(self): - """Create a claims_options for verify JWT payload claims. Developers - MAY overwrite this method to create a more strict options. - """ - # https://tools.ietf.org/html/rfc7523#section-3 - return { - 'iss': {'essential': True}, - 'sub': {'essential': True}, - 'aud': {'essential': True}, - 'exp': {'essential': True}, - } - def process_assertion_claims(self, assertion): """Extract JWT payload claims from request "assertion", per `Section 3.1`_. @@ -47,7 +44,7 @@ def process_assertion_claims(self, assertion): """ claims = jwt.decode( assertion, self.resolve_public_key, - claims_options=self.create_claims_options()) + claims_options=self.CLAIMS_OPTIONS) try: claims.validate() except JoseError as e: diff --git a/authlib/oauth2/rfc7523/token.py b/authlib/oauth2/rfc7523/token.py new file mode 100755 index 00000000..8ef9a162 --- /dev/null +++ b/authlib/oauth2/rfc7523/token.py @@ -0,0 +1,86 @@ +import time +from authlib.common.encoding import to_unicode +from authlib.jose import jwt + + +class JWTBearerTokenGenerator(object): + """A JSON Web Token formatted bearer token generator for jwt-bearer grant type. + This token generator can be registered into authorization server:: + + authorization_server.register_token_generator( + 'urn:ietf:params:oauth:grant-type:jwt-bearer', + JWTBearerTokenGenerator(private_rsa_key), + ) + + In this way, we can generate the token into JWT format. And we don't have to + save this token into database, since it will be short time valid. Consider to + rewrite ``JWTBearerGrant.save_token``:: + + class MyJWTBearerGrant(JWTBearerGrant): + def save_token(self, token): + pass + + :param secret_key: private RSA key in bytes, JWK or JWK Set. + :param issuer: a string or URI of the issuer + :param alg: ``alg`` to use in JWT + """ + DEFAULT_EXPIRES_IN = 3600 + + def __init__(self, secret_key, issuer=None, alg='RS256'): + self.secret_key = secret_key + self.issuer = issuer + self.alg = alg + + @staticmethod + def get_allowed_scope(client, scope): + if scope: + scope = client.get_allowed_scope(scope) + return scope + + @staticmethod + def get_user_id(user): + return user.get_user_id() + + def get_token_data(self, grant_type, client, user=None, scope=None, expires_in=None): + scope = self.get_allowed_scope(client, scope) + if not expires_in: + expires_in = self.DEFAULT_EXPIRES_IN + issued_at = int(time.time()) + data = { + 'scope': scope, + 'grant_type': grant_type, + 'iat': issued_at, + 'exp': issued_at + expires_in, + 'client_id': client.get_client_id(), + } + if self.issuer: + data['iss'] = self.issuer + if user: + data['sub'] = self.get_user_id(user) + return data + + def generate(self, grant_type, client, user=None, scope=None, expires_in=None): + """Generate a bearer token for OAuth 2.0 authorization token endpoint. + + :param client: the client that making the request. + :param grant_type: current requested grant_type. + :param user: current authorized user. + :param expires_in: if provided, use this value as expires_in. + :param scope: current requested scope. + :return: Token dict + """ + token_data = self.get_token_data(grant_type, client, user, scope, expires_in) + access_token = jwt.encode({'alg': self.alg}, token_data, check=False) + token = { + 'token_type': 'Bearer', + 'access_token': to_unicode(access_token), + 'expires_in': expires_in + } + if scope: + token['scope'] = scope + return token + + def __call__(self, grant_type, client, user=None, scope=None, + expires_in=None, include_refresh_token=True): + # there is absolutely no refresh token in JWT format + return self.generate(grant_type, client, user, scope, expires_in) diff --git a/authlib/oauth2/rfc7523/validator.py b/authlib/oauth2/rfc7523/validator.py new file mode 100755 index 00000000..83222436 --- /dev/null +++ b/authlib/oauth2/rfc7523/validator.py @@ -0,0 +1,53 @@ +import time +from authlib.jose import jwt, JoseError +from ..rfc6749 import TokenMixin +from ..rfc6750 import BearerTokenValidator + + +class JWTBearerToken(TokenMixin, dict): + def __init__(self, data): + super(JWTBearerToken, self).__init__(data) + + def check_client(self, client): + return self['client_id'] == client.get_client_id() + + def get_scope(self): + return self.get('scope') + + def get_expires_in(self): + return self['exp'] - self['iat'] + + def is_expired(self): + return self['exp'] < time.time() + + def is_revoked(self): + return False + + +class JWTBearerTokenValidator(BearerTokenValidator): + TOKEN_TYPE = 'bearer' + token_cls = JWTBearerToken + + def __init__(self, public_key, issuer=None, realm=None, **extra_attributes): + super(JWTBearerTokenValidator, self).__init__(realm, **extra_attributes) + self.public_key = public_key + claims_options = { + 'sub': {'essential': True}, + 'exp': {'essential': True}, + 'client_id': {'essential': True}, + 'grant_type': {'essential': True}, + } + if issuer: + claims_options['iss'] = {'essential': True, 'value': issuer} + self.claims_options = claims_options + + def authenticate_token(self, token_string): + try: + claims = jwt.decode( + token_string, self.public_key, + claims_options=self.claims_options, + ) + claims.validate() + return self.token_cls(dict(claims)) + except JoseError: + return None diff --git a/tox.ini b/tox.ini index ca4490aa..94075413 100644 --- a/tox.ini +++ b/tox.ini @@ -23,10 +23,12 @@ setenv = starlette: TESTPATH=tests/starlette flask: TESTPATH=tests/flask django: TESTPATH=tests/django - django: DJANGO_SETTINGS_MODULE=tests.django.settings commands = coverage run --source=authlib -p -m pytest {env:TESTPATH} +[pytest] +DJANGO_SETTINGS_MODULE=tests.django.settings + [testenv:coverage] skip_install = true commands = From 750de5daef7ac3c62377fdfe537c1b9b5c52184d Mon Sep 17 00:00:00 2001 From: Hsiaoming Yang Date: Wed, 9 Dec 2020 00:34:31 +0900 Subject: [PATCH 22/26] Add tests for JWTBearerTokenGenerator Move random key into jwt.encode function --- authlib/jose/rfc7517/base_key.py | 7 +++ authlib/jose/rfc7517/key_set.py | 2 +- authlib/jose/rfc7519/jwt.py | 51 ++++++++++++++++--- authlib/oauth2/rfc7523/token.py | 6 +-- authlib/oidc/core/grants/util.py | 19 +------ .../test_oauth2/test_jwt_bearer_grant.py | 27 ++++++++-- 6 files changed, 80 insertions(+), 32 deletions(-) diff --git a/authlib/jose/rfc7517/base_key.py b/authlib/jose/rfc7517/base_key.py index f8fe7b4a..9413f988 100644 --- a/authlib/jose/rfc7517/base_key.py +++ b/authlib/jose/rfc7517/base_key.py @@ -43,6 +43,13 @@ def tokens(self): rv[k] = self.options[k] return rv + @property + def kid(self): + rv = self.tokens.get('kid') + if not rv: + rv = self.thumbprint() + return rv + def keys(self): return self.tokens.keys() diff --git a/authlib/jose/rfc7517/key_set.py b/authlib/jose/rfc7517/key_set.py index e95c4d0c..c4f7720b 100644 --- a/authlib/jose/rfc7517/key_set.py +++ b/authlib/jose/rfc7517/key_set.py @@ -24,6 +24,6 @@ def find_by_kid(self, kid): :raise: ValueError """ for k in self.keys: - if k.tokens.get('kid') == kid: + if k.kid == kid: return k raise ValueError('Invalid JSON Web Key Set') diff --git a/authlib/jose/rfc7519/jwt.py b/authlib/jose/rfc7519/jwt.py index 28cec79b..c76b583f 100644 --- a/authlib/jose/rfc7519/jwt.py +++ b/authlib/jose/rfc7519/jwt.py @@ -1,4 +1,5 @@ import re +import random import datetime import calendar from authlib.common.encoding import ( @@ -60,7 +61,7 @@ def encode(self, header, payload, key, check=True): if check: self.check_sensitive_data(payload) - key = prepare_raw_key(key, header) + key = find_encode_key(key, header) text = to_bytes(json_dumps(payload)) if 'enc' in header: return self._jwe.serialize_compact(header, text, key) @@ -87,8 +88,7 @@ def decode(self, s, key, claims_cls=None, if callable(key): load_key = key else: - def load_key(header, payload): - return prepare_raw_key(key, header) + load_key = create_load_key(prepare_raw_key(key)) s = to_bytes(s) dot_count = s.count(b'.') @@ -115,21 +115,56 @@ def decode_payload(bytes_payload): return payload -def prepare_raw_key(raw, header): +def prepare_raw_key(raw): if isinstance(raw, KeySet): - return raw.find_by_kid(header.get('kid')) + return raw if isinstance(raw, str) and \ raw.startswith('{') and raw.endswith('}'): raw = json_loads(raw) elif isinstance(raw, (tuple, list)): raw = {'keys': raw} + return raw + + +def find_encode_key(key, header): + if isinstance(key, KeySet): + kid = header.get('kid') + if kid: + return key.find_by_kid(kid) + + rv = random.choice(key.keys) + # use side effect to add kid value into header + header['kid'] = rv.kid + return rv - if isinstance(raw, dict) and 'keys' in raw: - keys = raw['keys'] + if isinstance(key, dict) and 'keys' in key: + keys = key['keys'] kid = header.get('kid') for k in keys: if k.get('kid') == kid: return k + + if not kid: + rv = random.choice(keys) + header['kid'] = rv['kid'] + return rv raise ValueError('Invalid JSON Web Key Set') - return raw + return key + + +def create_load_key(key): + def load_key(header, payload): + if isinstance(key, KeySet): + return key.find_by_kid(header.get('kid')) + + if isinstance(key, dict) and 'keys' in key: + keys = key['keys'] + kid = header.get('kid') + for k in keys: + if k.get('kid') == kid: + return k + raise ValueError('Invalid JSON Web Key Set') + return key + + return load_key diff --git a/authlib/oauth2/rfc7523/token.py b/authlib/oauth2/rfc7523/token.py index 8ef9a162..352994a2 100755 --- a/authlib/oauth2/rfc7523/token.py +++ b/authlib/oauth2/rfc7523/token.py @@ -1,5 +1,5 @@ import time -from authlib.common.encoding import to_unicode +from authlib.common.encoding import to_native from authlib.jose import jwt @@ -70,10 +70,10 @@ def generate(self, grant_type, client, user=None, scope=None, expires_in=None): :return: Token dict """ token_data = self.get_token_data(grant_type, client, user, scope, expires_in) - access_token = jwt.encode({'alg': self.alg}, token_data, check=False) + access_token = jwt.encode({'alg': self.alg}, token_data, key=self.secret_key, check=False) token = { 'token_type': 'Bearer', - 'access_token': to_unicode(access_token), + 'access_token': to_native(access_token), 'expires_in': expires_in } if scope: diff --git a/authlib/oidc/core/grants/util.py b/authlib/oidc/core/grants/util.py index ba8e5ea8..cb366260 100644 --- a/authlib/oidc/core/grants/util.py +++ b/authlib/oidc/core/grants/util.py @@ -1,8 +1,7 @@ import time -import random from authlib.oauth2.rfc6749 import InvalidRequestError from authlib.oauth2.rfc6749.util import scope_to_list -from authlib.jose import JsonWebToken +from authlib.jose import jwt from authlib.common.encoding import to_native from authlib.common.urls import add_params_to_uri, quote_url from ..util import create_half_hash @@ -68,7 +67,7 @@ def generate_id_token( access_token=token.get('access_token'), ) payload.update(user_info) - return _jwt_encode(alg, payload, key) + return to_native(jwt.encode({'alg': alg}, payload, key)) def create_response_mode_response(redirect_uri, params, response_mode): @@ -139,17 +138,3 @@ def _generate_id_token_payload( if access_token: payload['at_hash'] = to_native(create_half_hash(access_token, alg)) return payload - - -def _jwt_encode(alg, payload, key): - jwt = JsonWebToken(algorithms=[alg]) - header = {'alg': alg} - if isinstance(key, dict): - # JWK set format - if 'keys' in key: - key = random.choice(key['keys']) - header['kid'] = key['kid'] - elif 'kid' in key: - header['kid'] = key['kid'] - - return to_native(jwt.encode(header, payload, key)) diff --git a/tests/flask/test_oauth2/test_jwt_bearer_grant.py b/tests/flask/test_oauth2/test_jwt_bearer_grant.py index 41ca77e9..e5512878 100644 --- a/tests/flask/test_oauth2/test_jwt_bearer_grant.py +++ b/tests/flask/test_oauth2/test_jwt_bearer_grant.py @@ -1,5 +1,7 @@ from flask import json from authlib.oauth2.rfc7523 import JWTBearerGrant as _JWTBearerGrant +from authlib.oauth2.rfc7523 import JWTBearerTokenGenerator, JWTBearerTokenValidator +from tests.util import read_file_path from .models import db, User, Client from .oauth2_server import TestCase from .oauth2_server import create_authorization_server @@ -19,15 +21,19 @@ def resolve_public_key(self, headers, payload): class JWTBearerGrantTest(TestCase): - def prepare_data(self, grant_type=None): + def prepare_data(self, grant_type=None, token_generator=None): server = create_authorization_server(self.app) server.register_grant(JWTBearerGrant) + if token_generator: + server.register_token_generator(JWTBearerGrant.GRANT_TYPE, token_generator) + + if grant_type is None: + grant_type = JWTBearerGrant.GRANT_TYPE + user = User(username='foo') db.session.add(user) db.session.commit() - if grant_type is None: - grant_type = JWTBearerGrant.GRANT_TYPE client = Client( user_id=user.id, client_id='jwt-client', @@ -104,3 +110,18 @@ def test_token_generator(self): resp = json.loads(rv.data) self.assertIn('access_token', resp) self.assertIn('j-', resp['access_token']) + + def test_jwt_bearer_token_generator(self): + private_key = read_file_path('jwks_private.json') + self.prepare_data(token_generator=JWTBearerTokenGenerator(private_key)) + assertion = JWTBearerGrant.sign( + 'foo', issuer='jwt-client', audience='https://i.b/token', + subject='self', header={'alg': 'HS256', 'kid': '1'} + ) + rv = self.client.post('/oauth/token', data={ + 'grant_type': JWTBearerGrant.GRANT_TYPE, + 'assertion': assertion + }) + resp = json.loads(rv.data) + self.assertIn('access_token', resp) + self.assertEqual(resp['access_token'].count('.'), 2) From c0cd15f9ac3eef702b1f850e98c8f50de411117f Mon Sep 17 00:00:00 2001 From: Hsiaoming Yang Date: Wed, 9 Dec 2020 00:59:55 +0900 Subject: [PATCH 23/26] Use linux line separator --- authlib/jose/rfc7517/base_key.py | 5 +- authlib/oauth2/rfc7523/token.py | 172 ++++++++++++++-------------- authlib/oauth2/rfc7523/validator.py | 104 +++++++++-------- 3 files changed, 138 insertions(+), 143 deletions(-) diff --git a/authlib/jose/rfc7517/base_key.py b/authlib/jose/rfc7517/base_key.py index 9413f988..c8c958ce 100644 --- a/authlib/jose/rfc7517/base_key.py +++ b/authlib/jose/rfc7517/base_key.py @@ -45,10 +45,7 @@ def tokens(self): @property def kid(self): - rv = self.tokens.get('kid') - if not rv: - rv = self.thumbprint() - return rv + return self.tokens.get('kid') def keys(self): return self.tokens.keys() diff --git a/authlib/oauth2/rfc7523/token.py b/authlib/oauth2/rfc7523/token.py index 352994a2..ea7c2dea 100755 --- a/authlib/oauth2/rfc7523/token.py +++ b/authlib/oauth2/rfc7523/token.py @@ -1,86 +1,86 @@ -import time -from authlib.common.encoding import to_native -from authlib.jose import jwt - - -class JWTBearerTokenGenerator(object): - """A JSON Web Token formatted bearer token generator for jwt-bearer grant type. - This token generator can be registered into authorization server:: - - authorization_server.register_token_generator( - 'urn:ietf:params:oauth:grant-type:jwt-bearer', - JWTBearerTokenGenerator(private_rsa_key), - ) - - In this way, we can generate the token into JWT format. And we don't have to - save this token into database, since it will be short time valid. Consider to - rewrite ``JWTBearerGrant.save_token``:: - - class MyJWTBearerGrant(JWTBearerGrant): - def save_token(self, token): - pass - - :param secret_key: private RSA key in bytes, JWK or JWK Set. - :param issuer: a string or URI of the issuer - :param alg: ``alg`` to use in JWT - """ - DEFAULT_EXPIRES_IN = 3600 - - def __init__(self, secret_key, issuer=None, alg='RS256'): - self.secret_key = secret_key - self.issuer = issuer - self.alg = alg - - @staticmethod - def get_allowed_scope(client, scope): - if scope: - scope = client.get_allowed_scope(scope) - return scope - - @staticmethod - def get_user_id(user): - return user.get_user_id() - - def get_token_data(self, grant_type, client, user=None, scope=None, expires_in=None): - scope = self.get_allowed_scope(client, scope) - if not expires_in: - expires_in = self.DEFAULT_EXPIRES_IN - issued_at = int(time.time()) - data = { - 'scope': scope, - 'grant_type': grant_type, - 'iat': issued_at, - 'exp': issued_at + expires_in, - 'client_id': client.get_client_id(), - } - if self.issuer: - data['iss'] = self.issuer - if user: - data['sub'] = self.get_user_id(user) - return data - - def generate(self, grant_type, client, user=None, scope=None, expires_in=None): - """Generate a bearer token for OAuth 2.0 authorization token endpoint. - - :param client: the client that making the request. - :param grant_type: current requested grant_type. - :param user: current authorized user. - :param expires_in: if provided, use this value as expires_in. - :param scope: current requested scope. - :return: Token dict - """ - token_data = self.get_token_data(grant_type, client, user, scope, expires_in) - access_token = jwt.encode({'alg': self.alg}, token_data, key=self.secret_key, check=False) - token = { - 'token_type': 'Bearer', - 'access_token': to_native(access_token), - 'expires_in': expires_in - } - if scope: - token['scope'] = scope - return token - - def __call__(self, grant_type, client, user=None, scope=None, - expires_in=None, include_refresh_token=True): - # there is absolutely no refresh token in JWT format - return self.generate(grant_type, client, user, scope, expires_in) +import time +from authlib.common.encoding import to_native +from authlib.jose import jwt + + +class JWTBearerTokenGenerator(object): + """A JSON Web Token formatted bearer token generator for jwt-bearer grant type. + This token generator can be registered into authorization server:: + + authorization_server.register_token_generator( + 'urn:ietf:params:oauth:grant-type:jwt-bearer', + JWTBearerTokenGenerator(private_rsa_key), + ) + + In this way, we can generate the token into JWT format. And we don't have to + save this token into database, since it will be short time valid. Consider to + rewrite ``JWTBearerGrant.save_token``:: + + class MyJWTBearerGrant(JWTBearerGrant): + def save_token(self, token): + pass + + :param secret_key: private RSA key in bytes, JWK or JWK Set. + :param issuer: a string or URI of the issuer + :param alg: ``alg`` to use in JWT + """ + DEFAULT_EXPIRES_IN = 3600 + + def __init__(self, secret_key, issuer=None, alg='RS256'): + self.secret_key = secret_key + self.issuer = issuer + self.alg = alg + + @staticmethod + def get_allowed_scope(client, scope): + if scope: + scope = client.get_allowed_scope(scope) + return scope + + @staticmethod + def get_user_id(user): + return user.get_user_id() + + def get_token_data(self, grant_type, client, user=None, scope=None, expires_in=None): + scope = self.get_allowed_scope(client, scope) + if not expires_in: + expires_in = self.DEFAULT_EXPIRES_IN + issued_at = int(time.time()) + data = { + 'scope': scope, + 'grant_type': grant_type, + 'iat': issued_at, + 'exp': issued_at + expires_in, + 'client_id': client.get_client_id(), + } + if self.issuer: + data['iss'] = self.issuer + if user: + data['sub'] = self.get_user_id(user) + return data + + def generate(self, grant_type, client, user=None, scope=None, expires_in=None): + """Generate a bearer token for OAuth 2.0 authorization token endpoint. + + :param client: the client that making the request. + :param grant_type: current requested grant_type. + :param user: current authorized user. + :param expires_in: if provided, use this value as expires_in. + :param scope: current requested scope. + :return: Token dict + """ + token_data = self.get_token_data(grant_type, client, user, scope, expires_in) + access_token = jwt.encode({'alg': self.alg}, token_data, key=self.secret_key, check=False) + token = { + 'token_type': 'Bearer', + 'access_token': to_native(access_token), + 'expires_in': expires_in + } + if scope: + token['scope'] = scope + return token + + def __call__(self, grant_type, client, user=None, scope=None, + expires_in=None, include_refresh_token=True): + # there is absolutely no refresh token in JWT format + return self.generate(grant_type, client, user, scope, expires_in) diff --git a/authlib/oauth2/rfc7523/validator.py b/authlib/oauth2/rfc7523/validator.py index 83222436..fd64d3b0 100755 --- a/authlib/oauth2/rfc7523/validator.py +++ b/authlib/oauth2/rfc7523/validator.py @@ -1,53 +1,51 @@ -import time -from authlib.jose import jwt, JoseError -from ..rfc6749 import TokenMixin -from ..rfc6750 import BearerTokenValidator - - -class JWTBearerToken(TokenMixin, dict): - def __init__(self, data): - super(JWTBearerToken, self).__init__(data) - - def check_client(self, client): - return self['client_id'] == client.get_client_id() - - def get_scope(self): - return self.get('scope') - - def get_expires_in(self): - return self['exp'] - self['iat'] - - def is_expired(self): - return self['exp'] < time.time() - - def is_revoked(self): - return False - - -class JWTBearerTokenValidator(BearerTokenValidator): - TOKEN_TYPE = 'bearer' - token_cls = JWTBearerToken - - def __init__(self, public_key, issuer=None, realm=None, **extra_attributes): - super(JWTBearerTokenValidator, self).__init__(realm, **extra_attributes) - self.public_key = public_key - claims_options = { - 'sub': {'essential': True}, - 'exp': {'essential': True}, - 'client_id': {'essential': True}, - 'grant_type': {'essential': True}, - } - if issuer: - claims_options['iss'] = {'essential': True, 'value': issuer} - self.claims_options = claims_options - - def authenticate_token(self, token_string): - try: - claims = jwt.decode( - token_string, self.public_key, - claims_options=self.claims_options, - ) - claims.validate() - return self.token_cls(dict(claims)) - except JoseError: - return None +import time +from authlib.jose import jwt, JoseError, JWTClaims +from ..rfc6749 import TokenMixin +from ..rfc6750 import BearerTokenValidator + + +class JWTBearerToken(TokenMixin, JWTClaims): + def check_client(self, client): + return self['client_id'] == client.get_client_id() + + def get_scope(self): + return self.get('scope') + + def get_expires_in(self): + return self['exp'] - self['iat'] + + def is_expired(self): + return self['exp'] < time.time() + + def is_revoked(self): + return False + + +class JWTBearerTokenValidator(BearerTokenValidator): + TOKEN_TYPE = 'bearer' + token_cls = JWTBearerToken + + def __init__(self, public_key, issuer=None, realm=None, **extra_attributes): + super(JWTBearerTokenValidator, self).__init__(realm, **extra_attributes) + self.public_key = public_key + claims_options = { + 'sub': {'essential': True}, + 'exp': {'essential': True}, + 'client_id': {'essential': True}, + 'grant_type': {'essential': True}, + } + if issuer: + claims_options['iss'] = {'essential': True, 'value': issuer} + self.claims_options = claims_options + + def authenticate_token(self, token_string): + try: + claims = jwt.decode( + token_string, self.public_key, + claims_options=self.claims_options, + claims_cls=self.token_cls, + ) + claims.validate() + return claims + except JoseError: + return None From cb2bbe2e82491b783a31e1c58428e0716e43575b Mon Sep 17 00:00:00 2001 From: Hsiaoming Yang Date: Wed, 9 Dec 2020 01:04:51 +0900 Subject: [PATCH 24/26] Append kid into header when jwt.encode --- authlib/jose/rfc7519/jwt.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/authlib/jose/rfc7519/jwt.py b/authlib/jose/rfc7519/jwt.py index c76b583f..1866c4e0 100644 --- a/authlib/jose/rfc7519/jwt.py +++ b/authlib/jose/rfc7519/jwt.py @@ -10,7 +10,7 @@ from ..errors import DecodeError, InsecureClaimError from ..rfc7515 import JsonWebSignature from ..rfc7516 import JsonWebEncryption -from ..rfc7517 import KeySet +from ..rfc7517 import KeySet, Key class JsonWebToken(object): @@ -150,6 +150,12 @@ def find_encode_key(key, header): header['kid'] = rv['kid'] return rv raise ValueError('Invalid JSON Web Key Set') + + # append kid into header + if isinstance(key, dict) and 'kid' in key: + header['kid'] = key['kid'] + elif isinstance(key, Key) and key.kid: + header['kid'] = key.kid return key From 2468c5af745e3025b481caa848debaa574de59ab Mon Sep 17 00:00:00 2001 From: Hsiaoming Yang Date: Wed, 9 Dec 2020 21:56:47 +0900 Subject: [PATCH 25/26] Add OpenIDToken extension for other flow This will fix https://github.com/lepture/authlib/issues/301 --- authlib/oidc/core/grants/code.py | 8 +++-- authlib/oidc/core/grants/util.py | 51 +++++++++++++------------------- 2 files changed, 27 insertions(+), 32 deletions(-) diff --git a/authlib/oidc/core/grants/code.py b/authlib/oidc/core/grants/code.py index 0e01bb23..e2059211 100644 --- a/authlib/oidc/core/grants/code.py +++ b/authlib/oidc/core/grants/code.py @@ -74,8 +74,12 @@ def process_token(self, grant, token): config = self.get_jwt_config(grant) config['aud'] = self.get_audiences(request) - config['nonce'] = credential.get_nonce() - config['auth_time'] = credential.get_auth_time() + + if credential: + config['nonce'] = credential.get_nonce() + config['auth_time'] = credential.get_auth_time() + else: + config['nonce'] = request.data.get('nonce') user_info = self.generate_user_info(request.user, token['scope']) id_token = generate_id_token(token, user_info, **config) diff --git a/authlib/oidc/core/grants/util.py b/authlib/oidc/core/grants/util.py index cb366260..e10b4596 100644 --- a/authlib/oidc/core/grants/util.py +++ b/authlib/oidc/core/grants/util.py @@ -61,11 +61,27 @@ def generate_id_token( token, user_info, key, iss, aud, alg='RS256', exp=3600, nonce=None, auth_time=None, code=None): - payload = _generate_id_token_payload( - alg=alg, iss=iss, aud=aud, exp=exp, nonce=nonce, - auth_time=auth_time, code=code, - access_token=token.get('access_token'), - ) + now = int(time.time()) + if auth_time is None: + auth_time = now + + payload = { + 'iss': iss, + 'aud': aud, + 'iat': now, + 'exp': now + exp, + 'auth_time': auth_time, + } + if nonce: + payload['nonce'] = nonce + + if code: + payload['c_hash'] = to_native(create_half_hash(code, alg)) + + access_token = token.get('access_token') + if access_token: + payload['at_hash'] = to_native(create_half_hash(access_token, alg)) + payload.update(user_info) return to_native(jwt.encode({'alg': alg}, payload, key)) @@ -113,28 +129,3 @@ def _guess_prompt_value(end_user, prompts, redirect_uri, redirect_fragment): redirect_uri=redirect_uri, redirect_fragment=redirect_fragment) return 'select_account' - - -def _generate_id_token_payload( - alg, iss, aud, exp, nonce=None, auth_time=None, - code=None, access_token=None): - now = int(time.time()) - if auth_time is None: - auth_time = now - - payload = { - 'iss': iss, - 'aud': aud, - 'iat': now, - 'exp': now + exp, - 'auth_time': auth_time, - } - if nonce: - payload['nonce'] = nonce - - if code: - payload['c_hash'] = to_native(create_half_hash(code, alg)) - - if access_token: - payload['at_hash'] = to_native(create_half_hash(access_token, alg)) - return payload From 36d5b3667520baada9135655c4e1377f4aec1177 Mon Sep 17 00:00:00 2001 From: Hsiaoming Yang Date: Wed, 9 Dec 2020 22:07:16 +0900 Subject: [PATCH 26/26] Add tests for adding OpenIDToken to password flow Related: https://github.com/lepture/authlib/issues/301 --- authlib/oidc/core/__init__.py | 4 +-- authlib/oidc/core/grants/code.py | 2 -- .../flask/test_oauth2/test_password_grant.py | 34 +++++++++++++++++-- 3 files changed, 33 insertions(+), 7 deletions(-) diff --git a/authlib/oidc/core/__init__.py b/authlib/oidc/core/__init__.py index 8ee628fa..212ebc03 100644 --- a/authlib/oidc/core/__init__.py +++ b/authlib/oidc/core/__init__.py @@ -12,12 +12,12 @@ IDToken, CodeIDToken, ImplicitIDToken, HybridIDToken, UserInfo, get_claim_cls_by_response_type, ) -from .grants import OpenIDCode, OpenIDHybridGrant, OpenIDImplicitGrant +from .grants import OpenIDToken, OpenIDCode, OpenIDHybridGrant, OpenIDImplicitGrant __all__ = [ 'AuthorizationCodeMixin', 'IDToken', 'CodeIDToken', 'ImplicitIDToken', 'HybridIDToken', 'UserInfo', 'get_claim_cls_by_response_type', - 'OpenIDCode', 'OpenIDHybridGrant', 'OpenIDImplicitGrant', + 'OpenIDToken', 'OpenIDCode', 'OpenIDHybridGrant', 'OpenIDImplicitGrant', ] diff --git a/authlib/oidc/core/grants/code.py b/authlib/oidc/core/grants/code.py index e2059211..040a360c 100644 --- a/authlib/oidc/core/grants/code.py +++ b/authlib/oidc/core/grants/code.py @@ -78,8 +78,6 @@ def process_token(self, grant, token): if credential: config['nonce'] = credential.get_nonce() config['auth_time'] = credential.get_auth_time() - else: - config['nonce'] = request.data.get('nonce') user_info = self.generate_user_info(request.user, token['scope']) id_token = generate_id_token(token, user_info, **config) diff --git a/tests/flask/test_oauth2/test_password_grant.py b/tests/flask/test_oauth2/test_password_grant.py index c5fb3694..9ddfcb19 100644 --- a/tests/flask/test_oauth2/test_password_grant.py +++ b/tests/flask/test_oauth2/test_password_grant.py @@ -3,11 +3,24 @@ from authlib.oauth2.rfc6749.grants import ( ResourceOwnerPasswordCredentialsGrant as _PasswordGrant, ) +from authlib.oidc.core import OpenIDToken from .models import db, User, Client from .oauth2_server import TestCase from .oauth2_server import create_authorization_server +class IDToken(OpenIDToken): + def get_jwt_config(self, grant): + return { + 'iss': 'Authlib', + 'key': 'secret', + 'alg': 'HS256', + } + + def generate_user_info(self, user, scopes): + return user.generate_user_info(scopes) + + class PasswordGrant(_PasswordGrant): def authenticate_user(self, username, password): user = User.query.filter_by(username=username).first() @@ -16,9 +29,9 @@ def authenticate_user(self, username, password): class PasswordTest(TestCase): - def prepare_data(self, grant_type='password'): + def prepare_data(self, grant_type='password', extensions=None): server = create_authorization_server(self.app) - server.register_grant(PasswordGrant) + server.register_grant(PasswordGrant, extensions) self.server = server user = User(username='foo') @@ -30,7 +43,7 @@ def prepare_data(self, grant_type='password'): client_secret='password-secret', ) client.set_client_metadata({ - 'scope': 'profile', + 'scope': 'openid profile', 'grant_types': [grant_type], 'redirect_uris': ['http://localhost/authorized'], }) @@ -164,3 +177,18 @@ def test_custom_expires_in(self): resp = json.loads(rv.data) self.assertIn('access_token', resp) self.assertEqual(resp['expires_in'], 1800) + + def test_id_token_extension(self): + self.prepare_data(extensions=[IDToken()]) + headers = self.create_basic_header( + 'password-client', 'password-secret' + ) + rv = self.client.post('/oauth/token', data={ + 'grant_type': 'password', + 'username': 'foo', + 'password': 'ok', + 'scope': 'openid profile', + }, headers=headers) + resp = json.loads(rv.data) + self.assertIn('access_token', resp) + self.assertIn('id_token', resp)