Skip to content

Commit

Permalink
feat: add denied domain check (#21)
Browse files Browse the repository at this point in the history
closes #12

---------

Co-authored-by: Hans Keeler <[email protected]>
  • Loading branch information
lchen-2101 and hkeeler authored Sep 14, 2023
1 parent a656c63 commit 98635a0
Show file tree
Hide file tree
Showing 12 changed files with 218 additions and 56 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ poetry run uvicorn main:app --reload --port 8888
```
### Local development notes
- [.env.template](.env.template) is added to allow VS Code to search the correct path for imports when writing tests, just copy the [.env.template](.env.template) file into `.env` file locally
- [src/.env.template](./src/.env.template) is added as the template for the app's environment variables, appropriate values are already provided in [.env.local](./src/.env.local) for local development. If `ENV` variable with default of `LOCAL` is changed, copy this template into `src/.env`, and provide appropriate values, and set all listed empty variables in the environment.

---
## Retrieve credentials
Expand Down
16 changes: 16 additions & 0 deletions src/.env.template
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
KC_URL=
KC_REALM=
KC_ADMIN_CLIENT_ID=
KC_ADMIN_CLIENT_SECRET=
KC_REALM_URL=${KC_URL}/realms/${KC_REALM}
KC_REALM_ADMIN_URL=${KC_URL}/admin/realms/${KC_REALM}
AUTH_URL=${KC_REALM_URL}/protocol/openid-connect/auth
TOKEN_URL=${KC_REALM_URL}/protocol/openid-connect/token
CERTS_URL=${KC_REALM_URL}/protocol/openid-connect/certs
AUTH_CLIENT=
INST_DB_NAME=
INST_DB_USER=
INST_DB_PWD=
INST_DB_HOST=
INST_DB_SCHEMA=
INST_CONN=postgresql+asyncpg://${INST_DB_USER}:${INST_DB_PWD}@${INST_DB_HOST}/${INST_DB_NAME}
37 changes: 37 additions & 0 deletions src/dependencies.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from http import HTTPStatus
from typing import Annotated
from fastapi import Depends, HTTPException, Request
from sqlalchemy.ext.asyncio import AsyncSession

from entities.engine import get_session
from entities.repos import institutions_repo as repo

OPEN_DOMAIN_REQUESTS = {
"/v1/admin/me": {"GET"},
"/v1/institutions": {"GET"},
"/v1/institutions/domains/allowed": {"GET"},
}


async def check_domain(
request: Request, session: Annotated[AsyncSession, Depends(get_session)]
) -> None:
if request_needs_domain_check(request):
if not request.user.is_authenticated:
raise HTTPException(status_code=HTTPStatus.FORBIDDEN)
if await email_domain_denied(session, request.user.email):
raise HTTPException(
status_code=HTTPStatus.FORBIDDEN, detail="email domain denied"
)


def request_needs_domain_check(request: Request) -> bool:
path = request.scope["path"].rstrip("/")
return not (
path in OPEN_DOMAIN_REQUESTS
and request.scope["method"] in OPEN_DOMAIN_REQUESTS[path]
)


async def email_domain_denied(session: AsyncSession, email: str) -> bool:
return not await repo.is_email_domain_allowed(session, email)
10 changes: 9 additions & 1 deletion src/entities/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,20 @@
"FinancialInstitutionWithDomainsDto",
"FinancialInsitutionDomainDto",
"FinancialInsitutionDomainCreate",
"DeniedDomainDao",
"DeniedDomainDto",
]

from .dao import Base, FinancialInstitutionDao, FinancialInstitutionDomainDao
from .dao import (
Base,
FinancialInstitutionDao,
FinancialInstitutionDomainDao,
DeniedDomainDao,
)
from .dto import (
FinancialInstitutionDto,
FinancialInstitutionWithDomainsDto,
FinancialInsitutionDomainDto,
FinancialInsitutionDomainCreate,
DeniedDomainDto,
)
5 changes: 5 additions & 0 deletions src/entities/models/dao.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,8 @@ class FinancialInstitutionDomainDao(AuditMixin, Base):
ForeignKey("financial_institutions.lei"), index=True, primary_key=True
)
fi = relationship("FinancialInstitutionDao", back_populates="domains")


class DeniedDomainDao(AuditMixin, Base):
__tablename__ = "denied_domains"
domain: Mapped[str] = mapped_column(index=True, primary_key=True)
7 changes: 7 additions & 0 deletions src/entities/models/dto.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,10 @@ class Config:

class FinancialInstitutionWithDomainsDto(FinancialInstitutionDto):
domains: List[FinancialInsitutionDomainDto] = []


class DeniedDomainDto(BaseModel):
domain: str

class Config:
orm_mode = True
19 changes: 18 additions & 1 deletion src/entities/repos/institutions_repo.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import List

from sqlalchemy import select
from sqlalchemy import select, func
from sqlalchemy.orm import joinedload
from sqlalchemy.ext.asyncio import AsyncSession

Expand All @@ -9,6 +9,7 @@
FinancialInstitutionDomainDao,
FinancialInstitutionDto,
FinancialInsitutionDomainCreate,
DeniedDomainDao,
)


Expand Down Expand Up @@ -73,3 +74,19 @@ async def add_domains(
session.add_all(daos)
await session.commit()
return daos


async def is_email_domain_allowed(session: AsyncSession, email: str) -> bool:
domain = get_email_domain(email)
if domain:
async with session:
stmt = select(func.count()).filter(DeniedDomainDao.domain == domain)
res = await session.scalar(stmt)
return res == 0
return False


def get_email_domain(email: str) -> str:
if email:
return email.split("@")[-1]
return None
9 changes: 5 additions & 4 deletions src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,20 @@
import logging
import env # noqa: F401
from http import HTTPStatus
from fastapi import FastAPI, HTTPException, Request
from fastapi import FastAPI, HTTPException, Request, Depends
from fastapi.responses import JSONResponse
from fastapi.security import OAuth2AuthorizationCodeBearer
from fastapi.middleware.cors import CORSMiddleware
from starlette.middleware.authentication import AuthenticationMiddleware
from dependencies import check_domain

from routers import admin_router, institutions_router

from oauth2 import BearerTokenAuthBackend

log = logging.getLogger()

app = FastAPI()
app = FastAPI(dependencies=[Depends(check_domain)])


@app.exception_handler(HTTPException)
Expand All @@ -23,7 +24,7 @@ async def http_exception_handler(
) -> JSONResponse:
log.error(exception, exc_info=True, stack_info=True)
return JSONResponse(
status_code=exception.status_code, content={"message": exception.detail}
status_code=exception.status_code, content={"detail": exception.detail}
)


Expand All @@ -34,7 +35,7 @@ async def general_exception_handler(
log.error(exception, exc_info=True, stack_info=True)
return JSONResponse(
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
content={"message": "server error"},
content={"detail": "server error"},
)


Expand Down
30 changes: 20 additions & 10 deletions src/routers/institutions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from http import HTTPStatus
from oauth2 import oauth2_admin
from util import Router
from typing import List, Tuple
from typing import Annotated, List, Tuple
from entities.engine import get_session
from entities.repos import institutions_repo as repo
from entities.models import (
Expand All @@ -14,7 +14,14 @@
from sqlalchemy.ext.asyncio import AsyncSession
from starlette.authentication import requires

router = Router()

async def set_db(
request: Request, session: Annotated[AsyncSession, Depends(get_session)]
):
request.state.db_session = session


router = Router(dependencies=[Depends(set_db)])


@router.get("/", response_model=List[FinancialInstitutionWithDomainsDto])
Expand All @@ -24,29 +31,28 @@ async def get_institutions(
domain: str = "",
page: int = 0,
count: int = 100,
session: AsyncSession = Depends(get_session),
):
return await repo.get_institutions(session, domain, page, count)
return await repo.get_institutions(request.state.db_session, domain, page, count)


@router.post("/", response_model=Tuple[str, FinancialInstitutionDto])
@requires(["query-groups", "manage-users"])
async def create_institution(
request: Request,
fi: FinancialInstitutionDto,
session: AsyncSession = Depends(get_session),
):
db_fi = await repo.upsert_institution(session, fi)
db_fi = await repo.upsert_institution(request.state.db_session, fi)
kc_id = oauth2_admin.upsert_group(fi.lei, fi.name)
return kc_id, db_fi


@router.get("/{lei}", response_model=FinancialInstitutionWithDomainsDto)
@requires("authenticated")
async def get_institution(
request: Request, lei: str, session: AsyncSession = Depends(get_session)
request: Request,
lei: str,
):
res = await repo.get_institution(session, lei)
res = await repo.get_institution(request.state.db_session, lei)
if not res:
raise HTTPException(HTTPStatus.NOT_FOUND, f"{lei} not found.")
return res
Expand All @@ -58,6 +64,10 @@ async def add_domains(
request: Request,
lei: str,
domains: List[FinancialInsitutionDomainCreate],
session: AsyncSession = Depends(get_session),
):
return await repo.add_domains(session, lei, domains)
return await repo.add_domains(request.state.db_session, lei, domains)


@router.get("/domains/allowed", response_model=bool)
async def is_domain_allowed(request: Request, domain: str):
return await repo.is_email_domain_allowed(request.state.db_session, domain)
10 changes: 7 additions & 3 deletions tests/api/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,20 @@

import pytest
from fastapi import FastAPI
from pytest_mock import MockerFixture, MockFixture
from pytest_mock import MockerFixture
from starlette.authentication import AuthCredentials, UnauthenticatedUser

from oauth2.oauth2_backend import AuthenticatedUser


@pytest.fixture
def app_fixture(mocker: MockFixture) -> FastAPI:
def app_fixture(mocker: MockerFixture) -> FastAPI:
mocked_engine = mocker.patch("sqlalchemy.ext.asyncio.create_async_engine")
MockedEngine = mocker.patch("sqlalchemy.ext.asyncio.AsyncEngine")
mocked_engine.return_value = MockedEngine.return_value
mocker.patch("fastapi.security.OAuth2AuthorizationCodeBearer")
domain_denied_mock = mocker.patch("dependencies.email_domain_denied")
domain_denied_mock.return_value = False
from main import app

return app
Expand All @@ -33,7 +35,9 @@ def authed_user_mock(auth_mock: Mock) -> Mock:
"sub": "testuser123",
}
auth_mock.return_value = (
AuthCredentials(["manage-account", "query-groups", "manage-users", "authenticated"]),
AuthCredentials(
["manage-account", "query-groups", "manage-users", "authenticated"]
),
AuthenticatedUser.from_claim(claims),
)
return auth_mock
Expand Down
Loading

0 comments on commit 98635a0

Please sign in to comment.