Skip to content

Commit

Permalink
Reintroduce VisitDirectoryProvider and update in preprocessor (#325)
Browse files Browse the repository at this point in the history
Make calls to numtracker in preprocessor for unique file ids
  • Loading branch information
rosesyrett authored Oct 30, 2023
1 parent 9724c0c commit 3ec7b08
Show file tree
Hide file tree
Showing 6 changed files with 520 additions and 27 deletions.
Empty file.
128 changes: 128 additions & 0 deletions src/blueapi/data_management/visit_directory_provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
import logging
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Optional

from aiohttp import ClientSession
from ophyd_async.core import DirectoryInfo, DirectoryProvider
from pydantic import BaseModel


class DataCollectionIdentifier(BaseModel):
collectionNumber: int


class VisitServiceClientBase(ABC):
"""
Object responsible for I/O in determining collection number
"""

@abstractmethod
async def create_new_collection(self) -> DataCollectionIdentifier:
"""Create new collection"""

@abstractmethod
async def get_current_collection(self) -> DataCollectionIdentifier:
"""Get current collection"""


class VisitServiceClient(VisitServiceClientBase):
_url: str

def __init__(self, url: str) -> None:
self._url = url

async def create_new_collection(self) -> DataCollectionIdentifier:
async with ClientSession() as session:
async with session.post(f"{self._url}/numtracker") as response:
if response.status == 200:
json = await response.json()
return DataCollectionIdentifier.parse_obj(json)
else:
raise Exception(response.status)

async def get_current_collection(self) -> DataCollectionIdentifier:
async with ClientSession() as session:
async with session.get(f"{self._url}/numtracker") as response:
if response.status == 200:
json = await response.json()
return DataCollectionIdentifier.parse_obj(json)
else:
raise Exception(response.status)


class LocalVisitServiceClient(VisitServiceClientBase):
_count: int

def __init__(self) -> None:
self._count = 0

async def create_new_collection(self) -> DataCollectionIdentifier:
self._count += 1
return DataCollectionIdentifier(collectionNumber=self._count)

async def get_current_collection(self) -> DataCollectionIdentifier:
return DataCollectionIdentifier(collectionNumber=self._count)


class VisitDirectoryProvider(DirectoryProvider):
"""
Gets information from a remote service to construct the path that detectors
should write to, and determine how their files should be named.
"""

_data_group_name: str
_data_directory: Path

_client: VisitServiceClientBase
_current_collection: Optional[DirectoryInfo]
_session: Optional[ClientSession]

def __init__(
self,
data_group_name: str,
data_directory: Path,
client: VisitServiceClientBase,
):
self._data_group_name = data_group_name
self._data_directory = data_directory
self._client = client

self._current_collection = None
self._session = None

async def update(self) -> None:
"""
Calls the visit service to create a new data collection in the current visit.
"""
# TODO: After visit service is more feature complete:
# TODO: Allow selecting visit as part of the request to BlueAPI
# TODO: Consume visit information from BlueAPI and pass down to this class
# TODO: Query visit service to get information about visit and data collection
# TODO: Use AuthN information as part of verification with visit service

try:
collection_id_info = await self._client.create_new_collection()
self._current_collection = self._generate_directory_info(collection_id_info)
except Exception as ex:
# TODO: The catch all is needed because the RunEngine will not
# currently handle it, see
# https://github.com/bluesky/bluesky/pull/1623
self._current_collection = None
logging.exception(ex)

def _generate_directory_info(
self,
collection_id_info: DataCollectionIdentifier,
) -> DirectoryInfo:
collection_id = collection_id_info.collectionNumber
file_prefix = f"{self._data_group_name}-{collection_id}"
return DirectoryInfo(str(self._data_directory), file_prefix)

def __call__(self) -> DirectoryInfo:
if self._current_collection is not None:
return self._current_collection
else:
raise ValueError(
"No current collection, update() needs to be called at least once"
)
11 changes: 7 additions & 4 deletions src/blueapi/preprocessors/attach_metadata.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,25 @@
import bluesky.plan_stubs as bps
import bluesky.preprocessors as bpp
from bluesky.utils import make_decorator
from ophyd_async.core import DirectoryProvider

from blueapi.core import MsgGenerator
from blueapi.data_management.visit_directory_provider import VisitDirectoryProvider

DATA_SESSION = "data_session"
DATA_GROUPS = "data_groups"


def attach_metadata(
plan: MsgGenerator,
provider: DirectoryProvider,
provider: VisitDirectoryProvider,
) -> MsgGenerator:
"""
Attach data session metadata to the runs within a plan and make it correlate
with an ophyd-async DirectoryProvider.
This calls the directory provider and ensures the start document contains
the correct data session.
This updates the directory provider (which in turn makes a call to to a service
to figure out which scan number we are using for such a scan), and ensures the
start document contains the correct data session.
Args:
plan: The plan to preprocess
Expand All @@ -29,6 +31,7 @@ def attach_metadata(
Yields:
Iterator[Msg]: Plan messages
"""
yield from bps.wait_for([provider.update])
directory_info = provider()
yield from bpp.inject_md_wrapper(
plan, md={DATA_SESSION: directory_info.filename_prefix}
Expand Down
23 changes: 18 additions & 5 deletions src/blueapi/service/handler.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
import logging
from typing import Mapping, Optional

from ophyd_async.core import StaticDirectoryProvider

from blueapi.config import ApplicationConfig
from blueapi.core import BlueskyContext
from blueapi.core.event import EventStream
from blueapi.data_management.visit_directory_provider import (
LocalVisitServiceClient,
VisitDirectoryProvider,
VisitServiceClient,
VisitServiceClientBase,
)
from blueapi.messaging import StompMessagingTemplate
from blueapi.messaging.base import MessagingTemplate
from blueapi.preprocessors.attach_metadata import attach_metadata
Expand Down Expand Up @@ -88,9 +92,18 @@ def setup_handler(
plan_wrappers = []

if config:
provider = StaticDirectoryProvider(
filename_prefix=f"{config.env.data_writing.group_name}-blueapi",
directory_path=str(config.env.data_writing.visit_directory),
visit_service_client: VisitServiceClientBase
if config.env.data_writing.visit_service_url is not None:
visit_service_client = VisitServiceClient(
config.env.data_writing.visit_service_url
)
else:
visit_service_client = LocalVisitServiceClient()

provider = VisitDirectoryProvider(
data_group_name=config.env.data_writing.group_name,
data_directory=config.env.data_writing.visit_directory,
client=visit_service_client,
)

# Make all dodal devices created by the context use provider if they can
Expand Down
66 changes: 66 additions & 0 deletions tests/data_management/test_visit_directory_provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
from pathlib import Path

import pytest
from ophyd_async.core import DirectoryInfo

from blueapi.data_management.visit_directory_provider import (
DataCollectionIdentifier,
LocalVisitServiceClient,
VisitDirectoryProvider,
VisitServiceClientBase,
)


@pytest.fixture
def visit_service_client() -> VisitServiceClientBase:
return LocalVisitServiceClient()


@pytest.fixture
def visit_directory_provider(
visit_service_client: VisitServiceClientBase,
) -> VisitDirectoryProvider:
return VisitDirectoryProvider("example", Path("/tmp"), visit_service_client)


@pytest.mark.asyncio
async def test_client_can_view_collection(
visit_service_client: VisitServiceClientBase,
) -> None:
collection = await visit_service_client.get_current_collection()
assert collection == DataCollectionIdentifier(collectionNumber=0)


@pytest.mark.asyncio
async def test_client_can_create_collection(
visit_service_client: VisitServiceClientBase,
) -> None:
collection = await visit_service_client.create_new_collection()
assert collection == DataCollectionIdentifier(collectionNumber=1)


@pytest.mark.asyncio
async def test_update_sets_collection_number(
visit_directory_provider: VisitDirectoryProvider,
) -> None:
await visit_directory_provider.update()
assert visit_directory_provider() == DirectoryInfo(
directory_path="/tmp",
filename_prefix="example-1",
)


@pytest.mark.asyncio
async def test_update_sets_collection_number_multi(
visit_directory_provider: VisitDirectoryProvider,
) -> None:
await visit_directory_provider.update()
assert visit_directory_provider() == DirectoryInfo(
directory_path="/tmp",
filename_prefix="example-1",
)
await visit_directory_provider.update()
assert visit_directory_provider() == DirectoryInfo(
directory_path="/tmp",
filename_prefix="example-2",
)
Loading

0 comments on commit 3ec7b08

Please sign in to comment.