Skip to content

Commit

Permalink
feat: refactor user model, and add institutions attribute
Browse files Browse the repository at this point in the history
  • Loading branch information
lchen-2101 committed Sep 14, 2023
1 parent 98635a0 commit 9b840aa
Show file tree
Hide file tree
Showing 7 changed files with 66 additions and 31 deletions.
2 changes: 2 additions & 0 deletions src/entities/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
"FinancialInsitutionDomainCreate",
"DeniedDomainDao",
"DeniedDomainDto",
"AuthenticatedUser",
]

from .dao import (
Expand All @@ -22,4 +23,5 @@
FinancialInsitutionDomainDto,
FinancialInsitutionDomainCreate,
DeniedDomainDto,
AuthenticatedUser,
)
35 changes: 34 additions & 1 deletion src/entities/models/dto.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from typing import List
from typing import Any, Dict, List

from pydantic import BaseModel
from starlette.authentication import BaseUser


class FinancialInsitutionDomainBase(BaseModel):
Expand Down Expand Up @@ -37,3 +39,34 @@ class DeniedDomainDto(BaseModel):

class Config:
orm_mode = True


class AuthenticatedUser(BaseUser, BaseModel):
claims: Dict[str, Any]
name: str
username: str
email: str
id: str
institutions: List[str]

@classmethod
def from_claim(cls, claims: Dict[str, Any]) -> "AuthenticatedUser":
return cls(
claims=claims,
name=claims.get("name", ""),
username=claims.get("preferred_username", ""),
email=claims.get("email", ""),
id=claims.get("sub", ""),
institutions=cls.parse_institutions(claims.get("institutions")),
)

@classmethod
def parse_institutions(cls, institutions: List[str] | None) -> List[str]:
if institutions:
return list(map(lambda institution: institution.lstrip("/"), institutions))
else:
return []

@property
def is_authenticated(self) -> bool:
return True
4 changes: 2 additions & 2 deletions src/oauth2/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__all__ = ["oauth2_admin", "BearerTokenAuthBackend", "AuthenticatedUser"]
__all__ = ["oauth2_admin", "BearerTokenAuthBackend"]

from .oauth2_admin import oauth2_admin
from .oauth2_backend import BearerTokenAuthBackend, AuthenticatedUser
from .oauth2_backend import BearerTokenAuthBackend
25 changes: 2 additions & 23 deletions src/oauth2/oauth2_backend.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import logging
from typing import Coroutine, Any, Dict, List, Tuple
from fastapi import HTTPException
from pydantic import BaseModel
from starlette.authentication import (
AuthCredentials,
AuthenticationBackend,
Expand All @@ -11,33 +10,13 @@
from fastapi.security import OAuth2AuthorizationCodeBearer
from starlette.requests import HTTPConnection

from entities.models import AuthenticatedUser

from .oauth2_admin import oauth2_admin

log = logging.getLogger(__name__)


class AuthenticatedUser(BaseUser, BaseModel):
claims: Dict[str, Any]
name: str | None
username: str | None
email: str | None
id: str | None

@classmethod
def from_claim(cls, claims: Dict[str, Any]) -> "AuthenticatedUser":
return cls(
claims=claims,
name=claims.get("name"),
username=claims.get("preferred_username"),
email=claims.get("email"),
id=claims.get("sub"),
)

@property
def is_authenticated(self) -> bool:
return True


class BearerTokenAuthBackend(AuthenticationBackend):
def __init__(self, token_bearer: OAuth2AuthorizationCodeBearer) -> None:
self.token_bearer = token_bearer
Expand Down
3 changes: 2 additions & 1 deletion src/routers/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
from starlette.authentication import requires
from util import Router

from oauth2 import AuthenticatedUser, oauth2_admin
from entities.models import AuthenticatedUser
from oauth2 import oauth2_admin

router = Router()

Expand Down
2 changes: 1 addition & 1 deletion tests/api/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from pytest_mock import MockerFixture
from starlette.authentication import AuthCredentials, UnauthenticatedUser

from oauth2.oauth2_backend import AuthenticatedUser
from entities.models import AuthenticatedUser


@pytest.fixture
Expand Down
26 changes: 23 additions & 3 deletions tests/api/routers/test_admin_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from pytest_mock import MockerFixture
from starlette.authentication import AuthCredentials

from oauth2.oauth2_backend import AuthenticatedUser
from entities.models import AuthenticatedUser


class TestAdminApi:
Expand All @@ -14,13 +14,33 @@ def test_get_me_unauthed(self, app_fixture: FastAPI, unauthed_user_mock: Mock):
res = client.get("/v1/admin/me")
assert res.status_code == 403

def test_get_me_authed(
self, mocker: MockerFixture, app_fixture: FastAPI, authed_user_mock: Mock
def test_get_me_authed_with_no_institutions(
self, app_fixture: FastAPI, authed_user_mock: Mock
):
client = TestClient(app_fixture)
res = client.get("/v1/admin/me")
assert res.status_code == 200
assert res.json().get("name") == "test"
assert res.json().get("institutions") == []

def test_get_me_authed_with_institutions(
self, app_fixture: FastAPI, auth_mock: Mock
):
claims = {
"name": "test",
"preferred_username": "test_user",
"email": "[email protected]",
"sub": "testuser123",
"institutions": ["/TEST1LEI", "/TEST2LEI"],
}
auth_mock.return_value = (
AuthCredentials(["authenticated"]),
AuthenticatedUser.from_claim(claims),
)
client = TestClient(app_fixture)
res = client.get("/v1/admin/me")
assert res.status_code == 200
assert res.json().get("institutions") == ["TEST1LEI", "TEST2LEI"]

def test_update_me_unauthed(self, app_fixture: FastAPI, unauthed_user_mock: Mock):
client = TestClient(app_fixture)
Expand Down

0 comments on commit 9b840aa

Please sign in to comment.