From 01ec996c952f918ee2dca16e62931d9a662ee7a4 Mon Sep 17 00:00:00 2001 From: keakon Date: Mon, 13 May 2024 23:41:01 +0800 Subject: [PATCH] reduce type ignore --- app/controllers/hello.py | 5 ++-- app/controllers/user.py | 3 ++- app/models/__init__.py | 17 +++++++------ app/utils/exception.py | 5 ++++ tests/controllers/test_hello.py | 4 +-- tests/controllers/test_user.py | 34 ++++++++++++++++--------- tests/models/test_base_model.py | 44 ++++++++++++++++++++++----------- tests/models/test_user.py | 6 ++--- 8 files changed, 76 insertions(+), 42 deletions(-) diff --git a/app/controllers/hello.py b/app/controllers/hello.py index 70c9d61..3afcd0c 100644 --- a/app/controllers/hello.py +++ b/app/controllers/hello.py @@ -1,7 +1,8 @@ from fastapi import Depends +from sqlmodel import col from app.clients.mysql import async_session -from app.models.user import get_current_user_id, User +from app.models.user import User, get_current_user_id from app.router import router from app.schemas.resp import Resp @@ -14,5 +15,5 @@ def hello(user_name: str): @router.get('/hello', response_model=Resp, response_model_exclude_none=True) async def hello_to_self(current_user_id: int = Depends(get_current_user_id)): async with async_session() as session: - user_name = await User.get_by_id(session, current_user_id, User.name) # type: ignore + user_name = await User.get_by_id(session, current_user_id, col(User.name)) return Resp(msg=f'Hello, {user_name}!') diff --git a/app/controllers/user.py b/app/controllers/user.py index 3311e70..7b87959 100644 --- a/app/controllers/user.py +++ b/app/controllers/user.py @@ -3,6 +3,7 @@ from fastapi import Body, Depends from fastapi.security import OAuth2PasswordRequestForm from sqlalchemy import Row +from sqlmodel import col from app.clients.mysql import async_session from app.models import all_is_instance @@ -57,7 +58,7 @@ async def update_user(user_id: int, req: UserRequest, current_user_id: int = Dep @router.get('/user/{user_id}/name', response_model=Resp, response_model_exclude_none=True) async def get_user_name(user_id: int, _=Depends(get_current_user_id)): async with async_session() as session: - user_name = await User.get_by_id(session, user_id, User.name) # type: ignore + user_name = await User.get_by_id(session, user_id, col(User.name)) if user_name: return Resp(data={'name': user_name}) raise not_found_error diff --git a/app/models/__init__.py b/app/models/__init__.py index 6aafd25..b8f36c2 100644 --- a/app/models/__init__.py +++ b/app/models/__init__.py @@ -2,10 +2,11 @@ from sqlalchemy import Column, Row from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import Mapped from sqlalchemy.orm.attributes import InstrumentedAttribute from sqlalchemy.sql import func, text from sqlalchemy.sql.elements import TextClause -from sqlmodel import Field, SQLModel, delete, insert, select, update +from sqlmodel import Field, SQLModel, col, delete, insert, select, update Values = dict[str, Any] @@ -26,7 +27,7 @@ async def get_by_id( cls, session: AsyncSession, id: int, - columns: list | tuple | InstrumentedAttribute | TextClause | Column | None = None, + columns: list | tuple | InstrumentedAttribute | TextClause | Column | Mapped | None = None, for_update: bool = False, for_read: bool = False, ) -> Any: @@ -55,7 +56,7 @@ async def get_by_ids( cls, session: AsyncSession, ids: Sequence[int], - columns: list | tuple | InstrumentedAttribute | TextClause | Column | None = None, + columns: list | tuple | InstrumentedAttribute | TextClause | Column | Mapped | None = None, for_update: bool = False, for_read: bool = False, ) -> 'Sequence[BaseModel | Row]': @@ -71,7 +72,7 @@ async def get_by_ids( else: query = select(columns) scalar = True - query = query.where(cls.id.in_(ids)) # type: ignore + query = query.where(col(cls.id).in_(ids)) if for_update: query = query.with_for_update() elif for_read: @@ -94,7 +95,7 @@ async def exist(cls, session: AsyncSession, id: int, for_update: bool = False, f async def get_all( cls, session: AsyncSession, - columns: list | tuple | InstrumentedAttribute | None = None, + columns: list | tuple | InstrumentedAttribute | Mapped | None = None, for_update: bool = False, for_read: bool = False, ) -> 'Sequence[BaseModel | Row]': @@ -123,15 +124,15 @@ async def count_all(cls, session: AsyncSession) -> int: @classmethod async def update_by_id(cls, session: AsyncSession, id: int, values: Values) -> int: - return (await session.execute(update(cls).where(cls.id == id).values(**values))).rowcount # type: ignore + return (await session.execute(update(cls).where(col(cls.id) == id).values(**values))).rowcount @classmethod async def delete_by_id(cls, session: AsyncSession, id: int) -> int: - return (await session.execute(delete(cls).where(cls.id == id))).rowcount # type: ignore + return (await session.execute(delete(cls).where(col(cls.id) == id))).rowcount @classmethod async def delete_by_ids(cls, session: AsyncSession, ids: Sequence[int]) -> int: - return (await session.execute(delete(cls).where(cls.id.in_(ids)))).rowcount # type: ignore + return (await session.execute(delete(cls).where(col(cls.id).in_(ids)))).rowcount @classmethod async def insert(cls, session: AsyncSession, values: Values) -> int: diff --git a/app/utils/exception.py b/app/utils/exception.py index 163e863..1e59821 100644 --- a/app/utils/exception.py +++ b/app/utils/exception.py @@ -40,6 +40,11 @@ def unauthorized_error(msg) -> HTTPError: invalid_token_error = unauthorized_error('Invalid token') not_authenticated_error = unauthorized_error('Not authenticated') + +def bad_request_error(msg) -> HTTPError: + return HTTPError(status_code=400, code=ErrorCode.BAD_REQUEST, msg=msg) + + forbidden_error = HTTPError( status_code=403, code=ErrorCode.FORBIDDEN, diff --git a/tests/controllers/test_hello.py b/tests/controllers/test_hello.py index e09ed39..b142e31 100644 --- a/tests/controllers/test_hello.py +++ b/tests/controllers/test_hello.py @@ -1,5 +1,5 @@ import pytest -from sqlmodel import delete +from sqlmodel import col, delete from app.clients.mysql import async_session from app.models.user import User @@ -17,7 +17,7 @@ def test_hello(): @pytest.mark.asyncio(scope='session') async def test_hello_to_self(): async with async_session() as session: - if (await session.execute(delete(User).where(User.id > 1))).rowcount > 0: # type: ignore + if (await session.execute(delete(User).where(col(User.id) > 1))).rowcount > 0: await session.commit() async with async_client() as client: diff --git a/tests/controllers/test_user.py b/tests/controllers/test_user.py index 7988bea..6ca4ff6 100644 --- a/tests/controllers/test_user.py +++ b/tests/controllers/test_user.py @@ -1,11 +1,11 @@ import pytest from sqlalchemy.exc import IntegrityError -from sqlmodel import delete +from sqlmodel import col, delete +from app.clients.mysql import async_session from app.models.user import User from app.schemas.token import TokenPayload from app.schemas.user import UserRequest -from app.clients.mysql import async_session from app.utils.token import decode_token from . import async_client @@ -63,7 +63,7 @@ async def test_get_user(): @pytest.mark.asyncio(scope='session') async def test_update_user(): async with async_session() as session: - if (await session.execute(delete(User).where(User.id > 1))).rowcount > 0: # type: ignore + if (await session.execute(delete(User).where(col(User.id) > 1))).rowcount > 0: await session.commit() async with async_client() as client: @@ -81,14 +81,18 @@ async def test_update_user(): user_id = payload.user_id data = UserRequest(name='test2', password='test2').model_dump() - response = await client.put(f'/api/v1/user/{user_id}', headers={'Authorization': f'Bearer {access_token}'}, json=data) + response = await client.put( + f'/api/v1/user/{user_id}', headers={'Authorization': f'Bearer {access_token}'}, json=data + ) assert response.status_code == 403 response = await client.post('/api/v1/login', data={'username': 'admin', 'password': 'admin'}) assert response.status_code == 200 access_token = response.json()['access_token'] - response = await client.put(f'/api/v1/user/{user_id}', headers={'Authorization': f'Bearer {access_token}'}, json=data) + response = await client.put( + f'/api/v1/user/{user_id}', headers={'Authorization': f'Bearer {access_token}'}, json=data + ) assert response.status_code == 200 response = await client.post('/api/v1/login', data={'username': 'test', 'password': 'test'}) @@ -104,7 +108,7 @@ async def test_update_user(): @pytest.mark.asyncio(scope='session') async def test_get_user_name(): async with async_session() as session: - if (await session.execute(delete(User).where(User.id > 1))).rowcount > 0: # type: ignore + if (await session.execute(delete(User).where(col(User.id) > 1))).rowcount > 0: await session.commit() async with async_client() as client: @@ -136,7 +140,7 @@ async def test_get_user_name(): @pytest.mark.asyncio(scope='session') async def test_set_user_name(): async with async_session() as session: - if (await session.execute(delete(User).where(User.id > 1))).rowcount > 0: # type: ignore + if (await session.execute(delete(User).where(col(User.id) > 1))).rowcount > 0: await session.commit() async with async_client() as client: @@ -153,7 +157,9 @@ async def test_set_user_name(): payload = TokenPayload.model_validate_json(token.payload) user_id = payload.user_id - response = await client.patch(f'/api/v1/user/{user_id}/name', headers={'Authorization': f'Bearer {access_token}'}, json={'name': 'test2'}) + response = await client.patch( + f'/api/v1/user/{user_id}/name', headers={'Authorization': f'Bearer {access_token}'}, json={'name': 'test2'} + ) assert response.status_code == 200 response = await client.post('/api/v1/login', data={'username': 'test', 'password': 'test'}) @@ -162,14 +168,16 @@ async def test_set_user_name(): response = await client.post('/api/v1/login', data={'username': 'test2', 'password': 'test'}) assert response.status_code == 200 - response = await client.patch('/api/v1/user/0/name', headers={'Authorization': f'Bearer {access_token}'}, json={'name': 'test2'}) + response = await client.patch( + '/api/v1/user/0/name', headers={'Authorization': f'Bearer {access_token}'}, json={'name': 'test2'} + ) assert response.status_code == 404 @pytest.mark.asyncio(scope='session') async def test_get_user_time(): async with async_session() as session: - if (await session.execute(delete(User).where(User.id > 1))).rowcount > 0: # type: ignore + if (await session.execute(delete(User).where(col(User.id) > 1))).rowcount > 0: await session.commit() async with async_client() as client: @@ -193,7 +201,9 @@ async def test_get_user_time(): updated_at = data['updated_at'] assert created_at == updated_at - response = await client.patch(f'/api/v1/user/{user_id}/name', headers={'Authorization': f'Bearer {access_token}'}, json={'name': 'test2'}) + response = await client.patch( + f'/api/v1/user/{user_id}/name', headers={'Authorization': f'Bearer {access_token}'}, json={'name': 'test2'} + ) assert response.status_code == 200 response = await client.get(f'/api/v1/user/{user_id}/time', headers={'Authorization': f'Bearer {access_token}'}) @@ -211,7 +221,7 @@ async def test_get_user_time(): @pytest.mark.asyncio(scope='session') async def test_get_user_list(): async with async_session() as session: - if (await session.execute(delete(User).where(User.id > 1))).rowcount > 0: # type: ignore + if (await session.execute(delete(User).where(col(User.id) > 1))).rowcount > 0: await session.commit() async with async_client() as client: diff --git a/tests/models/test_base_model.py b/tests/models/test_base_model.py index 0dce83c..c4c4426 100644 --- a/tests/models/test_base_model.py +++ b/tests/models/test_base_model.py @@ -1,6 +1,7 @@ import pytest from sqlalchemy import Column, Row from sqlalchemy.sql import text +from sqlmodel import col from app.clients.mysql import async_session from app.models import BaseModel, all_is_instance @@ -14,7 +15,7 @@ class Model(BaseModel, table=True): class TestBaseModel: async def test_get_by_id(self): async with async_session() as session: - await session.execute(text('TRUNCATE TABLE model')) + await session.execute(text(f'TRUNCATE TABLE {Model.__tablename__}')) model = await Model.get_by_id(session, 1) assert model is None @@ -55,6 +56,9 @@ async def test_get_by_id(self): name = await Model.get_by_id(session, 1, Column('name')) assert name == 'test' + name = await Model.get_by_id(session, 1, col(Model.name)) + assert name == 'test' + row = await Model.get_by_id(session, 1, (Model.name,)) assert isinstance(row, Row) assert row.name == 'test' @@ -74,7 +78,7 @@ async def test_get_by_id(self): async def test_get_by_ids(self): async with async_session() as session: - await session.execute(text('TRUNCATE TABLE model')) + await session.execute(text(f'TRUNCATE TABLE {Model.__tablename__}')) models = await Model.get_by_ids(session, (1, 2, 3)) assert len(models) == 0 @@ -152,6 +156,12 @@ async def test_get_by_ids(self): assert row0 == 'test' assert row1 == 'test2' + rows = await Model.get_by_ids(session, (1, 2, 3), col(Model.name)) + assert len(rows) == 2 + row0, row1 = rows + assert row0 == 'test' + assert row1 == 'test2' + rows = await Model.get_by_ids(session, (1, 2, 3), (Model.name,)) assert len(rows) == 2 assert all_is_instance(rows, Row) @@ -181,7 +191,7 @@ async def test_get_by_ids(self): async def test_exist(self): async with async_session() as session: - await session.execute(text('TRUNCATE TABLE model')) + await session.execute(text(f'TRUNCATE TABLE {Model.__tablename__}')) assert not await Model.exist(session, 1) @@ -191,7 +201,7 @@ async def test_exist(self): async def test_get_all(self): async with async_session() as session: - await session.execute(text('TRUNCATE TABLE model')) + await session.execute(text(f'TRUNCATE TABLE {Model.__tablename__}')) models = await Model.get_all(session) assert len(models) == 0 @@ -239,6 +249,12 @@ async def test_get_all(self): assert row0 == 'test' assert row1 == 'test2' + rows = await Model.get_all(session, col(Model.name)) + assert len(rows) == 2 + row0, row1 = rows + assert row0 == 'test' + assert row1 == 'test2' + rows = await Model.get_all(session, Model.name, for_update=True) # type: ignore assert len(rows) == 2 row0, row1 = rows @@ -280,7 +296,7 @@ async def test_get_all(self): async def test_count_all(self): async with async_session() as session: - await session.execute(text('TRUNCATE TABLE model')) + await session.execute(text(f'TRUNCATE TABLE {Model.__tablename__}')) count = await Model.count_all(session) assert count == 0 @@ -294,7 +310,7 @@ async def test_count_all(self): async def test_update_by_id(self): async with async_session() as session: - await session.execute(text('TRUNCATE TABLE model')) + await session.execute(text(f'TRUNCATE TABLE {Model.__tablename__}')) count = await Model.update_by_id(session, 1, {'name': 'test2'}) assert count == 0 @@ -304,17 +320,17 @@ async def test_update_by_id(self): count = await Model.update_by_id(session, 1, {'name': 'test2'}) assert count == 1 - name = await Model.get_by_id(session, 1, Model.name) # type: ignore + name = await Model.get_by_id(session, 1, col(Model.name)) assert name == 'test2' count = await Model.update_by_id(session, 1, {Model.id.name: 2}) # type: ignore assert count == 1 - name = await Model.get_by_id(session, 2, Model.name) # type: ignore + name = await Model.get_by_id(session, 2, col(Model.name)) assert name == 'test2' async def test_delete_by_id(self): async with async_session() as session: - await session.execute(text('TRUNCATE TABLE model')) + await session.execute(text(f'TRUNCATE TABLE {Model.__tablename__}')) count = await Model.delete_by_id(session, 1) assert count == 0 @@ -328,7 +344,7 @@ async def test_delete_by_id(self): async def test_delete_by_ids(self): async with async_session() as session: - await session.execute(text('TRUNCATE TABLE model')) + await session.execute(text(f'TRUNCATE TABLE {Model.__tablename__}')) count = await Model.delete_by_ids(session, (1, 2)) assert count == 0 @@ -353,21 +369,21 @@ async def test_delete_by_ids(self): async def test_insert(self): async with async_session() as session: - await session.execute(text('TRUNCATE TABLE model')) + await session.execute(text(f'TRUNCATE TABLE {Model.__tablename__}')) row_id = await Model.insert(session, {'name': 'test'}) assert row_id == 1 - name = await Model.get_by_id(session, 1, Model.name) # type: ignore + name = await Model.get_by_id(session, 1, col(Model.name)) assert name == 'test' row_id = await Model.insert(session, {'id': 3, 'name': 'test3'}) assert row_id == 3 - name = await Model.get_by_id(session, 3, Model.name) # type: ignore + name = await Model.get_by_id(session, 3, col(Model.name)) assert name == 'test3' async def test_batch_insert(self): async with async_session() as session: - await session.execute(text('TRUNCATE TABLE model')) + await session.execute(text(f'TRUNCATE TABLE {Model.__tablename__}')) count = await Model.batch_insert(session, [{'name': 'test'}, {'name': 'test2'}]) assert count == 2 diff --git a/tests/models/test_user.py b/tests/models/test_user.py index a7f5dcf..15b2441 100644 --- a/tests/models/test_user.py +++ b/tests/models/test_user.py @@ -1,6 +1,6 @@ import pytest from argon2.exceptions import VerifyMismatchError -from sqlmodel import delete +from sqlmodel import col, delete from app.clients.mysql import async_session from app.models.user import User, get_current_user_id @@ -34,7 +34,7 @@ def test_hash_and_verify_password(self): @pytest.mark.asyncio(scope='session') async def test_get_verified_user_id(self): async with async_session() as session: - if (await session.execute(delete(User).where(User.id > 1))).rowcount > 0: # type: ignore + if (await session.execute(delete(User).where(col(User.id) > 1))).rowcount > 0: await session.commit() assert await User.get_verified_user_id(session, 'admin', 'admin') == 1 @@ -52,7 +52,7 @@ def test_generate_token(self): @pytest.mark.asyncio(scope='session') async def test_delete_by_name(self): async with async_session() as session: - if (await session.execute(delete(User).where(User.id > 1))).rowcount > 0: # type: ignore + if (await session.execute(delete(User).where(col(User.id) > 1))).rowcount > 0: await session.commit() assert await User.delete_by_name(session, 'test') == 0