diff --git a/server/api/datasets/routes.py b/server/api/datasets/routes.py index d7aa31e8..8f83cc7a 100644 --- a/server/api/datasets/routes.py +++ b/server/api/datasets/routes.py @@ -9,7 +9,10 @@ DeleteDataset, UpdateDataset, ) -from server.application.datasets.exceptions import CannotCreateDataset +from server.application.datasets.exceptions import ( + CannotCreateDataset, + CannotUpdateDataset, +) from server.application.datasets.queries import GetAllDatasets, GetDatasetByID from server.application.datasets.views import DatasetView from server.config.di import resolve @@ -107,15 +110,20 @@ async def create_dataset(data: DatasetCreate, request: "APIRequest") -> DatasetV response_model=DatasetView, responses={404: {}}, ) -async def update_dataset(id: ID, data: DatasetUpdate) -> DatasetView: +async def update_dataset( + id: ID, data: DatasetUpdate, request: "APIRequest" +) -> DatasetView: bus = resolve(MessageBus) - command = UpdateDataset(id=id, **data.dict()) + command = UpdateDataset(account=request.user.account, id=id, **data.dict()) try: await bus.execute(command) except DatasetDoesNotExist: raise HTTPException(404) + except CannotUpdateDataset as exc: + logger.exception(exc) + raise HTTPException(403, detail="Permission denied") query = GetDatasetByID(id=id) return await bus.execute(query) diff --git a/server/application/datasets/commands.py b/server/application/datasets/commands.py index 4ae0ae54..17b2eef7 100644 --- a/server/application/datasets/commands.py +++ b/server/application/datasets/commands.py @@ -35,6 +35,8 @@ class CreateDataset(CreateDatasetValidationMixin, Command[ID]): class UpdateDataset(UpdateDatasetValidationMixin, Command[None]): + account: Union[Account, Skip] + id: ID title: str description: str diff --git a/server/application/datasets/exceptions.py b/server/application/datasets/exceptions.py index 1c9aa84d..b237147a 100644 --- a/server/application/datasets/exceptions.py +++ b/server/application/datasets/exceptions.py @@ -1,2 +1,6 @@ class CannotCreateDataset(Exception): pass + + +class CannotUpdateDataset(Exception): + pass diff --git a/server/application/datasets/handlers.py b/server/application/datasets/handlers.py index b26a0f1f..0c73b1ad 100644 --- a/server/application/datasets/handlers.py +++ b/server/application/datasets/handlers.py @@ -15,9 +15,9 @@ from server.seedwork.application.messages import MessageBus from .commands import CreateDataset, DeleteDataset, UpdateDataset -from .exceptions import CannotCreateDataset +from .exceptions import CannotCreateDataset, CannotUpdateDataset from .queries import GetAllDatasets, GetDatasetByID, GetDatasetFilters -from .specifications import can_create_dataset +from .specifications import can_create_dataset, can_update_dataset from .views import DatasetFiltersView, DatasetView @@ -72,9 +72,14 @@ async def update_dataset(command: UpdateDataset) -> None: if dataset is None: raise DatasetDoesNotExist(pk) + if not isinstance(command.account, Skip) and not can_update_dataset( + dataset, command.account + ): + raise CannotUpdateDataset(f"{command.account=}, {dataset=}") + tags = await tag_repository.get_all(ids=command.tag_ids) dataset.update( - **command.dict(exclude={"id", "tag_ids", "extra_field_values"}), + **command.dict(exclude={"account", "id", "tag_ids", "extra_field_values"}), tags=tags, extra_field_values=command.extra_field_values, ) diff --git a/server/application/datasets/specifications.py b/server/application/datasets/specifications.py index c5f386a7..3ca31a02 100644 --- a/server/application/datasets/specifications.py +++ b/server/application/datasets/specifications.py @@ -1,6 +1,11 @@ from server.domain.auth.entities import Account from server.domain.catalogs.entities import Catalog +from server.domain.datasets.entities import Dataset def can_create_dataset(catalog: Catalog, account: Account) -> bool: return catalog.organization.siret == account.organization_siret + + +def can_update_dataset(dataset: Dataset, account: Account) -> bool: + return dataset.catalog_record.organization.siret == account.organization_siret diff --git a/tests/api/test_datasets.py b/tests/api/test_datasets.py index be65efdd..f11fdff4 100644 --- a/tests/api/test_datasets.py +++ b/tests/api/test_datasets.py @@ -12,7 +12,7 @@ from server.application.tags.queries import GetTagByID from server.config.di import resolve from server.domain.catalogs.entities import ExtraFieldValue, TextExtraField -from server.domain.common.types import ID, id_factory +from server.domain.common.types import ID, Skip, id_factory from server.domain.datasets.entities import DataFormat, UpdateFrequency from server.domain.datasets.exceptions import DatasetDoesNotExist from server.domain.organizations.entities import LEGACY_ORGANIZATION @@ -26,7 +26,7 @@ CreateDatasetPayloadFactory, CreateOrganizationFactory, CreatePasswordUserFactory, - UpdateDatasetFactory, + UpdateDatasetPayloadFactory, fake, ) from ..helpers import TestPasswordUser, create_test_password_user, to_payload @@ -257,6 +257,51 @@ async def test_update_not_authenticated(self, client: httpx.AsyncClient) -> None response = await client.put(f"/datasets/{pk}/", json={}) assert response.status_code == 401 + async def test_update_in_other_org_denied( + self, client: httpx.AsyncClient, temp_user: TestPasswordUser + ) -> None: + bus = resolve(MessageBus) + + other_org_siret = await bus.execute(CreateOrganizationFactory.build()) + await bus.execute(CreateCatalog(organization_siret=other_org_siret)) + + command = CreateDatasetFactory.build( + organization_siret=other_org_siret, account=Skip() + ) + dataset_id = await bus.execute(command) + + payload = to_payload( + UpdateDatasetPayloadFactory.build_from_create_command(command) + ) + response = await client.put( + f"/datasets/{dataset_id}/", json=payload, auth=temp_user.auth + ) + + assert response.status_code == 403 + + async def test_update_in_other_org_admin_denied( + self, client: httpx.AsyncClient, admin_user: TestPasswordUser + ) -> None: + bus = resolve(MessageBus) + + other_org_siret = await bus.execute(CreateOrganizationFactory.build()) + assert admin_user.account.organization_siret != other_org_siret + await bus.execute(CreateCatalog(organization_siret=other_org_siret)) + + command = CreateDatasetFactory.build( + organization_siret=other_org_siret, account=Skip() + ) + dataset_id = await bus.execute(command) + + payload = to_payload( + UpdateDatasetPayloadFactory.build_from_create_command(command) + ) + response = await client.put( + f"/datasets/{dataset_id}/", json=payload, auth=admin_user.auth + ) + + assert response.status_code == 403 + async def test_delete_not_authenticated(self, client: httpx.AsyncClient) -> None: pk = id_factory() response = await client.delete(f"/datasets/{pk}/") @@ -425,7 +470,7 @@ async def test_not_found( pk = id_factory() response = await client.put( f"/datasets/{pk}/", - json=to_payload(UpdateDatasetFactory.build(id=pk)), + json=to_payload(UpdateDatasetPayloadFactory.build(id=pk)), auth=temp_user.auth, ) assert response.status_code == 404 @@ -480,13 +525,13 @@ async def test_fields_empty_invalid( response = await client.put( f"/datasets/{dataset_id}/", json=to_payload( - UpdateDatasetFactory.build( + UpdateDatasetPayloadFactory.build_from_create_command( + command.copy(exclude={"title", "description", "service", "url"}), factory_use_construct=True, # Skip validation title="", description="", service="", url="", - **command.dict(exclude={"title", "description", "service", "url"}), ) ), auth=temp_user.auth, @@ -523,7 +568,7 @@ async def test_update( other_last_updated_at = fake.date_time_tz() payload = to_payload( - UpdateDatasetFactory.build( + UpdateDatasetPayloadFactory.build( title="Other title", description="Other description", service="Other service", @@ -603,9 +648,9 @@ async def test_formats_add( response = await client.put( f"/datasets/{dataset_id}/", json=to_payload( - UpdateDatasetFactory.build( + UpdateDatasetPayloadFactory.build_from_create_command( + command.copy(exclude={"formats"}), formats=[DataFormat.WEBSITE, DataFormat.API, DataFormat.FILE_GIS], - **command.dict(exclude={"formats"}), ) ), auth=temp_user.auth, @@ -627,9 +672,9 @@ async def test_formats_remove( response = await client.put( f"/datasets/{dataset_id}/", json=to_payload( - UpdateDatasetFactory.build( + UpdateDatasetPayloadFactory.build_from_create_command( + command.copy(exclude={"formats"}), formats=[DataFormat.WEBSITE], - **command.dict(exclude={"formats"}), ) ), auth=temp_user.auth, @@ -654,9 +699,9 @@ async def test_tags_add( response = await client.put( f"/datasets/{dataset_id}/", json=to_payload( - UpdateDatasetFactory.build( + UpdateDatasetPayloadFactory.build_from_create_command( + command.copy(exclude={"tag_ids"}), tag_ids=[str(tag_architecture_id)], - **command.dict(exclude={"tag_ids"}), ) ), auth=temp_user.auth, @@ -684,9 +729,9 @@ async def test_tags_remove( response = await client.put( f"/datasets/{dataset_id}/", json=to_payload( - UpdateDatasetFactory.build( + UpdateDatasetPayloadFactory.build_from_create_command( + command.copy(exclude={"tag_ids"}), tag_ids=[], - **command.dict(exclude={"tag_ids"}), ) ), auth=temp_user.auth, @@ -752,9 +797,7 @@ async def test_create_dataset_with_extra_field_values( } ] - async def test_add_extra_field_value( - self, client: httpx.AsyncClient, temp_user: TestPasswordUser - ) -> None: + async def test_add_extra_field_value(self, client: httpx.AsyncClient) -> None: bus = resolve(MessageBus) siret, user, extra_field_id = await self._setup() @@ -766,19 +809,18 @@ async def test_add_extra_field_value( assert not dataset.extra_field_values payload = to_payload( - UpdateDatasetFactory.build( - id=dataset_id, + UpdateDatasetPayloadFactory.build_from_create_command( + command.copy(exclude={"extra_field_values"}), extra_field_values=[ ExtraFieldValue( extra_field_id=extra_field_id, value="Environ 10 To", ) ], - **command.dict(exclude={"extra_field_values"}), ) ) response = await client.put( - f"/datasets/{dataset_id}/", json=payload, auth=temp_user.auth + f"/datasets/{dataset_id}/", json=payload, auth=user.auth ) assert response.status_code == 200 data = response.json() @@ -789,9 +831,7 @@ async def test_add_extra_field_value( } ] - async def test_remove_extra_field_value( - self, client: httpx.AsyncClient, temp_user: TestPasswordUser - ) -> None: + async def test_remove_extra_field_value(self, client: httpx.AsyncClient) -> None: bus = resolve(MessageBus) siret, user, extra_field_id = await self._setup() @@ -810,14 +850,13 @@ async def test_remove_extra_field_value( assert len(dataset.extra_field_values) == 1 payload = to_payload( - UpdateDatasetFactory.build( - id=dataset_id, + UpdateDatasetPayloadFactory.build_from_create_command( + command.copy(exclude={"extra_field_values"}), extra_field_values=[], - **command.dict(exclude={"extra_field_values"}), ) ) response = await client.put( - f"/datasets/{dataset_id}/", json=payload, auth=temp_user.auth + f"/datasets/{dataset_id}/", json=payload, auth=user.auth ) assert response.status_code == 200 data = response.json() diff --git a/tests/api/test_datasets_search.py b/tests/api/test_datasets_search.py index 150e8475..f524ec51 100644 --- a/tests/api/test_datasets_search.py +++ b/tests/api/test_datasets_search.py @@ -183,7 +183,10 @@ async def test_search_results_change_when_data_changes( # Update dataset title update_command = UpdateDatasetFactory.build( - id=pk, title="Modifié", **command.dict(exclude={"title"}) + account=temp_user.account, + id=pk, + title="Modifié", + **command.dict(exclude={"title", "account", "organization_siret"}), ) await bus.execute(update_command) @@ -200,7 +203,7 @@ async def test_search_results_change_when_data_changes( # Same on description update_command = UpdateDatasetFactory.build( description="Jeu de données spécial", - **update_command.dict(exclude={"description"}) + **update_command.dict(exclude={"description"}), ) await bus.execute(update_command) response = await client.get( diff --git a/tests/factories.py b/tests/factories.py index a7a5c9d5..b616c9cd 100644 --- a/tests/factories.py +++ b/tests/factories.py @@ -7,7 +7,7 @@ from pydantic import BaseModel from pydantic_factories import ModelFactory, Require, Use -from server.api.datasets.schemas import DatasetCreate +from server.api.datasets.schemas import DatasetCreate, DatasetUpdate from server.application.auth.commands import CreateDataPassUser, CreatePasswordUser from server.application.datasets.commands import CreateDataset, UpdateDataset from server.application.organizations.commands import CreateOrganization @@ -93,13 +93,29 @@ class CreateDatasetPayloadFactory(_BaseCreateDatasetFactory, Factory[DatasetCrea __model__ = DatasetCreate -class UpdateDatasetFactory(Factory[UpdateDataset]): - __model__ = UpdateDataset - +class _BaseUpdateDatasetFactory: tag_ids = Use(lambda: []) extra_field_values = Use(lambda: []) +class UpdateDatasetFactory(_BaseCreateDatasetFactory, Factory[UpdateDataset]): + __model__ = UpdateDataset + + account = Require() + + +class UpdateDatasetPayloadFactory(_BaseUpdateDatasetFactory, Factory[DatasetUpdate]): + __model__ = DatasetUpdate + + @classmethod + def build_from_create_command( + cls, command: CreateDataset, **kwargs: Any + ) -> DatasetUpdate: + return cls.build( + **command.dict(exclude={"account", "organization_siret"}), **kwargs + ) + + class CreateOrganizationFactory(Factory[CreateOrganization]): __model__ = CreateOrganization diff --git a/tests/tools/test_initdata.py b/tests/tools/test_initdata.py index 6af08a66..96707553 100644 --- a/tests/tools/test_initdata.py +++ b/tests/tools/test_initdata.py @@ -11,7 +11,7 @@ from server.application.datasets.commands import UpdateDataset from server.application.datasets.queries import GetAllDatasets, GetDatasetByID from server.config.di import resolve -from server.domain.common.types import ID +from server.domain.common.types import ID, Skip from server.seedwork.application.messages import MessageBus from tools import initdata @@ -146,6 +146,7 @@ async def test_repo_initdata( # Make a change. command = UpdateDataset( + account=Skip(), **dataset.dict(exclude={"title"}), tag_ids=[tag.id for tag in dataset.tags], title="Changed", diff --git a/tools/initdata.py b/tools/initdata.py index bed226a8..1c531c22 100644 --- a/tools/initdata.py +++ b/tools/initdata.py @@ -159,7 +159,7 @@ def _get_dataset_attr(dataset: Dataset, attr: str) -> Any: existing_dataset = await repository.get_by_id(id_) if existing_dataset is not None: - update_command = UpdateDataset(id=id_, **item["params"]) + update_command = UpdateDataset(account=Skip(), id=id_, **item["params"]) changed = any( getattr(update_command, k) != _get_dataset_attr(existing_dataset, k)