Skip to content

Commit

Permalink
Merge pull request #285 from chaen/fix_nested_access_policy
Browse files Browse the repository at this point in the history
Fix nested access policy and violent crash
  • Loading branch information
chaen authored Aug 29, 2024
2 parents b570671 + 2903afb commit c2cd486
Show file tree
Hide file tree
Showing 15 changed files with 82 additions and 45 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ jobs:
- name: Start demo
run: |
git clone https://github.com/DIRACGrid/diracx-charts.git ../diracx-charts
../diracx-charts/run_demo.sh --enable-open-telemetry --enable-coverage --exit-when-done --set-value developer.autoReload=false $PWD
../diracx-charts/run_demo.sh --enable-open-telemetry --enable-coverage --exit-when-done --set-value developer.autoReload=false --ci-values ../diracx-charts/demo/ci_values.yaml $PWD
- name: Debugging information
run: |
DIRACX_DEMO_DIR=$PWD/../diracx-charts/.demo
Expand Down
1 change: 0 additions & 1 deletion diracx-core/src/diracx/core/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,6 @@ def __hash__(self):
@cachedmethod(lambda self: self._pull_cache)
def _pull(self):
"""Git pull from remote repo."""
print("CHRIS PULL")
self.repo.remotes.origin.pull()

def latest_revision(self) -> tuple[str, datetime]:
Expand Down
10 changes: 10 additions & 0 deletions diracx-core/src/diracx/core/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,3 +78,13 @@ def create(cls) -> Self:
async def lifetime_function(self) -> AsyncIterator[None]:
"""A context manager that can be used to run code at startup and shutdown."""
yield


class DevelopmentSettings(ServiceSettingsBase):
"""Settings for the Development Configuration that can influence run time."""

model_config = SettingsConfigDict(env_prefix="DIRACX_DEV_")

# When then to true (only for demo/CI), crash if an access policy isn't
# called
crash_on_missed_access_policy: bool = False
11 changes: 10 additions & 1 deletion diracx-routers/src/diracx/routers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,15 @@
from collections.abc import AsyncGenerator
from functools import partial
from logging import Formatter, StreamHandler
from typing import Any, Awaitable, Callable, Iterable, Sequence, TypeVar, cast
from typing import (
Any,
Awaitable,
Callable,
Iterable,
Sequence,
TypeVar,
cast,
)

import dotenv
from cachetools import TTLCache
Expand Down Expand Up @@ -139,6 +147,7 @@ def create_app_inner(
# Please see ServiceSettingsBase for more details

available_settings_classes: set[type[ServiceSettingsBase]] = set()

for service_settings in all_service_settings:
cls = type(service_settings)
assert cls not in available_settings_classes
Expand Down
11 changes: 8 additions & 3 deletions diracx-routers/src/diracx/routers/access_policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from fastapi import Depends

from diracx.core.extensions import select_from_extension
from diracx.routers.dependencies import DevelopmentSettings
from diracx.routers.utils.users import AuthorizedUserInfo, verify_dirac_access_token

# FastAPI bug:
Expand Down Expand Up @@ -99,6 +100,7 @@ def check_permissions(
policy: Callable,
policy_name: str,
user_info: Annotated[AuthorizedUserInfo, Depends(verify_dirac_access_token)],
dev_settings: DevelopmentSettings,
):
"""This wrapper just calls the actual implementation, but also makes sure
that the policy has been called.
Expand All @@ -120,6 +122,7 @@ async def wrapped_policy(**kwargs):
try:
yield wrapped_policy
finally:

if not has_been_called:
# TODO nice error message with inspect
# That should really not happen
Expand All @@ -128,9 +131,11 @@ async def wrapped_policy(**kwargs):
"(PS: I hope you are in a CI)",
flush=True,
)
# Sleep a bit to make sure the flush happened
time.sleep(1)
os._exit(1)
# If enable, just crash, meanly
if dev_settings.crash_on_missed_access_policy:
# Sleep a bit to make sure the flush happened
time.sleep(1)
os._exit(1)


def open_access(f):
Expand Down
6 changes: 4 additions & 2 deletions diracx-routers/src/diracx/routers/auth/well_known.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from fastapi import Request
from typing_extensions import TypedDict

from ..dependencies import Config
from ..dependencies import Config, DevelopmentSettings
from ..fastapi_classes import DiracxRouter
from ..utils.users import AuthSettings

Expand All @@ -17,7 +17,6 @@ async def openid_configuration(
request: Request,
config: Config,
settings: AuthSettings,
# check_permissions: OpenAccessPolicyCallable,
):
"""OpenID Connect discovery endpoint."""
# await check_permissions()
Expand Down Expand Up @@ -65,17 +64,20 @@ class VOInfo(TypedDict):

class Metadata(TypedDict):
virtual_organizations: dict[str, VOInfo]
development_settings: DevelopmentSettings


@router.get("/dirac-metadata")
async def installation_metadata(
config: Config,
# check_permissions: OpenAccessPolicyCallable,
dev_settings: DevelopmentSettings,
) -> Metadata:
"""Get metadata about the dirac installation."""
# await check_permissions()
metadata: Metadata = {
"virtual_organizations": {},
"development_settings": dev_settings,
}
for vo, vo_info in config.Registry.items():
groups: dict[str, GroupInfo] = {
Expand Down
5 changes: 5 additions & 0 deletions diracx-routers/src/diracx/routers/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from diracx.core.config import Config as _Config
from diracx.core.config import ConfigSource
from diracx.core.properties import SecurityProperty
from diracx.core.settings import DevelopmentSettings as _DevelopmentSettings
from diracx.db.sql import AuthDB as _AuthDB
from diracx.db.sql import JobDB as _JobDB
from diracx.db.sql import JobLoggingDB as _JobLoggingDB
Expand Down Expand Up @@ -46,3 +47,7 @@ def add_settings_annotation(cls: T) -> T:
AvailableSecurityProperties = Annotated[
set[SecurityProperty], Depends(SecurityProperty.available_properties)
]

DevelopmentSettings = Annotated[
_DevelopmentSettings, Depends(_DevelopmentSettings.create)
]
18 changes: 2 additions & 16 deletions diracx-routers/src/diracx/routers/job_manager/access_policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,8 @@ async def policy(


class SandboxAccessPolicy(BaseAccessPolicy):
"""Policy for the sandbox
It delegates most of it to the WMSPolicy.
"""Policy for the sandbox.
They are similar to the WMS access policies.
"""

@staticmethod
Expand All @@ -108,25 +108,11 @@ async def policy(
/,
*,
action: ActionType | None = None,
job_db: JobDB | None = None,
sandbox_metadata_db: SandboxMetadataDB | None = None,
pfns: list[str] | None = None,
required_prefix: str | None = None,
job_ids: list[int] | None = None,
check_wms_permissions: CheckWMSPolicyCallable | None = None,
):

assert action, "action is a mandatory parameter"

# if we pass the job_db or job_ids,
# delegate the check to the WMSAccessPolicy
if job_db or job_ids:
# Make sure that check_wms_permission is set
# It should always be by fastapi Depends,
# but not when we test the policy in itself
assert check_wms_permissions
return check_wms_permissions(action=action, job_db=job_db, job_ids=job_ids)

assert sandbox_metadata_db, "sandbox_metadata_db is a mandatory parameter"
assert pfns, "pfns is a mandatory parameter"

Expand Down
16 changes: 10 additions & 6 deletions diracx-routers/src/diracx/routers/job_manager/sandboxes.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,11 @@
from diracx.core.settings import ServiceSettingsBase

from ..utils.users import AuthorizedUserInfo, verify_dirac_access_token
from .access_policies import ActionType, CheckSandboxPolicyCallable
from .access_policies import (
ActionType,
CheckSandboxPolicyCallable,
CheckWMSPolicyCallable,
)

if TYPE_CHECKING:
from types_aiobotocore_s3.client import S3Client
Expand Down Expand Up @@ -221,7 +225,7 @@ async def get_job_sandboxes(
job_id: int,
sandbox_metadata_db: SandboxMetadataDB,
job_db: JobDB,
check_permissions: CheckSandboxPolicyCallable,
check_permissions: CheckWMSPolicyCallable,
) -> dict[str, list[Any]]:
"""Get input and output sandboxes of given job."""
await check_permissions(action=ActionType.READ, job_db=job_db, job_ids=[job_id])
Expand All @@ -241,7 +245,7 @@ async def get_job_sandbox(
sandbox_metadata_db: SandboxMetadataDB,
job_db: JobDB,
sandbox_type: Literal["input", "output"],
check_permissions: CheckSandboxPolicyCallable,
check_permissions: CheckWMSPolicyCallable,
) -> list[Any]:
"""Get input or output sandbox of given job."""
await check_permissions(action=ActionType.READ, job_db=job_db, job_ids=[job_id])
Expand All @@ -259,7 +263,7 @@ async def assign_sandbox_to_job(
sandbox_metadata_db: SandboxMetadataDB,
job_db: JobDB,
settings: SandboxStoreSettings,
check_permissions: CheckSandboxPolicyCallable,
check_permissions: CheckWMSPolicyCallable,
):
"""Map the pfn as output sandbox to job."""
await check_permissions(action=ActionType.MANAGE, job_db=job_db, job_ids=[job_id])
Expand All @@ -277,7 +281,7 @@ async def unassign_job_sandboxes(
job_id: int,
sandbox_metadata_db: SandboxMetadataDB,
job_db: JobDB,
check_permissions: CheckSandboxPolicyCallable,
check_permissions: CheckWMSPolicyCallable,
):
"""Delete single job sandbox mapping."""
await check_permissions(action=ActionType.MANAGE, job_db=job_db, job_ids=[job_id])
Expand All @@ -289,7 +293,7 @@ async def unassign_bulk_jobs_sandboxes(
jobs_ids: Annotated[list[int], Query()],
sandbox_metadata_db: SandboxMetadataDB,
job_db: JobDB,
check_permissions: CheckSandboxPolicyCallable,
check_permissions: CheckWMSPolicyCallable,
):
"""Delete bulk jobs sandbox mapping."""
await check_permissions(action=ActionType.MANAGE, job_db=job_db, job_ids=jobs_ids)
Expand Down
8 changes: 7 additions & 1 deletion diracx-routers/tests/auth/test_legacy_exchange.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,13 @@

DIRAC_CLIENT_ID = "myDIRACClientID"
pytestmark = pytest.mark.enabled_dependencies(
["AuthDB", "AuthSettings", "ConfigSource", "BaseAccessPolicy"]
[
"AuthDB",
"AuthSettings",
"ConfigSource",
"BaseAccessPolicy",
"DevelopmentSettings",
]
)


Expand Down
1 change: 1 addition & 0 deletions diracx-routers/tests/jobs/test_sandboxes.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
"SandboxStoreSettings",
"WMSAccessPolicy",
"SandboxAccessPolicy",
"DevelopmentSettings",
]
)

Expand Down
12 changes: 0 additions & 12 deletions diracx-routers/tests/jobs/test_wms_access_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,18 +223,6 @@ async def summary_other_vo(*args):
)


async def test_sandbox_access_policy_delegate_to_wms(job_db):
"""We expect that the policy delegates to the WMS policy when given job info
This will trigger an Assert as the WMSAccessPolicy is None
in these tests.
"""
normal_user = AuthorizedUserInfo(properties=[NORMAL_USER], **base_payload)
with pytest.raises(AssertionError):
await SandboxAccessPolicy.policy(
SANDBOX_POLICY_NAME, normal_user, action=ActionType.CREATE, job_db=job_db
)


async def test_sandbox_access_policy_create(sandbox_db):

admin_user = AuthorizedUserInfo(properties=[JOB_ADMINISTRATOR], **base_payload)
Expand Down
7 changes: 6 additions & 1 deletion diracx-routers/tests/test_generic.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
import pytest

pytestmark = pytest.mark.enabled_dependencies(
["ConfigSource", "AuthSettings", "OpenAccessPolicy"]
[
"ConfigSource",
"AuthSettings",
"OpenAccessPolicy",
"DevelopmentSettings",
]
)


Expand Down
1 change: 1 addition & 0 deletions diracx-routers/tests/test_job_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@
"TaskQueueDB",
"SandboxMetadataDB",
"WMSAccessPolicy",
"DevelopmentSettings",
]
)

Expand Down
18 changes: 17 additions & 1 deletion diracx-testing/src/diracx/testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import requests

if TYPE_CHECKING:
from diracx.core.settings import DevelopmentSettings
from diracx.routers.job_manager.sandboxes import SandboxStoreSettings
from diracx.routers.utils.users import AuthorizedUserInfo, AuthSettings

Expand Down Expand Up @@ -76,6 +77,13 @@ def fernet_key() -> str:
return Fernet.generate_key().decode()


@pytest.fixture(scope="session")
def test_dev_settings() -> DevelopmentSettings:
from diracx.core.settings import DevelopmentSettings

yield DevelopmentSettings()


@pytest.fixture(scope="session")
def test_auth_settings(private_key_pem, fernet_key) -> AuthSettings:
from diracx.routers.utils.users import AuthSettings
Expand Down Expand Up @@ -141,6 +149,7 @@ def __init__(
with_config_repo,
test_auth_settings,
test_sandbox_settings,
test_dev_settings,
):
from diracx.core.config import ConfigSource
from diracx.core.extensions import select_from_extension
Expand Down Expand Up @@ -171,6 +180,7 @@ def enrich_tokens(access_payload: dict, refresh_payload: dict):
self._cache_dir = tmp_path_factory.mktemp("empty-dbs")

self.test_auth_settings = test_auth_settings
self.test_dev_settings = test_dev_settings

all_access_policies = {
e.name: [AlwaysAllowAccessPolicy]
Expand All @@ -183,6 +193,7 @@ def enrich_tokens(access_payload: dict, refresh_payload: dict):
all_service_settings=[
test_auth_settings,
test_sandbox_settings,
test_dev_settings,
],
database_urls=database_urls,
os_database_conn_kwargs={
Expand Down Expand Up @@ -346,13 +357,18 @@ def session_client_factory(
test_sandbox_settings,
with_config_repo,
tmp_path_factory,
test_dev_settings,
):
"""TODO.
----
"""
yield ClientFactory(
tmp_path_factory, with_config_repo, test_auth_settings, test_sandbox_settings
tmp_path_factory,
with_config_repo,
test_auth_settings,
test_sandbox_settings,
test_dev_settings,
)


Expand Down

0 comments on commit c2cd486

Please sign in to comment.