diff --git a/.github/workflows/python.yml b/.github/workflows/python.yml index dbbdfa88..a9256cf3 100644 --- a/.github/workflows/python.yml +++ b/.github/workflows/python.yml @@ -41,7 +41,7 @@ jobs: - name: Test with tox ${{ matrix.python.toxenv }} env: - TOXENV: py,flask,django,starlette + TOXENV: py,flask,django,fastapi,starlette run: tox - name: Report coverage diff --git a/.gitignore b/.gitignore index b0bcd0b1..467ba799 100644 --- a/.gitignore +++ b/.gitignore @@ -2,6 +2,7 @@ *.pyo *.egg-info *.swp +*.db __pycache__ build develop-eggs diff --git a/Makefile b/Makefile index 617a66e2..810d2ee3 100644 --- a/Makefile +++ b/Makefile @@ -3,7 +3,7 @@ clean: clean-build clean-pyc clean-docs clean-tox tests: - @TOXENV=py,flask,django,coverage tox + @TOXENV=py,flask,django,fastapi,coverage tox clean-build: @rm -fr build/ diff --git a/authlib/integrations/fastapi_oauth2/__init__.py b/authlib/integrations/fastapi_oauth2/__init__.py new file mode 100644 index 00000000..9a001a77 --- /dev/null +++ b/authlib/integrations/fastapi_oauth2/__init__.py @@ -0,0 +1,4 @@ +"""FastAPI package implementation.""" + +from .authorization_server import AuthorizationServer +from .resource_protector import ResourceProtector diff --git a/authlib/integrations/fastapi_oauth2/authorization_server.py b/authlib/integrations/fastapi_oauth2/authorization_server.py new file mode 100644 index 00000000..9e5feb13 --- /dev/null +++ b/authlib/integrations/fastapi_oauth2/authorization_server.py @@ -0,0 +1,145 @@ +"""Implementation of authlib.oauth2.rfc6749.AuthorizationServer class for FastAPI.""" + +import json + +from authlib.common.security import generate_token +from authlib.oauth2 import AuthorizationServer as _AuthorizationServer +from authlib.oauth2 import HttpRequest, OAuth2Request +from authlib.oauth2.rfc6750 import BearerToken +from authlib.oauth2.rfc8414 import AuthorizationServerMetadata +from fastapi.responses import JSONResponse +from werkzeug.utils import import_string + + +class AuthorizationServer(_AuthorizationServer): + """AuthorizationServer class.""" + + def __init__(self, app=None, query_client=None, save_token=None): + super(AuthorizationServer, self).__init__() + self._query_client = query_client + self._save_token = save_token + self.config = {} + if app: + self.init_app(app) + + def init_app(self, app, query_client=None, save_token=None): + """Initialize the FastAPI app.""" + if query_client: + self.query_client = query_client + if save_token: + self.save_token = save_token + + self.generate_token = create_bearer_token_generator(app.config) + + metadata_class = AuthorizationServerMetadata + + metadata_file = app.config.get("OAUTH2_METADATA_FILE") + if metadata_file: + with open(metadata_file) as metadata_file_content: + metadata = metadata_class(json.loads(metadata_file_content)) + metadata.validate() + self.metadata = metadata + + self.scopes_supported = app.config.get("OAUTH2_SCOPES_SUPPORTED") + self._error_uris = app.config.get("OAUTH2_ERROR_URIS") + + def query_client(self, client_id): + return self._query_client(client_id) + + def save_token(self, token, request): + return self._save_token(token, request) + + def get_error_uri(self, request, error): + if self._error_uris: + uris = dict(self._error_uris) + return uris.get(error.error) + + def create_oauth2_request(self, request): + return OAuth2Request( + request.method, str(request.url), request.body, request.headers + ) + + def create_json_request(self, request): + return HttpRequest( + request.method, str(request.url), request.body, request.headers + ) + + def send_signal(self, name, *args, **kwargs): + pass + + def handle_response(self, status, body, headers): + return JSONResponse(content=body, status_code=status, headers=dict(headers)) + + def validate_consent_request(self, request=None, end_user=None): + """Validate current HTTP request for authorization page. This page + is designed for resource owner to grant or deny the authorization""" + req = self.create_oauth2_request(request) + req.user = end_user + + grant = self.get_authorization_grant(req) + grant.validate_consent_request() + if not hasattr(grant, "prompt"): + grant.prompt = None + return grant + + +def create_bearer_token_generator(config): + """Create a generator function for generating ``token`` value. This + method will create a Bearer Token generator with + :class:`authlib.oauth2.rfc6750.BearerToken`. By default, it will not + generate ``refresh_token``, which can be turn on by configuration + ``OAUTH2_REFRESH_TOKEN_GENERATOR=True``. + """ + conf = config.get("OAUTH2_ACCESS_TOKEN_GENERATOR", True) + access_token_generator = create_token_generator(conf, 42) + + conf = config.get("OAUTH2_REFRESH_TOKEN_GENERATOR", False) + refresh_token_generator = create_token_generator(conf, 48) + + expires_generator = create_token_expires_in_generator(config) + + return BearerToken( + access_token_generator, refresh_token_generator, expires_generator + ) + + +def create_token_expires_in_generator(config): + """Create a generator function for generating ``expires_in`` value. + Developers can re-implement this method with a subclass if other means + required. The default expires_in value is defined by ``grant_type``, + different ``grant_type`` has different value. It can be configured + with:: + + OAUTH2_TOKEN_EXPIRES_IN = { + 'authorization_code': 864000 + } + """ + data = {} + data.update(BearerToken.GRANT_TYPES_EXPIRES_IN) + + expires_in_conf = config.get("OAUTH2_TOKEN_EXPIRES_IN") + if expires_in_conf: + data.update(expires_in_conf) + + def expires_in(client, grant_type): # pylint: disable=W0613 + return data.get(grant_type, BearerToken.DEFAULT_EXPIRES_IN) + + return expires_in + + +def create_token_generator(token_generator_conf, length=42): + """Create a token generator function.""" + if callable(token_generator_conf): + return token_generator_conf + + if isinstance(token_generator_conf, str): + return import_string(token_generator_conf) + + if token_generator_conf is True: + + def token_generator(*args, **kwargs): # pylint: disable=W0613 + return generate_token(length) + + return token_generator + + return None diff --git a/authlib/integrations/fastapi_oauth2/resource_protector.py b/authlib/integrations/fastapi_oauth2/resource_protector.py new file mode 100644 index 00000000..df1d8d97 --- /dev/null +++ b/authlib/integrations/fastapi_oauth2/resource_protector.py @@ -0,0 +1,60 @@ +"""Implementation of authlib.oauth2.rfc6749.ResourceProtector class for FastAPI.""" + +import functools +from contextlib import contextmanager + +from authlib.oauth2 import OAuth2Error +from authlib.oauth2 import ResourceProtector as _ResourceProtector +from authlib.oauth2.rfc6749 import HttpRequest, MissingAuthorizationError +from fastapi import HTTPException + + +class ResourceProtector(_ResourceProtector): + """ResourceProtector class.""" + + def acquire_token(self, request=None, scope=None): + """A method to acquire current valid token with the given scope. + + :param request: request object + :param scope: string or list of scope values + :return: token object + """ + http_request = HttpRequest(request.method, request.url, {}, request.headers) + token = self.validate_request(scope, http_request) + request.state.token = token + return token + + @contextmanager + def acquire(self, request=None, scope=None): + """The with statement of ``require_oauth``. Instead of using a + decorator, you can use a with statement instead.""" + try: + yield self.acquire_token(request, scope) + except OAuth2Error as error: + raise_error_response(error) + + def __call__(self, scope=None, optional=False): + def wrapper(func): + @functools.wraps(func) + def decorated(request, *args, **kwargs): + try: + self.acquire_token(request, scope) + except MissingAuthorizationError as error: + if optional: + return func(request, *args, **kwargs) + raise_error_response(error) + except OAuth2Error as error: + raise_error_response(error) + return func(request, *args, **kwargs) + + return decorated + + return wrapper + + +def raise_error_response(error): + """Raise the FastAPI HTTPException method.""" + status = error.status_code + body = dict(error.get_body()) + headers = error.get_headers() + raise HTTPException(status_code=status, detail=body, headers=dict(headers)) diff --git a/tests/fastapi/__init__.py b/tests/fastapi/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/fastapi/test_oauth2/__init__.py b/tests/fastapi/test_oauth2/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/fastapi/test_oauth2/models.py b/tests/fastapi/test_oauth2/models.py new file mode 100644 index 00000000..13e2783d --- /dev/null +++ b/tests/fastapi/test_oauth2/models.py @@ -0,0 +1,115 @@ +import time + +from authlib.integrations.sqla_oauth2 import (OAuth2AuthorizationCodeMixin, + OAuth2ClientMixin, + OAuth2TokenMixin) +from authlib.oidc.core import UserInfo +from sqlalchemy import (Boolean, Column, ForeignKey, Integer, String, + create_engine) +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.orm import relationship, sessionmaker + +engine = create_engine( + "sqlite:///fastapi_auth2_sql.db", connect_args={"check_same_thread": False} +) + +SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) + +Base = declarative_base() + +db = SessionLocal() + + +class User(Base): + __tablename__ = "user" + + id = Column(Integer, primary_key=True) + username = Column(String(40), unique=True, nullable=False) + + def get_user_id(self): + return self.id + + def check_password(self, password): + return password != "wrong" + + def generate_user_info(self, scopes): + profile = {"sub": str(self.id), "name": self.username} + return UserInfo(profile) + + +class Client(Base, OAuth2ClientMixin): + __tablename__ = "oauth2_client" + + id = Column(Integer, primary_key=True) + user_id = Column(Integer, ForeignKey("user.id", ondelete="CASCADE")) + user = relationship("User") + + +class AuthorizationCode(Base, OAuth2AuthorizationCodeMixin): + __tablename__ = "oauth2_code" + + id = Column(Integer, primary_key=True) + user_id = Column(Integer, nullable=False) + + @property + def user(self): + return db.query(User).filter(User.id == self.user_id).first() + + +class Token(Base, OAuth2TokenMixin): + __tablename__ = "oauth2_token" + + id = Column(Integer, primary_key=True) + user_id = Column(Integer, ForeignKey("user.id", ondelete="CASCADE")) + user = relationship("User") + revoked = Column(Boolean) + + def is_refresh_token_expired(self): + expired_at = self.issued_at + self.expires_in * 2 + return expired_at < time.time() + + +class CodeGrantMixin(object): + def query_authorization_code(self, code, client): + item = ( + db.query(AuthorizationCode) + .filter( + AuthorizationCode.code == code, Client.client_id == client.client_id + ) + .first() + ) + if item and not item.is_expired(): + return item + + def delete_authorization_code(self, authorization_code): + db.delete(authorization_code) + db.commit() + + def authenticate_user(self, authorization_code): + return db.query(User).filter(User.id == authorization_code.user_id).first() + + +def save_authorization_code(code, request): + client = request.client + auth_code = AuthorizationCode( + code=code, + client_id=client.client_id, + redirect_uri=request.redirect_uri, + scope=request.scope, + nonce=request.data.get("nonce"), + user_id=request.user.id, + code_challenge=request.data.get("code_challenge"), + code_challenge_method=request.data.get("code_challenge_method"), + ) + db.add(auth_code) + db.commit() + return auth_code + + +def exists_nonce(nonce, req): + exists = ( + db.query(AuthorizationCode) + .filter(Client.client_id == req.client_id, AuthorizationCode.nonce == nonce) + .first() + ) + return bool(exists) diff --git a/tests/fastapi/test_oauth2/oauth2_server.py b/tests/fastapi/test_oauth2/oauth2_server.py new file mode 100644 index 00000000..d5a70760 --- /dev/null +++ b/tests/fastapi/test_oauth2/oauth2_server.py @@ -0,0 +1,189 @@ +import base64 +import os +import unittest + +from authlib.common.encoding import to_bytes, to_unicode +from authlib.common.security import generate_token +from authlib.common.urls import url_encode +from authlib.integrations.fastapi_oauth2 import AuthorizationServer +from authlib.integrations.sqla_oauth2 import (create_query_client_func, + create_save_token_func) +from authlib.oauth2 import OAuth2Error +from fastapi import FastAPI, Form, Request +from fastapi.testclient import TestClient + +from .models import Base, Client, Token, User, db, engine + + +def token_generator(client, grant_type, user=None, scope=None): + token = "{}-{}".format(client.client_id[0], grant_type) + if user: + token = "{}.{}".format(token, user.get_user_id()) + return "{}.{}".format(token, generate_token(32)) + + +def create_authorization_server(app): + query_client = create_query_client_func(db, Client) + save_token = create_save_token_func(db, Token) + + server = AuthorizationServer() + server.init_app(app, query_client, save_token) + + @app.get("/oauth/authorize") + def authorize_get(request: Request): + user_id = request.query_params.get("user_id") + request.body = {} + if user_id: + end_user = db.query(User).filter(User.id == int(user_id)).first() + else: + end_user = None + try: + grant = server.validate_consent_request(request=request, end_user=end_user) + return grant.prompt or "ok" + except OAuth2Error as error: + return url_encode(error.get_body()) + + @app.post("/oauth/authorize") + def authorize_post( + request: Request, + response_type: str = Form(None), + client_id: str = Form(None), + state: str = Form(None), + scope: str = Form(None), + nonce: str = Form(None), + redirect_uri: str = Form(None), + response_mode: str = Form(None), + user_id: str = Form(None), + ): + if not user_id: + user_id = request.query_params.get("user_id") + + request.body = {"user_id": user_id} + + if response_type: + request.body.update({"response_type": response_type}) + + if client_id: + request.body.update({"client_id": client_id}) + + if state: + request.body.update({"state": state}) + + if nonce: + request.body.update({"nonce": nonce}) + + if scope: + request.body.update({"scope": scope}) + + if redirect_uri: + request.body.update({"redirect_uri": redirect_uri}) + + if response_mode: + request.body.update({"response_mode": response_mode}) + + if user_id: + grant_user = db.query(User).filter(User.id == int(user_id)).first() + else: + grant_user = None + + return server.create_authorization_response( + request=request, grant_user=grant_user + ) + + @app.api_route("/oauth/token", methods=["GET", "POST"]) + def issue_token( + request: Request, + grant_type: str = Form(None), + scope: str = Form(None), + code: str = Form(None), + refresh_token: str = Form(None), + code_verifier: str = Form(None), + client_id: str = Form(None), + client_secret: str = Form(None), + device_code: str = Form(None), + client_assertion_type: str = Form(None), + client_assertion: str = Form(None), + assertion: str = Form(None), + username: str = Form(None), + password: str = Form(None), + redirect_uri: str = Form(None), + ): + request.body = { + "grant_type": grant_type, + "scope": scope, + } + + if not grant_type: + grant_type = request.query_params.get("grant_type") + request.body.update({"grant_type": grant_type}) + + if grant_type == "authorization_code": + request.body.update({"code": code}) + elif grant_type == "refresh_token": + request.body.update({"refresh_token": refresh_token}) + + if code_verifier: + request.body.update({"code_verifier": code_verifier}) + + if client_id: + request.body.update({"client_id": client_id}) + + if client_secret: + request.body.update({"client_secret": client_secret}) + + if device_code: + request.body.update({"device_code": device_code}) + + if client_assertion_type: + request.body.update({"client_assertion_type": client_assertion_type}) + + if client_assertion: + request.body.update({"client_assertion": client_assertion}) + + if assertion: + request.body.update({"assertion": assertion}) + + if redirect_uri: + request.body.update({"redirect_uri": redirect_uri}) + + if username: + request.body.update({"username": username}) + + if password: + request.body.update({"password": password}) + + return server.create_token_response(request=request) + + return server + + +def create_fastapi_app(): + app = FastAPI() + app.debug = True + app.testing = True + app.secret_key = "testing" + app.test_client = TestClient(app) + app.config = { + "OAUTH2_ERROR_URIS": [("invalid_client", "https://a.b/e#invalid_client")] + } + return app + + +class TestCase(unittest.TestCase): + def setUp(self): + os.environ["AUTHLIB_INSECURE_TRANSPORT"] = "true" + app = create_fastapi_app() + + Base.metadata.create_all(bind=engine) + + self.app = app + self.client = app.test_client + + def tearDown(self): + Base.metadata.drop_all(bind=engine) + os.environ.pop("AUTHLIB_INSECURE_TRANSPORT") + + def create_basic_header(self, username, password): + text = "{}:{}".format(username, password) + auth = to_unicode(base64.b64encode(to_bytes(text))) + return {"Authorization": "Basic " + auth} diff --git a/tests/fastapi/test_oauth2/test_authorization_code_grant.py b/tests/fastapi/test_oauth2/test_authorization_code_grant.py new file mode 100644 index 00000000..d244e4fd --- /dev/null +++ b/tests/fastapi/test_oauth2/test_authorization_code_grant.py @@ -0,0 +1,288 @@ +from authlib.common.urls import url_decode, urlparse +from authlib.oauth2.rfc6749.grants import \ + AuthorizationCodeGrant as _AuthorizationCodeGrant + +from .models import (AuthorizationCode, Client, CodeGrantMixin, User, db, + save_authorization_code) +from .oauth2_server import TestCase, create_authorization_server + + +class AuthorizationCodeGrant(CodeGrantMixin, _AuthorizationCodeGrant): + TOKEN_ENDPOINT_AUTH_METHODS = ["client_secret_basic", "client_secret_post", "none"] + + def save_authorization_code(self, code, request): + return save_authorization_code(code, request) + + +class AuthorizationCodeTest(TestCase): + def register_grant(self, server): + server.register_grant(AuthorizationCodeGrant) + + def prepare_data( + self, + is_confidential=True, + response_type="code", + grant_type="authorization_code", + token_endpoint_auth_method="client_secret_basic", + ): + server = create_authorization_server(self.app) + self.register_grant(server) + self.server = server + + user = User(username="foo") + db.add(user) + db.commit() + + if is_confidential: + client_secret = "code-secret" + else: + client_secret = "" + client = Client( + user_id=user.id, + client_id="code-client", + client_secret=client_secret, + ) + client.set_client_metadata( + { + "redirect_uris": ["https://a.b"], + "scope": "profile address", + "token_endpoint_auth_method": token_endpoint_auth_method, + "response_types": [response_type], + "grant_types": grant_type.splitlines(), + } + ) + self.authorize_url = ( + "/oauth/authorize?response_type=code" "&client_id=code-client" + ) + db.add(client) + db.commit() + + def test_get_authorize(self): + self.prepare_data() + rv = self.client.get(self.authorize_url) + self.assertEqual(rv.json(), "ok") + + def test_invalid_client_id(self): + self.prepare_data() + url = "/oauth/authorize?response_type=code" + rv = self.client.get(url) + self.assertIn("invalid_client", rv.json()) + + url = "/oauth/authorize?response_type=code&client_id=invalid" + rv = self.client.get(url) + self.assertIn("invalid_client", rv.json()) + + def test_invalid_authorize(self): + self.prepare_data() + rv = self.client.post(self.authorize_url) + self.assertIn("error=access_denied", rv.headers["location"]) + + self.server.scopes_supported = ["profile"] + rv = self.client.post(self.authorize_url + "&scope=invalid&state=foo") + self.assertIn("error=invalid_scope", rv.headers["location"]) + self.assertIn("state=foo", rv.headers["location"]) + + def test_unauthorized_client(self): + self.prepare_data(True, "token") + rv = self.client.get(self.authorize_url) + self.assertIn("unauthorized_client", rv.json()) + + def test_invalid_client(self): + self.prepare_data() + rv = self.client.post( + "/oauth/token", + data={ + "grant_type": "authorization_code", + "code": "invalid", + "client_id": "invalid-id", + }, + ) + resp = rv.json() + self.assertEqual(resp["error"], "invalid_client") + + headers = self.create_basic_header("code-client", "invalid-secret") + rv = self.client.post( + "/oauth/token", + data={ + "grant_type": "authorization_code", + "code": "invalid", + }, + headers=headers, + ) + resp = rv.json() + self.assertEqual(resp["error"], "invalid_client") + self.assertEqual(resp["error_uri"], "https://a.b/e#invalid_client") + + def test_invalid_code(self): + self.prepare_data() + + headers = self.create_basic_header("code-client", "code-secret") + rv = self.client.post( + "/oauth/token", + data={ + "grant_type": "authorization_code", + }, + headers=headers, + ) + resp = rv.json() + self.assertEqual(resp["error"], "invalid_request") + + rv = self.client.post( + "/oauth/token", + data={ + "grant_type": "authorization_code", + "code": "invalid", + }, + headers=headers, + ) + resp = rv.json() + self.assertEqual(resp["error"], "invalid_request") + + code = AuthorizationCode(code="no-user", client_id="code-client", user_id=0) + db.add(code) + db.commit() + rv = self.client.post( + "/oauth/token", + data={ + "grant_type": "authorization_code", + "code": "no-user", + }, + headers=headers, + ) + resp = rv.json() + self.assertEqual(resp["error"], "invalid_request") + + def test_invalid_redirect_uri(self): + self.prepare_data() + uri = self.authorize_url + "&redirect_uri=https%3A%2F%2Fa.c" + rv = self.client.post(uri, data={"user_id": "1"}) + resp = rv.json() + self.assertEqual(resp["error"], "invalid_request") + + uri = self.authorize_url + "&redirect_uri=https%3A%2F%2Fa.b" + rv = self.client.post(uri, data={"user_id": "1"}) + self.assertIn("code=", rv.headers["location"]) + + params = dict(url_decode(urlparse.urlparse(rv.headers["location"]).query)) + code = params["code"] + headers = self.create_basic_header("code-client", "code-secret") + rv = self.client.post( + "/oauth/token", + data={ + "grant_type": "authorization_code", + "code": code, + }, + headers=headers, + ) + resp = rv.json() + self.assertEqual(resp["error"], "invalid_request") + + def test_invalid_grant_type(self): + self.prepare_data( + False, token_endpoint_auth_method="none", grant_type="invalid" + ) + rv = self.client.post( + "/oauth/token", + data={ + "grant_type": "authorization_code", + "client_id": "code-client", + "code": "a", + }, + ) + resp = rv.json() + self.assertEqual(resp["error"], "unauthorized_client") + + def test_authorize_token_no_refresh_token(self): + self.app.config.update({"OAUTH2_REFRESH_TOKEN_GENERATOR": True}) + self.prepare_data(False, token_endpoint_auth_method="none") + + rv = self.client.post(self.authorize_url, data={"user_id": "1"}) + self.assertIn("code=", rv.headers["location"]) + + params = dict(url_decode(urlparse.urlparse(rv.headers["location"]).query)) + code = params["code"] + rv = self.client.post( + "/oauth/token", + data={ + "grant_type": "authorization_code", + "code": code, + "client_id": "code-client", + }, + ) + resp = rv.json() + self.assertIn("access_token", resp) + self.assertNotIn("refresh_token", resp) + + def test_authorize_token_has_refresh_token(self): + # generate refresh token + self.app.config.update({"OAUTH2_REFRESH_TOKEN_GENERATOR": True}) + self.prepare_data(grant_type="authorization_code\nrefresh_token") + url = self.authorize_url + "&state=bar" + rv = self.client.post(url, data={"user_id": "1"}) + self.assertIn("code=", rv.headers["location"]) + + params = dict(url_decode(urlparse.urlparse(rv.headers["location"]).query)) + self.assertEqual(params["state"], "bar") + + code = params["code"] + headers = self.create_basic_header("code-client", "code-secret") + rv = self.client.post( + "/oauth/token", + data={ + "grant_type": "authorization_code", + "code": code, + }, + headers=headers, + ) + resp = rv.json() + self.assertIn("access_token", resp) + self.assertIn("refresh_token", resp) + + def test_client_secret_post(self): + self.app.config.update({"OAUTH2_REFRESH_TOKEN_GENERATOR": True}) + self.prepare_data( + grant_type="authorization_code\nrefresh_token", + token_endpoint_auth_method="client_secret_post", + ) + url = self.authorize_url + "&state=bar" + rv = self.client.post(url, data={"user_id": "1"}) + self.assertIn("code=", rv.headers["location"]) + + params = dict(url_decode(urlparse.urlparse(rv.headers["location"]).query)) + self.assertEqual(params["state"], "bar") + + code = params["code"] + rv = self.client.post( + "/oauth/token", + data={ + "grant_type": "authorization_code", + "client_id": "code-client", + "client_secret": "code-secret", + "code": code, + }, + ) + resp = rv.json() + self.assertIn("access_token", resp) + self.assertIn("refresh_token", resp) + + def test_token_generator(self): + m = "tests.fastapi.test_oauth2.oauth2_server:token_generator" + self.app.config.update({"OAUTH2_ACCESS_TOKEN_GENERATOR": m}) + self.prepare_data(False, token_endpoint_auth_method="none") + + rv = self.client.post(self.authorize_url, data={"user_id": "1"}) + self.assertIn("code=", rv.headers["location"]) + + params = dict(url_decode(urlparse.urlparse(rv.headers["location"]).query)) + code = params["code"] + rv = self.client.post( + "/oauth/token", + data={ + "grant_type": "authorization_code", + "code": code, + "client_id": "code-client", + }, + ) + resp = rv.json() + self.assertIn("access_token", resp) + self.assertIn("c-authorization_code.1.", resp["access_token"]) diff --git a/tests/fastapi/test_oauth2/test_client_credentials_grant.py b/tests/fastapi/test_oauth2/test_client_credentials_grant.py new file mode 100644 index 00000000..29470d78 --- /dev/null +++ b/tests/fastapi/test_oauth2/test_client_credentials_grant.py @@ -0,0 +1,109 @@ +from authlib.oauth2.rfc6749.grants import ClientCredentialsGrant + +from .models import Client, User, db +from .oauth2_server import TestCase, create_authorization_server + + +class ClientCredentialsTest(TestCase): + def prepare_data(self, grant_type="client_credentials"): + server = create_authorization_server(self.app) + server.register_grant(ClientCredentialsGrant) + self.server = server + + user = User(username="foo") + db.add(user) + db.commit() + client = Client( + user_id=user.id, + client_id="credential-client", + client_secret="credential-secret", + ) + client.set_client_metadata( + { + "scope": "profile", + "redirect_uris": ["http://localhost/authorized"], + "grant_types": [grant_type], + } + ) + db.add(client) + db.commit() + + def test_invalid_client(self): + self.prepare_data() + rv = self.client.post( + "/oauth/token", + data={ + "grant_type": "client_credentials", + }, + ) + resp = rv.json() + self.assertEqual(resp["error"], "invalid_client") + + headers = self.create_basic_header("credential-client", "invalid-secret") + rv = self.client.post( + "/oauth/token", + data={ + "grant_type": "client_credentials", + }, + headers=headers, + ) + resp = rv.json() + self.assertEqual(resp["error"], "invalid_client") + + def test_invalid_grant_type(self): + self.prepare_data(grant_type="invalid") + headers = self.create_basic_header("credential-client", "credential-secret") + rv = self.client.post( + "/oauth/token", + data={ + "grant_type": "client_credentials", + }, + headers=headers, + ) + resp = rv.json() + self.assertEqual(resp["error"], "unauthorized_client") + + def test_invalid_scope(self): + self.prepare_data() + self.server.scopes_supported = ["profile"] + headers = self.create_basic_header("credential-client", "credential-secret") + rv = self.client.post( + "/oauth/token", + data={ + "grant_type": "client_credentials", + "scope": "invalid", + }, + headers=headers, + ) + resp = rv.json() + self.assertEqual(resp["error"], "invalid_scope") + + def test_authorize_token(self): + self.prepare_data() + headers = self.create_basic_header("credential-client", "credential-secret") + rv = self.client.post( + "/oauth/token", + data={ + "grant_type": "client_credentials", + }, + headers=headers, + ) + resp = rv.json() + self.assertIn("access_token", resp) + + def test_token_generator(self): + m = "tests.fastapi.test_oauth2.oauth2_server:token_generator" + self.app.config.update({"OAUTH2_ACCESS_TOKEN_GENERATOR": m}) + + self.prepare_data() + headers = self.create_basic_header("credential-client", "credential-secret") + rv = self.client.post( + "/oauth/token", + data={ + "grant_type": "client_credentials", + }, + headers=headers, + ) + resp = rv.json() + self.assertIn("access_token", resp) + self.assertIn("c-client_credentials.", resp["access_token"]) diff --git a/tests/fastapi/test_oauth2/test_client_registration_endpoint.py b/tests/fastapi/test_oauth2/test_client_registration_endpoint.py new file mode 100644 index 00000000..a0a4c936 --- /dev/null +++ b/tests/fastapi/test_oauth2/test_client_registration_endpoint.py @@ -0,0 +1,199 @@ +from authlib.jose import jwt +from authlib.oauth2.rfc7591 import \ + ClientRegistrationEndpoint as _ClientRegistrationEndpoint +from fastapi import Request +from pydantic import BaseModel +from tests.util import read_file_path + +from .models import Client, User, db +from .oauth2_server import TestCase, create_authorization_server + + +class ClientRegistrationEndpoint(_ClientRegistrationEndpoint): + software_statement_alg_values_supported = ["RS256"] + + def authenticate_token(self, request): + auth_header = request.headers.get("Authorization") + if auth_header: + request.user_id = 1 + return auth_header + + def resolve_public_key(self, request): + return read_file_path("rsa_public.pem") + + def save_client(self, client_info, client_metadata, request): + client = Client(user_id=request.user_id, **client_info) + client.set_client_metadata(client_metadata) + db.add(client) + db.commit() + return client + + +class ClientRegistrationTest(TestCase): + def prepare_data(self, endpoint_cls=None, metadata=None): + app = self.app + server = create_authorization_server(app) + if metadata: + server.metadata = metadata + + if endpoint_cls: + server.register_endpoint(endpoint_cls) + else: + + class MyClientRegistration(ClientRegistrationEndpoint): + def get_server_metadata(self): + return metadata + + server.register_endpoint(MyClientRegistration) + + class Item(BaseModel): + client_name: str = None + client_uri: str = None + redirect_uri: str = None + scope: str = None + software_statement: str = None + token_endpoint_auth_method: str = None + grant_types: list = None + response_types: list = None + + @app.post("/create_client") + def create_client(request: Request, item: Item = None): + request.body = {} + if item: + request.body = { + "client_name": item.client_name, + "client_uri": item.client_uri, + "redirect_uri": item.redirect_uri, + "scope": item.scope, + "software_statement": item.software_statement, + "token_endpoint_auth_method": item.token_endpoint_auth_method, + "grant_types": item.grant_types, + "response_types": item.response_types, + } + return server.create_endpoint_response( + "client_registration", request=request + ) + + user = User(username="foo") + db.add(user) + db.commit() + + def test_access_denied(self): + self.prepare_data() + rv = self.client.post("/create_client") + resp = rv.json() + self.assertEqual(resp["error"], "access_denied") + + def test_invalid_request(self): + self.prepare_data() + headers = {"Authorization": "bearer abc"} + rv = self.client.post("/create_client", headers=headers) + resp = rv.json() + self.assertEqual(resp["error"], "invalid_request") + + def test_create_client(self): + self.prepare_data() + headers = {"Authorization": "bearer abc"} + body = {"client_name": "Authlib"} + rv = self.client.post("/create_client", json=body, headers=headers) + resp = rv.json() + self.assertIn("client_id", resp) + self.assertEqual(resp["client_name"], "Authlib") + + def test_software_statement(self): + payload = {"software_id": "uuid-123", "client_name": "Authlib"} + s = jwt.encode({"alg": "RS256"}, payload, read_file_path("rsa_private.pem")) + body = { + "software_statement": s.decode("utf-8"), + } + + self.prepare_data() + headers = {"Authorization": "bearer abc"} + rv = self.client.post("/create_client", json=body, headers=headers) + resp = rv.json() + self.assertIn("client_id", resp) + self.assertEqual(resp["client_name"], "Authlib") + + def test_no_public_key(self): + class ClientRegistrationEndpoint2(ClientRegistrationEndpoint): + def resolve_public_key(self, request): + return None + + payload = {"software_id": "uuid-123", "client_name": "Authlib"} + s = jwt.encode({"alg": "RS256"}, payload, read_file_path("rsa_private.pem")) + body = { + "software_statement": s.decode("utf-8"), + } + + self.prepare_data(ClientRegistrationEndpoint2) + headers = {"Authorization": "bearer abc"} + rv = self.client.post("/create_client", json=body, headers=headers) + resp = rv.json() + self.assertIn(resp["error"], "unapproved_software_statement") + + def test_scopes_supported(self): + metadata = {"scopes_supported": ["profile", "email"]} + self.prepare_data(metadata=metadata) + + headers = {"Authorization": "bearer abc"} + body = {"scope": "profile email", "client_name": "Authlib"} + rv = self.client.post("/create_client", json=body, headers=headers) + resp = rv.json() + self.assertIn("client_id", resp) + self.assertEqual(resp["client_name"], "Authlib") + + body = {"scope": "profile email address", "client_name": "Authlib"} + rv = self.client.post("/create_client", json=body, headers=headers) + resp = rv.json() + self.assertIn(resp["error"], "invalid_client_metadata") + + def test_response_types_supported(self): + metadata = {"response_types_supported": ["code"]} + self.prepare_data(metadata=metadata) + + headers = {"Authorization": "bearer abc"} + body = {"response_types": ["code"], "client_name": "Authlib"} + rv = self.client.post("/create_client", json=body, headers=headers) + resp = rv.json() + self.assertIn("client_id", resp) + self.assertEqual(resp["client_name"], "Authlib") + + body = {"response_types": ["code", "token"], "client_name": "Authlib"} + rv = self.client.post("/create_client", json=body, headers=headers) + resp = rv.json() + self.assertIn(resp["error"], "invalid_client_metadata") + + def test_grant_types_supported(self): + metadata = {"grant_types_supported": ["authorization_code", "password"]} + self.prepare_data(metadata=metadata) + + headers = {"Authorization": "bearer abc"} + body = {"grant_types": ["password"], "client_name": "Authlib"} + rv = self.client.post("/create_client", json=body, headers=headers) + resp = rv.json() + self.assertIn("client_id", resp) + self.assertEqual(resp["client_name"], "Authlib") + + body = {"grant_types": ["client_credentials"], "client_name": "Authlib"} + rv = self.client.post("/create_client", json=body, headers=headers) + resp = rv.json() + self.assertIn(resp["error"], "invalid_client_metadata") + + def test_token_endpoint_auth_methods_supported(self): + metadata = {"token_endpoint_auth_methods_supported": ["client_secret_basic"]} + self.prepare_data(metadata=metadata) + + headers = {"Authorization": "bearer abc"} + body = { + "token_endpoint_auth_method": "client_secret_basic", + "client_name": "Authlib", + } + rv = self.client.post("/create_client", json=body, headers=headers) + resp = rv.json() + self.assertIn("client_id", resp) + self.assertEqual(resp["client_name"], "Authlib") + + body = {"token_endpoint_auth_method": "none", "client_name": "Authlib"} + rv = self.client.post("/create_client", json=body, headers=headers) + resp = rv.json() + self.assertIn(resp["error"], "invalid_client_metadata") diff --git a/tests/fastapi/test_oauth2/test_code_challenge.py b/tests/fastapi/test_oauth2/test_code_challenge.py new file mode 100644 index 00000000..c9e8188c --- /dev/null +++ b/tests/fastapi/test_oauth2/test_code_challenge.py @@ -0,0 +1,224 @@ +from authlib.common.security import generate_token +from authlib.common.urls import urlparse, url_decode +from authlib.oauth2.rfc6749 import grants +from authlib.oauth2.rfc7636 import ( + CodeChallenge as _CodeChallenge, + create_s256_code_challenge, +) +from .models import db, User, Client +from .models import CodeGrantMixin, save_authorization_code +from .oauth2_server import TestCase +from .oauth2_server import create_authorization_server + + +class AuthorizationCodeGrant(CodeGrantMixin, grants.AuthorizationCodeGrant): + TOKEN_ENDPOINT_AUTH_METHODS = ['client_secret_basic', 'client_secret_post', 'none'] + + def save_authorization_code(self, code, request): + return save_authorization_code(code, request) + + +class CodeChallenge(_CodeChallenge): + SUPPORTED_CODE_CHALLENGE_METHOD = ['plain', 'S256', 'S128'] + + +class CodeChallengeTest(TestCase): + def prepare_data(self, token_endpoint_auth_method='none'): + server = create_authorization_server(self.app) + server.register_grant( + AuthorizationCodeGrant, + [CodeChallenge(required=True)] + ) + + user = User(username='foo') + db.add(user) + db.commit() + + client_secret = '' + if token_endpoint_auth_method != 'none': + client_secret = 'code-secret' + + client = Client( + user_id=user.id, + client_id='code-client', + client_secret=client_secret, + ) + client.set_client_metadata({ + 'redirect_uris': ['https://a.b'], + 'scope': 'profile address', + 'token_endpoint_auth_method': token_endpoint_auth_method, + 'response_types': ['code'], + 'grant_types': ['authorization_code'], + }) + self.authorize_url = ( + '/oauth/authorize?response_type=code' + '&client_id=code-client' + ) + db.add(client) + db.commit() + + def test_missing_code_challenge(self): + self.prepare_data() + rv = self.client.get(self.authorize_url + '&code_challenge_method=plain') + self.assertIn('Missing', rv.json()) + + def test_has_code_challenge(self): + self.prepare_data() + rv = self.client.get(self.authorize_url + '&code_challenge=abc') + self.assertEqual(rv.json(), 'ok') + + def test_invalid_code_challenge_method(self): + self.prepare_data() + suffix = '&code_challenge=abc&code_challenge_method=invalid' + rv = self.client.get(self.authorize_url + suffix) + self.assertIn('Unsupported', rv.json()) + + def test_supported_code_challenge_method(self): + self.prepare_data() + suffix = '&code_challenge=abc&code_challenge_method=plain' + rv = self.client.get(self.authorize_url + suffix) + self.assertEqual(rv.json(), 'ok') + + def test_trusted_client_without_code_challenge(self): + self.prepare_data('client_secret_basic') + rv = self.client.get(self.authorize_url) + self.assertEqual(rv.json(), 'ok') + + rv = self.client.post(self.authorize_url, data={'user_id': '1'}) + self.assertIn('code=', rv.headers['location']) + + params = dict(url_decode(urlparse.urlparse(rv.headers['location']).query)) + + code = params['code'] + headers = self.create_basic_header('code-client', 'code-secret') + rv = self.client.post('/oauth/token', data={ + 'grant_type': 'authorization_code', + 'code': code, + }, headers=headers) + resp = rv.json() + self.assertIn('access_token', resp) + + def test_missing_code_verifier(self): + self.prepare_data() + url = self.authorize_url + '&code_challenge=foo' + rv = self.client.post(url, data={'user_id': '1'}) + self.assertIn('code=', rv.headers['location']) + + params = dict(url_decode(urlparse.urlparse(rv.headers['location']).query)) + code = params['code'] + rv = self.client.post('/oauth/token', data={ + 'grant_type': 'authorization_code', + 'code': code, + 'client_id': 'code-client', + }) + resp = rv.json() + self.assertIn('Missing', resp['error_description']) + + def test_trusted_client_missing_code_verifier(self): + self.prepare_data('client_secret_basic') + url = self.authorize_url + '&code_challenge=foo' + rv = self.client.post(url, data={'user_id': '1'}) + self.assertIn('code=', rv.headers['location']) + + params = dict(url_decode(urlparse.urlparse(rv.headers['location']).query)) + code = params['code'] + headers = self.create_basic_header('code-client', 'code-secret') + rv = self.client.post('/oauth/token', data={ + 'grant_type': 'authorization_code', + 'code': code, + }, headers=headers) + resp = rv.json() + self.assertIn('Missing', resp['error_description']) + + def test_plain_code_challenge_invalid(self): + self.prepare_data() + url = self.authorize_url + '&code_challenge=foo' + rv = self.client.post(url, data={'user_id': '1'}) + self.assertIn('code=', rv.headers['location']) + + params = dict(url_decode(urlparse.urlparse(rv.headers['location']).query)) + code = params['code'] + rv = self.client.post('/oauth/token', data={ + 'grant_type': 'authorization_code', + 'code': code, + 'code_verifier': 'bar', + 'client_id': 'code-client', + }) + resp = rv.json() + self.assertIn('Invalid', resp['error_description']) + + def test_plain_code_challenge_failed(self): + self.prepare_data() + url = self.authorize_url + '&code_challenge=foo' + rv = self.client.post(url, data={'user_id': '1'}) + self.assertIn('code=', rv.headers['location']) + + params = dict(url_decode(urlparse.urlparse(rv.headers['location']).query)) + code = params['code'] + rv = self.client.post('/oauth/token', data={ + 'grant_type': 'authorization_code', + 'code': code, + 'code_verifier': generate_token(48), + 'client_id': 'code-client', + }) + resp = rv.json() + self.assertIn('failed', resp['error_description']) + + def test_plain_code_challenge_success(self): + self.prepare_data() + code_verifier = generate_token(48) + url = self.authorize_url + '&code_challenge=' + code_verifier + rv = self.client.post(url, data={'user_id': '1'}) + self.assertIn('code=', rv.headers['location']) + + params = dict(url_decode(urlparse.urlparse(rv.headers['location']).query)) + code = params['code'] + rv = self.client.post('/oauth/token', data={ + 'grant_type': 'authorization_code', + 'code': code, + 'code_verifier': code_verifier, + 'client_id': 'code-client', + }) + resp = rv.json() + self.assertIn('access_token', resp) + + def test_s256_code_challenge_success(self): + self.prepare_data() + code_verifier = generate_token(48) + code_challenge = create_s256_code_challenge(code_verifier) + url = self.authorize_url + '&code_challenge=' + code_challenge + url += '&code_challenge_method=S256' + + rv = self.client.post(url, data={'user_id': '1'}) + self.assertIn('code=', rv.headers['location']) + + params = dict(url_decode(urlparse.urlparse(rv.headers['location']).query)) + code = params['code'] + rv = self.client.post('/oauth/token', data={ + 'grant_type': 'authorization_code', + 'code': code, + 'code_verifier': code_verifier, + 'client_id': 'code-client', + }) + resp = rv.json() + self.assertIn('access_token', resp) + + def test_not_implemented_code_challenge_method(self): + self.prepare_data() + url = self.authorize_url + '&code_challenge=foo' + url += '&code_challenge_method=S128' + + rv = self.client.post(url, data={'user_id': '1'}) + self.assertIn('code=', rv.headers['location']) + + params = dict(url_decode(urlparse.urlparse(rv.headers['location']).query)) + code = params['code'] + self.assertRaises( + RuntimeError, self.client.post, '/oauth/token', + data={ + 'grant_type': 'authorization_code', + 'code': code, + 'code_verifier': generate_token(48), + 'client_id': 'code-client', + } + ) diff --git a/tests/fastapi/test_oauth2/test_device_code_grant.py b/tests/fastapi/test_oauth2/test_device_code_grant.py new file mode 100644 index 00000000..45729625 --- /dev/null +++ b/tests/fastapi/test_oauth2/test_device_code_grant.py @@ -0,0 +1,275 @@ +import time + +from authlib.oauth2.rfc8628 import \ + DeviceAuthorizationEndpoint as _DeviceAuthorizationEndpoint +from authlib.oauth2.rfc8628 import DeviceCodeGrant as _DeviceCodeGrant +from authlib.oauth2.rfc8628 import DeviceCredentialDict +from fastapi import Form, Request + +from .models import Client, User, db +from .oauth2_server import TestCase, create_authorization_server + +device_credentials = { + "valid-device": { + "client_id": "client", + "expires_in": 1800, + "user_code": "code", + }, + "expired-token": { + "client_id": "client", + "expires_in": -100, + "user_code": "none", + }, + "invalid-client": { + "client_id": "invalid", + "expires_in": 1800, + "user_code": "none", + }, + "denied-code": { + "client_id": "client", + "expires_in": 1800, + "user_code": "denied", + }, + "grant-code": { + "client_id": "client", + "expires_in": 1800, + "user_code": "code", + }, + "pending-code": { + "client_id": "client", + "expires_in": 1800, + "user_code": "none", + }, +} + + +class DeviceCodeGrant(_DeviceCodeGrant): + def query_device_credential(self, device_code): + data = device_credentials.get(device_code) + if not data: + return None + + now = int(time.time()) + data["expires_at"] = now + data["expires_in"] + data["device_code"] = device_code + data["scope"] = "profile" + data["interval"] = 5 + data["verification_uri"] = "https://example.com/activate" + return DeviceCredentialDict(data) + + def query_user_grant(self, user_code): + if user_code == "code": + return db.query(User).filter(User.id == 1).first(), True + if user_code == "denied": + return db.query(User).filter(User.id == 1).first(), False + return None + + def should_slow_down(self, credential): + return False + + +class DeviceCodeGrantTest(TestCase): + def create_server(self): + server = create_authorization_server(self.app) + server.register_grant(DeviceCodeGrant) + self.server = server + return server + + def prepare_data(self, grant_type=DeviceCodeGrant.GRANT_TYPE): + user = User(username="foo") + db.add(user) + db.commit() + client = Client( + user_id=user.id, + client_id="client", + client_secret="secret", + ) + client.set_client_metadata( + { + "redirect_uris": ["http://localhost/authorized"], + "scope": "profile", + "grant_types": [grant_type], + "token_endpoint_auth_method": "none", + } + ) + db.add(client) + db.commit() + + def test_invalid_request(self): + self.create_server() + self.prepare_data() + rv = self.client.post( + "/oauth/token", + data={ + "grant_type": DeviceCodeGrant.GRANT_TYPE, + "client_id": "client", + }, + ) + resp = rv.json() + self.assertEqual(resp["error"], "invalid_request") + + rv = self.client.post( + "/oauth/token", + data={ + "grant_type": DeviceCodeGrant.GRANT_TYPE, + "device_code": "missing", + "client_id": "client", + }, + ) + resp = rv.json() + self.assertEqual(resp["error"], "invalid_request") + + def test_unauthorized_client(self): + self.create_server() + rv = self.client.post( + "/oauth/token", + data={ + "grant_type": DeviceCodeGrant.GRANT_TYPE, + "device_code": "valid-device", + "client_id": "invalid", + }, + ) + resp = rv.json() + self.assertEqual(resp["error"], "invalid_client") + + self.prepare_data(grant_type="password") + rv = self.client.post( + "/oauth/token", + data={ + "grant_type": DeviceCodeGrant.GRANT_TYPE, + "device_code": "valid-device", + "client_id": "client", + }, + ) + resp = rv.json() + self.assertEqual(resp["error"], "unauthorized_client") + + def test_invalid_client(self): + self.create_server() + self.prepare_data() + rv = self.client.post( + "/oauth/token", + data={ + "grant_type": DeviceCodeGrant.GRANT_TYPE, + "device_code": "invalid-client", + "client_id": "invalid", + }, + ) + resp = rv.json() + self.assertEqual(resp["error"], "invalid_client") + + def test_expired_token(self): + self.create_server() + self.prepare_data() + rv = self.client.post( + "/oauth/token", + data={ + "grant_type": DeviceCodeGrant.GRANT_TYPE, + "device_code": "expired-token", + "client_id": "client", + }, + ) + resp = rv.json() + self.assertEqual(resp["error"], "expired_token") + + def test_denied_by_user(self): + self.create_server() + self.prepare_data() + rv = self.client.post( + "/oauth/token", + data={ + "grant_type": DeviceCodeGrant.GRANT_TYPE, + "device_code": "denied-code", + "client_id": "client", + }, + ) + resp = rv.json() + self.assertEqual(resp["error"], "access_denied") + + def test_authorization_pending(self): + self.create_server() + self.prepare_data() + rv = self.client.post( + "/oauth/token", + data={ + "grant_type": DeviceCodeGrant.GRANT_TYPE, + "device_code": "pending-code", + "client_id": "client", + }, + ) + resp = rv.json() + self.assertEqual(resp["error"], "authorization_pending") + + def test_get_access_token(self): + self.create_server() + self.prepare_data() + rv = self.client.post( + "/oauth/token", + data={ + "grant_type": DeviceCodeGrant.GRANT_TYPE, + "device_code": "grant-code", + "client_id": "client", + }, + ) + resp = rv.json() + self.assertIn("access_token", resp) + + +class DeviceAuthorizationEndpoint(_DeviceAuthorizationEndpoint): + def get_verification_uri(self): + return "https://example.com/activate" + + def save_device_credential(self, client_id, scope, data): + pass + + +class DeviceAuthorizationEndpointTest(TestCase): + def create_server(self): + server = create_authorization_server(self.app) + server.register_endpoint(DeviceAuthorizationEndpoint) + self.server = server + + @self.app.post("/device_authorize") + def device_authorize( + request: Request, scope: str = Form(None), client_id: str = Form(None) + ): + request.body = { + "scope": scope, + "client_id": client_id, + } + name = DeviceAuthorizationEndpoint.ENDPOINT_NAME + return server.create_endpoint_response(name, request=request) + + return server + + def test_missing_client_id(self): + self.create_server() + rv = self.client.post("/device_authorize", data={"scope": "profile"}) + self.assertEqual(rv.status_code, 401) + resp = rv.json() + self.assertEqual(resp["error"], "invalid_client") + + def test_create_authorization_response(self): + self.create_server() + client = Client( + user_id=1, + client_id="client", + client_secret="secret", + ) + db.add(client) + db.commit() + rv = self.client.post( + "/device_authorize", + data={ + "client_id": "client", + }, + ) + self.assertEqual(rv.status_code, 200) + resp = rv.json() + self.assertIn("device_code", resp) + self.assertIn("user_code", resp) + self.assertEqual(resp["verification_uri"], "https://example.com/activate") + self.assertEqual( + resp["verification_uri_complete"], + "https://example.com/activate?user_code=" + resp["user_code"], + ) diff --git a/tests/fastapi/test_oauth2/test_implicit_grant.py b/tests/fastapi/test_oauth2/test_implicit_grant.py new file mode 100644 index 00000000..b875cb6d --- /dev/null +++ b/tests/fastapi/test_oauth2/test_implicit_grant.py @@ -0,0 +1,83 @@ +from authlib.oauth2.rfc6749.grants import ImplicitGrant + +from .models import Client, User, db +from .oauth2_server import TestCase, create_authorization_server + + +class ImplicitTest(TestCase): + def prepare_data(self, is_confidential=False, response_type="token"): + server = create_authorization_server(self.app) + server.register_grant(ImplicitGrant) + self.server = server + + user = User(username="foo") + db.add(user) + db.commit() + if is_confidential: + client_secret = "implicit-secret" + token_endpoint_auth_method = "client_secret_basic" + else: + client_secret = "" + token_endpoint_auth_method = "none" + + client = Client( + user_id=user.id, + client_id="implicit-client", + client_secret=client_secret, + ) + client.set_client_metadata( + { + "redirect_uris": ["http://localhost/authorized"], + "scope": "profile", + "response_types": [response_type], + "grant_types": ["implicit"], + "token_endpoint_auth_method": token_endpoint_auth_method, + } + ) + self.authorize_url = ( + "/oauth/authorize?response_type=token" "&client_id=implicit-client" + ) + db.add(client) + db.commit() + + def test_get_authorize(self): + self.prepare_data() + rv = self.client.get(self.authorize_url) + self.assertEqual(rv.json(), "ok") + + def test_confidential_client(self): + self.prepare_data(True) + rv = self.client.get(self.authorize_url) + self.assertIn("invalid_client", rv.json()) + + def test_unsupported_client(self): + self.prepare_data(response_type="code") + rv = self.client.get(self.authorize_url) + self.assertIn("unauthorized_client", rv.json()) + + def test_invalid_authorize(self): + self.prepare_data() + rv = self.client.post(self.authorize_url) + self.assertIn("#error=access_denied", rv.headers["location"]) + + self.server.scopes_supported = ["profile"] + rv = self.client.post(self.authorize_url + "&scope=invalid") + self.assertIn("#error=invalid_scope", rv.headers["location"]) + + def test_authorize_token(self): + self.prepare_data() + rv = self.client.post(self.authorize_url, data={"user_id": "1"}) + self.assertIn("access_token=", rv.headers["location"]) + + url = self.authorize_url + "&state=bar&scope=profile" + rv = self.client.post(url, data={"user_id": "1"}) + self.assertIn("access_token=", rv.headers["location"]) + self.assertIn("state=bar", rv.headers["location"]) + self.assertIn("scope=profile", rv.headers["location"]) + + def test_token_generator(self): + m = "tests.fastapi.test_oauth2.oauth2_server:token_generator" + self.app.config.update({"OAUTH2_ACCESS_TOKEN_GENERATOR": m}) + self.prepare_data() + rv = self.client.post(self.authorize_url, data={"user_id": "1"}) + self.assertIn("access_token=i-implicit.1.", rv.headers["location"]) diff --git a/tests/fastapi/test_oauth2/test_introspection_endpoint.py b/tests/fastapi/test_oauth2/test_introspection_endpoint.py new file mode 100644 index 00000000..2a9c3cca --- /dev/null +++ b/tests/fastapi/test_oauth2/test_introspection_endpoint.py @@ -0,0 +1,184 @@ +from authlib.integrations.sqla_oauth2 import create_query_token_func +from authlib.oauth2.rfc7662 import IntrospectionEndpoint +from fastapi import Form, Request + +from .models import Client, Token, User, db +from .oauth2_server import TestCase, create_authorization_server + +query_token = create_query_token_func(db, Token) + + +class MyIntrospectionEndpoint(IntrospectionEndpoint): + def check_permission(self, token, client, request): + return True + + def query_token(self, token, token_type_hint): + return query_token(token, token_type_hint) + + def introspect_token(self, token): + user = db.query(User).filter(User.id == int(token.user_id)).first() + return { + "active": not token.revoked, + "client_id": token.client_id, + "username": user.username, + "scope": token.scope, + "sub": user.get_user_id(), + "aud": token.client_id, + "iss": "https://server.example.com/", + "exp": token.issued_at + token.expires_in, + "iat": token.issued_at, + } + + +class IntrospectTokenTest(TestCase): + def prepare_data(self): + app = self.app + + server = create_authorization_server(app) + server.register_endpoint(MyIntrospectionEndpoint) + + @app.post("/oauth/introspect") + def introspect_token( + request: Request, token: str = Form(None), token_type_hint: str = Form(None) + ): + request.body = {} + + if token: + request.body.update({"token": token}) + + if token_type_hint: + request.body.update({"token_type_hint": token_type_hint}) + + return server.create_endpoint_response("introspection", request=request) + + user = User(username="foo") + db.add(user) + db.commit() + client = Client( + user_id=user.id, + client_id="introspect-client", + client_secret="introspect-secret", + ) + client.set_client_metadata( + { + "scope": "profile", + "redirect_uris": ["http://a.b/c"], + } + ) + db.add(client) + db.commit() + + def create_token(self): + token = Token( + user_id=1, + client_id="introspect-client", + token_type="bearer", + access_token="a1", + refresh_token="r1", + scope="profile", + expires_in=3600, + revoked=False, + ) + db.add(token) + db.commit() + + def test_invalid_client(self): + self.prepare_data() + rv = self.client.post("/oauth/introspect") + resp = rv.json() + self.assertEqual(resp["error"], "invalid_client") + + headers = {"Authorization": "invalid token_string"} + rv = self.client.post("/oauth/introspect", headers=headers) + resp = rv.json() + self.assertEqual(resp["error"], "invalid_client") + + headers = self.create_basic_header("invalid-client", "introspect-secret") + rv = self.client.post("/oauth/introspect", headers=headers) + resp = rv.json() + self.assertEqual(resp["error"], "invalid_client") + + headers = self.create_basic_header("introspect-client", "invalid-secret") + rv = self.client.post("/oauth/introspect", headers=headers) + resp = rv.json() + self.assertEqual(resp["error"], "invalid_client") + + def test_invalid_token(self): + self.prepare_data() + headers = self.create_basic_header("introspect-client", "introspect-secret") + rv = self.client.post("/oauth/introspect", headers=headers) + resp = rv.json() + self.assertEqual(resp["error"], "invalid_request") + + rv = self.client.post( + "/oauth/introspect", + data={ + "token_type_hint": "refresh_token", + }, + headers=headers, + ) + resp = rv.json() + self.assertEqual(resp["error"], "invalid_request") + + rv = self.client.post( + "/oauth/introspect", + data={ + "token": "a1", + "token_type_hint": "unsupported_token_type", + }, + headers=headers, + ) + resp = rv.json() + self.assertEqual(resp["error"], "unsupported_token_type") + + rv = self.client.post( + "/oauth/introspect", + data={ + "token": "invalid-token", + }, + headers=headers, + ) + resp = rv.json() + self.assertEqual(resp["active"], False) + + rv = self.client.post( + "/oauth/introspect", + data={ + "token": "a1", + "token_type_hint": "refresh_token", + }, + headers=headers, + ) + resp = rv.json() + self.assertEqual(resp["active"], False) + + def test_introspect_token_with_hint(self): + self.prepare_data() + self.create_token() + headers = self.create_basic_header("introspect-client", "introspect-secret") + rv = self.client.post( + "/oauth/introspect", + data={ + "token": "a1", + "token_type_hint": "access_token", + }, + headers=headers, + ) + self.assertEqual(rv.status_code, 200) + resp = rv.json() + self.assertEqual(resp["client_id"], "introspect-client") + + def test_introspect_token_without_hint(self): + self.prepare_data() + self.create_token() + headers = self.create_basic_header("introspect-client", "introspect-secret") + rv = self.client.post( + "/oauth/introspect", + data={ + "token": "a1", + }, + headers=headers, + ) + self.assertEqual(rv.status_code, 200) + resp = rv.json() + self.assertEqual(resp["client_id"], "introspect-client") diff --git a/tests/fastapi/test_oauth2/test_jwt_bearer_client_auth.py b/tests/fastapi/test_oauth2/test_jwt_bearer_client_auth.py new file mode 100644 index 00000000..f038374f --- /dev/null +++ b/tests/fastapi/test_oauth2/test_jwt_bearer_client_auth.py @@ -0,0 +1,152 @@ +from authlib.oauth2.rfc6749.grants import ClientCredentialsGrant +from authlib.oauth2.rfc7523 import ( + JWTBearerClientAssertion, + client_secret_jwt_sign, + private_key_jwt_sign, +) +from tests.util import read_file_path +from .models import db, User, Client +from .oauth2_server import TestCase +from .oauth2_server import create_authorization_server + + +class JWTClientCredentialsGrant(ClientCredentialsGrant): + TOKEN_ENDPOINT_AUTH_METHODS = [ + JWTBearerClientAssertion.CLIENT_AUTH_METHOD, + ] + + +class JWTClientAuth(JWTBearerClientAssertion): + def validate_jti(self, claims, jti): + return True + + def resolve_client_public_key(self, client, headers): + if headers['alg'] == 'RS256': + return read_file_path('jwk_public.json') + return client.client_secret + + +class ClientCredentialsTest(TestCase): + def prepare_data(self, auth_method, validate_jti=True): + server = create_authorization_server(self.app) + server.register_grant(JWTClientCredentialsGrant) + server.register_client_auth_method( + JWTClientAuth.CLIENT_AUTH_METHOD, + JWTClientAuth('https://localhost/oauth/token', validate_jti) + ) + + user = User(username='foo') + db.add(user) + db.commit() + client = Client( + user_id=user.id, + client_id='credential-client', + client_secret='credential-secret', + ) + client.set_client_metadata({ + 'scope': 'profile', + 'redirect_uris': ['http://localhost/authorized'], + 'grant_types': ['client_credentials'], + 'token_endpoint_auth_method': auth_method, + }) + db.add(client) + db.commit() + + def test_invalid_client(self): + self.prepare_data(JWTBearerClientAssertion.CLIENT_AUTH_METHOD) + rv = self.client.post('/oauth/token', data={ + 'grant_type': 'client_credentials', + 'client_assertion_type': JWTBearerClientAssertion.CLIENT_ASSERTION_TYPE + }) + resp = rv.json() + self.assertEqual(resp['error'], 'invalid_client') + + def test_invalid_jwt(self): + self.prepare_data(JWTBearerClientAssertion.CLIENT_AUTH_METHOD) + + rv = self.client.post('/oauth/token', data={ + 'grant_type': 'client_credentials', + 'client_assertion_type': JWTBearerClientAssertion.CLIENT_ASSERTION_TYPE, + 'client_assertion': client_secret_jwt_sign( + client_secret='invalid-secret', + client_id='credential-client', + token_endpoint='https://localhost/oauth/token', + ) + }) + resp = rv.json() + self.assertEqual(resp['error'], 'invalid_client') + + def test_not_found_client(self): + self.prepare_data(JWTBearerClientAssertion.CLIENT_AUTH_METHOD) + + rv = self.client.post('/oauth/token', data={ + 'grant_type': 'client_credentials', + 'client_assertion_type': JWTBearerClientAssertion.CLIENT_ASSERTION_TYPE, + 'client_assertion': client_secret_jwt_sign( + client_secret='credential-secret', + client_id='invalid-client', + token_endpoint='https://localhost/oauth/token', + ) + }) + resp = rv.json() + self.assertEqual(resp['error'], 'invalid_client') + + def test_not_supported_auth_method(self): + self.prepare_data('invalid') + rv = self.client.post('/oauth/token', data={ + 'grant_type': 'client_credentials', + 'client_assertion_type': JWTBearerClientAssertion.CLIENT_ASSERTION_TYPE, + 'client_assertion': client_secret_jwt_sign( + client_secret='credential-secret', + client_id='credential-client', + token_endpoint='https://localhost/oauth/token', + ) + }) + resp = rv.json() + self.assertEqual(resp['error'], 'invalid_client') + + def test_client_secret_jwt(self): + self.prepare_data(JWTBearerClientAssertion.CLIENT_AUTH_METHOD) + + rv = self.client.post('/oauth/token', data={ + 'grant_type': 'client_credentials', + 'client_assertion_type': JWTBearerClientAssertion.CLIENT_ASSERTION_TYPE, + 'client_assertion': client_secret_jwt_sign( + client_secret='credential-secret', + client_id='credential-client', + token_endpoint='https://localhost/oauth/token', + claims={'jti': 'nonce'}, + ) + }) + resp = rv.json() + self.assertIn('access_token', resp) + + def test_private_key_jwt(self): + self.prepare_data(JWTBearerClientAssertion.CLIENT_AUTH_METHOD) + + rv = self.client.post('/oauth/token', data={ + 'grant_type': 'client_credentials', + 'client_assertion_type': JWTBearerClientAssertion.CLIENT_ASSERTION_TYPE, + 'client_assertion': private_key_jwt_sign( + private_key=read_file_path('jwk_private.json'), + client_id='credential-client', + token_endpoint='https://localhost/oauth/token', + ) + }) + resp = rv.json() + self.assertIn('access_token', resp) + + def test_not_validate_jti(self): + self.prepare_data(JWTBearerClientAssertion.CLIENT_AUTH_METHOD, False) + + rv = self.client.post('/oauth/token', data={ + 'grant_type': 'client_credentials', + 'client_assertion_type': JWTBearerClientAssertion.CLIENT_ASSERTION_TYPE, + 'client_assertion': client_secret_jwt_sign( + client_secret='credential-secret', + client_id='credential-client', + token_endpoint='https://localhost/oauth/token', + ) + }) + resp = rv.json() + self.assertIn('access_token', resp) diff --git a/tests/fastapi/test_oauth2/test_jwt_bearer_grant.py b/tests/fastapi/test_oauth2/test_jwt_bearer_grant.py new file mode 100644 index 00000000..a3d80abf --- /dev/null +++ b/tests/fastapi/test_oauth2/test_jwt_bearer_grant.py @@ -0,0 +1,105 @@ +from authlib.oauth2.rfc7523 import JWTBearerGrant as _JWTBearerGrant +from .models import db, User, Client +from .oauth2_server import TestCase +from .oauth2_server import create_authorization_server + + +class JWTBearerGrant(_JWTBearerGrant): + def authenticate_user(self, client, claims): + return None + + def authenticate_client(self, claims): + iss = claims['iss'] + return db.query(Client).filter(Client.client_id == iss).first() + + def resolve_public_key(self, headers, payload): + keys = {'1': 'foo', '2': 'bar'} + return keys[headers['kid']] + + +class JWTBearerGrantTest(TestCase): + def prepare_data(self, grant_type=None): + server = create_authorization_server(self.app) + server.register_grant(JWTBearerGrant) + + user = User(username='foo') + db.add(user) + db.commit() + if grant_type is None: + grant_type = JWTBearerGrant.GRANT_TYPE + client = Client( + user_id=user.id, + client_id='jwt-client', + client_secret='jwt-secret', + ) + client.set_client_metadata({ + 'scope': 'profile', + 'redirect_uris': ['http://localhost/authorized'], + 'grant_types': [grant_type], + }) + db.add(client) + db.commit() + + def test_missing_assertion(self): + self.prepare_data() + rv = self.client.post('/oauth/token', data={ + 'grant_type': JWTBearerGrant.GRANT_TYPE + }) + resp = rv.json() + self.assertEqual(resp['error'], 'invalid_request') + self.assertIn('assertion', resp['error_description']) + + def test_invalid_assertion(self): + self.prepare_data() + assertion = JWTBearerGrant.sign( + 'foo', issuer='jwt-client', audience='https://i.b/token', + header={'alg': 'HS256', 'kid': '1'} + ) + rv = self.client.post('/oauth/token', data={ + 'grant_type': JWTBearerGrant.GRANT_TYPE, + 'assertion': assertion + }) + resp = rv.json() + self.assertEqual(resp['error'], 'invalid_grant') + + def test_authorize_token(self): + self.prepare_data() + assertion = JWTBearerGrant.sign( + 'foo', issuer='jwt-client', audience='https://i.b/token', + subject='self', header={'alg': 'HS256', 'kid': '1'} + ) + rv = self.client.post('/oauth/token', data={ + 'grant_type': JWTBearerGrant.GRANT_TYPE, + 'assertion': assertion + }) + resp = rv.json() + self.assertIn('access_token', resp) + + def test_unauthorized_client(self): + self.prepare_data('password') + assertion = JWTBearerGrant.sign( + 'bar', issuer='jwt-client', audience='https://i.b/token', + subject='self', header={'alg': 'HS256', 'kid': '2'} + ) + rv = self.client.post('/oauth/token', data={ + 'grant_type': JWTBearerGrant.GRANT_TYPE, + 'assertion': assertion + }) + resp = rv.json() + self.assertEqual(resp['error'], 'unauthorized_client') + + def test_token_generator(self): + m = 'tests.fastapi.test_oauth2.oauth2_server:token_generator' + self.app.config.update({'OAUTH2_ACCESS_TOKEN_GENERATOR': m}) + self.prepare_data() + assertion = JWTBearerGrant.sign( + 'foo', issuer='jwt-client', audience='https://i.b/token', + subject='self', header={'alg': 'HS256', 'kid': '1'} + ) + rv = self.client.post('/oauth/token', data={ + 'grant_type': JWTBearerGrant.GRANT_TYPE, + 'assertion': assertion + }) + resp = rv.json() + self.assertIn('access_token', resp) + self.assertIn('j-', resp['access_token']) diff --git a/tests/fastapi/test_oauth2/test_oauth2_server.py b/tests/fastapi/test_oauth2/test_oauth2_server.py new file mode 100644 index 00000000..e1b54534 --- /dev/null +++ b/tests/fastapi/test_oauth2/test_oauth2_server.py @@ -0,0 +1,178 @@ +from authlib.integrations.fastapi_oauth2 import ResourceProtector +from authlib.integrations.sqla_oauth2 import create_bearer_token_validator +from fastapi import Request + +from .models import Client, Token, User, db +from .oauth2_server import TestCase, create_authorization_server + +require_oauth = ResourceProtector() +BearerTokenValidator = create_bearer_token_validator(db, Token) +require_oauth.register_token_validator(BearerTokenValidator()) + + +def create_resource_server(app): + @app.get("/user") + @require_oauth(["profile"]) + def user_profile(request: Request): + user = request.state.token.user + return {"id": user.id, "username": user.username} + + @app.get("/user/email") + @require_oauth("email") + def user_email(request: Request): + pass + + @app.get("/info") + @require_oauth() + def public_info(request: Request): + return {"status": "ok"} + + @app.get("/operator-and") + @require_oauth(["profile email"]) + def operator_and(request: Request): + return {"status": "ok"} + + @app.get("/operator-or") + @require_oauth(["profile", "email"]) + def operator_or(request: Request): + return {"status": "ok"} + + @app.get("/acquire") + def test_acquire(request: Request): + with require_oauth.acquire(request, ["profile"]) as token: + user = token.user + return {"id": user.id, "username": user.username} + + +class AuthorizationTest(TestCase): + def test_none_grant(self): + create_authorization_server(self.app) + authorize_url = ( + "/oauth/authorize?response_type=token" "&client_id=implicit-client" + ) + rv = self.client.get(authorize_url) + self.assertIn("unsupported_response_type", rv.text) + + rv = self.client.post(authorize_url, data={"user_id": "1"}) + self.assertNotEqual(rv.status_code, 200) + + rv = self.client.post( + "/oauth/token", + data={ + "grant_type": "authorization_code", + "code": "x", + }, + ) + data = rv.json() + self.assertEqual(data["error"], "unsupported_grant_type") + + +class ResourceTest(TestCase): + def prepare_data(self): + create_resource_server(self.app) + + user = User(username="foo") + db.add(user) + db.commit() + client = Client( + user_id=user.id, + client_id="resource-client", + client_secret="resource-secret", + ) + client.set_client_metadata( + { + "scope": "profile", + "redirect_uris": ["http://localhost/authorized"], + } + ) + db.add(client) + db.commit() + + def create_token(self, expires_in=3600): + token = Token( + user_id=1, + client_id="resource-client", + token_type="bearer", + access_token="a1", + scope="profile", + expires_in=expires_in, + ) + db.add(token) + db.commit() + + def create_bearer_header(self, token): + return {"Authorization": "Bearer " + token} + + def test_invalid_token(self): + self.prepare_data() + + rv = self.client.get("/user") + self.assertEqual(rv.status_code, 401) + resp = rv.json() + self.assertEqual(resp["detail"]["error"], "missing_authorization") + + headers = {"Authorization": "invalid token"} + rv = self.client.get("/user", headers=headers) + self.assertEqual(rv.status_code, 401) + resp = rv.json() + self.assertEqual(resp["detail"]["error"], "unsupported_token_type") + + headers = self.create_bearer_header("invalid") + rv = self.client.get("/user", headers=headers) + self.assertEqual(rv.status_code, 401) + resp = rv.json() + self.assertEqual(resp["detail"]["error"], "invalid_token") + + def test_expired_token(self): + self.prepare_data() + self.create_token(-10) + headers = self.create_bearer_header("a1") + + rv = self.client.get("/user", headers=headers) + self.assertEqual(rv.status_code, 401) + resp = rv.json() + self.assertEqual(resp["detail"]["error"], "invalid_token") + + rv = self.client.get("/acquire", headers=headers) + self.assertEqual(rv.status_code, 401) + + def test_insufficient_token(self): + self.prepare_data() + self.create_token() + headers = self.create_bearer_header("a1") + rv = self.client.get("/user/email", headers=headers) + self.assertEqual(rv.status_code, 403) + resp = rv.json() + self.assertEqual(resp["detail"]["error"], "insufficient_scope") + + def test_access_resource(self): + self.prepare_data() + self.create_token() + headers = self.create_bearer_header("a1") + + rv = self.client.get("/user", headers=headers) + resp = rv.json() + self.assertEqual(rv.status_code, 200) + self.assertEqual(resp["username"], "foo") + + rv = self.client.get("/acquire", headers=headers) + resp = rv.json() + self.assertEqual(rv.status_code, 200) + self.assertEqual(resp["username"], "foo") + + rv = self.client.get("/info", headers=headers) + resp = rv.json() + self.assertEqual(rv.status_code, 200) + self.assertEqual(resp["status"], "ok") + + def test_scope_operator(self): + self.prepare_data() + self.create_token() + headers = self.create_bearer_header("a1") + rv = self.client.get("/operator-and", headers=headers) + self.assertEqual(rv.status_code, 403) + resp = rv.json() + self.assertEqual(resp["detail"]["error"], "insufficient_scope") + + rv = self.client.get("/operator-or", headers=headers) + self.assertEqual(rv.status_code, 200) diff --git a/tests/fastapi/test_oauth2/test_openid_code_grant.py b/tests/fastapi/test_oauth2/test_openid_code_grant.py new file mode 100644 index 00000000..ed230299 --- /dev/null +++ b/tests/fastapi/test_oauth2/test_openid_code_grant.py @@ -0,0 +1,274 @@ +import json +from authlib.common.encoding import to_unicode +from authlib.common.urls import urlparse, url_decode, url_encode +from authlib.jose import JsonWebToken, JsonWebKey +from authlib.oidc.core import CodeIDToken +from authlib.oidc.core.grants import OpenIDCode as _OpenIDCode +from authlib.oauth2.rfc6749.grants import ( + AuthorizationCodeGrant as _AuthorizationCodeGrant, +) +from tests.util import get_file_path +from .models import db, User, Client, exists_nonce +from .models import CodeGrantMixin, save_authorization_code +from .oauth2_server import TestCase +from .oauth2_server import create_authorization_server + +DUMMY_JWT_CONFIG = { + 'key': 'secret', + 'alg': 'HS256', + 'iss': 'Authlib', + 'exp': 3600, +} + + +class AuthorizationCodeGrant(CodeGrantMixin, _AuthorizationCodeGrant): + def save_authorization_code(self, code, request): + return save_authorization_code(code, request) + + +class OpenIDCode(_OpenIDCode): + def get_jwt_config(self, grant): + return DUMMY_JWT_CONFIG + + def exists_nonce(self, nonce, request): + return exists_nonce(nonce, request) + + def generate_user_info(self, user, scopes): + return user.generate_user_info(scopes) + + +class BaseTestCase(TestCase): + def config_app(self): + DUMMY_JWT_CONFIG.update({ + 'iss': 'Authlib', + 'key': 'secret', + 'alg': 'HS256', + }) + + def prepare_data(self): + self.config_app() + server = create_authorization_server(self.app) + server.register_grant(AuthorizationCodeGrant, [OpenIDCode()]) + + user = User(username='foo') + db.add(user) + db.commit() + + client = Client( + user_id=user.id, + client_id='code-client', + client_secret='code-secret', + ) + client.set_client_metadata({ + 'redirect_uris': ['https://a.b'], + 'scope': 'openid profile address', + 'response_types': ['code'], + 'grant_types': ['authorization_code'], + }) + db.add(client) + db.commit() + + +class OpenIDCodeTest(BaseTestCase): + def test_authorize_token(self): + self.prepare_data() + rv = self.client.post('/oauth/authorize', data={ + 'response_type': 'code', + 'client_id': 'code-client', + 'state': 'bar', + 'scope': 'openid profile', + 'redirect_uri': 'https://a.b', + 'user_id': '1' + }) + self.assertIn('code=', rv.headers['location']) + + params = dict(url_decode(urlparse.urlparse(rv.headers['location']).query)) + self.assertEqual(params['state'], 'bar') + + code = params['code'] + headers = self.create_basic_header('code-client', 'code-secret') + rv = self.client.post('/oauth/token', data={ + 'grant_type': 'authorization_code', + 'redirect_uri': 'https://a.b', + 'code': code, + }, headers=headers) + resp = rv.json() + self.assertIn('access_token', resp) + self.assertIn('id_token', resp) + + jwt = JsonWebToken() + claims = jwt.decode( + resp['id_token'], 'secret', + claims_cls=CodeIDToken, + claims_options={'iss': {'value': 'Authlib'}} + ) + claims.validate() + + def test_pure_code_flow(self): + self.prepare_data() + rv = self.client.post('/oauth/authorize', data={ + 'response_type': 'code', + 'client_id': 'code-client', + 'state': 'bar', + 'scope': 'profile', + 'redirect_uri': 'https://a.b', + 'user_id': '1' + }) + self.assertIn('code=', rv.headers['location']) + + params = dict(url_decode(urlparse.urlparse(rv.headers['location']).query)) + self.assertEqual(params['state'], 'bar') + + code = params['code'] + headers = self.create_basic_header('code-client', 'code-secret') + rv = self.client.post('/oauth/token', data={ + 'grant_type': 'authorization_code', + 'redirect_uri': 'https://a.b', + 'code': code, + }, headers=headers) + resp = rv.json() + self.assertIn('access_token', resp) + self.assertNotIn('id_token', resp) + + def test_nonce_replay(self): + self.prepare_data() + data = { + 'response_type': 'code', + 'client_id': 'code-client', + 'user_id': '1', + 'state': 'bar', + 'nonce': 'abc', + 'scope': 'openid profile', + 'redirect_uri': 'https://a.b' + } + rv = self.client.post('/oauth/authorize', data=data) + self.assertIn('code=', rv.headers['location']) + + rv = self.client.post('/oauth/authorize', data=data) + self.assertIn('error=', rv.headers['location']) + + def test_prompt(self): + self.prepare_data() + params = [ + ('response_type', 'code'), + ('client_id', 'code-client'), + ('state', 'bar'), + ('nonce', 'abc'), + ('scope', 'openid profile'), + ('redirect_uri', 'https://a.b') + ] + query = url_encode(params) + rv = self.client.get('/oauth/authorize?' + query) + self.assertEqual(rv.json(), 'login') + + query = url_encode(params + [('user_id', '1')]) + rv = self.client.get('/oauth/authorize?' + query) + self.assertEqual(rv.json(), 'ok') + + query = url_encode(params + [('prompt', 'login')]) + rv = self.client.get('/oauth/authorize?' + query) + self.assertEqual(rv.json(), 'login') + + +class RSAOpenIDCodeTest(BaseTestCase): + def config_app(self): + jwt_key_path = get_file_path('jwk_private.json') + with open(jwt_key_path, 'r') as f: + jwt_key = json.load(f) + + DUMMY_JWT_CONFIG.update({ + 'iss': 'Authlib', + 'key': jwt_key, + 'alg': 'RS256', + }) + + def get_validate_key(self): + with open(get_file_path('jwk_public.json'), 'r') as f: + return json.load(f) + + def test_authorize_token(self): + # generate refresh token + self.prepare_data() + rv = self.client.post('/oauth/authorize', data={ + 'response_type': 'code', + 'client_id': 'code-client', + 'state': 'bar', + 'scope': 'openid profile', + 'redirect_uri': 'https://a.b', + 'user_id': '1' + }) + self.assertIn('code=', rv.headers['location']) + + params = dict(url_decode(urlparse.urlparse(rv.headers['location']).query)) + self.assertEqual(params['state'], 'bar') + + code = params['code'] + headers = self.create_basic_header('code-client', 'code-secret') + rv = self.client.post('/oauth/token', data={ + 'grant_type': 'authorization_code', + 'redirect_uri': 'https://a.b', + 'code': code, + }, headers=headers) + resp = rv.json() + self.assertIn('access_token', resp) + self.assertIn('id_token', resp) + + jwt = JsonWebToken() + claims = jwt.decode( + resp['id_token'], + self.get_validate_key(), + claims_cls=CodeIDToken, + claims_options={'iss': {'value': 'Authlib'}} + ) + claims.validate() + + +class JWKSOpenIDCodeTest(RSAOpenIDCodeTest): + def config_app(self): + jwt_key_path = get_file_path('jwks_private.json') + with open(jwt_key_path, 'r') as f: + jwt_key = json.load(f) + + DUMMY_JWT_CONFIG.update({ + 'iss': 'Authlib', + 'key': jwt_key, + 'alg': 'PS256', + }) + + def get_validate_key(self): + with open(get_file_path('jwks_public.json'), 'r') as f: + return JsonWebKey.import_key_set(json.load(f)) + + +class ECOpenIDCodeTest(RSAOpenIDCodeTest): + def config_app(self): + jwt_key_path = get_file_path('secp521r1-private.json') + with open(jwt_key_path, 'r') as f: + jwt_key = json.load(f) + + DUMMY_JWT_CONFIG.update({ + 'iss': 'Authlib', + 'key': jwt_key, + 'alg': 'ES512', + }) + + def get_validate_key(self): + with open(get_file_path('secp521r1-public.json'), 'r') as f: + return json.load(f) + + +class PEMOpenIDCodeTest(RSAOpenIDCodeTest): + def config_app(self): + jwt_key_path = get_file_path('rsa_private.pem') + with open(jwt_key_path, 'r') as f: + jwt_key = to_unicode(f.read()) + + DUMMY_JWT_CONFIG.update({ + 'iss': 'Authlib', + 'key': jwt_key, + 'alg': 'RS256', + }) + + def get_validate_key(self): + with open(get_file_path('rsa_public.pem'), 'r') as f: + return f.read() diff --git a/tests/fastapi/test_oauth2/test_openid_hybrid_grant.py b/tests/fastapi/test_oauth2/test_openid_hybrid_grant.py new file mode 100644 index 00000000..3050504d --- /dev/null +++ b/tests/fastapi/test_oauth2/test_openid_hybrid_grant.py @@ -0,0 +1,330 @@ +from authlib.common.urls import url_decode, urlparse +from authlib.jose import jwt +from authlib.oauth2.rfc6749.grants import \ + AuthorizationCodeGrant as _AuthorizationCodeGrant +from authlib.oidc.core import HybridIDToken +from authlib.oidc.core.grants import OpenIDCode as _OpenIDCode +from authlib.oidc.core.grants import OpenIDHybridGrant as _OpenIDHybridGrant + +from .models import (Client, CodeGrantMixin, User, db, exists_nonce, + save_authorization_code) +from .oauth2_server import TestCase, create_authorization_server + +JWT_CONFIG = {"iss": "Authlib", "key": "secret", "alg": "HS256", "exp": 3600} + + +class AuthorizationCodeGrant(CodeGrantMixin, _AuthorizationCodeGrant): + def save_authorization_code(self, code, request): + return save_authorization_code(code, request) + + +class OpenIDCode(_OpenIDCode): + def get_jwt_config(self, grant): + return dict(JWT_CONFIG) + + def exists_nonce(self, nonce, request): + return exists_nonce(nonce, request) + + def generate_user_info(self, user, scopes): + return user.generate_user_info(scopes) + + +class OpenIDHybridGrant(_OpenIDHybridGrant): + def save_authorization_code(self, code, request): + return save_authorization_code(code, request) + + def get_jwt_config(self): + return dict(JWT_CONFIG) + + def exists_nonce(self, nonce, request): + return exists_nonce(nonce, request) + + def generate_user_info(self, user, scopes): + return user.generate_user_info(scopes) + + +class OpenIDCodeTest(TestCase): + def prepare_data(self): + server = create_authorization_server(self.app) + server.register_grant(OpenIDHybridGrant) + server.register_grant(AuthorizationCodeGrant, [OpenIDCode()]) + + user = User(username="foo") + db.add(user) + db.commit() + + client = Client( + user_id=user.id, + client_id="hybrid-client", + client_secret="hybrid-secret", + ) + client.set_client_metadata( + { + "redirect_uris": ["https://a.b"], + "scope": "openid profile address", + "response_types": [ + "code id_token", + "code token", + "code id_token token", + ], + "grant_types": ["authorization_code"], + } + ) + db.add(client) + db.commit() + + def validate_claims(self, id_token, params): + claims = jwt.decode( + id_token, "secret", claims_cls=HybridIDToken, claims_params=params + ) + claims.validate() + + def test_invalid_client_id(self): + self.prepare_data() + rv = self.client.post( + "/oauth/authorize", + data={ + "response_type": "code token", + "state": "bar", + "nonce": "abc", + "scope": "openid profile", + "redirect_uri": "https://a.b", + "user_id": "1", + }, + ) + resp = rv.json() + self.assertEqual(resp["error"], "invalid_client") + + rv = self.client.post( + "/oauth/authorize", + data={ + "client_id": "invalid-client", + "response_type": "code token", + "state": "bar", + "nonce": "abc", + "scope": "openid profile", + "redirect_uri": "https://a.b", + "user_id": "1", + }, + ) + resp = rv.json() + self.assertEqual(resp["error"], "invalid_client") + + def test_require_nonce(self): + self.prepare_data() + rv = self.client.post( + "/oauth/authorize", + data={ + "client_id": "hybrid-client", + "response_type": "code token", + "scope": "openid profile", + "state": "bar", + "redirect_uri": "https://a.b", + "user_id": "1", + }, + ) + self.assertIn("error=invalid_request", rv.headers["location"]) + self.assertIn("nonce", rv.headers["location"]) + + def test_invalid_response_type(self): + self.prepare_data() + rv = self.client.post( + "/oauth/authorize", + data={ + "client_id": "hybrid-client", + "response_type": "code id_token invalid", + "state": "bar", + "nonce": "abc", + "scope": "profile", + "redirect_uri": "https://a.b", + "user_id": "1", + }, + ) + resp = rv.json() + self.assertEqual(resp["error"], "unsupported_response_type") + + def test_invalid_scope(self): + self.prepare_data() + rv = self.client.post( + "/oauth/authorize", + data={ + "client_id": "hybrid-client", + "response_type": "code id_token", + "state": "bar", + "nonce": "abc", + "scope": "profile", + "redirect_uri": "https://a.b", + "user_id": "1", + }, + ) + self.assertIn("error=invalid_scope", rv.headers["location"]) + + def test_access_denied(self): + self.prepare_data() + rv = self.client.post( + "/oauth/authorize", + data={ + "client_id": "hybrid-client", + "response_type": "code token", + "state": "bar", + "nonce": "abc", + "scope": "openid profile", + "redirect_uri": "https://a.b", + }, + ) + self.assertIn("error=access_denied", rv.headers["location"]) + + def test_code_access_token(self): + self.prepare_data() + rv = self.client.post( + "/oauth/authorize", + data={ + "client_id": "hybrid-client", + "response_type": "code token", + "state": "bar", + "nonce": "abc", + "scope": "openid profile", + "redirect_uri": "https://a.b", + "user_id": "1", + }, + ) + self.assertIn("code=", rv.headers["location"]) + self.assertIn("access_token=", rv.headers["location"]) + self.assertNotIn("id_token=", rv.headers["location"]) + + params = dict(url_decode(urlparse.urlparse(rv.headers["location"]).fragment)) + self.assertEqual(params["state"], "bar") + + code = params["code"] + headers = self.create_basic_header("hybrid-client", "hybrid-secret") + rv = self.client.post( + "/oauth/token", + data={ + "grant_type": "authorization_code", + "redirect_uri": "https://a.b", + "code": code, + }, + headers=headers, + ) + resp = rv.json() + self.assertIn("access_token", resp) + self.assertIn("id_token", resp) + + def test_code_id_token(self): + self.prepare_data() + rv = self.client.post( + "/oauth/authorize", + data={ + "client_id": "hybrid-client", + "response_type": "code id_token", + "state": "bar", + "nonce": "abc", + "scope": "openid profile", + "redirect_uri": "https://a.b", + "user_id": "1", + }, + ) + self.assertIn("code=", rv.headers["location"]) + self.assertIn("id_token=", rv.headers["location"]) + self.assertNotIn("access_token=", rv.headers["location"]) + + params = dict(url_decode(urlparse.urlparse(rv.headers["location"]).fragment)) + self.assertEqual(params["state"], "bar") + + params["nonce"] = "abc" + params["client_id"] = "hybrid-client" + self.validate_claims(params["id_token"], params) + + code = params["code"] + headers = self.create_basic_header("hybrid-client", "hybrid-secret") + rv = self.client.post( + "/oauth/token", + data={ + "grant_type": "authorization_code", + "redirect_uri": "https://a.b", + "code": code, + }, + headers=headers, + ) + resp = rv.json() + self.assertIn("access_token", resp) + self.assertIn("id_token", resp) + + def test_code_id_token_access_token(self): + self.prepare_data() + rv = self.client.post( + "/oauth/authorize", + data={ + "client_id": "hybrid-client", + "response_type": "code id_token token", + "state": "bar", + "nonce": "abc", + "scope": "openid profile", + "redirect_uri": "https://a.b", + "user_id": "1", + }, + ) + self.assertIn("code=", rv.headers["location"]) + self.assertIn("id_token=", rv.headers["location"]) + self.assertIn("access_token=", rv.headers["location"]) + + params = dict(url_decode(urlparse.urlparse(rv.headers["location"]).fragment)) + self.assertEqual(params["state"], "bar") + self.validate_claims(params["id_token"], params) + + code = params["code"] + headers = self.create_basic_header("hybrid-client", "hybrid-secret") + rv = self.client.post( + "/oauth/token", + data={ + "grant_type": "authorization_code", + "redirect_uri": "https://a.b", + "code": code, + }, + headers=headers, + ) + resp = rv.json() + self.assertIn("access_token", resp) + self.assertIn("id_token", resp) + + def test_response_mode_query(self): + self.prepare_data() + rv = self.client.post( + "/oauth/authorize", + data={ + "client_id": "hybrid-client", + "response_type": "code id_token token", + "response_mode": "query", + "state": "bar", + "nonce": "abc", + "scope": "openid profile", + "redirect_uri": "https://a.b", + "user_id": "1", + }, + ) + self.assertIn("code=", rv.headers["location"]) + self.assertIn("id_token=", rv.headers["location"]) + self.assertIn("access_token=", rv.headers["location"]) + + params = dict(url_decode(urlparse.urlparse(rv.headers["location"]).query)) + self.assertEqual(params["state"], "bar") + + def test_response_mode_form_post(self): + self.prepare_data() + rv = self.client.post( + "/oauth/authorize", + data={ + "client_id": "hybrid-client", + "response_type": "code id_token token", + "response_mode": "form_post", + "state": "bar", + "nonce": "abc", + "scope": "openid profile", + "redirect_uri": "https://a.b", + "user_id": "1", + }, + ) + resp = rv.json() + self.assertIn('name="code"', resp) + self.assertIn('name="id_token"', resp) + self.assertIn('name="access_token"', resp) diff --git a/tests/fastapi/test_oauth2/test_openid_implict_grant.py b/tests/fastapi/test_oauth2/test_openid_implict_grant.py new file mode 100644 index 00000000..67b2fbbd --- /dev/null +++ b/tests/fastapi/test_oauth2/test_openid_implict_grant.py @@ -0,0 +1,195 @@ +from authlib.common.urls import add_params_to_uri, url_decode, urlparse +from authlib.jose import jwt +from authlib.oidc.core import ImplicitIDToken +from authlib.oidc.core.grants import \ + OpenIDImplicitGrant as _OpenIDImplicitGrant + +from .models import Client, User, db, exists_nonce +from .oauth2_server import TestCase, create_authorization_server + + +class OpenIDImplicitGrant(_OpenIDImplicitGrant): + def get_jwt_config(self): + return dict(key="secret", alg="HS256", iss="Authlib", exp=3600) + + def generate_user_info(self, user, scopes): + return user.generate_user_info(scopes) + + def exists_nonce(self, nonce, request): + return exists_nonce(nonce, request) + + +class ImplicitTest(TestCase): + def prepare_data(self): + server = create_authorization_server(self.app) + server.register_grant(OpenIDImplicitGrant) + + user = User(username="foo") + db.add(user) + db.commit() + client = Client( + user_id=user.id, + client_id="implicit-client", + client_secret="", + ) + client.set_client_metadata( + { + "redirect_uris": ["https://a.b/c"], + "scope": "openid profile", + "token_endpoint_auth_method": "none", + "response_types": ["id_token", "id_token token"], + } + ) + self.authorize_url = ( + "/oauth/authorize?response_type=token" "&client_id=implicit-client" + ) + db.add(client) + db.commit() + + def validate_claims(self, id_token, params): + claims = jwt.decode( + id_token, "secret", claims_cls=ImplicitIDToken, claims_params=params + ) + claims.validate() + + def test_consent_view(self): + self.prepare_data() + rv = self.client.get( + add_params_to_uri( + "/oauth/authorize", + { + "response_type": "id_token", + "client_id": "implicit-client", + "scope": "openid profile", + "state": "foo", + "redirect_uri": "https://a.b/c", + "user_id": "1", + }, + ) + ) + self.assertIn("error=invalid_request", rv.json()) + self.assertIn("nonce", rv.json()) + + def test_require_nonce(self): + self.prepare_data() + rv = self.client.post( + "/oauth/authorize", + data={ + "response_type": "id_token", + "client_id": "implicit-client", + "scope": "openid profile", + "state": "bar", + "redirect_uri": "https://a.b/c", + "user_id": "1", + }, + ) + self.assertIn("error=invalid_request", rv.headers["location"]) + self.assertIn("nonce", rv.headers["location"]) + + def test_missing_openid_in_scope(self): + self.prepare_data() + rv = self.client.post( + "/oauth/authorize", + data={ + "response_type": "id_token token", + "client_id": "implicit-client", + "scope": "profile", + "state": "bar", + "nonce": "abc", + "redirect_uri": "https://a.b/c", + "user_id": "1", + }, + ) + self.assertIn("error=invalid_scope", rv.headers["location"]) + + def test_denied(self): + self.prepare_data() + rv = self.client.post( + "/oauth/authorize", + data={ + "response_type": "id_token", + "client_id": "implicit-client", + "scope": "openid profile", + "state": "bar", + "nonce": "abc", + "redirect_uri": "https://a.b/c", + }, + ) + self.assertIn("error=access_denied", rv.headers["location"]) + + def test_authorize_access_token(self): + self.prepare_data() + rv = self.client.post( + "/oauth/authorize", + data={ + "response_type": "id_token token", + "client_id": "implicit-client", + "scope": "openid profile", + "state": "bar", + "nonce": "abc", + "redirect_uri": "https://a.b/c", + "user_id": "1", + }, + ) + self.assertIn("access_token=", rv.headers["location"]) + self.assertIn("id_token=", rv.headers["location"]) + self.assertIn("state=bar", rv.headers["location"]) + params = dict(url_decode(urlparse.urlparse(rv.headers["location"]).fragment)) + self.validate_claims(params["id_token"], params) + + def test_authorize_id_token(self): + self.prepare_data() + rv = self.client.post( + "/oauth/authorize", + data={ + "response_type": "id_token", + "client_id": "implicit-client", + "scope": "openid profile", + "state": "bar", + "nonce": "abc", + "redirect_uri": "https://a.b/c", + "user_id": "1", + }, + ) + self.assertIn("id_token=", rv.headers["location"]) + self.assertIn("state=bar", rv.headers["location"]) + params = dict(url_decode(urlparse.urlparse(rv.headers["location"]).fragment)) + self.validate_claims(params["id_token"], params) + + def test_response_mode_query(self): + self.prepare_data() + rv = self.client.post( + "/oauth/authorize", + data={ + "response_type": "id_token", + "response_mode": "query", + "client_id": "implicit-client", + "scope": "openid profile", + "state": "bar", + "nonce": "abc", + "redirect_uri": "https://a.b/c", + "user_id": "1", + }, + ) + self.assertIn("id_token=", rv.headers["location"]) + self.assertIn("state=bar", rv.headers["location"]) + params = dict(url_decode(urlparse.urlparse(rv.headers["location"]).query)) + self.validate_claims(params["id_token"], params) + + def test_response_mode_form_post(self): + self.prepare_data() + rv = self.client.post( + "/oauth/authorize", + data={ + "response_type": "id_token", + "response_mode": "form_post", + "client_id": "implicit-client", + "scope": "openid profile", + "state": "bar", + "nonce": "abc", + "redirect_uri": "https://a.b/c", + "user_id": "1", + }, + ) + self.assertIn('name="id_token"', rv.json()) + self.assertIn('name="state"', rv.json()) diff --git a/tests/fastapi/test_oauth2/test_password_grant.py b/tests/fastapi/test_oauth2/test_password_grant.py new file mode 100644 index 00000000..92c467a5 --- /dev/null +++ b/tests/fastapi/test_oauth2/test_password_grant.py @@ -0,0 +1,195 @@ +from authlib.common.urls import add_params_to_uri +from authlib.oauth2.rfc6749.grants import \ + ResourceOwnerPasswordCredentialsGrant as _PasswordGrant + +from .models import Client, User, db +from .oauth2_server import TestCase, create_authorization_server + + +class PasswordGrant(_PasswordGrant): + def authenticate_user(self, username, password): + user = db.query(User).filter(User.username == username).first() + if user.check_password(password): + return user + + +class PasswordTest(TestCase): + def prepare_data(self, grant_type="password"): + server = create_authorization_server(self.app) + server.register_grant(PasswordGrant) + self.server = server + + user = User(username="foo") + db.add(user) + db.commit() + client = Client( + user_id=user.id, + client_id="password-client", + client_secret="password-secret", + ) + client.set_client_metadata( + { + "scope": "profile", + "grant_types": [grant_type], + "redirect_uris": ["http://localhost/authorized"], + } + ) + db.add(client) + db.commit() + + def test_invalid_client(self): + self.prepare_data() + rv = self.client.post( + "/oauth/token", + data={ + "grant_type": "password", + "username": "foo", + "password": "ok", + }, + ) + resp = rv.json() + self.assertEqual(resp["error"], "invalid_client") + + headers = self.create_basic_header("password-client", "invalid-secret") + rv = self.client.post( + "/oauth/token", + data={ + "grant_type": "password", + "username": "foo", + "password": "ok", + }, + headers=headers, + ) + resp = rv.json() + self.assertEqual(resp["error"], "invalid_client") + + def test_invalid_scope(self): + self.prepare_data() + self.server.scopes_supported = "profile" + headers = self.create_basic_header("password-client", "password-secret") + rv = self.client.post( + "/oauth/token", + data={ + "grant_type": "password", + "username": "foo", + "password": "ok", + "scope": "invalid", + }, + headers=headers, + ) + resp = rv.json() + self.assertEqual(resp["error"], "invalid_scope") + + def test_invalid_request(self): + self.prepare_data() + headers = self.create_basic_header("password-client", "password-secret") + + rv = self.client.get( + add_params_to_uri( + "/oauth/token", + { + "grant_type": "password", + }, + ), + headers=headers, + ) + resp = rv.json() + self.assertEqual(resp["error"], "unsupported_grant_type") + + rv = self.client.post( + "/oauth/token", + data={ + "grant_type": "password", + }, + headers=headers, + ) + resp = rv.json() + self.assertEqual(resp["error"], "invalid_request") + + rv = self.client.post( + "/oauth/token", + data={ + "grant_type": "password", + "username": "foo", + }, + headers=headers, + ) + resp = rv.json() + self.assertEqual(resp["error"], "invalid_request") + + rv = self.client.post( + "/oauth/token", + data={ + "grant_type": "password", + "username": "foo", + "password": "wrong", + }, + headers=headers, + ) + resp = rv.json() + self.assertEqual(resp["error"], "invalid_request") + + def test_invalid_grant_type(self): + self.prepare_data(grant_type="invalid") + headers = self.create_basic_header("password-client", "password-secret") + rv = self.client.post( + "/oauth/token", + data={ + "grant_type": "password", + "username": "foo", + "password": "ok", + }, + headers=headers, + ) + resp = rv.json() + self.assertEqual(resp["error"], "unauthorized_client") + + def test_authorize_token(self): + self.prepare_data() + headers = self.create_basic_header("password-client", "password-secret") + rv = self.client.post( + "/oauth/token", + data={ + "grant_type": "password", + "username": "foo", + "password": "ok", + }, + headers=headers, + ) + resp = rv.json() + self.assertIn("access_token", resp) + + def test_token_generator(self): + m = "tests.fastapi.test_oauth2.oauth2_server:token_generator" + self.app.config.update({"OAUTH2_ACCESS_TOKEN_GENERATOR": m}) + self.prepare_data() + headers = self.create_basic_header("password-client", "password-secret") + rv = self.client.post( + "/oauth/token", + data={ + "grant_type": "password", + "username": "foo", + "password": "ok", + }, + headers=headers, + ) + resp = rv.json() + self.assertIn("access_token", resp) + self.assertIn("p-password.1.", resp["access_token"]) + + def test_custom_expires_in(self): + self.app.config.update({"OAUTH2_TOKEN_EXPIRES_IN": {"password": 1800}}) + self.prepare_data() + headers = self.create_basic_header("password-client", "password-secret") + rv = self.client.post( + "/oauth/token", + data={ + "grant_type": "password", + "username": "foo", + "password": "ok", + }, + headers=headers, + ) + resp = rv.json() + self.assertIn("access_token", resp) + self.assertEqual(resp["expires_in"], 1800) diff --git a/tests/fastapi/test_oauth2/test_refresh_token.py b/tests/fastapi/test_oauth2/test_refresh_token.py new file mode 100644 index 00000000..a6564aa9 --- /dev/null +++ b/tests/fastapi/test_oauth2/test_refresh_token.py @@ -0,0 +1,227 @@ +from authlib.oauth2.rfc6749.grants import ( + RefreshTokenGrant as _RefreshTokenGrant, +) +from .models import db, User, Client, Token +from .oauth2_server import TestCase +from .oauth2_server import create_authorization_server + + +class RefreshTokenGrant(_RefreshTokenGrant): + def authenticate_refresh_token(self, refresh_token): + item = db.query(Token).filter(Token.refresh_token == refresh_token).first() + if item and not item.revoked and not item.is_refresh_token_expired(): + return item + + def authenticate_user(self, credential): + return db.query(User).filter(User.id == int(credential.user_id)).first() + + def revoke_old_credential(self, credential): + credential.revoked = True + db.add(credential) + db.commit() + + +class RefreshTokenTest(TestCase): + def prepare_data(self, grant_type='refresh_token'): + server = create_authorization_server(self.app) + server.register_grant(RefreshTokenGrant) + + user = User(username='foo') + db.add(user) + db.commit() + client = Client( + user_id=user.id, + client_id='refresh-client', + client_secret='refresh-secret', + ) + client.set_client_metadata({ + 'scope': 'profile', + 'grant_types': [grant_type], + 'redirect_uris': ['http://localhost/authorized'], + }) + db.add(client) + db.commit() + + def create_token(self, scope='profile', user_id=1): + token = Token( + user_id=user_id, + client_id='refresh-client', + token_type='bearer', + access_token='a1', + refresh_token='r1', + scope=scope, + expires_in=3600, + ) + db.add(token) + db.commit() + + def test_invalid_client(self): + self.prepare_data() + rv = self.client.post('/oauth/token', data={ + 'grant_type': 'refresh_token', + 'refresh_token': 'foo', + }) + resp = rv.json() + self.assertEqual(resp['error'], 'invalid_client') + + headers = self.create_basic_header( + 'invalid-client', 'refresh-secret' + ) + rv = self.client.post('/oauth/token', data={ + 'grant_type': 'refresh_token', + 'refresh_token': 'foo', + }, headers=headers) + resp = rv.json() + self.assertEqual(resp['error'], 'invalid_client') + + headers = self.create_basic_header( + 'refresh-client', 'invalid-secret' + ) + rv = self.client.post('/oauth/token', data={ + 'grant_type': 'refresh_token', + 'refresh_token': 'foo', + }, headers=headers) + resp = rv.json() + self.assertEqual(resp['error'], 'invalid_client') + + def test_invalid_refresh_token(self): + self.prepare_data() + headers = self.create_basic_header( + 'refresh-client', 'refresh-secret' + ) + rv = self.client.post('/oauth/token', data={ + 'grant_type': 'refresh_token', + }, headers=headers) + resp = rv.json() + self.assertEqual(resp['error'], 'invalid_request') + self.assertIn('Missing', resp['error_description']) + + rv = self.client.post('/oauth/token', data={ + 'grant_type': 'refresh_token', + 'refresh_token': 'foo', + }, headers=headers) + resp = rv.json() + self.assertEqual(resp['error'], 'invalid_grant') + + def test_invalid_scope(self): + self.prepare_data() + self.create_token() + headers = self.create_basic_header( + 'refresh-client', 'refresh-secret' + ) + rv = self.client.post('/oauth/token', data={ + 'grant_type': 'refresh_token', + 'refresh_token': 'r1', + 'scope': 'invalid', + }, headers=headers) + resp = rv.json() + self.assertEqual(resp['error'], 'invalid_scope') + + def test_invalid_scope_none(self): + self.prepare_data() + self.create_token(scope=None) + headers = self.create_basic_header( + 'refresh-client', 'refresh-secret' + ) + rv = self.client.post('/oauth/token', data={ + 'grant_type': 'refresh_token', + 'refresh_token': 'r1', + 'scope': 'invalid', + }, headers=headers) + resp = rv.json() + self.assertEqual(resp['error'], 'invalid_scope') + + def test_invalid_user(self): + self.prepare_data() + self.create_token(user_id=5) + headers = self.create_basic_header( + 'refresh-client', 'refresh-secret' + ) + rv = self.client.post('/oauth/token', data={ + 'grant_type': 'refresh_token', + 'refresh_token': 'r1', + 'scope': 'profile', + }, headers=headers) + resp = rv.json() + self.assertEqual(resp['error'], 'invalid_request') + + def test_invalid_grant_type(self): + self.prepare_data(grant_type='invalid') + self.create_token() + headers = self.create_basic_header( + 'refresh-client', 'refresh-secret' + ) + rv = self.client.post('/oauth/token', data={ + 'grant_type': 'refresh_token', + 'refresh_token': 'r1', + 'scope': 'profile', + }, headers=headers) + resp = rv.json() + self.assertEqual(resp['error'], 'unauthorized_client') + + def test_authorize_token_no_scope(self): + self.prepare_data() + self.create_token() + headers = self.create_basic_header( + 'refresh-client', 'refresh-secret' + ) + rv = self.client.post('/oauth/token', data={ + 'grant_type': 'refresh_token', + 'refresh_token': 'r1', + }, headers=headers) + resp = rv.json() + self.assertIn('access_token', resp) + + def test_authorize_token_scope(self): + self.prepare_data() + self.create_token() + headers = self.create_basic_header( + 'refresh-client', 'refresh-secret' + ) + rv = self.client.post('/oauth/token', data={ + 'grant_type': 'refresh_token', + 'refresh_token': 'r1', + 'scope': 'profile', + }, headers=headers) + resp = rv.json() + self.assertIn('access_token', resp) + + def test_revoke_old_credential(self): + self.prepare_data() + self.create_token() + headers = self.create_basic_header( + 'refresh-client', 'refresh-secret' + ) + rv = self.client.post('/oauth/token', data={ + 'grant_type': 'refresh_token', + 'refresh_token': 'r1', + 'scope': 'profile', + }, headers=headers) + resp = rv.json() + self.assertIn('access_token', resp) + + rv = self.client.post('/oauth/token', data={ + 'grant_type': 'refresh_token', + 'refresh_token': 'r1', + 'scope': 'profile', + }, headers=headers) + self.assertEqual(rv.status_code, 400) + resp = rv.json() + self.assertEqual(resp['error'], 'invalid_grant') + + def test_token_generator(self): + m = 'tests.fastapi.test_oauth2.oauth2_server:token_generator' + self.app.config.update({'OAUTH2_ACCESS_TOKEN_GENERATOR': m}) + + self.prepare_data() + self.create_token() + headers = self.create_basic_header( + 'refresh-client', 'refresh-secret' + ) + rv = self.client.post('/oauth/token', data={ + 'grant_type': 'refresh_token', + 'refresh_token': 'r1', + }, headers=headers) + resp = rv.json() + self.assertIn('access_token', resp) + self.assertIn('r-refresh_token.1.', resp['access_token']) diff --git a/tests/fastapi/test_oauth2/test_revocation_endpoint.py b/tests/fastapi/test_oauth2/test_revocation_endpoint.py new file mode 100644 index 00000000..2c2bc14a --- /dev/null +++ b/tests/fastapi/test_oauth2/test_revocation_endpoint.py @@ -0,0 +1,129 @@ +from fastapi import Request, Form +from authlib.integrations.sqla_oauth2 import create_revocation_endpoint +from .models import db, User, Client, Token +from .oauth2_server import TestCase +from .oauth2_server import create_authorization_server + + +RevocationEndpoint = create_revocation_endpoint(db, Token) + + +class RevokeTokenTest(TestCase): + def prepare_data(self): + app = self.app + server = create_authorization_server(app) + server.register_endpoint(RevocationEndpoint) + + @app.post('/oauth/revoke') + def revoke_token(request: Request, + token: str = Form(None), + token_type_hint: str = Form(None)): + request.body = {} + if token: + request.body.update({'token': token}) + if token_type_hint: + request.body.update({'token_type_hint': token_type_hint}) + return server.create_endpoint_response('revocation', request=request) + + user = User(username='foo') + db.add(user) + db.commit() + client = Client( + user_id=user.id, + client_id='revoke-client', + client_secret='revoke-secret', + ) + client.set_client_metadata({ + 'scope': 'profile', + 'redirect_uris': ['http://localhost/authorized'], + }) + db.add(client) + db.commit() + + def create_token(self): + token = Token( + user_id=1, + client_id='revoke-client', + token_type='bearer', + access_token='a1', + refresh_token='r1', + scope='profile', + expires_in=3600, + ) + db.add(token) + db.commit() + + def test_invalid_client(self): + self.prepare_data() + rv = self.client.post('/oauth/revoke') + resp = rv.json() + self.assertEqual(resp['error'], 'invalid_client') + + headers = {'Authorization': 'invalid token_string'} + rv = self.client.post('/oauth/revoke', headers=headers) + resp = rv.json() + self.assertEqual(resp['error'], 'invalid_client') + + headers = self.create_basic_header( + 'invalid-client', 'revoke-secret' + ) + rv = self.client.post('/oauth/revoke', headers=headers) + resp = rv.json() + self.assertEqual(resp['error'], 'invalid_client') + + headers = self.create_basic_header( + 'revoke-client', 'invalid-secret' + ) + rv = self.client.post('/oauth/revoke', headers=headers) + resp = rv.json() + self.assertEqual(resp['error'], 'invalid_client') + + def test_invalid_token(self): + self.prepare_data() + headers = self.create_basic_header( + 'revoke-client', 'revoke-secret' + ) + rv = self.client.post('/oauth/revoke', headers=headers) + resp = rv.json() + self.assertEqual(resp['error'], 'invalid_request') + + rv = self.client.post('/oauth/revoke', data={ + 'token': 'invalid-token', + }, headers=headers) + self.assertEqual(rv.status_code, 200) + + rv = self.client.post('/oauth/revoke', data={ + 'token': 'a1', + 'token_type_hint': 'unsupported_token_type', + }, headers=headers) + resp = rv.json() + self.assertEqual(resp['error'], 'unsupported_token_type') + + rv = self.client.post('/oauth/revoke', data={ + 'token': 'a1', + 'token_type_hint': 'refresh_token', + }, headers=headers) + self.assertEqual(rv.status_code, 200) + + def test_revoke_token_with_hint(self): + self.prepare_data() + self.create_token() + headers = self.create_basic_header( + 'revoke-client', 'revoke-secret' + ) + rv = self.client.post('/oauth/revoke', data={ + 'token': 'a1', + 'token_type_hint': 'access_token', + }, headers=headers) + self.assertEqual(rv.status_code, 200) + + def test_revoke_token_without_hint(self): + self.prepare_data() + self.create_token() + headers = self.create_basic_header( + 'revoke-client', 'revoke-secret' + ) + rv = self.client.post('/oauth/revoke', data={ + 'token': 'a1', + }, headers=headers) + self.assertEqual(rv.status_code, 200) diff --git a/tox.ini b/tox.ini index 94075413..2ae88c74 100644 --- a/tox.ini +++ b/tox.ini @@ -1,12 +1,16 @@ [tox] envlist = py{36,37,38,39} - py{36,37,38,39}-{flask,django,starlette} + py{36,37,38,39}-{flask,django,starlette,fastapi} coverage [testenv] deps = -rrequirements-test.txt + fastapi: FastAPI + fastapi: sqlalchemy + fastapi: werkzeug + fastapi: python-multipart flask: Flask flask: Flask-SQLAlchemy flask: itsdangerous @@ -21,6 +25,7 @@ deps = setenv = TESTPATH=tests/core starlette: TESTPATH=tests/starlette + fastapi: TESTPATH=tests/fastapi flask: TESTPATH=tests/flask django: TESTPATH=tests/django commands =