diff --git a/auth/auth.py b/auth/auth.py index 71c7928..2b80720 100644 --- a/auth/auth.py +++ b/auth/auth.py @@ -1,10 +1,15 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Dict, Union +import aiohttp_jinja2 +from aiohttp import web +from aiohttp.web_exceptions import HTTPFound +from aiohttp_security import remember from aiopg.sa.engine import Engine from aiopg.sa.result import RowProxy from aiohttp.web import Application from aiohttp_security.abc import AbstractAuthorizationPolicy +from dynaconf import settings from passlib.hash import sha256_crypt from sqlalchemy import and_, func, not_ @@ -14,6 +19,31 @@ from typing import Optional +async def get_login_context(error: str | None = None) -> Dict[str, Union[str | bool]]: + use_oauth = getattr(getattr(settings, 'OAUTH', None), 'IS_USED', False) + only_oauth = getattr(getattr(settings, 'OAUTH', None), 'ONLY_OAUTH', False) + oauth_sign_in_title = getattr(getattr(settings, 'OAUTH', None), 'SIGN_IN_TITLE', '') + context = { + 'context': '', + 'use_oauth': use_oauth, + 'only_oauth': only_oauth, + 'oauth_sign_in_title': oauth_sign_in_title, + } + if error: + context['error'] = error + return context + + +async def oauth_on_login(request: web.Request, user_data: dict) -> web.Response: + await remember(request, HTTPFound('/zbs/switches'), 'admin') + return HTTPFound('/zbs/switches') + + +@aiohttp_jinja2.template('users/login.html') +async def oauth_on_error(request: web.Request) -> Dict[str, Union[str | bool]]: + return await get_login_context(error='OAUTH failed') + + async def check_credentials(db_engine: Engine, username: str, password: str) -> bool: """Производит аутентификацию пользователя.""" async with db_engine.acquire() as conn: diff --git a/auth/views.py b/auth/views.py index 0c55027..497c403 100644 --- a/auth/views.py +++ b/auth/views.py @@ -1,14 +1,15 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Union import aiohttp_jinja2 from aiohttp.abc import StreamResponse from aiohttp.web import HTTPFound, View, Response from aiohttp_security import forget, remember +from dynaconf import settings from marshmallow.exceptions import ValidationError from multidict import MultiDictProxy -from auth.auth import check_credentials +from auth.auth import check_credentials, get_login_context from auth.schemes import LoginPostRequestSchema if TYPE_CHECKING: @@ -17,12 +18,16 @@ class LoginView(View): @aiohttp_jinja2.template('users/login.html') - async def get(self, error: Optional[str] = None) -> Dict[str, str]: - return {'context': ''} + async def get(self, error: Optional[str] = None) -> Dict[str, Union[str | bool]]: + return await get_login_context() @aiohttp_jinja2.template('users/login.html') - async def error(self) -> Dict[str, str]: - return {'context': '', 'error': 'Authorization failed'} + async def error(self) -> Dict[str, Union[str | bool]]: + return await get_login_context('Authorization failed') + + @aiohttp_jinja2.template('users/login.html') + async def only_oauth_error(self) -> Dict[str, Union[str | bool]]: + return await get_login_context('Classic login is forbidden') async def authorise( self, response_location: Response, login: str, password: str, @@ -33,6 +38,9 @@ async def authorise( return await self.error() async def post(self) -> StreamResponse: + only_oauth = getattr(getattr(settings, 'OAUTH', None), 'ONLY_OAUTH', False) + if only_oauth: + return await self.only_oauth_error() response_location = HTTPFound('/zbs/switches') form_data = await self.request.post() validated_data = self.validate_form_data(form_data) diff --git a/its_on/main.py b/its_on/main.py index 037b3e1..de9cdf1 100644 --- a/its_on/main.py +++ b/its_on/main.py @@ -4,6 +4,7 @@ import pathlib from aiohttp import web +from aiohttp_oauth2 import oauth2_app from aiohttp_security import setup as setup_security from aiohttp_security import SessionIdentityPolicy import aiohttp_cors @@ -16,7 +17,7 @@ from dynaconf import settings import uvloop -from auth.auth import DBAuthorizationPolicy +from auth.auth import DBAuthorizationPolicy, oauth_on_login, oauth_on_error from its_on.cache import setup_cache from its_on.db_utils import init_pg, close_pg from its_on.middlewares import setup_middlewares @@ -42,6 +43,20 @@ def init_app( ) -> web.Application: app = web.Application(loop=loop) + if settings.OAUTH.IS_USED: + app.add_subapp( + '/oauth/', + oauth2_app( + client_id=settings.OAUTH.CLIENT_ID, + client_secret=settings.OAUTH.CLIENT_SECRET, + authorize_url=settings.OAUTH.AUTHORIZE_URL, + token_url=settings.OAUTH.TOKEN_URL, + on_login=oauth_on_login, + on_error=oauth_on_error, + json_data=False, + ), + ) + app['config'] = settings if not redis_pool: diff --git a/its_on/templates/users/login.html b/its_on/templates/users/login.html index 22f55cc..8df268d 100644 --- a/its_on/templates/users/login.html +++ b/its_on/templates/users/login.html @@ -12,6 +12,7 @@