Skip to content

Commit

Permalink
feat: update sqlalchemy to use standard api instead of async one, sta…
Browse files Browse the repository at this point in the history
…ndard has better documentation, more community support, and cleaner code (#276)

closes #275
  • Loading branch information
lchen-2101 authored Jan 3, 2025
1 parent 344e0a7 commit c30117c
Show file tree
Hide file tree
Showing 16 changed files with 316 additions and 420 deletions.
296 changes: 131 additions & 165 deletions poetry.lock

Large diffs are not rendered by default.

2 changes: 0 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@ regtech-regex = {git = "https://github.com/cfpb/regtech-regex.git"}

[tool.poetry.group.dev.dependencies]
pytest = "^8.3.3"
pytest-asyncio = "^0.24.0"
aiosqlite = "^0.20.0"
pytest-cov = "^6.0.0"
pytest-mock = "^3.11.1"
pytest-env = "^1.1.4"
Expand Down
2 changes: 1 addition & 1 deletion src/regtech_user_fi_management/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class Settings(BaseSettings):
inst_db_user: str
inst_db_pwd: str
inst_db_host: str
inst_db_scheme: str = "postgresql+asyncpg"
inst_db_scheme: str = "postgresql+psycopg2"
inst_conn: str | None = None
admin_scopes: Set[str] = set(["query-groups", "manage-users"])
db_logging: bool = True
Expand Down
10 changes: 5 additions & 5 deletions src/regtech_user_fi_management/dependencies.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,24 @@
from http import HTTPStatus
from typing import Annotated
from fastapi import Depends, Request
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session

from regtech_user_fi_management.entities.engine.engine import get_session
import regtech_user_fi_management.entities.repos.institutions_repo as repo
from regtech_api_commons.api.exceptions import RegTechHttpException
from regtech_api_commons.api.dependencies import get_email_domain


async def check_domain(request: Request, session: Annotated[AsyncSession, Depends(get_session)]) -> None:
def check_domain(request: Request, session: Annotated[Session, Depends(get_session)]) -> None:
if not request.user.is_authenticated:
raise RegTechHttpException(
status_code=HTTPStatus.FORBIDDEN, name="Request Forbidden", detail="unauthenticated user"
)
if await email_domain_denied(session, get_email_domain(request.user.email)):
if email_domain_denied(session, get_email_domain(request.user.email)):
raise RegTechHttpException(
status_code=HTTPStatus.FORBIDDEN, name="Request Forbidden", detail="email domain denied"
)


async def email_domain_denied(session: AsyncSession, email: str) -> bool:
return not await repo.is_domain_allowed(session, email)
def email_domain_denied(session: Session, email: str) -> bool:
return not repo.is_domain_allowed(session, email)
16 changes: 6 additions & 10 deletions src/regtech_user_fi_management/entities/engine/engine.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,16 @@
from sqlalchemy.ext.asyncio import (
create_async_engine,
async_sessionmaker,
async_scoped_session,
)
from asyncio import current_task
from sqlalchemy import create_engine
from sqlalchemy.orm import scoped_session, sessionmaker
from regtech_user_fi_management.config import settings

engine = create_async_engine(str(settings.inst_conn), echo=settings.db_logging).execution_options(
engine = create_engine(str(settings.inst_conn), echo=settings.db_logging).execution_options(
schema_translate_map={None: settings.inst_db_schema}
)
SessionLocal = async_scoped_session(async_sessionmaker(engine, expire_on_commit=False), current_task)
SessionLocal = scoped_session(sessionmaker(engine, expire_on_commit=False))


async def get_session():
def get_session():
session = SessionLocal()
try:
yield session
finally:
await session.close()
session.close()
11 changes: 3 additions & 8 deletions src/regtech_user_fi_management/entities/listeners.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,14 +67,9 @@ def _insert_history(
return _insert_history


async def setup_dao_listeners():
async with engine.begin() as connection:
fi_history, mapping_history = await connection.run_sync(
lambda conn: (
Table("financial_institutions_history", Base.metadata, autoload_with=conn),
Table("fi_to_type_mapping_history", Base.metadata, autoload_with=conn),
)
)
def setup_dao_listeners():
fi_history = Table("financial_institutions_history", Base.metadata, autoload_with=engine)
mapping_history = Table("fi_to_type_mapping_history", Base.metadata, autoload_with=engine)

insert_fi_history = _setup_fi_history(fi_history, mapping_history)

Expand Down
3 changes: 1 addition & 2 deletions src/regtech_user_fi_management/entities/models/dao.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,9 @@
from typing import List
from sqlalchemy import ForeignKey, func, String, inspect
from sqlalchemy.orm import Mapped, mapped_column, relationship, DeclarativeBase
from sqlalchemy.ext.asyncio import AsyncAttrs


class Base(AsyncAttrs, DeclarativeBase):
class Base(DeclarativeBase):
pass


Expand Down
124 changes: 49 additions & 75 deletions src/regtech_user_fi_management/entities/repos/institutions_repo.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
from typing import List, Sequence, Set

from sqlalchemy import select, func
from sqlalchemy.orm import joinedload
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session

from regtech_api_commons.models.auth import AuthenticatedUser

from .repo_utils import get_associated_sbl_types, query_type
from .repo_utils import get_associated_sbl_types

from regtech_user_fi_management.entities.models.dao import (
FinancialInstitutionDao,
Expand All @@ -25,76 +23,61 @@
)


async def get_institutions(
session: AsyncSession,
def get_institutions(
session: Session,
leis: List[str] | None = None,
domain: str = "",
page: int = 0,
count: int = 100,
) -> Sequence[FinancialInstitutionDao]:
async with session.begin():
stmt = (
select(FinancialInstitutionDao)
.options(joinedload(FinancialInstitutionDao.domains))
.limit(count)
.offset(page * count)
)
if leis is not None:
stmt = stmt.filter(FinancialInstitutionDao.lei.in_(leis))
elif d := domain.strip():
stmt = stmt.join(FinancialInstitutionDomainDao).filter(FinancialInstitutionDomainDao.domain == d)
res = await session.scalars(stmt)
return res.unique().all()


async def get_institution(session: AsyncSession, lei: str) -> FinancialInstitutionDao | None:
async with session.begin():
stmt = (
select(FinancialInstitutionDao)
.options(joinedload(FinancialInstitutionDao.domains))
.filter(FinancialInstitutionDao.lei == lei)
)
return await session.scalar(stmt)
query = session.query(FinancialInstitutionDao)
if leis is not None:
query = query.filter(FinancialInstitutionDao.lei.in_(leis))
elif d := domain.strip():
query = query.join(FinancialInstitutionDomainDao).filter(FinancialInstitutionDomainDao.domain == d)
return query.limit(count).offset(page * count).all()


def get_institution(session: Session, lei: str) -> FinancialInstitutionDao | None:
return session.get(FinancialInstitutionDao, lei)

async def get_sbl_types(session: AsyncSession) -> Sequence[SBLInstitutionTypeDao]:
return await query_type(session, SBLInstitutionTypeDao)

def get_sbl_types(session: Session) -> Sequence[SBLInstitutionTypeDao]:
return session.query(SBLInstitutionTypeDao).all()

async def get_hmda_types(session: AsyncSession) -> Sequence[HMDAInstitutionTypeDao]:
return await query_type(session, HMDAInstitutionTypeDao)

def get_hmda_types(session: Session) -> Sequence[HMDAInstitutionTypeDao]:
return session.query(HMDAInstitutionTypeDao).all()

async def get_address_states(session: AsyncSession) -> Sequence[AddressStateDao]:
return await query_type(session, AddressStateDao)

def get_address_states(session: Session) -> Sequence[AddressStateDao]:
return session.query(AddressStateDao).all()

async def get_federal_regulators(session: AsyncSession) -> Sequence[FederalRegulatorDao]:
return await query_type(session, FederalRegulatorDao)

def get_federal_regulators(session: Session) -> Sequence[FederalRegulatorDao]:
return session.query(FederalRegulatorDao).all()

async def upsert_institution(
session: AsyncSession, fi: FinancialInstitutionDto, user: AuthenticatedUser

def upsert_institution(
session: Session, fi: FinancialInstitutionDto, user: AuthenticatedUser
) -> FinancialInstitutionDao:
async with session.begin():
fi_data = fi.__dict__.copy()
fi_data.pop("_sa_instance_state", None)
fi_data.pop("version", None)
fi_data = fi.__dict__.copy()
fi_data.pop("_sa_instance_state", None)
fi_data.pop("version", None)

if "sbl_institution_types" in fi_data:
types_association = get_associated_sbl_types(fi.lei, user.id, fi.sbl_institution_types)
fi_data["sbl_institution_types"] = types_association
if "sbl_institution_types" in fi_data:
types_association = get_associated_sbl_types(fi.lei, user.id, fi.sbl_institution_types)
fi_data["sbl_institution_types"] = types_association

db_fi = await session.merge(FinancialInstitutionDao(**fi_data, modified_by=user.id))
await session.flush()
await session.refresh(db_fi)
return db_fi
db_fi = session.merge(FinancialInstitutionDao(**fi_data, modified_by=user.id))
session.commit()
return db_fi


async def update_sbl_types(
session: AsyncSession, user: AuthenticatedUser, lei: str, sbl_types: Sequence[SblTypeAssociationDto | str]
def update_sbl_types(
session: Session, user: AuthenticatedUser, lei: str, sbl_types: Sequence[SblTypeAssociationDto | str]
) -> FinancialInstitutionDao | None:
if fi := await get_institution(session, lei):
if fi := get_institution(session, lei):
new_types = set(get_associated_sbl_types(lei, user.id, sbl_types))
old_types = set(fi.sbl_institution_types)
add_types = new_types.difference(old_types)
Expand All @@ -104,34 +87,25 @@ async def update_sbl_types(
fi.sbl_institution_types.extend(add_types)
for type in fi.sbl_institution_types:
type.version = fi.version
await session.commit()
"""
load the async relational attributes so dto can be properly serialized
"""
for type in fi.sbl_institution_types:
await type.awaitable_attrs.sbl_type
session.commit()
return fi


async def add_domains(
session: AsyncSession, lei: str, domains: List[FinancialInsitutionDomainCreate]
def add_domains(
session: Session, lei: str, domains: List[FinancialInsitutionDomainCreate]
) -> Set[FinancialInstitutionDomainDao]:
async with session.begin():
daos = set(
map(
lambda dto: FinancialInstitutionDomainDao(domain=dto.domain, lei=lei),
domains,
)
daos = set(
map(
lambda dto: FinancialInstitutionDomainDao(domain=dto.domain, lei=lei),
domains,
)
session.add_all(daos)
await session.commit()
return daos
)
session.add_all(daos)
session.commit()
return daos


async def is_domain_allowed(session: AsyncSession, domain: str) -> bool:
def is_domain_allowed(session: Session, domain: str) -> bool:
if domain:
async with session:
stmt = select(func.count()).filter(DeniedDomainDao.domain == domain)
res = await session.scalar(stmt)
return res == 0
return session.get(DeniedDomainDao, domain) is None
return False
11 changes: 1 addition & 10 deletions src/regtech_user_fi_management/entities/repos/repo_utils.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,10 @@
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from typing import Sequence, TypeVar, Type
from typing import Sequence, TypeVar
from regtech_user_fi_management.entities.models.dao import Base, SblTypeMappingDao
from regtech_user_fi_management.entities.models.dto import SblTypeAssociationDto

T = TypeVar("T", bound=Base)


async def query_type(session: AsyncSession, type: Type[T]) -> Sequence[T]:
async with session.begin():
stmt = select(type)
res = await session.scalars(stmt)
return res.all()


def get_associated_sbl_types(
lei: str, user_id: str, types: Sequence[SblTypeAssociationDto | str]
) -> Sequence[SblTypeMappingDao]:
Expand Down
2 changes: 1 addition & 1 deletion src/regtech_user_fi_management/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ async def lifespan(app_: FastAPI):
log.info("Starting up...")
log.info("run alembic upgrade head...")
run_migrations()
await setup_dao_listeners()
setup_dao_listeners()
yield
log.info("Shutting down...")

Expand Down
Loading

0 comments on commit c30117c

Please sign in to comment.