-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Reintroduce VisitDirectoryProvider and update in preprocessor (#325)
Make calls to numtracker in preprocessor for unique file ids
- Loading branch information
1 parent
9724c0c
commit 3ec7b08
Showing
6 changed files
with
520 additions
and
27 deletions.
There are no files selected for viewing
Empty file.
128 changes: 128 additions & 0 deletions
128
src/blueapi/data_management/visit_directory_provider.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,128 @@ | ||
import logging | ||
from abc import ABC, abstractmethod | ||
from pathlib import Path | ||
from typing import Optional | ||
|
||
from aiohttp import ClientSession | ||
from ophyd_async.core import DirectoryInfo, DirectoryProvider | ||
from pydantic import BaseModel | ||
|
||
|
||
class DataCollectionIdentifier(BaseModel): | ||
collectionNumber: int | ||
|
||
|
||
class VisitServiceClientBase(ABC): | ||
""" | ||
Object responsible for I/O in determining collection number | ||
""" | ||
|
||
@abstractmethod | ||
async def create_new_collection(self) -> DataCollectionIdentifier: | ||
"""Create new collection""" | ||
|
||
@abstractmethod | ||
async def get_current_collection(self) -> DataCollectionIdentifier: | ||
"""Get current collection""" | ||
|
||
|
||
class VisitServiceClient(VisitServiceClientBase): | ||
_url: str | ||
|
||
def __init__(self, url: str) -> None: | ||
self._url = url | ||
|
||
async def create_new_collection(self) -> DataCollectionIdentifier: | ||
async with ClientSession() as session: | ||
async with session.post(f"{self._url}/numtracker") as response: | ||
if response.status == 200: | ||
json = await response.json() | ||
return DataCollectionIdentifier.parse_obj(json) | ||
else: | ||
raise Exception(response.status) | ||
|
||
async def get_current_collection(self) -> DataCollectionIdentifier: | ||
async with ClientSession() as session: | ||
async with session.get(f"{self._url}/numtracker") as response: | ||
if response.status == 200: | ||
json = await response.json() | ||
return DataCollectionIdentifier.parse_obj(json) | ||
else: | ||
raise Exception(response.status) | ||
|
||
|
||
class LocalVisitServiceClient(VisitServiceClientBase): | ||
_count: int | ||
|
||
def __init__(self) -> None: | ||
self._count = 0 | ||
|
||
async def create_new_collection(self) -> DataCollectionIdentifier: | ||
self._count += 1 | ||
return DataCollectionIdentifier(collectionNumber=self._count) | ||
|
||
async def get_current_collection(self) -> DataCollectionIdentifier: | ||
return DataCollectionIdentifier(collectionNumber=self._count) | ||
|
||
|
||
class VisitDirectoryProvider(DirectoryProvider): | ||
""" | ||
Gets information from a remote service to construct the path that detectors | ||
should write to, and determine how their files should be named. | ||
""" | ||
|
||
_data_group_name: str | ||
_data_directory: Path | ||
|
||
_client: VisitServiceClientBase | ||
_current_collection: Optional[DirectoryInfo] | ||
_session: Optional[ClientSession] | ||
|
||
def __init__( | ||
self, | ||
data_group_name: str, | ||
data_directory: Path, | ||
client: VisitServiceClientBase, | ||
): | ||
self._data_group_name = data_group_name | ||
self._data_directory = data_directory | ||
self._client = client | ||
|
||
self._current_collection = None | ||
self._session = None | ||
|
||
async def update(self) -> None: | ||
""" | ||
Calls the visit service to create a new data collection in the current visit. | ||
""" | ||
# TODO: After visit service is more feature complete: | ||
# TODO: Allow selecting visit as part of the request to BlueAPI | ||
# TODO: Consume visit information from BlueAPI and pass down to this class | ||
# TODO: Query visit service to get information about visit and data collection | ||
# TODO: Use AuthN information as part of verification with visit service | ||
|
||
try: | ||
collection_id_info = await self._client.create_new_collection() | ||
self._current_collection = self._generate_directory_info(collection_id_info) | ||
except Exception as ex: | ||
# TODO: The catch all is needed because the RunEngine will not | ||
# currently handle it, see | ||
# https://github.com/bluesky/bluesky/pull/1623 | ||
self._current_collection = None | ||
logging.exception(ex) | ||
|
||
def _generate_directory_info( | ||
self, | ||
collection_id_info: DataCollectionIdentifier, | ||
) -> DirectoryInfo: | ||
collection_id = collection_id_info.collectionNumber | ||
file_prefix = f"{self._data_group_name}-{collection_id}" | ||
return DirectoryInfo(str(self._data_directory), file_prefix) | ||
|
||
def __call__(self) -> DirectoryInfo: | ||
if self._current_collection is not None: | ||
return self._current_collection | ||
else: | ||
raise ValueError( | ||
"No current collection, update() needs to be called at least once" | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
from pathlib import Path | ||
|
||
import pytest | ||
from ophyd_async.core import DirectoryInfo | ||
|
||
from blueapi.data_management.visit_directory_provider import ( | ||
DataCollectionIdentifier, | ||
LocalVisitServiceClient, | ||
VisitDirectoryProvider, | ||
VisitServiceClientBase, | ||
) | ||
|
||
|
||
@pytest.fixture | ||
def visit_service_client() -> VisitServiceClientBase: | ||
return LocalVisitServiceClient() | ||
|
||
|
||
@pytest.fixture | ||
def visit_directory_provider( | ||
visit_service_client: VisitServiceClientBase, | ||
) -> VisitDirectoryProvider: | ||
return VisitDirectoryProvider("example", Path("/tmp"), visit_service_client) | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_client_can_view_collection( | ||
visit_service_client: VisitServiceClientBase, | ||
) -> None: | ||
collection = await visit_service_client.get_current_collection() | ||
assert collection == DataCollectionIdentifier(collectionNumber=0) | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_client_can_create_collection( | ||
visit_service_client: VisitServiceClientBase, | ||
) -> None: | ||
collection = await visit_service_client.create_new_collection() | ||
assert collection == DataCollectionIdentifier(collectionNumber=1) | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_update_sets_collection_number( | ||
visit_directory_provider: VisitDirectoryProvider, | ||
) -> None: | ||
await visit_directory_provider.update() | ||
assert visit_directory_provider() == DirectoryInfo( | ||
directory_path="/tmp", | ||
filename_prefix="example-1", | ||
) | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_update_sets_collection_number_multi( | ||
visit_directory_provider: VisitDirectoryProvider, | ||
) -> None: | ||
await visit_directory_provider.update() | ||
assert visit_directory_provider() == DirectoryInfo( | ||
directory_path="/tmp", | ||
filename_prefix="example-1", | ||
) | ||
await visit_directory_provider.update() | ||
assert visit_directory_provider() == DirectoryInfo( | ||
directory_path="/tmp", | ||
filename_prefix="example-2", | ||
) |
Oops, something went wrong.