diff --git a/src/blueapi/data_management/__init__.py b/src/blueapi/data_management/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/blueapi/data_management/visit_directory_provider.py b/src/blueapi/data_management/visit_directory_provider.py new file mode 100644 index 000000000..9b27e9d9c --- /dev/null +++ b/src/blueapi/data_management/visit_directory_provider.py @@ -0,0 +1,128 @@ +import logging +from abc import ABC, abstractmethod +from pathlib import Path +from typing import Optional + +from aiohttp import ClientSession +from ophyd_async.core import DirectoryInfo, DirectoryProvider +from pydantic import BaseModel + + +class DataCollectionIdentifier(BaseModel): + collectionNumber: int + + +class VisitServiceClientBase(ABC): + """ + Object responsible for I/O in determining collection number + """ + + @abstractmethod + async def create_new_collection(self) -> DataCollectionIdentifier: + """Create new collection""" + + @abstractmethod + async def get_current_collection(self) -> DataCollectionIdentifier: + """Get current collection""" + + +class VisitServiceClient(VisitServiceClientBase): + _url: str + + def __init__(self, url: str) -> None: + self._url = url + + async def create_new_collection(self) -> DataCollectionIdentifier: + async with ClientSession() as session: + async with session.post(f"{self._url}/numtracker") as response: + if response.status == 200: + json = await response.json() + return DataCollectionIdentifier.parse_obj(json) + else: + raise Exception(response.status) + + async def get_current_collection(self) -> DataCollectionIdentifier: + async with ClientSession() as session: + async with session.get(f"{self._url}/numtracker") as response: + if response.status == 200: + json = await response.json() + return DataCollectionIdentifier.parse_obj(json) + else: + raise Exception(response.status) + + +class LocalVisitServiceClient(VisitServiceClientBase): + _count: int + + def __init__(self) -> None: + self._count = 0 + + async def create_new_collection(self) -> DataCollectionIdentifier: + self._count += 1 + return DataCollectionIdentifier(collectionNumber=self._count) + + async def get_current_collection(self) -> DataCollectionIdentifier: + return DataCollectionIdentifier(collectionNumber=self._count) + + +class VisitDirectoryProvider(DirectoryProvider): + """ + Gets information from a remote service to construct the path that detectors + should write to, and determine how their files should be named. + """ + + _data_group_name: str + _data_directory: Path + + _client: VisitServiceClientBase + _current_collection: Optional[DirectoryInfo] + _session: Optional[ClientSession] + + def __init__( + self, + data_group_name: str, + data_directory: Path, + client: VisitServiceClientBase, + ): + self._data_group_name = data_group_name + self._data_directory = data_directory + self._client = client + + self._current_collection = None + self._session = None + + async def update(self) -> None: + """ + Calls the visit service to create a new data collection in the current visit. + """ + # TODO: After visit service is more feature complete: + # TODO: Allow selecting visit as part of the request to BlueAPI + # TODO: Consume visit information from BlueAPI and pass down to this class + # TODO: Query visit service to get information about visit and data collection + # TODO: Use AuthN information as part of verification with visit service + + try: + collection_id_info = await self._client.create_new_collection() + self._current_collection = self._generate_directory_info(collection_id_info) + except Exception as ex: + # TODO: The catch all is needed because the RunEngine will not + # currently handle it, see + # https://github.com/bluesky/bluesky/pull/1623 + self._current_collection = None + logging.exception(ex) + + def _generate_directory_info( + self, + collection_id_info: DataCollectionIdentifier, + ) -> DirectoryInfo: + collection_id = collection_id_info.collectionNumber + file_prefix = f"{self._data_group_name}-{collection_id}" + return DirectoryInfo(str(self._data_directory), file_prefix) + + def __call__(self) -> DirectoryInfo: + if self._current_collection is not None: + return self._current_collection + else: + raise ValueError( + "No current collection, update() needs to be called at least once" + ) diff --git a/src/blueapi/preprocessors/attach_metadata.py b/src/blueapi/preprocessors/attach_metadata.py index 3de02201e..21d9ed8b4 100644 --- a/src/blueapi/preprocessors/attach_metadata.py +++ b/src/blueapi/preprocessors/attach_metadata.py @@ -1,8 +1,9 @@ +import bluesky.plan_stubs as bps import bluesky.preprocessors as bpp from bluesky.utils import make_decorator -from ophyd_async.core import DirectoryProvider from blueapi.core import MsgGenerator +from blueapi.data_management.visit_directory_provider import VisitDirectoryProvider DATA_SESSION = "data_session" DATA_GROUPS = "data_groups" @@ -10,14 +11,15 @@ def attach_metadata( plan: MsgGenerator, - provider: DirectoryProvider, + provider: VisitDirectoryProvider, ) -> MsgGenerator: """ Attach data session metadata to the runs within a plan and make it correlate with an ophyd-async DirectoryProvider. - This calls the directory provider and ensures the start document contains - the correct data session. + This updates the directory provider (which in turn makes a call to to a service + to figure out which scan number we are using for such a scan), and ensures the + start document contains the correct data session. Args: plan: The plan to preprocess @@ -29,6 +31,7 @@ def attach_metadata( Yields: Iterator[Msg]: Plan messages """ + yield from bps.wait_for([provider.update]) directory_info = provider() yield from bpp.inject_md_wrapper( plan, md={DATA_SESSION: directory_info.filename_prefix} diff --git a/src/blueapi/service/handler.py b/src/blueapi/service/handler.py index 3e3dbef8a..4a006a960 100644 --- a/src/blueapi/service/handler.py +++ b/src/blueapi/service/handler.py @@ -1,11 +1,15 @@ import logging from typing import Mapping, Optional -from ophyd_async.core import StaticDirectoryProvider - from blueapi.config import ApplicationConfig from blueapi.core import BlueskyContext from blueapi.core.event import EventStream +from blueapi.data_management.visit_directory_provider import ( + LocalVisitServiceClient, + VisitDirectoryProvider, + VisitServiceClient, + VisitServiceClientBase, +) from blueapi.messaging import StompMessagingTemplate from blueapi.messaging.base import MessagingTemplate from blueapi.preprocessors.attach_metadata import attach_metadata @@ -88,9 +92,18 @@ def setup_handler( plan_wrappers = [] if config: - provider = StaticDirectoryProvider( - filename_prefix=f"{config.env.data_writing.group_name}-blueapi", - directory_path=str(config.env.data_writing.visit_directory), + visit_service_client: VisitServiceClientBase + if config.env.data_writing.visit_service_url is not None: + visit_service_client = VisitServiceClient( + config.env.data_writing.visit_service_url + ) + else: + visit_service_client = LocalVisitServiceClient() + + provider = VisitDirectoryProvider( + data_group_name=config.env.data_writing.group_name, + data_directory=config.env.data_writing.visit_directory, + client=visit_service_client, ) # Make all dodal devices created by the context use provider if they can diff --git a/tests/data_management/test_visit_directory_provider.py b/tests/data_management/test_visit_directory_provider.py new file mode 100644 index 000000000..57d93d0ef --- /dev/null +++ b/tests/data_management/test_visit_directory_provider.py @@ -0,0 +1,66 @@ +from pathlib import Path + +import pytest +from ophyd_async.core import DirectoryInfo + +from blueapi.data_management.visit_directory_provider import ( + DataCollectionIdentifier, + LocalVisitServiceClient, + VisitDirectoryProvider, + VisitServiceClientBase, +) + + +@pytest.fixture +def visit_service_client() -> VisitServiceClientBase: + return LocalVisitServiceClient() + + +@pytest.fixture +def visit_directory_provider( + visit_service_client: VisitServiceClientBase, +) -> VisitDirectoryProvider: + return VisitDirectoryProvider("example", Path("/tmp"), visit_service_client) + + +@pytest.mark.asyncio +async def test_client_can_view_collection( + visit_service_client: VisitServiceClientBase, +) -> None: + collection = await visit_service_client.get_current_collection() + assert collection == DataCollectionIdentifier(collectionNumber=0) + + +@pytest.mark.asyncio +async def test_client_can_create_collection( + visit_service_client: VisitServiceClientBase, +) -> None: + collection = await visit_service_client.create_new_collection() + assert collection == DataCollectionIdentifier(collectionNumber=1) + + +@pytest.mark.asyncio +async def test_update_sets_collection_number( + visit_directory_provider: VisitDirectoryProvider, +) -> None: + await visit_directory_provider.update() + assert visit_directory_provider() == DirectoryInfo( + directory_path="/tmp", + filename_prefix="example-1", + ) + + +@pytest.mark.asyncio +async def test_update_sets_collection_number_multi( + visit_directory_provider: VisitDirectoryProvider, +) -> None: + await visit_directory_provider.update() + assert visit_directory_provider() == DirectoryInfo( + directory_path="/tmp", + filename_prefix="example-1", + ) + await visit_directory_provider.update() + assert visit_directory_provider() == DirectoryInfo( + directory_path="/tmp", + filename_prefix="example-2", + ) diff --git a/tests/preprocessors/test_attach_metadata.py b/tests/preprocessors/test_attach_metadata.py index 9f3a1f1d7..87d1f0a27 100644 --- a/tests/preprocessors/test_attach_metadata.py +++ b/tests/preprocessors/test_attach_metadata.py @@ -1,29 +1,78 @@ from pathlib import Path -from typing import Any, Dict, List, Mapping +from typing import Any, Callable, Dict, List, Mapping +import bluesky.plan_stubs as bps import bluesky.plans as bp import pytest from bluesky import RunEngine +from bluesky.preprocessors import ( + run_decorator, + run_wrapper, + set_run_key_decorator, + set_run_key_wrapper, + stage_wrapper, +) from bluesky.protocols import HasName, Readable, Reading, Status, Triggerable from event_model.documents.event_descriptor import DataKey from ophyd.status import StatusBase -from ophyd_async.core import DirectoryProvider, StaticDirectoryProvider +from ophyd_async.core import DirectoryProvider from blueapi.core import DataEvent, MsgGenerator +from blueapi.data_management.visit_directory_provider import ( + DataCollectionIdentifier, + VisitDirectoryProvider, + VisitServiceClient, +) from blueapi.preprocessors.attach_metadata import DATA_SESSION, attach_metadata DATA_DIRECTORY = Path("/tmp") DATA_GROUP_NAME = "test" -RUN_0 = DATA_DIRECTORY / f"{DATA_GROUP_NAME}" -RUN_1 = DATA_DIRECTORY / f"{DATA_GROUP_NAME}" -RUN_2 = DATA_DIRECTORY / f"{DATA_GROUP_NAME}" +RUN_0 = DATA_DIRECTORY / f"{DATA_GROUP_NAME}-0" +RUN_1 = DATA_DIRECTORY / f"{DATA_GROUP_NAME}-1" +RUN_2 = DATA_DIRECTORY / f"{DATA_GROUP_NAME}-2" + + +class MockVisitServiceClient(VisitServiceClient): + _count: int + _fail: bool + + def __init__(self) -> None: + super().__init__("http://example.com") + self._count = 0 + self._fail = False + + def always_fail(self) -> None: + self._fail = True + + async def create_new_collection(self) -> DataCollectionIdentifier: + if self._fail: + raise ConnectionError() + + count = self._count + self._count += 1 + return DataCollectionIdentifier(collectionNumber=count) + + async def get_current_collection(self) -> DataCollectionIdentifier: + if self._fail: + raise ConnectionError() + + return DataCollectionIdentifier(collectionNumber=self._count) @pytest.fixture -def provider() -> DirectoryProvider: - return StaticDirectoryProvider(str(DATA_DIRECTORY), DATA_GROUP_NAME) +def client() -> VisitServiceClient: + return MockVisitServiceClient() + + +@pytest.fixture +def provider(client: VisitServiceClient) -> VisitDirectoryProvider: + return VisitDirectoryProvider( + data_directory=DATA_DIRECTORY, + data_group_name=DATA_GROUP_NAME, + client=client, + ) @pytest.fixture @@ -77,7 +126,7 @@ def parent(self) -> None: @pytest.fixture(params=[1, 2]) -def detectors(request, provider: DirectoryProvider) -> List[Readable]: +def detectors(request, provider: VisitDirectoryProvider) -> List[Readable]: number_of_detectors = request.param return [ FakeDetector( @@ -88,6 +137,228 @@ def detectors(request, provider: DirectoryProvider) -> List[Readable]: ] +def simple_run(detectors: List[Readable]) -> MsgGenerator: + yield from bp.count(detectors) + + +def multi_run(detectors: List[Readable]) -> MsgGenerator: + yield from bp.count(detectors) + yield from bp.count(detectors) + + +def multi_nested_plan(detectors: List[Readable]) -> MsgGenerator: + yield from simple_run(detectors) + yield from simple_run(detectors) + + +def multi_run_single_stage(detectors: List[Readable]) -> MsgGenerator: + def stageless_count() -> MsgGenerator: + return (yield from bps.one_shot(detectors)) + + def inner_plan() -> MsgGenerator: + yield from run_wrapper(stageless_count()) + yield from run_wrapper(stageless_count()) + + yield from stage_wrapper(inner_plan(), detectors) + + +def multi_run_single_stage_multi_group( + detectors: List[Readable], +) -> MsgGenerator: + def stageless_count() -> MsgGenerator: + return (yield from bps.one_shot(detectors)) + + def inner_plan() -> MsgGenerator: + yield from run_wrapper(stageless_count(), md={DATA_SESSION: 1}) + yield from run_wrapper(stageless_count(), md={DATA_SESSION: 1}) + yield from run_wrapper(stageless_count(), md={DATA_SESSION: 2}) + yield from run_wrapper(stageless_count(), md={DATA_SESSION: 2}) + + yield from stage_wrapper(inner_plan(), detectors) + + +@run_decorator(md={DATA_SESSION: 12345}) +@set_run_key_decorator("outer") +def nested_run_with_metadata(detectors: List[Readable]) -> MsgGenerator: + yield from set_run_key_wrapper(bp.count(detectors), "inner") + yield from set_run_key_wrapper(bp.count(detectors), "inner") + + +@run_decorator() +@set_run_key_decorator("outer") +def nested_run_without_metadata( + detectors: List[Readable], +) -> MsgGenerator: + yield from set_run_key_wrapper(bp.count(detectors), "inner") + yield from set_run_key_wrapper(bp.count(detectors), "inner") + + +def test_simple_run_gets_scan_number( + run_engine: RunEngine, + detectors: List[Readable], + provider: DirectoryProvider, +) -> None: + docs = collect_docs( + run_engine, + simple_run(detectors), + provider, + ) + assert docs[0].name == "start" + assert docs[0].doc[DATA_SESSION] == f"{DATA_GROUP_NAME}-0" + assert_all_detectors_used_collection_numbers(docs, detectors, [RUN_0]) + + +@pytest.mark.parametrize("plan", [multi_run, multi_nested_plan]) +def test_multi_run_gets_scan_numbers( + run_engine: RunEngine, + detectors: List[Readable], + plan: Callable[[List[Readable]], MsgGenerator], + provider: DirectoryProvider, +) -> None: + """Test is here to demonstrate that multi run plans will overwrite files.""" + docs = collect_docs( + run_engine, + plan(detectors), + provider, + ) + start_docs = find_start_docs(docs) + assert len(start_docs) == 2 + assert start_docs[0].doc[DATA_SESSION] == f"{DATA_GROUP_NAME}-0" + assert start_docs[1].doc[DATA_SESSION] == f"{DATA_GROUP_NAME}-0" + assert_all_detectors_used_collection_numbers(docs, detectors, [RUN_0, RUN_0]) + + +def test_multi_run_single_stage( + run_engine: RunEngine, + detectors: List[Readable], + provider: DirectoryProvider, +) -> None: + docs = collect_docs( + run_engine, + multi_run_single_stage(detectors), + provider, + ) + start_docs = find_start_docs(docs) + assert len(start_docs) == 2 + assert start_docs[0].doc[DATA_SESSION] == f"{DATA_GROUP_NAME}-0" + assert start_docs[1].doc[DATA_SESSION] == f"{DATA_GROUP_NAME}-0" + assert_all_detectors_used_collection_numbers( + docs, + detectors, + [ + RUN_0, + RUN_0, + ], + ) + + +def test_multi_run_single_stage_multi_group( + run_engine: RunEngine, + detectors: List[Readable], + provider: DirectoryProvider, +) -> None: + docs = collect_docs( + run_engine, + multi_run_single_stage_multi_group(detectors), + provider, + ) + start_docs = find_start_docs(docs) + assert len(start_docs) == 4 + assert start_docs[0].doc[DATA_SESSION] == f"{DATA_GROUP_NAME}-0" + assert start_docs[1].doc[DATA_SESSION] == f"{DATA_GROUP_NAME}-0" + assert start_docs[2].doc[DATA_SESSION] == f"{DATA_GROUP_NAME}-0" + assert start_docs[3].doc[DATA_SESSION] == f"{DATA_GROUP_NAME}-0" + assert_all_detectors_used_collection_numbers( + docs, + detectors, + [ + RUN_0, + RUN_0, + RUN_0, + RUN_0, + ], + ) + + +def test_nested_run_with_metadata( + run_engine: RunEngine, + detectors: List[Readable], + provider: DirectoryProvider, +) -> None: + """Test is here to demonstrate that nested runs will be treated as a single run. + + That means detectors in such runs will overwrite files. + """ + docs = collect_docs( + run_engine, + nested_run_with_metadata(detectors), + provider, + ) + start_docs = find_start_docs(docs) + assert len(start_docs) == 3 + assert start_docs[0].doc[DATA_SESSION] == f"{DATA_GROUP_NAME}-0" + assert start_docs[1].doc[DATA_SESSION] == f"{DATA_GROUP_NAME}-0" + assert start_docs[2].doc[DATA_SESSION] == f"{DATA_GROUP_NAME}-0" + assert_all_detectors_used_collection_numbers(docs, detectors, [RUN_0, RUN_0]) + + +def test_nested_run_without_metadata( + run_engine: RunEngine, + detectors: List[Readable], + provider: DirectoryProvider, +) -> None: + """Test is here to demonstrate that nested runs will be treated as a single run. + + That means detectors in such runs will overwrite files. + """ + docs = collect_docs( + run_engine, + nested_run_without_metadata(detectors), + provider, + ) + start_docs = find_start_docs(docs) + assert len(start_docs) == 3 + assert start_docs[0].doc[DATA_SESSION] == f"{DATA_GROUP_NAME}-0" + assert start_docs[1].doc[DATA_SESSION] == f"{DATA_GROUP_NAME}-0" + assert start_docs[2].doc[DATA_SESSION] == f"{DATA_GROUP_NAME}-0" + assert_all_detectors_used_collection_numbers(docs, detectors, [RUN_0, RUN_0]) + + +def test_visit_directory_provider_fails( + run_engine: RunEngine, + detectors: List[Readable], + provider: DirectoryProvider, + client: MockVisitServiceClient, +) -> None: + client.always_fail() + with pytest.raises(ValueError): + collect_docs( + run_engine, + simple_run(detectors), + provider, + ) + + +def test_visit_directory_provider_fails_after_one_sucess( + run_engine: RunEngine, + detectors: List[Readable], + provider: DirectoryProvider, + client: MockVisitServiceClient, +) -> None: + collect_docs( + run_engine, + simple_run(detectors), + provider, + ) + client.always_fail() + with pytest.raises(ValueError): + collect_docs( + run_engine, + simple_run(detectors), + provider, + ) + + def collect_docs( run_engine: RunEngine, plan: MsgGenerator, @@ -103,13 +374,25 @@ def on_event(name: str, doc: Mapping[str, Any]) -> None: return events -def test_attach_metadata_attaches_correct_data_session( - detectors: List[Readable], provider: DirectoryProvider, run_engine: RunEngine -): - docs = collect_docs( - run_engine, - attach_metadata(bp.count(detectors), provider), - provider, - ) - assert docs[0].name == "start" - assert docs[0].doc.get(DATA_SESSION) == DATA_GROUP_NAME +def assert_all_detectors_used_collection_numbers( + docs: List[DataEvent], + detectors: List[Readable], + source_history: List[Path], +) -> None: + descriptors = find_descriptor_docs(docs) + assert len(descriptors) == len(source_history) + + for descriptor, expected_source in zip(descriptors, source_history): + for detector in detectors: + source = descriptor.doc.get("data_keys", {}).get(f"{detector.name}_data")[ + "source" + ] + assert Path(source) == expected_source + + +def find_start_docs(docs: List[DataEvent]) -> List[DataEvent]: + return list(filter(lambda event: event.name == "start", docs)) + + +def find_descriptor_docs(docs: List[DataEvent]) -> List[DataEvent]: + return list(filter(lambda event: event.name == "descriptor", docs))