Skip to content

Commit

Permalink
reduce type ignore
Browse files Browse the repository at this point in the history
  • Loading branch information
keakon committed May 13, 2024
1 parent 204b4e5 commit 01ec996
Show file tree
Hide file tree
Showing 8 changed files with 76 additions and 42 deletions.
5 changes: 3 additions & 2 deletions app/controllers/hello.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from fastapi import Depends
from sqlmodel import col

from app.clients.mysql import async_session
from app.models.user import get_current_user_id, User
from app.models.user import User, get_current_user_id
from app.router import router
from app.schemas.resp import Resp

Expand All @@ -14,5 +15,5 @@ def hello(user_name: str):
@router.get('/hello', response_model=Resp, response_model_exclude_none=True)
async def hello_to_self(current_user_id: int = Depends(get_current_user_id)):
async with async_session() as session:
user_name = await User.get_by_id(session, current_user_id, User.name) # type: ignore
user_name = await User.get_by_id(session, current_user_id, col(User.name))
return Resp(msg=f'Hello, {user_name}!')
3 changes: 2 additions & 1 deletion app/controllers/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from fastapi import Body, Depends
from fastapi.security import OAuth2PasswordRequestForm
from sqlalchemy import Row
from sqlmodel import col

from app.clients.mysql import async_session
from app.models import all_is_instance
Expand Down Expand Up @@ -57,7 +58,7 @@ async def update_user(user_id: int, req: UserRequest, current_user_id: int = Dep
@router.get('/user/{user_id}/name', response_model=Resp, response_model_exclude_none=True)
async def get_user_name(user_id: int, _=Depends(get_current_user_id)):
async with async_session() as session:
user_name = await User.get_by_id(session, user_id, User.name) # type: ignore
user_name = await User.get_by_id(session, user_id, col(User.name))
if user_name:
return Resp(data={'name': user_name})
raise not_found_error
Expand Down
17 changes: 9 additions & 8 deletions app/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@

from sqlalchemy import Column, Row
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Mapped
from sqlalchemy.orm.attributes import InstrumentedAttribute
from sqlalchemy.sql import func, text
from sqlalchemy.sql.elements import TextClause
from sqlmodel import Field, SQLModel, delete, insert, select, update
from sqlmodel import Field, SQLModel, col, delete, insert, select, update

Values = dict[str, Any]

Expand All @@ -26,7 +27,7 @@ async def get_by_id(
cls,
session: AsyncSession,
id: int,
columns: list | tuple | InstrumentedAttribute | TextClause | Column | None = None,
columns: list | tuple | InstrumentedAttribute | TextClause | Column | Mapped | None = None,
for_update: bool = False,
for_read: bool = False,
) -> Any:
Expand Down Expand Up @@ -55,7 +56,7 @@ async def get_by_ids(
cls,
session: AsyncSession,
ids: Sequence[int],
columns: list | tuple | InstrumentedAttribute | TextClause | Column | None = None,
columns: list | tuple | InstrumentedAttribute | TextClause | Column | Mapped | None = None,
for_update: bool = False,
for_read: bool = False,
) -> 'Sequence[BaseModel | Row]':
Expand All @@ -71,7 +72,7 @@ async def get_by_ids(
else:
query = select(columns)
scalar = True
query = query.where(cls.id.in_(ids)) # type: ignore
query = query.where(col(cls.id).in_(ids))
if for_update:
query = query.with_for_update()
elif for_read:
Expand All @@ -94,7 +95,7 @@ async def exist(cls, session: AsyncSession, id: int, for_update: bool = False, f
async def get_all(
cls,
session: AsyncSession,
columns: list | tuple | InstrumentedAttribute | None = None,
columns: list | tuple | InstrumentedAttribute | Mapped | None = None,
for_update: bool = False,
for_read: bool = False,
) -> 'Sequence[BaseModel | Row]':
Expand Down Expand Up @@ -123,15 +124,15 @@ async def count_all(cls, session: AsyncSession) -> int:

@classmethod
async def update_by_id(cls, session: AsyncSession, id: int, values: Values) -> int:
return (await session.execute(update(cls).where(cls.id == id).values(**values))).rowcount # type: ignore
return (await session.execute(update(cls).where(col(cls.id) == id).values(**values))).rowcount

@classmethod
async def delete_by_id(cls, session: AsyncSession, id: int) -> int:
return (await session.execute(delete(cls).where(cls.id == id))).rowcount # type: ignore
return (await session.execute(delete(cls).where(col(cls.id) == id))).rowcount

@classmethod
async def delete_by_ids(cls, session: AsyncSession, ids: Sequence[int]) -> int:
return (await session.execute(delete(cls).where(cls.id.in_(ids)))).rowcount # type: ignore
return (await session.execute(delete(cls).where(col(cls.id).in_(ids)))).rowcount

@classmethod
async def insert(cls, session: AsyncSession, values: Values) -> int:
Expand Down
5 changes: 5 additions & 0 deletions app/utils/exception.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,11 @@ def unauthorized_error(msg) -> HTTPError:
invalid_token_error = unauthorized_error('Invalid token')
not_authenticated_error = unauthorized_error('Not authenticated')


def bad_request_error(msg) -> HTTPError:
return HTTPError(status_code=400, code=ErrorCode.BAD_REQUEST, msg=msg)


forbidden_error = HTTPError(
status_code=403,
code=ErrorCode.FORBIDDEN,
Expand Down
4 changes: 2 additions & 2 deletions tests/controllers/test_hello.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import pytest
from sqlmodel import delete
from sqlmodel import col, delete

from app.clients.mysql import async_session
from app.models.user import User
Expand All @@ -17,7 +17,7 @@ def test_hello():
@pytest.mark.asyncio(scope='session')
async def test_hello_to_self():
async with async_session() as session:
if (await session.execute(delete(User).where(User.id > 1))).rowcount > 0: # type: ignore
if (await session.execute(delete(User).where(col(User.id) > 1))).rowcount > 0:
await session.commit()

async with async_client() as client:
Expand Down
34 changes: 22 additions & 12 deletions tests/controllers/test_user.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import pytest
from sqlalchemy.exc import IntegrityError
from sqlmodel import delete
from sqlmodel import col, delete

from app.clients.mysql import async_session
from app.models.user import User
from app.schemas.token import TokenPayload
from app.schemas.user import UserRequest
from app.clients.mysql import async_session
from app.utils.token import decode_token

from . import async_client
Expand Down Expand Up @@ -63,7 +63,7 @@ async def test_get_user():
@pytest.mark.asyncio(scope='session')
async def test_update_user():
async with async_session() as session:
if (await session.execute(delete(User).where(User.id > 1))).rowcount > 0: # type: ignore
if (await session.execute(delete(User).where(col(User.id) > 1))).rowcount > 0:
await session.commit()

async with async_client() as client:
Expand All @@ -81,14 +81,18 @@ async def test_update_user():
user_id = payload.user_id

data = UserRequest(name='test2', password='test2').model_dump()
response = await client.put(f'/api/v1/user/{user_id}', headers={'Authorization': f'Bearer {access_token}'}, json=data)
response = await client.put(
f'/api/v1/user/{user_id}', headers={'Authorization': f'Bearer {access_token}'}, json=data
)
assert response.status_code == 403

response = await client.post('/api/v1/login', data={'username': 'admin', 'password': 'admin'})
assert response.status_code == 200
access_token = response.json()['access_token']

response = await client.put(f'/api/v1/user/{user_id}', headers={'Authorization': f'Bearer {access_token}'}, json=data)
response = await client.put(
f'/api/v1/user/{user_id}', headers={'Authorization': f'Bearer {access_token}'}, json=data
)
assert response.status_code == 200

response = await client.post('/api/v1/login', data={'username': 'test', 'password': 'test'})
Expand All @@ -104,7 +108,7 @@ async def test_update_user():
@pytest.mark.asyncio(scope='session')
async def test_get_user_name():
async with async_session() as session:
if (await session.execute(delete(User).where(User.id > 1))).rowcount > 0: # type: ignore
if (await session.execute(delete(User).where(col(User.id) > 1))).rowcount > 0:
await session.commit()

async with async_client() as client:
Expand Down Expand Up @@ -136,7 +140,7 @@ async def test_get_user_name():
@pytest.mark.asyncio(scope='session')
async def test_set_user_name():
async with async_session() as session:
if (await session.execute(delete(User).where(User.id > 1))).rowcount > 0: # type: ignore
if (await session.execute(delete(User).where(col(User.id) > 1))).rowcount > 0:
await session.commit()

async with async_client() as client:
Expand All @@ -153,7 +157,9 @@ async def test_set_user_name():
payload = TokenPayload.model_validate_json(token.payload)
user_id = payload.user_id

response = await client.patch(f'/api/v1/user/{user_id}/name', headers={'Authorization': f'Bearer {access_token}'}, json={'name': 'test2'})
response = await client.patch(
f'/api/v1/user/{user_id}/name', headers={'Authorization': f'Bearer {access_token}'}, json={'name': 'test2'}
)
assert response.status_code == 200

response = await client.post('/api/v1/login', data={'username': 'test', 'password': 'test'})
Expand All @@ -162,14 +168,16 @@ async def test_set_user_name():
response = await client.post('/api/v1/login', data={'username': 'test2', 'password': 'test'})
assert response.status_code == 200

response = await client.patch('/api/v1/user/0/name', headers={'Authorization': f'Bearer {access_token}'}, json={'name': 'test2'})
response = await client.patch(
'/api/v1/user/0/name', headers={'Authorization': f'Bearer {access_token}'}, json={'name': 'test2'}
)
assert response.status_code == 404


@pytest.mark.asyncio(scope='session')
async def test_get_user_time():
async with async_session() as session:
if (await session.execute(delete(User).where(User.id > 1))).rowcount > 0: # type: ignore
if (await session.execute(delete(User).where(col(User.id) > 1))).rowcount > 0:
await session.commit()

async with async_client() as client:
Expand All @@ -193,7 +201,9 @@ async def test_get_user_time():
updated_at = data['updated_at']
assert created_at == updated_at

response = await client.patch(f'/api/v1/user/{user_id}/name', headers={'Authorization': f'Bearer {access_token}'}, json={'name': 'test2'})
response = await client.patch(
f'/api/v1/user/{user_id}/name', headers={'Authorization': f'Bearer {access_token}'}, json={'name': 'test2'}
)
assert response.status_code == 200

response = await client.get(f'/api/v1/user/{user_id}/time', headers={'Authorization': f'Bearer {access_token}'})
Expand All @@ -211,7 +221,7 @@ async def test_get_user_time():
@pytest.mark.asyncio(scope='session')
async def test_get_user_list():
async with async_session() as session:
if (await session.execute(delete(User).where(User.id > 1))).rowcount > 0: # type: ignore
if (await session.execute(delete(User).where(col(User.id) > 1))).rowcount > 0:
await session.commit()

async with async_client() as client:
Expand Down
44 changes: 30 additions & 14 deletions tests/models/test_base_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pytest
from sqlalchemy import Column, Row
from sqlalchemy.sql import text
from sqlmodel import col

from app.clients.mysql import async_session
from app.models import BaseModel, all_is_instance
Expand All @@ -14,7 +15,7 @@ class Model(BaseModel, table=True):
class TestBaseModel:
async def test_get_by_id(self):
async with async_session() as session:
await session.execute(text('TRUNCATE TABLE model'))
await session.execute(text(f'TRUNCATE TABLE {Model.__tablename__}'))

model = await Model.get_by_id(session, 1)
assert model is None
Expand Down Expand Up @@ -55,6 +56,9 @@ async def test_get_by_id(self):
name = await Model.get_by_id(session, 1, Column('name'))
assert name == 'test'

name = await Model.get_by_id(session, 1, col(Model.name))
assert name == 'test'

row = await Model.get_by_id(session, 1, (Model.name,))
assert isinstance(row, Row)
assert row.name == 'test'
Expand All @@ -74,7 +78,7 @@ async def test_get_by_id(self):

async def test_get_by_ids(self):
async with async_session() as session:
await session.execute(text('TRUNCATE TABLE model'))
await session.execute(text(f'TRUNCATE TABLE {Model.__tablename__}'))

models = await Model.get_by_ids(session, (1, 2, 3))
assert len(models) == 0
Expand Down Expand Up @@ -152,6 +156,12 @@ async def test_get_by_ids(self):
assert row0 == 'test'
assert row1 == 'test2'

rows = await Model.get_by_ids(session, (1, 2, 3), col(Model.name))
assert len(rows) == 2
row0, row1 = rows
assert row0 == 'test'
assert row1 == 'test2'

rows = await Model.get_by_ids(session, (1, 2, 3), (Model.name,))
assert len(rows) == 2
assert all_is_instance(rows, Row)
Expand Down Expand Up @@ -181,7 +191,7 @@ async def test_get_by_ids(self):

async def test_exist(self):
async with async_session() as session:
await session.execute(text('TRUNCATE TABLE model'))
await session.execute(text(f'TRUNCATE TABLE {Model.__tablename__}'))

assert not await Model.exist(session, 1)

Expand All @@ -191,7 +201,7 @@ async def test_exist(self):

async def test_get_all(self):
async with async_session() as session:
await session.execute(text('TRUNCATE TABLE model'))
await session.execute(text(f'TRUNCATE TABLE {Model.__tablename__}'))

models = await Model.get_all(session)
assert len(models) == 0
Expand Down Expand Up @@ -239,6 +249,12 @@ async def test_get_all(self):
assert row0 == 'test'
assert row1 == 'test2'

rows = await Model.get_all(session, col(Model.name))
assert len(rows) == 2
row0, row1 = rows
assert row0 == 'test'
assert row1 == 'test2'

rows = await Model.get_all(session, Model.name, for_update=True) # type: ignore
assert len(rows) == 2
row0, row1 = rows
Expand Down Expand Up @@ -280,7 +296,7 @@ async def test_get_all(self):

async def test_count_all(self):
async with async_session() as session:
await session.execute(text('TRUNCATE TABLE model'))
await session.execute(text(f'TRUNCATE TABLE {Model.__tablename__}'))

count = await Model.count_all(session)
assert count == 0
Expand All @@ -294,7 +310,7 @@ async def test_count_all(self):

async def test_update_by_id(self):
async with async_session() as session:
await session.execute(text('TRUNCATE TABLE model'))
await session.execute(text(f'TRUNCATE TABLE {Model.__tablename__}'))

count = await Model.update_by_id(session, 1, {'name': 'test2'})
assert count == 0
Expand All @@ -304,17 +320,17 @@ async def test_update_by_id(self):

count = await Model.update_by_id(session, 1, {'name': 'test2'})
assert count == 1
name = await Model.get_by_id(session, 1, Model.name) # type: ignore
name = await Model.get_by_id(session, 1, col(Model.name))
assert name == 'test2'

count = await Model.update_by_id(session, 1, {Model.id.name: 2}) # type: ignore
assert count == 1
name = await Model.get_by_id(session, 2, Model.name) # type: ignore
name = await Model.get_by_id(session, 2, col(Model.name))
assert name == 'test2'

async def test_delete_by_id(self):
async with async_session() as session:
await session.execute(text('TRUNCATE TABLE model'))
await session.execute(text(f'TRUNCATE TABLE {Model.__tablename__}'))

count = await Model.delete_by_id(session, 1)
assert count == 0
Expand All @@ -328,7 +344,7 @@ async def test_delete_by_id(self):

async def test_delete_by_ids(self):
async with async_session() as session:
await session.execute(text('TRUNCATE TABLE model'))
await session.execute(text(f'TRUNCATE TABLE {Model.__tablename__}'))

count = await Model.delete_by_ids(session, (1, 2))
assert count == 0
Expand All @@ -353,21 +369,21 @@ async def test_delete_by_ids(self):

async def test_insert(self):
async with async_session() as session:
await session.execute(text('TRUNCATE TABLE model'))
await session.execute(text(f'TRUNCATE TABLE {Model.__tablename__}'))

row_id = await Model.insert(session, {'name': 'test'})
assert row_id == 1
name = await Model.get_by_id(session, 1, Model.name) # type: ignore
name = await Model.get_by_id(session, 1, col(Model.name))
assert name == 'test'

row_id = await Model.insert(session, {'id': 3, 'name': 'test3'})
assert row_id == 3
name = await Model.get_by_id(session, 3, Model.name) # type: ignore
name = await Model.get_by_id(session, 3, col(Model.name))
assert name == 'test3'

async def test_batch_insert(self):
async with async_session() as session:
await session.execute(text('TRUNCATE TABLE model'))
await session.execute(text(f'TRUNCATE TABLE {Model.__tablename__}'))

count = await Model.batch_insert(session, [{'name': 'test'}, {'name': 'test2'}])
assert count == 2
Expand Down
Loading

0 comments on commit 01ec996

Please sign in to comment.