Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: restrict FI data retrieval #120

Merged
merged 2 commits into from
Mar 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/.env.local
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,5 @@ INST_DB_HOST=localhost:5432
INST_DB_SCHEMA=public
JWT_OPTS_VERIFY_AT_HASH="false"
JWT_OPTS_VERIFY_AUD="false"
JWT_OPTS_VERIFY_ISS="false"
JWT_OPTS_VERIFY_ISS="false"
ADMIN_SCOPES=["query-groups","manage-users"]
3 changes: 2 additions & 1 deletion src/config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
from urllib import parse
from typing import Any
from typing import Any, Set

from pydantic import field_validator, ValidationInfo
from pydantic.networks import PostgresDsn
Expand All @@ -24,6 +24,7 @@ class Settings(BaseSettings):
inst_db_host: str
inst_db_scheme: str = "postgresql+asyncpg"
inst_conn: PostgresDsn | None = None
admin_scopes: Set[str] = set(["query-groups", "manage-users"])

def __init__(self, **data):
super().__init__(**data)
Expand Down
60 changes: 57 additions & 3 deletions src/dependencies.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
import functools

from http import HTTPStatus
from typing import Annotated
from fastapi import Depends, HTTPException, Request
from fastapi import Depends, Query, HTTPException, Request, Response
from fastapi.types import DecoratedCallable
from sqlalchemy.ext.asyncio import AsyncSession
from typing import List, Optional
from itertools import chain

from fastapi import Query
from config import settings

from entities.engine import get_session
from entities.repos import institutions_repo as repo
from starlette.authentication import AuthCredentials
from regtech_api_commons.models.auth import AuthenticatedUser


async def check_domain(request: Request, session: Annotated[AsyncSession, Depends(get_session)]) -> None:
Expand Down Expand Up @@ -41,3 +45,53 @@ def get_email_domain(email: str) -> str:
if email:
return email.split("@")[-1]
return None


def is_admin(auth: AuthCredentials):
return settings.admin_scopes.issubset(auth.scopes)


def lei_association_check(func: DecoratedCallable) -> DecoratedCallable:
@functools.wraps(func)
async def wrapper(request: Request, *args, **kwargs) -> Response:
lei = kwargs.get("lei")
user: AuthenticatedUser = request.user
auth: AuthCredentials = request.auth
if not is_admin(auth) and lei not in user.institutions:
raise HTTPException(HTTPStatus.FORBIDDEN, detail=f"LEI {lei} is not associated with the user.")
return await func(request, *args, **kwargs)

return wrapper # type: ignore[return-value]


def fi_search_association_check(func: DecoratedCallable) -> DecoratedCallable:
def verify_leis(user: AuthenticatedUser, leis: List[str]) -> None:
if not set(filter(len, leis)).issubset(set(filter(len, user.institutions))):
raise HTTPException(
HTTPStatus.FORBIDDEN,
detail=f"Institutions query with LEIs ({leis}) not associated with user is forbidden.",
)

def verify_domain(user: AuthenticatedUser, domain: str) -> None:
if domain != get_email_domain(user.email):
raise HTTPException(
HTTPStatus.FORBIDDEN,
detail=f"Institutions query with domain ({domain}) not associated with user is forbidden.",
)

@functools.wraps(func)
async def wrapper(request: Request, *args, **kwargs) -> Response:
user: AuthenticatedUser = request.user
auth: AuthCredentials = request.auth
if not is_admin(auth):
leis = kwargs.get("leis")
domain = kwargs.get("domain")
if leis:
verify_leis(user, leis)
elif domain:
verify_domain(user, domain)
elif not leis and not domain:
raise HTTPException(HTTPStatus.FORBIDDEN, detail="Retrieving institutions without filter is forbidden.")
return await func(request=request, *args, **kwargs)

return wrapper # type: ignore[return-value]
6 changes: 5 additions & 1 deletion src/routers/institutions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from regtech_api_commons.oauth2.oauth2_admin import OAuth2Admin
from config import kc_settings
from regtech_api_commons.api import Router
from dependencies import check_domain, parse_leis, get_email_domain
from dependencies import check_domain, parse_leis, get_email_domain, lei_association_check, fi_search_association_check
from typing import Annotated, List, Tuple, Literal
from entities.engine import get_session
from entities.repos import institutions_repo as repo
Expand Down Expand Up @@ -38,6 +38,7 @@ async def set_db(request: Request, session: Annotated[AsyncSession, Depends(get_

@router.get("/", response_model=List[FinancialInstitutionWithRelationsDto])
@requires("authenticated")
@fi_search_association_check
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

went the decorator / annotation route, how do we feel about this vs dependencies?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks pretty slick to me

async def get_institutions(
request: Request,
leis: List[str] = Depends(parse_leis),
Expand Down Expand Up @@ -98,6 +99,7 @@ async def get_federal_regulators(request: Request):

@router.get("/{lei}", response_model=FinancialInstitutionWithRelationsDto)
@requires("authenticated")
@lei_association_check
async def get_institution(
request: Request,
lei: str,
Expand All @@ -110,6 +112,7 @@ async def get_institution(

@router.get("/{lei}/types/{type}", response_model=VersionedData[List[SblTypeAssociationDetailsDto]] | None)
@requires("authenticated")
@lei_association_check
async def get_types(request: Request, response: Response, lei: str, type: InstitutionType):
match type:
case "sbl":
Expand All @@ -123,6 +126,7 @@ async def get_types(request: Request, response: Response, lei: str, type: Instit

@router.put("/{lei}/types/{type}", response_model=VersionedData[List[SblTypeAssociationDetailsDto]] | None)
@requires("authenticated")
@lei_association_check
async def update_types(
request: Request, response: Response, lei: str, type: InstitutionType, types_patch: SblTypeAssociationPatchDto
):
Expand Down
20 changes: 20 additions & 0 deletions tests/api/routers/test_institutions_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,26 @@ def test_get_institutions_authed(
assert res.status_code == 200
assert res.json()[0].get("name") == "Test Bank 123"

def test_get_institutions_authed_not_admin(
self,
mocker: MockerFixture,
app_fixture: FastAPI,
auth_mock: Mock,
):
claims = {
"name": "test",
"preferred_username": "test_user",
"email": "[email protected]",
"sub": "testuser123",
}
auth_mock.return_value = (
AuthCredentials(["manage-account", "authenticated"]),
AuthenticatedUser.from_claim(claims),
)
client = TestClient(app_fixture)
res = client.get("/v1/institutions/")
assert res.status_code == 403

def test_create_institution_unauthed(self, app_fixture: FastAPI, unauthed_user_mock: Mock):
client = TestClient(app_fixture)
res = client.post("/v1/institutions/", json={"name": "testName", "lei": "testLei"})
Expand Down
27 changes: 27 additions & 0 deletions tests/app/conftest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
from typing import Tuple
from fastapi import Request
import pytest

from pytest_mock import MockerFixture
from starlette.authentication import AuthCredentials
from regtech_api_commons.models import AuthenticatedUser, RegTechUser


@pytest.fixture(autouse=True)
Expand All @@ -10,3 +14,26 @@ def setup(mocker: MockerFixture):
mocked_engine.return_value = MockedEngine.return_value
mocker.patch("fastapi.security.OAuth2AuthorizationCodeBearer")
mocker.patch("entities.engine.get_session")


@pytest.fixture
def mock_auth() -> Tuple[AuthCredentials, RegTechUser]:
creds = AuthCredentials(["manage-account", "authenticated"])
user = AuthenticatedUser.from_claim(
{
"name": "test",
"preferred_username": "test_user",
"email": "[email protected]",
"sub": "testuser123",
"institutions": ["TESTBANK123"],
}
)
return creds, user


@pytest.fixture
def mock_request(mocker: MockerFixture, mock_auth: Tuple[AuthCredentials, RegTechUser]) -> Request:
request: Request = mocker.patch("fastapi.Request").return_value
request.auth = mock_auth[0]
request.user = mock_auth[1]
return request
113 changes: 113 additions & 0 deletions tests/app/test_dependencies.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
from http import HTTPStatus
from typing import List
from fastapi import HTTPException, Request
from pytest_mock import MockerFixture
from sqlalchemy.ext.asyncio import AsyncSession
from dependencies import lei_association_check, fi_search_association_check
from starlette.authentication import AuthCredentials

import pytest

Expand Down Expand Up @@ -29,3 +34,111 @@ async def test_domain_allowed(mocker: MockerFixture, mock_session: AsyncSession)

assert await email_domain_denied(mock_session, allowed_domain) is False
domain_allowed_mock.assert_called_once_with(mock_session, allowed_domain)


async def test_lei_association_check_matching_lei(mock_request: Request):
@lei_association_check
async def method_to_wrap(request: Request, lei: str):
pass

await method_to_wrap(mock_request, lei="TESTBANK123")


async def test_lei_association_check_is_admin(mock_request: Request):
mock_request.auth = AuthCredentials(["manage-account", "query-groups", "manage-users", "authenticated"])

@lei_association_check
async def method_to_wrap(request: Request, lei: str):
pass

await method_to_wrap(mock_request, lei="TESTBANK1234")


async def test_lei_association_check_not_matching(mock_request: Request):
@lei_association_check
async def method_to_wrap(request: Request, lei: str):
pass

with pytest.raises(HTTPException) as e:
await method_to_wrap(mock_request, lei="NOTMYBANK")
assert e.value.status_code == HTTPStatus.FORBIDDEN
assert "not associated" in e.value.detail


async def test_fi_search_association_check_matching_lei(mock_request: Request):
@fi_search_association_check
async def method_to_wrap(request: Request, leis: List[str] = [], domain: str = ""):
pass

await method_to_wrap(mock_request, leis=["TESTBANK123"])


async def test_fi_search_association_check_invalid_lei(mock_request: Request):
@fi_search_association_check
async def method_to_wrap(request: Request, leis: List[str] = [], domain: str = ""):
pass

with pytest.raises(HTTPException) as e:
await method_to_wrap(mock_request, leis=["NOTMYBANK"])
assert e.value.status_code == HTTPStatus.FORBIDDEN
assert "not associated" in e.value.detail


async def test_fi_search_association_check_matching_domain(mock_request: Request):
@fi_search_association_check
async def method_to_wrap(request: Request, leis: List[str] = [], domain: str = ""):
pass

await method_to_wrap(mock_request, domain="local.host")


async def test_fi_search_association_check_invalid_domain(mock_request: Request):
@fi_search_association_check
async def method_to_wrap(request: Request, leis: List[str] = [], domain: str = ""):
pass

with pytest.raises(HTTPException) as e:
await method_to_wrap(mock_request, domain="not.myhost")
assert e.value.status_code == HTTPStatus.FORBIDDEN
assert "not associated" in e.value.detail


async def test_fi_search_association_check_no_filter(mock_request: Request):
@fi_search_association_check
async def method_to_wrap(request: Request, leis: List[str] = [], domain: str = ""):
pass

with pytest.raises(HTTPException) as e:
await method_to_wrap(mock_request)
assert e.value.status_code == HTTPStatus.FORBIDDEN
assert "without filter" in e.value.detail


async def test_fi_search_association_check_lei_admin(mock_request: Request):
mock_request.auth = AuthCredentials(["manage-account", "query-groups", "manage-users", "authenticated"])

@fi_search_association_check
async def method_to_wrap(request: Request, leis: List[str] = [], domain: str = ""):
pass

await method_to_wrap(mock_request, leis=["TESTBANK123", "ANOTHERBANK", "NOTMYBANK"])


async def test_fi_search_association_check_domain_admin(mock_request: Request):
mock_request.auth = AuthCredentials(["manage-account", "query-groups", "manage-users", "authenticated"])

@fi_search_association_check
async def method_to_wrap(request: Request, leis: List[str] = [], domain: str = ""):
pass

await method_to_wrap(mock_request, domain="not.myhost")


async def test_fi_search_association_check_no_filter_admin(mock_request: Request):
mock_request.auth = AuthCredentials(["manage-account", "query-groups", "manage-users", "authenticated"])

@fi_search_association_check
async def method_to_wrap(request: Request, leis: List[str] = [], domain: str = ""):
pass

await method_to_wrap(mock_request)
Loading